[
  {
    "path": ".gitignore",
    "content": "# Repo-specific GitIgnore ----------------------------------------------------------------------------------------------\n*.jpg\n*.jpeg\n*.png\n*.bmp\n*.tif\n*.tiff\n*.heic\n*.JPG\n*.JPEG\n*.PNG\n*.BMP\n*.TIF\n*.TIFF\n*.HEIC\n*.mp4\n*.mov\n*.MOV\n*.avi\n*.data\n*.json\n*.pth\n*.cfg\n!cfg/yolov3*.cfg\n\nstorage.googleapis.com\nruns/*\ndata/*\n!data/images/zidane.jpg\n!data/images/bus.jpg\n!data/coco.names\n!data/coco_paper.names\n!data/coco.data\n!data/coco_*.data\n!data/coco_*.txt\n!data/trainvalno5k.shapes\n!data/*.sh\n\ntest.py\ntest_imgs/\n\npycocotools/*\nresults*.txt\ngcp_test*.sh\n\ncheckpoints/\n# output/\n# output*/\n*events*\nassests/*/\n\n# Datasets -------------------------------------------------------------------------------------------------------------\ncoco/\ncoco128/\nVOC/\n\n# MATLAB GitIgnore -----------------------------------------------------------------------------------------------------\n*.m~\n*.mat\n!targets*.mat\n\n# Neural Network weights -----------------------------------------------------------------------------------------------\n*.weights\n*.pt\n*.onnx\n*.mlmodel\n*.torchscript\ndarknet53.conv.74\nyolov3-tiny.conv.15\n\n# GitHub Python GitIgnore ----------------------------------------------------------------------------------------------\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\nwandb/\n.installed.cfg\n*.egg\n\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n# *.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# dotenv\n.env\n\n# virtualenv\n.venv*\nvenv*/\nENV*/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n\n# https://github.com/github/gitignore/blob/master/Global/macOS.gitignore -----------------------------------------------\n\n# General\n.DS_Store\n.AppleDouble\n.LSOverride\n\n# Icon must end with two \\r\nIcon\nIcon?\n\n# Thumbnails\n._*\n\n# Files that might appear in the root of a volume\n.DocumentRevisions-V100\n.fseventsd\n.Spotlight-V100\n.TemporaryItems\n.Trashes\n.VolumeIcon.icns\n.com.apple.timemachine.donotpresent\n\n# Directories potentially created on remote AFP share\n.AppleDB\n.AppleDesktop\nNetwork Trash Folder\nTemporary Items\n.apdisk\n\n\n# https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore\n# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm\n# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839\n\n# User-specific stuff:\n.idea/*\n.idea/**/workspace.xml\n.idea/**/tasks.xml\n.idea/dictionaries\n.html  # Bokeh Plots\n.pg  # TensorFlow Frozen Graphs\n.avi # videos\n\n# Sensitive or high-churn files:\n.idea/**/dataSources/\n.idea/**/dataSources.ids\n.idea/**/dataSources.local.xml\n.idea/**/sqlDataSources.xml\n.idea/**/dynamic.xml\n.idea/**/uiDesigner.xml\n\n# Gradle:\n.idea/**/gradle.xml\n.idea/**/libraries\n\n# CMake\ncmake-build-debug/\ncmake-build-release/\n\n# Mongo Explorer plugin:\n.idea/**/mongoSettings.xml\n\n## File-based project format:\n*.iws\n\n## Plugin-specific files:\n\n# IntelliJ\nout/\n\n# mpeltonen/sbt-idea plugin\n.idea_modules/\n\n# JIRA plugin\natlassian-ide-plugin.xml\n\n# Cursive Clojure plugin\n.idea/replstate.xml\n\n# Crashlytics plugin (for Android Studio and IntelliJ)\ncom_crashlytics_export_strings.xml\ncrashlytics.properties\ncrashlytics-build.properties\nfabric.properties\n\noutput/\ndata/"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [2023] [Jiaming Zhang]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\"> \n\n## Delivering Arbitrary-Modal Semantic Segmentation (CVPR 2023)\n\n</div>\n\n<p align=\"center\">\n<a href=\"https://arxiv.org/pdf/2303.01480.pdf\">\n    <img src=\"https://img.shields.io/badge/arXiv-2303.01480-red\" /></a>\n<a href=\"https://jamycheung.github.io/DELIVER.html\">\n    <img src=\"https://img.shields.io/badge/Project-page-green\" /></a>\n<a href=\"https://www.youtube.com/watch?v=X-VeSLsEToA\">\n    <img src=\"https://img.shields.io/badge/Video-YouTube-%23FF0000.svg\" /></a>\n<a href=\"https://pytorch.org/\">\n    <img src=\"https://img.shields.io/badge/Framework-PyTorch-orange.svg\" /></a>\n<a href=\"https://github.com/jamycheung/DELIVER/blob/main/LICENSE\">\n    <img src=\"https://img.shields.io/badge/License-Apache_2.0-blue.svg\" /></a>\n</p>\n\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-deliver)](https://paperswithcode.com/sota/semantic-segmentation-on-deliver?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-kitti-360)](https://paperswithcode.com/sota/semantic-segmentation-on-kitti-360?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-nyu-depth-v2)](https://paperswithcode.com/sota/semantic-segmentation-on-nyu-depth-v2?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/thermal-image-segmentation-on-mfn-dataset)](https://paperswithcode.com/sota/thermal-image-segmentation-on-mfn-dataset?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-mcubes)](https://paperswithcode.com/sota/semantic-segmentation-on-mcubes?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-mcubes-p)](https://paperswithcode.com/sota/semantic-segmentation-on-mcubes-p?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-urbanlf)](https://paperswithcode.com/sota/semantic-segmentation-on-urbanlf?p=delivering-arbitrary-modal-semantic)    \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/thermal-image-segmentation-on-noisy-rs-rgb-t)](https://paperswithcode.com/sota/thermal-image-segmentation-on-noisy-rs-rgb-t?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-ddd17)](https://paperswithcode.com/sota/semantic-segmentation-on-ddd17?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-dsec)](https://paperswithcode.com/sota/semantic-segmentation-on-dsec?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-bjroad)](https://paperswithcode.com/sota/semantic-segmentation-on-bjroad?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-tlcgis)](https://paperswithcode.com/sota/semantic-segmentation-on-tlcgis?p=delivering-arbitrary-modal-semantic)  \n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/delivering-arbitrary-modal-semantic/semantic-segmentation-on-porto)](https://paperswithcode.com/sota/semantic-segmentation-on-porto?p=delivering-arbitrary-modal-semantic)\n\n## Introduction\n\nTo conduct arbitrary-modal semantic segmentation, we create **DeLiVER** benchmark, covering **De**pth, **Li**DAR, multiple **V**iews, **E**vents, and **R**GB. It has four *severe weather conditions* as well as five *sensor failure cases* to exploit modal complementarity and resolve partial outages. Besides, we present the arbitrary cross-modal segmentation model **CMNeXt**, allowing to scale from 1 to 81 modalities on the DeLiVER, KITTI-360, MFNet, NYU Depth V2, UrbanLF, and MCubeS datasets.\n\nFor more details, please check our [arXiv](https://arxiv.org/pdf/2303.01480.pdf) paper.\n\n## Updates\n- [x] 03/2023, init repository.\n- [x] 04/2023, release front-view DeLiVER. Download from [**GoogleDrive**](https://drive.google.com/file/d/1P-glCmr-iFSYrzCfNawgVI9qKWfP94pm/view?usp=share_link).\n- [x] 04/2023, release CMNeXt model weights. Download from [**GoogleDrive**](https://drive.google.com/drive/folders/1MZaaZ5_rEVSjns3TBM0UDt6IW4X-HPIN?usp=share_link).\n\n## DeLiVER dataset\n\n![DELIVER](figs/DeLiVER.gif)\n\n![DELIVER](figs/DELIVER.png)\n\nDeLiVER multimodal dataset including (a) four adverse conditions out of five conditions(**cloudy, foggy, night-time, rainy and sunny**). Apart from normal cases, each condition has five corner cases (**MB: Motion Blur; OE: Over-Exposure; UE: Under-Exposure; LJ: LiDAR-Jitter; and EL: Event Low-resolution**). Each sample has six views. Each view has four modalities and two labels (semantic and instance). (b) is the data statistics. (c) is the data distribution of 25 semantic classes.\n\n\n### DELIVER splitting\n![DELIVER](figs/DELIVER_more.png)\n\n\n### Data folder structure\n**Download DELIVER dataset from [**GoogleDrive**](https://drive.google.com/file/d/1P-glCmr-iFSYrzCfNawgVI9qKWfP94pm/view?usp=share_link) (~12.2 GB)**. \n\nThe `data/DELIVER` folder is structured as:\n```text\nDELIVER\n├── depth\n│   ├── cloud\n│   │   ├── test\n│   │   │   ├── MAP_10_point102\n│   │   │   │   ├── 045050_depth_front.png\n│   │   │   │   ├── ...\n│   │   ├── train\n│   │   └── val\n│   ├── fog\n│   ├── night\n│   ├── rain\n│   └── sun\n├── event\n├── hha\n├── img\n├── lidar\n└── semantic\n```\n\n## CMNeXt model\n\n![CMNeXt](figs/CMNeXt.png)\n\nCMNeXt architecture in Hub2Fuse paradigm and asymmetric branches, having e.g., Multi-Head Self-Attention (MHSA) blocks in the RGB branch and our Parallel Pooling Mixer (PPX) blocks in the accompanying branch. At the hub step, the Self-Query Hub selects informative features from the supplementary modalities. At the fusion step, the feature rectification module (FRM) and feature fusion module (FFM) are used for feature fusion. Between stages, features of each modality are restored via adding the fused feature. The four-stage fused features are forwarded to the segmentation head for the final prediction.\n\n\n## Environment\n\n```bash\nconda env create -f environment.yml\nconda activate cmnext\n# Optional: install apex follow: https://github.com/NVIDIA/apex\n```\n\n## Data preparation\nPrepare six datasets:\n- [DELIVER](https://github.com/jamycheung/DELIVER), for RGB-Depth-Event-LiDAR semantic segmentation.\n- [KITTI-360](https://www.cvlibs.net/datasets/kitti-360/), for RGB-Depth-Event-LiDAR semantic segmentation.\n- [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html), for RGB-Depth semantic segmentation.\n- [MFNet](https://github.com/haqishen/MFNet-pytorch), for RGB-Thermal semantic segmentation.\n- [UrbanLF](https://github.com/HAWKEYE-Group/UrbanLF), for light-filed segmentation based on sub-aperture images.\n- [MCubeS](https://github.com/kyotovision-public/multimodal-material-segmentation), for multimodal material segmentation with RGB-A-D-N modalities.\n\nThen, all datasets are structured as:\n\n```\ndata/\n├── DELIVER\n│   ├── img\n│   ├── hha\n│   ├── event\n│   ├── lidar\n│   └── semantic\n├── KITTI-360\n│   ├── data_2d_raw\n│   ├── data_2d_hha\n│   ├── data_2d_event\n│   ├── data_2d_lidar\n│   └── data_2d_semantics\n├── NYUDepthv2\n│   ├── RGB\n│   ├── HHA\n│   └── Label\n├── MFNet\n│   ├── rgb\n│   ├── ther\n│   └── labels\n├── UrbanLF\n│   ├── Syn\n│   └── real\n├── MCubeS\n│   ├── polL_color\n│   ├── polL_aolp\n│   ├── polL_dolp\n│   ├── NIR_warped\n│   └── SS\n```\n\n*For RGB-Depth, the [HHA format](https://github.com/charlesCXK/Depth2HHA-python) is generated from depth image.*\n\n## Model Zoo\n\n### DELIVER dataset \n\n| Model-Modal      | #Params(M) | GFLOPs | mIoU   | weight |\n| :--------------- | :---------- | :----- | :----- | :------ |\n| CMNeXt-RGB       | 25.79      | 38.93 | 57.20 | [GoogleDrive](https://drive.google.com/drive/folders/1OWteEOrjfrC3VNg3sxJFZPHz9urkZ3lm?usp=share_link) |\n| CMNeXt-RGB-E     | 58.69      | 62.94 | 57.48 | [GoogleDrive](https://drive.google.com/drive/folders/1OWteEOrjfrC3VNg3sxJFZPHz9urkZ3lm?usp=share_link) |\n| CMNeXt-RGB-L     | 58.69      | 62.94 | 58.04 | [GoogleDrive](https://drive.google.com/drive/folders/1OWteEOrjfrC3VNg3sxJFZPHz9urkZ3lm?usp=share_link) |\n| CMNeXt-RGB-D     | 58.69      | 62.94 | 63.58 | [GoogleDrive](https://drive.google.com/drive/folders/1OWteEOrjfrC3VNg3sxJFZPHz9urkZ3lm?usp=share_link) |\n| CMNeXt-RGB-D-E   | 58.72      | 64.19 | 64.44 | [GoogleDrive](https://drive.google.com/drive/folders/1OWteEOrjfrC3VNg3sxJFZPHz9urkZ3lm?usp=share_link) |\n| CMNeXt-RGB-D-L   | 58.72      | 64.19 | 65.50 | [GoogleDrive](https://drive.google.com/drive/folders/1OWteEOrjfrC3VNg3sxJFZPHz9urkZ3lm?usp=share_link) |\n| CMNeXt-RGB-D-E-L | 58.73      | 65.42 | 66.30 | [GoogleDrive](https://drive.google.com/drive/folders/1OWteEOrjfrC3VNg3sxJFZPHz9urkZ3lm?usp=share_link) |\n\n\n### KITTI360 dataset\n\n| Model-Modal      | mIoU   | weight |\n| :--------------- | :----- | :------ |\n| CMNeXt-RGB       | 67.04 | [GoogleDrive](https://drive.google.com/drive/folders/1OXfBNNcwMCQWGmvOqvkAHd8xViKJ6djf?usp=share_link) |\n| CMNeXt-RGB-E     | 66.13 | [GoogleDrive](https://drive.google.com/drive/folders/1OXfBNNcwMCQWGmvOqvkAHd8xViKJ6djf?usp=share_link) |\n| CMNeXt-RGB-L     | 65.26 | [GoogleDrive](https://drive.google.com/drive/folders/1OXfBNNcwMCQWGmvOqvkAHd8xViKJ6djf?usp=share_link) |\n| CMNeXt-RGB-D     | 65.09 | [GoogleDrive](https://drive.google.com/drive/folders/1OXfBNNcwMCQWGmvOqvkAHd8xViKJ6djf?usp=share_link) |\n| CMNeXt-RGB-D-E   | 67.73 | [GoogleDrive](https://drive.google.com/drive/folders/1OXfBNNcwMCQWGmvOqvkAHd8xViKJ6djf?usp=share_link) |\n| CMNeXt-RGB-D-L   | 66.55 | [GoogleDrive](https://drive.google.com/drive/folders/1OXfBNNcwMCQWGmvOqvkAHd8xViKJ6djf?usp=share_link) |\n| CMNeXt-RGB-D-E-L | 67.84 | [GoogleDrive](https://drive.google.com/drive/folders/1OXfBNNcwMCQWGmvOqvkAHd8xViKJ6djf?usp=share_link) |\n\n### NYU Depth V2\n\n| Model-Modal      | mIoU   | weight |\n| :--------------- | :----- | :------ |\n| CMNeXt-RGB-D (MiT-B4)     | 56.9 | [GoogleDrive](https://drive.google.com/drive/folders/1OXrMv1Mi6E-vyedlkNpfc5ibJeJDfwqq?usp=share_link) |\n### MFNet\n\n| Model-Modal      | mIoU   | weight |\n| :--------------- | :----- | :------ |\n| CMNeXt-RGB-D (MiT-B4)     | 59.9 | [GoogleDrive](https://drive.google.com/drive/folders/1OaOHMbD5P_HPwTzzXwdYRUFIiGkyFQfP?usp=share_link) |\n### UrbanLF\nThere are **real** and **synthetic** datasets. \n\n| Model-Modal     | Real   | weight | Syn    | weight |\n| :-------------- | :----- | :----- | :----- | :----- |\n| CMNeXt-RGB      | 82.20 | [GoogleDrive](https://drive.google.com/drive/folders/1OfepYOYaM8I0itjuHK4csqmAuu_zHol-?usp=share_link) | 78.53 | [GoogleDrive](https://drive.google.com/drive/folders/1OfepYOYaM8I0itjuHK4csqmAuu_zHol-?usp=share_link) |\n| CMNeXt-RGB-LF8  | 83.22 | [GoogleDrive](https://drive.google.com/drive/folders/1OfepYOYaM8I0itjuHK4csqmAuu_zHol-?usp=share_link) | 80.74 | [GoogleDrive](https://drive.google.com/drive/folders/1OfepYOYaM8I0itjuHK4csqmAuu_zHol-?usp=share_link) |\n| CMNeXt-RGB-LF33 | 82.62 | [GoogleDrive](https://drive.google.com/drive/folders/1OfepYOYaM8I0itjuHK4csqmAuu_zHol-?usp=share_link) | 80.98 | [GoogleDrive](https://drive.google.com/drive/folders/1OfepYOYaM8I0itjuHK4csqmAuu_zHol-?usp=share_link) |\n| CMNeXt-RGB-LF80 | 83.11 | [GoogleDrive](https://drive.google.com/drive/folders/1OfepYOYaM8I0itjuHK4csqmAuu_zHol-?usp=share_link) | 81.02 | [GoogleDrive](https://drive.google.com/drive/folders/1OfepYOYaM8I0itjuHK4csqmAuu_zHol-?usp=share_link) |\n\n### MCubeS\n| Model-Modal      | mIoU   | weight |\n| :--------------- | :----- | :----- |\n| CMNeXt-RGB       | 48.16 | [GoogleDrive](https://drive.google.com/drive/folders/1OgbqpT6TSCPPsoJn0sh2wfXNy5mL5nk9?usp=share_link) |\n| CMNeXt-RGB-A     | 48.42 | [GoogleDrive](https://drive.google.com/drive/folders/1OgbqpT6TSCPPsoJn0sh2wfXNy5mL5nk9?usp=share_link) |\n| CMNeXt-RGB-A-D   | 49.48 | [GoogleDrive](https://drive.google.com/drive/folders/1OgbqpT6TSCPPsoJn0sh2wfXNy5mL5nk9?usp=share_link) |\n| CMNeXt-RGB-A-D-N | 51.54 | [GoogleDrive](https://drive.google.com/drive/folders/1OgbqpT6TSCPPsoJn0sh2wfXNy5mL5nk9?usp=share_link) |\n\n\n\n## Training\n\nBefore training, please download [pre-trained SegFormer](https://drive.google.com/drive/folders/10XgSW8f7ghRs9fJ0dE-EV8G2E_guVsT5?usp=sharing), such as `checkpoints/pretrained/segformer/mit_b2.pth`.\n\n```text\ncheckpoints/pretrained/segformer\n├── mit_b2.pth\n└── mit_b4.pth\n```\n\nTo train CMNeXt model, please use change yaml file for `--cfg`. Several training examples using 4 A100 GPUs are:  \n\n```bash\ncd path/to/DELIVER\nconda activate cmnext\nexport PYTHONPATH=\"path/to/DELIVER\"\npython -m torch.distributed.launch --nproc_per_node=4 --use_env tools/train_mm.py --cfg configs/deliver_rgbdel.yaml\npython -m torch.distributed.launch --nproc_per_node=4 --use_env tools/train_mm.py --cfg configs/kitti360_rgbdel.yaml\npython -m torch.distributed.launch --nproc_per_node=4 --use_env tools/train_mm.py --cfg configs/nyu_rgbd.yaml\npython -m torch.distributed.launch --nproc_per_node=4 --use_env tools/train_mm.py --cfg configs/mfnet_rgbt.yaml\npython -m torch.distributed.launch --nproc_per_node=4 --use_env tools/train_mm.py --cfg configs/mcubes_rgbadn.yaml\npython -m torch.distributed.launch --nproc_per_node=4 --use_env tools/train_mm.py --cfg configs/urbanlf.yaml\n```\n\n\n## Evaluation\nTo evaluate CMNeXt models, please download respective model weights ([**GoogleDrive**](https://drive.google.com/drive/folders/1MZaaZ5_rEVSjns3TBM0UDt6IW4X-HPIN?usp=share_link)) as:\n\n\n```text\noutput/\n├── DELIVER\n│   ├── cmnext_b2_deliver_rgb.pth\n│   ├── cmnext_b2_deliver_rgbd.pth\n│   ├── cmnext_b2_deliver_rgbde.pth\n│   ├── cmnext_b2_deliver_rgbdel.pth\n│   ├── cmnext_b2_deliver_rgbdl.pth\n│   ├── cmnext_b2_deliver_rgbe.pth\n│   └── cmnext_b2_deliver_rgbl.pth\n├── KITTI360\n│   ├── cmnext_b2_kitti360_rgb.pth\n│   ├── cmnext_b2_kitti360_rgbd.pth\n│   ├── cmnext_b2_kitti360_rgbde.pth\n│   ├── cmnext_b2_kitti360_rgbdel.pth\n│   ├── cmnext_b2_kitti360_rgbdl.pth\n│   ├── cmnext_b2_kitti360_rgbe.pth\n│   └── cmnext_b2_kitti360_rgbl.pth\n├── MCubeS\n│   ├── cmnext_b2_mcubes_rgb.pth\n│   ├── cmnext_b2_mcubes_rgba.pth\n│   ├── cmnext_b2_mcubes_rgbad.pth\n│   └── cmnext_b2_mcubes_rgbadn.pth\n├── MFNet\n│   └── cmnext_b4_mfnet_rgbt.pth\n├── NYU_Depth_V2\n│   └── cmnext_b4_nyu_rgbd.pth\n├── UrbanLF\n│   ├── cmnext_b4_urbanlf_real_rgblf1.pth\n│   ├── cmnext_b4_urbanlf_real_rgblf33.pth\n│   ├── cmnext_b4_urbanlf_real_rgblf8.pth\n│   ├── cmnext_b4_urbanlf_real_rgblf80.pth\n│   ├── cmnext_b4_urbanlf_syn_rgblf1.pth\n│   ├── cmnext_b4_urbanlf_syn_rgblf33.pth\n│   ├── cmnext_b4_urbanlf_syn_rgblf8.pth\n│   └── cmnext_b4_urbanlf_syn_rgblf80.pth\n```\n\nThen, modify `--cfg` to respective config file, and run:\n```bash\ncd path/to/DELIVER\nconda activate cmnext\nexport PYTHONPATH=\"path/to/DELIVER\"\nCUDA_VISIBLE_DEVICES=0 python tools/val_mm.py --cfg configs/deliver_rgbdel.yaml\n```\n\nOn DeLiVER dataset, there are **validation** and **test** sets. Please check [*val_mm.py*](tools/val_mm.py) to modify the dataset for validation and test sets.\n\nTo evaluate the different cases (adverse weather conditions, sensor failures), modify the `cases` list at [*val_mm.py*](tools/val_mm.py), as shown below:\n\n```python\n# cases = ['cloud', 'fog', 'night', 'rain', 'sun']\n# cases = ['motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres']\ncases = [None] # all\n```\nNote that the default value is `[None]` for all cases.\n\n\n### DELIVER visualization\n\n<img src=\"figs/DELIVER_vis.png\" width=\"500px\">\n\nThe visualization results on DELIVER dataset. From left to right are the respective *cloudy*, *foggy*, *night* and *rainy* scene.\n## Acknowledgements\nThanks for the public repositories:\n\n- [RGBX-semantic-segmentation](https://github.com/huaaaliu/RGBX_Semantic_Segmentation)\n- [Semantic-segmentation](https://github.com/sithu31296/semantic-segmentation)\n\n## License\n\nThis repository is under the Apache-2.0 license. For commercial use, please contact with the authors.\n\n\n## Citations\n\nIf you use DeLiVer dataset and CMNeXt model, please cite the following works:\n\n- **DeLiVER & CMNeXt** [[**PDF**](https://arxiv.org/pdf/2303.01480.pdf)]\n```\n@inproceedings{zhang2023delivering,\n  title={Delivering Arbitrary-Modal Semantic Segmentation},\n  author={Zhang, Jiaming and Liu, Ruiping and Shi, Hao and Yang, Kailun and Rei{\\ss}, Simon and Peng, Kunyu and Fu, Haodong and Wang, Kaiwei and Stiefelhagen, Rainer},\n  booktitle={CVPR},\n  year={2023}\n}\n```\n\n- **CMX** [[**PDF**](https://arxiv.org/pdf/2203.04838.pdf)]\n```\n@article{zhang2023cmx,\n  title={CMX: Cross-modal fusion for RGB-X semantic segmentation with transformers},\n  author={Zhang, Jiaming and Liu, Huayao and Yang, Kailun and Hu, Xinxin and Liu, Ruiping and Stiefelhagen, Rainer},\n  journal={IEEE Transactions on Intelligent Transportation Systems},\n  year={2023}\n}\n```\n"
  },
  {
    "path": "configs/deliver_rgbdel.yaml",
    "content": "DEVICE          : cuda              # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)\nSAVE_DIR        : 'output'          # output folder name used for saving the model, logs and inference results\n\nMODEL:\n  NAME          : CMNeXt                                            # name of the model you are using\n  BACKBONE      : CMNeXt-B2                                         # model variant\n  PRETRAINED    : 'checkpoints/pretrained/segformer/mit_b2.pth'     # backbone model's weight \n  RESUME        : ''                                                # checkpoint file \n\nDATASET:\n  NAME          : DELIVER                                          # dataset name to be trained with (camvid, cityscapes, ade20k)\n  ROOT          : 'data/DELIVER'                                   # dataset root path\n  IGNORE_LABEL  : 255\n  # MODALS        : ['img']\n  # MODALS        : ['img', 'depth']\n  # MODALS        : ['img', 'event']\n  # MODALS        : ['img', 'lidar']\n  # MODALS        : ['img', 'depth', 'event']\n  # MODALS        : ['img', 'depth', 'lidar']\n  MODALS        : ['img', 'depth', 'event', 'lidar']\n\nTRAIN:\n  IMAGE_SIZE    : [1024, 1024]    # training image size in (h, w)\n  BATCH_SIZE    : 2               # batch size used to train\n  EPOCHS        : 200             # number of epochs to train\n  EVAL_START    : 100             # evaluation interval start\n  EVAL_INTERVAL : 1               # evaluation interval during training\n  AMP           : false           # use AMP in training\n  DDP           : true            # use DDP training\n\nLOSS:\n  NAME          : OhemCrossEntropy          # loss function name \n  CLS_WEIGHTS   : false            # use class weights in loss calculation\n\nOPTIMIZER:\n  NAME          : adamw           # optimizer name\n  LR            : 0.00006         # initial learning rate used in optimizer\n  WEIGHT_DECAY  : 0.01            # decay rate used in optimizer\n\nSCHEDULER:\n  NAME          : warmuppolylr    # scheduler name\n  POWER         : 0.9             # scheduler power\n  WARMUP        : 10              # warmup epochs used in scheduler\n  WARMUP_RATIO  : 0.1             # warmup ratio\n  \n\nEVAL:\n  # MODEL_PATH    : 'output/DELIVER/cmnext_b2_deliver_rgb.pth'\n  # MODEL_PATH    : 'output/DELIVER/cmnext_b2_deliver_rgbd.pth'\n  # MODEL_PATH    : 'output/DELIVER/cmnext_b2_deliver_rgbe.pth'\n  # MODEL_PATH    : 'output/DELIVER/cmnext_b2_deliver_rgbl.pth'\n  # MODEL_PATH    : 'output/DELIVER/cmnext_b2_deliver_rgbde.pth'\n  # MODEL_PATH    : 'output/DELIVER/cmnext_b2_deliver_rgbdl.pth'\n  MODEL_PATH    : 'output/DELIVER/cmnext_b2_deliver_rgbdel.pth'\n  IMAGE_SIZE    : [1024, 1024]                            # evaluation image size in (h, w)        \n  BATCH_SIZE    : 4                                       # batch size used to train               \n  MSF: \n    ENABLE      : false                                   # multi-scale and flip evaluation  \n    FLIP        : true                                    # use flip in evaluation  \n    SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation                \n\n\nTEST:\n  MODEL_PATH    : 'output/DELIVER/cmnext_b2_deliver_rgbdel.pth'    # trained model file path\n  FILE          : 'data/DELIVER'                          # filename or foldername \n  IMAGE_SIZE    : [1024, 1024]                            # inference image size in (h, w)\n  OVERLAY       : false                                   # save the overlay result (image_alpha+label_alpha)"
  },
  {
    "path": "configs/kitti360_rgbdel.yaml",
    "content": "DEVICE          : cuda              # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)\nSAVE_DIR        : 'output'          # output folder name used for saving the model, logs and inference results\n\nMODEL:\n  NAME          : CMNeXt                                            # name of the model you are using\n  BACKBONE      : CMNeXt-B2                                         # model variant\n  PRETRAINED    : 'checkpoints/pretrained/segformer/mit_b2.pth'     # backbone model's weight \n  RESUME        : ''                                                # checkpoint file \n\nDATASET:\n  NAME          : KITTI360                                          # dataset name to be trained with (camvid, cityscapes, ade20k)\n  ROOT          : 'data/KITTI360'                                   # dataset root path\n  IGNORE_LABEL  : 255\n  # MODALS        : ['img']\n  # MODALS        : ['img', 'depth']\n  # MODALS        : ['img', 'event']\n  # MODALS        : ['img', 'lidar']\n  # MODALS        : ['img', 'depth', 'event']\n  # MODALS        : ['img', 'depth', 'lidar']\n  MODALS        : ['img', 'depth', 'event', 'lidar']\n\nTRAIN:\n  IMAGE_SIZE    : [376, 1408]     # training image size in (h, w)\n  BATCH_SIZE    : 4               # batch size used to train --- KD\n  EPOCHS        : 40              # number of epochs to train\n  EVAL_START    : 10              # evaluation interval during training\n  EVAL_INTERVAL : 1               # evaluation interval during training\n  AMP           : false           # use AMP in training\n  DDP           : true            # use DDP training\n\nLOSS:\n  NAME          : OhemCrossEntropy          # loss function name\n  CLS_WEIGHTS   : false            # use class weights in loss calculation\n\nOPTIMIZER:\n  NAME          : adamw           # optimizer name\n  LR            : 0.00006         # initial learning rate used in optimizer\n  WEIGHT_DECAY  : 0.01            # decay rate used in optimizer \n\nSCHEDULER:\n  NAME          : warmuppolylr    # scheduler name\n  POWER         : 0.9             # scheduler power\n  WARMUP        : 10              # warmup epochs used in scheduler\n  WARMUP_RATIO  : 0.1             # warmup ratio\n  \n\nEVAL:\n  # MODEL_PATH    : 'output/KITTI360/cmnext_b2_kitti360_rgb.pth'\n  # MODEL_PATH    : 'output/KITTI360/cmnext_b2_kitti360_rgbd.pth'\n  # MODEL_PATH    : 'output/KITTI360/cmnext_b2_kitti360_rgbe.pth'\n  # MODEL_PATH    : 'output/KITTI360/cmnext_b2_kitti360_rgbl.pth'\n  # MODEL_PATH    : 'output/KITTI360/cmnext_b2_kitti360_rgbde.pth'\n  # MODEL_PATH    : 'output/KITTI360/cmnext_b2_kitti360_rgbdl.pth'\n  MODEL_PATH    : 'output/KITTI360/cmnext_b2_kitti360_rgbdel.pth'\n  IMAGE_SIZE    : [376, 1408]                            # evaluation image size in (h, w)                       \n  BATCH_SIZE    : 4                                      # batch size used to train\n  MSF: \n    ENABLE      : false                                   # multi-scale and flip evaluation  \n    FLIP        : true                                    # use flip in evaluation  \n    SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation                \n\n\nTEST:\n  MODEL_PATH    : 'output/KITTI360/cmnext_b2_kitti360_rgbdel.pth'    # trained model file path\n  FILE          : 'data/KITTI360'                    # filename or foldername \n  IMAGE_SIZE    : [376, 1408]                            # inference image size in (h, w)\n  OVERLAY       : false                                    # save the overlay result (image_alpha+label_alpha)"
  },
  {
    "path": "configs/mcubes_rgbadn.yaml",
    "content": "DEVICE          : cuda              # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)\nSAVE_DIR        : 'output'          # output folder name used for saving the model, logs and inference results\n\nMODEL:\n  NAME          : CMNeXt                                            # name of the model you are using\n  BACKBONE      : CMNeXt-B2                                         # model variant\n  PRETRAINED    : 'checkpoints/pretrained/segformer/mit_b4.pth'     # backbone model's weight \n  RESUME        : ''                                                # checkpoint file\n\nDATASET:\n  NAME          : MCubeS                                          # dataset name to be trained with (camvid, cityscapes, ade20k)\n  ROOT          : 'data/MCubeS/multimodal_dataset'                # dataset root path\n  IGNORE_LABEL  : 255\n  # MODALS        : ['image'] # \n  # MODALS        : ['image', 'aolp']\n  # MODALS        : ['image', 'aolp', 'dolp']\n  MODALS        : ['image', 'aolp', 'dolp', 'nir']\n\nTRAIN:\n  IMAGE_SIZE    : [512, 512]      # training image size in (h, w) === Fixed in dataloader, following MCubeSNet\n  BATCH_SIZE    : 4               # batch size used to train\n  EPOCHS        : 500             # number of epochs to train\n  EVAL_START    : 400             # evaluation interval during training\n  EVAL_INTERVAL : 1               # evaluation interval during training\n  AMP           : false           # use AMP in training\n  DDP           : true            # use DDP training\n\nLOSS:\n  NAME          : OhemCrossEntropy     # loss function name\n  CLS_WEIGHTS   : false            # use class weights in loss calculation\n\nOPTIMIZER:\n  NAME          : adamw           # optimizer name\n  LR            : 0.00006         # initial learning rate used in optimizer\n  WEIGHT_DECAY  : 0.01            # decay rate used in optimizer \n\nSCHEDULER:\n  NAME          : warmuppolylr    # scheduler name\n  POWER         : 0.9             # scheduler power\n  WARMUP        : 10              # warmup epochs used in scheduler\n  WARMUP_RATIO  : 0.1             # warmup ratio\n  \n\nEVAL:\n  # MODEL_PATH    : 'output/MCubeS/cmnext_b2_mcubes_rgb.pth'\n  # MODEL_PATH    : 'output/MCubeS/cmnext_b2_mcubes_rgba.pth'\n  # MODEL_PATH    : 'output/MCubeS/cmnext_b2_mcubes_rgbad.pth'\n  MODEL_PATH    : 'output/MCubeS/cmnext_b2_mcubes_rgbadn.pth'\n  IMAGE_SIZE    : [1024, 1024]    # evaluation image size in (h, w)                       \n  BATCH_SIZE    : 2               # batch size used to train\n  MSF: \n    ENABLE      : false                                   # multi-scale and flip evaluation  \n    FLIP        : true                                    # use flip in evaluation  \n    SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation                \n"
  },
  {
    "path": "configs/mfnet_rgbt.yaml",
    "content": "DEVICE          : cuda              # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)\nSAVE_DIR        : 'output'         # output folder name used for saving the model, logs and inference results\n\nMODEL:\n  NAME          : CMNeXt                                            # name of the model you are using\n  BACKBONE      : CMNeXt-B4                                         # model variant\n  PRETRAINED    : 'checkpoints/pretrained/segformer/mit_b4.pth'     # backbone model's weight \n  RESUME        : ''                                                # checkpoint file \n\nDATASET:\n  NAME          : MFNet                                          # dataset name to be trained with (camvid, cityscapes, ade20k)\n  ROOT          : 'data/MFNet'                                   # dataset root path\n  IGNORE_LABEL  : 255\n  # MODALS        : ['img']\n  MODALS        : ['img', 'thermal']\n\nTRAIN:\n  IMAGE_SIZE    : [480, 640]      # training image size in (h, w)\n  BATCH_SIZE    : 4               # batch size used to train\n  EPOCHS        : 500             # number of epochs to train\n  EVAL_START    : 300             # evaluation interval during training\n  EVAL_INTERVAL : 1               # evaluation interval during training\n  AMP           : false           # use AMP in training\n  DDP           : true            # use DDP training\n\nLOSS:\n  NAME          : CrossEntropy     # loss function name (ohemce, ce, dice)\n  CLS_WEIGHTS   : false            # use class weights in loss calculation\n\nOPTIMIZER:\n  NAME          : adamw           # optimizer name\n  LR            : 0.00006         # initial learning rate used in optimizer\n  WEIGHT_DECAY  : 0.01            # decay rate used in optimizer \n\nSCHEDULER:\n  NAME          : warmuppolylr    # scheduler name\n  POWER         : 0.9             # scheduler power\n  WARMUP        : 10              # warmup epochs used in scheduler\n  WARMUP_RATIO  : 0.1             # warmup ratio\n  \n\nEVAL:\n  MODEL_PATH    : 'output/MFNet/cmnext_b4_mfnet_rgbt.pth'\n  IMAGE_SIZE    : [480, 640]      # evaluation image size in (h, w)                       \n  BATCH_SIZE    : 2               # batch size used to train\n  MSF: \n    ENABLE      : false                                   # multi-scale and flip evaluation  \n    FLIP        : true                                    # use flip in evaluation  \n    SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation                \n"
  },
  {
    "path": "configs/nyu_rgbd.yaml",
    "content": "DEVICE          : cuda              # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)\nSAVE_DIR        : 'output'          # output folder name used for saving the model, logs and inference results\n\nMODEL:\n  NAME          : CMNeXt                                            # name of the model you are using\n  BACKBONE      : CMNeXt-B4                                         # model variant\n  PRETRAINED    : 'checkpoints/pretrained/segformer/mit_b4.pth'     # backbone model's weight \n  RESUME        : ''                                                # checkpoint file \n\nDATASET:\n  NAME          : NYU                                          # dataset name to be trained with (camvid, cityscapes, ade20k)\n  ROOT          : 'data/NYUDepthv2'                            # dataset root path\n  IGNORE_LABEL  : 255\n  # MODALS        : ['img'] \n  MODALS        : ['img', 'depth']\n\nTRAIN:\n  IMAGE_SIZE    : [480, 640]      # training image size in (h, w)\n  BATCH_SIZE    : 4               # batch size used to train\n  EPOCHS        : 500             # number of epochs to train\n  EVAL_START    : 300             # evaluation interval during training\n  EVAL_INTERVAL : 1               # evaluation interval during training\n  AMP           : false           # use AMP in training\n  DDP           : true            # use DDP training\n\nLOSS:\n  NAME          : CrossEntropy     # loss function name\n  CLS_WEIGHTS   : false            # use class weights in loss calculation\n\nOPTIMIZER:\n  NAME          : adamw           # optimizer name\n  LR            : 0.00006         # initial learning rate used in optimizer\n  WEIGHT_DECAY  : 0.01            # decay rate used in optimizer \n\nSCHEDULER:\n  NAME          : warmuppolylr    # scheduler name\n  POWER         : 0.9             # scheduler power\n  WARMUP        : 10              # warmup epochs used in scheduler\n  WARMUP_RATIO  : 0.1             # warmup ratio\n  \n\nEVAL:\n  MODEL_PATH    : 'output/NYU_Depth_V2/cmnext_b4_nyu_rgbd.pth'\n  IMAGE_SIZE    : [480, 640]      # evaluation image size in (h, w)                       \n  BATCH_SIZE    : 2               # batch size used to train\n  MSF: \n    ENABLE      : true                                    # multi-scale and flip evaluation  \n    FLIP        : true                                    # use flip in evaluation  \n    SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation                \n"
  },
  {
    "path": "configs/urbanlf.yaml",
    "content": "DEVICE          : cuda              # device used for training and evaluation (cpu, cuda, cuda0, cuda1, ...)\nSAVE_DIR        : 'output'          # output folder name used for saving the model, logs and inference results\n\nMODEL:\n  NAME          : CMNeXt                                            # name of the model you are using\n  BACKBONE      : CMNeXt-B4                                         # model variant\n  PRETRAINED    : 'checkpoints/pretrained/segformer/mit_b4.pth'     # backbone model's weight \n  RESUME        : ''                                                # checkpoint file \n\nDATASET:\n  NAME          : UrbanLF                                          # dataset name to be trained with (camvid, cityscapes, ade20k)\n  # ROOT          : 'data/UrBanLF/real'                              # dataset root path, for real dataset\n  ROOT          : 'data/UrBanLF/Syn'                             # dataset root path, for synthetic dataset\n  IGNORE_LABEL  : 255\n  # MODALS        : ['img'] \n  # MODALS        : ['img', '5_1', '5_2', '5_3', '5_4', '5_6', '5_7', '5_8', '5_9']\n  # MODALS        : ['img', '1_1', '1_5', '1_9', '2_2', '2_5', '2_8', '3_3', '3_5', '3_7', '4_4', '4_5', '4_6', '5_1', '5_2', '5_3', '5_4', '5_6', '5_7', '5_8', '5_9', '6_4', '6_5', '6_6', '7_3', '7_5', '7_7', '8_2', '8_5', '8_8', '9_1', '9_5', '9_9']\n  MODALS        : ['img', '1_1', '1_2', '1_3', '1_4', '1_5', '1_6', '1_7', '1_8', '1_9', '2_1', '2_2', '2_3', '2_4', '2_5', '2_6', '2_7', '2_8', '2_9', '3_1', '3_2', '3_3', '3_4', '3_5', '3_6', '3_7', '3_8', '3_9', '4_1', '4_2', '4_3', '4_4', '4_5', '4_6', '4_7', '4_8', '4_9', '5_1', '5_2', '5_3', '5_4', '5_6', '5_7', '5_8', '5_9', '6_1', '6_2', '6_3', '6_4', '6_5', '6_6', '6_7', '6_8', '6_9', '7_1', '7_2', '7_3', '7_4', '7_5', '7_6', '7_7', '7_8', '7_9', '8_1', '8_2', '8_3', '8_4', '8_5', '8_6', '8_7', '8_8', '8_9', '9_1', '9_2', '9_3', '9_4', '9_5', '9_6', '9_7', '9_8', '9_9']\n\nTRAIN:\n  IMAGE_SIZE    : [480, 640]      # training image size in (h, w)\n  BATCH_SIZE    : 2               # batch size used to train\n  EPOCHS        : 500             # number of epochs to train\n  EVAL_START    : 300             # evaluation interval start\n  EVAL_INTERVAL : 1               # evaluation interval during training\n  AMP           : false           # use AMP in training\n  DDP           : true            # use DDP training\n\nLOSS:\n  NAME          : OhemCrossEntropy          # loss function name\n  CLS_WEIGHTS   : false            # use class weights in loss calculation\n\nOPTIMIZER:\n  NAME          : adamw           # optimizer name\n  LR            : 0.00006         # initial learning rate used in optimizer\n  WEIGHT_DECAY  : 0.01            # decay rate used in optimizer \n\nSCHEDULER:\n  NAME          : warmuppolylr    # scheduler name\n  POWER         : 0.9             # scheduler power\n  WARMUP        : 10              # warmup epochs used in scheduler\n  WARMUP_RATIO  : 0.1             # warmup ratio\n  \n\nEVAL:\n  # MODEL_PATH    : 'output/UrbanLF/cmnext_b4_urbanlf_real_rgblf1.pth'\n  # MODEL_PATH    : 'output/UrbanLF/cmnext_b4_urbanlf_real_rgblf8.pth'\n  # MODEL_PATH    : 'output/UrbanLF/cmnext_b4_urbanlf_real_rgblf33.pth'\n  # MODEL_PATH    : 'output/UrbanLF/cmnext_b4_urbanlf_real_rgblf80.pth'\n\n  # MODEL_PATH    : 'output/UrbanLF/cmnext_b4_urbanlf_syn_rgblf1.pth'\n  # MODEL_PATH    : 'output/UrbanLF/cmnext_b4_urbanlf_syn_rgblf8.pth'\n  # MODEL_PATH    : 'output/UrbanLF/cmnext_b4_urbanlf_syn_rgblf33.pth'\n  MODEL_PATH    : 'output/UrbanLF/cmnext_b4_urbanlf_syn_rgblf80.pth'\n  IMAGE_SIZE    : [480, 640]      # eval image size in (h, w)            \n  BATCH_SIZE    : 2               # batch size used to train\n  MSF: \n    ENABLE      : false                                   # multi-scale and flip evaluation  \n    FLIP        : true                                    # use flip in evaluation  \n    SCALES      : [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]       # scales used in MSF evaluation                \n"
  },
  {
    "path": "environment.yaml",
    "content": "name: cmnext\nchannels:\n  - pytorch\n  - conda-forge\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=4.5=1_gnu\n  - _pytorch_select=0.1=cpu_0\n  - blas=1.0=mkl\n  - bzip2=1.0.8=h7f98852_4\n  - ca-certificates=2021.10.26=h06a4308_2\n  - certifi=2021.10.8=py38h06a4308_2\n  - cffi=1.14.6=py38ha65f79e_0\n  - cudatoolkit=11.3.1=h2bc3f7f_2\n  - cudnn=8.2.1.32=h86fa8c9_0\n  - ffmpeg=4.3=hf484d3e_0\n  - freetype=2.10.4=h5ab3b9f_0\n  - future=0.18.2=py38h578d9bd_4\n  - gmp=6.2.1=h58526e2_0\n  - gnutls=3.6.13=h85f3911_1\n  - intel-openmp=2021.3.0=h06a4308_3350\n  - jpeg=9d=h7f8727e_0\n  - lame=3.100=h7f98852_1001\n  - lcms2=2.12=h3be6417_0\n  - ld_impl_linux-64=2.35.1=h7274673_9\n  - libblas=3.9.0=11_linux64_mkl\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=9.3.0=h5101ec6_17\n  - libgomp=9.3.0=h5101ec6_17\n  - libiconv=1.16=h516909a_0\n  - liblapack=3.9.0=11_linux64_mkl\n  - libpng=1.6.37=hbc83047_0\n  - libprotobuf=3.16.0=h780b84a_0\n  - libstdcxx-ng=9.3.0=hd4cf53a_17\n  - libtiff=4.2.0=h85742a9_0\n  - libuv=1.40.0=h7b6447c_0\n  - libwebp-base=1.2.0=h27cfd23_0\n  - lz4-c=1.9.3=h295c915_1\n  - magma=2.5.4=h6103c52_2\n  - mkl=2021.3.0=h06a4308_520\n  - mkl-service=2.4.0=py38h7f8727e_0\n  - mkl_fft=1.3.0=py38h42c9631_2\n  - mkl_random=1.2.2=py38h51133e4_0\n  - nccl=2.11.4.1=hdc17891_0\n  - ncurses=6.2=he6710b0_1\n  - nettle=3.6=he412f7d_0\n  - ninja=1.10.2=hff7bd54_1\n  - numpy=1.21.2=py38h20f2e39_0\n  - numpy-base=1.21.2=py38h79a1101_0\n  - olefile=0.46=pyhd3eb1b0_0\n  - openh264=2.1.1=h780b84a_0\n  - openjpeg=2.4.0=h3ad879b_0\n  - openssl=1.1.1m=h7f8727e_0\n  - pillow=8.3.1=py38h2c7a002_0\n  - pycparser=2.21=pyhd8ed1ab_0\n  - python=3.8.12=h12debd9_0\n  - python_abi=3.8=2_cp38\n  - pytorch=1.9.0=cuda112py38h3d13190_1\n  - pytorch-gpu=1.9.0=cuda112py38h0bbbad9_1\n  - pytorch-mutex=1.0=cuda\n  - readline=8.1=h27cfd23_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sleef=3.5.1=h7f98852_1\n  - sqlite=3.36.0=hc218d9a_0\n  - tk=8.6.11=h1ccaba5_0\n  - torchaudio=0.9.0=py38\n  - torchvision=0.10.0=py38cuda112h04b465a_0_cuda\n  - typing_extensions=3.10.0.2=pyh06a4308_0\n  - xz=5.2.5=h7b6447c_0\n  - yaml=0.2.5=h7b6447c_0\n  - zlib=1.2.11=h7b6447c_3\n  - zstd=1.4.9=haebb681_0\n  - pip:\n    - absl-py==1.2.0\n    - addict==2.4.0\n    - argon2-cffi==21.3.0\n    - argon2-cffi-bindings==21.2.0\n    - asttokens==2.0.5\n    - attrs==21.4.0\n    - backcall==0.2.0\n    - bleach==4.1.0\n    - cachetools==5.0.0\n    - charset-normalizer==2.1.1\n    - cycler==0.11.0\n    - dataclasses==0.6\n    - debugpy==1.5.1\n    - decorator==5.1.1\n    - defusedxml==0.7.1\n    - descartes==1.1.0\n    - easydict==1.9\n    - einops==0.4.1\n    - entrypoints==0.4\n    - executing==0.8.3\n    - fire==0.4.0\n    - fvcore==0.1.5.post20220512\n    - google-auth==2.11.0\n    - google-auth-oauthlib==0.4.6\n    - grpcio==1.48.1\n    - idna==3.3\n    - importlib-metadata==4.12.0\n    - importlib-resources==5.4.0\n    - iopath==0.1.10\n    - ipykernel==6.9.1\n    - ipython==8.1.0\n    - ipython-genutils==0.2.0\n    - ipywidgets==7.6.5\n    - jedi==0.18.1\n    - jinja2==3.0.3\n    - joblib==1.1.0\n    - jsonschema==4.4.0\n    - jupyter==1.0.0\n    - jupyter-client==7.1.2\n    - jupyter-console==6.4.0\n    - jupyter-core==4.9.2\n    - jupyterlab-pygments==0.1.2\n    - jupyterlab-widgets==1.0.2\n    - kiwisolver==1.3.2\n    - markdown==3.4.1\n    - markupsafe==2.1.1\n    - matplotlib==3.4.3\n    - matplotlib-inline==0.1.3\n    - mistune==0.8.4\n    - mmcv-full==1.6.1\n    - nbclient==0.5.11\n    - nbconvert==6.4.2\n    - nbformat==5.1.3\n    - nest-asyncio==1.5.4\n    - notebook==6.4.8\n    - nuscenes-devkit==1.1.9\n    - oauthlib==3.2.1\n    - opencv-python==4.5.3.56\n    - packaging==21.3\n    - pandocfilters==1.5.0\n    - parso==0.8.3\n    - pexpect==4.8.0\n    - pickleshare==0.7.5\n    - pip==22.0.3\n    - plyfile==0.7.4\n    - portalocker==2.5.1\n    - prometheus-client==0.13.1\n    - prompt-toolkit==3.0.28\n    - protobuf==3.18.1\n    - ptyprocess==0.7.0\n    - pure-eval==0.2.2\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycocotools==2.0.4\n    - pygments==2.11.2\n    - pyparsing==3.0.6\n    - pyquaternion==0.9.9\n    - pyrsistent==0.18.1\n    - python-dateutil==2.8.2\n    - pyyaml==6.0\n    - pyzmq==22.3.0\n    - qtconsole==5.2.2\n    - qtpy==2.0.1\n    - requests==2.28.1\n    - requests-oauthlib==1.3.1\n    - rsa==4.9\n    - scikit-learn==1.0.2\n    - scipy==1.7.1\n    - send2trash==1.8.0\n    - setuptools==59.5.0\n    - shapely==1.8.1.post1\n    - stack-data==0.2.0\n    - tabulate==0.8.10\n    - tensorboard==2.10.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.1\n    - tensorboardx==2.4\n    - termcolor==1.1.0\n    - terminado==0.13.1\n    - testpath==0.6.0\n    - threadpoolctl==3.1.0\n    - timm==0.4.12\n    - tornado==6.1\n    - tqdm==4.62.3\n    - traitlets==5.1.1\n    - urllib3==1.26.12\n    - wcwidth==0.2.5\n    - webencodings==0.5.1\n    - werkzeug==2.2.2\n    - wheel==0.37.1\n    - widgetsnbextension==3.5.2\n    - yacs==0.1.8\n    - yapf==0.32.0\n    - zipp==3.7.0"
  },
  {
    "path": "requirements.txt",
    "content": "absl-py==1.2.0\naddict==2.4.0\nargon2-cffi==21.3.0\nargon2-cffi-bindings==21.2.0\nasttokens==2.0.5\nattrs==21.4.0\nbackcall==0.2.0\nbleach==4.1.0\ncachetools==5.0.0\ncharset-normalizer==2.1.1\ncycler==0.11.0\ndataclasses==0.6\ndebugpy==1.5.1\ndecorator==5.1.1\ndefusedxml==0.7.1\ndescartes==1.1.0\neasydict==1.9\neinops==0.4.1\nentrypoints==0.4\nexecuting==0.8.3\nfire==0.4.0\nfvcore==0.1.5.post20220512\ngoogle-auth==2.11.0\ngoogle-auth-oauthlib==0.4.6\ngrpcio==1.48.1\nidna==3.3\nimportlib-metadata==4.12.0\nimportlib-resources==5.4.0\niopath==0.1.10\nipykernel==6.9.1\nipython==8.1.0\nipython-genutils==0.2.0\nipywidgets==7.6.5\njedi==0.18.1\njinja2==3.0.3\njoblib==1.1.0\njsonschema==4.4.0\njupyter==1.0.0\njupyter-client==7.1.2\njupyter-console==6.4.0\njupyter-core==4.9.2\njupyterlab-pygments==0.1.2\njupyterlab-widgets==1.0.2\nkiwisolver==1.3.2\nmarkdown==3.4.1\nmarkupsafe==2.1.1\nmatplotlib==3.4.3\nmatplotlib-inline==0.1.3\nmistune==0.8.4\nmmcv-full==1.6.1\nnbclient==0.5.11\nnbconvert==6.4.2\nnbformat==5.1.3\nnest-asyncio==1.5.4\nnotebook==6.4.8\nnuscenes-devkit==1.1.9\noauthlib==3.2.1\nopencv-python==4.5.3.56\npackaging==21.3\npandocfilters==1.5.0\nparso==0.8.3\npexpect==4.8.0\npickleshare==0.7.5\npip==22.0.3\nplyfile==0.7.4\nportalocker==2.5.1\nprometheus-client==0.13.1\nprompt-toolkit==3.0.28\nprotobuf==3.18.1\nptyprocess==0.7.0\npure-eval==0.2.2\npyasn1==0.4.8\npyasn1-modules==0.2.8\npycocotools==2.0.4\npygments==2.11.2\npyparsing==3.0.6\npyquaternion==0.9.9\npyrsistent==0.18.1\npython-dateutil==2.8.2\npyyaml==6.0\npyzmq==22.3.0\nqtconsole==5.2.2\nqtpy==2.0.1\nrequests==2.28.1\nrequests-oauthlib==1.3.1\nrsa==4.9\nscikit-learn==1.0.2\nscipy==1.7.1\nsend2trash==1.8.0\nsetuptools==59.5.0\nshapely==1.8.1.post1\nstack-data==0.2.0\ntabulate==0.8.10\ntensorboard==2.10.0\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.1\ntensorboardx==2.4\ntermcolor==1.1.0\nterminado==0.13.1\ntestpath==0.6.0\nthreadpoolctl==3.1.0\ntimm==0.4.12\ntornado==6.1\ntqdm==4.62.3\ntraitlets==5.1.1\nurllib3==1.26.12\nwcwidth==0.2.5\nwebencodings==0.5.1\nwerkzeug==2.2.2\nwheel==0.37.1\nwidgetsnbextension==3.5.2\nyacs==0.1.8\nyapf==0.32.0\nzipp==3.7.0"
  },
  {
    "path": "semseg/__init__.py",
    "content": "from tabulate import tabulate\nfrom semseg import models\nfrom semseg import datasets\nfrom semseg.models import backbones, heads\n\n\ndef show_models():\n    model_names = models.__all__\n    numbers = list(range(1, len(model_names)+1))\n    print(tabulate({'No.': numbers, 'Model Names': model_names}, headers='keys'))\n\n\ndef show_backbones():\n    backbone_names = backbones.__all__\n    variants = []\n    for name in backbone_names:\n        try:\n            variants.append(list(eval(f\"backbones.{name.lower()}_settings\").keys()))\n        except:\n            variants.append('-')\n    print(tabulate({'Backbone Names': backbone_names, 'Variants': variants}, headers='keys'))\n\n\ndef show_heads():\n    head_names = heads.__all__\n    numbers = list(range(1, len(head_names)+1))\n    print(tabulate({'No.': numbers, 'Heads': head_names}, headers='keys'))\n\n\ndef show_datasets():\n    dataset_names = datasets.__all__\n    numbers = list(range(1, len(dataset_names)+1))\n    print(tabulate({'No.': numbers, 'Datasets': dataset_names}, headers='keys'))\n"
  },
  {
    "path": "semseg/augmentations.py",
    "content": "import torchvision.transforms.functional as TF \nimport random\nimport math\nimport torch\nfrom torch import Tensor\nfrom typing import Tuple, List, Union, Tuple, Optional\n\n\nclass Compose:\n    def __init__(self, transforms: list) -> None:\n        self.transforms = transforms\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if mask.ndim == 2:\n            assert img.shape[1:] == mask.shape\n        else:\n            assert img.shape[1:] == mask.shape[1:]\n\n        for transform in self.transforms:\n            img, mask = transform(img, mask)\n\n        return img, mask\n\n\nclass Normalize:\n    def __init__(self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        img = img.float()\n        img /= 255\n        img = TF.normalize(img, self.mean, self.std)\n        return img, mask\n\n\nclass ColorJitter:\n    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0) -> None:\n        self.brightness = brightness\n        self.contrast = contrast\n        self.saturation = saturation\n        self.hue = hue\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if self.brightness > 0:\n            img = TF.adjust_brightness(img, self.brightness)\n        if self.contrast > 0:\n            img = TF.adjust_contrast(img, self.contrast)\n        if self.saturation > 0:\n            img = TF.adjust_saturation(img, self.saturation)\n        if self.hue > 0:\n            img = TF.adjust_hue(img, self.hue)\n        return img, mask\n\n\nclass AdjustGamma:\n    def __init__(self, gamma: float, gain: float = 1) -> None:\n        \"\"\"\n        Args:\n            gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.\n            gain: constant multiplier\n        \"\"\"\n        self.gamma = gamma\n        self.gain = gain\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        return TF.adjust_gamma(img, self.gamma, self.gain), mask\n\n\nclass RandomAdjustSharpness:\n    def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:\n        self.sharpness = sharpness_factor\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if random.random() < self.p:\n            img = TF.adjust_sharpness(img, self.sharpness)\n        return img, mask\n\n\nclass RandomAutoContrast:\n    def __init__(self, p: float = 0.5) -> None:\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if random.random() < self.p:\n            img = TF.autocontrast(img)\n        return img, mask\n\n\nclass RandomGaussianBlur:\n    def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None:\n        self.kernel_size = kernel_size\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if random.random() < self.p:\n            img = TF.gaussian_blur(img, self.kernel_size)\n        return img, mask\n\n\nclass RandomHorizontalFlip:\n    def __init__(self, p: float = 0.5) -> None:\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if random.random() < self.p:\n            return TF.hflip(img), TF.hflip(mask)\n        return img, mask\n\n\nclass RandomVerticalFlip:\n    def __init__(self, p: float = 0.5) -> None:\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if random.random() < self.p:\n            return TF.vflip(img), TF.vflip(mask)\n        return img, mask\n\n\nclass RandomGrayscale:\n    def __init__(self, p: float = 0.5) -> None:\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if random.random() < self.p:\n            img = TF.rgb_to_grayscale(img, 3)\n        return img, mask\n\n\nclass Equalize:\n    def __call__(self, image, label):\n        return TF.equalize(image), label\n\n\nclass Posterize:\n    def __init__(self, bits=2):\n        self.bits = bits # 0-8\n        \n    def __call__(self, image, label):\n        return TF.posterize(image, self.bits), label\n\n\nclass Affine:\n    def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0):\n        self.angle = angle\n        self.translate = translate\n        self.scale = scale\n        self.shear = shear\n        self.seg_fill = seg_fill\n        \n    def __call__(self, img, label):\n        return TF.affine(img, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.BILINEAR, 0), TF.affine(label, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.NEAREST, self.seg_fill) \n\n\nclass RandomRotation:\n    def __init__(self, degrees: float = 10.0, p: float = 0.2, seg_fill: int = 0, expand: bool = False) -> None:\n        \"\"\"Rotate the image by a random angle between -angle and angle with probability p\n\n        Args:\n            p: probability\n            angle: rotation angle value in degrees, counter-clockwise.\n            expand: Optional expansion flag. \n                    If true, expands the output image to make it large enough to hold the entire rotated image.\n                    If false or omitted, make the output image the same size as the input image. \n                    Note that the expand flag assumes rotation around the center and no translation.\n        \"\"\"\n        self.p = p\n        self.angle = degrees\n        self.expand = expand\n        self.seg_fill = seg_fill\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        random_angle = random.random() * 2 * self.angle - self.angle\n        if random.random() < self.p:\n            img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)\n            mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)\n        return img, mask\n    \n\nclass CenterCrop:\n    def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None:\n        \"\"\"Crops the image at the center\n\n        Args:\n            output_size: height and width of the crop box. If int, this size is used for both directions.\n        \"\"\"\n        self.size = (size, size) if isinstance(size, int) else size\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        return TF.center_crop(img, self.size), TF.center_crop(mask, self.size)\n\n\nclass RandomCrop:\n    def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None:\n        \"\"\"Randomly Crops the image.\n\n        Args:\n            output_size: height and width of the crop box. If int, this size is used for both directions.\n        \"\"\"\n        self.size = (size, size) if isinstance(size, int) else size\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        H, W = img.shape[1:]\n        tH, tW = self.size\n\n        if random.random() < self.p:\n            margin_h = max(H - tH, 0)\n            margin_w = max(W - tW, 0)\n            y1 = random.randint(0, margin_h+1)\n            x1 = random.randint(0, margin_w+1)\n            y2 = y1 + tH\n            x2 = x1 + tW\n            img = img[:, y1:y2, x1:x2]\n            mask = mask[:, y1:y2, x1:x2]\n        return img, mask\n\n\nclass Pad:\n    def __init__(self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0) -> None:\n        \"\"\"Pad the given image on all sides with the given \"pad\" value.\n        Args:\n            size: expected output image size (h, w)\n            fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant.\n        \"\"\"\n        self.size = size\n        self.seg_fill = seg_fill\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        padding = (0, 0, self.size[1]-img.shape[2], self.size[0]-img.shape[1])\n        return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill)\n\n\nclass ResizePad:\n    def __init__(self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0) -> None:\n        \"\"\"Resize the input image to the given size.\n        Args:\n            size: Desired output size. \n                If size is a sequence, the output size will be matched to this. \n                If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.\n        \"\"\"\n        self.size = size\n        self.seg_fill = seg_fill\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        H, W = img.shape[1:]\n        tH, tW = self.size\n\n        # scale the image \n        scale_factor = min(tH/H, tW/W) if W > H else max(tH/H, tW/W)\n        # nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)\n        nH, nW = round(H*scale_factor), round(W*scale_factor)\n        img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)\n        mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)\n\n        # pad the image\n        padding = [0, 0, tW - nW, tH - nH]\n        img = TF.pad(img, padding, fill=0)\n        mask = TF.pad(mask, padding, fill=self.seg_fill)\n        return img, mask \n\n\nclass Resize:\n    def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None:\n        \"\"\"Resize the input image to the given size.\n        Args:\n            size: Desired output size. \n                If size is a sequence, the output size will be matched to this. \n                If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.\n        \"\"\"\n        self.size = size\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        H, W = img.shape[1:]\n\n        # scale the image \n        scale_factor = self.size[0] / min(H, W)\n        nH, nW = round(H*scale_factor), round(W*scale_factor)\n        img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)\n        mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)\n\n        # make the image divisible by stride\n        alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32\n        img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR)\n        mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST)\n        return img, mask \n\n\nclass RandomResizedCrop:\n    def __init__(self, size: Union[int, Tuple[int], List[int]], scale: Tuple[float, float] = (0.5, 2.0), seg_fill: int = 0) -> None:\n        \"\"\"Resize the input image to the given size.\n        \"\"\"\n        self.size = size\n        self.scale = scale\n        self.seg_fill = seg_fill\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        H, W = img.shape[1:]\n        tH, tW = self.size\n\n        # get the scale\n        ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]\n        # ratio = random.uniform(min(self.scale), max(self.scale))\n        scale = int(tH*ratio), int(tW*4*ratio)\n\n        # scale the image \n        scale_factor = min(max(scale)/max(H, W), min(scale)/min(H, W))\n        nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)\n        # nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32\n        img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)\n        mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)\n\n        # random crop\n        margin_h = max(img.shape[1] - tH, 0)\n        margin_w = max(img.shape[2] - tW, 0)\n        y1 = random.randint(0, margin_h+1)\n        x1 = random.randint(0, margin_w+1)\n        y2 = y1 + tH\n        x2 = x1 + tW\n        img = img[:, y1:y2, x1:x2]\n        mask = mask[:, y1:y2, x1:x2]\n\n        # pad the image\n        if img.shape[1:] != self.size:\n            padding = [0, 0, tW - img.shape[2], tH - img.shape[1]]\n            img = TF.pad(img, padding, fill=0)\n            mask = TF.pad(mask, padding, fill=self.seg_fill)\n        return img, mask \n\n\n\ndef get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0):\n    return Compose([\n        # ColorJitter(brightness=0.0, contrast=0.5, saturation=0.5, hue=0.5),\n        # RandomAdjustSharpness(sharpness_factor=0.1, p=0.5),\n        # RandomAutoContrast(p=0.2),\n        RandomHorizontalFlip(p=0.5),\n        # RandomVerticalFlip(p=0.5),\n        # RandomGaussianBlur((3, 3), p=0.5),\n        # RandomGrayscale(p=0.5),\n        # RandomRotation(degrees=10, p=0.3, seg_fill=seg_fill),\n        RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill),\n        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n\ndef get_val_augmentation(size: Union[int, Tuple[int], List[int]]):\n    return Compose([\n        Resize(size),\n        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n\n\nif __name__ == '__main__':\n    h = 230\n    w = 420\n    img = torch.randn(3, h, w)\n    mask = torch.randn(1, h, w)\n    aug = Compose([\n        RandomResizedCrop((512, 512)),\n        # RandomCrop((512, 512), p=1.0),\n        # Pad((512, 512))\n    ])\n    img, mask = aug(img, mask)\n    print(img.shape, mask.shape)"
  },
  {
    "path": "semseg/augmentations_mm.py",
    "content": "import torchvision.transforms.functional as TF \nimport random\nimport math\nimport torch\nfrom torch import Tensor\nfrom typing import Tuple, List, Union, Tuple, Optional\n\n\nclass Compose:\n    def __init__(self, transforms: list) -> None:\n        self.transforms = transforms\n\n    def __call__(self, sample: list) -> list:\n        img, mask = sample['img'], sample['mask']\n        if mask.ndim == 2:\n            assert img.shape[1:] == mask.shape\n        else:\n            assert img.shape[1:] == mask.shape[1:]\n\n        for transform in self.transforms:\n            sample = transform(sample)\n\n        return sample\n\n\nclass Normalize:\n    def __init__(self, mean: list = (0.485, 0.456, 0.406), std: list = (0.229, 0.224, 0.225)):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, sample: list) -> list:\n        for k, v in sample.items():\n            if k == 'mask':\n                continue\n            elif k == 'img':\n                sample[k] = sample[k].float()\n                sample[k] /= 255\n                sample[k] = TF.normalize(sample[k], self.mean, self.std)\n            else:\n                sample[k] = sample[k].float()\n                sample[k] /= 255\n        \n        return sample\n\n\nclass RandomColorJitter:\n    def __init__(self, p=0.5) -> None:\n        self.p = p\n\n    def __call__(self, sample: list) -> list:\n        if random.random() < self.p:\n            self.brightness = random.uniform(0.5, 1.5)\n            sample['img'] = TF.adjust_brightness(sample['img'], self.brightness)\n            self.contrast = random.uniform(0.5, 1.5)\n            sample['img'] = TF.adjust_contrast(sample['img'], self.contrast)\n            self.saturation = random.uniform(0.5, 1.5)\n            sample['img'] = TF.adjust_saturation(sample['img'], self.saturation)\n        return sample\n\n\nclass AdjustGamma:\n    def __init__(self, gamma: float, gain: float = 1) -> None:\n        \"\"\"\n        Args:\n            gamma: Non-negative real number. gamma larger than 1 make the shadows darker, while gamma smaller than 1 make dark regions lighter.\n            gain: constant multiplier\n        \"\"\"\n        self.gamma = gamma\n        self.gain = gain\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        return TF.adjust_gamma(img, self.gamma, self.gain), mask\n\n\nclass RandomAdjustSharpness:\n    def __init__(self, sharpness_factor: float, p: float = 0.5) -> None:\n        self.sharpness = sharpness_factor\n        self.p = p\n\n    def __call__(self, sample: list) -> list:\n        if random.random() < self.p:\n            sample['img'] = TF.adjust_sharpness(sample['img'], self.sharpness)\n        return sample\n\n\nclass RandomAutoContrast:\n    def __init__(self, p: float = 0.5) -> None:\n        self.p = p\n\n    def __call__(self, sample: list) -> list:\n        if random.random() < self.p:\n            sample['img'] = TF.autocontrast(sample['img'])\n        return sample\n\n\nclass RandomGaussianBlur:\n    def __init__(self, kernel_size: int = 3, p: float = 0.5) -> None:\n        self.kernel_size = kernel_size\n        self.p = p\n\n    def __call__(self, sample: list) -> list:\n        if random.random() < self.p:\n            sample['img'] = TF.gaussian_blur(sample['img'], self.kernel_size)\n            # img = TF.gaussian_blur(img, self.kernel_size)\n        return sample\n\n\nclass RandomHorizontalFlip:\n    def __init__(self, p: float = 0.5) -> None:\n        self.p = p\n\n    def __call__(self, sample: list) -> list:\n        if random.random() < self.p:\n            for k, v in sample.items():\n                sample[k] = TF.hflip(v)\n            return sample\n        return sample\n\n\nclass RandomVerticalFlip:\n    def __init__(self, p: float = 0.5) -> None:\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if random.random() < self.p:\n            return TF.vflip(img), TF.vflip(mask)\n        return img, mask\n\n\nclass RandomGrayscale:\n    def __init__(self, p: float = 0.5) -> None:\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        if random.random() < self.p:\n            img = TF.rgb_to_grayscale(img, 3)\n        return img, mask\n\n\nclass Equalize:\n    def __call__(self, image, label):\n        return TF.equalize(image), label\n\n\nclass Posterize:\n    def __init__(self, bits=2):\n        self.bits = bits # 0-8\n        \n    def __call__(self, image, label):\n        return TF.posterize(image, self.bits), label\n\n\nclass Affine:\n    def __init__(self, angle=0, translate=[0, 0], scale=1.0, shear=[0, 0], seg_fill=0):\n        self.angle = angle\n        self.translate = translate\n        self.scale = scale\n        self.shear = shear\n        self.seg_fill = seg_fill\n        \n    def __call__(self, img, label):\n        return TF.affine(img, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.BILINEAR, 0), TF.affine(label, self.angle, self.translate, self.scale, self.shear, TF.InterpolationMode.NEAREST, self.seg_fill) \n\n\nclass RandomRotation:\n    def __init__(self, degrees: float = 10.0, p: float = 0.2, seg_fill: int = 0, expand: bool = False) -> None:\n        \"\"\"Rotate the image by a random angle between -angle and angle with probability p\n\n        Args:\n            p: probability\n            angle: rotation angle value in degrees, counter-clockwise.\n            expand: Optional expansion flag. \n                    If true, expands the output image to make it large enough to hold the entire rotated image.\n                    If false or omitted, make the output image the same size as the input image. \n                    Note that the expand flag assumes rotation around the center and no translation.\n        \"\"\"\n        self.p = p\n        self.angle = degrees\n        self.expand = expand\n        self.seg_fill = seg_fill\n\n    def __call__(self, sample: list) -> list:\n        random_angle = random.random() * 2 * self.angle - self.angle\n        if random.random() < self.p:\n            for k, v in sample.items():\n                if k == 'mask':                \n                    sample[k] = TF.rotate(v, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)\n                else:\n                    sample[k] = TF.rotate(v, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)\n            # img = TF.rotate(img, random_angle, TF.InterpolationMode.BILINEAR, self.expand, fill=0)\n            # mask = TF.rotate(mask, random_angle, TF.InterpolationMode.NEAREST, self.expand, fill=self.seg_fill)\n        return sample\n    \n\nclass CenterCrop:\n    def __init__(self, size: Union[int, List[int], Tuple[int]]) -> None:\n        \"\"\"Crops the image at the center\n\n        Args:\n            output_size: height and width of the crop box. If int, this size is used for both directions.\n        \"\"\"\n        self.size = (size, size) if isinstance(size, int) else size\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        return TF.center_crop(img, self.size), TF.center_crop(mask, self.size)\n\n\nclass RandomCrop:\n    def __init__(self, size: Union[int, List[int], Tuple[int]], p: float = 0.5) -> None:\n        \"\"\"Randomly Crops the image.\n\n        Args:\n            output_size: height and width of the crop box. If int, this size is used for both directions.\n        \"\"\"\n        self.size = (size, size) if isinstance(size, int) else size\n        self.p = p\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        H, W = img.shape[1:]\n        tH, tW = self.size\n\n        if random.random() < self.p:\n            margin_h = max(H - tH, 0)\n            margin_w = max(W - tW, 0)\n            y1 = random.randint(0, margin_h+1)\n            x1 = random.randint(0, margin_w+1)\n            y2 = y1 + tH\n            x2 = x1 + tW\n            img = img[:, y1:y2, x1:x2]\n            mask = mask[:, y1:y2, x1:x2]\n        return img, mask\n\n\nclass Pad:\n    def __init__(self, size: Union[List[int], Tuple[int], int], seg_fill: int = 0) -> None:\n        \"\"\"Pad the given image on all sides with the given \"pad\" value.\n        Args:\n            size: expected output image size (h, w)\n            fill: Pixel fill value for constant fill. Default is 0. This value is only used when the padding mode is constant.\n        \"\"\"\n        self.size = size\n        self.seg_fill = seg_fill\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        padding = (0, 0, self.size[1]-img.shape[2], self.size[0]-img.shape[1])\n        return TF.pad(img, padding), TF.pad(mask, padding, self.seg_fill)\n\n\nclass ResizePad:\n    def __init__(self, size: Union[int, Tuple[int], List[int]], seg_fill: int = 0) -> None:\n        \"\"\"Resize the input image to the given size.\n        Args:\n            size: Desired output size. \n                If size is a sequence, the output size will be matched to this. \n                If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.\n        \"\"\"\n        self.size = size\n        self.seg_fill = seg_fill\n\n    def __call__(self, img: Tensor, mask: Tensor) -> Tuple[Tensor, Tensor]:\n        H, W = img.shape[1:]\n        tH, tW = self.size\n\n        # scale the image \n        scale_factor = min(tH/H, tW/W) if W > H else max(tH/H, tW/W)\n        # nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)\n        nH, nW = round(H*scale_factor), round(W*scale_factor)\n        img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)\n        mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)\n\n        # pad the image\n        padding = [0, 0, tW - nW, tH - nH]\n        img = TF.pad(img, padding, fill=0)\n        mask = TF.pad(mask, padding, fill=self.seg_fill)\n        return img, mask \n\n\nclass Resize:\n    def __init__(self, size: Union[int, Tuple[int], List[int]]) -> None:\n        \"\"\"Resize the input image to the given size.\n        Args:\n            size: Desired output size. \n                If size is a sequence, the output size will be matched to this. \n                If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio.\n        \"\"\"\n        self.size = size\n\n    def __call__(self, sample:list) -> list:\n        H, W = sample['img'].shape[1:]\n\n        # scale the image \n        scale_factor = self.size[0] / min(H, W)\n        nH, nW = round(H*scale_factor), round(W*scale_factor)\n        for k, v in sample.items():\n            if k == 'mask':                \n                sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.NEAREST)\n            else:\n                sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR)\n        # img = TF.resize(img, (nH, nW), TF.InterpolationMode.BILINEAR)\n        # mask = TF.resize(mask, (nH, nW), TF.InterpolationMode.NEAREST)\n\n        # make the image divisible by stride\n        alignH, alignW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32\n        \n        for k, v in sample.items():\n            if k == 'mask':                \n                sample[k] = TF.resize(v, (alignH, alignW), TF.InterpolationMode.NEAREST)\n            else:\n                sample[k] = TF.resize(v, (alignH, alignW), TF.InterpolationMode.BILINEAR)\n        # img = TF.resize(img, (alignH, alignW), TF.InterpolationMode.BILINEAR)\n        # mask = TF.resize(mask, (alignH, alignW), TF.InterpolationMode.NEAREST)\n        return sample\n\n\nclass RandomResizedCrop:\n    def __init__(self, size: Union[int, Tuple[int], List[int]], scale: Tuple[float, float] = (0.5, 2.0), seg_fill: int = 0) -> None:\n        \"\"\"Resize the input image to the given size.\n        \"\"\"\n        self.size = size\n        self.scale = scale\n        self.seg_fill = seg_fill\n\n    def __call__(self, sample: list) -> list:\n        # img, mask = sample['img'], sample['mask']\n        H, W = sample['img'].shape[1:]\n        tH, tW = self.size\n\n        # get the scale\n        ratio = random.random() * (self.scale[1] - self.scale[0]) + self.scale[0]\n        # ratio = random.uniform(min(self.scale), max(self.scale))\n        scale = int(tH*ratio), int(tW*4*ratio)\n        # scale the image \n        scale_factor = min(max(scale)/max(H, W), min(scale)/min(H, W))\n        nH, nW = int(H * scale_factor + 0.5), int(W * scale_factor + 0.5)\n        # nH, nW = int(math.ceil(nH / 32)) * 32, int(math.ceil(nW / 32)) * 32\n        for k, v in sample.items():\n            if k == 'mask':                \n                sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.NEAREST)\n            else:\n                sample[k] = TF.resize(v, (nH, nW), TF.InterpolationMode.BILINEAR)\n\n        # random crop\n        margin_h = max(sample['img'].shape[1] - tH, 0)\n        margin_w = max(sample['img'].shape[2] - tW, 0)\n        y1 = random.randint(0, margin_h+1)\n        x1 = random.randint(0, margin_w+1)\n        y2 = y1 + tH\n        x2 = x1 + tW\n        for k, v in sample.items():\n            sample[k] = v[:, y1:y2, x1:x2]\n\n        # pad the image\n        if sample['img'].shape[1:] != self.size:\n            padding = [0, 0, tW - sample['img'].shape[2], tH - sample['img'].shape[1]]\n            for k, v in sample.items():\n                if k == 'mask':                \n                    sample[k] = TF.pad(v, padding, fill=self.seg_fill)\n                else:\n                    sample[k] = TF.pad(v, padding, fill=0)\n\n        return sample\n\n\n\ndef get_train_augmentation(size: Union[int, Tuple[int], List[int]], seg_fill: int = 0):\n    return Compose([\n        RandomColorJitter(p=0.2), # \n        RandomHorizontalFlip(p=0.5), #\n        RandomGaussianBlur((3, 3), p=0.2), #\n        RandomResizedCrop(size, scale=(0.5, 2.0), seg_fill=seg_fill), #\n        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n\ndef get_val_augmentation(size: Union[int, Tuple[int], List[int]]):\n    return Compose([\n        Resize(size),\n        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n    ])\n\n\nif __name__ == '__main__':\n    h = 230\n    w = 420\n    sample = {}\n    sample['img'] = torch.randn(3, h, w)\n    sample['depth'] = torch.randn(3, h, w)\n    sample['lidar'] = torch.randn(3, h, w)\n    sample['event'] = torch.randn(3, h, w)\n    sample['mask'] = torch.randn(1, h, w)\n    aug = Compose([\n        RandomHorizontalFlip(p=0.5),\n        RandomResizedCrop((512, 512)),\n        Resize((224, 224)),\n        Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n    ])\n    sample = aug(sample)\n    for k, v in sample.items():\n        print(k, v.shape)"
  },
  {
    "path": "semseg/datasets/__init__.py",
    "content": "from .deliver import DELIVER\nfrom .kitti360 import KITTI360\nfrom .nyu import NYU\nfrom .mfnet import MFNet\nfrom .urbanlf import UrbanLF\nfrom .mcubes import MCubeS\n\n__all__ = [\n    'DELIVER',\n    'KITTI360',\n    'NYU',\n    'MFNet',\n    'UrbanLF',\n    'MCubeS'\n]"
  },
  {
    "path": "semseg/datasets/deliver.py",
    "content": "import os\nimport torch \nimport numpy as np\nfrom torch import Tensor\nfrom torch.utils.data import Dataset\nimport torchvision.transforms.functional as TF \nfrom torchvision import io\nfrom pathlib import Path\nfrom typing import Tuple\nimport glob\nimport einops\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import DistributedSampler, RandomSampler\nfrom semseg.augmentations_mm import get_train_augmentation\n\nclass DELIVER(Dataset):\n    \"\"\"\n    num_classes: 25\n    \"\"\"\n    CLASSES = [\"Building\", \"Fence\", \"Other\", \"Pedestrian\", \"Pole\", \"RoadLine\", \"Road\", \"SideWalk\", \"Vegetation\", \n                \"Cars\", \"Wall\", \"TrafficSign\", \"Sky\", \"Ground\", \"Bridge\", \"RailTrack\", \"GroundRail\", \n                \"TrafficLight\", \"Static\", \"Dynamic\", \"Water\", \"Terrain\", \"TwoWheeler\", \"Bus\", \"Truck\"]\n\n    PALETTE = torch.tensor([[70, 70, 70],\n            [100, 40, 40],\n            [55, 90, 80],\n            [220, 20, 60],\n            [153, 153, 153],\n            [157, 234, 50],\n            [128, 64, 128],\n            [244, 35, 232],\n            [107, 142, 35],\n            [0, 0, 142],\n            [102, 102, 156],\n            [220, 220, 0],\n            [70, 130, 180],\n            [81, 0, 81],\n            [150, 100, 100],\n            [230, 150, 140],\n            [180, 165, 180],\n            [250, 170, 30],\n            [110, 190, 160],\n            [170, 120, 50],\n            [45, 60, 150],\n            [145, 170, 100],\n            [  0,  0, 230], \n            [  0, 60, 100],\n            [  0,  0, 70],\n            ])\n    \n    def __init__(self, root: str = 'data/DELIVER', split: str = 'train', transform = None, modals = ['img'], case = None) -> None:\n        super().__init__()\n        assert split in ['train', 'val', 'test']\n        self.transform = transform\n        self.n_classes = len(self.CLASSES)\n        self.ignore_label = 255\n        self.modals = modals\n        self.files = sorted(glob.glob(os.path.join(*[root, 'img', '*', split, '*', '*.png'])))\n        # --- debug\n        # self.files = sorted(glob.glob(os.path.join(*[root, 'img', '*', split, '*', '*.png'])))[:100]\n        # --- split as case\n        if case is not None:\n            assert case in ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres'], \"Case name not available.\"\n            _temp_files = [f for f in self.files if case in f]\n            self.files = _temp_files\n        if not self.files:\n            raise Exception(f\"No images found in {img_path}\")\n        print(f\"Found {len(self.files)} {split} {case} images.\")\n\n    def __len__(self) -> int:\n        return len(self.files)\n    \n    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:\n        rgb = str(self.files[index])\n        x1 = rgb.replace('/img', '/hha').replace('_rgb', '_depth')\n        x2 = rgb.replace('/img', '/lidar').replace('_rgb', '_lidar')\n        x3 = rgb.replace('/img', '/event').replace('_rgb', '_event')\n        lbl_path = rgb.replace('/img', '/semantic').replace('_rgb', '_semantic')\n\n        sample = {}\n        sample['img'] = io.read_image(rgb)[:3, ...]\n        H, W = sample['img'].shape[1:]\n        if 'depth' in self.modals:\n            sample['depth'] = self._open_img(x1)\n        if 'lidar' in self.modals:\n            sample['lidar'] = self._open_img(x2)\n        if 'event' in self.modals:\n            eimg = self._open_img(x3)\n            sample['event'] = TF.resize(eimg, (H, W), TF.InterpolationMode.NEAREST)\n        label = io.read_image(lbl_path)[0,...].unsqueeze(0)\n        label[label==255] = 0\n        label -= 1\n        sample['mask'] = label\n        \n        if self.transform:\n            sample = self.transform(sample)\n        label = sample['mask']\n        del sample['mask']\n        label = self.encode(label.squeeze().numpy()).long()\n        sample = [sample[k] for k in self.modals]\n        return sample, label\n\n    def _open_img(self, file):\n        img = io.read_image(file)\n        C, H, W = img.shape\n        if C == 4:\n            img = img[:3, ...]\n        if C == 1:\n            img = img.repeat(3, 1, 1)\n        return img\n\n    def encode(self, label: Tensor) -> Tensor:\n        return torch.from_numpy(label)\n\n\nif __name__ == '__main__':\n    cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres']\n    traintransform = get_train_augmentation((1024, 1024), seg_fill=255)\n    for case in cases:\n\n        trainset = DELIVER(transform=traintransform, split='val', case=case)\n        trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=False, pin_memory=False)\n\n        for i, (sample, lbl) in enumerate(trainloader):\n            print(torch.unique(lbl))"
  },
  {
    "path": "semseg/datasets/kitti360.py",
    "content": "import os\nimport torch \nimport numpy as np\nfrom torch import Tensor\nfrom torch.utils.data import Dataset\nfrom torchvision import io\nfrom pathlib import Path\nfrom typing import Tuple\nimport glob\nimport einops\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import DistributedSampler, RandomSampler\nfrom semseg.augmentations_mm import get_train_augmentation\n\nclass KITTI360(Dataset):\n    \"\"\"\n    num_classes: 19\n    \"\"\"\n    CLASSES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', \n                'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle']\n\n    PALETTE = torch.tensor([[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], [107, 142, 35], \n                [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]])\n\n    ID2TRAINID = {0:255, 1:255, 2:255, 3:255, 4:255, 5:255, 6:255, 7:0, 8:1, 9:255, 10:255, 11:2, 12:3, 13:4, 14:255, 15:255, 16:255, 17:5, 18:255, 19:6, \n    20:7, 21:8, 22:9, 23:10, 24:11, 25:12, 26:13, 27:14, 28:15, 29:255, 30:255, 31:16, 32:17, 33:18, 34:2, 35:4, 36:255, 37:5, 38:255, 39:255, 40:255, 41:255, 42:255, 43:255, 44:255, -1:255}\n\n    def __init__(self, root: str = 'data/KITTI360', split: str = 'train', transform = None, modals = ['img', 'depth', 'event', 'lidar'], case = None) -> None:\n        super().__init__()\n        assert split in ['train', 'val']\n        self.root = root\n        self.transform = transform\n        self.n_classes = len(self.CLASSES)\n        self.ignore_label = 255\n        self.modals = modals\n\n        self.label_map = np.arange(256)\n        for id, trainid in self.ID2TRAINID.items():\n            self.label_map[id] = trainid\n\n        self.files = self._get_file_names(split)\n        # --- debug\n        # self.files = self._get_file_names(split)[:100]\n    \n        if not self.files:\n            raise Exception(f\"No images found in {img_path}\")\n        print(f\"Found {len(self.files)} {split} images.\")\n\n    def __len__(self) -> int:\n        return len(self.files)\n    \n    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:\n        item_name = str(self.files[index])\n        rgb = os.path.join(self.root, item_name)\n        x1 = os.path.join(self.root, item_name.replace('data_2d_raw', 'data_2d_hha'))\n        x2 = os.path.join(self.root, item_name.replace('data_2d_raw', 'data_2d_lidar'))\n        x2 = x2.replace('.png', '_color.png')\n        x3 = os.path.join(self.root, item_name.replace('data_2d_raw', 'data_2d_event'))\n        x3 = x3.replace('/image_00/data_rect/', '/').replace('.png', '_event_image.png')\n        lbl_path = os.path.join(*[self.root, item_name.replace('data_2d_raw', 'data_2d_semantics/train').replace('data_rect', 'semantic')])\n\n        sample = {}\n        sample['img'] = io.read_image(rgb)[:3, ...]\n        if 'depth' in self.modals:\n            sample['depth'] = self._open_img(x1)\n        if 'lidar' in self.modals:\n            sample['lidar'] = self._open_img(x2)\n        if 'event' in self.modals:\n            sample['event'] = self._open_img(x3)\n        label = io.read_image(lbl_path)[0,...].unsqueeze(0)\n        sample['mask'] = label\n        \n        if self.transform:\n            sample = self.transform(sample)\n        label = sample['mask']\n        del sample['mask']\n        label = self.encode(label.squeeze().numpy()).long()\n        sample = [sample[k] for k in self.modals]\n        return sample, label\n\n    def _open_img(self, file):\n        img = io.read_image(file)\n        C, H, W = img.shape\n        if C == 4:\n            img = img[:3, ...]\n        if C == 1:\n            img = img.repeat(3, 1, 1)\n        return img\n\n    def encode(self, label: Tensor) -> Tensor:\n        label = self.label_map[label]\n        return torch.from_numpy(label)\n\n    def _get_file_names(self, split_name):\n        assert split_name in ['train', 'val']\n        source = os.path.join(self.root, '{}.txt'.format(split_name))\n        file_names = []\n        with open(source) as f:\n            files = f.readlines()\n        for item in files:\n            file_name = item.strip()\n            if ' ' in file_name:\n                # --- KITTI-360\n                file_name = file_name.split(' ')[0]\n            file_names.append(file_name)\n        return file_names\n\n\nif __name__ == '__main__':\n    traintransform = get_train_augmentation((376, 1408), seg_fill=255)\n\n    trainset = KITTI360(transform=traintransform)\n    trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=True, pin_memory=False)\n\n    for i, (sample, lbl) in enumerate(trainloader):\n        print(torch.unique(lbl))"
  },
  {
    "path": "semseg/datasets/mcubes.py",
    "content": "import os\nimport torch \nimport numpy as np\nfrom torch import Tensor\nfrom torch.utils.data import Dataset\nfrom torchvision import io\nfrom torchvision import transforms\nfrom pathlib import Path\nfrom typing import Tuple\nimport glob\nimport einops\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import DistributedSampler, RandomSampler\nfrom semseg.augmentations_mm import get_train_augmentation\nimport cv2\nimport random\nfrom PIL import Image, ImageOps, ImageFilter\n\nclass MCubeS(Dataset):\n    \"\"\"\n    num_classes: 20\n    \"\"\"\n    CLASSES = ['asphalt','concrete','metal','road_marking','fabric','glass','plaster','plastic','rubber','sand',\n    'gravel','ceramic','cobblestone','brick','grass','wood','leaf','water','human','sky',]\n\n    PALETTE = torch.tensor([[ 44, 160,  44],\n                [ 31, 119, 180],\n                [255, 127,  14],\n                [214,  39,  40],\n                [140,  86,  75],\n                [127, 127, 127],\n                [188, 189,  34],\n                [255, 152, 150],\n                [ 23, 190, 207],\n                [174, 199, 232],\n                [196, 156, 148],\n                [197, 176, 213],\n                [247, 182, 210],\n                [199, 199, 199],\n                [219, 219, 141],\n                [158, 218, 229],\n                [ 57,  59, 121],\n                [107, 110, 207],\n                [156, 158, 222],\n                [ 99, 121,  57]])\n\n    def __init__(self, root: str = 'data/MCubeS/multimodal_dataset', split: str = 'train', transform = None, modals = ['image', 'aolp', 'dolp', 'nir'], case = None) -> None:\n        super().__init__()\n        assert split in ['train', 'val']\n        self.split = split\n        self.root = root\n        self.transform = transform\n        self.n_classes = len(self.CLASSES)\n        self.ignore_label = 255\n        self.modals = modals\n        self._left_offset = 192\n        \t\n        self.img_h = 1024\n        self.img_w = 1224\n        max_dim = max(self.img_h, self.img_w)\n        u_vec = (np.arange(self.img_w)-self.img_w/2)/max_dim*2\n        v_vec = (np.arange(self.img_h)-self.img_h/2)/max_dim*2\n        self.u_map, self.v_map = np.meshgrid(u_vec, v_vec)\n        self.u_map = self.u_map[:,:self._left_offset]\n\n        self.base_size = 512\n        self.crop_size = 512\n        self.files = self._get_file_names(split)\n    \n        if not self.files:\n            raise Exception(f\"No images found in {img_path}\")\n        print(f\"Found {len(self.files)} {split} images.\")\n\n    def __len__(self) -> int:\n        return len(self.files)\n    \n    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:\n        item_name = str(self.files[index])\n        rgb = os.path.join(*[self.root, 'polL_color', item_name+'.png'])\n        x1 = os.path.join(*[self.root, 'polL_aolp_sin', item_name+'.npy'])\n        x1_1 = os.path.join(*[self.root, 'polL_aolp_cos', item_name+'.npy'])\n        x2 = os.path.join(*[self.root, 'polL_dolp', item_name+'.npy'])\n        x3 = os.path.join(*[self.root, 'NIR_warped', item_name+'.png'])\n        lbl_path = os.path.join(*[self.root, 'GT', item_name+'.png'])\n        nir_mask = os.path.join(*[self.root, 'NIR_warped_mask', item_name+'.png'])\n        _mask = os.path.join(*[self.root, 'SS', item_name+'.png'])\n\n        _img = cv2.imread(rgb,-1)[:,:,::-1]\n        _img = _img.astype(np.float32)/65535 if _img.dtype==np.uint16 else _img.astype(np.float32)/255\n        _target = cv2.imread(lbl_path,-1)\n        _mask = cv2.imread(_mask,-1)\n        _aolp_sin = np.load(x1)\n        _aolp_cos = np.load(x1_1)\n        _aolp = np.stack([_aolp_sin, _aolp_cos, _aolp_sin], axis=2) # H x W x 3\n        dolp = np.load(x2)\n        _dolp = np.stack([dolp, dolp, dolp], axis=2) # H x W x 3\n        nir  = cv2.imread(x3,-1)\n        nir = nir.astype(np.float32)/65535 if nir.dtype==np.uint16 else nir.astype(np.float32)/255\n        _nir = np.stack([nir, nir, nir], axis=2) # H x W x 3\n\n        _nir_mask = cv2.imread(nir_mask,0)\n\n        _img, _target, _aolp, _dolp, _nir, _nir_mask, _mask = _img[:,self._left_offset:], _target[:,self._left_offset:], \\\n               _aolp[:,self._left_offset:], _dolp[:,self._left_offset:], \\\n               _nir[:,self._left_offset:], _nir_mask[:,self._left_offset:], _mask[:,self._left_offset:]\n        sample = {'image': _img, 'label': _target, 'aolp': _aolp, 'dolp': _dolp, 'nir': _nir, 'nir_mask': _nir_mask, 'u_map': self.u_map, 'v_map': self.v_map, 'mask':_mask}\n\n        if self.split == \"train\":\n            sample = self.transform_tr(sample)\n        elif self.split == 'val':\n            sample = self.transform_val(sample)\n        elif self.split == 'test':\n            sample = self.transform_val(sample)\n        else:\n            raise NotImplementedError()\n        label = sample['label'].long()\n        sample = [sample[k] for k in self.modals]\n        return sample, label\n\n    def transform_tr(self, sample):\n        composed_transforms = transforms.Compose([\n            RandomHorizontalFlip(),\n            RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size, fill=255),\n            RandomGaussianBlur(),\n            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n            ToTensor()])\n\n        return composed_transforms(sample)\n\n    def transform_val(self, sample):\n        composed_transforms = transforms.Compose([\n            FixScaleCrop(crop_size=1024),\n            Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n            ToTensor()])\n\n        return composed_transforms(sample)\n\n    def _get_file_names(self, split_name):\n        assert split_name in ['train', 'val']\n        source = os.path.join(self.root, 'list_folder/test.txt') if split_name == 'val' else os.path.join(self.root, 'list_folder/train.txt')\n        file_names = []\n        with open(source) as f:\n            files = f.readlines()\n        for item in files:\n            file_name = item.strip()\n            if ' ' in file_name:\n                # --- KITTI-360\n                file_name = file_name.split(' ')[0]\n            file_names.append(file_name)\n        return file_names\n\nclass Normalize(object):\n    \"\"\"Normalize a tensor image with mean and standard deviation.\n    Args:\n        mean (tuple): means for each channel.\n        std (tuple): standard deviations for each channel.\n    \"\"\"\n    def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, sample):\n        img = sample['image']\n        mask = sample['label']\n        img = np.array(img).astype(np.float32)\n        mask = np.array(mask).astype(np.float32)\n        img -= self.mean\n        img /= self.std\n\n        nir = sample['nir']\n        nir = np.array(nir).astype(np.float32)\n        # nir /= 255\n\n        return {'image': img,\n                'label': mask,\n                'aolp' : sample['aolp'], \n                'dolp' : sample['dolp'], \n                'nir'  : nir, \n                'nir_mask': sample['nir_mask'],\n                'u_map': sample['u_map'],\n                'v_map': sample['v_map'],\n                'mask':sample['mask']}\n\n\nclass ToTensor(object):\n    \"\"\"Convert ndarrays in sample to Tensors.\"\"\"\n\n    def __call__(self, sample):\n        # swap color axis because\n        # numpy image: H x W x C\n        # torch image: C X H X W\n        img = sample['image']\n        mask = sample['label']\n        aolp = sample['aolp']\n        dolp = sample['dolp']\n        nir  = sample['nir']\n        nir_mask  = sample['nir_mask']\n        SS=sample['mask']\n\n        img = np.array(img).astype(np.float32).transpose((2, 0, 1))\n        mask = np.array(mask).astype(np.float32)\n        aolp = np.array(aolp).astype(np.float32).transpose((2, 0, 1))\n        dolp = np.array(dolp).astype(np.float32).transpose((2, 0, 1))\n        SS = np.array(SS).astype(np.float32)\n        nir = np.array(nir).astype(np.float32).transpose((2, 0, 1))\n        nir_mask = np.array(nir_mask).astype(np.float32)\n        \n        img = torch.from_numpy(img).float()\n        mask = torch.from_numpy(mask).float()\n        aolp = torch.from_numpy(aolp).float()\n        dolp = torch.from_numpy(dolp).float()\n        SS = torch.from_numpy(SS).float()\n        nir = torch.from_numpy(nir).float()\n        nir_mask = torch.from_numpy(nir_mask).float()\n\n        u_map = sample['u_map']\n        v_map = sample['v_map']\n        u_map = torch.from_numpy(u_map.astype(np.float32)).float()\n        v_map = torch.from_numpy(v_map.astype(np.float32)).float()\n\n        return {'image': img,\n                'label': mask,\n                'aolp' : aolp,\n                'dolp' : dolp,\n                'nir'  : nir,\n                'nir_mask'  : nir_mask,\n                'u_map': u_map,\n                'v_map': v_map,\n                'mask':SS}\n\n\nclass RandomHorizontalFlip(object):\n    def __call__(self, sample):\n        img = sample['image']\n        mask = sample['label']\n        aolp = sample['aolp']\n        dolp = sample['dolp']\n        nir  = sample['nir']\n        nir_mask  = sample['nir_mask']\n        u_map = sample['u_map']\n        v_map = sample['v_map']\n        SS=sample['mask']\n        if random.random() < 0.5:\n            # img = img.transpose(Image.FLIP_LEFT_RIGHT)\n            # mask = mask.transpose(Image.FLIP_LEFT_RIGHT)\n            # nir = nir.transpose(Image.FLIP_LEFT_RIGHT)\n\n            img = img[:,::-1]\n            mask = mask[:,::-1]\n            nir = nir[:,::-1]\n            nir_mask = nir_mask[:,::-1]\n            aolp  = aolp[:,::-1]\n            dolp  = dolp[:,::-1]\n            SS  = SS[:,::-1]\n            u_map = u_map[:,::-1]\n\n        return {'image': img,\n                'label': mask,\n                'aolp' : aolp,\n                'dolp' : dolp,\n                'nir'  : nir,\n                'nir_mask'  : nir_mask,\n                'u_map': u_map,\n                'v_map': v_map,\n                'mask':SS}\n\nclass RandomGaussianBlur(object):\n    def __call__(self, sample):\n        img = sample['image']\n        mask = sample['label']\n        nir  = sample['nir']\n        if random.random() < 0.5:\n            radius = random.random()\n            # img = img.filter(ImageFilter.GaussianBlur(radius=radius))\n            # nir = nir.filter(ImageFilter.GaussianBlur(radius=radius))\n            img = cv2.GaussianBlur(img, (0,0), radius)\n            nir = cv2.GaussianBlur(nir, (0,0), radius)\n\n        return {'image': img,\n                'label': mask,\n                'aolp' : sample['aolp'], \n                'dolp' : sample['dolp'], \n                'nir'  : nir, \n                'nir_mask': sample['nir_mask'],\n                'u_map': sample['u_map'],\n                'v_map': sample['v_map'],\n                'mask':sample['mask']}\n\nclass RandomScaleCrop(object):\n    def __init__(self, base_size, crop_size, fill=255):\n        self.base_size = base_size\n        self.crop_size = crop_size\n        self.fill = fill\n\n    def __call__(self, sample):\n        img = sample['image']\n        mask = sample['label']\n        aolp = sample['aolp']\n        dolp = sample['dolp']\n        nir = sample['nir']\n        nir_mask = sample['nir_mask']\n        SS=sample['mask']\n        # random scale (short edge)\n        short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))\n        # w, h = img.size\n        h, w = img.shape[:2]\n        if h > w:\n            ow = short_size\n            oh = int(1.0 * h * ow / w)\n        else:\n            oh = short_size\n            ow = int(1.0 * w * oh / h)\n\n        # pad crop\n        if short_size < self.crop_size:\n            padh = self.crop_size - oh if oh < self.crop_size else 0\n            padw = self.crop_size - ow if ow < self.crop_size else 0\n            \n        # random crop crop_size\n        # w, h = img.size\n        h, w = img.shape[:2]\n\n        # x1 = random.randint(0, w - self.crop_size)\n        # y1 = random.randint(0, h - self.crop_size)\n        x1 = random.randint(0, max(0, ow - self.crop_size))\n        y1 = random.randint(0, max(0, oh - self.crop_size))\n\n        u_map = sample['u_map']\n        v_map = sample['v_map']\n        u_map    = cv2.resize(u_map,(ow,oh))\n        v_map    = cv2.resize(v_map,(ow,oh))\n        aolp     = cv2.resize(aolp ,(ow,oh))\n        dolp     = cv2.resize(dolp ,(ow,oh))\n        SS     = cv2.resize(SS ,(ow,oh))\n        img      = cv2.resize(img  ,(ow,oh), interpolation=cv2.INTER_LINEAR)\n        mask     = cv2.resize(mask ,(ow,oh), interpolation=cv2.INTER_NEAREST)\n        nir      = cv2.resize(nir  ,(ow,oh), interpolation=cv2.INTER_LINEAR)\n        nir_mask = cv2.resize(nir_mask  ,(ow,oh), interpolation=cv2.INTER_NEAREST)\n        if short_size < self.crop_size:\n            u_map_ = np.zeros((oh+padh,ow+padw))\n            u_map_[:oh,:ow] = u_map\n            u_map = u_map_\n            v_map_ = np.zeros((oh+padh,ow+padw))\n            v_map_[:oh,:ow] = v_map\n            v_map = v_map_\n            aolp_ = np.zeros((oh+padh,ow+padw,3))\n            aolp_[:oh,:ow] = aolp\n            aolp = aolp_\n            dolp_ = np.zeros((oh+padh,ow+padw,3))\n            dolp_[:oh,:ow] = dolp\n            dolp = dolp_\n\n            img_ = np.zeros((oh+padh,ow+padw,3))\n            img_[:oh,:ow] = img\n            img = img_\n            SS_ = np.zeros((oh+padh,ow+padw))\n            SS_[:oh,:ow] = SS\n            SS = SS_\n            mask_ = np.full((oh+padh,ow+padw),self.fill)\n            mask_[:oh,:ow] = mask\n            mask = mask_\n            nir_ = np.zeros((oh+padh,ow+padw,3))\n            nir_[:oh,:ow] = nir\n            nir = nir_\n            nir_mask_ = np.zeros((oh+padh,ow+padw))\n            nir_mask_[:oh,:ow] = nir_mask\n            nir_mask = nir_mask_\n\n        u_map = u_map[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        v_map = v_map[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        aolp  =  aolp[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        dolp  =  dolp[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        img   =   img[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        mask  =  mask[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        nir   =   nir[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        SS   =   SS[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        nir_mask = nir_mask[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        return {'image': img,\n                'label': mask,\n                'aolp' : aolp,\n                'dolp' : dolp,\n                'nir'  : nir,\n                'nir_mask'  : nir_mask,\n                'u_map': u_map,\n                'v_map': v_map,\n                'mask':SS}\n\nclass FixScaleCrop(object):\n    def __init__(self, crop_size):\n        self.crop_size = crop_size\n\n    def __call__(self, sample):\n        img = sample['image']\n        mask = sample['label']\n        aolp = sample['aolp']\n        dolp = sample['dolp']\n        nir = sample['nir']\n        nir_mask = sample['nir_mask']\n        SS = sample['mask']\n\n        # w, h = img.size\n        h, w = img.shape[:2]\n\n        if w > h:\n            oh = self.crop_size\n            ow = int(1.0 * w * oh / h)\n        else:\n            ow = self.crop_size\n            oh = int(1.0 * h * ow / w)\n        # img = img.resize((ow, oh), Image.BILINEAR)\n        # mask = mask.resize((ow, oh), Image.NEAREST)\n        # nir = nir.resize((ow, oh), Image.BILINEAR)\n\n        # center crop\n        # w, h = img.size\n        # h, w = img.shape[:2]\n        x1 = int(round((ow - self.crop_size) / 2.))\n        y1 = int(round((oh - self.crop_size) / 2.))\n        # img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n        # mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n        # nir = nir.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))\n\n        u_map = sample['u_map']\n        v_map = sample['v_map']\n        u_map = cv2.resize(u_map,(ow,oh))\n        v_map = cv2.resize(v_map,(ow,oh))\n        aolp  = cv2.resize(aolp ,(ow,oh))\n        dolp  = cv2.resize(dolp ,(ow,oh))\n        SS  = cv2.resize(SS ,(ow,oh))\n        img   = cv2.resize(img  ,(ow,oh), interpolation=cv2.INTER_LINEAR)\n        mask  = cv2.resize(mask ,(ow,oh), interpolation=cv2.INTER_NEAREST)\n        nir   = cv2.resize(nir  ,(ow,oh), interpolation=cv2.INTER_LINEAR)\n        nir_mask = cv2.resize(nir_mask,(ow,oh), interpolation=cv2.INTER_NEAREST)\n        u_map = u_map[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        v_map = v_map[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        aolp  =  aolp[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        dolp  =  dolp[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        img   =   img[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        mask  =  mask[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        SS  =  SS[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        nir   =   nir[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        nir_mask = nir_mask[y1:y1+self.crop_size, x1:x1+self.crop_size]\n        return {'image': img,\n                'label': mask,\n                'aolp' : aolp,\n                'dolp' : dolp,\n                'nir'  : nir,\n                'nir_mask'  : nir_mask,\n                'u_map': u_map,\n                'v_map': v_map,\n                'mask':SS}\n\n\nif __name__ == '__main__':\n    traintransform = get_train_augmentation((1024, 1224), seg_fill=255)\n\n    trainset = MCubeS(transform=traintransform, split='val')\n    trainloader = DataLoader(trainset, batch_size=1, num_workers=0, drop_last=False, pin_memory=False)\n\n    for i, (sample, lbl) in enumerate(trainloader):\n        print(torch.unique(lbl))"
  },
  {
    "path": "semseg/datasets/mfnet.py",
    "content": "import os\nimport torch \nimport numpy as np\nfrom torch import Tensor\nfrom torch.utils.data import Dataset\nfrom torchvision import io\nfrom pathlib import Path\nfrom typing import Tuple\nimport glob\nimport einops\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import DistributedSampler, RandomSampler\nfrom semseg.augmentations_mm import get_train_augmentation\n\nclass MFNet(Dataset):\n    \"\"\"\n    num_classes: 9\n    \"\"\"\n    CLASSES = ['unlabeled', 'car', 'person', 'bike', 'curve', 'car_stop', 'guardrail', 'color_cone', 'bump']\n    PALETTE = torch.tensor([[64,0,128],[64,64,0],[0,128,192],[0,0,192],[128,128,0],[64,64,128],[192,128,128],[192,64,0]])\n\n    def __init__(self, root: str = 'data/MFNet', split: str = 'train', transform = None, modals = ['img', 'thermal'], case = None) -> None:\n        super().__init__()\n        assert split in ['train', 'val']\n        self.root = root\n        self.transform = transform\n        self.n_classes = len(self.CLASSES)\n        self.ignore_label = 255\n        self.modals = modals\n        self.files = self._get_file_names(split)\n    \n        if not self.files:\n            raise Exception(f\"No images found in {img_path}\")\n        print(f\"Found {len(self.files)} {split} images.\")\n\n    def __len__(self) -> int:\n        return len(self.files)\n    \n    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:\n        item_name = str(self.files[index])\n        rgb = os.path.join(*[self.root, 'rgb', item_name+'.jpg'])\n        x1 = os.path.join(*[self.root, 'ther', item_name+'.jpg'])\n        lbl_path = os.path.join(*[self.root, 'labels', item_name+'.png'])\n        sample = {}\n        sample['img'] = io.read_image(rgb)[:3, ...]\n        if 'thermal' in self.modals:\n            sample['thermal'] = self._open_img(x1)\n        label = io.read_image(lbl_path)[0,...].unsqueeze(0)\n        sample['mask'] = label\n        \n        if self.transform:\n            sample = self.transform(sample)\n        label = sample['mask']\n        del sample['mask']\n        label = self.encode(label.squeeze().numpy()).long()\n        sample = [sample[k] for k in self.modals]\n        return sample, label\n\n    def _open_img(self, file):\n        img = io.read_image(file)\n        C, H, W = img.shape\n        if C == 4:\n            img = img[:3, ...]\n        if C == 1:\n            img = img.repeat(3, 1, 1)\n        return img\n\n    def encode(self, label: Tensor) -> Tensor:\n        return torch.from_numpy(label)\n\n    def _get_file_names(self, split_name):\n        assert split_name in ['train', 'val']\n        source = os.path.join(self.root, 'test.txt') if split_name == 'val' else os.path.join(self.root, 'train.txt')\n        file_names = []\n        with open(source) as f:\n            files = f.readlines()\n        for item in files:\n            file_name = item.strip()\n            if ' ' in file_name:\n                file_name = file_name.split(' ')[0]\n            file_names.append(file_name)\n        return file_names\n\n\nif __name__ == '__main__':\n    traintransform = get_train_augmentation((480, 640), seg_fill=255)\n\n    trainset = MFNet(transform=traintransform)\n    trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=True, pin_memory=False)\n\n    for i, (sample, lbl) in enumerate(trainloader):\n        print(torch.unique(lbl))"
  },
  {
    "path": "semseg/datasets/nyu.py",
    "content": "import os\nimport torch \nimport numpy as np\nfrom torch import Tensor\nfrom torch.utils.data import Dataset\nimport torchvision.transforms.functional as TF \nfrom torchvision import io\nfrom pathlib import Path\nfrom typing import Tuple\nimport glob\nimport einops\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import DistributedSampler, RandomSampler\nfrom semseg.augmentations_mm import get_train_augmentation\n\nclass NYU(Dataset):\n    \"\"\"\n    num_classes: 40\n    \"\"\"\n    CLASSES = ['wall','floor','cabinet','bed','chair','sofa','table','door','window','bookshelf','picture','counter','blinds',\n    'desk','shelves','curtain','dresser','pillow','mirror','floor mat','clothes','ceiling','books','refridgerator',\n    'television','paper','towel','shower curtain','box','whiteboard','person','night stand','toilet',\n    'sink','lamp','bathtub','bag','otherstructure','otherfurniture','otherprop']\n\n    PALETTE = None\n\n    def __init__(self, root: str = 'data/NYUDepthv2', split: str = 'train', transform = None, modals = ['img', 'depth'], case = None) -> None:\n        super().__init__()\n        assert split in ['train', 'val']\n        self.root = root\n        self.transform = transform\n        self.n_classes = len(self.CLASSES)\n        self.ignore_label = 255\n        self.modals = modals\n        self.files = self._get_file_names(split)\n        if not self.files:\n            raise Exception(f\"No images found in {img_path}\")\n        print(f\"Found {len(self.files)} {split} images.\")\n\n    def __len__(self) -> int:\n        return len(self.files)\n    \n    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:\n        item_name = str(self.files[index])\n        rgb = os.path.join(*[self.root, 'RGB', item_name+'.jpg'])\n        x1 = os.path.join(*[self.root, 'HHA', item_name+'.jpg'])\n        lbl_path = os.path.join(*[self.root, 'Label', item_name+'.png'])\n\n        sample = {}\n        sample['img'] = io.read_image(rgb)[:3, ...]\n        if 'depth' in self.modals:\n            sample['depth'] = self._open_img(x1)\n        if 'lidar' in self.modals:\n            raise NotImplementedError()\n        if 'event' in self.modals:\n            raise NotImplementedError()\n        label = io.read_image(lbl_path)[0,...].unsqueeze(0)\n        label[label==255] = 0\n        label -= 1\n        sample['mask'] = label\n        \n        if self.transform:\n            sample = self.transform(sample)\n        label = sample['mask']\n        del sample['mask']\n        label = self.encode(label.squeeze().numpy()).long()\n        sample = [sample[k] for k in self.modals]\n        return sample, label\n\n    def _open_img(self, file):\n        img = io.read_image(file)\n        C, H, W = img.shape\n        if C == 4:\n            img = img[:3, ...]\n        if C == 1:\n            img = img.repeat(3, 1, 1)\n        return img\n\n    def encode(self, label: Tensor) -> Tensor:\n        return torch.from_numpy(label)\n\n    def _get_file_names(self, split_name):\n        assert split_name in ['train', 'val']\n        source = os.path.join(self.root, 'test.txt') if split_name == 'val' else os.path.join(self.root, 'train.txt')\n        file_names = []\n        with open(source) as f:\n            files = f.readlines()\n        for item in files:\n            file_name = item.strip()\n            if ' ' in file_name:\n                file_name = file_name.split(' ')[0]\n            file_names.append(file_name)\n        return file_names\n\n\nif __name__ == '__main__':\n    traintransform = get_train_augmentation((480, 640), seg_fill=255)\n\n    trainset = NYU(transform=traintransform, split='val')\n    trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=True, pin_memory=False)\n\n    for i, (sample, lbl) in enumerate(trainloader):\n        print(torch.unique(lbl))\n"
  },
  {
    "path": "semseg/datasets/unzip.py",
    "content": "import zipfile\n\nwith zipfile.ZipFile(\"data/MCubeS/multimodal_dataset.zip\", \"r\") as zip_ref:\n    for name in zip_ref.namelist():\n        try:\n            zip_ref.extract(name, \"multimodal_dataset_extracted/\")\n        except zipfile.BadZipFile as e:\n            print(e)"
  },
  {
    "path": "semseg/datasets/urbanlf.py",
    "content": "import os\nimport torch \nimport numpy as np\nfrom torch import Tensor\nfrom torch.utils.data import Dataset\nimport torchvision.transforms.functional as TF \nfrom torchvision import io\nfrom pathlib import Path\nfrom typing import Tuple\nimport glob\nimport einops\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import DistributedSampler, RandomSampler\nfrom semseg.augmentations_mm import get_train_augmentation\n\nclass UrbanLF(Dataset):\n    \"\"\"\n    num_classes: 14\n    \"\"\"\n    CLASSES = ['bike','building','fence','others','person','pole','road','sidewalk','traffic sign','vegetation','vehicle','bridge','rider','sky']\n\n    PALETTE = [[168,198,168],[198,0,0],[202,154,198],[0,0,0],[100,198,198],[198,100,0],[52,42,198],[154,52,192],[198,0,168],[0,198,0],[198,186,90],[108,107,161],[156,200,26],[158,179,202]]\n\n    def __init__(self, root: str = 'data/UrBanLF/Syn', split: str = 'train', transform = None, modals = ['img', '5_1', '5_2', '5_3', '5_4', '5_6', '5_7', '5_8', '5_9'], case = None) -> None:\n        super().__init__()\n        assert split in ['train', 'val']\n        self.root = root\n        self.transform = transform\n        self.n_classes = len(self.CLASSES)\n        self.ignore_label = 255\n        self.modals = modals\n        self.files = sorted(glob.glob(os.path.join(*[root, split, '*', '5_5.png'])))\n    \n        if not self.files:\n            raise Exception(f\"No images found in {img_path}\")\n        print(f\"Found {len(self.files)} {split} images.\")\n\n    def __len__(self) -> int:\n        return len(self.files)\n    \n    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:\n        item_name = str(self.files[index])\n        rgb = item_name\n        rgb_dir_name = os.path.dirname(rgb)\n        lf_names = []\n        lf_paths = []\n        for i in range(1, 10):\n            for j in range(1, 10):\n                lf_name = '{}_{}'.format(i, j)\n                if lf_name != '5_5':\n                    if lf_name in self.modals:\n                        lf_names.append(lf_name)\n                        lf_paths.append(os.path.join(rgb_dir_name, lf_name+'.png'))\n        if 'real' in self.root:\n            lbl_path = item_name.replace('5_5', 'label')\n        elif 'Syn' in self.root:\n            lbl_path = item_name.replace('5_5.png', '5_5_label.npy')\n        else:\n            raise NotImplemented\n        sample = {}\n        sample['img'] = io.read_image(rgb)[:3, ...]\n        if len(self.modals) > 1:\n            for i, lf_name in enumerate(lf_names):\n                assert lf_name in lf_paths[i], \"Not matched.\"\n                sample[lf_name] = self._open_img(lf_paths[i])\n        \n        if 'real' in self.root:\n            label = io.read_image(lbl_path)\n            label = self.encode(label.numpy())\n        elif 'Syn' in self.root:\n            label = np.load(lbl_path)\n            label[label==255] = 0\n            label -= 1\n            label = torch.tensor(label[None,...])\n        else:\n            raise NotImplemented\n        sample['mask'] = label\n        \n        if self.transform:\n            sample = self.transform(sample)\n        label = sample['mask']\n        del sample['mask']\n        label = label.long().squeeze(0)\n        sample_list = [sample['img']]\n        sample_list += [sample[k] for k in lf_names]\n        return sample_list, label\n\n    def _open_img(self, file):\n        img = io.read_image(file)\n        C, H, W = img.shape\n        if C == 4:\n            img = img[:3, ...]\n        if C == 1:\n            img = img.repeat(3, 1, 1)\n        return img\n\n    def encode(self, label: Tensor) -> Tensor:\n        label = label.transpose(1,2,0) # C, H, W -> H, W, C\n        label_mask = np.zeros((label.shape[0], label.shape[1]), dtype=np.int16)\n        for ii, lb in enumerate(self.PALETTE):\n            label_mask[np.where(np.all(label == lb, axis=-1))[:2]] = ii\n        label_mask = label_mask[None,...].astype(int)\n        return torch.from_numpy(label_mask)\n\n\nif __name__ == '__main__':\n    traintransform = get_train_augmentation((432, 623), seg_fill=255)\n\n    trainset = UrbanLF(transform=traintransform, modals=['img', '1_2'])\n    trainloader = DataLoader(trainset, batch_size=2, num_workers=2, drop_last=True, pin_memory=False)\n\n    for i, (sample, lbl) in enumerate(trainloader):\n        print(torch.unique(lbl))\n"
  },
  {
    "path": "semseg/losses.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\n\n\nclass CrossEntropy(nn.Module):\n    def __init__(self, ignore_label: int = 255, weight: Tensor = None, aux_weights: list = [1, 0.4, 0.4]) -> None:\n        super().__init__()\n        self.aux_weights = aux_weights\n        self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label)\n\n    def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:\n        # preds in shape [B, C, H, W] and labels in shape [B, H, W]\n        return self.criterion(preds, labels)\n\n    def forward(self, preds, labels: Tensor) -> Tensor:\n        if isinstance(preds, tuple):\n            return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)])\n        return self._forward(preds, labels)\n\n\nclass OhemCrossEntropy(nn.Module):\n    def __init__(self, ignore_label: int = 255, weight: Tensor = None, thresh: float = 0.7, aux_weights: list = [1, 1]) -> None:\n        super().__init__()\n        self.ignore_label = ignore_label\n        self.aux_weights = aux_weights\n        self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float))\n        self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label, reduction='none')\n\n    def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:\n        # preds in shape [B, C, H, W] and labels in shape [B, H, W]\n        n_min = labels[labels != self.ignore_label].numel() // 16\n        loss = self.criterion(preds, labels).view(-1)\n        loss_hard = loss[loss > self.thresh]\n\n        if loss_hard.numel() < n_min:\n            loss_hard, _ = loss.topk(n_min)\n\n        return torch.mean(loss_hard)\n\n    def forward(self, preds, labels: Tensor) -> Tensor:\n        if isinstance(preds, tuple):\n            return sum([w * self._forward(pred, labels) for (pred, w) in zip(preds, self.aux_weights)])\n        return self._forward(preds, labels)\n\n\nclass Dice(nn.Module):\n    def __init__(self, delta: float = 0.5, aux_weights: list = [1, 0.4, 0.4]):\n        \"\"\"\n        delta: Controls weight given to FP and FN. This equals to dice score when delta=0.5\n        \"\"\"\n        super().__init__()\n        self.delta = delta\n        self.aux_weights = aux_weights\n\n    def _forward(self, preds: Tensor, labels: Tensor) -> Tensor:\n        # preds in shape [B, C, H, W] and labels in shape [B, H, W]\n        num_classes = preds.shape[1]\n        labels = F.one_hot(labels, num_classes).permute(0, 3, 1, 2)\n        tp = torch.sum(labels*preds, dim=(2, 3))\n        fn = torch.sum(labels*(1-preds), dim=(2, 3))\n        fp = torch.sum((1-labels)*preds, dim=(2, 3))\n\n        dice_score = (tp + 1e-6) / (tp + self.delta * fn + (1 - self.delta) * fp + 1e-6)\n        dice_score = torch.sum(1 - dice_score, dim=-1)\n\n        dice_score = dice_score / num_classes\n        return dice_score.mean()\n\n    def forward(self, preds, targets: Tensor) -> Tensor:\n        if isinstance(preds, tuple):\n            return sum([w * self._forward(pred, targets) for (pred, w) in zip(preds, self.aux_weights)])\n        return self._forward(preds, targets)\n\n\n__all__ = ['CrossEntropy', 'OhemCrossEntropy', 'Dice']\n\n\ndef get_loss(loss_fn_name: str = 'CrossEntropy', ignore_label: int = 255, cls_weights: Tensor = None):\n    assert loss_fn_name in __all__, f\"Unavailable loss function name >> {loss_fn_name}.\\nAvailable loss functions: {__all__}\"\n    if loss_fn_name == 'Dice':\n        return Dice()\n    return eval(loss_fn_name)(ignore_label, cls_weights)\n\n\nif __name__ == '__main__':\n    pred = torch.randint(0, 19, (2, 19, 480, 640), dtype=torch.float)\n    label = torch.randint(0, 19, (2, 480, 640), dtype=torch.long)\n    loss_fn = Dice()\n    y = loss_fn(pred, label)\n    print(y)"
  },
  {
    "path": "semseg/metrics.py",
    "content": "import torch\nfrom torch import Tensor\nfrom typing import Tuple\n\n\nclass Metrics:\n    def __init__(self, num_classes: int, ignore_label: int, device) -> None:\n        self.ignore_label = ignore_label\n        self.num_classes = num_classes\n        self.hist = torch.zeros(num_classes, num_classes).to(device)\n\n    def update(self, pred: Tensor, target: Tensor) -> None:\n        pred = pred.argmax(dim=1)\n        keep = target != self.ignore_label\n        self.hist += torch.bincount(target[keep] * self.num_classes + pred[keep], minlength=self.num_classes**2).view(self.num_classes, self.num_classes)\n\n    def compute_iou(self) -> Tuple[Tensor, Tensor]:\n        ious = self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1) - self.hist.diag())\n        ious[ious.isnan()]=0.\n        miou = ious.mean().item()\n        # miou = ious[~ious.isnan()].mean().item()\n        ious *= 100\n        miou *= 100\n        return ious.cpu().numpy().round(2).tolist(), round(miou, 2)\n\n    def compute_f1(self) -> Tuple[Tensor, Tensor]:\n        f1 = 2 * self.hist.diag() / (self.hist.sum(0) + self.hist.sum(1))\n        f1[f1.isnan()]=0.\n        mf1 = f1.mean().item()\n        # mf1 = f1[~f1.isnan()].mean().item()\n        f1 *= 100\n        mf1 *= 100\n        return f1.cpu().numpy().round(2).tolist(), round(mf1, 2)\n\n    def compute_pixel_acc(self) -> Tuple[Tensor, Tensor]:\n        acc = self.hist.diag() / self.hist.sum(1)\n        acc[acc.isnan()]=0.\n        macc = acc.mean().item()\n        # macc = acc[~acc.isnan()].mean().item()\n        acc *= 100\n        macc *= 100\n        return acc.cpu().numpy().round(2).tolist(), round(macc, 2)\n\n"
  },
  {
    "path": "semseg/models/__init__.py",
    "content": "from .cmx import CMX\nfrom .cmnext import CMNeXt\n\n__all__ = [\n    'CMX',\n    'CMNeXt',\n]"
  },
  {
    "path": "semseg/models/backbones/__init__.py",
    "content": "from .cmx import CMX\nfrom .cmnext import CMNeXt\n\n__all__ = [\n    'CMX', \n    'CMNeXt',\n]"
  },
  {
    "path": "semseg/models/backbones/cmnext.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom semseg.models.layers import DropPath\nimport functools\nfrom functools import partial\nfrom fvcore.nn import flop_count_table, FlopCountAnalysis\nfrom semseg.models.modules.ffm import FeatureFusionModule as FFM\nfrom semseg.models.modules.ffm import FeatureRectifyModule as FRM\nfrom semseg.models.modules.ffm import ChannelEmbed\nfrom semseg.models.modules.mspa import MSPABlock\nfrom semseg.utils.utils import nchw_to_nlc, nlc_to_nchw\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, head, sr_ratio):\n        super().__init__()\n        self.head = head\n        self.sr_ratio = sr_ratio \n        self.scale = (dim // head) ** -0.5\n        self.q = nn.Linear(dim, dim)\n        self.kv = nn.Linear(dim, dim*2)\n        self.proj = nn.Linear(dim, dim)\n\n        if sr_ratio > 1:\n            self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)\n            self.norm = nn.LayerNorm(dim)\n\n    def forward(self, x: Tensor, H, W) -> Tensor:\n        B, N, C = x.shape\n        q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)\n\n        if self.sr_ratio > 1:\n            x = x.permute(0, 2, 1).reshape(B, C, H, W)\n            x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)\n            x = self.norm(x)\n            \n        k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        return x\n\n\nclass DWConv(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)\n\n    def forward(self, x: Tensor, H, W) -> Tensor:\n        B, _, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        return x.flatten(2).transpose(1, 2)\n\n\nclass MLP(nn.Module):\n    def __init__(self, c1, c2):\n        super().__init__()\n        self.fc1 = nn.Linear(c1, c2)\n        self.dwconv = DWConv(c2)\n        self.fc2 = nn.Linear(c2, c1)\n        \n    def forward(self, x: Tensor, H, W) -> Tensor:\n        return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))\n\n\nclass PatchEmbed(nn.Module):\n    def __init__(self, c1=3, c2=32, patch_size=7, stride=4, padding=0):\n        super().__init__()\n        self.proj = nn.Conv2d(c1, c2, patch_size, stride, padding)    # padding=(ps[0]//2, ps[1]//2)\n        self.norm = nn.LayerNorm(c2)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.proj(x)\n        _, _, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        return x, H, W\n\nclass PatchEmbedParallel(nn.Module):\n    def __init__(self, c1=3, c2=32, patch_size=7, stride=4, padding=0, num_modals=4):\n        super().__init__()\n        self.proj = ModuleParallel(nn.Conv2d(c1, c2, patch_size, stride, padding))    # padding=(ps[0]//2, ps[1]//2)\n        self.norm = LayerNormParallel(c2, num_modals)\n\n    def forward(self, x: list) -> list:\n        x = self.proj(x)\n        _, _, H, W = x[0].shape\n        x = self.norm(x)\n        return x, H, W\n\nclass Block(nn.Module):\n    def __init__(self, dim, head, sr_ratio=1, dpr=0., is_fan=False):\n        super().__init__()\n        self.norm1 = nn.LayerNorm(dim)\n        self.attn = Attention(dim, head, sr_ratio)\n        self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()\n        self.norm2 = nn.LayerNorm(dim)\n        self.mlp = MLP(dim, int(dim*4)) if not is_fan else ChannelProcessing(dim, mlp_hidden_dim=int(dim*4))\n\n    def forward(self, x: Tensor, H, W) -> Tensor:\n        x = x + self.drop_path(self.attn(self.norm1(x), H, W))\n        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))\n        return x\n\n\nclass ChannelProcessing(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., drop_path=0., mlp_hidden_dim=None, norm_layer=nn.LayerNorm):\n        super().__init__()\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n        self.dim = dim\n        self.num_heads = num_heads\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.mlp_v = MLP(dim, mlp_hidden_dim)\n        self.norm_v = norm_layer(dim)\n\n        self.q = nn.Linear(dim, dim, bias=qkv_bias)\n        self.pool = nn.AdaptiveAvgPool2d((None, 1))\n        self.sigmoid = nn.Sigmoid()\n        \n    def forward(self, x, H, W, atten=None):\n        B, N, C = x.shape\n\n        v = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)\n        k = x.reshape(B, N, self.num_heads,  C // self.num_heads).permute(0, 2, 1, 3)\n\n        q = q.softmax(-2).transpose(-1,-2)\n        _, _, Nk, Ck  = k.shape\n        k = k.softmax(-2)\n        k = torch.nn.functional.avg_pool2d(k, (1, Ck))\n        \n        attn = self.sigmoid(q @ k)\n\n        Bv, Hd, Nv, Cv = v.shape\n        v = self.norm_v(self.mlp_v(v.transpose(1, 2).reshape(Bv, Nv, Hd*Cv), H, W)).reshape(Bv, Nv, Hd, Cv).transpose(1, 2)\n        x = (attn * v.transpose(-1, -2)).permute(0, 3, 1, 2).reshape(B, N, C)\n        return x \n\n\nclass PredictorConv(nn.Module):\n    def __init__(self, embed_dim=384, num_modals=4):\n        super().__init__()\n        self.num_modals = num_modals\n        self.score_nets = nn.ModuleList([nn.Sequential(\n            nn.Conv2d(embed_dim, embed_dim, 3, 1, 1, groups=(embed_dim)),\n            nn.Conv2d(embed_dim, 1, 1),\n            nn.Sigmoid()\n        )for _ in range(num_modals)])\n\n    def forward(self, x):\n        B, C, H, W = x[0].shape\n        x_ = [torch.zeros((B, 1, H, W)) for _ in range(self.num_modals)]\n        for i in range(self.num_modals):\n            x_[i] = self.score_nets[i](x[i])\n        return x_\n\nclass ModuleParallel(nn.Module):\n    def __init__(self, module):\n        super(ModuleParallel, self).__init__()\n        self.module = module\n\n    def forward(self, x_parallel):\n        return [self.module(x) for x in x_parallel]\n\nclass ConvLayerNorm(nn.Module):\n    \"\"\"Channel first layer norm\n    \"\"\"\n    def __init__(self, normalized_shape, eps=1e-6) -> None:\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(normalized_shape))\n        self.bias = nn.Parameter(torch.zeros(normalized_shape))\n        self.eps = eps\n\n    def forward(self, x: Tensor) -> Tensor:\n        u = x.mean(1, keepdim=True)\n        s = (x - u).pow(2).mean(1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.eps)\n        x = self.weight[:, None, None] * x + self.bias[:, None, None]\n        return x\n\nclass LayerNormParallel(nn.Module):\n    def __init__(self, num_features, num_modals=4):\n        super(LayerNormParallel, self).__init__()\n        # self.num_modals = num_modals\n        for i in range(num_modals):\n            setattr(self, 'ln_' + str(i), ConvLayerNorm(num_features, eps=1e-6))\n\n    def forward(self, x_parallel):\n        return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)]\n\n\ncmnext_settings = {\n    # 'B0': [[32, 64, 160, 256], [2, 2, 2, 2]],\n    # 'B1': [[64, 128, 320, 512], [2, 2, 2, 2]],\n    'B2': [[64, 128, 320, 512], [3, 4, 6, 3]],\n    # 'B3': [[64, 128, 320, 512], [3, 4, 18, 3]],\n    'B4': [[64, 128, 320, 512], [3, 8, 27, 3]],\n    'B5': [[64, 128, 320, 512], [3, 6, 40, 3]]\n}\n\n\nclass CMNeXt(nn.Module):\n    def __init__(self, model_name: str = 'B0', modals: list = ['rgb', 'depth', 'event', 'lidar']):\n        super().__init__()\n        assert model_name in cmnext_settings.keys(), f\"Model name should be in {list(cmnext_settings.keys())}\"\n        embed_dims, depths = cmnext_settings[model_name]\n        extra_depths = depths \n        self.modals = modals[1:] if len(modals)>1 else []  \n        self.num_modals = len(self.modals)\n        drop_path_rate = 0.1\n        self.channels = embed_dims\n        norm_cfg = dict(type='BN', requires_grad=True)\n\n        # patch_embed\n        self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4, 7//2)\n        self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2, 3//2)\n        self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2, 3//2)\n        self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2, 3//2)\n   \n        if self.num_modals > 0:\n            self.extra_downsample_layers = nn.ModuleList([\n                PatchEmbedParallel(3, embed_dims[0], 7, 4, 7//2, self.num_modals),\n                *[PatchEmbedParallel(embed_dims[i], embed_dims[i+1], 3, 2, 3//2, self.num_modals) for i in range(3)]\n            ])\n        if self.num_modals > 1:\n            self.extra_score_predictor = nn.ModuleList([PredictorConv(embed_dims[i], self.num_modals) for i in range(len(depths))])\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n        \n        cur = 0\n        self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])\n        self.norm1 = nn.LayerNorm(embed_dims[0])\n        if self.num_modals > 0:\n            self.extra_block1 = nn.ModuleList([MSPABlock(embed_dims[0], mlp_ratio=8, drop_path=dpr[cur+i], norm_cfg=norm_cfg) for i in range(extra_depths[0])]) # --- MSPABlock\n            self.extra_norm1 = ConvLayerNorm(embed_dims[0])\n            \n        cur += depths[0]\n        self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])\n        self.norm2 = nn.LayerNorm(embed_dims[1])\n        if self.num_modals > 0:\n            self.extra_block2 = nn.ModuleList([MSPABlock(embed_dims[1], mlp_ratio=8, drop_path=dpr[cur+i], norm_cfg=norm_cfg) for i in range(extra_depths[1])])\n            self.extra_norm2 = ConvLayerNorm(embed_dims[1])\n\n        cur += depths[1]\n        self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])\n        self.norm3 = nn.LayerNorm(embed_dims[2])\n        if self.num_modals > 0:\n            self.extra_block3 = nn.ModuleList([MSPABlock(embed_dims[2], mlp_ratio=4, drop_path=dpr[cur+i], norm_cfg=norm_cfg) for i in range(extra_depths[2])])\n            self.extra_norm3 = ConvLayerNorm(embed_dims[2])\n\n        cur += depths[2]\n        self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])\n        self.norm4 = nn.LayerNorm(embed_dims[3])\n        if self.num_modals > 0:\n            self.extra_block4 = nn.ModuleList([MSPABlock(embed_dims[3], mlp_ratio=4, drop_path=dpr[cur+i], norm_cfg=norm_cfg) for i in range(extra_depths[3])])\n            self.extra_norm4 = ConvLayerNorm(embed_dims[3])\n\n        if self.num_modals > 0:\n            num_heads = [1,2,5,8]\n            self.FRMs = nn.ModuleList([\n                FRM(dim=embed_dims[0], reduction=1),\n                FRM(dim=embed_dims[1], reduction=1),\n                FRM(dim=embed_dims[2], reduction=1),\n                FRM(dim=embed_dims[3], reduction=1)])\n            self.FFMs = nn.ModuleList([\n                FFM(dim=embed_dims[0], reduction=1, num_heads=num_heads[0], norm_layer=nn.BatchNorm2d),\n                FFM(dim=embed_dims[1], reduction=1, num_heads=num_heads[1], norm_layer=nn.BatchNorm2d),\n                FFM(dim=embed_dims[2], reduction=1, num_heads=num_heads[2], norm_layer=nn.BatchNorm2d),\n                FFM(dim=embed_dims[3], reduction=1, num_heads=num_heads[3], norm_layer=nn.BatchNorm2d)])\n\n    def tokenselect(self, x_ext, module):    \n        x_scores = module(x_ext) \n        for i in range(len(x_ext)):\n            x_ext[i] = x_scores[i] * x_ext[i] + x_ext[i]\n        x_f = functools.reduce(torch.max, x_ext)\n        return x_f\n     \n    def forward(self, x: list) -> list:\n        x_cam = x[0]        \n        if self.num_modals > 0:\n            x_ext = x[1:]\n        B = x_cam.shape[0]\n        outs = []\n        # stage 1\n        x_cam, H, W = self.patch_embed1(x_cam)\n        for blk in self.block1:\n            x_cam = blk(x_cam, H, W)\n        x1_cam = self.norm1(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n        if self.num_modals > 0:\n            x_ext, _, _ = self.extra_downsample_layers[0](x_ext)\n            x_f = self.tokenselect(x_ext, self.extra_score_predictor[0]) if self.num_modals > 1 else x_ext[0] \n            for blk in self.extra_block1:\n                x_f = blk(x_f)\n            x1_f = self.extra_norm1(x_f)\n            x1_cam, x1_f = self.FRMs[0](x1_cam, x1_f)\n            x_fused = self.FFMs[0](x1_cam, x1_f)\n            outs.append(x_fused)\n            x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x1_f for x_ in x_ext] if self.num_modals > 1 else [x1_f]\n        else:\n            outs.append(x1_cam)\n\n        # stage 2\n        x_cam, H, W = self.patch_embed2(x1_cam)\n        for blk in self.block2:\n            x_cam = blk(x_cam, H, W)\n        x2_cam = self.norm2(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n        if self.num_modals > 0:\n            x_ext, _, _ = self.extra_downsample_layers[1](x_ext)\n            x_f = self.tokenselect(x_ext, self.extra_score_predictor[1]) if self.num_modals > 1 else x_ext[0] \n            for blk in self.extra_block2:\n                x_f = blk(x_f)\n            \n            x2_f = self.extra_norm2(x_f)\n            x2_cam, x2_f = self.FRMs[1](x2_cam, x2_f)\n            x_fused = self.FFMs[1](x2_cam, x2_f)\n            outs.append(x_fused)\n            x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x2_f for x_ in x_ext] if self.num_modals > 1 else [x2_f]\n        else:\n            outs.append(x2_cam)\n\n        # stage 3\n        x_cam, H, W = self.patch_embed3(x2_cam)\n        for blk in self.block3:\n            x_cam = blk(x_cam, H, W)\n        x3_cam = self.norm3(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n        if self.num_modals > 0:\n            x_ext, _, _ = self.extra_downsample_layers[2](x_ext)\n            x_f = self.tokenselect(x_ext, self.extra_score_predictor[2]) if self.num_modals > 1 else x_ext[0] \n            for blk in self.extra_block3:\n                x_f = blk(x_f)\n            \n            x3_f = self.extra_norm3(x_f)\n            x3_cam, x3_f = self.FRMs[2](x3_cam, x3_f)\n            x_fused = self.FFMs[2](x3_cam, x3_f)\n            outs.append(x_fused)\n            x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x3_f for x_ in x_ext] if self.num_modals > 1 else [x3_f]\n        else:\n            outs.append(x3_cam)\n\n        # stage 4\n        x_cam, H, W = self.patch_embed4(x3_cam)\n        for blk in self.block4:\n            x_cam = blk(x_cam, H, W)\n        x4_cam = self.norm4(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n        if self.num_modals > 0:\n            x_ext, _, _ = self.extra_downsample_layers[3](x_ext)\n            x_f = self.tokenselect(x_ext, self.extra_score_predictor[3]) if self.num_modals > 1 else x_ext[0] \n            for blk in self.extra_block4:\n                x_f = blk(x_f)\n            \n            x4_f = self.extra_norm4(x_f)\n            x4_cam, x4_f = self.FRMs[3](x4_cam, x4_f)\n            x_fused = self.FFMs[3](x4_cam, x4_f)\n            outs.append(x_fused)\n        else:\n            outs.append(x4_cam)\n\n        return outs\n\n\nif __name__ == '__main__':\n    modals = ['img', 'depth', 'event', 'lidar']\n    x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3]\n    model = CMNeXt('B2', modals)\n    outs = model(x)\n    for y in outs:\n        print(y.shape)\n\n"
  },
  {
    "path": "semseg/models/backbones/cmx.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nimport einops\nfrom semseg.models.layers import DropPath\nfrom semseg.models.modules.ffm import FeatureFusionModule as FFM\nfrom semseg.models.modules.ffm import FeatureRectifyModule as FRM\n\nclass Attention(nn.Module):\n    def __init__(self, dim, head, sr_ratio):\n        super().__init__()\n        self.head = head\n        self.sr_ratio = sr_ratio \n        self.scale = (dim // head) ** -0.5\n        self.q = nn.Linear(dim, dim)\n        self.kv = nn.Linear(dim, dim*2)\n        self.proj = nn.Linear(dim, dim)\n\n        if sr_ratio > 1:\n            self.sr = nn.Conv2d(dim, dim, sr_ratio, sr_ratio)\n            self.norm = nn.LayerNorm(dim)\n\n    def forward(self, x: Tensor, H, W) -> Tensor:\n        B, N, C = x.shape\n        q = self.q(x).reshape(B, N, self.head, C // self.head).permute(0, 2, 1, 3)\n\n        if self.sr_ratio > 1:\n            x = x.permute(0, 2, 1).reshape(B, C, H, W)\n            x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)\n            x = self.norm(x)\n            \n        k, v = self.kv(x).reshape(B, -1, 2, self.head, C // self.head).permute(2, 0, 3, 1, 4)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        return x\n\n\nclass DWConv(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim)\n\n    def forward(self, x: Tensor, H, W) -> Tensor:\n        B, _, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        return x.flatten(2).transpose(1, 2)\n\n\nclass MLP(nn.Module):\n    def __init__(self, c1, c2):\n        super().__init__()\n        self.fc1 = nn.Linear(c1, c2)\n        self.dwconv = DWConv(c2)\n        self.fc2 = nn.Linear(c2, c1)\n        \n    def forward(self, x: Tensor, H, W) -> Tensor:\n        return self.fc2(F.gelu(self.dwconv(self.fc1(x), H, W)))\n\n\nclass PatchEmbed(nn.Module):\n    def __init__(self, c1=3, c2=32, patch_size=7, stride=4):\n        super().__init__()\n        self.proj = nn.Conv2d(c1, c2, patch_size, stride, patch_size//2)    # padding=(ps[0]//2, ps[1]//2)\n        self.norm = nn.LayerNorm(c2)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.proj(x)\n        _, _, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        return x, H, W\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, head, sr_ratio=1, dpr=0.):\n        super().__init__()\n        self.norm1 = nn.LayerNorm(dim)\n        self.attn = Attention(dim, head, sr_ratio)\n        self.drop_path = DropPath(dpr) if dpr > 0. else nn.Identity()\n        self.norm2 = nn.LayerNorm(dim)\n        self.mlp = MLP(dim, int(dim*4))\n\n    def forward(self, x: Tensor, H, W) -> Tensor:\n        x = x + self.drop_path(self.attn(self.norm1(x), H, W))\n        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))\n        return x\n\n# class PredictorConv(nn.Module):\n#     \"\"\" Image to modality score, in spatial selection\n#     b, h, w, c -> b, h, w, 1\n#     \"\"\"\n#     def __init__(self, embed_dim=384, num_modals=4):\n#         super().__init__()\n#         self.num_modals = num_modals\n#         self.dwconv = ModuleParallel(nn.Conv2d(embed_dim, embed_dim, 7, 1, 3, groups=embed_dim))\n#         self.score_nets = nn.ModuleList([nn.Sequential(\n#             nn.LayerNorm(embed_dim, eps=1e-6),\n#             # nn.Linear(embed_dim, embed_dim),\n#             # nn.GELU(),\n#             # nn.Linear(embed_dim, embed_dim // 2),\n#             # nn.GELU(),\n#             # nn.Linear(embed_dim // 2, embed_dim // 4),\n#             # nn.GELU(),\n#             # nn.Linear(embed_dim // 4, 1),\n#             # nn.Linear(embed_dim, embed_dim // 4),\n#             # nn.GELU(),\n#             nn.Linear(embed_dim, 1),\n#             # nn.Sigmoid()\n#             # nn.LogSoftmax(dim=-1)\n#             nn.Softmax(dim=-1)\n#         ) for _ in range(num_modals)])\n\n\n#     def forward(self, x):\n#         x = self.dwconv(x)\n#         x = [x_.permute(0, 2, 3, 1) for x_ in x]   # NCHW to NHWC\n#         x = [self.score_nets[i](x[i]) for i in range(self.num_modals)]\n#         x = [x_.permute(0, 3, 1, 2) for x_ in x]   # NHWC to NCHW \n#         # for i, (xi, rat) in enumerate(zip(x, [1, 1, 0.5, 0.2])):\n#             # x[i] = xi * rat\n#         return x\n\nclass PredictorLG(nn.Module):\n    \"\"\" Image to Patch Embedding from DydamicVit\n    \"\"\"\n    def __init__(self, embed_dim=384, num_modals=4):\n        super().__init__()\n        self.num_modals = num_modals\n        self.score_nets = nn.ModuleList([nn.Sequential(\n            nn.LayerNorm(embed_dim),\n            nn.Linear(embed_dim, embed_dim // 4),\n            nn.GELU(),\n            nn.Linear(embed_dim // 4, 1),\n            nn.Softmax(dim=-1)\n        ) for _ in range(num_modals)])\n\n    def forward(self, x):\n        x = [self.score_nets[i](x[i]) for i in range(self.num_modals)]\n        return x\n\nmit_settings = {\n    'B0': [[32, 64, 160, 256], [2, 2, 2, 2]],        # [embed_dims, depths]\n    'B1': [[64, 128, 320, 512], [2, 2, 2, 2]],\n    'B2': [[64, 128, 320, 512], [3, 4, 6, 3]],\n    'B3': [[64, 128, 320, 512], [3, 4, 18, 3]],\n    'B4': [[64, 128, 320, 512], [3, 8, 27, 3]],\n    'B5': [[64, 128, 320, 512], [3, 6, 40, 3]]\n}\n\n\nclass CMX(nn.Module):\n    def __init__(self, model_name: str = 'B0', modals: list = ['rgb', 'depth', 'event', 'lidar']):\n        super().__init__()\n        assert model_name in mit_settings.keys(), f\"Model name should be in {list(mit_settings.keys())}\"\n        embed_dims, depths = mit_settings[model_name]\n        extra_depths = depths # for fusion branch\n        self.modals = modals[1:] if len(modals)>1 else []  # remove rgb\n        self.num_modals = len(self.modals)\n        drop_path_rate = 0.1\n        self.channels = embed_dims\n\n        # patch_embed\n        self.patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)\n        self.patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)\n        self.patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)\n        self.patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)\n\n        \n        if self.num_modals > 0:\n            self.extra_patch_embed1 = PatchEmbed(3, embed_dims[0], 7, 4)\n            self.extra_patch_embed2 = PatchEmbed(embed_dims[0], embed_dims[1], 3, 2)\n            self.extra_patch_embed3 = PatchEmbed(embed_dims[1], embed_dims[2], 3, 2)\n            self.extra_patch_embed4 = PatchEmbed(embed_dims[2], embed_dims[3], 3, 2)\n\n        if self.num_modals > 1:\n            # self.extra_score_predictor = nn.ModuleList([PredictorConv(embed_dims[i], self.num_modals) for i in range(len(depths))])\n            self.extra_score_predictor = nn.ModuleList([PredictorLG(embed_dims[i], self.num_modals) for i in range(len(depths))])\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n        \n        cur = 0\n        self.block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])\n        self.norm1 = nn.LayerNorm(embed_dims[0])\n        if self.num_modals > 0:\n            self.extra_block1 = nn.ModuleList([Block(embed_dims[0], 1, 8, dpr[cur+i]) for i in range(depths[0])])\n            self.extra_norm1 = nn.LayerNorm(embed_dims[0])\n\n        cur += depths[0]\n        self.block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])\n        self.norm2 = nn.LayerNorm(embed_dims[1])\n        if self.num_modals > 0:\n            self.extra_block2 = nn.ModuleList([Block(embed_dims[1], 2, 4, dpr[cur+i]) for i in range(depths[1])])\n            self.extra_norm2 = nn.LayerNorm(embed_dims[1])\n\n        cur += depths[1]\n        self.block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])\n        self.norm3 = nn.LayerNorm(embed_dims[2])\n        if self.num_modals > 0:\n            self.extra_block3 = nn.ModuleList([Block(embed_dims[2], 5, 2, dpr[cur+i]) for i in range(depths[2])])\n            self.extra_norm3 = nn.LayerNorm(embed_dims[2])\n\n        cur += depths[2]\n        self.block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])\n        self.norm4 = nn.LayerNorm(embed_dims[3])\n        if self.num_modals > 0:\n            self.extra_block4 = nn.ModuleList([Block(embed_dims[3], 8, 1, dpr[cur+i]) for i in range(depths[3])])\n            self.extra_norm4 = nn.LayerNorm(embed_dims[3])\n\n\n        if self.num_modals > 0:\n            num_heads = [1,2,5,8]\n            self.FRMs = nn.ModuleList([\n                FRM(dim=embed_dims[0], reduction=1),\n                FRM(dim=embed_dims[1], reduction=1),\n                FRM(dim=embed_dims[2], reduction=1),\n                FRM(dim=embed_dims[3], reduction=1)])\n            self.FFMs = nn.ModuleList([\n                FFM(dim=embed_dims[0], reduction=1, num_heads=num_heads[0], norm_layer=nn.BatchNorm2d),\n                FFM(dim=embed_dims[1], reduction=1, num_heads=num_heads[1], norm_layer=nn.BatchNorm2d),\n                FFM(dim=embed_dims[2], reduction=1, num_heads=num_heads[2], norm_layer=nn.BatchNorm2d),\n                FFM(dim=embed_dims[3], reduction=1, num_heads=num_heads[3], norm_layer=nn.BatchNorm2d)])\n\n\n    # ---Hard selection\n    def tokenselect(self, x_ext, module):        \n        x_scores = module(x_ext)\n        # select tokens according to the max score of multiple modals, regarding H, W\n        x_stack = torch.stack(x_ext, dim=-1) # B, N, C, N_modals\n        B, N, C, N_modals = x_stack.shape\n        x_scores = torch.stack(x_scores, dim=-1) # B, N, 1, N_modals\n        x_index = torch.argmax(x_scores, dim=-1, keepdim=True) # B, N, 1, N_modals\n        x_index = einops.repeat(x_index, 'b n 1 m  -> b n c m', c=C) # B, C, H, W, N_modals\n        # --- token selection\n        x_select = x_stack.gather(-1, x_index)\n        return x_select.squeeze(-1) # B, C, H, W\n    \n    def forward(self, x: list) -> list:\n        x_cam = x[0]        \n        if self.num_modals > 0:\n            x_ext = x[1:]\n        B = x_cam.shape[0]\n        outs = []        \n        # stage 1\n        x_cam, H, W = self.patch_embed1(x_cam)\n        for blk in self.block1:\n            x_cam = blk(x_cam, H, W)\n        x_cam = self.norm1(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n        if self.num_modals > 0:\n            x_ext = [self.extra_patch_embed1(x_ext_)[0] for x_ext_ in x_ext]\n            x_f = self.tokenselect(x_ext, self.extra_score_predictor[0]) if self.num_modals > 1 else x_ext[0] \n            for blk in self.extra_block1:\n                x_f = blk(x_f, H, W)\n            x_f = self.extra_norm1(x_f).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n            # --- FFM\n            x_cam, x_f = self.FRMs[0](x_cam, x_f)\n            x_fused = self.FFMs[0](x_cam, x_f)\n            outs.append(x_fused)\n            x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x_f for x_ in x_ext] if self.num_modals > 1 else [x_f] \n        else:\n            outs.append(x_cam)\n\n        # stage 2\n        x_cam, H, W = self.patch_embed2(x_cam)\n        for blk in self.block2:\n            x_cam = blk(x_cam, H, W)\n        x_cam = self.norm2(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n        if self.num_modals > 0:\n            x_ext = [self.extra_patch_embed2(x_ext_)[0] for x_ext_ in x_ext]\n            x_f = self.tokenselect(x_ext, self.extra_score_predictor[1]) if self.num_modals > 1 else x_ext[0] \n            for blk in self.extra_block2:\n                x_f = blk(x_f, H, W)\n            x_f = self.extra_norm2(x_f).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n            # --- FFM\n            x_cam, x_f = self.FRMs[1](x_cam, x_f)\n            x_fused = self.FFMs[1](x_cam, x_f)\n            outs.append(x_fused)\n            x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x_f for x_ in x_ext] if self.num_modals > 1 else [x_f] \n        else:\n            outs.append(x_cam)\n\n        # stage 3\n        x_cam, H, W = self.patch_embed3(x_cam)\n        for blk in self.block3:\n            x_cam = blk(x_cam, H, W)\n        x_cam = self.norm3(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n        if self.num_modals > 0:\n            x_ext = [self.extra_patch_embed3(x_ext_)[0] for x_ext_ in x_ext]\n            x_f = self.tokenselect(x_ext, self.extra_score_predictor[2]) if self.num_modals > 1 else x_ext[0] \n            for blk in self.extra_block3:\n                x_f = blk(x_f, H, W)\n            x_f = self.extra_norm3(x_f).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n            # --- FFM\n            x_cam, x_f = self.FRMs[2](x_cam, x_f)\n            x_fused = self.FFMs[2](x_cam, x_f)\n            outs.append(x_fused)\n            x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x_f for x_ in x_ext] if self.num_modals > 1 else [x_f] \n        else:\n            outs.append(x_cam)\n\n        # stage 4\n        x_cam, H, W = self.patch_embed4(x_cam)\n        for blk in self.block4:\n            x_cam = blk(x_cam, H, W)\n        x_cam = self.norm4(x_cam).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n        if self.num_modals > 0:\n            x_ext = [self.extra_patch_embed4(x_ext_)[0] for x_ext_ in x_ext]\n            x_f = self.tokenselect(x_ext, self.extra_score_predictor[3]) if self.num_modals > 1 else x_ext[0] \n            for blk in self.extra_block4:\n                x_f = blk(x_f, H, W)\n            x_f = self.extra_norm4(x_f).reshape(B, H, W, -1).permute(0, 3, 1, 2)\n            # --- FFM\n            x_cam, x_f = self.FRMs[3](x_cam, x_f)\n            x_fused = self.FFMs[3](x_cam, x_f)\n            outs.append(x_fused)\n            # x_ext = [x_.reshape(B, H, W, -1).permute(0, 3, 1, 2) + x_f for x_ in x_ext] if self.num_modals > 1 else [x_f] \n        else:\n            outs.append(x_cam)\n\n        return outs\n\n\nif __name__ == '__main__':\n    modals = ['img']\n    # modals = ['img', 'depth', 'event', 'lidar']\n    # x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3]\n    # modals = ['img', 'depth']\n    # x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)]\n    x = [torch.zeros(1, 3, 1024, 1024)]\n    model = CMX('B2', modals)\n    outs = model(x)\n    print(model)\n    for y in outs:\n        print(y.shape)\n\n    # print(flop_count_table(FlopCountAnalysis(model, x)))        \n\n\n"
  },
  {
    "path": "semseg/models/base.py",
    "content": "import torch\nimport math\nfrom torch import nn\nfrom semseg.models.backbones import *\nfrom semseg.models.layers import trunc_normal_\nfrom collections import OrderedDict\n\ndef load_dualpath_model(model, model_file):\n    # load raw state_dict\n    if isinstance(model_file, str):\n        raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))\n        #raw_state_dict = torch.load(model_file)\n        if 'model' in raw_state_dict.keys():\n            raw_state_dict = raw_state_dict['model']\n    else:\n        raw_state_dict = model_file\n    \n    state_dict = {}\n    for k, v in raw_state_dict.items():\n        if k.find('patch_embed') >= 0:\n            state_dict[k] = v\n            # patch_embedx, proj, weight = k.split('.')\n            # state_dict[k.replace('patch_embed', 'extra_patch_embed')] = v\n            # state_dict[new_k] = v\n        elif k.find('block') >= 0:\n            state_dict[k] = v\n            # state_dict[k.replace('block', 'extra_block')] = v\n        elif k.find('norm') >= 0:\n            state_dict[k] = v\n            # state_dict[k.replace('norm', 'extra_norm')] = v\n\n    msg = model.load_state_dict(state_dict, strict=False)\n    print(msg)\n    del state_dict\n    \n\nclass BaseModel(nn.Module):\n    def __init__(self, backbone: str = 'MiT-B0', num_classes: int = 19, modals: list = ['rgb', 'depth', 'event', 'lidar']) -> None:\n        super().__init__()\n        backbone, variant = backbone.split('-')\n        self.backbone = eval(backbone)(variant, modals)\n        # self.backbone = eval(backbone)(variant)\n        self.modals = modals\n\n    def _init_weights(self, m: nn.Module) -> None:\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.zeros_(m.bias)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out // m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                nn.init.zeros_(m.bias)\n        elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):\n            nn.init.ones_(m.weight)\n            nn.init.zeros_(m.bias)\n\n    def init_pretrained(self, pretrained: str = None) -> None:\n        if pretrained:\n            if len(self.modals)>1:\n                load_dualpath_model(self.backbone, pretrained)\n            else:\n                checkpoint = torch.load(pretrained, map_location='cpu')\n                if 'state_dict' in checkpoint.keys():\n                    checkpoint = checkpoint['state_dict']\n                # if 'PoolFormer' in self.__class__.__name__:\n                #     new_dict = OrderedDict()\n                #     for k, v in checkpoint.items():\n                #         if not 'backbone.' in k:\n                #             new_dict['backbone.'+k] = v\n                #         else:\n                #             new_dict[k] = v\n                #     checkpoint = new_dict\n                if 'model' in checkpoint.keys(): # --- for HorNet\n                    checkpoint = checkpoint['model']\n                msg = self.backbone.load_state_dict(checkpoint, strict=False)\n                print(msg)\n"
  },
  {
    "path": "semseg/models/cmnext.py",
    "content": "import torch\nfrom torch import Tensor\nfrom torch.nn import functional as F\nfrom semseg.models.base import BaseModel\nfrom semseg.models.heads import SegFormerHead\nfrom semseg.models.heads import LightHamHead\nfrom semseg.models.heads import UPerHead\nfrom fvcore.nn import flop_count_table, FlopCountAnalysis\n\n\nclass CMNeXt(BaseModel):\n    def __init__(self, backbone: str = 'CMNeXt-B0', num_classes: int = 25, modals: list = ['img', 'depth', 'event', 'lidar']) -> None:\n        super().__init__(backbone, num_classes, modals)\n        self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 512, num_classes)\n        self.apply(self._init_weights)\n\n    def forward(self, x: list) -> list:\n        y = self.backbone(x)\n        y = self.decode_head(y)\n        y = F.interpolate(y, size=x[0].shape[2:], mode='bilinear', align_corners=False)\n        return y\n\n    def init_pretrained(self, pretrained: str = None) -> None:\n        if pretrained:\n            if self.backbone.num_modals > 0:\n                load_dualpath_model(self.backbone, pretrained)\n            else:\n                checkpoint = torch.load(pretrained, map_location='cpu')\n                if 'state_dict' in checkpoint.keys():\n                    checkpoint = checkpoint['state_dict']\n                if 'model' in checkpoint.keys():\n                    checkpoint = checkpoint['model']\n                msg = self.backbone.load_state_dict(checkpoint, strict=False)\n                print(msg)\n\ndef load_dualpath_model(model, model_file):\n    extra_pretrained = None\n    if isinstance(extra_pretrained, str):\n        raw_state_dict_ext = torch.load(extra_pretrained, map_location=torch.device('cpu'))\n        if 'state_dict' in raw_state_dict_ext.keys():\n            raw_state_dict_ext = raw_state_dict_ext['state_dict']\n    if isinstance(model_file, str):\n        raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))\n        if 'model' in raw_state_dict.keys(): \n            raw_state_dict = raw_state_dict['model']\n    else:\n        raw_state_dict = model_file\n    \n    state_dict = {}\n    for k, v in raw_state_dict.items():\n        if k.find('patch_embed') >= 0:\n            state_dict[k] = v\n        elif k.find('block') >= 0:\n            state_dict[k] = v\n        elif k.find('norm') >= 0:\n            state_dict[k] = v\n\n    if isinstance(extra_pretrained, str):\n        for k, v in raw_state_dict_ext.items():\n            if k.find('patch_embed1.proj') >= 0:\n                state_dict[k.replace('patch_embed1.proj', 'extra_downsample_layers.0.proj.module')] = v \n            if k.find('patch_embed2.proj') >= 0:\n                state_dict[k.replace('patch_embed2.proj', 'extra_downsample_layers.1.proj.module')] = v \n            if k.find('patch_embed3.proj') >= 0:\n                state_dict[k.replace('patch_embed3.proj', 'extra_downsample_layers.2.proj.module')] = v \n            if k.find('patch_embed4.proj') >= 0:\n                state_dict[k.replace('patch_embed4.proj', 'extra_downsample_layers.3.proj.module')] = v \n            \n            if k.find('patch_embed1.norm') >= 0:\n                for i in range(model.num_modals):\n                    state_dict[k.replace('patch_embed1.norm', 'extra_downsample_layers.0.norm.ln_{}'.format(i))] = v \n            if k.find('patch_embed2.norm') >= 0:\n                for i in range(model.num_modals):\n                    state_dict[k.replace('patch_embed2.norm', 'extra_downsample_layers.1.norm.ln_{}'.format(i))] = v \n            if k.find('patch_embed3.norm') >= 0:\n                for i in range(model.num_modals):\n                    state_dict[k.replace('patch_embed3.norm', 'extra_downsample_layers.2.norm.ln_{}'.format(i))] = v \n            if k.find('patch_embed4.norm') >= 0:\n                for i in range(model.num_modals):\n                    state_dict[k.replace('patch_embed4.norm', 'extra_downsample_layers.3.norm.ln_{}'.format(i))] = v \n            elif k.find('block') >= 0:\n                state_dict[k.replace('block', 'extra_block')] = v\n            elif k.find('norm') >= 0:\n                state_dict[k.replace('norm', 'extra_norm')] = v\n\n\n    msg = model.load_state_dict(state_dict, strict=False)\n    del state_dict\n\nif __name__ == '__main__':\n    modals = ['img', 'depth', 'event', 'lidar']\n    model = CMNeXt('CMNeXt-B2', 25, modals)\n    model.init_pretrained('checkpoints/pretrained/segformer/mit_b2.pth')\n    x = [torch.zeros(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024), torch.ones(1, 3, 1024, 1024)*2, torch.ones(1, 3, 1024, 1024) *3]\n    y = model(x)\n    print(y.shape)\n"
  },
  {
    "path": "semseg/models/cmx.py",
    "content": "import torch\nfrom torch import Tensor\nfrom torch.nn import functional as F\nfrom semseg.models.base import BaseModel\nfrom semseg.models.heads import SegFormerHead\nfrom fvcore.nn import flop_count_table, FlopCountAnalysis\n\nclass CMX(BaseModel):\n    def __init__(self, backbone: str = 'CMX-B0', num_classes: int = 25, modals: list = ['img', 'depth', 'event', 'lidar']) -> None:\n        super().__init__(backbone, num_classes, modals)\n        self.decode_head = SegFormerHead(self.backbone.channels, 256 if 'B0' in backbone or 'B1' in backbone else 512, num_classes)\n        self.apply(self._init_weights)\n\n    def forward(self, x: list) -> list:\n        y = self.backbone(x)\n        y = self.decode_head(y) \n        y = F.interpolate(y, size=x[0].shape[2:], mode='bilinear', align_corners=False)    # to original image shape\n        return y\n\n    \n    def init_pretrained(self, pretrained: str = None) -> None:\n        if pretrained:\n            if self.backbone.num_modals > 0:\n                load_dualpath_model(self.backbone, pretrained)\n            else:\n                checkpoint = torch.load(pretrained, map_location='cpu')\n                if 'state_dict' in checkpoint.keys():\n                    checkpoint = checkpoint['state_dict']\n                if 'model' in checkpoint.keys():\n                    checkpoint = checkpoint['model']\n                msg = self.backbone.load_state_dict(checkpoint, strict=False)\n                print(msg)\n\ndef load_dualpath_model(model, model_file):\n    if isinstance(model_file, str):\n        raw_state_dict = torch.load(model_file, map_location=torch.device('cpu'))\n        if 'model' in raw_state_dict.keys():\n            raw_state_dict = raw_state_dict['model']\n    else:\n        raw_state_dict = model_file\n    \n    state_dict = {}\n    for k, v in raw_state_dict.items():\n        if k.find('patch_embed') >= 0:\n            state_dict[k] = v\n        elif k.find('block') >= 0:\n            state_dict[k] = v\n        elif k.find('norm') >= 0:\n            state_dict[k] = v\n\n    msg = model.load_state_dict(state_dict, strict=False)\n    del state_dict\n\nif __name__ == '__main__':\n    modals = ['img']\n    # modals = ['img', 'depth', 'event', 'lidar']\n    model = CMX('CMX-B2', 25, modals)\n    model.init_pretrained('checkpoints/pretrained/segformer/mit_b2.pth')\n    x = [torch.zeros(1, 3, 512, 512)]\n    y = model(x)\n    print(y.shape)\n"
  },
  {
    "path": "semseg/models/heads/__init__.py",
    "content": "from .upernet import UPerHead\nfrom .segformer import SegFormerHead\nfrom .sfnet import SFHead\nfrom .fpn import FPNHead\nfrom .fapn import FaPNHead\nfrom .fcn import FCNHead\nfrom .condnet import CondHead\nfrom .lawin import LawinHead\nfrom .hem import LightHamHead\n\n__all__ = ['UPerHead', 'SegFormerHead']"
  },
  {
    "path": "semseg/models/heads/condnet.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom semseg.models.layers import ConvModule\n\n\nclass CondHead(nn.Module):\n    def __init__(self, in_channel: int = 2048, channel: int = 512, num_classes: int = 19):\n        super().__init__()\n        self.num_classes = num_classes\n        self.weight_num = channel * num_classes\n        self.bias_num = num_classes\n\n        self.conv = ConvModule(in_channel, channel, 1)\n        self.dropout = nn.Dropout2d(0.1)\n\n        self.guidance_project = nn.Conv2d(channel, num_classes, 1)\n        self.filter_project = nn.Conv2d(channel*num_classes, self.weight_num + self.bias_num, 1, groups=num_classes)\n\n    def forward(self, features) -> Tensor:\n        x = self.dropout(self.conv(features[-1]))\n        B, C, H, W = x.shape\n        guidance_mask = self.guidance_project(x)\n        cond_logit = guidance_mask\n        \n        key = x\n        value = x\n        guidance_mask = guidance_mask.softmax(dim=1).view(*guidance_mask.shape[:2], -1)\n        key = key.view(B, C, -1).permute(0, 2, 1)\n\n        cond_filters = torch.matmul(guidance_mask, key)\n        cond_filters /= H * W\n        cond_filters = cond_filters.view(B, -1, 1, 1)\n        cond_filters = self.filter_project(cond_filters)\n        cond_filters = cond_filters.view(B, -1)\n\n        weight, bias = torch.split(cond_filters, [self.weight_num, self.bias_num], dim=1)\n        weight = weight.reshape(B * self.num_classes, -1, 1, 1)\n        bias = bias.reshape(B * self.num_classes)\n\n        value = value.view(-1, H, W).unsqueeze(0)\n        seg_logit = F.conv2d(value, weight, bias, 1, 0, groups=B).view(B, self.num_classes, H, W)\n        \n        if self.training:\n            return cond_logit, seg_logit\n        return seg_logit\n\n\nif __name__ == '__main__':\n    from semseg.models.backbones import ResNetD\n    backbone = ResNetD('50')\n    head = CondHead()\n    x = torch.randn(2, 3, 224, 224)\n    features = backbone(x)\n    outs = head(features)\n    for out in outs:\n        out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)\n        print(out.shape)"
  },
  {
    "path": "semseg/models/heads/fapn.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom torchvision.ops import DeformConv2d\nfrom semseg.models.layers import ConvModule\n\n\nclass DCNv2(nn.Module):\n    def __init__(self, c1, c2, k, s, p, g=1):\n        super().__init__()\n        self.dcn = DeformConv2d(c1, c2, k, s, p, groups=g)\n        self.offset_mask = nn.Conv2d(c2,  g* 3 * k * k, k, s, p)\n        self._init_offset()\n\n    def _init_offset(self):\n        self.offset_mask.weight.data.zero_()\n        self.offset_mask.bias.data.zero_()\n\n    def forward(self, x, offset):\n        out = self.offset_mask(offset)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n        offset = torch.cat([o1, o2], dim=1)\n        mask = mask.sigmoid()\n        return self.dcn(x, offset, mask)\n\n\nclass FSM(nn.Module):\n    def __init__(self, c1, c2):\n        super().__init__()\n        self.conv_atten = nn.Conv2d(c1, c1, 1, bias=False)\n        self.conv = nn.Conv2d(c1, c2, 1, bias=False)\n\n    def forward(self, x: Tensor) -> Tensor:\n        atten = self.conv_atten(F.avg_pool2d(x, x.shape[2:])).sigmoid()\n        feat = torch.mul(x, atten)\n        x = x + feat\n        return self.conv(x)\n\n\nclass FAM(nn.Module):\n    def __init__(self, c1, c2):\n        super().__init__()\n        self.lateral_conv = FSM(c1, c2)\n        self.offset = nn.Conv2d(c2*2, c2, 1, bias=False)\n        self.dcpack_l2 = DCNv2(c2, c2, 3, 1, 1, 8)\n    \n    def forward(self, feat_l, feat_s):\n        feat_up = feat_s\n        if feat_l.shape[2:] != feat_s.shape[2:]:\n            feat_up = F.interpolate(feat_s, size=feat_l.shape[2:], mode='bilinear', align_corners=False)\n        \n        feat_arm = self.lateral_conv(feat_l)\n        offset = self.offset(torch.cat([feat_arm, feat_up*2], dim=1))\n\n        feat_align = F.relu(self.dcpack_l2(feat_up, offset))\n        return feat_align + feat_arm\n\n\nclass FaPNHead(nn.Module):\n    def __init__(self, in_channels, channel=128, num_classes=19):\n        super().__init__()\n        in_channels = in_channels[::-1]\n        self.align_modules = nn.ModuleList([ConvModule(in_channels[0], channel, 1)])\n        self.output_convs = nn.ModuleList([])\n\n        for ch in in_channels[1:]:\n            self.align_modules.append(FAM(ch, channel))\n            self.output_convs.append(ConvModule(channel, channel, 3, 1, 1))\n\n        self.conv_seg = nn.Conv2d(channel, num_classes, 1)\n        self.dropout = nn.Dropout2d(0.1)\n\n    def forward(self, features) -> Tensor:\n        features = features[::-1]\n        out = self.align_modules[0](features[0])\n        \n        for feat, align_module, output_conv in zip(features[1:], self.align_modules[1:], self.output_convs):\n            out = align_module(feat, out)\n            out = output_conv(out)\n        out = self.conv_seg(self.dropout(out))\n        return out\n\n\nif __name__ == '__main__':\n    from semseg.models.backbones import ResNet\n    backbone = ResNet('50')\n    head = FaPNHead([256, 512, 1024, 2048], 128, 19)\n    x = torch.randn(2, 3, 224, 224)\n    features = backbone(x)\n    out = head(features)\n    out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)\n    print(out.shape)"
  },
  {
    "path": "semseg/models/heads/fcn.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom semseg.models.layers import ConvModule\n\n\nclass FCNHead(nn.Module):\n    def __init__(self, c1, c2, num_classes: int = 19):\n        super().__init__()\n        self.conv = ConvModule(c1, c2, 1)\n        self.cls = nn.Conv2d(c2, num_classes, 1)\n\n    def forward(self, features) -> Tensor:\n        x = self.conv(features[-1])\n        x = self.cls(x)\n        return x\n\n\nif __name__ == '__main__':\n    from semseg.models.backbones import ResNet\n    backbone = ResNet('50')\n    head = FCNHead(2048, 256, 19)\n    x = torch.randn(2, 3, 224, 224)\n    features = backbone(x)\n    out = head(features)\n    out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)\n    print(out.shape)\n"
  },
  {
    "path": "semseg/models/heads/fpn.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom semseg.models.layers import ConvModule\n\n\nclass FPNHead(nn.Module):\n    \"\"\"Panoptic Feature Pyramid Networks\n    https://arxiv.org/abs/1901.02446\n    \"\"\"\n    def __init__(self, in_channels, channel=128, num_classes=19):\n        super().__init__()\n        self.lateral_convs = nn.ModuleList([])\n        self.output_convs = nn.ModuleList([])\n\n        for ch in in_channels[::-1]:\n            self.lateral_convs.append(ConvModule(ch, channel, 1))\n            self.output_convs.append(ConvModule(channel, channel, 3, 1, 1))\n\n        self.conv_seg = nn.Conv2d(channel, num_classes, 1)\n        self.dropout = nn.Dropout2d(0.1)\n\n    def forward(self, features) -> Tensor:\n        features = features[::-1]\n        out = self.lateral_convs[0](features[0])\n        \n        for i in range(1, len(features)):\n            out = F.interpolate(out, scale_factor=2.0, mode='nearest')\n            out = out + self.lateral_convs[i](features[i])\n            out = self.output_convs[i](out)\n        out = self.conv_seg(self.dropout(out))\n        return out\n\n\nif __name__ == '__main__':\n    from semseg.models.backbones import ResNet\n    backbone = ResNet('50')\n    head = FPNHead([256, 512, 1024, 2048], 128, 19)\n    x = torch.randn(2, 3, 224, 224)\n    features = backbone(x)\n    out = head(features)\n    out = F.interpolate(out, size=x.shape[-2:], mode='bilinear', align_corners=False)\n    print(out.shape)"
  },
  {
    "path": "semseg/models/heads/hem.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmcv.cnn import ConvModule\nfrom fvcore.nn import flop_count_table, FlopCountAnalysis\n\nclass _MatrixDecomposition2DBase(nn.Module):\n    def __init__(self, args=dict()):\n        super().__init__()\n\n        self.spatial = args.setdefault('SPATIAL', True)\n\n        self.S = args.setdefault('MD_S', 1)\n        self.D = args.setdefault('MD_D', 512)\n        self.R = args.setdefault('MD_R', 64)\n\n        self.train_steps = args.setdefault('TRAIN_STEPS', 6)\n        self.eval_steps = args.setdefault('EVAL_STEPS', 7)\n\n        self.inv_t = args.setdefault('INV_T', 100)\n        self.eta = args.setdefault('ETA', 0.9)\n\n        self.rand_init = args.setdefault('RAND_INIT', True)\n\n        print('spatial', self.spatial)\n        print('S', self.S)\n        print('D', self.D)\n        print('R', self.R)\n        print('train_steps', self.train_steps)\n        print('eval_steps', self.eval_steps)\n        print('inv_t', self.inv_t)\n        print('eta', self.eta)\n        print('rand_init', self.rand_init)\n\n    def _build_bases(self, B, S, D, R, cuda=False):\n        raise NotImplementedError\n\n    def local_step(self, x, bases, coef):\n        raise NotImplementedError\n\n    # @torch.no_grad()\n    def local_inference(self, x, bases):\n        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)\n        coef = torch.bmm(x.transpose(1, 2), bases)\n        coef = F.softmax(self.inv_t * coef, dim=-1)\n\n        steps = self.train_steps if self.training else self.eval_steps\n        for _ in range(steps):\n            bases, coef = self.local_step(x, bases, coef)\n\n        return bases, coef\n\n    def compute_coef(self, x, bases, coef):\n        raise NotImplementedError\n\n    def forward(self, x, return_bases=False):\n        B, C, H, W = x.shape\n\n        # (B, C, H, W) -> (B * S, D, N)\n        if self.spatial:\n            D = C // self.S\n            N = H * W\n            x = x.view(B * self.S, D, N)\n        else:\n            D = H * W\n            N = C // self.S\n            x = x.view(B * self.S, N, D).transpose(1, 2)\n\n        if not self.rand_init and not hasattr(self, 'bases'):\n            bases = self._build_bases(1, self.S, D, self.R, cuda=True)\n            self.register_buffer('bases', bases)\n\n        # (S, D, R) -> (B * S, D, R)\n        if self.rand_init:\n            bases = self._build_bases(B, self.S, D, self.R, cuda=True)\n        else:\n            bases = self.bases.repeat(B, 1, 1)\n\n        bases, coef = self.local_inference(x, bases)\n\n        # (B * S, N, R)\n        coef = self.compute_coef(x, bases, coef)\n\n        # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)\n        x = torch.bmm(bases, coef.transpose(1, 2))\n\n        # (B * S, D, N) -> (B, C, H, W)\n        if self.spatial:\n            x = x.view(B, C, H, W)\n        else:\n            x = x.transpose(1, 2).view(B, C, H, W)\n\n        # (B * H, D, R) -> (B, H, N, D)\n        bases = bases.view(B, self.S, D, self.R)\n\n        return x\n\n\nclass NMF2D(_MatrixDecomposition2DBase):\n    def __init__(self, args=dict()):\n        super().__init__(args)\n\n        self.inv_t = 1\n\n    def _build_bases(self, B, S, D, R, cuda=False):\n        if cuda:\n            bases = torch.rand((B * S, D, R)).cuda()\n        else:\n            bases = torch.rand((B * S, D, R))\n\n        bases = F.normalize(bases, dim=1)\n\n        return bases\n\n    # @torch.no_grad()\n    def local_step(self, x, bases, coef):\n        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)\n        numerator = torch.bmm(x.transpose(1, 2), bases)\n        # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)\n        denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))\n        # Multiplicative Update\n        coef = coef * numerator / (denominator + 1e-6)\n\n        # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)\n        numerator = torch.bmm(x, coef)\n        # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)\n        denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))\n        # Multiplicative Update\n        bases = bases * numerator / (denominator + 1e-6)\n\n        return bases, coef\n\n    def compute_coef(self, x, bases, coef):\n        # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)\n        numerator = torch.bmm(x.transpose(1, 2), bases)\n        # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)\n        denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))\n        # multiplication update\n        coef = coef * numerator / (denominator + 1e-6)\n\n        return coef\n\n\nclass Hamburger(nn.Module):\n    def __init__(self, ham_channels=512, ham_kwargs=dict(), norm_cfg=None):\n        super().__init__()\n        self.ham_in = ConvModule(ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)\n        self.ham = NMF2D(ham_kwargs)\n        self.ham_out = ConvModule(ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)\n\n    def forward(self, x):\n        enjoy = self.ham_in(x)\n        enjoy = F.relu(enjoy, inplace=True)\n        enjoy = self.ham(enjoy)\n        enjoy = self.ham_out(enjoy)\n        ham = F.relu(x + enjoy, inplace=True)\n        return ham\n\nclass LightHamHead(nn.Module):\n    def __init__(self, in_channels=[64, 128, 320, 512], ham_channels=512, ham_kwargs=dict(), num_classes=25):\n        super().__init__()\n        self.in_channels = in_channels[1:]\n        self.in_index = [1,2,3] \n        self.ham_channels = self.channels = ham_channels\n        self.conv_cfg = None\n        self.norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)\n        self.act_cfg = dict(type='ReLU')\n\n        self.ham_channels = ham_channels\n        self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1, conv_cfg=self.conv_cfg,\n                               norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)\n        self.hamburger = Hamburger(ham_channels, ham_kwargs, self.norm_cfg)\n        self.align = ConvModule(self.ham_channels, self.channels, 1, conv_cfg=self.conv_cfg,\n                                norm_cfg=self.norm_cfg, act_cfg=self.act_cfg)\n        self.conv_seg = nn.Conv2d(self.channels, num_classes, kernel_size=1)\n\n    def forward(self, inputs):\n        \"\"\"Forward function.\"\"\"\n        inputs = [inputs[i] for i in self.in_index]\n\n        inputs = [F.interpolate(level, size=inputs[0].shape[2:], mode='bilinear', align_corners=False) for level in inputs]\n\n        inputs = torch.cat(inputs, dim=1)\n        x = self.squeeze(inputs)\n\n        x = self.hamburger(x)\n\n        output = self.align(x)\n        output = self.conv_seg(output)\n        return output\n\nif __name__ == '__main__':\n    model = LightHamHead(num_classes=25)\n    model = model.cuda()\n    x = [torch.zeros(1, 64, 256, 256), torch.ones(1, 128, 128, 128), torch.ones(1, 320, 64, 64)*2, torch.ones(1, 512, 32, 32) *3]\n    x = [xi.cuda() for xi in x]\n    outs = model(x)\n    print(model)\n    for y in outs:\n        print(y.shape)\n    print(flop_count_table(FlopCountAnalysis(model, x)))        \n\n"
  },
  {
    "path": "semseg/models/heads/lawin.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom einops import rearrange\n\n\nclass MLP(nn.Module):\n    def __init__(self, dim=2048, embed_dim=768):\n        super().__init__()\n        self.proj = nn.Linear(dim, embed_dim)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass PatchEmbed(nn.Module):\n    def __init__(self, patch_size=4, in_ch=3, dim=96, type='pool') -> None:\n        super().__init__()\n        self.patch_size = patch_size\n        self.type = type\n        self.dim = dim\n        \n        if type == 'conv':\n            self.proj = nn.Conv2d(in_ch, dim, patch_size, patch_size, groups=patch_size*patch_size)\n        else:\n            self.proj = nn.ModuleList([\n                nn.MaxPool2d(patch_size, patch_size),\n                nn.AvgPool2d(patch_size, patch_size)\n            ])\n        \n        self.norm = nn.LayerNorm(dim)\n\n    def forward(self, x: Tensor) -> Tensor:\n        _, _, H, W = x.shape\n        if W % self.patch_size != 0:\n            x = F.pad(x, (0, self.patch_size - W % self.patch_size))\n        if H % self.patch_size != 0:\n            x = F.pad(x, (0, 0, 0, self.patch_size - H % self.patch_size))\n\n        if self.type == 'conv':\n            x = self.proj(x)\n        else:\n            x = 0.5 * (self.proj[0](x) + self.proj[1](x))\n        Wh, Ww = x.size(2), x.size(3)\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n        x = x.transpose(1, 2).view(-1, self.dim, Wh, Ww)\n        return x\n\n\nclass LawinAttn(nn.Module):\n    def __init__(self, in_ch=512, head=4, patch_size=8, reduction=2) -> None:\n        super().__init__()\n        self.head = head\n\n        self.position_mixing = nn.ModuleList([\n            nn.Linear(patch_size * patch_size, patch_size * patch_size)\n        for _ in range(self.head)])\n\n        self.inter_channels = max(in_ch // reduction, 1)\n        self.g = nn.Conv2d(in_ch, self.inter_channels, 1)\n        self.theta = nn.Conv2d(in_ch, self.inter_channels, 1)\n        self.phi = nn.Conv2d(in_ch, self.inter_channels, 1)\n        self.conv_out = nn.Sequential(\n            nn.Conv2d(self.inter_channels, in_ch, 1, bias=False),\n            nn.BatchNorm2d(in_ch)\n        )\n\n\n    def forward(self, query: Tensor, context: Tensor) -> Tensor:\n        B, C, H, W = context.shape\n        context = context.reshape(B, C, -1)\n        context_mlp = []\n\n        for i, pm in enumerate(self.position_mixing):\n            context_crt = context[:, (C//self.head)*i:(C//self.head)*(i+1), :]\n            context_mlp.append(pm(context_crt))\n\n        context_mlp = torch.cat(context_mlp, dim=1)\n        context = context + context_mlp\n        context = context.reshape(B, C, H, W)\n\n        g_x = self.g(context).view(B, self.inter_channels, -1)\n        g_x = rearrange(g_x, \"b (h dim) n -> (b h) dim n\", h=self.head)\n        g_x = g_x.permute(0, 2, 1)\n\n        theta_x = self.theta(query).view(B, self.inter_channels, -1)\n        theta_x = rearrange(theta_x, \"b (h dim) n -> (b h) dim n\", h=self.head)\n        theta_x = theta_x.permute(0, 2, 1)\n\n        phi_x = self.phi(context).view(B, self.inter_channels, -1)\n        phi_x = rearrange(phi_x, \"b (h dim) n -> (b h) dim n\", h=self.head)\n\n        pairwise_weight = torch.matmul(theta_x, phi_x)\n        pairwise_weight /= theta_x.shape[-1]**0.5\n        pairwise_weight = pairwise_weight.softmax(dim=-1)\n\n        y = torch.matmul(pairwise_weight, g_x)\n        y = rearrange(y, \"(b h) n dim -> b n (h dim)\", h=self.head)\n        y = y.permute(0, 2, 1).contiguous().reshape(B, self.inter_channels, *query.shape[-2:])\n\n        output = query + self.conv_out(y)\n        return output\n        \n\nclass ConvModule(nn.Module):\n    def __init__(self, c1, c2):\n        super().__init__()\n        self.conv = nn.Conv2d(c1, c2, 1, bias=False)\n        self.bn = nn.BatchNorm2d(c2)        # use SyncBN in original\n        self.activate = nn.ReLU(True)\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.activate(self.bn(self.conv(x)))\n\n\nclass LawinHead(nn.Module):\n    def __init__(self, in_channels: list, embed_dim=512, num_classes=19) -> None:\n        super().__init__()\n        for i, dim in enumerate(in_channels):\n            self.add_module(f\"linear_c{i+1}\", MLP(dim, 48 if i == 0 else embed_dim))\n\n        self.lawin_8 = LawinAttn(embed_dim, 64)\n        self.lawin_4 = LawinAttn(embed_dim, 16)\n        self.lawin_2 = LawinAttn(embed_dim, 4)\n        self.ds_8 = PatchEmbed(8, embed_dim, embed_dim)\n        self.ds_4 = PatchEmbed(4, embed_dim, embed_dim)\n        self.ds_2 = PatchEmbed(2, embed_dim, embed_dim)\n    \n        self.image_pool = nn.Sequential(\n            nn.AdaptiveAvgPool2d(1),\n            ConvModule(embed_dim, embed_dim)\n        )\n        self.linear_fuse = ConvModule(embed_dim*3, embed_dim)\n        self.short_path = ConvModule(embed_dim, embed_dim)\n        self.cat = ConvModule(embed_dim*5, embed_dim)\n\n        self.low_level_fuse = ConvModule(embed_dim+48, embed_dim)\n        self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)\n        self.dropout = nn.Dropout2d(0.1)\n    \n    def get_lawin_att_feats(self, x: Tensor, patch_size: int):\n        _, _, H, W = x.shape\n        query = F.unfold(x, patch_size, stride=patch_size)\n        query = rearrange(query, 'b (c ph pw) (nh nw) -> (b nh nw) c ph pw', ph=patch_size, pw=patch_size, nh=H//patch_size, nw=W//patch_size)\n        outs = []\n\n        for r in [8, 4, 2]:\n            context = F.unfold(x, patch_size*r, stride=patch_size, padding=int((r-1)/2*patch_size))\n            context = rearrange(context, \"b (c ph pw) (nh nw) -> (b nh nw) c ph pw\", ph=patch_size*r, pw=patch_size*r, nh=H//patch_size, nw=W//patch_size)\n            context = getattr(self, f\"ds_{r}\")(context)\n            output = getattr(self, f\"lawin_{r}\")(query, context)\n            output = rearrange(output, \"(b nh nw) c ph pw -> b c (nh ph) (nw pw)\", ph=patch_size, pw=patch_size, nh=H//patch_size, nw=W//patch_size)\n            outs.append(output)\n        return outs\n\n\n    def forward(self, features):\n        B, _, H, W = features[1].shape\n        outs = [self.linear_c2(features[1]).permute(0, 2, 1).reshape(B, -1, *features[1].shape[-2:])]\n\n        for i, feature in enumerate(features[2:]):\n            cf = eval(f\"self.linear_c{i+3}\")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])\n            outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))\n\n        feat = self.linear_fuse(torch.cat(outs[::-1], dim=1))\n        B, _, H, W = feat.shape\n\n        ## Lawin attention spatial pyramid pooling\n        feat_short = self.short_path(feat)\n        feat_pool = F.interpolate(self.image_pool(feat), size=(H, W), mode='bilinear', align_corners=False)\n        feat_lawin = self.get_lawin_att_feats(feat, 8)\n        output = self.cat(torch.cat([feat_short, feat_pool, *feat_lawin], dim=1))\n\n        ## Low-level feature enhancement\n        c1 = self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])\n        output = F.interpolate(output, size=features[0].shape[-2:], mode='bilinear', align_corners=False)\n        fused = self.low_level_fuse(torch.cat([output, c1], dim=1))\n\n        seg = self.linear_pred(self.dropout(fused))\n        return seg\n"
  },
  {
    "path": "semseg/models/heads/segformer.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom typing import Tuple\nfrom torch.nn import functional as F\n\n\nclass MLP(nn.Module):\n    def __init__(self, dim, embed_dim):\n        super().__init__()\n        self.proj = nn.Linear(dim, embed_dim)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = x.flatten(2).transpose(1, 2)\n        x = self.proj(x)\n        return x\n\n\nclass ConvModule(nn.Module):\n    def __init__(self, c1, c2):\n        super().__init__()\n        self.conv = nn.Conv2d(c1, c2, 1, bias=False)\n        self.bn = nn.BatchNorm2d(c2)        # use SyncBN in original\n        self.activate = nn.ReLU(True)\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.activate(self.bn(self.conv(x)))\n\n\nclass SegFormerHead(nn.Module):\n    def __init__(self, dims: list, embed_dim: int = 256, num_classes: int = 19):\n        super().__init__()\n        for i, dim in enumerate(dims):\n            self.add_module(f\"linear_c{i+1}\", MLP(dim, embed_dim))\n\n        self.linear_fuse = ConvModule(embed_dim*4, embed_dim)\n        self.linear_pred = nn.Conv2d(embed_dim, num_classes, 1)\n        self.dropout = nn.Dropout2d(0.1)\n\n    def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:\n        B, _, H, W = features[0].shape\n        outs = [self.linear_c1(features[0]).permute(0, 2, 1).reshape(B, -1, *features[0].shape[-2:])]\n\n        for i, feature in enumerate(features[1:]):\n            cf = eval(f\"self.linear_c{i+2}\")(feature).permute(0, 2, 1).reshape(B, -1, *feature.shape[-2:])\n            outs.append(F.interpolate(cf, size=(H, W), mode='bilinear', align_corners=False))\n\n        seg = self.linear_fuse(torch.cat(outs[::-1], dim=1))\n        seg = self.linear_pred(self.dropout(seg))\n        return seg"
  },
  {
    "path": "semseg/models/heads/sfnet.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom semseg.models.layers import ConvModule\nfrom semseg.models.modules import PPM\n\n\nclass AlignedModule(nn.Module):\n    def __init__(self, c1, c2, k=3):\n        super().__init__()\n        self.down_h = nn.Conv2d(c1, c2, 1, bias=False)\n        self.down_l = nn.Conv2d(c1, c2, 1, bias=False)\n        self.flow_make = nn.Conv2d(c2 * 2, 2, k, 1, 1, bias=False)\n\n    def forward(self, low_feature: Tensor, high_feature: Tensor) -> Tensor:\n        high_feature_origin = high_feature\n        H, W = low_feature.shape[-2:]\n        low_feature = self.down_l(low_feature)\n        high_feature = self.down_h(high_feature)\n        high_feature = F.interpolate(high_feature, size=(H, W), mode='bilinear', align_corners=True)\n        flow = self.flow_make(torch.cat([high_feature, low_feature], dim=1))\n        high_feature = self.flow_warp(high_feature_origin, flow, (H, W))\n        return high_feature\n\n    def flow_warp(self, x: Tensor, flow: Tensor, size: tuple) -> Tensor:\n        norm = torch.tensor([[[[*size]]]]).type_as(x).to(x.device)\n        H = torch.linspace(-1.0, 1.0, size[0]).view(-1, 1).repeat(1, size[1])\n        W = torch.linspace(-1.0, 1.0, size[1]).repeat(size[0], 1)\n        grid = torch.cat((W.unsqueeze(2), H.unsqueeze(2)), dim=2)\n        grid = grid.repeat(x.shape[0], 1, 1, 1).type_as(x).to(x.device)\n        grid = grid + flow.permute(0, 2, 3, 1) / norm\n        output = F.grid_sample(x, grid, align_corners=False)\n        return output\n\n\nclass SFHead(nn.Module):\n    def __init__(self, in_channels, channel=256, num_classes=19, scales=(1, 2, 3, 6)):\n        super().__init__()\n        self.ppm = PPM(in_channels[-1], channel, scales)\n\n        self.fpn_in = nn.ModuleList([])\n        self.fpn_out = nn.ModuleList([])\n        self.fpn_out_align = nn.ModuleList([])\n\n        for in_ch in in_channels[:-1]:\n            self.fpn_in.append(ConvModule(in_ch, channel, 1))\n            self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1))\n            self.fpn_out_align.append(AlignedModule(channel, channel//2))\n\n        self.bottleneck = ConvModule(len(in_channels) * channel, channel, 3, 1, 1)\n        self.dropout = nn.Dropout2d(0.1)\n        self.conv_seg = nn.Conv2d(channel, num_classes, 1)\n\n    def forward(self, features: list) -> Tensor:\n        f = self.ppm(features[-1])\n        fpn_features = [f]\n\n        for i in reversed(range(len(features) - 1)):\n            feature = self.fpn_in[i](features[i])\n            f = feature + self.fpn_out_align[i](feature, f)\n            fpn_features.append(self.fpn_out[i](f))\n\n        fpn_features.reverse()\n\n        for i in range(1, len(fpn_features)):\n            fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=True)\n\n        output = self.bottleneck(torch.cat(fpn_features, dim=1))\n        output = self.conv_seg(self.dropout(output))\n        return output\n\n"
  },
  {
    "path": "semseg/models/heads/upernet.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom typing import Tuple\nfrom semseg.models.layers import ConvModule\nfrom semseg.models.modules import PPM\n\n\nclass UPerHead(nn.Module):\n    \"\"\"Unified Perceptual Parsing for Scene Understanding\n    https://arxiv.org/abs/1807.10221\n    scales: Pooling scales used in PPM module applied on the last feature\n    \"\"\"\n    def __init__(self, in_channels, channel=128, num_classes: int = 19, scales=(1, 2, 3, 6)):\n        super().__init__()\n        # PPM Module\n        self.ppm = PPM(in_channels[-1], channel, scales)\n\n        # FPN Module\n        self.fpn_in = nn.ModuleList()\n        self.fpn_out = nn.ModuleList()\n\n        for in_ch in in_channels[:-1]: # skip the top layer\n            self.fpn_in.append(ConvModule(in_ch, channel, 1))\n            self.fpn_out.append(ConvModule(channel, channel, 3, 1, 1))\n\n        self.bottleneck = ConvModule(len(in_channels)*channel, channel, 3, 1, 1)\n        self.dropout = nn.Dropout2d(0.1)\n        self.conv_seg = nn.Conv2d(channel, num_classes, 1)\n\n\n    def forward(self, features: Tuple[Tensor, Tensor, Tensor, Tensor]) -> Tensor:\n        f = self.ppm(features[-1])\n        fpn_features = [f]\n\n        for i in reversed(range(len(features)-1)):\n            feature = self.fpn_in[i](features[i])\n            f = feature + F.interpolate(f, size=feature.shape[-2:], mode='bilinear', align_corners=False)\n            fpn_features.append(self.fpn_out[i](f))\n\n        fpn_features.reverse()\n        for i in range(1, len(features)):\n            fpn_features[i] = F.interpolate(fpn_features[i], size=fpn_features[0].shape[-2:], mode='bilinear', align_corners=False)\n \n        output = self.bottleneck(torch.cat(fpn_features, dim=1))\n        output = self.conv_seg(self.dropout(output))\n        return output\n\n\nif __name__ == '__main__':\n    model = UPerHead([64, 128, 256, 512], 128)\n    x1 = torch.randn(2, 64, 56, 56)\n    x2 = torch.randn(2, 128, 28, 28)\n    x3 = torch.randn(2, 256, 14, 14)\n    x4 = torch.randn(2, 512, 7, 7)\n    y = model([x1, x2, x3, x4])\n    print(y.shape)"
  },
  {
    "path": "semseg/models/layers/__init__.py",
    "content": "from .common import *\nfrom .initialize import *"
  },
  {
    "path": "semseg/models/layers/common.py",
    "content": "import torch\nfrom torch import nn, Tensor\n\n\nclass ConvModule(nn.Sequential):\n    def __init__(self, c1, c2, k, s=1, p=0, d=1, g=1):\n        super().__init__(\n            nn.Conv2d(c1, c2, k, s, p, d, g, bias=False),\n            nn.BatchNorm2d(c2),\n            nn.ReLU(True)\n        )\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\n    Copied from timm\n    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,\n    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...\n    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for\n    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use\n    'survival rate' as the argument.\n    \"\"\"\n    def __init__(self, p: float = None):\n        super().__init__()\n        self.p = p\n\n    def forward(self, x: Tensor) -> Tensor:\n        if self.p == 0. or not self.training:\n            return x\n        kp = 1 - self.p\n        shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n        random_tensor = kp + torch.rand(shape, dtype=x.dtype, device=x.device)\n        random_tensor.floor_()  # binarize\n        return x.div(kp) * random_tensor"
  },
  {
    "path": "semseg/models/layers/initialize.py",
    "content": "import torch\nimport math\nimport warnings\nfrom torch import nn, Tensor\n\n\ndef _no_grad_trunc_normal_(tensor, mean, std, a, b):\n    # Cut & paste from PyTorch official master until it's in a few official releases - RW\n    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf\n    def norm_cdf(x):\n        # Computes standard normal cumulative distribution function\n        return (1. + math.erf(x / math.sqrt(2.))) / 2.\n\n    if (mean < a - 2 * std) or (mean > b + 2 * std):\n        warnings.warn(\"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. \"\n                      \"The distribution of values may be incorrect.\",\n                      stacklevel=2)\n\n    with torch.no_grad():\n        # Values are generated by using a truncated uniform distribution and\n        # then using the inverse CDF for the normal distribution.\n        # Get upper and lower cdf values\n        l = norm_cdf((a - mean) / std)\n        u = norm_cdf((b - mean) / std)\n\n        # Uniformly fill tensor with values from [l, u], then translate to\n        # [2l-1, 2u-1].\n        tensor.uniform_(2 * l - 1, 2 * u - 1)\n\n        # Use inverse cdf transform for normal distribution to get truncated\n        # standard normal\n        tensor.erfinv_()\n\n        # Transform to proper mean, std\n        tensor.mul_(std * math.sqrt(2.))\n        tensor.add_(mean)\n\n        # Clamp to ensure it's in the proper range\n        tensor.clamp_(min=a, max=b)\n        return tensor\n\n\ndef trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):\n    # type: (Tensor, float, float, float, float) -> Tensor\n    r\"\"\"Fills the input Tensor with values drawn from a truncated\n    normal distribution. The values are effectively drawn from the\n    normal distribution :math:`\\mathcal{N}(\\text{mean}, \\text{std}^2)`\n    with values outside :math:`[a, b]` redrawn until they are within\n    the bounds. The method used for generating the random values works\n    best when :math:`a \\leq \\text{mean} \\leq b`.\n    Args:\n        tensor: an n-dimensional `torch.Tensor`\n        mean: the mean of the normal distribution\n        std: the standard deviation of the normal distribution\n        a: the minimum cutoff value\n        b: the maximum cutoff value\n    Examples:\n        >>> w = torch.empty(3, 5)\n        >>> nn.init.trunc_normal_(w)\n    \"\"\"\n    return _no_grad_trunc_normal_(tensor, mean, std, a, b)\n"
  },
  {
    "path": "semseg/models/modules/__init__.py",
    "content": "from .ppm import PPM\nfrom .psa import PSAP, PSAS\n\n__all__ = ['PPM', 'PSAP', 'PSAS']"
  },
  {
    "path": "semseg/models/modules/crossatt.py",
    "content": "import torch\nfrom torch import nn\nfrom einops import rearrange\nfrom torch import einsum\n\ndef exists(val):\n    return val is not None\n\ndef default(val, d):\n    return val if exists(val) else d\n\ndef stable_softmax(t, dim = -1):\n    t = t - t.amax(dim = dim, keepdim = True)\n    return t.softmax(dim = dim)\n\n# bidirectional cross attention - have two sequences attend to each other with 1 attention step\n\nclass BidirectionalCrossAttention(nn.Module):\n    def __init__(self, dim, heads=8, dim_head=64, context_dim=None, dropout=0., talking_heads=False, prenorm=True,):\n        super().__init__()\n        context_dim = default(context_dim, dim)\n\n        self.norm = nn.LayerNorm(dim) if prenorm else nn.Identity()\n        self.context_norm = nn.LayerNorm(context_dim) if prenorm else nn.Identity()\n\n        self.heads = heads\n        self.scale = dim_head ** -0.5\n        inner_dim = dim_head * heads\n\n        self.dropout = nn.Dropout(dropout)\n        self.context_dropout = nn.Dropout(dropout)\n\n        self.to_qk = nn.Linear(dim, inner_dim, bias = False)\n        self.context_to_qk = nn.Linear(context_dim, inner_dim, bias = False)\n\n        self.to_v = nn.Linear(dim, inner_dim, bias = False)\n        self.context_to_v = nn.Linear(context_dim, inner_dim, bias = False)\n\n        self.to_out = nn.Linear(inner_dim, dim)\n        self.context_to_out = nn.Linear(inner_dim, context_dim)\n\n        self.talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()\n        self.context_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) if talking_heads else nn.Identity()\n\n    def forward(self, x, context, return_attn=False, rel_pos_bias=None):\n        b, i, j, h, device = x.shape[0], x.shape[-2], context.shape[-2], self.heads, x.device\n\n        x = self.norm(x)\n        context = self.context_norm(context)\n\n        # get shared query/keys and values for sequence and context\n        qk, v = self.to_qk(x), self.to_v(x)\n        context_qk, context_v = self.context_to_qk(context), self.context_to_v(context)\n\n        # split out head\n        qk, context_qk, v, context_v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (qk, context_qk, v, context_v))\n\n        # get similarities\n        sim = einsum('b h i d, b h j d -> b h i j', qk, context_qk) * self.scale\n\n        # relative positional bias, if supplied\n        if exists(rel_pos_bias):\n            sim = sim + rel_pos_bias\n\n        # get attention along both sequence length and context length dimensions\n        # shared similarity matrix\n        attn = stable_softmax(sim, dim = -1)\n        context_attn = stable_softmax(sim, dim = -2)\n\n        # dropouts\n        attn = self.dropout(attn)\n        context_attn = self.context_dropout(context_attn)\n\n        # talking heads\n        attn = self.talking_heads(attn)\n        context_attn = self.context_talking_heads(context_attn)\n\n        # src sequence aggregates values from context, context aggregates values from src sequence\n        out = einsum('b h i j, b h j d -> b h i d', attn, context_v)\n        context_out = einsum('b h j i, b h j d -> b h i d', context_attn, v)\n\n        # merge heads and combine out\n        out, context_out = map(lambda t: rearrange(t, 'b h n d -> b n (h d)'), (out, context_out))\n\n        out = self.to_out(out)\n        context_out = self.context_to_out(context_out)\n\n        if return_attn:\n            return out, context_out, attn, context_attn\n\n        return out, context_out\n\nif __name__ == '__main__':\n    video = torch.randn(1, 4096, 512)\n    audio = torch.randn(1, 8192, 386)\n\n    joint_cross_attn = BidirectionalCrossAttention(dim = 512, heads = 8, dim_head = 64, context_dim = 386)\n    video_out, audio_out = joint_cross_attn(video, audio)\n\n    # attended output should have the same shape as input\n    assert video_out.shape == video.shape\n    assert audio_out.shape == audio.shape"
  },
  {
    "path": "semseg/models/modules/ffm.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom timm.models.layers import trunc_normal_\nimport math\n\n\n# Feature Rectify Module\nclass ChannelWeights(nn.Module):\n    def __init__(self, dim, reduction=1):\n        super(ChannelWeights, self).__init__()\n        self.dim = dim\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.max_pool = nn.AdaptiveMaxPool2d(1)\n        self.mlp = nn.Sequential(\n                    nn.Linear(self.dim * 4, self.dim * 4 // reduction),\n                    nn.ReLU(inplace=True),\n                    nn.Linear(self.dim * 4 // reduction, self.dim * 2), \n                    nn.Sigmoid())\n\n    def forward(self, x1, x2):\n        B, _, H, W = x1.shape\n        x = torch.cat((x1, x2), dim=1)\n        avg = self.avg_pool(x).view(B, self.dim * 2)\n        max = self.max_pool(x).view(B, self.dim * 2)\n        y = torch.cat((avg, max), dim=1) # B 4C\n        y = self.mlp(y).view(B, self.dim * 2, 1)\n        channel_weights = y.reshape(B, 2, self.dim, 1, 1).permute(1, 0, 2, 3, 4) # 2 B C 1 1\n        return channel_weights\n\n\nclass SpatialWeights(nn.Module):\n    def __init__(self, dim, reduction=1):\n        super(SpatialWeights, self).__init__()\n        self.dim = dim\n        self.mlp = nn.Sequential(\n                    nn.Conv2d(self.dim * 2, self.dim // reduction, kernel_size=1),\n                    nn.ReLU(inplace=True),\n                    nn.Conv2d(self.dim // reduction, 2, kernel_size=1), \n                    nn.Sigmoid())\n\n    def forward(self, x1, x2):\n        B, _, H, W = x1.shape\n        x = torch.cat((x1, x2), dim=1) # B 2C H W\n        spatial_weights = self.mlp(x).reshape(B, 2, 1, H, W).permute(1, 0, 2, 3, 4) # 2 B 1 H W\n        return spatial_weights\n\n\nclass FeatureRectifyModule(nn.Module):\n    def __init__(self, dim, reduction=1, lambda_c=.5, lambda_s=.5):\n        super(FeatureRectifyModule, self).__init__()\n        self.lambda_c = lambda_c\n        self.lambda_s = lambda_s\n        self.channel_weights = ChannelWeights(dim=dim, reduction=reduction)\n        self.spatial_weights = SpatialWeights(dim=dim, reduction=reduction)\n    \n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n    \n    def forward(self, x1, x2):\n        channel_weights = self.channel_weights(x1, x2)\n        spatial_weights = self.spatial_weights(x1, x2)\n        out_x1 = x1 + self.lambda_c * channel_weights[1] * x2 + self.lambda_s * spatial_weights[1] * x2\n        out_x2 = x2 + self.lambda_c * channel_weights[0] * x1 + self.lambda_s * spatial_weights[0] * x1\n        return out_x1, out_x2 \n\n\n# Stage 1\nclass CrossAttention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None):\n        super(CrossAttention, self).__init__()\n        assert dim % num_heads == 0, f\"dim {dim} should be divided by num_heads {num_heads}.\"\n\n        self.dim = dim\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        self.kv1 = nn.Linear(dim, dim * 2, bias=qkv_bias)\n        self.kv2 = nn.Linear(dim, dim * 2, bias=qkv_bias)\n\n    def forward(self, x1, x2):\n        B, N, C = x1.shape\n        q1 = x1.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()\n        q2 = x2.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3).contiguous()\n        k1, v1 = self.kv1(x1).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()\n        k2, v2 = self.kv2(x2).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()\n\n        ctx1 = (k1.transpose(-2, -1) @ v1) * self.scale\n        ctx1 = ctx1.softmax(dim=-2)\n        ctx2 = (k2.transpose(-2, -1) @ v2) * self.scale\n        ctx2 = ctx2.softmax(dim=-2)\n\n        x1 = (q1 @ ctx2).permute(0, 2, 1, 3).reshape(B, N, C).contiguous() \n        x2 = (q2 @ ctx1).permute(0, 2, 1, 3).reshape(B, N, C).contiguous() \n\n        return x1, x2\n\n\nclass CrossPath(nn.Module):\n    def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.channel_proj1 = nn.Linear(dim, dim // reduction * 2)\n        self.channel_proj2 = nn.Linear(dim, dim // reduction * 2)\n        self.act1 = nn.ReLU(inplace=True)\n        self.act2 = nn.ReLU(inplace=True)\n        self.cross_attn = CrossAttention(dim // reduction, num_heads=num_heads)\n        self.end_proj1 = nn.Linear(dim // reduction * 2, dim)\n        self.end_proj2 = nn.Linear(dim // reduction * 2, dim)\n        self.norm1 = norm_layer(dim)\n        self.norm2 = norm_layer(dim)\n\n    def forward(self, x1, x2):\n        y1, u1 = self.act1(self.channel_proj1(x1)).chunk(2, dim=-1)\n        y2, u2 = self.act2(self.channel_proj2(x2)).chunk(2, dim=-1)\n        v1, v2 = self.cross_attn(u1, u2)\n        y1 = torch.cat((y1, v1), dim=-1)\n        y2 = torch.cat((y2, v2), dim=-1)\n        out_x1 = self.norm1(x1 + self.end_proj1(y1))\n        out_x2 = self.norm2(x2 + self.end_proj2(y2))\n        return out_x1, out_x2\n\n\n# Stage 2\nclass ChannelEmbed(nn.Module):\n    def __init__(self, in_channels, out_channels, reduction=1, norm_layer=nn.BatchNorm2d):\n        super(ChannelEmbed, self).__init__()\n        self.out_channels = out_channels\n        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)\n        self.channel_embed = nn.Sequential(\n                        nn.Conv2d(in_channels, out_channels//reduction, kernel_size=1, bias=True),\n                        nn.Conv2d(out_channels//reduction, out_channels//reduction, kernel_size=3, stride=1, padding=1, bias=True, groups=out_channels//reduction),\n                        nn.ReLU(inplace=True),\n                        nn.Conv2d(out_channels//reduction, out_channels, kernel_size=1, bias=True),\n                        norm_layer(out_channels) \n                        )\n        self.norm = norm_layer(out_channels)\n        \n    def forward(self, x, H, W):\n        B, N, _C = x.shape\n        x = x.permute(0, 2, 1).reshape(B, _C, H, W).contiguous()\n        residual = self.residual(x)\n        x = self.channel_embed(x)\n        out = self.norm(residual + x)\n        return out\n\n\nclass FeatureFusionModule(nn.Module):\n    def __init__(self, dim, reduction=1, num_heads=None, norm_layer=nn.BatchNorm2d):\n        super().__init__()\n        self.cross = CrossPath(dim=dim, reduction=reduction, num_heads=num_heads)\n        self.channel_emb = ChannelEmbed(in_channels=dim*2, out_channels=dim, reduction=reduction, norm_layer=norm_layer)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x1, x2):\n        B, C, H, W = x1.shape\n        x1 = x1.flatten(2).transpose(1, 2)\n        x2 = x2.flatten(2).transpose(1, 2)\n        x1, x2 = self.cross(x1, x2) \n        merge = torch.cat((x1, x2), dim=-1)\n        merge = self.channel_emb(merge, H, W)\n        \n        return merge"
  },
  {
    "path": "semseg/models/modules/mspa.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd\nfrom timm.models.layers import DropPath\nfrom fvcore.nn import flop_count_table, FlopCountAnalysis\nfrom mmcv.cnn import build_norm_layer\n\nclass DWConv(nn.Module):\n    def __init__(self, dim=768):\n        super(DWConv, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n\n    def forward(self, x):\n        x = self.dwconv(x)\n        return x\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)\n        self.dwconv = DWConv(hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n\n        x = self.dwconv(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n\n        return x\n\nclass MSPoolAttention(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        pools = [3,7,11]\n        self.conv0 = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)\n        self.pool1 = nn.AvgPool2d(pools[0], stride=1, padding=pools[0]//2, count_include_pad=False)\n        self.pool2 = nn.AvgPool2d(pools[1], stride=1, padding=pools[1]//2, count_include_pad=False)\n        self.pool3 = nn.AvgPool2d(pools[2], stride=1, padding=pools[2]//2, count_include_pad=False)\n        self.conv4 = nn.Conv2d(dim, dim, 1)\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, x):\n        u = x.clone()\n        x_in = self.conv0(x)\n        x_1 = self.pool1(x_in)\n        x_2 = self.pool2(x_in)\n        x_3 = self.pool3(x_in)\n        x_out = self.sigmoid(self.conv4(x_in + x_1 + x_2 + x_3)) * u\n        return x_out + u\n\nclass MSPABlock(nn.Module):\n    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_cfg=dict(type='BN', requires_grad=True)):\n        super().__init__()\n        self.norm1 = build_norm_layer(norm_cfg, dim)[1]\n        self.attn = MSPoolAttention(dim)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = build_norm_layer(norm_cfg, dim)[1]\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        layer_scale_init_value = 1e-2\n        self.layer_scale_1 = nn.Parameter(\n            layer_scale_init_value * torch.ones((dim)), requires_grad=True)\n        self.layer_scale_2 = nn.Parameter(\n            layer_scale_init_value * torch.ones((dim)), requires_grad=True)\n\n        self.is_channel_mix = True\n        if self.is_channel_mix:\n            self.avg_pool = nn.AdaptiveAvgPool2d(1)\n            self.c_nets = nn.Sequential(\n                nn.Conv1d(1, 1, kernel_size=3, padding=1, bias=False),\n                nn.Sigmoid())\n\n    def forward(self, x):\n        x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))\n                               \n        if self.is_channel_mix:\n            x_c = self.avg_pool(x)\n            x_c = self.c_nets(x_c.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)\n            x_c = x_c.expand_as(x)\n            x_c_mix = x_c * x\n            x_mlp = self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))\n            x = x_c_mix + x_mlp\n        else:\n            x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))\n        return x\n\n\nif __name__ == '__main__':\n    x = torch.zeros(2, 224*224, 64)\n    c1 = MSDyBlock(64, 64)\n    outs = c1(x)\n    print(outs.shape)\n    print(c1)\n    print(flop_count_table(FlopCountAnalysis(c1, x)))        \n"
  },
  {
    "path": "semseg/models/modules/ppm.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\nfrom semseg.models.layers import ConvModule\n\n\nclass PPM(nn.Module):\n    \"\"\"Pyramid Pooling Module in PSPNet\n    \"\"\"\n    def __init__(self, c1, c2=128, scales=(1, 2, 3, 6)):\n        super().__init__()\n        self.stages = nn.ModuleList([\n            nn.Sequential(\n                nn.AdaptiveAvgPool2d(scale),\n                ConvModule(c1, c2, 1)\n                # ConvModule(c1, c2, 1, p=1)\n            )\n        for scale in scales])\n\n        self.bottleneck = ConvModule(c1 + c2 * len(scales), c2, 3, 1, 1)\n\n    def forward(self, x: Tensor) -> Tensor:\n        outs = []\n        for stage in self.stages:\n            outs.append(F.interpolate(stage(x), size=x.shape[-2:], mode='bilinear', align_corners=True))\n\n        outs = [x] + outs[::-1]\n        out = self.bottleneck(torch.cat(outs, dim=1))\n        return out\n\n\nif __name__ == '__main__':\n    model = PPM(512, 128)\n    x = torch.randn(2, 512, 7, 7)   \n    y = model(x)\n    print(y.shape)  # [2, 128, 7, 7]"
  },
  {
    "path": "semseg/models/modules/psa.py",
    "content": "import torch\nfrom torch import nn, Tensor\nfrom torch.nn import functional as F\n\n\nclass PSAP(nn.Module):\n    def __init__(self, c1, c2):\n        super().__init__()\n        ch = c2 // 2\n        self.conv_q_right = nn.Conv2d(c1, 1, 1, bias=False)\n        self.conv_v_right = nn.Conv2d(c1, ch, 1, bias=False)\n        self.conv_up = nn.Conv2d(ch, c2, 1, bias=False)\n\n        self.conv_q_left = nn.Conv2d(c1, ch, 1, bias=False)\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.conv_v_left = nn.Conv2d(c1, ch, 1, bias=False)\n\n    def spatial_pool(self, x: Tensor) -> Tensor:\n        input_x = self.conv_v_right(x)               # [B, C, H, W]\n        context_mask = self.conv_q_right(x)         # [B, 1, H, W]\n        B, C, _, _ = input_x.shape\n\n        input_x = input_x.view(B, C, -1)\n        context_mask = context_mask.view(B, 1, -1).softmax(dim=2)\n\n        context = input_x @ context_mask.transpose(1, 2)\n        context = self.conv_up(context.unsqueeze(-1)).sigmoid()\n        x *= context\n        return x\n\n    def channel_pool(self, x: Tensor) -> Tensor:\n        g_x = self.conv_q_left(x)\n        B, C, H, W = g_x.shape\n        avg_x = self.avg_pool(g_x).view(B, C, -1).permute(0, 2, 1)\n        theta_x = self.conv_v_left(x).view(B, C, -1)\n\n        context = avg_x @ theta_x\n        context = context.softmax(dim=2).view(B, 1, H, W).sigmoid()\n        x *= context\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.spatial_pool(x) + self.channel_pool(x)\n\n\n\nclass PSAS(nn.Module):\n    def __init__(self, c1, c2):\n        super().__init__()\n        ch = c2 // 2\n        self.conv_q_right = nn.Conv2d(c1, 1, 1, bias=False)\n        self.conv_v_right = nn.Conv2d(c1, ch, 1, bias=False)\n        self.conv_up = nn.Sequential(\n            nn.Conv2d(ch, ch // 4, 1),\n            nn.LayerNorm([ch // 4, 1, 1]),\n            nn.ReLU(),\n            nn.Conv2d(ch // 4, c2, 1)\n        )\n\n        self.conv_q_left = nn.Conv2d(c1, ch, 1, bias=False)\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.conv_v_left = nn.Conv2d(c1, ch, 1, bias=False)\n\n    def spatial_pool(self, x: Tensor) -> Tensor:\n        input_x = self.conv_v_right(x)               # [B, C, H, W]\n        context_mask = self.conv_q_right(x)     # [B, 1, H, W]\n        B, C, _, _ = input_x.shape\n\n        input_x = input_x.view(B, C, -1)\n        context_mask = context_mask.view(B, 1, -1).softmax(dim=2)\n\n        context = input_x @ context_mask.transpose(1, 2)\n        context = self.conv_up(context.unsqueeze(-1)).sigmoid()\n        x *= context\n        return x\n\n    def channel_pool(self, x: Tensor) -> Tensor:\n        g_x = self.conv_q_left(x)\n        B, C, H, W = g_x.shape\n        avg_x = self.avg_pool(g_x).view(B, C, -1).permute(0, 2, 1)\n        theta_x = self.conv_v_left(x).view(B, C, -1).softmax(dim=2)\n\n        context = avg_x @ theta_x\n        context = context.view(B, 1, H, W).sigmoid()\n        x *= context\n        return x\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.channel_pool(self.spatial_pool(x))\n\n\n\"\"\"\nPSA Module Usage\n\nclass BasicBlock(nn.Module):\n    # 2 Layer No Expansion Block\n    expansion: int = 1\n    def __init__(self, c1, c2, s=1, downsample= None) -> None:\n        super().__init__()\n        self.conv1 = nn.Conv2d(c1, c2, 3, s, 1, bias=False)\n        self.bn1 = nn.BatchNorm2d(c2)\n        self.deattn = PSAS(c2, c2)\n        self.conv2 = nn.Conv2d(c2, c2, 3, 1, 1, bias=False)\n        self.bn2 = nn.BatchNorm2d(c2)\n        self.downsample = downsample\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = self.deattn(out)\n        out = self.bn2(self.conv2(out))\n        if self.downsample is not None: identity = self.downsample(x)\n        out += identity\n        return F.relu(out)\n\n\nclass Bottleneck(nn.Module):\n    # 3 Layer 4x Expansion Block\n    expansion: int = 4\n    def __init__(self, c1, c2, s=1, downsample=None) -> None:\n        super().__init__()\n        self.conv1 = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)\n        self.bn1 = nn.BatchNorm2d(c2)\n        self.conv2 = nn.Conv2d(c2, c2, 3, s, 1, bias=False)\n        self.bn2 = nn.BatchNorm2d(c2)\n        self.deattn = PSAP(c2, c2)\n        self.conv3 = nn.Conv2d(c2, c2 * self.expansion, 1, 1, 0, bias=False)\n        self.bn3 = nn.BatchNorm2d(c2 * self.expansion)\n        self.downsample = downsample\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n        out = F.relu(self.bn1(self.conv1(x)))\n        out = F.relu(self.bn2(self.conv2(out)))\n        out = self.deattn(out)\n        out = self.bn3(self.conv3(out))\n        if self.downsample is not None: identity = self.downsample(x)\n        out += identity\n        return F.relu(out)\n\n\nresnet_settings = {\n    '18': [BasicBlock, [2, 2, 2, 2]],\n    '34': [BasicBlock, [3, 4, 6, 3]],\n    '50': [Bottleneck, [3, 4, 6, 3]],\n    '101': [Bottleneck, [3, 4, 23, 3]],\n    '152': [Bottleneck, [3, 8, 36, 3]]\n}\n\n\nclass ResNet(nn.Module):\n    def __init__(self, model_name: str = '50') -> None:\n        super().__init__()\n        assert model_name in resnet_settings.keys(), f\"ResNet model name should be in {list(resnet_settings.keys())}\"\n        block, depths = resnet_settings[model_name]\n\n        self.inplanes = 64\n        self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False)\n        self.bn1 = nn.BatchNorm2d(self.inplanes)\n        self.maxpool = nn.MaxPool2d(3, 2, 1)\n\n        self.layer1 = self._make_layer(block, 64, depths[0], s=1)\n        self.layer2 = self._make_layer(block, 128, depths[1], s=2)\n        self.layer3 = self._make_layer(block, 256, depths[2], s=2)\n        self.layer4 = self._make_layer(block, 512, depths[3], s=2)\n\n\n    def _make_layer(self, block, planes, depth, s=1) -> nn.Sequential:\n        downsample = None\n        if s != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                nn.Conv2d(self.inplanes, planes * block.expansion, 1, s, bias=False),\n                nn.BatchNorm2d(planes * block.expansion)\n            )\n        layers = nn.Sequential(\n            block(self.inplanes, planes, s, downsample),\n            *[block(planes * block.expansion, planes) for _ in range(1, depth)]\n        )\n        self.inplanes = planes * block.expansion\n        return layers\n\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.maxpool(F.relu(self.bn1(self.conv1(x))))   # [1, 64, H/4, W/4]\n        x1 = self.layer1(x)  # [1, 64/256, H/4, W/4]   \n        x2 = self.layer2(x1)  # [1, 128/512, H/8, W/8]\n        x3 = self.layer3(x2)  # [1, 256/1024, H/16, W/16]\n        x4 = self.layer4(x3)  # [1, 512/2048, H/32, W/32]\n        return x1, x2, x3, x4\n\n\nif __name__ == '__main__':\n    model = ResNet('18')\n    x = torch.zeros(2, 3, 224, 224)\n    outs = model(x)\n    for y in outs:\n        print(y.shape) \n\n\"\"\"\n        \n"
  },
  {
    "path": "semseg/optimizers.py",
    "content": "from torch import nn\nfrom torch.optim import AdamW, SGD\n\n\ndef get_optimizer(model: nn.Module, optimizer: str, lr: float, weight_decay: float = 0.01):\n    wd_params, nwd_params = [], []\n    for p in model.parameters():\n        if p.requires_grad:\n            if p.dim() == 1:\n                nwd_params.append(p)\n            else:\n                wd_params.append(p)\n    \n    params = [\n        {\"params\": wd_params},\n        {\"params\": nwd_params, \"weight_decay\": 0}\n    ]\n\n    if optimizer == 'adamw':\n        return AdamW(params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=weight_decay)\n    else:\n        return SGD(params, lr, momentum=0.9, weight_decay=weight_decay)"
  },
  {
    "path": "semseg/schedulers.py",
    "content": "import torch\nimport math\nfrom torch.optim.lr_scheduler import _LRScheduler\n\n\nclass PolyLR(_LRScheduler):\n    def __init__(self, optimizer, max_iter, decay_iter=1, power=0.9, last_epoch=-1) -> None:\n        self.decay_iter = decay_iter\n        self.max_iter = max_iter\n        self.power = power\n        super().__init__(optimizer, last_epoch=last_epoch)\n\n    def get_lr(self):\n        if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter:\n            return self.base_lrs\n        else:\n            factor = (1 - self.last_epoch / float(self.max_iter)) ** self.power\n            return [factor*lr for lr in self.base_lrs]\n\n\nclass WarmupLR(_LRScheduler):\n    def __init__(self, optimizer, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:\n        self.warmup_iter = warmup_iter\n        self.warmup_ratio = warmup_ratio\n        self.warmup = warmup\n        super().__init__(optimizer, last_epoch)\n\n    def get_lr(self):\n        ratio = self.get_lr_ratio()\n        return [ratio * lr for lr in self.base_lrs]\n\n    def get_lr_ratio(self):\n        return self.get_warmup_ratio() if self.last_epoch < self.warmup_iter else self.get_main_ratio()\n\n    def get_main_ratio(self):\n        raise NotImplementedError\n\n    def get_warmup_ratio(self):\n        assert self.warmup in ['linear', 'exp']\n        alpha = self.last_epoch / self.warmup_iter\n\n        return self.warmup_ratio + (1. - self.warmup_ratio) * alpha if self.warmup == 'linear' else self.warmup_ratio ** (1. - alpha)\n\n\nclass WarmupPolyLR(WarmupLR):\n    def __init__(self, optimizer, power, max_iter, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:\n        self.power = power\n        self.max_iter = max_iter\n        super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)\n\n    def get_main_ratio(self):\n        real_iter = self.last_epoch - self.warmup_iter\n        real_max_iter = self.max_iter - self.warmup_iter\n        alpha = real_iter / real_max_iter\n\n        return (1 - alpha) ** self.power\n\n\nclass WarmupExpLR(WarmupLR):\n    def __init__(self, optimizer, gamma, interval=1, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:\n        self.gamma = gamma\n        self.interval = interval\n        super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)\n\n    def get_main_ratio(self):\n        real_iter = self.last_epoch - self.warmup_iter\n        return self.gamma ** (real_iter // self.interval)\n\n\nclass WarmupCosineLR(WarmupLR):\n    def __init__(self, optimizer, max_iter, eta_ratio=0, warmup_iter=500, warmup_ratio=5e-4, warmup='exp', last_epoch=-1) -> None:\n        self.eta_ratio = eta_ratio\n        self.max_iter = max_iter\n        super().__init__(optimizer, warmup_iter, warmup_ratio, warmup, last_epoch)\n\n    def get_main_ratio(self):\n        real_iter = self.last_epoch - self.warmup_iter\n        real_max_iter = self.max_iter - self.warmup_iter\n        \n        return self.eta_ratio + (1 - self.eta_ratio) * (1 + math.cos(math.pi * self.last_epoch / real_max_iter)) / 2\n\n\n\n__all__ = ['polylr', 'warmuppolylr', 'warmupcosinelr', 'warmupsteplr']\n\n\ndef get_scheduler(scheduler_name: str, optimizer, max_iter: int, power: int, warmup_iter: int, warmup_ratio: float):\n    assert scheduler_name in __all__, f\"Unavailable scheduler name >> {scheduler_name}.\\nAvailable schedulers: {__all__}\"\n    if scheduler_name == 'warmuppolylr':\n        return WarmupPolyLR(optimizer, power, max_iter, warmup_iter, warmup_ratio, warmup='linear')\n    elif scheduler_name == 'warmupcosinelr':\n        return WarmupCosineLR(optimizer, max_iter, warmup_iter=warmup_iter, warmup_ratio=warmup_ratio)\n    return PolyLR(optimizer, max_iter)\n\n\nif __name__ == '__main__':\n    model = torch.nn.Conv2d(3, 16, 3, 1, 1)\n    optim = torch.optim.SGD(model.parameters(), lr=1e-3)\n\n    max_iter = 20000\n    sched = WarmupPolyLR(optim, power=0.9, max_iter=max_iter, warmup_iter=200, warmup_ratio=0.1, warmup='exp', last_epoch=-1)\n\n    lrs = []\n\n    for _ in range(max_iter):\n        lr = sched.get_lr()[0]\n        lrs.append(lr)\n        optim.step()\n        sched.step()\n\n    import matplotlib.pyplot as plt\n    import numpy as np \n\n    plt.plot(np.arange(len(lrs)), np.array(lrs))\n    plt.grid()\n    plt.show()"
  },
  {
    "path": "semseg/utils/__init__.py",
    "content": ""
  },
  {
    "path": "semseg/utils/utils.py",
    "content": "import torch\nimport numpy as np\nimport random\nimport time\nimport os\nimport sys\nimport functools\nfrom pathlib import Path\nfrom torch.backends import cudnn\nfrom torch import nn, Tensor\nfrom torch.autograd import profiler\nfrom typing import Union\nfrom torch import distributed as dist\nfrom tabulate import tabulate\nfrom semseg import models\nimport logging\nfrom fvcore.nn import flop_count_table, FlopCountAnalysis\nimport datetime\n\ndef fix_seeds(seed: int = 3407) -> None:\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n\ndef setup_cudnn() -> None:\n    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html\n    cudnn.benchmark = True\n    cudnn.deterministic = False\n\ndef time_sync() -> float:\n    if torch.cuda.is_available():\n        torch.cuda.synchronize()\n    return time.time()\n\ndef get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]):\n    tmp_model_path = Path('temp.p')\n    if isinstance(model, torch.jit.ScriptModule):\n        torch.jit.save(model, tmp_model_path)\n    else:\n        torch.save(model.state_dict(), tmp_model_path)\n    size = tmp_model_path.stat().st_size\n    os.remove(tmp_model_path)\n    return size / 1e6   # in MB\n\n@torch.no_grad()\ndef test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float:\n    with profiler.profile(use_cuda=use_cuda) as prof:\n        _ = model(inputs)\n    return prof.self_cpu_time_total / 1000  # ms\n\ndef count_parameters(model: nn.Module) -> float:\n    return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6      # in M\n\ndef setup_ddp():\n    # print(os.environ.keys())\n    if 'SLURM_PROCID' in os.environ and not 'RANK' in os.environ:\n        # --- multi nodes\n        world_size = int(os.environ['WORLD_SIZE'])\n        rank = int(os.environ[\"SLURM_PROCID\"])\n        gpus_per_node = int(os.environ[\"SLURM_GPUS_ON_NODE\"])\n        gpu = rank - gpus_per_node * (rank // gpus_per_node)\n        torch.cuda.set_device(gpu)\n        dist.init_process_group(backend=\"nccl\", world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=7200))\n    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        rank = int(os.environ['RANK'])\n        world_size = int(os.environ['WORLD_SIZE'])\n        # gpu = int(os.environ(['LOCAL_RANK']))\n        # ---\n        gpu = int(os.environ['LOCAL_RANK'])\n        torch.cuda.set_device(gpu)\n        dist.init_process_group('nccl', init_method=\"env://\",world_size=world_size, rank=rank, timeout=datetime.timedelta(seconds=7200))\n        dist.barrier()\n    else:\n        gpu = 0\n    return gpu\n\ndef cleanup_ddp():\n    if dist.is_initialized():\n        dist.destroy_process_group()\n\ndef reduce_tensor(tensor: Tensor) -> Tensor:\n    rt = tensor.clone()\n    dist.all_reduce(rt, op=dist.ReduceOp.SUM)\n    rt /= dist.get_world_size()\n    return rt\n\n@torch.no_grad()\ndef throughput(dataloader, model: nn.Module, times: int = 30):\n    model.eval()\n    images, _  = next(iter(dataloader))\n    images = images.cuda(non_blocking=True)\n    B = images.shape[0]\n    print(f\"Throughput averaged with {times} times\")\n    start = time_sync()\n    for _ in range(times):\n        model(images)\n    end = time_sync()\n\n    print(f\"Batch Size {B} throughput {times * B / (end - start)} images/s\")\n\n\ndef show_models():\n    model_names = models.__all__\n    model_variants = [list(eval(f'models.{name.lower()}_settings').keys()) for name in model_names]\n\n    print(tabulate({'Model Names': model_names, 'Model Variants': model_variants}, headers='keys'))\n\n\ndef timer(func):\n    @functools.wraps(func)\n    def wrapper_timer(*args, **kwargs):\n        tic = time.perf_counter()\n        value = func(*args, **kwargs)\n        toc = time.perf_counter()\n        elapsed_time = toc - tic\n        print(f\"Elapsed time: {elapsed_time * 1000:.2f}ms\")\n        return value\n    return wrapper_timer\n\n\n# _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO')\n# _default_level = logging.getLevelName(_default_level_name.upper())\n\ndef get_logger(log_file=None):\n    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: - %(message)s',datefmt='%Y%m%d %H:%M:%S')\n    logger = logging.getLogger()\n    # logger.setLevel(logging.DEBUG)\n    logger.setLevel(logging.INFO)\n    del logger.handlers[:]\n\n    if log_file:\n        file_handler = logging.FileHandler(log_file, mode='w')\n        # file_handler.setLevel(logging.DEBUG)\n        file_handler.setLevel(logging.INFO)\n        file_handler.setFormatter(formatter)\n        logger.addHandler(file_handler)\n\n    stream_handler = logging.StreamHandler()\n    stream_handler.setFormatter(formatter)\n    # stream_handler.setLevel(logging.DEBUG)\n    stream_handler.setLevel(logging.INFO)\n    logger.addHandler(stream_handler)\n    return logger\n\n\ndef cal_flops(model, modals, logger):\n    x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))]\n    # x = [torch.zeros(2, 3, 512, 512) for _ in range(len(modals))] #--- PGSNet\n    # x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HRFuser\n    if torch.distributed.is_initialized():\n        if 'HR' in model.module.__class__.__name__:\n            x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HorNet\n    else:\n        if 'HR' in model.__class__.__name__:\n            x = [torch.zeros(1, 3, 512, 512) for _ in range(len(modals))] # --- for HorNet\n\n    if torch.cuda.is_available:\n        x = [xi.cuda() for xi in x]\n        model = model.cuda()\n    logger.info(flop_count_table(FlopCountAnalysis(model, x)))        \n\ndef print_iou(epoch, iou, miou, acc, macc, class_names):\n    assert len(iou) == len(class_names)\n    assert len(acc) == len(class_names)\n    lines = ['\\n%-8s\\t%-8s\\t%-8s' % ('Class', 'IoU', 'Acc')]\n    for i in range(len(iou)):\n        if class_names is None:\n            cls = 'Class %d:' % (i+1)\n        else:\n            cls = '%d %s' % (i+1, class_names[i])\n        lines.append('%-8s\\t%.2f\\t%.2f' % (cls, iou[i], acc[i]))\n    lines.append('== %-8s\\t%d\\t%-8s\\t%.2f\\t%-8s\\t%.2f' % ('Epoch:', epoch, 'mean_IoU', miou, 'mean_Acc',macc))\n    line = \"\\n\".join(lines)\n    return line\n\n\ndef nchw_to_nlc(x):\n    \"\"\"Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.\n\n    Args:\n        x (Tensor): The input tensor of shape [N, C, H, W] before conversion.\n\n    Returns:\n        Tensor: The output tensor of shape [N, L, C] after conversion.\n    \"\"\"\n    assert len(x.shape) == 4\n    return x.flatten(2).transpose(1, 2).contiguous()\n\ndef nlc_to_nchw(x, hw_shape):\n    \"\"\"Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor.\n\n    Args:\n        x (Tensor): The input tensor of shape [N, L, C] before conversion.\n        hw_shape (Sequence[int]): The height and width of output feature map.\n\n    Returns:\n        Tensor: The output tensor of shape [N, C, H, W] after conversion.\n    \"\"\"\n    H, W = hw_shape\n    assert len(x.shape) == 3\n    B, L, C = x.shape\n    assert L == H * W, 'The seq_len does not match H, W'\n    return x.transpose(1, 2).reshape(B, C, H, W).contiguous()\n\ndef nlc2nchw2nlc(module, x, hw_shape, contiguous=False, **kwargs):\n    \"\"\"Convert [N, L, C] shape tensor `x` to [N, C, H, W] shape tensor. Use the\n    reshaped tensor as the input of `module`, and convert the output of\n    `module`, whose shape is.\n    [N, C, H, W], to [N, L, C].\n    Args:\n        module (Callable): A callable object the takes a tensor\n            with shape [N, C, H, W] as input.\n        x (Tensor): The input tensor of shape [N, L, C].\n        hw_shape: (Sequence[int]): The height and width of the\n            feature map with shape [N, C, H, W].\n        contiguous (Bool): Whether to make the tensor contiguous\n            after each shape transform.\n    Returns:\n        Tensor: The output tensor of shape [N, L, C].\n    Example:\n        >>> import torch\n        >>> import torch.nn as nn\n        >>> conv = nn.Conv2d(16, 16, 3, 1, 1)\n        >>> feature_map = torch.rand(4, 25, 16)\n        >>> output = nlc2nchw2nlc(conv, feature_map, (5, 5))\n    \"\"\"\n    H, W = hw_shape\n    assert len(x.shape) == 3\n    B, L, C = x.shape\n    assert L == H * W, 'The seq_len doesn\\'t match H, W'\n    if not contiguous:\n        x = x.transpose(1, 2).reshape(B, C, H, W)\n        x = module(x, **kwargs)\n        x = x.flatten(2).transpose(1, 2)\n    else:\n        x = x.transpose(1, 2).reshape(B, C, H, W).contiguous()\n        x = module(x, **kwargs)\n        x = x.flatten(2).transpose(1, 2).contiguous()\n    return x\n"
  },
  {
    "path": "semseg/utils/visualize.py",
    "content": "import torch\nimport random\nimport numpy as np\nimport torch\nimport matplotlib.pyplot as plt\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms as T\nfrom torchvision.utils import make_grid\nfrom semseg.augmentations import Compose, Normalize, RandomResizedCrop\nfrom PIL import Image, ImageDraw, ImageFont\n\n\ndef visualize_dataset_sample(dataset, root, split='val', batch_size=4):\n    transform = Compose([\n        RandomResizedCrop((512, 512), scale=(1.0, 1.0)),\n        Normalize()\n    ])\n\n    dataset = dataset(root, split=split, transform=transform)\n    dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)\n    image, label = next(iter(dataloader))\n\n    print(f\"Image Shape\\t: {image.shape}\")\n    print(f\"Label Shape\\t: {label.shape}\")\n    print(f\"Classes\\t\\t: {label.unique().tolist()}\")\n\n    label[label == -1] = 0\n    label[label == 255] = 0\n    labels = [dataset.PALETTE[lbl.to(int)].permute(2, 0, 1) for lbl in label]\n    labels = torch.stack(labels)\n\n    inv_normalize = T.Normalize(\n        mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225),\n        std=(1/0.229, 1/0.224, 1/0.225)\n    )\n    image = inv_normalize(image)\n    image *= 255\n    images = torch.vstack([image, labels])\n    \n    plt.imshow(make_grid(images, nrow=4).to(torch.uint8).numpy().transpose((1, 2, 0)))\n    plt.show()\n\n\ncolors = [\n    [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],\n    [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],\n    [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],\n    [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],\n    [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],\n    [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],\n    [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],\n    [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],\n    [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],\n    [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],\n    [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],\n    [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],\n    [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],\n    [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],\n    [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],\n    [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],\n    [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],\n    [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],\n    [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]\n]\n\n\ndef generate_palette(num_classes, background: bool = False):\n    random.shuffle(colors)\n    if background:\n        palette = [[0, 0, 0]]\n        palette += colors[:num_classes-1]\n    else:\n        palette = colors[:num_classes]\n    return np.array(palette)\n\n\ndef draw_text(image: torch.Tensor, seg_map: torch.Tensor, labels: list, fontsize: int = 15):\n    image = image.to(torch.uint8)\n    font = ImageFont.truetype(\"Helvetica.ttf\", fontsize)\n    pil_image = Image.fromarray(image.numpy())\n    draw = ImageDraw.Draw(pil_image)\n\n    indices = seg_map.unique().tolist()\n    classes = [labels[index] for index in indices]\n\n    for idx, cls in zip(indices, classes):\n        mask = seg_map == idx\n        mask = mask.squeeze().numpy()\n        center = np.median((mask == 1).nonzero(), axis=1)[::-1]\n        bbox = draw.textbbox(center, cls, font=font)\n        bbox = (bbox[0]-3, bbox[1]-3, bbox[2]+3, bbox[3]+3)\n        draw.rectangle(bbox, fill=(255, 255, 255), width=1)\n        draw.text(center, cls, fill=(0, 0, 0), font=font)\n    return pil_image"
  },
  {
    "path": "tools/infer_mm.py",
    "content": "import torch\nimport argparse\nimport yaml\nimport math\nfrom torch import Tensor\nfrom torch.nn import functional as F\nfrom pathlib import Path\nfrom torchvision import io\nfrom torchvision import transforms as T\nimport torchvision.transforms.functional as TF \nfrom semseg.models import *\nfrom semseg.datasets import *\nfrom semseg.utils.utils import timer\nfrom semseg.utils.visualize import draw_text\nimport glob\nimport os\nfrom PIL import Image, ImageDraw, ImageFont\n\n\nclass SemSeg:\n    def __init__(self, cfg) -> None:\n        # inference device cuda or cpu\n        self.device = torch.device(cfg['DEVICE'])\n\n        # get dataset classes' colors and labels\n        self.palette = eval(cfg['DATASET']['NAME']).PALETTE\n        self.labels = eval(cfg['DATASET']['NAME']).CLASSES\n\n        # initialize the model and load weights and send to device\n        self.model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], len(self.palette), cfg['DATASET']['MODALS'])\n        msg = self.model.load_state_dict(torch.load(cfg['EVAL']['MODEL_PATH'], map_location='cpu'))\n        print(msg)\n        self.model = self.model.to(self.device)\n        self.model.eval()\n\n        # preprocess parameters and transformation pipeline\n        self.size = cfg['TEST']['IMAGE_SIZE']\n        self.tf_pipeline_img = T.Compose([\n            T.Resize(self.size),\n            T.Lambda(lambda x: x / 255),\n            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),\n            T.Lambda(lambda x: x.unsqueeze(0))\n        ])\n        self.tf_pipeline_modal = T.Compose([\n            T.Resize(self.size),\n            T.Lambda(lambda x: x / 255),\n            T.Lambda(lambda x: x.unsqueeze(0))\n        ])\n\n    def postprocess(self, orig_img: Tensor, seg_map: Tensor, overlay: bool) -> Tensor:\n        seg_map = seg_map.softmax(dim=1).argmax(dim=1).cpu().to(int)\n\n        seg_image = self.palette[seg_map].squeeze()\n        if overlay: \n            seg_image = (orig_img.permute(1, 2, 0) * 0.4) + (seg_image * 0.6)\n\n        image = seg_image.to(torch.uint8)\n        pil_image = Image.fromarray(image.numpy())\n        return pil_image\n\n    @torch.inference_mode()\n    @timer\n    def model_forward(self, img: Tensor) -> Tensor:\n        return self.model(img)\n    \n    def _open_img(self, file):\n        img = io.read_image(file)\n        C, H, W = img.shape\n        if C == 4:\n            img = img[:3, ...]\n        if C == 1:\n            img = img.repeat(3, 1, 1)\n        return img\n\n    def predict(self, img_fname: str, overlay: bool) -> Tensor:\n        if cfg['DATASET']['NAME'] == 'DELIVER':\n            x1 = img_fname.replace('/img', '/hha').replace('_rgb', '_depth')\n            x2 = img_fname.replace('/img', '/lidar').replace('_rgb', '_lidar')\n            x3 = img_fname.replace('/img', '/event').replace('_rgb', '_event')\n            lbl_path = img_fname.replace('/img', '/semantic').replace('_rgb', '_semantic')\n        elif cfg['DATASET']['NAME'] == 'KITTI360':\n            x1 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_hha'))\n            x2 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_lidar'))\n            x2 = x2.replace('.png', '_color.png')\n            x3 = os.path.join(img_fname.replace('data_2d_raw', 'data_2d_event'))\n            x3 = x3.replace('/image_00/data_rect/', '/').replace('.png', '_event_image.png')\n            lbl_path = os.path.join(*[img_fname.replace('data_2d_raw', 'data_2d_semantics/train').replace('data_rect', 'semantic')])\n\n        image = io.read_image(img_fname)[:3, ...]\n        img = self.tf_pipeline_img(image).to(self.device)\n        # --- modals\n        x1 = self._open_img(x1)\n        x1 = self.tf_pipeline_modal(x1).to(self.device)\n        x2 = self._open_img(x2)\n        x2 = self.tf_pipeline_modal(x2).to(self.device)\n        x3 = self._open_img(x3)\n        x3 = self.tf_pipeline_modal(x3).to(self.device)\n        label = io.read_image(lbl_path)[0,...].unsqueeze(0)\n        label[label==255] = 0\n        label -= 1\n\n        sample = [img, x1, x2, x3][:len(modals)]\n\n        seg_map = self.model_forward(sample)\n        seg_map = self.postprocess(image, seg_map, overlay)\n        return seg_map\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--cfg', type=str, default='configs/DELIVER.yaml')\n    args = parser.parse_args()\n    with open(args.cfg) as f:\n        cfg = yaml.load(f, Loader=yaml.SafeLoader)\n\n    # cases = ['cloud', 'fog', 'night', 'rain', 'sun', 'motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres', None]\n    cases = ['lidarjitter']\n\n    modals = cfg['DATASET']['MODALS']\n\n    test_file = Path(cfg['TEST']['FILE'])\n    if not test_file.exists():\n        raise FileNotFoundError(test_file)\n\n    # print(f\"Model {cfg['MODEL']['NAME']} {cfg['MODEL']['BACKBONE']}\")\n    # print(f\"Model {cfg['DATASET']['NAME']}\")\n\n    modals_name = ''.join([m[0] for m in cfg['DATASET']['MODALS']])\n    save_dir = Path(cfg['SAVE_DIR']) / 'test_results' / (cfg['DATASET']['NAME']+'_'+cfg['MODEL']['BACKBONE']+'_'+modals_name)\n    \n    semseg = SemSeg(cfg)\n\n    if test_file.is_file():\n        segmap = semseg.predict(str(test_file), cfg['TEST']['OVERLAY'])\n        segmap.save(save_dir / f\"{str(test_file.stem)}.png\")\n    else:\n        if cfg['DATASET']['NAME'] == 'DELIVER':\n            files = sorted(glob.glob(os.path.join(*[str(test_file), 'img', '*', 'val', '*', '*.png']))) # --- Deliver\n        elif cfg['DATASET']['NAME'] == 'KITTI360':\n            source = os.path.join(test_file, 'val.txt')\n            files = []\n            with open(source) as f:\n                files_ = f.readlines()\n            for item in files_:\n                file_name = item.strip()\n                if ' ' in file_name:\n                    # --- KITTI-360\n                    file_name = os.path.join(*[str(test_file), file_name.split(' ')[0]])\n                files.append(file_name)\n        else:\n            raise NotImplementedError()\n\n        for file in files:\n            print(file)\n            if not '2013_05_28_drive_0000_sync' in file:\n                continue\n            segmap = semseg.predict(file, cfg['TEST']['OVERLAY'])\n            save_path = os.path.join(str(save_dir),file)\n            os.makedirs(os.path.dirname(save_path), exist_ok=True)\n            segmap.save(save_path)\n"
  },
  {
    "path": "tools/train_mm.py",
    "content": "import os\nimport torch \nimport argparse\nimport yaml\nimport time\nimport multiprocessing as mp\nfrom tabulate import tabulate\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\nfrom pathlib import Path\nfrom torch.utils.tensorboard import SummaryWriter\nfrom torch.cuda.amp import GradScaler, autocast\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils.data import DistributedSampler, RandomSampler\nfrom torch import distributed as dist\nfrom semseg.models import *\nfrom semseg.datasets import * \nfrom semseg.augmentations_mm import get_train_augmentation, get_val_augmentation\nfrom semseg.losses import get_loss\nfrom semseg.schedulers import get_scheduler\nfrom semseg.optimizers import get_optimizer\nfrom semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp, get_logger, cal_flops, print_iou\nfrom val_mm import evaluate\n\n\ndef main(cfg, gpu, save_dir):\n    start = time.time()\n    best_mIoU = 0.0\n    best_epoch = 0\n    num_workers = 8\n    device = torch.device(cfg['DEVICE'])\n    train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']\n    dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']\n    loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']\n    epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']\n    resume_path = cfg['MODEL']['RESUME']\n    gpus = int(os.environ['WORLD_SIZE'])\n\n    traintransform = get_train_augmentation(train_cfg['IMAGE_SIZE'], seg_fill=dataset_cfg['IGNORE_LABEL'])\n    valtransform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])\n\n    trainset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'train', traintransform, dataset_cfg['MODALS'])\n    valset = eval(dataset_cfg['NAME'])(dataset_cfg['ROOT'], 'val', valtransform, dataset_cfg['MODALS'])\n    class_names = trainset.CLASSES\n\n    model = eval(model_cfg['NAME'])(model_cfg['BACKBONE'], trainset.n_classes, dataset_cfg['MODALS'])\n    resume_checkpoint = None\n    if os.path.isfile(resume_path):\n        resume_checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))\n        msg = model.load_state_dict(resume_checkpoint['model_state_dict'])\n        # print(msg)\n        logger.info(msg)\n    else:\n        model.init_pretrained(model_cfg['PRETRAINED'])\n    model = model.to(device)\n    \n    iters_per_epoch = len(trainset) // train_cfg['BATCH_SIZE'] // gpus\n    loss_fn = get_loss(loss_cfg['NAME'], trainset.ignore_label, None)\n    start_epoch = 0\n    optimizer = get_optimizer(model, optim_cfg['NAME'], lr, optim_cfg['WEIGHT_DECAY'])\n    scheduler = get_scheduler(sched_cfg['NAME'], optimizer, int((epochs+1)*iters_per_epoch), sched_cfg['POWER'], iters_per_epoch * sched_cfg['WARMUP'], sched_cfg['WARMUP_RATIO'])\n\n    if train_cfg['DDP']: \n        sampler = DistributedSampler(trainset, dist.get_world_size(), dist.get_rank(), shuffle=True)\n        sampler_val = None\n        model = DDP(model, device_ids=[gpu], output_device=0, find_unused_parameters=True)\n    else:\n        sampler = RandomSampler(trainset)\n        sampler_val = None\n    \n    if resume_checkpoint:\n        start_epoch = resume_checkpoint['epoch'] - 1\n        optimizer.load_state_dict(resume_checkpoint['optimizer_state_dict'])\n        scheduler.load_state_dict(resume_checkpoint['scheduler_state_dict'])\n        loss = resume_checkpoint['loss']        \n        best_mIoU = resume_checkpoint['best_miou']\n           \n    trainloader = DataLoader(trainset, batch_size=train_cfg['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=False, sampler=sampler)\n    valloader = DataLoader(valset, batch_size=eval_cfg['BATCH_SIZE'], num_workers=num_workers, pin_memory=False, sampler=sampler_val)\n\n    scaler = GradScaler(enabled=train_cfg['AMP'])\n    if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):\n        writer = SummaryWriter(str(save_dir))\n        logger.info('================== model complexity =====================')\n        cal_flops(model, dataset_cfg['MODALS'], logger)\n        logger.info('================== model structure =====================')\n        logger.info(model)\n        logger.info('================== training config =====================')\n        logger.info(cfg)\n\n    for epoch in range(start_epoch, epochs):\n        model.train()\n        if train_cfg['DDP']: sampler.set_epoch(epoch)\n\n        train_loss = 0.0        \n        lr = scheduler.get_lr()\n        lr = sum(lr) / len(lr)\n        pbar = tqdm(enumerate(trainloader), total=iters_per_epoch, desc=f\"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.8f}\")\n\n        for iter, (sample, lbl) in pbar:\n            optimizer.zero_grad(set_to_none=True)\n            sample = [x.to(device) for x in sample]\n            lbl = lbl.to(device)\n            \n            with autocast(enabled=train_cfg['AMP']):\n                logits = model(sample)\n                loss = loss_fn(logits, lbl)\n\n            scaler.scale(loss).backward()\n            scaler.step(optimizer)\n            scaler.update()\n            scheduler.step()\n            torch.cuda.synchronize()\n\n            lr = scheduler.get_lr()\n            lr = sum(lr) / len(lr)\n            if lr <= 1e-8:\n                lr = 1e-8 # minimum of lr\n            train_loss += loss.item()\n\n            pbar.set_description(f\"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss / (iter+1):.8f}\")\n        \n        train_loss /= iter+1\n        if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):\n            writer.add_scalar('train/loss', train_loss, epoch)\n        torch.cuda.empty_cache()\n\n        if ((epoch+1) % train_cfg['EVAL_INTERVAL'] == 0 and (epoch+1)>train_cfg['EVAL_START']) or (epoch+1) == epochs:\n            if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):\n                acc, macc, _, _, ious, miou = evaluate(model, valloader, device)\n                writer.add_scalar('val/mIoU', miou, epoch)\n\n                if miou > best_mIoU:\n                    prev_best_ckp = save_dir / f\"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}_checkpoint.pth\"\n                    prev_best = save_dir / f\"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}.pth\"\n                    if os.path.isfile(prev_best): os.remove(prev_best)\n                    if os.path.isfile(prev_best_ckp): os.remove(prev_best_ckp)\n                    best_mIoU = miou\n                    best_epoch = epoch+1\n                    cur_best_ckp = save_dir / f\"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}_checkpoint.pth\"\n                    cur_best = save_dir / f\"{model_cfg['NAME']}_{model_cfg['BACKBONE']}_{dataset_cfg['NAME']}_epoch{best_epoch}_{best_mIoU}.pth\"\n                    torch.save(model.module.state_dict() if train_cfg['DDP'] else model.state_dict(), cur_best)\n                    # --- \n                    torch.save({'epoch': best_epoch,\n                                'model_state_dict': model.module.state_dict() if train_cfg['DDP'] else model.state_dict(),\n                                'optimizer_state_dict': optimizer.state_dict(),\n                                'loss': train_loss,\n                                'scheduler_state_dict': scheduler.state_dict(),\n                                'best_miou': best_mIoU,\n                                }, cur_best_ckp)\n                    logger.info(print_iou(epoch, ious, miou, acc, macc, class_names))\n                logger.info(f\"Current epoch:{epoch} mIoU: {miou} Best mIoU: {best_mIoU}\")\n\n    if (train_cfg['DDP'] and torch.distributed.get_rank() == 0) or (not train_cfg['DDP']):\n        writer.close()\n    pbar.close()\n    end = time.gmtime(time.time() - start)\n\n    table = [\n        ['Best mIoU', f\"{best_mIoU:.2f}\"],\n        ['Total Training Time', time.strftime(\"%H:%M:%S\", end)]\n    ]\n    logger.info(tabulate(table, numalign='right'))\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--cfg', type=str, default='configs/deliver_rgbdel.yaml', help='Configuration file to use')\n    args = parser.parse_args()\n\n    with open(args.cfg) as f:\n        cfg = yaml.load(f, Loader=yaml.SafeLoader)\n\n    fix_seeds(3407)\n    setup_cudnn()\n    gpu = setup_ddp()\n    modals = ''.join([m[0] for m in cfg['DATASET']['MODALS']])\n    model = cfg['MODEL']['BACKBONE']\n    exp_name = '_'.join([cfg['DATASET']['NAME'], model, modals])\n    save_dir = Path(cfg['SAVE_DIR'], exp_name)\n    if os.path.isfile(cfg['MODEL']['RESUME']):\n        save_dir =  Path(os.path.dirname(cfg['MODEL']['RESUME']))\n    os.makedirs(save_dir, exist_ok=True)\n    logger = get_logger(save_dir / 'train.log')\n    main(cfg, gpu, save_dir)\n    cleanup_ddp()"
  },
  {
    "path": "tools/val_mm.py",
    "content": "import torch\nimport argparse\nimport yaml\nimport math\nimport os\nimport time\nfrom pathlib import Path\nfrom tqdm import tqdm\nfrom tabulate import tabulate\nfrom torch.utils.data import DataLoader\nfrom torch.nn import functional as F\nfrom semseg.models import *\nfrom semseg.datasets import *\nfrom semseg.augmentations_mm import get_val_augmentation\nfrom semseg.metrics import Metrics\nfrom semseg.utils.utils import setup_cudnn\nfrom math import ceil\nimport numpy as np\nfrom torch.utils.data import DistributedSampler, RandomSampler\nfrom torch import distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom semseg.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp, get_logger, cal_flops, print_iou\n\ndef pad_image(img, target_size):\n    rows_to_pad = max(target_size[0] - img.shape[2], 0)\n    cols_to_pad = max(target_size[1] - img.shape[3], 0)\n    padded_img = F.pad(img, (0, cols_to_pad, 0, rows_to_pad), \"constant\", 0)\n    return padded_img\n\n@torch.no_grad()\ndef sliding_predict(model, image, num_classes, flip=True):\n    image_size = image[0].shape\n    tile_size = (int(ceil(image_size[2]*1)), int(ceil(image_size[3]*1)))\n    overlap = 1/3\n\n    stride = ceil(tile_size[0] * (1 - overlap))\n    \n    num_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1)\n    num_cols = int(ceil((image_size[3] - tile_size[1]) / stride) + 1)\n    total_predictions = torch.zeros((num_classes, image_size[2], image_size[3]), device=torch.device('cuda'))\n    count_predictions = torch.zeros((image_size[2], image_size[3]), device=torch.device('cuda'))\n    tile_counter = 0\n\n    for row in range(num_rows):\n        for col in range(num_cols):\n            x_min, y_min = int(col * stride), int(row * stride)\n            x_max = min(x_min + tile_size[1], image_size[3])\n            y_max = min(y_min + tile_size[0], image_size[2])\n\n            img = [modal[:, :, y_min:y_max, x_min:x_max] for modal in image]\n            padded_img = [pad_image(modal, tile_size) for modal in img]\n            tile_counter += 1\n            padded_prediction = model(padded_img)\n            if flip:\n                fliped_img = [padded_modal.flip(-1) for padded_modal in padded_img]\n                fliped_predictions = model(fliped_img)\n                padded_prediction += fliped_predictions.flip(-1)\n            predictions = padded_prediction[:, :, :img[0].shape[2], :img[0].shape[3]]\n            count_predictions[y_min:y_max, x_min:x_max] += 1\n            total_predictions[:, y_min:y_max, x_min:x_max] += predictions.squeeze(0)\n\n    return total_predictions.unsqueeze(0)\n\n@torch.no_grad()\ndef evaluate(model, dataloader, device):\n    print('Evaluating...')\n    model.eval()\n    n_classes = dataloader.dataset.n_classes\n    metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device)\n    sliding = False\n    for images, labels in tqdm(dataloader):\n        images = [x.to(device) for x in images]\n        labels = labels.to(device)\n        if sliding:\n            preds = sliding_predict(model, images, num_classes=n_classes).softmax(dim=1)\n        else:\n            preds = model(images).softmax(dim=1)\n        metrics.update(preds, labels)\n    \n    ious, miou = metrics.compute_iou()\n    acc, macc = metrics.compute_pixel_acc()\n    f1, mf1 = metrics.compute_f1()\n    \n    return acc, macc, f1, mf1, ious, miou\n\n\n@torch.no_grad()\ndef evaluate_msf(model, dataloader, device, scales, flip):\n    model.eval()\n\n    n_classes = dataloader.dataset.n_classes\n    metrics = Metrics(n_classes, dataloader.dataset.ignore_label, device)\n\n    for images, labels in tqdm(dataloader):\n        labels = labels.to(device)\n        B, H, W = labels.shape\n        scaled_logits = torch.zeros(B, n_classes, H, W).to(device)\n\n        for scale in scales:\n            new_H, new_W = int(scale * H), int(scale * W)\n            new_H, new_W = int(math.ceil(new_H / 32)) * 32, int(math.ceil(new_W / 32)) * 32\n            scaled_images = [F.interpolate(img, size=(new_H, new_W), mode='bilinear', align_corners=True) for img in images]\n            scaled_images = [scaled_img.to(device) for scaled_img in scaled_images]\n            logits = model(scaled_images)\n            logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)\n            scaled_logits += logits.softmax(dim=1)\n\n            if flip:\n                scaled_images = [torch.flip(scaled_img, dims=(3,)) for scaled_img in scaled_images]\n                logits = model(scaled_images)\n                logits = torch.flip(logits, dims=(3,))\n                logits = F.interpolate(logits, size=(H, W), mode='bilinear', align_corners=True)\n                scaled_logits += logits.softmax(dim=1)\n\n        metrics.update(scaled_logits, labels)\n    \n    acc, macc = metrics.compute_pixel_acc()\n    f1, mf1 = metrics.compute_f1()\n    ious, miou = metrics.compute_iou()\n    return acc, macc, f1, mf1, ious, miou\n\n\ndef main(cfg):\n    device = torch.device(cfg['DEVICE'])\n\n    eval_cfg = cfg['EVAL']\n    transform = get_val_augmentation(eval_cfg['IMAGE_SIZE'])\n    # cases = ['cloud', 'fog', 'night', 'rain', 'sun']\n    # cases = ['motionblur', 'overexposure', 'underexposure', 'lidarjitter', 'eventlowres']\n    cases = [None] # all\n    \n    model_path = Path(eval_cfg['MODEL_PATH'])\n    if not model_path.exists(): \n        raise FileNotFoundError\n    print(f\"Evaluating {model_path}...\")\n\n    exp_time = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    eval_path = os.path.join(os.path.dirname(eval_cfg['MODEL_PATH']), 'eval_{}.txt'.format(exp_time))\n\n    for case in cases:\n        dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'val', transform, cfg['DATASET']['MODALS'], case)\n        # --- test set\n        # dataset = eval(cfg['DATASET']['NAME'])(cfg['DATASET']['ROOT'], 'test', transform, cfg['DATASET']['MODALS'], case)\n\n        model = eval(cfg['MODEL']['NAME'])(cfg['MODEL']['BACKBONE'], dataset.n_classes, cfg['DATASET']['MODALS'])\n        msg = model.load_state_dict(torch.load(str(model_path), map_location='cpu'))\n        print(msg)\n        model = model.to(device)\n        sampler_val = None\n        dataloader = DataLoader(dataset, batch_size=eval_cfg['BATCH_SIZE'], num_workers=eval_cfg['BATCH_SIZE'], pin_memory=False, sampler=sampler_val)\n        if True:\n            if eval_cfg['MSF']['ENABLE']:\n                acc, macc, f1, mf1, ious, miou = evaluate_msf(model, dataloader, device, eval_cfg['MSF']['SCALES'], eval_cfg['MSF']['FLIP'])\n            else:\n                acc, macc, f1, mf1, ious, miou = evaluate(model, dataloader, device)\n\n            table = {\n                'Class': list(dataset.CLASSES) + ['Mean'],\n                'IoU': ious + [miou],\n                'F1': f1 + [mf1],\n                'Acc': acc + [macc]\n            }\n            print(\"mIoU : {}\".format(miou))\n            print(\"Results saved in {}\".format(eval_cfg['MODEL_PATH']))\n\n        with open(eval_path, 'a+') as f:\n            f.writelines(eval_cfg['MODEL_PATH'])\n            f.write(\"\\n============== Eval on {} {} images =================\\n\".format(case, len(dataset)))\n            f.write(\"\\n\")\n            print(tabulate(table, headers='keys'), file=f)\n\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--cfg', type=str, default='configs/DELIVER.yaml')\n    args = parser.parse_args()\n\n    with open(args.cfg) as f:\n        cfg = yaml.load(f, Loader=yaml.SafeLoader)\n\n    setup_cudnn()\n    # gpu = setup_ddp()\n    # main(cfg, gpu)\n    main(cfg)"
  }
]