[
  {
    "path": ".gitignore",
    "content": "# compilation and distribution\n__pycache__\n_ext\n*.pyc\n*.so\nbuild/\ndist/\n*.egg-info/\n\n# Pycharm editor settings\n.idea\n"
  },
  {
    "path": "LICENSE",
    "content": "   xMUDA\n\n   Copyright 2020 Valeo\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       https://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n\n\n                                    Apache License\n                           Version 2.0, January 2004\n                        https://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n"
  },
  {
    "path": "README.md",
    "content": "## [Updated code](https://github.com/valeoai/xmuda_journal) from our TPAMI paper.\n\n# xMUDA: Cross-Modal Unsupervised Domain Adaptation for 3D Semantic Segmentation\n\nOfficial code for the paper.\n\n## Paper\n![](./teaser.png)\n\n[xMUDA: Cross-Modal Unsupervised Domain Adaptation for 3D Semantic Segmentation](https://arxiv.org/abs/1911.12676)  \n [Maximilian Jaritz](https://team.inria.fr/rits/membres/maximilian-jaritz/), [Tuan-Hung Vu](https://tuanhungvu.github.io/), [Raoul de Charette](https://team.inria.fr/rits/membres/raoul-de-charette/),  Émilie Wirbel, [Patrick Pérez](https://ptrckprz.github.io/)  \n Inria, valeo.ai\n CVPR 2020\n\nIf you find this code useful for your research, please cite our [paper](https://arxiv.org/abs/1911.12676):\n\n```\n@inproceedings{jaritz2019xmuda,\n\ttitle={{xMUDA}: Cross-Modal Unsupervised Domain Adaptation for {3D} Semantic Segmentation},\n\tauthor={Jaritz, Maximilian and Vu, Tuan-Hung and de Charette, Raoul and Wirbel, Emilie and P{\\'e}rez, Patrick},\n\tbooktitle={CVPR},\n\tyear={2020}\n}\n```\n## Preparation\n### Prerequisites\nTested with\n* PyTorch 1.4\n* CUDA 10.0\n* Python 3.8\n* [SparseConvNet](https://github.com/facebookresearch/SparseConvNet)\n* [nuscenes-devkit](https://github.com/nutonomy/nuscenes-devkit)\n\n### Installation\nAs 3D network we use SparseConvNet. It requires to use CUDA 10.0 (it did not work with 10.1 when we tried).\nWe advise to create a new conda environment for installation. PyTorch and CUDA can be installed, and SparseConvNet\ninstalled/compiled as follows:\n```\n$ conda install pytorch torchvision cudatoolkit=10.0 -c pytorch\n$ pip install --upgrade git+https://github.com/facebookresearch/SparseConvNet.git\n```\n\nClone this repository and install it with pip. It will automatically install the nuscenes-devkit as a dependency.\n```\n$ git clone https://github.com/valeoai/xmuda.git\n$ cd xmuda\n$ pip install -ve .\n```\nThe `-e` option means that you can edit the code on the fly.\n\n### Datasets\n#### NuScenes\nPlease download the Full dataset (v1.0) from the [NuScenes website](https://www.nuscenes.org) and extract it.\n\nYou need to perform preprocessing to generate the data for xMUDA first.\nThe preprocessing subsamples the 360° LiDAR point cloud to only keep the points that project into\nthe front camera image. It also generates the point-wise segmentation labels using\nthe 3D objects by checking which points lie inside the 3D boxes. \nAll information will be stored in a pickle file (except the images which will be \nread frame by frame by the dataloader during training).\n\nPlease edit the script `xmuda/data/nuscenes/preprocess.py` as follows and then run it.\n* `root_dir` should point to the root directory of the NuScenes dataset\n* `out_dir` should point to the desired output directory to store the pickle files\n\n#### A2D2\nPlease download the Semantic Segmentation dataset and Sensor Configuration from the\n[Audi website](https://www.a2d2.audi/a2d2/en/download.html) or directly use `wget` and\nthe following links, then extract.\n```\n$ wget https://aev-autonomous-driving-dataset.s3.eu-central-1.amazonaws.com/camera_lidar_semantic.tar\n$ wget https://aev-autonomous-driving-dataset.s3.eu-central-1.amazonaws.com/cams_lidars.json\n```\n\nThe dataset directory should have this basic structure:\n```\na2d2                                   % A2D2 dataset root\n ├── 20180807_145028\n ├── 20180810_142822\n ├── ...\n ├── cams_lidars.json\n └── class_list.json\n```\nFor preprocessing, we undistort the images and store them separately as .png files.\nSimilar to NuScenes preprocessing, we save all points that project into the front camera image as well\nas the segmentation labels to a pickle file.\n\nPlease edit the script `xmuda/data/a2d2/preprocess.py` as follows and then run it.\n* `root_dir` should point to the root directory of the A2D2 dataset\n* `out_dir` should point to the desired output directory to store the undistorted images and pickle files.\nIt should be set differently than the `root_dir` to prevent overwriting of images.\n\n#### SemanticKITTI\nPlease download the files from the [SemanticKITTI website](http://semantic-kitti.org/dataset.html) and\nadditionally the [color data](http://www.cvlibs.net/download.php?file=data_odometry_color.zip)\nfrom the [Kitti Odometry website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php). Extract\neverything into the same folder.\n\nSimilar to NuScenes preprocessing, we save all points that project into the front camera image as well\nas the segmentation labels to a pickle file.\n\nPlease edit the script `xmuda/data/semantic_kitti/preprocess.py` as follows and then run it.\n* `root_dir` should point to the root directory of the SemanticKITTI dataset\n* `out_dir` should point to the desired output directory to store the pickle files\n\n## Training\n### xMUDA\nYou can run the training with\n```\n$ cd <root dir of this repo>\n$ python xmuda/train_xmuda.py --cfg=configs/nuscenes/usa_singapore/xmuda.yaml\n```\n\nThe output will be written to `/home/<user>/workspace/outputs/xmuda/<config_path>` by \ndefault. The `OUTPUT_DIR` can be modified in the config file in\n(e.g. `configs/nuscenes/usa_singapore/xmuda.yaml`) or optionally at run time in the\ncommand line (dominates over config file). Note that `@` in the following example will be\nautomatically replaced with the config path, i.e. with `nuscenes/usa_singapore/xmuda`.\n```\n$ python xmuda/train_xmuda.py --cfg=configs/nuscenes/usa_singapore/xmuda.yaml OUTPUT_DIR path/to/output/directory/@\n```\n\nYou can start the trainings on the other UDA scenarios (Day/Night and A2D2/SemanticKITTI) analogously:\n```\n$ python xmuda/train_xmuda.py --cfg=configs/nuscenes/day_night/xmuda.yaml\n$ python xmuda/train_xmuda.py --cfg=configs/a2d2_semantic_kitti/xmuda.yaml\n```\n\n### xMUDA<sub>PL</sub>\nAfter having trained the xMUDA model, generate the pseudo-labels as follows:\n```\n$ python xmuda/test.py --cfg=configs/nuscenes/usa_singapore/xmuda.yaml --pselab @/model_2d_100000.pth @/model_3d_100000.pth DATASET_TARGET.TEST \"('train_singapore',)\"\n```\nNote that we use the last model at 100,000 steps to exclude supervision from the validation set by picking the best\nweights. The pseudo labels and maximum probabilities are saved as `.npy` file.\n\nPlease edit the `pselab_paths` in the config file, e.g. `configs/nuscenes/usa_singapore/xmuda_pl.yaml`,\nto match your path of the generated pseudo-labels.\n\nThen start the training. The pseudo-label refinement (discard less confident pseudo-labels) is done\nwhen the dataloader is initialized.\n```\n$ python xmuda/train_xmuda.py --cfg=configs/nuscenes/usa_singapore/xmuda_pl.yaml\n```\n\nYou can start the trainings on the other UDA scenarios (Day/Night and A2D2/SemanticKITTI) analogously:\n```\n$ python xmuda/test.py --cfg=configs/nuscenes/day_night/xmuda.yaml --pselab @/model_2d_100000.pth @/model_3d_100000.pth DATASET_TARGET.TEST \"('train_night',)\"\n$ python xmuda/train_xmuda.py --cfg=configs/nuscenes/day_night/xmuda_pl.yaml\n\n# use batch size 1, because of different image sizes Kitti\n$ python xmuda/test.py --cfg=configs/a2d2_semantic_kitti/xmuda.yaml --pselab @/model_2d_100000.pth @/model_3d_100000.pth DATASET_TARGET.TEST \"('train',)\" VAL.BATCH_SIZE 1\n$ python xmuda/train_xmuda.py --cfg=configs/a2d2_semantic_kitti/xmuda_pl.yaml\n```\n\n### Baseline\nTrain the baselines (only on source) with:\n```\n$ python xmuda/train_baseline.py --cfg=configs/nuscenes/usa_singapore/baseline.yaml\n$ python xmuda/train_baseline.py --cfg=configs/nuscenes/day_night/baseline.yaml\n$ python xmuda/train_baseline.py --cfg=configs/a2d2_semantic_kitti/baseline.yaml\n```\n\n## Testing\nYou can provide which checkpoints you want to use for testing. We used the ones\nthat performed best on the validation set during training (the best val iteration for 2D and 3D is\nshown at the end of each training). Note that `@` will be replaced\nby the output directory for that config file. For example:\n```\n$ cd <root dir of this repo>\n$ python xmuda/test.py --cfg=configs/nuscenes/usa_singapore/xmuda.yaml @/model_2d_065000.pth @/model_3d_095000.pth\n```\nYou can also provide an absolute path without `@`. \n\n## Model Zoo\n\nYou can download the models with the scores below from\n[this Google drive folder](https://drive.google.com/drive/folders/16MTKz4LOIwqQc3Vo6LAGrpiIC72hvggc?usp=sharing).\n\n| Method | USA/Singapore 2D | USA/Singapore 3D | Day/Night 2D | Day/Night 3D | A2D2/Sem.KITTI 2D | A2D2/Sem.KITTI 3D |\n| --- | --- | --- | --- | --- | --- |  --- | \n| Baseline (source only)  | 53.4 | 46.5 | 42.2 | 41.2 | 34.2<sup>*</sup> | 35.9<sup>*</sup>\n| xMUDA  | 59.3 | 52.0 | 46.2 | 44.2 | 38.3<sup>*</sup> | 46.0<sup>*</sup>\n| xMUDA<sub>PL</sub> |61.1 | 54.1 | 47.1 | 46.7 | 41.2<sup>*</sup> | 49.8<sup>*</sup>\n\n<sup>*</sup> Slight differences from the paper on A2D2/Sem.KITTI: Now we use class weights computed on source.\nIn the paper, we falsely computed class weights on the target domain.\n\n## Acknowledgements\nNote that this code borrows from the [MVPNet](https://github.com/maxjaritz/mvpnet) repo.\n\n## License\nxMUDA is released under the [Apache 2.0 license](./LICENSE).\n"
  },
  {
    "path": "configs/a2d2_semantic_kitti/baseline.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\n  NUM_CLASSES: 10\nMODEL_3D:\n  TYPE: \"SCN\"\n  NUM_CLASSES: 10\nDATASET_SOURCE:\n  TYPE: \"A2D2SCN\"\n  TRAIN: (\"train\",)\n  A2D2SCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/a2d2_preprocess\"\nDATASET_TARGET:\n  TYPE: \"SemanticKITTISCN\"\n  VAL: (\"val\",)\n  TEST: (\"test\",)\n  SemanticKITTISCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/semantic_kitti_preprocess/preprocess\"\n    semantic_kitti_dir: \"/datasets_local/datasets_mjaritz/semantic_kitti_preprocess\"\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [1.89090012, 2.0585112, 3.1970535, 3.1111633, 1., 2.93751704, 1.92053733,\n                  1.47886874, 1.04654198, 1.78266561]\nVAL:\n  BATCH_SIZE: 8\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "configs/a2d2_semantic_kitti/xmuda.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\n  DUAL_HEAD: True\n  NUM_CLASSES: 10\nMODEL_3D:\n  TYPE: \"SCN\"\n  DUAL_HEAD: True\n  NUM_CLASSES: 10\nDATASET_SOURCE:\n  TYPE: \"A2D2SCN\"\n  TRAIN: (\"train\",)\n  A2D2SCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/a2d2_preprocess\"\nDATASET_TARGET:\n  TYPE: \"SemanticKITTISCN\"\n  TRAIN: (\"train\",)\n  VAL: (\"val\",)\n  TEST: (\"test\",)\n  SemanticKITTISCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/semantic_kitti_preprocess/preprocess\"\n    semantic_kitti_dir: \"/datasets_local/datasets_mjaritz/semantic_kitti_preprocess\"\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [1.89090012, 2.0585112, 3.1970535, 3.1111633, 1., 2.93751704, 1.92053733,\n                  1.47886874, 1.04654198, 1.78266561]\n  XMUDA:\n    lambda_xm_src: 0.1\n    lambda_xm_trg: 0.01\nVAL:\n  BATCH_SIZE: 2\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "configs/a2d2_semantic_kitti/xmuda_pl.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\n  DUAL_HEAD: True\n  NUM_CLASSES: 10\nMODEL_3D:\n  TYPE: \"SCN\"\n  DUAL_HEAD: True\n  NUM_CLASSES: 10\nDATASET_SOURCE:\n  TYPE: \"A2D2SCN\"\n  TRAIN: (\"train\",)\n  A2D2SCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/a2d2_preprocess\"\nDATASET_TARGET:\n  TYPE: \"SemanticKITTISCN\"\n  TRAIN: (\"train\",)\n  VAL: (\"val\",)\n  TEST: (\"test\",)\n  SemanticKITTISCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/semantic_kitti_preprocess/preprocess\"\n    semantic_kitti_dir: \"/datasets_local/datasets_mjaritz/semantic_kitti_preprocess\"\n    pselab_paths: (\"/home/docker_user/workspace/outputs/xmuda/a2d2_semantic_kitti/xmuda/pselab_data/train.npy\",)\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [1.89090012, 2.0585112, 3.1970535, 3.1111633, 1., 2.93751704, 1.92053733,\n                  1.47886874, 1.04654198, 1.78266561]\n  XMUDA:\n    lambda_xm_src: 0.1\n    lambda_xm_trg: 0.01\n    lambda_pl: 1.0\nVAL:\n  BATCH_SIZE: 2\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "configs/nuscenes/day_night/baseline.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\nMODEL_3D:\n  TYPE: \"SCN\"\nDATASET_SOURCE:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_day\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATASET_TARGET:\n  TYPE: \"NuScenesSCN\"\n  VAL: (\"val_night\",)\n  TEST: (\"test_night\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [2.68678412, 4.36182969, 5.47896839, 3.89026883, 1.]\nVAL:\n  BATCH_SIZE: 32\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "configs/nuscenes/day_night/xmuda.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\n  DUAL_HEAD: True\nMODEL_3D:\n  TYPE: \"SCN\"\n  DUAL_HEAD: True\nDATASET_SOURCE:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_day\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATASET_TARGET:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_night\",)\n  VAL: (\"val_night\",)\n  TEST: (\"test_night\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [2.68678412, 4.36182969, 5.47896839, 3.89026883, 1.]\n  XMUDA:\n    lambda_xm_src: 1.0\n    lambda_xm_trg: 0.1\nVAL:\n  BATCH_SIZE: 32\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "configs/nuscenes/day_night/xmuda_pl.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\n  DUAL_HEAD: True\nMODEL_3D:\n  TYPE: \"SCN\"\n  DUAL_HEAD: True\nDATASET_SOURCE:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_day\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATASET_TARGET:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_night\",)\n  VAL: (\"val_night\",)\n  TEST: (\"test_night\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\n    pselab_paths: (\"/home/docker_user/workspace/outputs/xmuda/nuscenes/day_night/xmuda/pselab_data/train_night.npy\",)\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [2.68678412, 4.36182969, 5.47896839, 3.89026883, 1.]\n  XMUDA:\n    lambda_xm_src: 1.0\n    lambda_xm_trg: 0.1\n    lambda_pl: 1.0\nVAL:\n  BATCH_SIZE: 32\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "configs/nuscenes/usa_singapore/baseline.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\nMODEL_3D:\n  TYPE: \"SCN\"\nDATASET_SOURCE:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_usa\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATASET_TARGET:\n  TYPE: \"NuScenesSCN\"\n  VAL: (\"val_singapore\",)\n  TEST: (\"test_singapore\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [2.47956584, 4.26788384, 5.71114131, 3.80241668, 1.]\nVAL:\n  BATCH_SIZE: 32\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "configs/nuscenes/usa_singapore/xmuda.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\n  DUAL_HEAD: True\nMODEL_3D:\n  TYPE: \"SCN\"\n  DUAL_HEAD: True\nDATASET_SOURCE:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_usa\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATASET_TARGET:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_singapore\",)\n  VAL: (\"val_singapore\",)\n  TEST: (\"test_singapore\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [2.47956584, 4.26788384, 5.71114131, 3.80241668, 1.]\n  XMUDA:\n    lambda_xm_src: 1.0\n    lambda_xm_trg: 0.1\nVAL:\n  BATCH_SIZE: 32\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "configs/nuscenes/usa_singapore/xmuda_pl.yaml",
    "content": "MODEL_2D:\n  TYPE: \"UNetResNet34\"\n  DUAL_HEAD: True\nMODEL_3D:\n  TYPE: \"SCN\"\n  DUAL_HEAD: True\nDATASET_SOURCE:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_usa\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\nDATASET_TARGET:\n  TYPE: \"NuScenesSCN\"\n  TRAIN: (\"train_singapore\",)\n  VAL: (\"val_singapore\",)\n  TEST: (\"test_singapore\",)\n  NuScenesSCN:\n    preprocess_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess\"\n    nuscenes_dir: \"/datasets_local/datasets_mjaritz/nuscenes_preprocess\"  # only front cam images are needed\n    pselab_paths: (\"/home/docker_user/workspace/outputs/xmuda/nuscenes/usa_singapore/xmuda/pselab_data/train_singapore.npy\",)\nDATALOADER:\n  NUM_WORKERS: 4\nOPTIMIZER:\n  TYPE: \"Adam\"\n  BASE_LR: 0.001\nSCHEDULER:\n  TYPE: \"MultiStepLR\"\n  MultiStepLR:\n    gamma: 0.1\n    milestones: (80000, 90000)\n  MAX_ITERATION: 100000\nTRAIN:\n  BATCH_SIZE: 8\n  SUMMARY_PERIOD: 50\n  CHECKPOINT_PERIOD: 5000\n  CLASS_WEIGHTS: [2.47956584, 4.26788384, 5.71114131, 3.80241668, 1.]\n  XMUDA:\n    lambda_xm_src: 1.0\n    lambda_xm_trg: 0.1\n    lambda_pl: 1.0\nVAL:\n  BATCH_SIZE: 32\n  PERIOD: 5000\n#OUTPUT_DIR: \"path/to/output/directory/@\"  #  @ will be replaced with config path, e.g. nuscenes/usa_singapore/xmuda"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup\nfrom setuptools import find_packages\n\nexclude_dirs = (\"configs\",)\n\n# for install, do: pip install -ve .\n\nsetup(\n    name='xmuda',\n    version=\"0.0.1\",\n    url=\"https://github.com/maxjaritz/xmuda\",\n    description=\"xMUDA: Cross-Modal Unsupervised Domain Adaptation for 3D Semantic Segmentation\",\n    install_requires=['yacs', 'nuscenes-devkit', 'tabulate'],\n    packages=find_packages(exclude=exclude_dirs),\n)"
  },
  {
    "path": "xmuda/common/config/__init__.py",
    "content": "from yacs.config import CfgNode\n\n\ndef purge_cfg(cfg: CfgNode):\n    \"\"\"Purge configuration for clean logs and logical check.\n    If a CfgNode has 'TYPE' attribute, its CfgNode children the key of which do not contain 'TYPE' will be removed.\n    \"\"\"\n    target_key = cfg.get('TYPE', None)\n    removed_keys = []\n    for k, v in cfg.items():\n        if isinstance(v, CfgNode):\n            if target_key is not None and (k != target_key):\n                removed_keys.append(k)\n            else:\n                purge_cfg(v)\n    for k in removed_keys:\n        del cfg[k]\n"
  },
  {
    "path": "xmuda/common/config/base.py",
    "content": "\"\"\"Basic experiments configuration\nFor different tasks, a specific configuration might be created by importing this basic config.\n\"\"\"\n\nfrom yacs.config import CfgNode as CN\n\n# ---------------------------------------------------------------------------- #\n# Config definition\n# ---------------------------------------------------------------------------- #\n_C = CN()\n\n# ---------------------------------------------------------------------------- #\n# Resume\n# ---------------------------------------------------------------------------- #\n# Automatically resume weights from last checkpoints\n_C.AUTO_RESUME = True\n# Whether to resume the optimizer and the scheduler\n_C.RESUME_STATES = True\n# Path of weights to resume\n_C.RESUME_PATH = ''\n\n# ---------------------------------------------------------------------------- #\n# Model\n# ---------------------------------------------------------------------------- #\n_C.MODEL = CN()\n_C.MODEL.TYPE = ''\n\n# ---------------------------------------------------------------------------- #\n# DataLoader\n# ---------------------------------------------------------------------------- #\n_C.DATALOADER = CN()\n# Number of data loading threads\n_C.DATALOADER.NUM_WORKERS = 0\n# Whether to drop last\n_C.DATALOADER.DROP_LAST = True\n\n# ---------------------------------------------------------------------------- #\n# Optimizer\n# ---------------------------------------------------------------------------- #\n_C.OPTIMIZER = CN()\n_C.OPTIMIZER.TYPE = ''\n\n# Basic parameters of the optimizer\n# Note that the learning rate should be changed according to batch size\n_C.OPTIMIZER.BASE_LR = 0.001\n_C.OPTIMIZER.WEIGHT_DECAY = 0.0\n\n# Specific parameters of optimizers\n_C.OPTIMIZER.SGD = CN()\n_C.OPTIMIZER.SGD.momentum = 0.9\n_C.OPTIMIZER.SGD.dampening = 0.0\n\n_C.OPTIMIZER.Adam = CN()\n_C.OPTIMIZER.Adam.betas = (0.9, 0.999)\n\n# ---------------------------------------------------------------------------- #\n# Scheduler (learning rate schedule)\n# ---------------------------------------------------------------------------- #\n_C.SCHEDULER = CN()\n_C.SCHEDULER.TYPE = ''\n\n_C.SCHEDULER.MAX_ITERATION = 1\n# Minimum learning rate. 0.0 for disable.\n_C.SCHEDULER.CLIP_LR = 0.0\n\n# Specific parameters of schedulers\n_C.SCHEDULER.StepLR = CN()\n_C.SCHEDULER.StepLR.step_size = 0\n_C.SCHEDULER.StepLR.gamma = 0.1\n\n_C.SCHEDULER.MultiStepLR = CN()\n_C.SCHEDULER.MultiStepLR.milestones = ()\n_C.SCHEDULER.MultiStepLR.gamma = 0.1\n\n# ---------------------------------------------------------------------------- #\n# Specific train options\n# ---------------------------------------------------------------------------- #\n_C.TRAIN = CN()\n\n# Batch size\n_C.TRAIN.BATCH_SIZE = 1\n# Period to save checkpoints. 0 for disable\n_C.TRAIN.CHECKPOINT_PERIOD = 0\n# Period to log training status. 0 for disable\n_C.TRAIN.LOG_PERIOD = 50\n# Period to summary training status. 0 for disable\n_C.TRAIN.SUMMARY_PERIOD = 0\n# Max number of checkpoints to keep\n_C.TRAIN.MAX_TO_KEEP = 100\n\n# Regex patterns of modules and/or parameters to freeze\n_C.TRAIN.FROZEN_PATTERNS = ()\n\n# ---------------------------------------------------------------------------- #\n# Specific validation options\n# ---------------------------------------------------------------------------- #\n_C.VAL = CN()\n\n# Batch size\n_C.VAL.BATCH_SIZE = 1\n# Period to validate. 0 for disable\n_C.VAL.PERIOD = 0\n# Period to log validation status. 0 for disable\n_C.VAL.LOG_PERIOD = 20\n# The metric for best validation performance\n_C.VAL.METRIC = ''\n\n# ---------------------------------------------------------------------------- #\n# Misc options\n# ---------------------------------------------------------------------------- #\n# if set to @, the filename of config will be used by default\n_C.OUTPUT_DIR = '@'\n\n# For reproducibility...but not really because modern fast GPU libraries use\n# non-deterministic op implementations\n# -1 means use time seed.\n_C.RNG_SEED = 1\n"
  },
  {
    "path": "xmuda/common/solver/__init__.py",
    "content": ""
  },
  {
    "path": "xmuda/common/solver/build.py",
    "content": "\"\"\"Build optimizers and schedulers\"\"\"\nimport warnings\nimport torch\nfrom .lr_scheduler import ClipLR\n\n\ndef build_optimizer(cfg, model):\n    name = cfg.OPTIMIZER.TYPE\n    if name == '':\n        warnings.warn('No optimizer is built.')\n        return None\n    elif hasattr(torch.optim, name):\n        return getattr(torch.optim, name)(\n            model.parameters(),\n            lr=cfg.OPTIMIZER.BASE_LR,\n            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY,\n            **cfg.OPTIMIZER.get(name, dict()),\n        )\n    else:\n        raise ValueError('Unsupported type of optimizer.')\n\n\ndef build_scheduler(cfg, optimizer):\n    name = cfg.SCHEDULER.TYPE\n    if name == '':\n        warnings.warn('No scheduler is built.')\n        return None\n    elif hasattr(torch.optim.lr_scheduler, name):\n        scheduler = getattr(torch.optim.lr_scheduler, name)(\n            optimizer,\n            **cfg.SCHEDULER.get(name, dict()),\n        )\n    else:\n        raise ValueError('Unsupported type of scheduler.')\n\n    # clip learning rate\n    if cfg.SCHEDULER.CLIP_LR > 0.0:\n        print('Learning rate is clipped to {}'.format(cfg.SCHEDULER.CLIP_LR))\n        scheduler = ClipLR(scheduler, min_lr=cfg.SCHEDULER.CLIP_LR)\n\n    return scheduler\n"
  },
  {
    "path": "xmuda/common/solver/lr_scheduler.py",
    "content": "from __future__ import division\nfrom bisect import bisect_right\nfrom torch.optim.lr_scheduler import _LRScheduler, MultiStepLR\n\n\nclass WarmupMultiStepLR(_LRScheduler):\n    \"\"\"https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py\"\"\"\n\n    def __init__(\n        self,\n        optimizer,\n        milestones,\n        gamma=0.1,\n        warmup_factor=0.1,\n        warmup_steps=1,\n        warmup_method=\"linear\",\n        last_epoch=-1,\n    ):\n        if not list(milestones) == sorted(milestones):\n            raise ValueError(\n                \"Milestones should be a list of\" \" increasing integers. Got {}\",\n                milestones,\n            )\n\n        if warmup_method not in (\"constant\", \"linear\"):\n            raise ValueError(\n                \"Only 'constant' or 'linear' warmup_method accepted\"\n                \"got {}\".format(warmup_method)\n            )\n        self.milestones = milestones\n        self.gamma = gamma\n        self.warmup_factor = warmup_factor\n        self.warmup_steps = warmup_steps\n        self.warmup_method = warmup_method\n        super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)\n\n    def get_lr(self):\n        warmup_factor = 1\n        if self.last_epoch < self.warmup_steps:\n            if self.warmup_method == \"constant\":\n                warmup_factor = self.warmup_factor\n            elif self.warmup_method == \"linear\":\n                alpha = float(self.last_epoch) / self.warmup_steps\n                warmup_factor = self.warmup_factor * (1 - alpha) + alpha\n        return [\n            base_lr\n            * warmup_factor\n            * self.gamma ** bisect_right(self.milestones, self.last_epoch)\n            for base_lr in self.base_lrs\n        ]\n\n\nclass ClipLR(object):\n    \"\"\"Clip the learning rate of a given scheduler.\n    Same interfaces of _LRScheduler should be implemented.\n\n    Args:\n        scheduler (_LRScheduler): an instance of _LRScheduler.\n        min_lr (float): minimum learning rate.\n\n    \"\"\"\n\n    def __init__(self, scheduler, min_lr=1e-5):\n        assert isinstance(scheduler, _LRScheduler)\n        self.scheduler = scheduler\n        self.min_lr = min_lr\n\n    def get_lr(self):\n        return [max(self.min_lr, lr) for lr in self.scheduler.get_lr()]\n\n    def __getattr__(self, item):\n        if hasattr(self.scheduler, item):\n            return getattr(self.scheduler, item)\n        else:\n            return getattr(self, item)\n"
  },
  {
    "path": "xmuda/common/utils/checkpoint.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Modified by Jiayuan Gu\nimport os\nimport logging\n\nimport torch\nfrom torch.nn.parallel import DataParallel, DistributedDataParallel\n\nfrom .io import get_md5\n\n\nclass Checkpointer(object):\n    \"\"\"Checkpoint the model and relevant states.\n\n    Supported features:\n    1. Resume optimizer and scheduler\n    2. Automatically deal with DataParallel, DistributedDataParallel\n    3. Resume last saved checkpoint\n\n    \"\"\"\n\n    def __init__(self,\n                 model,\n                 optimizer=None,\n                 scheduler=None,\n                 save_dir='',\n                 logger=None,\n                 postfix=''\n                 ):\n        self.model = model\n        self.optimizer = optimizer\n        self.scheduler = scheduler\n        self.save_dir = save_dir\n        # logging\n        self.logger = logger\n        self._print = logger.info if logger else print\n        self.postfix = postfix\n\n    def save(self, name, tag=True, **kwargs):\n        if not self.save_dir:\n            return\n\n        data = dict()\n        if isinstance(self.model, (DataParallel, DistributedDataParallel)):\n            data['model'] = self.model.module.state_dict()\n        else:\n            data['model'] = self.model.state_dict()\n        if self.optimizer is not None:\n            data['optimizer'] = self.optimizer.state_dict()\n        if self.scheduler is not None:\n            data['scheduler'] = self.scheduler.state_dict()\n        data.update(kwargs)\n\n        save_file = os.path.join(self.save_dir, '{}.pth'.format(name))\n        self._print('Saving checkpoint to {}'.format(os.path.abspath(save_file)))\n        torch.save(data, save_file)\n        if tag:\n            self.tag_last_checkpoint(save_file)\n\n    def load(self, path=None, resume=True, resume_states=True):\n        if resume and self.has_checkpoint():\n            # override argument with existing checkpoint\n            path = self.get_checkpoint_file()\n        if not path:\n            # no checkpoint could be found\n            self._print('No checkpoint found. Initializing model from scratch')\n            return {}\n\n        self._print('Loading checkpoint from {}, MD5: {}'.format(path, get_md5(path)))\n        checkpoint = self._load_file(path)\n\n        if isinstance(self.model, (DataParallel, DistributedDataParallel)):\n            self.model.module.load_state_dict(checkpoint.pop('model'))\n        else:\n            self.model.load_state_dict(checkpoint.pop('model'))\n        if resume_states:\n            if 'optimizer' in checkpoint and self.optimizer:\n                self.logger.info('Loading optimizer from {}'.format(path))\n                self.optimizer.load_state_dict(checkpoint.pop('optimizer'))\n            if 'scheduler' in checkpoint and self.scheduler:\n                self.logger.info('Loading scheduler from {}'.format(path))\n                self.scheduler.load_state_dict(checkpoint.pop('scheduler'))\n        else:\n            checkpoint = {}\n\n        # return any further checkpoint data\n        return checkpoint\n\n    def has_checkpoint(self):\n        save_file = os.path.join(self.save_dir, 'last_checkpoint' + self.postfix)\n        return os.path.exists(save_file)\n\n    def get_checkpoint_file(self):\n        save_file = os.path.join(self.save_dir, 'last_checkpoint' + self.postfix)\n        try:\n            with open(save_file, 'r') as f:\n                last_saved = f.read()\n            # If not absolute path, add save_dir as prefix\n            if not os.path.isabs(last_saved):\n                last_saved = os.path.join(self.save_dir, last_saved)\n        except IOError:\n            # If file doesn't exist, maybe because it has just been\n            # deleted by a separate process\n            last_saved = ''\n        return last_saved\n\n    def tag_last_checkpoint(self, last_filename):\n        save_file = os.path.join(self.save_dir, 'last_checkpoint' + self.postfix)\n        # If not absolute path, only save basename\n        if not os.path.isabs(last_filename):\n            last_filename = os.path.basename(last_filename)\n        with open(save_file, 'w') as f:\n            f.write(last_filename)\n\n    def _load_file(self, path):\n        return torch.load(path, map_location=torch.device('cpu'))\n\n\nclass CheckpointerV2(Checkpointer):\n    \"\"\"Support max_to_keep like tf.Saver\"\"\"\n\n    def __init__(self, *args, max_to_keep=5, **kwargs):\n        super(CheckpointerV2, self).__init__(*args, **kwargs)\n        self.max_to_keep = max_to_keep\n        self._last_checkpoints = []\n\n    def get_checkpoint_file(self):\n        save_file = os.path.join(self.save_dir, 'last_checkpoint' + self.postfix)\n        try:\n            self._last_checkpoints = self._load_last_checkpoints(save_file)\n            last_saved = self._last_checkpoints[-1]\n        except (IOError, IndexError):\n            # If file doesn't exist, maybe because it has just been\n            # deleted by a separate process\n            last_saved = ''\n        return last_saved\n\n    def tag_last_checkpoint(self, last_filename):\n        save_file = os.path.join(self.save_dir, 'last_checkpoint' + self.postfix)\n        # Remove first from list if the same name was used before.\n        for path in self._last_checkpoints:\n            if last_filename == path:\n                self._last_checkpoints.remove(path)\n        # Append new path to list\n        self._last_checkpoints.append(last_filename)\n        # If more than max_to_keep, remove the oldest.\n        self._delete_old_checkpoint()\n        # Dump last checkpoints to a file\n        self._save_checkpoint_file(save_file)\n\n    def _delete_old_checkpoint(self):\n        if len(self._last_checkpoints) > self.max_to_keep:\n            path = self._last_checkpoints.pop(0)\n            try:\n                os.remove(path)\n            except Exception as e:\n                logging.warning(\"Ignoring: %s\", str(e))\n\n    def _save_checkpoint_file(self, path):\n        with open(path, 'w') as f:\n            lines = []\n            for p in self._last_checkpoints:\n                if not os.path.isabs(p):\n                    # If not absolute path, only save basename\n                    p = os.path.basename(p)\n                lines.append(p)\n            f.write('\\n'.join(lines))\n\n    def _load_last_checkpoints(self, path):\n        last_checkpoints = []\n        with open(path, 'r') as f:\n            for p in f.readlines():\n                if not os.path.isabs(p):\n                    # If not absolute path, add save_dir as prefix\n                    p = os.path.join(self.save_dir, p)\n                last_checkpoints.append(p)\n        return last_checkpoints\n"
  },
  {
    "path": "xmuda/common/utils/io.py",
    "content": "import hashlib\n\n\ndef get_md5(filename):\n    hash_obj = hashlib.md5()\n    with open(filename, 'rb') as f:\n        hash_obj.update(f.read())\n    return hash_obj.hexdigest()\n"
  },
  {
    "path": "xmuda/common/utils/logger.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Modified by Jiayuan Gu\nimport logging\nimport os\nimport sys\n\n\ndef setup_logger(name, save_dir, comment=''):\n    logger = logging.getLogger(name)\n    logger.setLevel(logging.DEBUG)\n    ch = logging.StreamHandler(stream=sys.stdout)\n    ch.setLevel(logging.DEBUG)\n    formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s')\n    ch.setFormatter(formatter)\n    logger.addHandler(ch)\n\n    if save_dir:\n        filename = 'log'\n        if comment:\n            filename += '.' + comment\n        log_file = os.path.join(save_dir, filename + '.txt')\n        fh = logging.FileHandler(log_file)\n        fh.setLevel(logging.DEBUG)\n        fh.setFormatter(formatter)\n        logger.addHandler(fh)\n\n    return logger\n"
  },
  {
    "path": "xmuda/common/utils/metric_logger.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n# Modified by Jiayuan Gu\nfrom __future__ import division\nfrom collections import defaultdict\nfrom collections import deque\n\nimport numpy as np\nimport torch\n\n\nclass AverageMeter(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n    default_fmt = '{avg:.4f} ({global_avg:.4f})'\n    default_summary_fmt = '{global_avg:.4f}'\n\n    def __init__(self, window_size=20, fmt=None, summary_fmt=None):\n        self.values = deque(maxlen=window_size)\n        self.counts = deque(maxlen=window_size)\n        self.sum = 0.0\n        self.count = 0\n        self.fmt = fmt or self.default_fmt\n        self.summary_fmt = summary_fmt or self.default_summary_fmt\n\n    def update(self, value, count=1):\n        self.values.append(value)\n        self.counts.append(count)\n        self.sum += value\n        self.count += count\n\n    @property\n    def avg(self):\n        return np.sum(self.values) / np.sum(self.counts)\n\n    @property\n    def global_avg(self):\n        return self.sum / self.count if self.count != 0 else float('nan')\n\n    def reset(self):\n        self.values.clear()\n        self.counts.clear()\n        self.sum = 0.0\n        self.count = 0\n\n    def __str__(self):\n        return self.fmt.format(avg=self.avg, global_avg=self.global_avg)\n\n    @property\n    def summary_str(self):\n        return self.summary_fmt.format(global_avg=self.global_avg)\n\n\nclass MetricLogger(object):\n    \"\"\"Metric logger.\n    All the meters should implement following methods:\n        __str__, summary_str, reset\n    \"\"\"\n\n    def __init__(self, delimiter='\\t'):\n        self.meters = defaultdict(AverageMeter)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if isinstance(v, torch.Tensor):\n                count = v.numel()\n                value = v.item() if count == 1 else v.sum().item()\n            elif isinstance(v, np.ndarray):\n                count = v.size\n                value = v.item() if count == 1 else v.sum().item()\n            else:\n                assert isinstance(v, (float, int))\n                value = v\n                count = 1\n            self.meters[k].update(value, count)\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def add_meters(self, meters):\n        if not isinstance(meters, (list, tuple)):\n            meters = [meters]\n        for meter in meters:\n            self.add_meter(meter.name, meter)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        return getattr(self, attr)\n\n    def __str__(self):\n        metric_str = []\n        for name, meter in self.meters.items():\n            metric_str.append('{}: {}'.format(name, str(meter)))\n        return self.delimiter.join(metric_str)\n\n    @property\n    def summary_str(self):\n        metric_str = []\n        for name, meter in self.meters.items():\n            metric_str.append('{}: {}'.format(name, meter.summary_str))\n        return self.delimiter.join(metric_str)\n\n    def reset(self):\n        for meter in self.meters.values():\n            meter.reset()\n"
  },
  {
    "path": "xmuda/common/utils/sampler.py",
    "content": "from torch.utils.data.sampler import Sampler\n\n\nclass IterationBasedBatchSampler(Sampler):\n    \"\"\"\n    Wraps a BatchSampler, resampling from it until a specified number of iterations have been sampled\n\n    References:\n        https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/samplers/iteration_based_batch_sampler.py\n    \"\"\"\n\n    def __init__(self, batch_sampler, num_iterations, start_iter=0):\n        self.batch_sampler = batch_sampler\n        self.num_iterations = num_iterations\n        self.start_iter = start_iter\n\n    def __iter__(self):\n        iteration = self.start_iter\n        while iteration < self.num_iterations:\n            # if the underlying sampler has a set_epoch method, like\n            # DistributedSampler, used for making each process see\n            # a different split of the dataset, then set it\n            if hasattr(self.batch_sampler.sampler, \"set_epoch\"):\n                self.batch_sampler.sampler.set_epoch(iteration)\n            for batch in self.batch_sampler:\n                yield batch\n                iteration += 1\n                if iteration >= self.num_iterations:\n                    break\n\n    def __len__(self):\n        return self.num_iterations - self.start_iter\n\n\ndef test_IterationBasedBatchSampler():\n    from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler\n    sampler = RandomSampler([i for i in range(9)])\n    batch_sampler = BatchSampler(sampler, batch_size=2, drop_last=True)\n    batch_sampler = IterationBasedBatchSampler(batch_sampler, 6, start_iter=0)\n\n    # check __len__\n    # assert len(batch_sampler) == 5\n    for i, index in enumerate(batch_sampler):\n        print(i, index)\n        # assert [i * 2, i * 2 + 1] == index\n\n    # # check start iter\n    # batch_sampler.start_iter = 2\n    # assert len(batch_sampler) == 3\n\n\nif __name__ == '__main__':\n    test_IterationBasedBatchSampler()\n"
  },
  {
    "path": "xmuda/common/utils/torch_util.py",
    "content": "import random\nimport numpy as np\nimport torch\n\n\ndef set_random_seed(seed):\n    if seed < 0:\n        return\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    # torch.cuda.manual_seed_all(seed)\n\n\ndef worker_init_fn(worker_id):\n    \"\"\"The function is designed for pytorch multi-process dataloader.\n    Note that we use the pytorch random generator to generate a base_seed.\n    Please try to be consistent.\n\n    References:\n        https://pytorch.org/docs/stable/notes/faq.html#dataloader-workers-random-seed\n\n    \"\"\"\n    base_seed = torch.IntTensor(1).random_().item()\n    # print(worker_id, base_seed)\n    np.random.seed(base_seed + worker_id)\n"
  },
  {
    "path": "xmuda/config/xmuda.py",
    "content": "\"\"\"xMUDA experiments configuration\"\"\"\nimport os.path as osp\n\nfrom xmuda.common.config.base import CN, _C\n\n# public alias\ncfg = _C\n_C.VAL.METRIC = 'seg_iou'\n\n# ---------------------------------------------------------------------------- #\n# Specific train options\n# ---------------------------------------------------------------------------- #\n_C.TRAIN.CLASS_WEIGHTS = []\n\n# ---------------------------------------------------------------------------- #\n# xMUDA options\n# ---------------------------------------------------------------------------- #\n_C.TRAIN.XMUDA = CN()\n_C.TRAIN.XMUDA.lambda_xm_src = 0.0\n_C.TRAIN.XMUDA.lambda_xm_trg = 0.0\n_C.TRAIN.XMUDA.lambda_pl = 0.0\n_C.TRAIN.XMUDA.lambda_minent = 0.0\n_C.TRAIN.XMUDA.lambda_logcoral = 0.0\n\n# ---------------------------------------------------------------------------- #\n# Datasets\n# ---------------------------------------------------------------------------- #\n_C.DATASET_SOURCE = CN()\n_C.DATASET_SOURCE.TYPE = ''\n_C.DATASET_SOURCE.TRAIN = tuple()\n\n_C.DATASET_TARGET = CN()\n_C.DATASET_TARGET.TYPE = ''\n_C.DATASET_TARGET.TRAIN = tuple()\n_C.DATASET_TARGET.VAL = tuple()\n_C.DATASET_TARGET.TEST = tuple()\n\n# NuScenesSCN\n_C.DATASET_SOURCE.NuScenesSCN = CN()\n_C.DATASET_SOURCE.NuScenesSCN.preprocess_dir = ''\n_C.DATASET_SOURCE.NuScenesSCN.nuscenes_dir = ''\n_C.DATASET_SOURCE.NuScenesSCN.merge_classes = True\n# 3D\n_C.DATASET_SOURCE.NuScenesSCN.scale = 20\n_C.DATASET_SOURCE.NuScenesSCN.full_scale = 4096\n# 2D\n_C.DATASET_SOURCE.NuScenesSCN.use_image = True\n_C.DATASET_SOURCE.NuScenesSCN.resize = (400, 225)\n_C.DATASET_SOURCE.NuScenesSCN.image_normalizer = ()\n# 3D augmentation\n_C.DATASET_SOURCE.NuScenesSCN.augmentation = CN()\n_C.DATASET_SOURCE.NuScenesSCN.augmentation.noisy_rot = 0.1\n_C.DATASET_SOURCE.NuScenesSCN.augmentation.flip_x = 0.5\n_C.DATASET_SOURCE.NuScenesSCN.augmentation.rot_z = 6.2831  # 2 * pi\n_C.DATASET_SOURCE.NuScenesSCN.augmentation.transl = True\n# 2D augmentation\n_C.DATASET_SOURCE.NuScenesSCN.augmentation.fliplr = 0.5\n_C.DATASET_SOURCE.NuScenesSCN.augmentation.color_jitter = (0.4, 0.4, 0.4)\n# copy over the same arguments to target dataset settings\n_C.DATASET_TARGET.NuScenesSCN = CN(_C.DATASET_SOURCE.NuScenesSCN)\n_C.DATASET_TARGET.NuScenesSCN.pselab_paths = tuple()\n\n# A2D2SCN\n_C.DATASET_SOURCE.A2D2SCN = CN()\n_C.DATASET_SOURCE.A2D2SCN.preprocess_dir = ''\n_C.DATASET_SOURCE.A2D2SCN.merge_classes = True\n# 3D\n_C.DATASET_SOURCE.A2D2SCN.scale = 20\n_C.DATASET_SOURCE.A2D2SCN.full_scale = 4096\n# 2D\n_C.DATASET_SOURCE.A2D2SCN.use_image = True\n_C.DATASET_SOURCE.A2D2SCN.resize = (480, 302)\n_C.DATASET_SOURCE.A2D2SCN.image_normalizer = ()\n# 3D augmentation\n_C.DATASET_SOURCE.A2D2SCN.augmentation = CN()\n_C.DATASET_SOURCE.A2D2SCN.augmentation.noisy_rot = 0.1\n_C.DATASET_SOURCE.A2D2SCN.augmentation.flip_y = 0.5\n_C.DATASET_SOURCE.A2D2SCN.augmentation.rot_z = 6.2831  # 2 * pi\n_C.DATASET_SOURCE.A2D2SCN.augmentation.transl = True\n# 2D augmentation\n_C.DATASET_SOURCE.A2D2SCN.augmentation.fliplr = 0.5\n_C.DATASET_SOURCE.A2D2SCN.augmentation.color_jitter = (0.4, 0.4, 0.4)\n\n# SemanticKITTISCN\n_C.DATASET_SOURCE.SemanticKITTISCN = CN()\n_C.DATASET_SOURCE.SemanticKITTISCN.preprocess_dir = ''\n_C.DATASET_SOURCE.SemanticKITTISCN.semantic_kitti_dir = ''\n_C.DATASET_SOURCE.SemanticKITTISCN.merge_classes = True\n# 3D\n_C.DATASET_SOURCE.SemanticKITTISCN.scale = 20\n_C.DATASET_SOURCE.SemanticKITTISCN.full_scale = 4096\n# 2D\n_C.DATASET_SOURCE.SemanticKITTISCN.image_normalizer = ()\n# 3D augmentation\n_C.DATASET_SOURCE.SemanticKITTISCN.augmentation = CN()\n_C.DATASET_SOURCE.SemanticKITTISCN.augmentation.noisy_rot = 0.1\n_C.DATASET_SOURCE.SemanticKITTISCN.augmentation.flip_y = 0.5\n_C.DATASET_SOURCE.SemanticKITTISCN.augmentation.rot_z = 6.2831  # 2 * pi\n_C.DATASET_SOURCE.SemanticKITTISCN.augmentation.transl = True\n# 2D augmentation\n_C.DATASET_SOURCE.SemanticKITTISCN.augmentation.bottom_crop = (480, 302)\n_C.DATASET_SOURCE.SemanticKITTISCN.augmentation.fliplr = 0.5\n_C.DATASET_SOURCE.SemanticKITTISCN.augmentation.color_jitter = (0.4, 0.4, 0.4)\n# copy over the same arguments to target dataset settings\n_C.DATASET_TARGET.SemanticKITTISCN = CN(_C.DATASET_SOURCE.SemanticKITTISCN)\n_C.DATASET_TARGET.SemanticKITTISCN.pselab_paths = tuple()\n\n# ---------------------------------------------------------------------------- #\n# Model 2D\n# ---------------------------------------------------------------------------- #\n_C.MODEL_2D = CN()\n_C.MODEL_2D.TYPE = ''\n_C.MODEL_2D.CKPT_PATH = ''\n_C.MODEL_2D.NUM_CLASSES = 5\n_C.MODEL_2D.DUAL_HEAD = False\n# ---------------------------------------------------------------------------- #\n# UNetResNet34 options\n# ---------------------------------------------------------------------------- #\n_C.MODEL_2D.UNetResNet34 = CN()\n_C.MODEL_2D.UNetResNet34.pretrained = True\n\n# ---------------------------------------------------------------------------- #\n# Model 3D\n# ---------------------------------------------------------------------------- #\n_C.MODEL_3D = CN()\n_C.MODEL_3D.TYPE = ''\n_C.MODEL_3D.CKPT_PATH = ''\n_C.MODEL_3D.NUM_CLASSES = 5\n_C.MODEL_3D.DUAL_HEAD = False\n# ----------------------------------------------------------------------------- #\n# SCN options\n# ----------------------------------------------------------------------------- #\n_C.MODEL_3D.SCN = CN()\n_C.MODEL_3D.SCN.in_channels = 1\n_C.MODEL_3D.SCN.m = 16  # number of unet features (multiplied in each layer)\n_C.MODEL_3D.SCN.block_reps = 1  # block repetitions\n_C.MODEL_3D.SCN.residual_blocks = False  # ResNet style basic blocks\n_C.MODEL_3D.SCN.full_scale = 4096\n_C.MODEL_3D.SCN.num_planes = 7\n\n# ---------------------------------------------------------------------------- #\n# Misc options\n# ---------------------------------------------------------------------------- #\n# @ will be replaced by config path\n_C.OUTPUT_DIR = osp.expanduser('~/workspace/outputs/xmuda/@')"
  },
  {
    "path": "xmuda/data/a2d2/a2d2_dataloader.py",
    "content": "import os.path as osp\nimport pickle\nfrom PIL import Image\nimport numpy as np\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms as T\nimport json\n\nfrom xmuda.data.utils.augmentation_3d import augment_and_scale_3d\n\n\nclass A2D2Base(Dataset):\n    \"\"\"A2D2 dataset\"\"\"\n\n    class_names = [\n        'Car 1',\n        'Car 2',\n        'Car 3',\n        'Car 4',\n        'Bicycle 1',\n        'Bicycle 2',\n        'Bicycle 3',\n        'Bicycle 4',\n        'Pedestrian 1',\n        'Pedestrian 2',\n        'Pedestrian 3',\n        'Truck 1',\n        'Truck 2',\n        'Truck 3',\n        'Small vehicles 1',\n        'Small vehicles 2',\n        'Small vehicles 3',\n        'Traffic signal 1',\n        'Traffic signal 2',\n        'Traffic signal 3',\n        'Traffic sign 1',\n        'Traffic sign 2',\n        'Traffic sign 3',\n        'Utility vehicle 1',\n        'Utility vehicle 2',\n        'Sidebars',\n        'Speed bumper',\n        'Curbstone',\n        'Solid line',\n        'Irrelevant signs',\n        'Road blocks',\n        'Tractor',\n        'Non-drivable street',\n        'Zebra crossing',\n        'Obstacles / trash',\n        'Poles',\n        'RD restricted area',\n        'Animals',\n        'Grid structure',\n        'Signal corpus',\n        'Drivable cobblestone',\n        'Electronic traffic',\n        'Slow drive area',\n        'Nature object',\n        'Parking area',\n        'Sidewalk',\n        'Ego car',\n        'Painted driv. instr.',\n        'Traffic guide obj.',\n        'Dashed line',\n        'RD normal street',\n        'Sky',\n        'Buildings',\n        'Blurred area',\n        'Rain dirt'\n    ]\n\n    # use those categories if merge_classes == True\n    categories = {\n        'car': ['Car 1', 'Car 2', 'Car 3', 'Car 4', 'Ego car'],\n        'truck': ['Truck 1', 'Truck 2', 'Truck 3'],\n        'bike': ['Bicycle 1', 'Bicycle 2', 'Bicycle 3', 'Bicycle 4', 'Small vehicles 1', 'Small vehicles 2',\n                 'Small vehicles 3'],  # small vehicles are \"usually\" motorcycles\n        'person': ['Pedestrian 1', 'Pedestrian 2', 'Pedestrian 3'],\n        'road': ['RD normal street', 'Zebra crossing', 'Solid line', 'RD restricted area', 'Slow drive area',\n                 'Drivable cobblestone', 'Dashed line', 'Painted driv. instr.'],\n        'parking': ['Parking area'],\n        'sidewalk': ['Sidewalk', 'Curbstone'],\n        'building': ['Buildings'],\n        'nature': ['Nature object'],\n        'other-objects': ['Poles', 'Traffic signal 1', 'Traffic signal 2', 'Traffic signal 3', 'Traffic sign 1',\n                          'Traffic sign 2', 'Traffic sign 3', 'Sidebars', 'Speed bumper', 'Irrelevant signs',\n                          'Road blocks', 'Obstacles / trash', 'Animals', 'Signal corpus', 'Electronic traffic',\n                          'Traffic guide obj.', 'Grid structure'],\n        # 'ignore': ['Sky', 'Utility vehicle 1', 'Utility vehicle 2', 'Tractor', 'Non-drivable street',\n        #            'Blurred area', 'Rain dirt'],\n    }\n\n    def __init__(self,\n                 split,\n                 preprocess_dir,\n                 merge_classes=False\n                 ):\n\n        self.split = split\n        self.preprocess_dir = preprocess_dir\n\n        print(\"Initialize A2D2 dataloader\")\n\n        with open(osp.join(self.preprocess_dir, 'cams_lidars.json'), 'r') as f:\n            self.config = json.load(f)\n\n        assert isinstance(split, tuple)\n        print('Load', split)\n        self.data = []\n        for curr_split in split:\n            with open(osp.join(self.preprocess_dir, 'preprocess', curr_split + '.pkl'), 'rb') as f:\n                self.data.extend(pickle.load(f))\n\n        with open(osp.join(self.preprocess_dir, 'class_list.json'), 'r') as f:\n            class_list = json.load(f)\n            self.rgb_to_class = {}\n            self.rgb_to_cls_idx = {}\n            count = 0\n            for k, v in class_list.items():\n                # hex to rgb\n                rgb_value = tuple(int(k.lstrip('#')[i:i + 2], 16) for i in (0, 2, 4))\n                self.rgb_to_class[rgb_value] = v\n                self.rgb_to_cls_idx[rgb_value] = count\n                count += 1\n\n        assert self.class_names == list(self.rgb_to_class.values())\n        if merge_classes:\n            self.label_mapping = -100 * np.ones(len(self.rgb_to_class) + 1, dtype=int)\n            for cat_idx, cat_list in enumerate(self.categories.values()):\n                for class_name in cat_list:\n                    self.label_mapping[self.class_names.index(class_name)] = cat_idx\n            self.class_names = list(self.categories.keys())\n        else:\n            self.label_mapping = None\n\n    def __getitem__(self, index):\n        raise NotImplementedError\n\n    def __len__(self):\n        return len(self.data)\n\n\nclass A2D2SCN(A2D2Base):\n    def __init__(self,\n                 split,\n                 preprocess_dir,\n                 merge_classes=False,\n                 scale=20,\n                 full_scale=4096,\n                 use_image=False,\n                 resize=(480, 302),\n                 image_normalizer=None,\n                 noisy_rot=0.0,  # 3D augmentation\n                 flip_y=0.0,  # 3D augmentation\n                 rot_z=0.0,  # 3D augmentation\n                 transl=False,  # 3D augmentation\n                 fliplr=0.0,  # 2D augmentation\n                 color_jitter=None,  # 2D augmentation\n                 ):\n        super().__init__(split,\n                         preprocess_dir,\n                         merge_classes=merge_classes)\n\n        # point cloud parameters\n        self.scale = scale\n        self.full_scale = full_scale\n        # 3D augmentation\n        self.noisy_rot = noisy_rot\n        self.flip_y = flip_y\n        self.rot_z = rot_z\n        self.transl = transl\n\n        # image parameters\n        self.use_image = use_image\n        if self.use_image:\n            self.resize = resize\n            self.image_normalizer = image_normalizer\n\n            # data augmentation\n            self.fliplr = fliplr\n            self.color_jitter = T.ColorJitter(*color_jitter) if color_jitter else None\n\n    def __getitem__(self, index):\n        data_dict = self.data[index]\n\n        points = data_dict['points'].copy()\n        seg_label = data_dict['seg_labels'].astype(np.int64)\n\n        if self.label_mapping is not None:\n            seg_label = self.label_mapping[seg_label]\n\n        out_dict = {}\n\n        if self.use_image:\n            points_img = data_dict['points_img'].copy()\n            img_path = osp.join(self.preprocess_dir, data_dict['camera_path'])\n            image = Image.open(img_path)\n\n            if self.resize:\n                if not image.size == self.resize:\n                    # check if we do not enlarge downsized images\n                    assert image.size[0] > self.resize[0]\n\n                    # scale image points\n                    points_img[:, 0] = float(self.resize[1]) / image.size[1] * np.floor(points_img[:, 0])\n                    points_img[:, 1] = float(self.resize[0]) / image.size[0] * np.floor(points_img[:, 1])\n\n                    # resize image\n                    image = image.resize(self.resize, Image.BILINEAR)\n\n            img_indices = points_img.astype(np.int64)\n\n            assert np.all(img_indices[:, 0] >= 0)\n            assert np.all(img_indices[:, 1] >= 0)\n            assert np.all(img_indices[:, 0] < image.size[1])\n            assert np.all(img_indices[:, 1] < image.size[0])\n\n            # 2D augmentation\n            if self.color_jitter is not None:\n                image = self.color_jitter(image)\n            # PIL to numpy\n            image = np.array(image, dtype=np.float32, copy=False) / 255.\n            # 2D augmentation\n            if np.random.rand() < self.fliplr:\n                image = np.ascontiguousarray(np.fliplr(image))\n                img_indices[:, 1] = image.shape[1] - 1 - img_indices[:, 1]\n\n            # normalize image\n            if self.image_normalizer:\n                mean, std = self.image_normalizer\n                mean = np.asarray(mean, dtype=np.float32)\n                std = np.asarray(std, dtype=np.float32)\n                image = (image - mean) / std\n\n            out_dict['img'] = np.moveaxis(image, -1, 0)\n            out_dict['img_indices'] = img_indices\n\n        # 3D data augmentation and scaling from points to voxel indices\n        # A2D2 lidar coordinates (same as Kitti): x (front), y (left), z (up)\n        coords = augment_and_scale_3d(points, self.scale, self.full_scale, noisy_rot=self.noisy_rot,\n                                      flip_y=self.flip_y, rot_z=self.rot_z, transl=self.transl)\n\n        # cast to integer\n        coords = coords.astype(np.int64)\n\n        # only use voxels inside receptive field\n        idxs = (coords.min(1) >= 0) * (coords.max(1) < self.full_scale)\n\n        out_dict['coords'] = coords[idxs]\n        out_dict['feats'] = np.ones([len(idxs), 1], np.float32)  # simply use 1 as feature\n        out_dict['seg_label'] = seg_label[idxs]\n\n        if self.use_image:\n            out_dict['img_indices'] = out_dict['img_indices'][idxs]\n\n        return out_dict\n\n\ndef test_A2D2SCN():\n    from xmuda.data.utils.visualize import draw_points_image_labels, draw_bird_eye_view\n    preprocess_dir = '/datasets_local/datasets_mjaritz/a2d2_preprocess'\n    split = ('test',)\n    dataset = A2D2SCN(split=split,\n                      preprocess_dir=preprocess_dir,\n                      merge_classes=True,\n                      use_image=True,\n                      noisy_rot=0.1,\n                      flip_y=0.5,\n                      rot_z=2*np.pi,\n                      transl=True,\n                      fliplr=0.5,\n                      color_jitter=(0.4, 0.4, 0.4)\n                      )\n    for i in [10, 20, 30, 40, 50, 60]:\n        data = dataset[i]\n        coords = data['coords']\n        seg_label = data['seg_label']\n        img = np.moveaxis(data['img'], 0, 2)\n        img_indices = data['img_indices']\n        draw_points_image_labels(img, img_indices, seg_label, color_palette_type='SemanticKITTI', point_size=3)\n        draw_bird_eye_view(coords)\n\n\ndef compute_class_weights():\n    preprocess_dir = '/datasets_local/datasets_mjaritz/a2d2_preprocess'\n    split = ('train', 'test')\n    dataset = A2D2Base(split,\n                       preprocess_dir,\n                       merge_classes=True\n                       )\n    # compute points per class over whole dataset\n    num_classes = len(dataset.class_names)\n    points_per_class = np.zeros(num_classes, int)\n    for i, data in enumerate(dataset.data):\n        print('{}/{}'.format(i, len(dataset)))\n        labels = dataset.label_mapping[data['seg_labels']]\n        points_per_class += np.bincount(labels[labels != -100], minlength=num_classes)\n\n    # compute log smoothed class weights\n    class_weights = np.log(5 * points_per_class.sum() / points_per_class)\n    print('log smoothed class weights: ', class_weights / class_weights.min())\n\n\nif __name__ == '__main__':\n    test_A2D2SCN()\n    # compute_class_weights()\n"
  },
  {
    "path": "xmuda/data/a2d2/preprocess.py",
    "content": "import os\nimport os.path as osp\nimport shutil\nimport numpy as np\nimport pickle\nimport json\nfrom PIL import Image\nimport cv2\nimport glob\nimport torch\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.dataloader import DataLoader\n\n\nfrom xmuda.data.a2d2 import splits\n\nfrom xmuda.data.a2d2.a2d2_dataloader import A2D2Base\n\n# prevent \"RuntimeError: received 0 items of ancdata\"\ntorch.multiprocessing.set_sharing_strategy('file_system')\n\n\nclass_names_to_id = dict(zip(A2D2Base.class_names, range(len(A2D2Base.class_names))))\n\n\ndef undistort_image(config, image, cam_name):\n    \"\"\"copied from https://www.a2d2.audi/a2d2/en/tutorial.html\"\"\"\n    if cam_name in ['front_left', 'front_center',\n                    'front_right', 'side_left',\n                    'side_right', 'rear_center']:\n        # get parameters from config file\n        intr_mat_undist = np.asarray(config['cameras'][cam_name]['CamMatrix'])\n        intr_mat_dist = np.asarray(config['cameras'][cam_name]['CamMatrixOriginal'])\n        dist_parms = np.asarray(config['cameras'][cam_name]['Distortion'])\n        lens = config['cameras'][cam_name]['Lens']\n\n        if lens == 'Fisheye':\n            return cv2.fisheye.undistortImage(image, intr_mat_dist, D=dist_parms, Knew=intr_mat_undist)\n        elif lens == 'Telecam':\n            return cv2.undistort(image, intr_mat_dist, distCoeffs=dist_parms, newCameraMatrix=intr_mat_undist)\n        else:\n            return image\n    else:\n        return image\n\n\nclass DummyDataset(Dataset):\n    \"\"\"Use torch dataloader for multiprocessing\"\"\"\n    def __init__(self, root_dir, scenes):\n        self.class_names = A2D2Base.class_names.copy()\n        self.categories = A2D2Base.categories.copy()\n        self.root_dir = root_dir\n        self.data = []\n        self.glob_frames(scenes)\n\n        # load config\n        with open(osp.join(root_dir, 'cams_lidars.json'), 'r') as f:\n            self.config = json.load(f)\n\n        # load color to class mapping\n        with open(osp.join(root_dir, 'class_list.json'), 'r') as f:\n            class_list = json.load(f)\n            self.rgb_to_class = {}\n            self.rgb_to_cls_idx = {}\n            count = 0\n            for k, v in class_list.items():\n                # hex to rgb\n                rgb_value = tuple(int(k.lstrip('#')[i:i + 2], 16) for i in (0, 2, 4))\n                self.rgb_to_class[rgb_value] = v\n                self.rgb_to_cls_idx[rgb_value] = count\n                count += 1\n\n        assert list(class_names_to_id.keys()) == list(self.rgb_to_class.values())\n\n    def glob_frames(self, scenes):\n        for scene in scenes:\n            cam_paths = sorted(glob.glob(osp.join(self.root_dir, scene, 'camera', 'cam_front_center', '*.png')))\n            for cam_path in cam_paths:\n                basename = osp.basename(cam_path)\n                datetime = basename[:14]\n                assert datetime.isdigit()\n                frame_id = basename[-13:-4]\n                assert frame_id.isdigit()\n                data = {\n                    'camera_path': cam_path,\n                    'lidar_path': osp.join(self.root_dir, scene, 'lidar', 'cam_front_center',\n                                           datetime + '_lidar_frontcenter_' + frame_id + '.npz'),\n                    'label_path': osp.join(self.root_dir, scene, 'label', 'cam_front_center',\n                                           datetime + '_label_frontcenter_' + frame_id + '.png'),\n                }\n                for k, v in data.items():\n                    if not osp.exists(v):\n                        raise IOError('File not found {}'.format(v))\n                self.data.append(data)\n\n    def __getitem__(self, index):\n        data_dict = self.data[index].copy()\n        lidar_front_center = np.load(data_dict['lidar_path'])\n        points = lidar_front_center['points']\n        if 'row' not in lidar_front_center.keys():\n            print('row not in lidar dict, return None, {}'.format(data_dict['lidar_path']))\n            return {}\n        rows = lidar_front_center['row'].astype(np.int)\n        cols = lidar_front_center['col'].astype(np.int)\n\n        # extract 3D labels from 2D\n        label_img = np.array(Image.open(data_dict['label_path']))\n        label_img = undistort_image(self.config, label_img, 'front_center')\n        label_pc = label_img[rows, cols, :]\n        seg_label = np.full(label_pc.shape[0], fill_value=len(self.rgb_to_cls_idx), dtype=np.int64)\n        # map RGB label code to index\n        for rgb_values, cls_idx in self.rgb_to_cls_idx.items():\n            idx = (rgb_values == label_pc).all(1)\n            if idx.any():\n                seg_label[idx] = cls_idx\n\n        # load image\n        image = Image.open(data_dict['camera_path'])\n        image_size = image.size\n        assert image_size == (1920, 1208)\n        # undistort\n        image = undistort_image(self.config, np.array(image), 'front_center')\n        # scale image points\n        points_img = np.stack([lidar_front_center['row'], lidar_front_center['col']], 1).astype(np.float32)\n        # check if conversion from float64 to float32 has led to image points outside of image\n        assert np.all(points_img[:, 0] < image_size[1])\n        assert np.all(points_img[:, 1] < image_size[0])\n\n        data_dict['seg_label'] = seg_label.astype(np.uint8)\n        data_dict['points'] = points.astype(np.float32)\n        data_dict['points_img'] = points_img  # row, col format, shape: (num_points, 2)\n        data_dict['img'] = image\n\n        return data_dict\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef preprocess(split_name, root_dir, out_dir):\n    pkl_data = []\n    split = getattr(splits, split_name)\n\n    dataloader = DataLoader(DummyDataset(root_dir, split), num_workers=8)\n\n    num_skips = 0\n    for i, data_dict in enumerate(dataloader):\n        # data error leads to returning empty dict\n        if not data_dict:\n            print('empty dict, continue')\n            num_skips += 1\n            continue\n        for k, v in data_dict.items():\n            data_dict[k] = v[0]\n        print('{}/{} {}'.format(i, len(dataloader), data_dict['lidar_path']))\n\n        # convert to relative path\n        lidar_path = data_dict['lidar_path'].replace(root_dir + '/', '')\n        cam_path = data_dict['camera_path'].replace(root_dir + '/', '')\n\n        # save undistorted image\n        new_cam_path = osp.join(out_dir, cam_path)\n        os.makedirs(osp.dirname(new_cam_path), exist_ok=True)\n        image = Image.fromarray(data_dict['img'].numpy())\n        image.save(new_cam_path)\n\n        # append data\n        out_dict = {\n            'points': data_dict['points'].numpy(),\n            'seg_labels': data_dict['seg_label'].numpy(),\n            'points_img': data_dict['points_img'].numpy(),  # row, col format, shape: (num_points, 2)\n            'lidar_path': lidar_path,\n            'camera_path': cam_path,\n        }\n        pkl_data.append(out_dict)\n\n    print('Skipped {} files'.format(num_skips))\n\n    # save to pickle file\n    save_dir = osp.join(out_dir, 'preprocess')\n    os.makedirs(save_dir, exist_ok=True)\n    save_path = osp.join(save_dir, '{}.pkl'.format(split_name))\n    with open(save_path, 'wb') as f:\n        pickle.dump(pkl_data, f)\n        print('Wrote preprocessed data to ' + save_path)\n\n\nif __name__ == '__main__':\n    root_dir = '/datasets_master/a2d2'\n    out_dir = '/datasets_local/datasets_mjaritz/a2d2_preprocess'\n    preprocess('test', root_dir, out_dir)\n    # split into train1 and train2 to prevent segmentation fault in torch dataloader\n    preprocess('train1', root_dir, out_dir)\n    preprocess('train2', root_dir, out_dir)\n    # merge train1 and train2\n    data = []\n    for curr_split in ['train1', 'train2']:\n        with open(osp.join(out_dir, 'preprocess', curr_split + '.pkl'), 'rb') as f:\n            data.extend(pickle.load(f))\n    save_path = osp.join(out_dir, 'preprocess', 'train.pkl')\n    with open(save_path, 'wb') as f:\n        pickle.dump(data, f)\n        print('Wrote preprocessed data to ' + save_path)\n    for curr_split in ['train1', 'train2']:\n        os.remove(osp.join(out_dir, 'preprocess', curr_split + '.pkl'))\n\n    # copy cams_lidars.json and class_list.json to out_dir\n    for filename in ['cams_lidars.json', 'class_list.json']:\n        shutil.copyfile(osp.join(root_dir, filename), osp.join(out_dir, filename))"
  },
  {
    "path": "xmuda/data/a2d2/splits.py",
    "content": "train = [\n    '20180810_142822',\n    '20180925_101535',\n    '20180925_112730',\n    '20180925_124435',\n    '20180925_135056',\n    '20181008_095521',\n    '20181016_082154',\n    '20181016_125231',\n    '20181107_132300',\n    '20181107_132730',\n    '20181107_133258',\n    '20181107_133445',\n    '20181108_084007',\n    '20181108_091945',\n    '20181108_103155',\n    '20181108_123750',\n    '20181108_141609',\n    '20181204_135952',\n    '20181204_154421',\n    '20181204_170238',\n]\n\n\ntrain1 = [\n    '20180810_142822',\n    '20180925_101535',\n    '20180925_112730',\n    '20180925_124435',\n    '20180925_135056',\n    '20181008_095521',\n    '20181016_082154',\n    '20181016_125231',\n    '20181107_132300',\n    '20181107_132730',\n]\ntrain2 = [\n    '20181107_133258',\n    '20181107_133445',\n    '20181108_084007',\n    '20181108_091945',\n    '20181108_103155',\n    '20181108_123750',\n    '20181108_141609',\n    '20181204_135952',\n    '20181204_154421',\n    '20181204_170238',\n]\n\ntest = [\n    '20180807_145028'\n]\n\nall = [\n    '20180807_145028',\n    '20180810_142822',\n    '20180925_101535',\n    '20180925_112730',\n    '20180925_124435',\n    '20180925_135056',\n    '20181008_095521',\n    '20181016_082154',\n    # '20181016_095036',  # no lidar\n    '20181016_125231',\n    '20181107_132300',\n    '20181107_132730',\n    '20181107_133258',\n    '20181107_133445',\n    '20181108_084007',\n    '20181108_091945',\n    '20181108_103155',\n    '20181108_123750',\n    '20181108_141609',\n    '20181204_135952',\n    '20181204_154421',\n    '20181204_170238',\n    # '20181204_191844',  # no lidar\n]\n"
  },
  {
    "path": "xmuda/data/build.py",
    "content": "from torch.utils.data.sampler import RandomSampler, BatchSampler\nfrom torch.utils.data.dataloader import DataLoader, default_collate\nfrom yacs.config import CfgNode as CN\n\nfrom xmuda.common.utils.torch_util import worker_init_fn\nfrom xmuda.data.collate import get_collate_scn\nfrom xmuda.common.utils.sampler import IterationBasedBatchSampler\nfrom xmuda.data.nuscenes.nuscenes_dataloader import NuScenesSCN\nfrom xmuda.data.a2d2.a2d2_dataloader import A2D2SCN\nfrom xmuda.data.semantic_kitti.semantic_kitti_dataloader import SemanticKITTISCN\n\n\ndef build_dataloader(cfg, mode='train', domain='source', start_iteration=0, halve_batch_size=False):\n    assert mode in ['train', 'val', 'test', 'train_labeled', 'train_unlabeled']\n    dataset_cfg = cfg.get('DATASET_' + domain.upper())\n    split = dataset_cfg[mode.upper()]\n    is_train = 'train' in mode\n    batch_size = cfg['TRAIN'].BATCH_SIZE if is_train else cfg['VAL'].BATCH_SIZE\n    if halve_batch_size:\n        batch_size = batch_size // 2\n\n    # build dataset\n    # Make a copy of dataset_kwargs so that we can pop augmentation afterwards without destroying the cfg.\n    # Note that the build_dataloader fn is called twice for train and val.\n    dataset_kwargs = CN(dataset_cfg.get(dataset_cfg.TYPE, dict()))\n    if 'SCN' in cfg.MODEL_3D.keys():\n        assert dataset_kwargs.full_scale == cfg.MODEL_3D.SCN.full_scale\n    augmentation = dataset_kwargs.pop('augmentation')\n    augmentation = augmentation if is_train else dict()\n    # use pselab_paths only when training on target\n    if domain == 'target' and not is_train:\n        dataset_kwargs.pop('pselab_paths')\n    if dataset_cfg.TYPE == 'NuScenesSCN':\n        dataset = NuScenesSCN(split=split,\n                              output_orig=not is_train,\n                              **dataset_kwargs,\n                              **augmentation)\n    elif dataset_cfg.TYPE == 'A2D2SCN':\n        dataset = A2D2SCN(split=split,\n                          **dataset_kwargs,\n                          **augmentation)\n    elif dataset_cfg.TYPE == 'SemanticKITTISCN':\n        dataset = SemanticKITTISCN(split=split,\n                                   output_orig=not is_train,\n                                   **dataset_kwargs,\n                                   **augmentation)\n    else:\n        raise ValueError('Unsupported type of dataset: {}.'.format(dataset_cfg.TYPE))\n\n    if 'SCN' in dataset_cfg.TYPE:\n        collate_fn = get_collate_scn(is_train)\n    else:\n        collate_fn = default_collate\n\n    if is_train:\n        sampler = RandomSampler(dataset)\n        batch_sampler = BatchSampler(sampler, batch_size=batch_size, drop_last=cfg.DATALOADER.DROP_LAST)\n        batch_sampler = IterationBasedBatchSampler(batch_sampler, cfg.SCHEDULER.MAX_ITERATION, start_iteration)\n        dataloader = DataLoader(\n            dataset,\n            batch_sampler=batch_sampler,\n            num_workers=cfg.DATALOADER.NUM_WORKERS,\n            worker_init_fn=worker_init_fn,\n            collate_fn=collate_fn\n        )\n    else:\n        dataloader = DataLoader(\n            dataset,\n            batch_size=batch_size,\n            drop_last=False,\n            num_workers=cfg.DATALOADER.NUM_WORKERS,\n            worker_init_fn=worker_init_fn,\n            collate_fn=collate_fn\n        )\n\n    return dataloader\n"
  },
  {
    "path": "xmuda/data/collate.py",
    "content": "import torch\nfrom functools import partial\n\n\ndef collate_scn_base(input_dict_list, output_orig, output_image=True):\n    \"\"\"\n    Custom collate function for SCN. The batch size is always 1,\n    but the batch indices are appended to the locations.\n    :param input_dict_list: a list of dicts from the dataloader\n    :param output_orig: whether to output original point cloud/labels/indices\n    :param output_image: whether to output images\n    :return: Collated data batch as dict\n    \"\"\"\n    locs=[]\n    feats=[]\n    labels=[]\n\n    if output_image:\n        imgs = []\n        img_idxs = []\n\n    if output_orig:\n        orig_seg_label = []\n        orig_points_idx = []\n\n    output_pselab = 'pseudo_label_2d' in input_dict_list[0].keys()\n    if output_pselab:\n        pseudo_label_2d = []\n        pseudo_label_3d = []\n\n    for idx, input_dict in enumerate(input_dict_list):\n        coords = torch.from_numpy(input_dict['coords'])\n        batch_idxs = torch.LongTensor(coords.shape[0], 1).fill_(idx)\n        locs.append(torch.cat([coords, batch_idxs], 1))\n        feats.append(torch.from_numpy(input_dict['feats']))\n        if 'seg_label' in input_dict.keys():\n            labels.append(torch.from_numpy(input_dict['seg_label']))\n        if output_image:\n            imgs.append(torch.from_numpy(input_dict['img']))\n            img_idxs.append(input_dict['img_indices'])\n        if output_orig:\n            orig_seg_label.append(input_dict['orig_seg_label'])\n            orig_points_idx.append(input_dict['orig_points_idx'])\n        if output_pselab:\n            pseudo_label_2d.append(torch.from_numpy(input_dict['pseudo_label_2d']))\n            if input_dict['pseudo_label_3d'] is not None:\n                pseudo_label_3d.append(torch.from_numpy(input_dict['pseudo_label_3d']))\n\n    locs = torch.cat(locs, 0)\n    feats = torch.cat(feats, 0)\n    out_dict = {'x': [locs, feats]}\n    if labels:\n        labels = torch.cat(labels, 0)\n        out_dict['seg_label'] = labels\n    if output_image:\n        out_dict['img'] = torch.stack(imgs)\n        out_dict['img_indices'] = img_idxs\n    if output_orig:\n        out_dict['orig_seg_label'] = orig_seg_label\n        out_dict['orig_points_idx'] = orig_points_idx\n    if output_pselab:\n        out_dict['pseudo_label_2d'] = torch.cat(pseudo_label_2d, 0)\n        out_dict['pseudo_label_3d'] = torch.cat(pseudo_label_3d, 0) if pseudo_label_3d else pseudo_label_3d\n    return out_dict\n\n\ndef get_collate_scn(is_train):\n    return partial(collate_scn_base,\n                   output_orig=not is_train,\n                   )\n"
  },
  {
    "path": "xmuda/data/nuscenes/nuscenes_dataloader.py",
    "content": "import os.path as osp\nimport pickle\nfrom PIL import Image\nimport numpy as np\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms as T\n\nfrom xmuda.data.utils.refine_pseudo_labels import refine_pseudo_labels\nfrom xmuda.data.utils.augmentation_3d import augment_and_scale_3d\n\n\nclass NuScenesBase(Dataset):\n    \"\"\"NuScenes dataset\"\"\"\n\n    class_names = [\n        \"car\",\n        \"truck\",\n        \"bus\",\n        \"trailer\",\n        \"construction_vehicle\",\n        \"pedestrian\",\n        \"motorcycle\",\n        \"bicycle\",\n        \"traffic_cone\",\n        \"barrier\",\n        \"background\",\n    ]\n\n    # use those categories if merge_classes == True\n    categories = {\n        \"vehicle\": [\"car\", \"truck\", \"bus\", \"trailer\", \"construction_vehicle\"],\n        \"pedestrian\": [\"pedestrian\"],\n        \"bike\": [\"motorcycle\", \"bicycle\"],\n        \"traffic_boundary\": [\"traffic_cone\", \"barrier\"],\n        \"background\": [\"background\"]\n    }\n\n    def __init__(self,\n                 split,\n                 preprocess_dir,\n                 merge_classes=False,\n                 pselab_paths=None\n                 ):\n\n        self.split = split\n        self.preprocess_dir = preprocess_dir\n\n        print(\"Initialize Nuscenes dataloader\")\n\n        assert isinstance(split, tuple)\n        print('Load', split)\n        self.data = []\n        for curr_split in split:\n            with open(osp.join(self.preprocess_dir, curr_split + '.pkl'), 'rb') as f:\n                self.data.extend(pickle.load(f))\n\n        self.pselab_data = None\n        if pselab_paths:\n            assert isinstance(pselab_paths, tuple)\n            print('Load pseudo label data ', pselab_paths)\n            self.pselab_data = []\n            for curr_split in pselab_paths:\n                self.pselab_data.extend(np.load(curr_split, allow_pickle=True))\n\n            # check consistency of data and pseudo labels\n            assert len(self.pselab_data) == len(self.data)\n            for i in range(len(self.pselab_data)):\n                assert len(self.pselab_data[i]['pseudo_label_2d']) == len(self.data[i]['seg_labels'])\n\n            # refine 2d pseudo labels\n            probs2d = np.concatenate([data['probs_2d'] for data in self.pselab_data])\n            pseudo_label_2d = np.concatenate([data['pseudo_label_2d'] for data in self.pselab_data]).astype(np.int)\n            pseudo_label_2d = refine_pseudo_labels(probs2d, pseudo_label_2d)\n\n            # refine 3d pseudo labels\n            # fusion model has only one final prediction saved in probs_2d\n            if 'probs_3d' in self.pselab_data[0].keys():\n                probs3d = np.concatenate([data['probs_3d'] for data in self.pselab_data])\n                pseudo_label_3d = np.concatenate([data['pseudo_label_3d'] for data in self.pselab_data]).astype(np.int)\n                pseudo_label_3d = refine_pseudo_labels(probs3d, pseudo_label_3d)\n            else:\n                pseudo_label_3d = None\n\n            # undo concat\n            left_idx = 0\n            for data_idx in range(len(self.pselab_data)):\n                right_idx = left_idx + len(self.pselab_data[data_idx]['probs_2d'])\n                self.pselab_data[data_idx]['pseudo_label_2d'] = pseudo_label_2d[left_idx:right_idx]\n                if pseudo_label_3d is not None:\n                    self.pselab_data[data_idx]['pseudo_label_3d'] = pseudo_label_3d[left_idx:right_idx]\n                else:\n                    self.pselab_data[data_idx]['pseudo_label_3d'] = None\n                left_idx = right_idx\n\n        if merge_classes:\n            self.label_mapping = -100 * np.ones(len(self.class_names), dtype=int)\n            for cat_idx, cat_list in enumerate(self.categories.values()):\n                for class_name in cat_list:\n                    self.label_mapping[self.class_names.index(class_name)] = cat_idx\n            self.class_names = list(self.categories.keys())\n        else:\n            self.label_mapping = None\n\n    def __getitem__(self, index):\n        raise NotImplementedError\n\n    def __len__(self):\n        return len(self.data)\n\n\nclass NuScenesSCN(NuScenesBase):\n    def __init__(self,\n                 split,\n                 preprocess_dir,\n                 nuscenes_dir='',\n                 pselab_paths=None,\n                 merge_classes=False,\n                 scale=20,\n                 full_scale=4096,\n                 use_image=False,\n                 resize=(400, 225),\n                 image_normalizer=None,\n                 noisy_rot=0.0,  # 3D augmentation\n                 flip_x=0.0,  # 3D augmentation\n                 rot_z=0.0,  # 3D augmentation\n                 transl=False,  # 3D augmentation\n                 fliplr=0.0,  # 2D augmentation\n                 color_jitter=None,  # 2D augmentation\n                 output_orig=False\n                 ):\n        super().__init__(split,\n                         preprocess_dir,\n                         merge_classes=merge_classes,\n                         pselab_paths=pselab_paths)\n\n        self.nuscenes_dir = nuscenes_dir\n        self.output_orig = output_orig\n\n        # point cloud parameters\n        self.scale = scale\n        self.full_scale = full_scale\n        # 3D augmentation\n        self.noisy_rot = noisy_rot\n        self.flip_x = flip_x\n        self.rot_z = rot_z\n        self.transl = transl\n\n        # image parameters\n        self.use_image = use_image\n        if self.use_image:\n            self.resize = resize\n            self.image_normalizer = image_normalizer\n\n            # data augmentation\n            self.fliplr = fliplr\n            self.color_jitter = T.ColorJitter(*color_jitter) if color_jitter else None\n\n    def __getitem__(self, index):\n        data_dict = self.data[index]\n\n        points = data_dict['points'].copy()\n        seg_label = data_dict['seg_labels'].astype(np.int64)\n\n        if self.label_mapping is not None:\n            seg_label = self.label_mapping[seg_label]\n\n        out_dict = {}\n\n        keep_idx = np.ones(len(points), dtype=np.bool)\n        if self.use_image:\n            points_img = data_dict['points_img'].copy()\n            img_path = osp.join(self.nuscenes_dir, data_dict['camera_path'])\n            image = Image.open(img_path)\n\n            if self.resize:\n                if not image.size == self.resize:\n                    # check if we do not enlarge downsized images\n                    assert image.size[0] > self.resize[0]\n\n                    # scale image points\n                    points_img[:, 0] = float(self.resize[1]) / image.size[1] * np.floor(points_img[:, 0])\n                    points_img[:, 1] = float(self.resize[0]) / image.size[0] * np.floor(points_img[:, 1])\n\n                    # resize image\n                    image = image.resize(self.resize, Image.BILINEAR)\n\n            img_indices = points_img.astype(np.int64)\n\n            assert np.all(img_indices[:, 0] >= 0)\n            assert np.all(img_indices[:, 1] >= 0)\n            assert np.all(img_indices[:, 0] < image.size[1])\n            assert np.all(img_indices[:, 1] < image.size[0])\n\n            # 2D augmentation\n            if self.color_jitter is not None:\n                image = self.color_jitter(image)\n            # PIL to numpy\n            image = np.array(image, dtype=np.float32, copy=False) / 255.\n            # 2D augmentation\n            if np.random.rand() < self.fliplr:\n                image = np.ascontiguousarray(np.fliplr(image))\n                img_indices[:, 1] = image.shape[1] - 1 - img_indices[:, 1]\n\n            # normalize image\n            if self.image_normalizer:\n                mean, std = self.image_normalizer\n                mean = np.asarray(mean, dtype=np.float32)\n                std = np.asarray(std, dtype=np.float32)\n                image = (image - mean) / std\n\n            out_dict['img'] = np.moveaxis(image, -1, 0)\n            out_dict['img_indices'] = img_indices\n\n        # 3D data augmentation and scaling from points to voxel indices\n        # nuscenes lidar coordinates: x (right), y (front), z (up)\n        coords = augment_and_scale_3d(points, self.scale, self.full_scale, noisy_rot=self.noisy_rot,\n                                      flip_x=self.flip_x, rot_z=self.rot_z, transl=self.transl)\n\n        # cast to integer\n        coords = coords.astype(np.int64)\n\n        # only use voxels inside receptive field\n        idxs = (coords.min(1) >= 0) * (coords.max(1) < self.full_scale)\n\n        out_dict['coords'] = coords[idxs]\n        out_dict['feats'] = np.ones([len(idxs), 1], np.float32)  # simply use 1 as feature\n        out_dict['seg_label'] = seg_label[idxs]\n\n        if self.use_image:\n            out_dict['img_indices'] = out_dict['img_indices'][idxs]\n\n        if self.pselab_data is not None:\n            out_dict.update({\n                'pseudo_label_2d': self.pselab_data[index]['pseudo_label_2d'][keep_idx][idxs],\n                'pseudo_label_3d': self.pselab_data[index]['pseudo_label_3d'][keep_idx][idxs]\n            })\n\n        if self.output_orig:\n            out_dict.update({\n                'orig_seg_label': seg_label,\n                'orig_points_idx': idxs,\n            })\n\n        return out_dict\n\n\ndef test_NuScenesSCN():\n    from xmuda.data.utils.visualize import draw_points_image_labels, draw_points_image_depth, draw_bird_eye_view\n    preprocess_dir = '/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess'\n    nuscenes_dir = '/datasets_local/datasets_mjaritz/nuscenes_preprocess'\n    # split = ('train_singapore',)\n    # pselab_paths = ('/home/docker_user/workspace/outputs/xmuda/nuscenes/usa_singapore/xmuda/pselab_data/train_singapore.npy',)\n    split = ('train_night',)\n    # pselab_paths = ('/home/docker_user/workspace/outputs/xmuda/nuscenes/day_night/xmuda/pselab_data/train_night.npy',)\n    dataset = NuScenesSCN(split=split,\n                          preprocess_dir=preprocess_dir,\n                          nuscenes_dir=nuscenes_dir,\n                          # pselab_paths=pselab_paths,\n                          merge_classes=True,\n                          use_image=True,\n                          noisy_rot=0.1,\n                          flip_x=0.5,\n                          rot_z=2*np.pi,\n                          transl=True,\n                          fliplr=0.5,\n                          color_jitter=(0.4, 0.4, 0.4)\n                          )\n    for i in [10, 20, 30, 40, 50, 60]:\n        data = dataset[i]\n        coords = data['coords']\n        seg_label = data['seg_label']\n        img = np.moveaxis(data['img'], 0, 2)\n        img_indices = data['img_indices']\n        draw_points_image_labels(img, img_indices, seg_label, color_palette_type='NuScenes', point_size=3)\n        # pseudo_label_2d = data['pseudo_label_2d']\n        # draw_points_image_labels(img, img_indices, pseudo_label_2d, color_palette_type='NuScenes', point_size=3)\n        draw_bird_eye_view(coords)\n        print('Number of points:', len(coords))\n\n\ndef compute_class_weights():\n    preprocess_dir = '/datasets_local/datasets_mjaritz/nuscenes_preprocess/preprocess'\n    # split = ('train_usa', 'test_usa')\n    split = ('train_day', 'test_day')\n    dataset = NuScenesBase(split,\n                           preprocess_dir,\n                           merge_classes=True\n                           )\n    # compute points per class over whole dataset\n    num_classes = len(dataset.class_names)\n    points_per_class = np.zeros(num_classes, int)\n    for i, data in enumerate(dataset.data):\n        print('{}/{}'.format(i, len(dataset)))\n        points_per_class += np.bincount(dataset.label_mapping[data['seg_labels']], minlength=num_classes)\n\n    # compute log smoothed class weights\n    class_weights = np.log(5 * points_per_class.sum() / points_per_class)\n    print('log smoothed class weights: ', class_weights / class_weights.min())\n\n\nif __name__ == '__main__':\n    test_NuScenesSCN()\n    # compute_class_weights()\n"
  },
  {
    "path": "xmuda/data/nuscenes/preprocess.py",
    "content": "import os\nimport os.path as osp\nimport numpy as np\nimport pickle\n\nfrom nuscenes.nuscenes import NuScenes\nfrom nuscenes.utils.geometry_utils import points_in_box\nfrom nuscenes.eval.detection.utils import category_to_detection_name\n\nfrom xmuda.data.nuscenes.nuscenes_dataloader import NuScenesBase\nfrom xmuda.data.nuscenes.projection import map_pointcloud_to_image\nfrom xmuda.data.nuscenes import splits\n\nclass_names_to_id = dict(zip(NuScenesBase.class_names, range(len(NuScenesBase.class_names))))\nif 'background' in class_names_to_id:\n    del class_names_to_id['background']\n\n\ndef preprocess(nusc, split_names, root_dir, out_dir,\n               keyword=None, keyword_action=None, subset_name=None,\n               location=None):\n    # cannot process day/night and location at the same time\n    assert not (bool(keyword) and bool(location))\n    if keyword:\n        assert keyword_action in ['filter', 'exclude']\n\n    # init dict to save\n    pkl_dict = {}\n    for split_name in split_names:\n         pkl_dict[split_name] = []\n\n    for i, sample in enumerate(nusc.sample):\n        curr_scene_name = nusc.get('scene', sample['scene_token'])['name']\n\n        # get if the current scene is in train, val or test\n        curr_split = None\n        for split_name in split_names:\n            if curr_scene_name in getattr(splits, split_name):\n                curr_split = split_name\n                break\n        if curr_split is None:\n            continue\n\n        if subset_name == 'night':\n            if curr_split == 'train':\n                if curr_scene_name in splits.val_night:\n                    curr_split = 'val'\n        if subset_name == 'singapore':\n            if curr_split == 'train':\n                if curr_scene_name in splits.val_singapore:\n                    curr_split = 'val'\n\n        # filter for day/night\n        if keyword:\n            scene_description = nusc.get(\"scene\", sample[\"scene_token\"])[\"description\"]\n            if keyword.lower() in scene_description.lower():\n                if keyword_action == 'exclude':\n                    # skip sample\n                    continue\n            else:\n                if keyword_action == 'filter':\n                    # skip sample\n                    continue\n\n        if location:\n            scene = nusc.get(\"scene\", sample[\"scene_token\"])\n            if location not in nusc.get(\"log\", scene['log_token'])['location']:\n                continue\n\n        lidar_token = sample[\"data\"][\"LIDAR_TOP\"]\n        cam_front_token = sample[\"data\"][\"CAM_FRONT\"]\n        lidar_path, boxes_lidar, _ = nusc.get_sample_data(lidar_token)\n        cam_path, boxes_front_cam, cam_intrinsic = nusc.get_sample_data(cam_front_token)\n\n        print('{}/{} {} {}'.format(i + 1, len(nusc.sample), curr_scene_name, lidar_path))\n\n        sd_rec_lidar = nusc.get('sample_data', sample['data'][\"LIDAR_TOP\"])\n        cs_record_lidar = nusc.get('calibrated_sensor',\n                             sd_rec_lidar['calibrated_sensor_token'])\n        pose_record_lidar = nusc.get('ego_pose', sd_rec_lidar['ego_pose_token'])\n        sd_rec_cam = nusc.get('sample_data', sample['data'][\"CAM_FRONT\"])\n        cs_record_cam = nusc.get('calibrated_sensor',\n                             sd_rec_cam['calibrated_sensor_token'])\n        pose_record_cam = nusc.get('ego_pose', sd_rec_cam['ego_pose_token'])\n\n        calib_infos = {\n            \"lidar2ego_translation\": cs_record_lidar['translation'],\n            \"lidar2ego_rotation\": cs_record_lidar['rotation'],\n            \"ego2global_translation_lidar\": pose_record_lidar['translation'],\n            \"ego2global_rotation_lidar\": pose_record_lidar['rotation'],\n            \"ego2global_translation_cam\": pose_record_cam['translation'],\n            \"ego2global_rotation_cam\": pose_record_cam['rotation'],\n            \"cam2ego_translation\": cs_record_cam['translation'],\n            \"cam2ego_rotation\": cs_record_cam['rotation'],\n            \"cam_intrinsic\": cam_intrinsic,\n        }\n\n        # load lidar points\n        pts = np.fromfile(lidar_path, dtype=np.float32, count=-1).reshape([-1, 5])[:, :3].T\n\n        # map point cloud into front camera image\n        pts_valid_flag, pts_cam_coord, pts_img = map_pointcloud_to_image(pts, (900, 1600, 3), calib_infos)\n        # fliplr so that indexing is row, col and not col, row\n        pts_img = np.ascontiguousarray(np.fliplr(pts_img))\n\n        # only use lidar points in the front camera image\n        pts = pts[:, pts_valid_flag]\n\n        num_pts = pts.shape[1]\n        seg_labels = np.full(num_pts, fill_value=len(class_names_to_id), dtype=np.uint8)\n        # only use boxes that are visible in camera\n        valid_box_tokens = [box.token for box in boxes_front_cam]\n        boxes = [box for box in boxes_lidar if box.token in valid_box_tokens]\n        for box in boxes:\n            # get points that lie inside of the box\n            fg_mask = points_in_box(box, pts)\n            det_class = category_to_detection_name(box.name)\n            if det_class is not None:\n                seg_labels[fg_mask] = class_names_to_id[det_class]\n\n        # convert to relative path\n        lidar_path = lidar_path.replace(root_dir + '/', '')\n        cam_path = cam_path.replace(root_dir + '/', '')\n\n        # transpose to yield shape (num_points, 3)\n        pts = pts.T\n\n        # append data to train, val or test list in pkl_dict\n        data_dict = {\n            'points': pts,\n            'seg_labels': seg_labels,\n            'points_img': pts_img,  # row, col format, shape: (num_points, 2)\n            'lidar_path': lidar_path,\n            'camera_path': cam_path,\n            'boxes': boxes_lidar,\n            \"sample_token\": sample[\"token\"],\n            \"scene_name\": curr_scene_name,\n            \"calib\": calib_infos\n        }\n        pkl_dict[curr_split].append(data_dict)\n\n    # save to pickle file\n    save_dir = osp.join(out_dir, 'preprocess')\n    os.makedirs(save_dir, exist_ok=True)\n    for split_name in split_names:\n        save_path = osp.join(save_dir, '{}{}.pkl'.format(split_name, '_' + subset_name if subset_name else ''))\n        with open(save_path, 'wb') as f:\n            pickle.dump(pkl_dict[split_name], f)\n            print('Wrote preprocessed data to ' + save_path)\n\n\nif __name__ == '__main__':\n    root_dir = '/datasets_master/nuscenes'\n    out_dir = '/datasets_local/datasets_mjaritz/nuscenes_preprocess'\n    nusc = NuScenes(version='v1.0-trainval', dataroot=root_dir, verbose=True)\n    # for faster debugging, the script can be run using the mini dataset\n    # nusc = NuScenes(version='v1.0-mini', dataroot=root_dir, verbose=True)\n    # We construct the splits by using the meta data of NuScenes:\n    # USA/Singapore: We check if the location is Boston or Singapore.\n    # Day/Night: We detect if \"night\" occurs in the scene description string.\n    preprocess(nusc, ['train', 'test'], root_dir, out_dir, location='boston', subset_name='usa')\n    preprocess(nusc, ['train', 'val', 'test'], root_dir, out_dir, location='singapore', subset_name='singapore')\n    preprocess(nusc, ['train', 'test'], root_dir, out_dir, keyword='night', keyword_action='exclude', subset_name='day')\n    preprocess(nusc, ['train', 'val', 'test'], root_dir, out_dir, keyword='night', keyword_action='filter', subset_name='night')\n"
  },
  {
    "path": "xmuda/data/nuscenes/projection.py",
    "content": "import numpy as np\nfrom pyquaternion import Quaternion\nfrom nuscenes.utils.geometry_utils import view_points\n\nimport matplotlib.pyplot as plt\n\n\n# modified from https://github.com/nutonomy/nuscenes-devkit/blob/master/python-sdk/nuscenes/nuscenes.py\ndef map_pointcloud_to_image(pc, im_shape, info, im=None):\n    \"\"\"\n    Maps the lidar point cloud to the image.\n    :param pc: (3, N)\n    :param im_shape: image to check size and debug\n    :param info: dict with calibration infos\n    :param im: image, only for visualization\n    :return:\n    \"\"\"\n    pc = pc.copy()\n\n    # Points live in the point sensor frame. So they need to be transformed via global to the image plane.\n    # First step: transform the point-cloud to the ego vehicle frame for the timestamp of the sweep.\n    pc = Quaternion(info['lidar2ego_rotation']).rotation_matrix @ pc\n    pc = pc + np.array(info['lidar2ego_translation'])[:, np.newaxis]\n\n    # Second step: transform to the global frame.\n    pc = Quaternion(info['ego2global_rotation_lidar']).rotation_matrix @ pc\n    pc = pc + np.array(info['ego2global_translation_lidar'])[:, np.newaxis]\n\n    # Third step: transform into the ego vehicle frame for the timestamp of the image.\n    pc = pc - np.array(info['ego2global_translation_cam'])[:, np.newaxis]\n    pc = Quaternion(info['ego2global_rotation_cam']).rotation_matrix.T @ pc\n\n    # Fourth step: transform into the camera.\n    pc = pc - np.array(info['cam2ego_translation'])[:, np.newaxis]\n    pc = Quaternion(info['cam2ego_rotation']).rotation_matrix.T @ pc\n\n    # Fifth step: actually take a \"picture\" of the point cloud.\n    # Grab the depths (camera frame z axis points away from the camera).\n    depths = pc[2, :]\n\n    # Take the actual picture (matrix multiplication with camera-matrix + renormalization).\n    points = view_points(pc, np.array(info['cam_intrinsic']), normalize=True)\n\n    # Cast to float32 to prevent later rounding errors\n    points = points.astype(np.float32)\n\n    # Remove points that are either outside or behind the camera.\n    mask = np.ones(depths.shape[0], dtype=bool)\n    mask = np.logical_and(mask, depths > 0)\n    mask = np.logical_and(mask, points[0, :] > 0)\n    mask = np.logical_and(mask, points[0, :] < im_shape[1])\n    mask = np.logical_and(mask, points[1, :] > 0)\n    mask = np.logical_and(mask, points[1, :] < im_shape[0])\n    points = points[:, mask]\n\n    # debug\n    if im is not None:\n        # Retrieve the color from the depth.\n        coloring = depths\n        coloring = coloring[mask]\n\n        plt.figure(figsize=(9, 16))\n        plt.imshow(im)\n        plt.scatter(points[0, :], points[1, :], c=coloring, s=2)\n        plt.axis('off')\n\n        # plt.show()\n\n    return mask, pc.T, points.T[:, :2]\n"
  },
  {
    "path": "xmuda/data/nuscenes/splits.py",
    "content": "# Official training set in NuScenes. We split scenes either into USA/Singapore or Day/Night.\ntrain = \\\n    ['scene-0001', 'scene-0002', 'scene-0004', 'scene-0005', 'scene-0006', 'scene-0007', 'scene-0008', 'scene-0009',\n     'scene-0010', 'scene-0011', 'scene-0019', 'scene-0020', 'scene-0021', 'scene-0022', 'scene-0023', 'scene-0024',\n     'scene-0025', 'scene-0026', 'scene-0027', 'scene-0028', 'scene-0029', 'scene-0030', 'scene-0031', 'scene-0032',\n     'scene-0033', 'scene-0034', 'scene-0041', 'scene-0042', 'scene-0043', 'scene-0044', 'scene-0045', 'scene-0046',\n     'scene-0047', 'scene-0048', 'scene-0049', 'scene-0050', 'scene-0051', 'scene-0052', 'scene-0053', 'scene-0054',\n     'scene-0055', 'scene-0056', 'scene-0057', 'scene-0058', 'scene-0059', 'scene-0060', 'scene-0061', 'scene-0062',\n     'scene-0063', 'scene-0064', 'scene-0065', 'scene-0066', 'scene-0067', 'scene-0068', 'scene-0069', 'scene-0070',\n     'scene-0071', 'scene-0072', 'scene-0073', 'scene-0074', 'scene-0075', 'scene-0076', 'scene-0120', 'scene-0121',\n     'scene-0122', 'scene-0123', 'scene-0124', 'scene-0125', 'scene-0126', 'scene-0127', 'scene-0128', 'scene-0129',\n     'scene-0130', 'scene-0131', 'scene-0132', 'scene-0133', 'scene-0134', 'scene-0135', 'scene-0138', 'scene-0139',\n     'scene-0149', 'scene-0150', 'scene-0151', 'scene-0152', 'scene-0154', 'scene-0155', 'scene-0157', 'scene-0158',\n     'scene-0159', 'scene-0160', 'scene-0161', 'scene-0162', 'scene-0163', 'scene-0164', 'scene-0165', 'scene-0166',\n     'scene-0167', 'scene-0168', 'scene-0170', 'scene-0171', 'scene-0172', 'scene-0173', 'scene-0174', 'scene-0175',\n     'scene-0176', 'scene-0177', 'scene-0178', 'scene-0179', 'scene-0180', 'scene-0181', 'scene-0182', 'scene-0183',\n     'scene-0184', 'scene-0185', 'scene-0187', 'scene-0188', 'scene-0190', 'scene-0191', 'scene-0192', 'scene-0193',\n     'scene-0194', 'scene-0195', 'scene-0196', 'scene-0199', 'scene-0200', 'scene-0202', 'scene-0203', 'scene-0204',\n     'scene-0206', 'scene-0207', 'scene-0208', 'scene-0209', 'scene-0210', 'scene-0211', 'scene-0212', 'scene-0213',\n     'scene-0214', 'scene-0218', 'scene-0219', 'scene-0220', 'scene-0222', 'scene-0224', 'scene-0225', 'scene-0226',\n     'scene-0227', 'scene-0228', 'scene-0229', 'scene-0230', 'scene-0231', 'scene-0232', 'scene-0233', 'scene-0234',\n     'scene-0235', 'scene-0236', 'scene-0237', 'scene-0238', 'scene-0239', 'scene-0240', 'scene-0241', 'scene-0242',\n     'scene-0243', 'scene-0244', 'scene-0245', 'scene-0246', 'scene-0247', 'scene-0248', 'scene-0249', 'scene-0250',\n     'scene-0251', 'scene-0252', 'scene-0253', 'scene-0254', 'scene-0255', 'scene-0256', 'scene-0257', 'scene-0258',\n     'scene-0259', 'scene-0260', 'scene-0261', 'scene-0262', 'scene-0263', 'scene-0264', 'scene-0283', 'scene-0284',\n     'scene-0285', 'scene-0286', 'scene-0287', 'scene-0288', 'scene-0289', 'scene-0290', 'scene-0291', 'scene-0292',\n     'scene-0293', 'scene-0294', 'scene-0295', 'scene-0296', 'scene-0297', 'scene-0298', 'scene-0299', 'scene-0300',\n     'scene-0301', 'scene-0302', 'scene-0303', 'scene-0304', 'scene-0305', 'scene-0306', 'scene-0315', 'scene-0316',\n     'scene-0317', 'scene-0318', 'scene-0321', 'scene-0323', 'scene-0324', 'scene-0328', 'scene-0347', 'scene-0348',\n     'scene-0349', 'scene-0350', 'scene-0351', 'scene-0352', 'scene-0353', 'scene-0354', 'scene-0355', 'scene-0356',\n     'scene-0357', 'scene-0358', 'scene-0359', 'scene-0360', 'scene-0361', 'scene-0362', 'scene-0363', 'scene-0364',\n     'scene-0365', 'scene-0366', 'scene-0367', 'scene-0368', 'scene-0369', 'scene-0370', 'scene-0371', 'scene-0372',\n     'scene-0373', 'scene-0374', 'scene-0375', 'scene-0376', 'scene-0377', 'scene-0378', 'scene-0379', 'scene-0380',\n     'scene-0381', 'scene-0382', 'scene-0383', 'scene-0384', 'scene-0385', 'scene-0386', 'scene-0388', 'scene-0389',\n     'scene-0390', 'scene-0391', 'scene-0392', 'scene-0393', 'scene-0394', 'scene-0395', 'scene-0396', 'scene-0397',\n     'scene-0398', 'scene-0399', 'scene-0400', 'scene-0401', 'scene-0402', 'scene-0403', 'scene-0405', 'scene-0406',\n     'scene-0407', 'scene-0408', 'scene-0410', 'scene-0411', 'scene-0412', 'scene-0413', 'scene-0414', 'scene-0415',\n     'scene-0416', 'scene-0417', 'scene-0418', 'scene-0419', 'scene-0420', 'scene-0421', 'scene-0422', 'scene-0423',\n     'scene-0424', 'scene-0425', 'scene-0426', 'scene-0427', 'scene-0428', 'scene-0429', 'scene-0430', 'scene-0431',\n     'scene-0432', 'scene-0433', 'scene-0434', 'scene-0435', 'scene-0436', 'scene-0437', 'scene-0438', 'scene-0439',\n     'scene-0440', 'scene-0441', 'scene-0442', 'scene-0443', 'scene-0444', 'scene-0445', 'scene-0446', 'scene-0447',\n     'scene-0448', 'scene-0449', 'scene-0450', 'scene-0451', 'scene-0452', 'scene-0453', 'scene-0454', 'scene-0455',\n     'scene-0456', 'scene-0457', 'scene-0458', 'scene-0459', 'scene-0461', 'scene-0462', 'scene-0463', 'scene-0464',\n     'scene-0465', 'scene-0467', 'scene-0468', 'scene-0469', 'scene-0471', 'scene-0472', 'scene-0474', 'scene-0475',\n     'scene-0476', 'scene-0477', 'scene-0478', 'scene-0479', 'scene-0480', 'scene-0499', 'scene-0500', 'scene-0501',\n     'scene-0502', 'scene-0504', 'scene-0505', 'scene-0506', 'scene-0507', 'scene-0508', 'scene-0509', 'scene-0510',\n     'scene-0511', 'scene-0512', 'scene-0513', 'scene-0514', 'scene-0515', 'scene-0517', 'scene-0518', 'scene-0525',\n     'scene-0526', 'scene-0527', 'scene-0528', 'scene-0529', 'scene-0530', 'scene-0531', 'scene-0532', 'scene-0533',\n     'scene-0534', 'scene-0535', 'scene-0536', 'scene-0537', 'scene-0538', 'scene-0539', 'scene-0541', 'scene-0542',\n     'scene-0543', 'scene-0544', 'scene-0545', 'scene-0546', 'scene-0566', 'scene-0568', 'scene-0570', 'scene-0571',\n     'scene-0572', 'scene-0573', 'scene-0574', 'scene-0575', 'scene-0576', 'scene-0577', 'scene-0578', 'scene-0580',\n     'scene-0582', 'scene-0583', 'scene-0584', 'scene-0585', 'scene-0586', 'scene-0587', 'scene-0588', 'scene-0589',\n     'scene-0590', 'scene-0591', 'scene-0592', 'scene-0593', 'scene-0594', 'scene-0595', 'scene-0596', 'scene-0597',\n     'scene-0598', 'scene-0599', 'scene-0600', 'scene-0639', 'scene-0640', 'scene-0641', 'scene-0642', 'scene-0643',\n     'scene-0644', 'scene-0645', 'scene-0646', 'scene-0647', 'scene-0648', 'scene-0649', 'scene-0650', 'scene-0651',\n     'scene-0652', 'scene-0653', 'scene-0654', 'scene-0655', 'scene-0656', 'scene-0657', 'scene-0658', 'scene-0659',\n     'scene-0660', 'scene-0661', 'scene-0662', 'scene-0663', 'scene-0664', 'scene-0665', 'scene-0666', 'scene-0667',\n     'scene-0668', 'scene-0669', 'scene-0670', 'scene-0671', 'scene-0672', 'scene-0673', 'scene-0674', 'scene-0675',\n     'scene-0676', 'scene-0677', 'scene-0678', 'scene-0679', 'scene-0681', 'scene-0683', 'scene-0684', 'scene-0685',\n     'scene-0686', 'scene-0687', 'scene-0688', 'scene-0689', 'scene-0695', 'scene-0696', 'scene-0697', 'scene-0698',\n     'scene-0700', 'scene-0701', 'scene-0703', 'scene-0704', 'scene-0705', 'scene-0706', 'scene-0707', 'scene-0708',\n     'scene-0709', 'scene-0710', 'scene-0711', 'scene-0712', 'scene-0713', 'scene-0714', 'scene-0715', 'scene-0716',\n     'scene-0717', 'scene-0718', 'scene-0719', 'scene-0726', 'scene-0727', 'scene-0728', 'scene-0730', 'scene-0731',\n     'scene-0733', 'scene-0734', 'scene-0735', 'scene-0736', 'scene-0737', 'scene-0738', 'scene-0739', 'scene-0740',\n     'scene-0741', 'scene-0744', 'scene-0746', 'scene-0747', 'scene-0749', 'scene-0750', 'scene-0751', 'scene-0752',\n     'scene-0757', 'scene-0758', 'scene-0759', 'scene-0760', 'scene-0761', 'scene-0762', 'scene-0763', 'scene-0764',\n     'scene-0765', 'scene-0767', 'scene-0768', 'scene-0769', 'scene-0786', 'scene-0787', 'scene-0789', 'scene-0790',\n     'scene-0791', 'scene-0792', 'scene-0803', 'scene-0804', 'scene-0805', 'scene-0806', 'scene-0808', 'scene-0809',\n     'scene-0810', 'scene-0811', 'scene-0812', 'scene-0813', 'scene-0815', 'scene-0816', 'scene-0817', 'scene-0819',\n     'scene-0820', 'scene-0821', 'scene-0822', 'scene-0847', 'scene-0848', 'scene-0849', 'scene-0850', 'scene-0851',\n     'scene-0852', 'scene-0853', 'scene-0854', 'scene-0855', 'scene-0856', 'scene-0858', 'scene-0860', 'scene-0861',\n     'scene-0862', 'scene-0863', 'scene-0864', 'scene-0865', 'scene-0866', 'scene-0868', 'scene-0869', 'scene-0870',\n     'scene-0871', 'scene-0872', 'scene-0873', 'scene-0875', 'scene-0876', 'scene-0877', 'scene-0878', 'scene-0880',\n     'scene-0882', 'scene-0883', 'scene-0884', 'scene-0885', 'scene-0886', 'scene-0887', 'scene-0888', 'scene-0889',\n     'scene-0890', 'scene-0891', 'scene-0892', 'scene-0893', 'scene-0894', 'scene-0895', 'scene-0896', 'scene-0897',\n     'scene-0898', 'scene-0899', 'scene-0900', 'scene-0901', 'scene-0902', 'scene-0903', 'scene-0945', 'scene-0947',\n     'scene-0949', 'scene-0952', 'scene-0953', 'scene-0955', 'scene-0956', 'scene-0957', 'scene-0958', 'scene-0959',\n     'scene-0960', 'scene-0961', 'scene-0975', 'scene-0976', 'scene-0977', 'scene-0978', 'scene-0979', 'scene-0980',\n     'scene-0981', 'scene-0982', 'scene-0983', 'scene-0984', 'scene-0988', 'scene-0989', 'scene-0990', 'scene-0991',\n     'scene-0992', 'scene-0994', 'scene-0995', 'scene-0996', 'scene-0997', 'scene-0998', 'scene-0999', 'scene-1000',\n     'scene-1001', 'scene-1002', 'scene-1003', 'scene-1004', 'scene-1005', 'scene-1006', 'scene-1007', 'scene-1008',\n     'scene-1009', 'scene-1010', 'scene-1011', 'scene-1012', 'scene-1013', 'scene-1014', 'scene-1015', 'scene-1016',\n     'scene-1017', 'scene-1018', 'scene-1019', 'scene-1020', 'scene-1021', 'scene-1022', 'scene-1023', 'scene-1024',\n     'scene-1025', 'scene-1044', 'scene-1045', 'scene-1046', 'scene-1047', 'scene-1048', 'scene-1049', 'scene-1050',\n     'scene-1051', 'scene-1052', 'scene-1053', 'scene-1054', 'scene-1055', 'scene-1056', 'scene-1057', 'scene-1058',\n     'scene-1074', 'scene-1075', 'scene-1076', 'scene-1077', 'scene-1078', 'scene-1079', 'scene-1080', 'scene-1081',\n     'scene-1082', 'scene-1083', 'scene-1084', 'scene-1085', 'scene-1086', 'scene-1087', 'scene-1088', 'scene-1089',\n     'scene-1090', 'scene-1091', 'scene-1092', 'scene-1093', 'scene-1094', 'scene-1095', 'scene-1096', 'scene-1097',\n     'scene-1098', 'scene-1099', 'scene-1100', 'scene-1101', 'scene-1102', 'scene-1104', 'scene-1105', 'scene-1106',\n     'scene-1107', 'scene-1108', 'scene-1109', 'scene-1110']\n\n# We use the official validation set as test set. We split scenes either into USA/Singapore or Day/Night.\nval = []\ntest = \\\n    ['scene-0003', 'scene-0012', 'scene-0013', 'scene-0014', 'scene-0015', 'scene-0016', 'scene-0017', 'scene-0018',\n     'scene-0035', 'scene-0036', 'scene-0038', 'scene-0039', 'scene-0092', 'scene-0093', 'scene-0094', 'scene-0095',\n     'scene-0096', 'scene-0097', 'scene-0098', 'scene-0099', 'scene-0100', 'scene-0101', 'scene-0102', 'scene-0103',\n     'scene-0104', 'scene-0105', 'scene-0106', 'scene-0107', 'scene-0108', 'scene-0109', 'scene-0110', 'scene-0221',\n     'scene-0268', 'scene-0269', 'scene-0270', 'scene-0271', 'scene-0272', 'scene-0273', 'scene-0274', 'scene-0275',\n     'scene-0276', 'scene-0277', 'scene-0278', 'scene-0329', 'scene-0330', 'scene-0331', 'scene-0332', 'scene-0344',\n     'scene-0345', 'scene-0346', 'scene-0519', 'scene-0520', 'scene-0521', 'scene-0522', 'scene-0523', 'scene-0524',\n     'scene-0552', 'scene-0553', 'scene-0554', 'scene-0555', 'scene-0556', 'scene-0557', 'scene-0558', 'scene-0559',\n     'scene-0560', 'scene-0561', 'scene-0562', 'scene-0563', 'scene-0564', 'scene-0565', 'scene-0625', 'scene-0626',\n     'scene-0627', 'scene-0629', 'scene-0630', 'scene-0632', 'scene-0633', 'scene-0634', 'scene-0635', 'scene-0636',\n     'scene-0637', 'scene-0638', 'scene-0770', 'scene-0771', 'scene-0775', 'scene-0777', 'scene-0778', 'scene-0780',\n     'scene-0781', 'scene-0782', 'scene-0783', 'scene-0784', 'scene-0794', 'scene-0795', 'scene-0796', 'scene-0797',\n     'scene-0798', 'scene-0799', 'scene-0800', 'scene-0802', 'scene-0904', 'scene-0905', 'scene-0906', 'scene-0907',\n     'scene-0908', 'scene-0909', 'scene-0910', 'scene-0911', 'scene-0912', 'scene-0913', 'scene-0914', 'scene-0915',\n     'scene-0916', 'scene-0917', 'scene-0919', 'scene-0920', 'scene-0921', 'scene-0922', 'scene-0923', 'scene-0924',\n     'scene-0925', 'scene-0926', 'scene-0927', 'scene-0928', 'scene-0929', 'scene-0930', 'scene-0931', 'scene-0962',\n     'scene-0963', 'scene-0966', 'scene-0967', 'scene-0968', 'scene-0969', 'scene-0971', 'scene-0972', 'scene-1059',\n     'scene-1060', 'scene-1061', 'scene-1062', 'scene-1063', 'scene-1064', 'scene-1065', 'scene-1066', 'scene-1067',\n     'scene-1068', 'scene-1069', 'scene-1070', 'scene-1071', 'scene-1072', 'scene-1073']\n\n# Exclude some scenes from the training set to use for validation.  Depends on split (Day/Night, USA/Singapore).\n# Note that, we do not produce a validation set on the source datasets (Day, USA), as we validate on target\n# (Night, Singapore) during training.\nval_night = [\n 'scene-1044',\n 'scene-1045',\n 'scene-1046',\n 'scene-1047',\n 'scene-1048',\n 'scene-1049',\n 'scene-1050',\n 'scene-1051',\n 'scene-1052',\n 'scene-1053',\n 'scene-1054',\n 'scene-1055',\n 'scene-1056',\n 'scene-1057',\n 'scene-1058'\n]\n\nval_singapore = [\n 'scene-0004',\n 'scene-0005',\n 'scene-0006',\n 'scene-0007',\n 'scene-0008',\n 'scene-0009',\n 'scene-0010',\n 'scene-0011',\n 'scene-0045',\n 'scene-0046',\n 'scene-0047',\n 'scene-0048',\n 'scene-0049',\n 'scene-0050',\n 'scene-0051',\n 'scene-0052',\n 'scene-0053',\n 'scene-0054',\n 'scene-0347',\n 'scene-0348',\n 'scene-0349',\n 'scene-0356',\n 'scene-0357',\n 'scene-0358',\n 'scene-0359',\n 'scene-0786',\n 'scene-0787',\n 'scene-0789',\n 'scene-0790',\n 'scene-0791',\n 'scene-0792',\n 'scene-0847',\n 'scene-0848',\n 'scene-0849',\n 'scene-0850',\n 'scene-0851',\n 'scene-0852',\n 'scene-0853',\n 'scene-0854',\n 'scene-0855',\n 'scene-0856',\n 'scene-0858',\n 'scene-0860',\n 'scene-0861',\n 'scene-0862',\n 'scene-0863',\n 'scene-0864',\n 'scene-0865',\n 'scene-0866',\n 'scene-0975',\n 'scene-0976',\n 'scene-0977',\n 'scene-0978',\n 'scene-0979',\n 'scene-0980',\n 'scene-0981',\n 'scene-0982',\n 'scene-0983',\n 'scene-0984',\n 'scene-0988',\n 'scene-0989',\n 'scene-0990',\n 'scene-0991',\n 'scene-1044',\n 'scene-1106',\n 'scene-1107',\n 'scene-1108',\n 'scene-1109',\n 'scene-1110',\n]"
  },
  {
    "path": "xmuda/data/semantic_kitti/preprocess.py",
    "content": "import os\nimport os.path as osp\nimport numpy as np\nimport pickle\nfrom PIL import Image\nimport glob\nimport torch\nfrom torch.utils.data import Dataset\nfrom torch.utils.data.dataloader import DataLoader\n\nfrom xmuda.data.semantic_kitti import splits\n\n# prevent \"RuntimeError: received 0 items of ancdata\"\ntorch.multiprocessing.set_sharing_strategy('file_system')\n\n\nclass DummyDataset(Dataset):\n    \"\"\"Use torch dataloader for multiprocessing\"\"\"\n    def __init__(self, root_dir, scenes):\n        self.root_dir = root_dir\n        self.data = []\n        self.glob_frames(scenes)\n\n    def glob_frames(self, scenes):\n        for scene in scenes:\n            glob_path = osp.join(self.root_dir, 'dataset', 'sequences', scene, 'image_2', '*.png')\n            cam_paths = sorted(glob.glob(glob_path))\n            # load calibration\n            calib = self.read_calib(osp.join(self.root_dir, 'dataset', 'sequences', scene, 'calib.txt'))\n            proj_matrix = calib['P2'] @ calib['Tr']\n            proj_matrix = proj_matrix.astype(np.float32)\n\n            for cam_path in cam_paths:\n                basename = osp.basename(cam_path)\n                frame_id = osp.splitext(basename)[0]\n                assert frame_id.isdigit()\n                data = {\n                    'camera_path': cam_path,\n                    'lidar_path': osp.join(self.root_dir, 'dataset', 'sequences', scene, 'velodyne',\n                                           frame_id + '.bin'),\n                    'label_path': osp.join(self.root_dir, 'dataset', 'sequences', scene, 'labels',\n                                           frame_id + '.label'),\n                    'proj_matrix': proj_matrix\n                }\n                for k, v in data.items():\n                    if isinstance(v, str):\n                        if not osp.exists(v):\n                            raise IOError('File not found {}'.format(v))\n                self.data.append(data)\n\n    @staticmethod\n    def read_calib(calib_path):\n        \"\"\"\n        :param calib_path: Path to a calibration text file.\n        :return: dict with calibration matrices.\n        \"\"\"\n        calib_all = {}\n        with open(calib_path, 'r') as f:\n            for line in f.readlines():\n                if line == '\\n':\n                    break\n                key, value = line.split(':', 1)\n                calib_all[key] = np.array([float(x) for x in value.split()])\n\n        # reshape matrices\n        calib_out = {}\n        calib_out['P2'] = calib_all['P2'].reshape(3, 4)  # 3x4 projection matrix for left camera\n        calib_out['Tr'] = np.identity(4)  # 4x4 matrix\n        calib_out['Tr'][:3, :4] = calib_all['Tr'].reshape(3, 4)\n        return calib_out\n\n    @staticmethod\n    def select_points_in_frustum(points_2d, x1, y1, x2, y2):\n        \"\"\"\n        Select points in a 2D frustum parametrized by x1, y1, x2, y2 in image coordinates\n        :param points_2d: point cloud projected into 2D\n        :param points_3d: point cloud\n        :param x1: left bound\n        :param y1: upper bound\n        :param x2: right bound\n        :param y2: lower bound\n        :return: points (2D and 3D) that are in the frustum\n        \"\"\"\n        keep_ind = (points_2d[:, 0] > x1) * \\\n                   (points_2d[:, 1] > y1) * \\\n                   (points_2d[:, 0] < x2) * \\\n                   (points_2d[:, 1] < y2)\n\n        return keep_ind\n\n    def __getitem__(self, index):\n        data_dict = self.data[index].copy()\n        scan = np.fromfile(data_dict['lidar_path'], dtype=np.float32)\n        scan = scan.reshape((-1, 4))\n        points = scan[:, :3]\n        label = np.fromfile(data_dict['label_path'], dtype=np.uint32)\n        label = label.reshape((-1))\n        label = label & 0xFFFF  # get lower half for semantics\n\n        # load image\n        image = Image.open(data_dict['camera_path'])\n        image_size = image.size\n\n        # project points into image\n        keep_idx = points[:, 0] > 0  # only keep point in front of the vehicle\n        points_hcoords = np.concatenate([points[keep_idx], np.ones([keep_idx.sum(), 1], dtype=np.float32)], axis=1)\n        img_points = (data_dict['proj_matrix'] @ points_hcoords.T).T\n        img_points = img_points[:, :2] / np.expand_dims(img_points[:, 2], axis=1)  # scale 2D points\n        keep_idx_img_pts = self.select_points_in_frustum(img_points, 0, 0, *image_size)\n        keep_idx[keep_idx] = keep_idx_img_pts\n        # fliplr so that indexing is row, col and not col, row\n        img_points = np.fliplr(img_points)\n        # debug\n        # from xmuda.data.utils.visualize import draw_points_image, draw_bird_eye_view\n        # draw_points_image(np.array(image), img_points[keep_idx_img_pts].astype(int), label[keep_idx],\n        #                   color_palette_type='SemanticKITTI_long')\n\n        data_dict['seg_label'] = label[keep_idx].astype(np.int16)\n        data_dict['points'] = points[keep_idx]\n        data_dict['points_img'] = img_points[keep_idx_img_pts]\n        data_dict['image_size'] = np.array(image_size)\n\n        return data_dict\n\n    def __len__(self):\n        return len(self.data)\n\n\ndef preprocess(split_name, root_dir, out_dir):\n    pkl_data = []\n    split = getattr(splits, split_name)\n\n    dataloader = DataLoader(DummyDataset(root_dir, split), num_workers=8)\n\n    num_skips = 0\n    for i, data_dict in enumerate(dataloader):\n        # data error leads to returning empty dict\n        if not data_dict:\n            print('empty dict, continue')\n            num_skips += 1\n            continue\n        for k, v in data_dict.items():\n            data_dict[k] = v[0]\n        print('{}/{} {}'.format(i, len(dataloader), data_dict['lidar_path']))\n\n        # convert to relative path\n        lidar_path = data_dict['lidar_path'].replace(root_dir + '/', '')\n        cam_path = data_dict['camera_path'].replace(root_dir + '/', '')\n\n        # append data\n        out_dict = {\n            'points': data_dict['points'].numpy(),\n            'seg_labels': data_dict['seg_label'].numpy(),\n            'points_img': data_dict['points_img'].numpy(),  # row, col format, shape: (num_points, 2)\n            'lidar_path': lidar_path,\n            'camera_path': cam_path,\n            'image_size': tuple(data_dict['image_size'].numpy())\n        }\n        pkl_data.append(out_dict)\n\n    print('Skipped {} files'.format(num_skips))\n\n    # save to pickle file\n    save_dir = osp.join(out_dir, 'preprocess')\n    os.makedirs(save_dir, exist_ok=True)\n    save_path = osp.join(save_dir, '{}.pkl'.format(split_name))\n    with open(save_path, 'wb') as f:\n        pickle.dump(pkl_data, f)\n        print('Wrote preprocessed data to ' + save_path)\n\n\nif __name__ == '__main__':\n    root_dir = '/datasets_master/semantic_kitti'\n    out_dir = '/datasets_local/datasets_mjaritz/semantic_kitti_preprocess'\n    preprocess('val', root_dir, out_dir)\n    preprocess('train', root_dir, out_dir)\n    preprocess('test', root_dir, out_dir)"
  },
  {
    "path": "xmuda/data/semantic_kitti/semantic_kitti_dataloader.py",
    "content": "import os.path as osp\nimport pickle\nfrom PIL import Image\nimport numpy as np\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms as T\n\nfrom xmuda.data.utils.refine_pseudo_labels import refine_pseudo_labels\nfrom xmuda.data.utils.augmentation_3d import augment_and_scale_3d\n\n\nclass SemanticKITTIBase(Dataset):\n    \"\"\"SemanticKITTI dataset\"\"\"\n\n    # https://github.com/PRBonn/semantic-kitti-api/blob/master/config/semantic-kitti.yaml\n    id_to_class_name = {\n        0: \"unlabeled\",\n        1: \"outlier\",\n        10: \"car\",\n        11: \"bicycle\",\n        13: \"bus\",\n        15: \"motorcycle\",\n        16: \"on-rails\",\n        18: \"truck\",\n        20: \"other-vehicle\",\n        30: \"person\",\n        31: \"bicyclist\",\n        32: \"motorcyclist\",\n        40: \"road\",\n        44: \"parking\",\n        48: \"sidewalk\",\n        49: \"other-ground\",\n        50: \"building\",\n        51: \"fence\",\n        52: \"other-structure\",\n        60: \"lane-marking\",\n        70: \"vegetation\",\n        71: \"trunk\",\n        72: \"terrain\",\n        80: \"pole\",\n        81: \"traffic-sign\",\n        99: \"other-object\",\n        252: \"moving-car\",\n        253: \"moving-bicyclist\",\n        254: \"moving-person\",\n        255: \"moving-motorcyclist\",\n        256: \"moving-on-rails\",\n        257: \"moving-bus\",\n        258: \"moving-truck\",\n        259: \"moving-other-vehicle\",\n    }\n\n    class_name_to_id = {v: k for k, v in id_to_class_name.items()}\n\n    # use those categories if merge_classes == True (common with A2D2)\n    categories = {\n        'car': ['car', 'moving-car'],\n        'truck': ['truck', 'moving-truck'],\n        'bike': ['bicycle', 'motorcycle', 'bicyclist', 'motorcyclist',\n                 'moving-bicyclist', 'moving-motorcyclist'],  # riders are labeled as bikes in Audi dataset\n        'person': ['person', 'moving-person'],\n        'road': ['road', 'lane-marking'],\n        'parking': ['parking'],\n        'sidewalk': ['sidewalk'],\n        'building': ['building'],\n        'nature': ['vegetation', 'trunk', 'terrain'],\n        'other-objects': ['fence', 'pole', 'traffic-sign', 'other-object'],\n    }\n\n    def __init__(self,\n                 split,\n                 preprocess_dir,\n                 merge_classes=False,\n                 pselab_paths=None\n                 ):\n\n        self.split = split\n        self.preprocess_dir = preprocess_dir\n\n        print(\"Initialize SemanticKITTI dataloader\")\n\n        assert isinstance(split, tuple)\n        print('Load', split)\n        self.data = []\n        for curr_split in split:\n            with open(osp.join(self.preprocess_dir, curr_split + '.pkl'), 'rb') as f:\n                self.data.extend(pickle.load(f))\n\n        self.pselab_data = None\n        if pselab_paths:\n            assert isinstance(pselab_paths, tuple)\n            print('Load pseudo label data ', pselab_paths)\n            self.pselab_data = []\n            for curr_split in pselab_paths:\n                self.pselab_data.extend(np.load(curr_split, allow_pickle=True))\n\n            # check consistency of data and pseudo labels\n            assert len(self.pselab_data) == len(self.data)\n            for i in range(len(self.pselab_data)):\n                assert len(self.pselab_data[i]['pseudo_label_2d']) == len(self.data[i]['seg_labels'])\n\n            # refine 2d pseudo labels\n            probs2d = np.concatenate([data['probs_2d'] for data in self.pselab_data])\n            pseudo_label_2d = np.concatenate([data['pseudo_label_2d'] for data in self.pselab_data]).astype(np.int)\n            pseudo_label_2d = refine_pseudo_labels(probs2d, pseudo_label_2d)\n\n            # refine 3d pseudo labels\n            # fusion model has only one final prediction saved in probs_2d\n            if 'probs_3d' in self.pselab_data[0].keys():\n                probs3d = np.concatenate([data['probs_3d'] for data in self.pselab_data])\n                pseudo_label_3d = np.concatenate([data['pseudo_label_3d'] for data in self.pselab_data]).astype(np.int)\n                pseudo_label_3d = refine_pseudo_labels(probs3d, pseudo_label_3d)\n            else:\n                pseudo_label_3d = None\n\n            # undo concat\n            left_idx = 0\n            for data_idx in range(len(self.pselab_data)):\n                right_idx = left_idx + len(self.pselab_data[data_idx]['probs_2d'])\n                self.pselab_data[data_idx]['pseudo_label_2d'] = pseudo_label_2d[left_idx:right_idx]\n                if pseudo_label_3d is not None:\n                    self.pselab_data[data_idx]['pseudo_label_3d'] = pseudo_label_3d[left_idx:right_idx]\n                else:\n                    self.pselab_data[data_idx]['pseudo_label_3d'] = None\n                left_idx = right_idx\n\n        if merge_classes:\n            highest_id = list(self.id_to_class_name.keys())[-1]\n            self.label_mapping = -100 * np.ones(highest_id + 2, dtype=int)\n            for cat_idx, cat_list in enumerate(self.categories.values()):\n                for class_name in cat_list:\n                    self.label_mapping[self.class_name_to_id[class_name]] = cat_idx\n            self.class_names = list(self.categories.keys())\n        else:\n            self.label_mapping = None\n\n    def __getitem__(self, index):\n        raise NotImplementedError\n\n    def __len__(self):\n        return len(self.data)\n\n\nclass SemanticKITTISCN(SemanticKITTIBase):\n    def __init__(self,\n                 split,\n                 preprocess_dir,\n                 semantic_kitti_dir='',\n                 pselab_paths=None,\n                 merge_classes=False,\n                 scale=20,\n                 full_scale=4096,\n                 image_normalizer=None,\n                 noisy_rot=0.0,  # 3D augmentation\n                 flip_y=0.0,  # 3D augmentation\n                 rot_z=0.0,  # 3D augmentation\n                 transl=False,  # 3D augmentation\n                 bottom_crop=tuple(),  # 2D augmentation (also effects 3D)\n                 fliplr=0.0,  # 2D augmentation\n                 color_jitter=None,  # 2D augmentation\n                 output_orig=False\n                 ):\n        super().__init__(split,\n                         preprocess_dir,\n                         merge_classes=merge_classes,\n                         pselab_paths=pselab_paths)\n\n        self.semantic_kitti_dir = semantic_kitti_dir\n        self.output_orig = output_orig\n\n        # point cloud parameters\n        self.scale = scale\n        self.full_scale = full_scale\n        # 3D augmentation\n        self.noisy_rot = noisy_rot\n        self.flip_y = flip_y\n        self.rot_z = rot_z\n        self.transl = transl\n\n        # image parameters\n        self.image_normalizer = image_normalizer\n        # 2D augmentation\n        self.bottom_crop = bottom_crop\n        self.fliplr = fliplr\n        self.color_jitter = T.ColorJitter(*color_jitter) if color_jitter else None\n\n    def __getitem__(self, index):\n        data_dict = self.data[index]\n\n        points = data_dict['points'].copy()\n        seg_label = data_dict['seg_labels'].astype(np.int64)\n\n        if self.label_mapping is not None:\n            seg_label = self.label_mapping[seg_label]\n\n        out_dict = {}\n\n        keep_idx = np.ones(len(points), dtype=np.bool)\n        points_img = data_dict['points_img'].copy()\n        img_path = osp.join(self.semantic_kitti_dir, data_dict['camera_path'])\n        image = Image.open(img_path)\n\n        if self.bottom_crop:\n            # self.bottom_crop is a tuple (crop_width, crop_height)\n            left = int(np.random.rand() * (image.size[0] + 1 - self.bottom_crop[0]))\n            right = left + self.bottom_crop[0]\n            top = image.size[1] - self.bottom_crop[1]\n            bottom = image.size[1]\n\n            # update image points\n            keep_idx = points_img[:, 0] >= top\n            keep_idx = np.logical_and(keep_idx, points_img[:, 0] < bottom)\n            keep_idx = np.logical_and(keep_idx, points_img[:, 1] >= left)\n            keep_idx = np.logical_and(keep_idx, points_img[:, 1] < right)\n\n            # crop image\n            image = image.crop((left, top, right, bottom))\n            points_img = points_img[keep_idx]\n            points_img[:, 0] -= top\n            points_img[:, 1] -= left\n\n            # update point cloud\n            points = points[keep_idx]\n            seg_label = seg_label[keep_idx]\n\n        img_indices = points_img.astype(np.int64)\n\n        # 2D augmentation\n        if self.color_jitter is not None:\n            image = self.color_jitter(image)\n        # PIL to numpy\n        image = np.array(image, dtype=np.float32, copy=False) / 255.\n        # 2D augmentation\n        if np.random.rand() < self.fliplr:\n            image = np.ascontiguousarray(np.fliplr(image))\n            img_indices[:, 1] = image.shape[1] - 1 - img_indices[:, 1]\n\n        # normalize image\n        if self.image_normalizer:\n            mean, std = self.image_normalizer\n            mean = np.asarray(mean, dtype=np.float32)\n            std = np.asarray(std, dtype=np.float32)\n            image = (image - mean) / std\n\n        out_dict['img'] = np.moveaxis(image, -1, 0)\n        out_dict['img_indices'] = img_indices\n\n        # 3D data augmentation and scaling from points to voxel indices\n        # Kitti lidar coordinates: x (front), y (left), z (up)\n        coords = augment_and_scale_3d(points, self.scale, self.full_scale, noisy_rot=self.noisy_rot,\n                                      flip_y=self.flip_y, rot_z=self.rot_z, transl=self.transl)\n\n        # cast to integer\n        coords = coords.astype(np.int64)\n\n        # only use voxels inside receptive field\n        idxs = (coords.min(1) >= 0) * (coords.max(1) < self.full_scale)\n\n        out_dict['coords'] = coords[idxs]\n        out_dict['feats'] = np.ones([len(idxs), 1], np.float32)  # simply use 1 as feature\n        out_dict['seg_label'] = seg_label[idxs]\n        out_dict['img_indices'] = out_dict['img_indices'][idxs]\n\n        if self.pselab_data is not None:\n            out_dict.update({\n                'pseudo_label_2d': self.pselab_data[index]['pseudo_label_2d'][keep_idx][idxs],\n                'pseudo_label_3d': self.pselab_data[index]['pseudo_label_3d'][keep_idx][idxs]\n            })\n\n        if self.output_orig:\n            out_dict.update({\n                'orig_seg_label': seg_label,\n                'orig_points_idx': idxs,\n            })\n\n        return out_dict\n\n\ndef test_SemanticKITTISCN():\n    from xmuda.data.utils.visualize import draw_points_image_labels, draw_bird_eye_view\n    preprocess_dir = '/datasets_local/datasets_mjaritz/semantic_kitti_preprocess/preprocess'\n    semantic_kitti_dir = '/datasets_local/datasets_mjaritz/semantic_kitti_preprocess'\n    # pselab_paths = (\"/home/docker_user/workspace/outputs/xmuda/a2d2_semantic_kitti/xmuda_crop_resize/pselab_data/train.npy\",)\n    # split = ('train',)\n    split = ('val',)\n    dataset = SemanticKITTISCN(split=split,\n                               preprocess_dir=preprocess_dir,\n                               semantic_kitti_dir=semantic_kitti_dir,\n                               # pselab_paths=pselab_paths,\n                               merge_classes=True,\n                               noisy_rot=0.1,\n                               flip_y=0.5,\n                               rot_z=2*np.pi,\n                               transl=True,\n                               bottom_crop=(480, 302),\n                               fliplr=0.5,\n                               color_jitter=(0.4, 0.4, 0.4)\n                               )\n    for i in [10, 20, 30, 40, 50, 60]:\n        data = dataset[i]\n        coords = data['coords']\n        seg_label = data['seg_label']\n        img = np.moveaxis(data['img'], 0, 2)\n        img_indices = data['img_indices']\n        # pseudo_label_2d = data['pseudo_label_2d']\n        draw_points_image_labels(img, img_indices, seg_label, color_palette_type='SemanticKITTI', point_size=1)\n        # draw_points_image_labels(img, img_indices, pseudo_label_2d, color_palette_type='SemanticKITTI', point_size=1)\n        # assert len(pseudo_label_2d) == len(seg_label)\n        draw_bird_eye_view(coords)\n\n\ndef compute_class_weights():\n    preprocess_dir = '/datasets_local/datasets_mjaritz/semantic_kitti_preprocess/preprocess'\n    split = ('train',)\n    dataset = SemanticKITTIBase(split,\n                                preprocess_dir,\n                                merge_classes=True\n                                )\n    # compute points per class over whole dataset\n    num_classes = len(dataset.class_names)\n    points_per_class = np.zeros(num_classes, int)\n    for i, data in enumerate(dataset.data):\n        print('{}/{}'.format(i, len(dataset)))\n        labels = dataset.label_mapping[data['seg_labels']]\n        points_per_class += np.bincount(labels[labels != -100], minlength=num_classes)\n\n    # compute log smoothed class weights\n    class_weights = np.log(5 * points_per_class.sum() / points_per_class)\n    print('log smoothed class weights: ', class_weights / class_weights.min())\n\n\nif __name__ == '__main__':\n    test_SemanticKITTISCN()\n    # compute_class_weights()\n"
  },
  {
    "path": "xmuda/data/semantic_kitti/splits.py",
    "content": "# official split defined in https://github.com/PRBonn/semantic-kitti-api/blob/master/config/semantic-kitti.yaml\n\ntrain = [\n    '00',\n    '01',\n    '02',\n    '03',\n    '04',\n    '05',\n    '06',\n    '09',\n    '10',\n]\n\nval = [\n    '07'\n]\n\ntest = [\n    '08'\n]\n\n# not used\nhidden_test = [\n    '11',\n    '12',\n    '13',\n    '14',\n    '15',\n    '16',\n    '17',\n    '18',\n    '19',\n    '20',\n    '21',\n]\n"
  },
  {
    "path": "xmuda/data/utils/augmentation_3d.py",
    "content": "import numpy as np\n\n\ndef augment_and_scale_3d(points, scale, full_scale,\n                         noisy_rot=0.0,\n                         flip_x=0.0,\n                         flip_y=0.0,\n                         rot_z=0.0,\n                         transl=False):\n    \"\"\"\n    3D point cloud augmentation and scaling from points (in meters) to voxels\n    :param points: 3D points in meters\n    :param scale: voxel scale in 1 / m, e.g. 20 corresponds to 5cm voxels\n    :param full_scale: size of the receptive field of SparseConvNet\n    :param noisy_rot: scale of random noise added to all elements of a rotation matrix\n    :param flip_x: probability of flipping the x-axis (left-right in nuScenes LiDAR coordinate system)\n    :param flip_y: probability of flipping the y-axis (left-right in Kitti LiDAR coordinate system)\n    :param rot_z: angle in rad around the z-axis (up-axis)\n    :param transl: True or False, random translation inside the receptive field of the SCN, defined by full_scale\n    :return coords: the coordinates that are given as input to SparseConvNet\n    \"\"\"\n    if noisy_rot > 0 or flip_x > 0 or flip_y > 0 or rot_z > 0:\n        rot_matrix = np.eye(3, dtype=np.float32)\n        if noisy_rot > 0:\n            # add noise to rotation matrix\n            rot_matrix += np.random.randn(3, 3) * noisy_rot\n        if flip_x > 0:\n            # flip x axis: multiply element at (0, 0) with 1 or -1\n            rot_matrix[0][0] *= np.random.randint(0, 2) * 2 - 1\n        if flip_y > 0:\n            # flip y axis: multiply element at (1, 1) with 1 or -1\n            rot_matrix[1][1] *= np.random.randint(0, 2) * 2 - 1\n        if rot_z > 0:\n            # rotate around z-axis (up-axis)\n            theta = np.random.rand() * rot_z\n            z_rot_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],\n                                     [np.sin(theta), np.cos(theta), 0],\n                                     [0, 0, 1]], dtype=np.float32)\n            rot_matrix = rot_matrix.dot(z_rot_matrix)\n        points = points.dot(rot_matrix)\n\n    # scale with inverse voxel size (e.g. 20 corresponds to 5cm)\n    coords = points * scale\n    # translate points to positive octant (receptive field of SCN in x, y, z coords is in interval [0, full_scale])\n    coords -= coords.min(0)\n\n    if transl:\n        # random translation inside receptive field of SCN\n        offset = np.clip(full_scale - coords.max(0) - 0.001, a_min=0, a_max=None) * np.random.rand(3)\n        coords += offset\n\n    return coords\n"
  },
  {
    "path": "xmuda/data/utils/evaluate.py",
    "content": "import numpy as np\nfrom sklearn.metrics import confusion_matrix as CM\n\nclass Evaluator(object):\n    def __init__(self, class_names, labels=None):\n        self.class_names = tuple(class_names)\n        self.num_classes = len(class_names)\n        self.labels = np.arange(self.num_classes) if labels is None else np.array(labels)\n        assert self.labels.shape[0] == self.num_classes\n        self.confusion_matrix = np.zeros((self.num_classes, self.num_classes))\n\n    def update(self, pred_label, gt_label):\n        \"\"\"Update per instance\n\n        Args:\n            pred_label (np.ndarray): (num_points)\n            gt_label (np.ndarray): (num_points,)\n\n        \"\"\"\n        # convert ignore_label to num_classes\n        # refer to sklearn.metrics.confusion_matrix\n        gt_label[gt_label == -100] = self.num_classes\n        confusion_matrix = CM(gt_label.flatten(),\n                              pred_label.flatten(),\n                              labels=self.labels)\n        self.confusion_matrix += confusion_matrix\n\n    def batch_update(self, pred_labels, gt_labels):\n        assert len(pred_labels) == len(gt_labels)\n        for pred_label, gt_label in zip(pred_labels, gt_labels):\n            self.update(pred_label, gt_label)\n\n    @property\n    def overall_acc(self):\n        return np.sum(np.diag(self.confusion_matrix)) / np.sum(self.confusion_matrix)\n\n    @property\n    def overall_iou(self):\n        class_iou = np.array(self.class_iou.copy())\n        class_iou[np.isnan(class_iou)] = 0\n        return np.mean(class_iou)\n\n    @property\n    def class_seg_acc(self):\n        return [self.confusion_matrix[i, i] / np.sum(self.confusion_matrix[i])\n                for i in range(self.num_classes)]\n\n    @property\n    def class_iou(self):\n        iou_list = []\n        for i in range(self.num_classes):\n            tp = self.confusion_matrix[i, i]\n            p = self.confusion_matrix[:, i].sum()\n            g = self.confusion_matrix[i, :].sum()\n            union = p + g - tp\n            if union == 0:\n                iou = float('nan')\n            else:\n                iou = tp / union\n            iou_list.append(iou)\n        return iou_list\n\n    def print_table(self):\n        from tabulate import tabulate\n        header = ['Class', 'Accuracy', 'IOU', 'Total']\n        seg_acc_per_class = self.class_seg_acc\n        iou_per_class = self.class_iou\n        table = []\n        for ind, class_name in enumerate(self.class_names):\n            table.append([class_name,\n                          seg_acc_per_class[ind] * 100,\n                          iou_per_class[ind] * 100,\n                          int(self.confusion_matrix[ind].sum()),\n                          ])\n        return tabulate(table, headers=header, tablefmt='psql', floatfmt='.2f')\n\n    def save_table(self, filename):\n        from tabulate import tabulate\n        header = ('overall acc', 'overall iou') + self.class_names\n        table = [[self.overall_acc, self.overall_iou] + self.class_iou]\n        with open(filename, 'w') as f:\n            # In order to unify format, remove all the alignments.\n            f.write(tabulate(table, headers=header, tablefmt='tsv', floatfmt='.5f',\n                             numalign=None, stralign=None))\n"
  },
  {
    "path": "xmuda/data/utils/refine_pseudo_labels.py",
    "content": "import torch\n\n\ndef refine_pseudo_labels(probs, pseudo_label, ignore_label=-100):\n    \"\"\"\n    Reference: https://github.com/liyunsheng13/BDL/blob/master/SSL.py\n    Per class, set the less confident half of labels to ignore label.\n    :param probs: maximum probabilities (N,), where N is the number of 3D points\n    :param pseudo_label: predicted label which had maximum probability (N,)\n    :param ignore_label:\n    :return:\n    \"\"\"\n    probs, pseudo_label = torch.tensor(probs), torch.tensor(pseudo_label)\n    for cls_idx in pseudo_label.unique():\n        curr_idx = pseudo_label == cls_idx\n        curr_idx = curr_idx.nonzero().squeeze(1)\n        thresh = probs[curr_idx].median()\n        thresh = min(thresh, 0.9)\n        ignore_idx = curr_idx[probs[curr_idx] < thresh]\n        pseudo_label[ignore_idx] = ignore_label\n    return pseudo_label.numpy()\n"
  },
  {
    "path": "xmuda/data/utils/turbo_cmap.py",
    "content": "# Reference: https://gist.github.com/mikhailov-work/ee72ba4191942acecc03fe6da94fc73f\n\n# Copyright 2019 Google LLC.\n# SPDX-License-Identifier: Apache-2.0\n\n# Author: Anton Mikhailov\n\nturbo_colormap_data = [[0.18995,0.07176,0.23217],[0.19483,0.08339,0.26149],[0.19956,0.09498,0.29024],[0.20415,0.10652,0.31844],[0.20860,0.11802,0.34607],[0.21291,0.12947,0.37314],[0.21708,0.14087,0.39964],[0.22111,0.15223,0.42558],[0.22500,0.16354,0.45096],[0.22875,0.17481,0.47578],[0.23236,0.18603,0.50004],[0.23582,0.19720,0.52373],[0.23915,0.20833,0.54686],[0.24234,0.21941,0.56942],[0.24539,0.23044,0.59142],[0.24830,0.24143,0.61286],[0.25107,0.25237,0.63374],[0.25369,0.26327,0.65406],[0.25618,0.27412,0.67381],[0.25853,0.28492,0.69300],[0.26074,0.29568,0.71162],[0.26280,0.30639,0.72968],[0.26473,0.31706,0.74718],[0.26652,0.32768,0.76412],[0.26816,0.33825,0.78050],[0.26967,0.34878,0.79631],[0.27103,0.35926,0.81156],[0.27226,0.36970,0.82624],[0.27334,0.38008,0.84037],[0.27429,0.39043,0.85393],[0.27509,0.40072,0.86692],[0.27576,0.41097,0.87936],[0.27628,0.42118,0.89123],[0.27667,0.43134,0.90254],[0.27691,0.44145,0.91328],[0.27701,0.45152,0.92347],[0.27698,0.46153,0.93309],[0.27680,0.47151,0.94214],[0.27648,0.48144,0.95064],[0.27603,0.49132,0.95857],[0.27543,0.50115,0.96594],[0.27469,0.51094,0.97275],[0.27381,0.52069,0.97899],[0.27273,0.53040,0.98461],[0.27106,0.54015,0.98930],[0.26878,0.54995,0.99303],[0.26592,0.55979,0.99583],[0.26252,0.56967,0.99773],[0.25862,0.57958,0.99876],[0.25425,0.58950,0.99896],[0.24946,0.59943,0.99835],[0.24427,0.60937,0.99697],[0.23874,0.61931,0.99485],[0.23288,0.62923,0.99202],[0.22676,0.63913,0.98851],[0.22039,0.64901,0.98436],[0.21382,0.65886,0.97959],[0.20708,0.66866,0.97423],[0.20021,0.67842,0.96833],[0.19326,0.68812,0.96190],[0.18625,0.69775,0.95498],[0.17923,0.70732,0.94761],[0.17223,0.71680,0.93981],[0.16529,0.72620,0.93161],[0.15844,0.73551,0.92305],[0.15173,0.74472,0.91416],[0.14519,0.75381,0.90496],[0.13886,0.76279,0.89550],[0.13278,0.77165,0.88580],[0.12698,0.78037,0.87590],[0.12151,0.78896,0.86581],[0.11639,0.79740,0.85559],[0.11167,0.80569,0.84525],[0.10738,0.81381,0.83484],[0.10357,0.82177,0.82437],[0.10026,0.82955,0.81389],[0.09750,0.83714,0.80342],[0.09532,0.84455,0.79299],[0.09377,0.85175,0.78264],[0.09287,0.85875,0.77240],[0.09267,0.86554,0.76230],[0.09320,0.87211,0.75237],[0.09451,0.87844,0.74265],[0.09662,0.88454,0.73316],[0.09958,0.89040,0.72393],[0.10342,0.89600,0.71500],[0.10815,0.90142,0.70599],[0.11374,0.90673,0.69651],[0.12014,0.91193,0.68660],[0.12733,0.91701,0.67627],[0.13526,0.92197,0.66556],[0.14391,0.92680,0.65448],[0.15323,0.93151,0.64308],[0.16319,0.93609,0.63137],[0.17377,0.94053,0.61938],[0.18491,0.94484,0.60713],[0.19659,0.94901,0.59466],[0.20877,0.95304,0.58199],[0.22142,0.95692,0.56914],[0.23449,0.96065,0.55614],[0.24797,0.96423,0.54303],[0.26180,0.96765,0.52981],[0.27597,0.97092,0.51653],[0.29042,0.97403,0.50321],[0.30513,0.97697,0.48987],[0.32006,0.97974,0.47654],[0.33517,0.98234,0.46325],[0.35043,0.98477,0.45002],[0.36581,0.98702,0.43688],[0.38127,0.98909,0.42386],[0.39678,0.99098,0.41098],[0.41229,0.99268,0.39826],[0.42778,0.99419,0.38575],[0.44321,0.99551,0.37345],[0.45854,0.99663,0.36140],[0.47375,0.99755,0.34963],[0.48879,0.99828,0.33816],[0.50362,0.99879,0.32701],[0.51822,0.99910,0.31622],[0.53255,0.99919,0.30581],[0.54658,0.99907,0.29581],[0.56026,0.99873,0.28623],[0.57357,0.99817,0.27712],[0.58646,0.99739,0.26849],[0.59891,0.99638,0.26038],[0.61088,0.99514,0.25280],[0.62233,0.99366,0.24579],[0.63323,0.99195,0.23937],[0.64362,0.98999,0.23356],[0.65394,0.98775,0.22835],[0.66428,0.98524,0.22370],[0.67462,0.98246,0.21960],[0.68494,0.97941,0.21602],[0.69525,0.97610,0.21294],[0.70553,0.97255,0.21032],[0.71577,0.96875,0.20815],[0.72596,0.96470,0.20640],[0.73610,0.96043,0.20504],[0.74617,0.95593,0.20406],[0.75617,0.95121,0.20343],[0.76608,0.94627,0.20311],[0.77591,0.94113,0.20310],[0.78563,0.93579,0.20336],[0.79524,0.93025,0.20386],[0.80473,0.92452,0.20459],[0.81410,0.91861,0.20552],[0.82333,0.91253,0.20663],[0.83241,0.90627,0.20788],[0.84133,0.89986,0.20926],[0.85010,0.89328,0.21074],[0.85868,0.88655,0.21230],[0.86709,0.87968,0.21391],[0.87530,0.87267,0.21555],[0.88331,0.86553,0.21719],[0.89112,0.85826,0.21880],[0.89870,0.85087,0.22038],[0.90605,0.84337,0.22188],[0.91317,0.83576,0.22328],[0.92004,0.82806,0.22456],[0.92666,0.82025,0.22570],[0.93301,0.81236,0.22667],[0.93909,0.80439,0.22744],[0.94489,0.79634,0.22800],[0.95039,0.78823,0.22831],[0.95560,0.78005,0.22836],[0.96049,0.77181,0.22811],[0.96507,0.76352,0.22754],[0.96931,0.75519,0.22663],[0.97323,0.74682,0.22536],[0.97679,0.73842,0.22369],[0.98000,0.73000,0.22161],[0.98289,0.72140,0.21918],[0.98549,0.71250,0.21650],[0.98781,0.70330,0.21358],[0.98986,0.69382,0.21043],[0.99163,0.68408,0.20706],[0.99314,0.67408,0.20348],[0.99438,0.66386,0.19971],[0.99535,0.65341,0.19577],[0.99607,0.64277,0.19165],[0.99654,0.63193,0.18738],[0.99675,0.62093,0.18297],[0.99672,0.60977,0.17842],[0.99644,0.59846,0.17376],[0.99593,0.58703,0.16899],[0.99517,0.57549,0.16412],[0.99419,0.56386,0.15918],[0.99297,0.55214,0.15417],[0.99153,0.54036,0.14910],[0.98987,0.52854,0.14398],[0.98799,0.51667,0.13883],[0.98590,0.50479,0.13367],[0.98360,0.49291,0.12849],[0.98108,0.48104,0.12332],[0.97837,0.46920,0.11817],[0.97545,0.45740,0.11305],[0.97234,0.44565,0.10797],[0.96904,0.43399,0.10294],[0.96555,0.42241,0.09798],[0.96187,0.41093,0.09310],[0.95801,0.39958,0.08831],[0.95398,0.38836,0.08362],[0.94977,0.37729,0.07905],[0.94538,0.36638,0.07461],[0.94084,0.35566,0.07031],[0.93612,0.34513,0.06616],[0.93125,0.33482,0.06218],[0.92623,0.32473,0.05837],[0.92105,0.31489,0.05475],[0.91572,0.30530,0.05134],[0.91024,0.29599,0.04814],[0.90463,0.28696,0.04516],[0.89888,0.27824,0.04243],[0.89298,0.26981,0.03993],[0.88691,0.26152,0.03753],[0.88066,0.25334,0.03521],[0.87422,0.24526,0.03297],[0.86760,0.23730,0.03082],[0.86079,0.22945,0.02875],[0.85380,0.22170,0.02677],[0.84662,0.21407,0.02487],[0.83926,0.20654,0.02305],[0.83172,0.19912,0.02131],[0.82399,0.19182,0.01966],[0.81608,0.18462,0.01809],[0.80799,0.17753,0.01660],[0.79971,0.17055,0.01520],[0.79125,0.16368,0.01387],[0.78260,0.15693,0.01264],[0.77377,0.15028,0.01148],[0.76476,0.14374,0.01041],[0.75556,0.13731,0.00942],[0.74617,0.13098,0.00851],[0.73661,0.12477,0.00769],[0.72686,0.11867,0.00695],[0.71692,0.11268,0.00629],[0.70680,0.10680,0.00571],[0.69650,0.10102,0.00522],[0.68602,0.09536,0.00481],[0.67535,0.08980,0.00449],[0.66449,0.08436,0.00424],[0.65345,0.07902,0.00408],[0.64223,0.07380,0.00401],[0.63082,0.06868,0.00401],[0.61923,0.06367,0.00410],[0.60746,0.05878,0.00427],[0.59550,0.05399,0.00453],[0.58336,0.04931,0.00486],[0.57103,0.04474,0.00529],[0.55852,0.04028,0.00579],[0.54583,0.03593,0.00638],[0.53295,0.03169,0.00705],[0.51989,0.02756,0.00780],[0.50664,0.02354,0.00863],[0.49321,0.01963,0.00955],[0.47960,0.01583,0.01055]]\n\n# The look-up table contains 256 entries. Each entry is a floating point sRGB triplet.\n# To use it with matplotlib, pass cmap=ListedColormap(turbo_colormap_data) as an arg to imshow() (don't forget \"from matplotlib.colors import ListedColormap\").\n# If you have a typical 8-bit greyscale image, you can use the 8-bit value to index into this LUT directly.\n# The floating point color values can be converted to 8-bit sRGB via multiplying by 255 and casting/flooring to an integer. Saturation should not be required for IEEE-754 compliant arithmetic.\n# If you have a floating point value in the range [0,1], you can use interpolate() to linearly interpolate between the entries.\n# If you have 16-bit or 32-bit integer values, convert them to floating point values on the [0,1] range and then use interpolate(). Doing the interpolation in floating point will reduce banding.\n# If some of your values may lie outside the [0,1] range, use interpolate_or_clip() to highlight them.\n\ndef interpolate(colormap, x):\n  x = max(0.0, min(1.0, x))\n  a = int(x*255.0)\n  b = min(255, a + 1)\n  f = x*255.0 - a\n  return [colormap[a][0] + (colormap[b][0] - colormap[a][0]) * f,\n          colormap[a][1] + (colormap[b][1] - colormap[a][1]) * f,\n          colormap[a][2] + (colormap[b][2] - colormap[a][2]) * f]\n\ndef interpolate_or_clip(colormap, x):\n  if   x < 0.0: return [0.0, 0.0, 0.0]\n  elif x > 1.0: return [1.0, 1.0, 1.0]\n  else: return interpolate(colormap, x)\n"
  },
  {
    "path": "xmuda/data/utils/validate.py",
    "content": "import numpy as np\nimport logging\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nfrom xmuda.data.utils.evaluate import Evaluator\n\n\ndef validate(cfg,\n             model_2d,\n             model_3d,\n             dataloader,\n             val_metric_logger,\n             pselab_path=None):\n    logger = logging.getLogger('xmuda.validate')\n    logger.info('Validation')\n\n    # evaluator\n    class_names = dataloader.dataset.class_names\n    evaluator_2d = Evaluator(class_names)\n    evaluator_3d = Evaluator(class_names) if model_3d else None\n    evaluator_ensemble = Evaluator(class_names) if model_3d else None\n\n    pselab_data_list = []\n\n    end = time.time()\n    with torch.no_grad():\n        for iteration, data_batch in enumerate(dataloader):\n            data_time = time.time() - end\n            # copy data from cpu to gpu\n            if 'SCN' in cfg.DATASET_TARGET.TYPE:\n                data_batch['x'][1] = data_batch['x'][1].cuda()\n                data_batch['seg_label'] = data_batch['seg_label'].cuda()\n                data_batch['img'] = data_batch['img'].cuda()\n            else:\n                raise NotImplementedError\n\n            # predict\n            preds_2d = model_2d(data_batch)\n            preds_3d = model_3d(data_batch) if model_3d else None\n\n            pred_label_voxel_2d = preds_2d['seg_logit'].argmax(1).cpu().numpy()\n            pred_label_voxel_3d = preds_3d['seg_logit'].argmax(1).cpu().numpy() if model_3d else None\n\n            # softmax average (ensembling)\n            probs_2d = F.softmax(preds_2d['seg_logit'], dim=1)\n            probs_3d = F.softmax(preds_3d['seg_logit'], dim=1) if model_3d else None\n            pred_label_voxel_ensemble = (probs_2d + probs_3d).argmax(1).cpu().numpy() if model_3d else None\n\n            # get original point cloud from before voxelization\n            seg_label = data_batch['orig_seg_label']\n            points_idx = data_batch['orig_points_idx']\n            # loop over batch\n            left_idx = 0\n            for batch_ind in range(len(seg_label)):\n                curr_points_idx = points_idx[batch_ind]\n                # check if all points have predictions (= all voxels inside receptive field)\n                assert np.all(curr_points_idx)\n\n                curr_seg_label = seg_label[batch_ind]\n                right_idx = left_idx + curr_points_idx.sum()\n                pred_label_2d = pred_label_voxel_2d[left_idx:right_idx]\n                pred_label_3d = pred_label_voxel_3d[left_idx:right_idx] if model_3d else None\n                pred_label_ensemble = pred_label_voxel_ensemble[left_idx:right_idx] if model_3d else None\n\n                # evaluate\n                evaluator_2d.update(pred_label_2d, curr_seg_label)\n                if model_3d:\n                    evaluator_3d.update(pred_label_3d, curr_seg_label)\n                    evaluator_ensemble.update(pred_label_ensemble, curr_seg_label)\n\n                if pselab_path is not None:\n                    assert np.all(pred_label_2d >= 0)\n                    curr_probs_2d = probs_2d[left_idx:right_idx]\n                    curr_probs_3d = probs_3d[left_idx:right_idx] if model_3d else None\n                    pselab_data_list.append({\n                        'probs_2d': curr_probs_2d[range(len(pred_label_2d)), pred_label_2d].cpu().numpy(),\n                        'pseudo_label_2d': pred_label_2d.astype(np.uint8),\n                        'probs_3d': curr_probs_3d[range(len(pred_label_3d)), pred_label_3d].cpu().numpy() if model_3d else None,\n                        'pseudo_label_3d': pred_label_3d.astype(np.uint8) if model_3d else None\n                    })\n\n                left_idx = right_idx\n\n            seg_loss_2d = F.cross_entropy(preds_2d['seg_logit'], data_batch['seg_label'])\n            seg_loss_3d = F.cross_entropy(preds_3d['seg_logit'], data_batch['seg_label']) if model_3d else None\n            val_metric_logger.update(seg_loss_2d=seg_loss_2d)\n            if seg_loss_3d is not None:\n                val_metric_logger.update(seg_loss_3d=seg_loss_3d)\n\n            batch_time = time.time() - end\n            val_metric_logger.update(time=batch_time, data=data_time)\n            end = time.time()\n\n            # log\n            cur_iter = iteration + 1\n            if cur_iter == 1 or (cfg.VAL.LOG_PERIOD > 0 and cur_iter % cfg.VAL.LOG_PERIOD == 0):\n                logger.info(\n                    val_metric_logger.delimiter.join(\n                        [\n                            'iter: {iter}/{total_iter}',\n                            '{meters}',\n                            'max mem: {memory:.0f}',\n                        ]\n                    ).format(\n                        iter=cur_iter,\n                        total_iter=len(dataloader),\n                        meters=str(val_metric_logger),\n                        memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),\n                    )\n                )\n\n        val_metric_logger.update(seg_iou_2d=evaluator_2d.overall_iou)\n        if evaluator_3d is not None:\n            val_metric_logger.update(seg_iou_3d=evaluator_3d.overall_iou)\n        eval_list = [('2D', evaluator_2d)]\n        if model_3d:\n            eval_list.extend([('3D', evaluator_3d), ('2D+3D', evaluator_ensemble)])\n        for modality, evaluator in eval_list:\n            logger.info('{} overall accuracy={:.2f}%'.format(modality, 100.0 * evaluator.overall_acc))\n            logger.info('{} overall IOU={:.2f}'.format(modality, 100.0 * evaluator.overall_iou))\n            logger.info('{} class-wise segmentation accuracy and IoU.\\n{}'.format(modality, evaluator.print_table()))\n\n        if pselab_path is not None:\n            np.save(pselab_path, pselab_data_list)\n            logger.info('Saved pseudo label data to {}'.format(pselab_path))\n"
  },
  {
    "path": "xmuda/data/utils/visualize.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nfrom xmuda.data.utils.turbo_cmap import interpolate_or_clip, turbo_colormap_data\n\n\n# all classes\nNUSCENES_COLOR_PALETTE = [\n    (255, 158, 0),  # car\n    (255, 158, 0),  # truck\n    (255, 158, 0),  # bus\n    (255, 158, 0),  # trailer\n    (255, 158, 0),  # construction_vehicle\n    (0, 0, 230),  # pedestrian\n    (255, 61, 99),  # motorcycle\n    (255, 61, 99),  # bicycle\n    (0, 0, 0),  # traffic_cone\n    (0, 0, 0),  # barrier\n    (200, 200, 200),  # background\n]\n\n# classes after merging (as used in xMUDA)\nNUSCENES_COLOR_PALETTE_SHORT = [\n    (255, 158, 0),  # vehicle\n    (0, 0, 230),  # pedestrian\n    (255, 61, 99),  # bike\n    (0, 0, 0),  # traffic boundary\n    (200, 200, 200),  # background\n]\n\n# all classes\nA2D2_COLOR_PALETTE_SHORT = [\n    (255, 0, 0),  # car\n    (255, 128, 0),  # truck\n    (182, 89, 6),  # bike\n    (204, 153, 255),  # person\n    (255, 0, 255),  # road\n    (150, 150, 200),  # parking\n    (180, 150, 200),  # sidewalk\n    (241, 230, 255),  # building\n    (147, 253, 194),  # nature\n    (255, 246, 143),  # other-objects\n    (0, 0, 0)  # ignore\n]\n\n# colors as defined in https://github.com/PRBonn/semantic-kitti-api/blob/master/config/semantic-kitti.yaml\nSEMANTIC_KITTI_ID_TO_BGR = {  # bgr\n  0: [0, 0, 0],\n  1: [0, 0, 255],\n  10: [245, 150, 100],\n  11: [245, 230, 100],\n  13: [250, 80, 100],\n  15: [150, 60, 30],\n  16: [255, 0, 0],\n  18: [180, 30, 80],\n  20: [255, 0, 0],\n  30: [30, 30, 255],\n  31: [200, 40, 255],\n  32: [90, 30, 150],\n  40: [255, 0, 255],\n  44: [255, 150, 255],\n  48: [75, 0, 75],\n  49: [75, 0, 175],\n  50: [0, 200, 255],\n  51: [50, 120, 255],\n  52: [0, 150, 255],\n  60: [170, 255, 150],\n  70: [0, 175, 0],\n  71: [0, 60, 135],\n  72: [80, 240, 150],\n  80: [150, 240, 255],\n  81: [0, 0, 255],\n  99: [255, 255, 50],\n  252: [245, 150, 100],\n  256: [255, 0, 0],\n  253: [200, 40, 255],\n  254: [30, 30, 255],\n  255: [90, 30, 150],\n  257: [250, 80, 100],\n  258: [180, 30, 80],\n  259: [255, 0, 0],\n}\nSEMANTIC_KITTI_COLOR_PALETTE = [SEMANTIC_KITTI_ID_TO_BGR[id] if id in SEMANTIC_KITTI_ID_TO_BGR.keys() else [0, 0, 0]\n                                for id in range(list(SEMANTIC_KITTI_ID_TO_BGR.keys())[-1] + 1)]\n\n\n# classes after merging (as used in xMUDA)\nSEMANTIC_KITTI_COLOR_PALETTE_SHORT_BGR = [\n    [245, 150, 100],  # car\n    [180, 30, 80],  # truck\n    [150, 60, 30],  # bike\n    [30, 30, 255],  # person\n    [255, 0, 255],  # road\n    [255, 150, 255],  # parking\n    [75, 0, 75],  # sidewalk\n    [0, 200, 255],  # building\n    [0, 175, 0],  # nature\n    [255, 255, 50],  # other-objects\n    [0, 0, 0],  # ignore\n]\nSEMANTIC_KITTI_COLOR_PALETTE_SHORT = [(c[2], c[1], c[0]) for c in SEMANTIC_KITTI_COLOR_PALETTE_SHORT_BGR]\n\n\ndef draw_points_image_labels(img, img_indices, seg_labels, show=True, color_palette_type='NuScenes', point_size=0.5):\n    if color_palette_type == 'NuScenes':\n        color_palette = NUSCENES_COLOR_PALETTE_SHORT\n    elif color_palette_type == 'A2D2':\n        color_palette = A2D2_COLOR_PALETTE_SHORT\n    elif color_palette_type == 'SemanticKITTI':\n        color_palette = SEMANTIC_KITTI_COLOR_PALETTE_SHORT\n    elif color_palette_type == 'SemanticKITTI_long':\n        color_palette = SEMANTIC_KITTI_COLOR_PALETTE\n    else:\n        raise NotImplementedError('Color palette type not supported')\n    color_palette = np.array(color_palette) / 255.\n    seg_labels[seg_labels == -100] = len(color_palette) - 1\n    colors = color_palette[seg_labels]\n\n    plt.imshow(img)\n    plt.scatter(img_indices[:, 1], img_indices[:, 0], c=colors, alpha=0.5, s=point_size)\n\n    plt.axis('off')\n\n    if show:\n        plt.show()\n\n\ndef normalize_depth(depth, d_min, d_max):\n    # normalize linearly between d_min and d_max\n    data = np.clip(depth, d_min, d_max)\n    return (data - d_min) / (d_max - d_min)\n\n\ndef draw_points_image_depth(img, img_indices, depth, show=True, point_size=0.5):\n    # depth = normalize_depth(depth, d_min=3., d_max=50.)\n    depth = normalize_depth(depth, d_min=depth.min(), d_max=depth.max())\n    colors = []\n    for depth_val in depth:\n        colors.append(interpolate_or_clip(colormap=turbo_colormap_data, x=depth_val))\n    # ax5.imshow(np.full_like(img, 255))\n    plt.imshow(img)\n    plt.scatter(img_indices[:, 1], img_indices[:, 0], c=colors, alpha=0.5, s=point_size)\n\n    plt.axis('off')\n\n    if show:\n        plt.show()\n\n\ndef draw_bird_eye_view(coords, full_scale=4096):\n    plt.scatter(coords[:, 0], coords[:, 1], s=0.1)\n    plt.xlim([0, full_scale])\n    plt.ylim([0, full_scale])\n    plt.gca().set_aspect('equal', adjustable='box')\n    plt.show()\n"
  },
  {
    "path": "xmuda/models/build.py",
    "content": "from xmuda.models.xmuda_arch import Net2DSeg, Net3DSeg\nfrom xmuda.models.metric import SegIoU\n\n\ndef build_model_2d(cfg):\n    model = Net2DSeg(num_classes=cfg.MODEL_2D.NUM_CLASSES,\n                     backbone_2d=cfg.MODEL_2D.TYPE,\n                     backbone_2d_kwargs=cfg.MODEL_2D[cfg.MODEL_2D.TYPE],\n                     dual_head=cfg.MODEL_2D.DUAL_HEAD\n                     )\n    train_metric = SegIoU(cfg.MODEL_2D.NUM_CLASSES, name='seg_iou_2d')\n    return model, train_metric\n\n\ndef build_model_3d(cfg):\n    model = Net3DSeg(num_classes=cfg.MODEL_3D.NUM_CLASSES,\n                     backbone_3d=cfg.MODEL_3D.TYPE,\n                     backbone_3d_kwargs=cfg.MODEL_3D[cfg.MODEL_3D.TYPE],\n                     dual_head=cfg.MODEL_3D.DUAL_HEAD\n                     )\n    train_metric = SegIoU(cfg.MODEL_3D.NUM_CLASSES, name='seg_iou_3d')\n    return model, train_metric\n"
  },
  {
    "path": "xmuda/models/losses.py",
    "content": "import numpy as np\nimport torch\nimport logging\n\n\ndef entropy_loss(v):\n    \"\"\"\n        Entropy loss for probabilistic prediction vectors\n        input: batch_size x classes x points\n        output: batch_size x 1 x points\n    \"\"\"\n    # (num points, num classes)\n    if v.dim() == 2:\n        v = v.transpose(0, 1)\n        v = v.unsqueeze(0)\n    # (1, num_classes, num_points)\n    assert v.dim() == 3\n    n, c, p = v.size()\n    return -torch.sum(torch.mul(v, torch.log2(v + 1e-30))) / (n * p * np.log2(c))\n\n\ndef logcoral_loss(x_src, x_trg):\n    \"\"\"\n    Geodesic loss (log coral loss), reference:\n    https://github.com/pmorerio/minimal-entropy-correlation-alignment/blob/master/svhn2mnist/model.py\n    :param x_src: source features of size (N, ..., F), where N is the batch size and F is the feature size\n    :param x_trg: target features of size (N, ..., F), where N is the batch size and F is the feature size\n    :return: geodesic distance between the x_src and x_trg\n    \"\"\"\n    # check if the feature size is the same, so that the covariance matrices will have the same dimensions\n    assert x_src.shape[-1] == x_trg.shape[-1]\n    assert x_src.dim() >= 2\n    batch_size = x_src.shape[0]\n    if x_src.dim() > 2:\n        # reshape from (N1, N2, ..., NM, F) to (N1 * N2 * ... * NM, F)\n        x_src = x_src.flatten(end_dim=-2)\n        x_trg = x_trg.flatten(end_dim=-2)\n\n    # subtract the mean over the batch\n    x_src = x_src - torch.mean(x_src, 0)\n    x_trg = x_trg - torch.mean(x_trg, 0)\n\n    # compute covariance\n    factor = 1. / (batch_size - 1)\n\n    cov_src = factor * torch.mm(x_src.t(), x_src)\n    cov_trg = factor * torch.mm(x_trg.t(), x_trg)\n\n    # dirty workaround to prevent GPU memory error due to MAGMA (used in SVD)\n    # this implementation achieves loss of zero without creating a fork in the computation graph\n    # if there is a nan or big number in the cov matrix, use where (not if!) to set cov matrix to identity matrix\n    condition = (cov_src > 1e30).any() or (cov_trg > 1e30).any() or torch.isnan(cov_src).any() or torch.isnan(cov_trg).any()\n    cov_src = torch.where(torch.full_like(cov_src, condition, dtype=torch.uint8), torch.eye(cov_src.shape[0], device=cov_src.device), cov_src)\n    cov_trg = torch.where(torch.full_like(cov_trg, condition, dtype=torch.uint8), torch.eye(cov_trg.shape[0], device=cov_trg.device), cov_trg)\n\n    if condition:\n        logger = logging.getLogger('xmuda.train')\n        logger.info('Big number > 1e30 or nan in covariance matrix, return loss of 0 to prevent error in SVD decomposition.')\n\n    _, e_src, v_src = cov_src.svd()\n    _, e_trg, v_trg = cov_trg.svd()\n\n    # nan can occur when taking log of a value near 0 (problem occurs if the cov matrix is of low rank)\n    log_cov_src = torch.mm(v_src, torch.mm(torch.diag(torch.log(e_src)), v_src.t()))\n    log_cov_trg = torch.mm(v_trg, torch.mm(torch.diag(torch.log(e_trg)), v_trg.t()))\n\n    # Frobenius norm\n    return torch.mean((log_cov_src - log_cov_trg) ** 2)\n"
  },
  {
    "path": "xmuda/models/metric.py",
    "content": "import torch\nfrom xmuda.common.utils.metric_logger import AverageMeter\n\n\nclass SegAccuracy(AverageMeter):\n    \"\"\"Segmentation accuracy\"\"\"\n    name = 'seg_acc'\n    \n    def __init__(self, ignore_index=-100):\n        super(SegAccuracy, self).__init__()\n        self.ignore_index = ignore_index\n\n    def update_dict(self, preds, labels):\n        seg_logit = preds['seg_logit']  # (b, c, n)\n        seg_label = labels['seg_label']  # (b, n)\n        pred_label = seg_logit.argmax(1)\n\n        mask = (seg_label != self.ignore_index)\n        seg_label = seg_label[mask]\n        pred_label = pred_label[mask]\n\n        tp_mask = pred_label.eq(seg_label)  # (b, n)\n        self.update(tp_mask.sum().item(), tp_mask.numel())\n\n\nclass SegIoU(object):\n    \"\"\"Segmentation IoU\n    References: https://github.com/pytorch/vision/blob/master/references/segmentation/utils.py\n    \"\"\"\n\n    def __init__(self, num_classes, ignore_index=-100, name='seg_iou'):\n        self.num_classes = num_classes\n        self.ignore_index = ignore_index\n        self.mat = None\n        self.name = name\n\n    def update_dict(self, preds, labels):\n        seg_logit = preds['seg_logit']  # (batch_size, num_classes, num_points)\n        seg_label = labels['seg_label']  # (batch_size, num_points)\n        pred_label = seg_logit.argmax(1)\n\n        mask = (seg_label != self.ignore_index)\n        seg_label = seg_label[mask]\n        pred_label = pred_label[mask]\n\n        # Update confusion matrix\n        # TODO: Compare the speed between torch.histogram and torch.bincount after pytorch v1.1.0\n        n = self.num_classes\n        with torch.no_grad():\n            if self.mat is None:\n                self.mat = seg_label.new_zeros((n, n))\n            inds = n * seg_label + pred_label\n            self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)\n\n    def reset(self):\n        self.mat = None\n\n    @property\n    def iou(self):\n        h = self.mat.float()\n        iou = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))\n        return iou\n\n    @property\n    def global_avg(self):\n        return self.iou.mean().item()\n\n    @property\n    def avg(self):\n        return self.global_avg\n\n    def __str__(self):\n        return '{iou:.4f}'.format(iou=self.iou.mean().item())\n\n    @property\n    def summary_str(self):\n        return str(self)\n"
  },
  {
    "path": "xmuda/models/resnet34_unet.py",
    "content": "\"\"\"UNet based on ResNet34\"\"\"\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torchvision.models.resnet import resnet34\n\n\nclass UNetResNet34(nn.Module):\n    def __init__(self, pretrained=True):\n        super(UNetResNet34, self).__init__()\n\n        # ----------------------------------------------------------------------------- #\n        # Encoder\n        # ----------------------------------------------------------------------------- #\n        net = resnet34(pretrained)\n        # Note that we do not downsample for conv1\n        # self.conv1 = net.conv1\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3, bias=False)\n        self.conv1.weight.data = net.conv1.weight.data\n        self.bn1 = net.bn1\n        self.relu = net.relu\n        self.maxpool = net.maxpool\n        self.layer1 = net.layer1\n        self.layer2 = net.layer2\n        self.layer3 = net.layer3\n        self.layer4 = net.layer4\n\n        # ----------------------------------------------------------------------------- #\n        # Decoder\n        # ----------------------------------------------------------------------------- #\n        _, self.dec_t_conv_stage5 = self.dec_stage(self.layer4, num_concat=1)\n        self.dec_conv_stage4, self.dec_t_conv_stage4 = self.dec_stage(self.layer3, num_concat=2)\n        self.dec_conv_stage3, self.dec_t_conv_stage3 = self.dec_stage(self.layer2, num_concat=2)\n        self.dec_conv_stage2, self.dec_t_conv_stage2 = self.dec_stage(self.layer1, num_concat=2)\n        self.dec_conv_stage1 = nn.Conv2d(2 * 64, 64, kernel_size=3, padding=1)\n\n        # dropout\n        self.dropout = nn.Dropout(p=0.4)\n\n    @staticmethod\n    def dec_stage(enc_stage, num_concat):\n        in_channels = enc_stage[0].conv1.in_channels\n        out_channels = enc_stage[-1].conv2.out_channels\n        conv = nn.Sequential(\n            nn.Conv2d(num_concat * out_channels, out_channels, kernel_size=3, padding=1),\n            nn.BatchNorm2d(out_channels),\n            nn.ReLU(inplace=True),\n        )\n        t_conv = nn.Sequential(\n            nn.ConvTranspose2d(out_channels, in_channels, kernel_size=2, stride=2),\n            nn.BatchNorm2d(in_channels),\n            nn.ReLU(inplace=True)\n        )\n        return conv, t_conv\n\n    def forward(self, x):\n        # pad input to be divisible by 16 = 2 ** 4\n        h, w = x.shape[2], x.shape[3]\n        min_size = 16\n        pad_h = int((h + min_size - 1) / min_size) * min_size - h\n        pad_w = int((w + min_size - 1) / min_size) * min_size - w\n        if pad_h > 0 or pad_w > 0:\n            x = F.pad(x, [0, pad_w, 0, pad_h])\n\n        # ----------------------------------------------------------------------------- #\n        # Encoder\n        # ----------------------------------------------------------------------------- #\n        inter_features = []\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        inter_features.append(x)\n        x = self.maxpool(x)  # downsample\n        x = self.layer1(x)\n        inter_features.append(x)\n        x = self.layer2(x)  # downsample\n        inter_features.append(x)\n        x = self.layer3(x)  # downsample\n        x = self.dropout(x)\n        inter_features.append(x)\n        x = self.layer4(x)  # downsample\n        x = self.dropout(x)\n\n        # ----------------------------------------------------------------------------- #\n        # Decoder\n        # ----------------------------------------------------------------------------- #\n        # upsample\n        x = self.dec_t_conv_stage5(x)\n        x = torch.cat([inter_features[3], x], dim=1)\n        x = self.dec_conv_stage4(x)\n\n        # upsample\n        x = self.dec_t_conv_stage4(x)\n        x = torch.cat([inter_features[2], x], dim=1)\n        x = self.dec_conv_stage3(x)\n\n        # upsample\n        x = self.dec_t_conv_stage3(x)\n        x = torch.cat([inter_features[1], x], dim=1)\n        x = self.dec_conv_stage2(x)\n\n        # upsample\n        x = self.dec_t_conv_stage2(x)\n        x = torch.cat([inter_features[0], x], dim=1)\n        x = self.dec_conv_stage1(x)\n\n        # crop padding\n        if pad_h > 0 or pad_w > 0:\n            x = x[:, :, 0:h, 0:w]\n\n        return x\n\n\ndef test():\n    b, c, h, w = 2, 20, 120, 160\n    image = torch.randn(b, 3, h, w).cuda()\n    net = UNetResNet34(pretrained=True)\n    net.cuda()\n    feats = net(image)\n    print('feats', feats.shape)\n\n\nif __name__ == '__main__':\n    test()\n"
  },
  {
    "path": "xmuda/models/scn_unet.py",
    "content": "import torch\nimport torch.nn as nn\n\nimport sparseconvnet as scn\n\nDIMENSION = 3\n\n\nclass UNetSCN(nn.Module):\n    def __init__(self,\n                 in_channels,\n                 m=16,  # number of unet features (multiplied in each layer)\n                 block_reps=1,  # depth\n                 residual_blocks=False,  # ResNet style basic blocks\n                 full_scale=4096,\n                 num_planes=7\n                 ):\n        super(UNetSCN, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = m\n        n_planes = [(n + 1) * m for n in range(num_planes)]\n\n        self.sparseModel = scn.Sequential().add(\n            scn.InputLayer(DIMENSION, full_scale, mode=4)).add(\n            scn.SubmanifoldConvolution(DIMENSION, in_channels, m, 3, False)).add(\n            scn.UNet(DIMENSION, block_reps, n_planes, residual_blocks)).add(\n            scn.BatchNormReLU(m)).add(\n            scn.OutputLayer(DIMENSION))\n\n    def forward(self, x):\n        x = self.sparseModel(x)\n        return x\n\n\ndef test():\n    b, n = 2, 100\n    coords = torch.randint(4096, [b, n, DIMENSION])\n    batch_idxs = torch.arange(b).reshape(b, 1, 1).repeat(1, n, 1)\n    coords = torch.cat([coords, batch_idxs], 2).reshape(-1, DIMENSION + 1)\n\n    in_channels = 3\n    feats = torch.rand(b * n, in_channels)\n\n    x = [coords, feats.cuda()]\n\n    net = UNetSCN(in_channels).cuda()\n    out_feats = net(x)\n\n    print('out_feats', out_feats.shape)\n\n\nif __name__ == '__main__':\n    test()\n"
  },
  {
    "path": "xmuda/models/xmuda_arch.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom xmuda.models.resnet34_unet import UNetResNet34\nfrom xmuda.models.scn_unet import UNetSCN\n\n\nclass Net2DSeg(nn.Module):\n    def __init__(self,\n                 num_classes,\n                 dual_head,\n                 backbone_2d,\n                 backbone_2d_kwargs\n                 ):\n        super(Net2DSeg, self).__init__()\n\n        # 2D image network\n        if backbone_2d == 'UNetResNet34':\n            self.net_2d = UNetResNet34(**backbone_2d_kwargs)\n            feat_channels = 64\n        else:\n            raise NotImplementedError('2D backbone {} not supported'.format(backbone_2d))\n\n        # segmentation head\n        self.linear = nn.Linear(feat_channels, num_classes)\n\n        # 2nd segmentation head\n        self.dual_head = dual_head\n        if dual_head:\n            self.linear2 = nn.Linear(feat_channels, num_classes)\n\n    def forward(self, data_batch):\n        # (batch_size, 3, H, W)\n        img = data_batch['img']\n        img_indices = data_batch['img_indices']\n\n        # 2D network\n        x = self.net_2d(img)\n\n        # 2D-3D feature lifting\n        img_feats = []\n        for i in range(x.shape[0]):\n            img_feats.append(x.permute(0, 2, 3, 1)[i][img_indices[i][:, 0], img_indices[i][:, 1]])\n        img_feats = torch.cat(img_feats, 0)\n\n        # linear\n        x = self.linear(img_feats)\n\n        preds = {\n            'feats': img_feats,\n            'seg_logit': x,\n        }\n\n        if self.dual_head:\n            preds['seg_logit2'] = self.linear2(img_feats)\n\n        return preds\n\n\nclass Net3DSeg(nn.Module):\n    def __init__(self,\n                 num_classes,\n                 dual_head,\n                 backbone_3d,\n                 backbone_3d_kwargs,\n                 ):\n        super(Net3DSeg, self).__init__()\n\n        # 3D network\n        if backbone_3d == 'SCN':\n            self.net_3d = UNetSCN(**backbone_3d_kwargs)\n        else:\n            raise NotImplementedError('3D backbone {} not supported'.format(backbone_3d))\n\n        # segmentation head\n        self.linear = nn.Linear(self.net_3d.out_channels, num_classes)\n\n        # 2nd segmentation head\n        self.dual_head = dual_head\n        if dual_head:\n            self.linear2 = nn.Linear(self.net_3d.out_channels, num_classes)\n\n    def forward(self, data_batch):\n        feats = self.net_3d(data_batch['x'])\n        x = self.linear(feats)\n\n        preds = {\n            'feats': feats,\n            'seg_logit': x,\n        }\n\n        if self.dual_head:\n            preds['seg_logit2'] = self.linear2(feats)\n\n        return preds\n\n\ndef test_Net2DSeg():\n    # 2D\n    batch_size = 2\n    img_width = 400\n    img_height = 225\n\n    # 3D\n    num_coords = 2000\n    num_classes = 11\n\n    # 2D\n    img = torch.rand(batch_size, 3, img_height, img_width)\n    u = torch.randint(high=img_height, size=(batch_size, num_coords // batch_size, 1))\n    v = torch.randint(high=img_width, size=(batch_size, num_coords // batch_size, 1))\n    img_indices = torch.cat([u, v], 2)\n\n    # to cuda\n    img = img.cuda()\n    img_indices = img_indices.cuda()\n\n    net_2d = Net2DSeg(num_classes,\n                      backbone_2d='UNetResNet34',\n                      backbone_2d_kwargs={},\n                      dual_head=True)\n\n    net_2d.cuda()\n    out_dict = net_2d({\n        'img': img,\n        'img_indices': img_indices,\n    })\n    for k, v in out_dict.items():\n        print('Net2DSeg:', k, v.shape)\n\n\ndef test_Net3DSeg():\n    in_channels = 1\n    num_coords = 2000\n    full_scale = 4096\n    num_seg_classes = 11\n\n    coords = torch.randint(high=full_scale, size=(num_coords, 3))\n    feats = torch.rand(num_coords, in_channels)\n\n    feats = feats.cuda()\n\n    net_3d = Net3DSeg(num_seg_classes,\n                      dual_head=True,\n                      backbone_3d='SCN',\n                      backbone_3d_kwargs={'in_channels': in_channels})\n\n    net_3d.cuda()\n    out_dict = net_3d({\n        'x': [coords, feats],\n    })\n    for k, v in out_dict.items():\n        print('Net3DSeg:', k, v.shape)\n\n\nif __name__ == '__main__':\n    test_Net2DSeg()\n    test_Net3DSeg()\n"
  },
  {
    "path": "xmuda/test.py",
    "content": "#!/usr/bin/env python\nimport os\nimport os.path as osp\nimport argparse\nimport logging\nimport time\nimport socket\nimport warnings\n\nimport torch\n\nfrom xmuda.common.utils.checkpoint import CheckpointerV2\nfrom xmuda.common.utils.logger import setup_logger\nfrom xmuda.common.utils.metric_logger import MetricLogger\nfrom xmuda.common.utils.torch_util import set_random_seed\nfrom xmuda.models.build import build_model_2d, build_model_3d\nfrom xmuda.data.build import build_dataloader\nfrom xmuda.data.utils.validate import validate\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='xMUDA test')\n    parser.add_argument(\n        '--cfg',\n        dest='config_file',\n        default='',\n        metavar='FILE',\n        help='path to config file',\n        type=str,\n    )\n    parser.add_argument('ckpt2d', type=str, help='path to checkpoint file of the 2D model')\n    parser.add_argument('ckpt3d', type=str, help='path to checkpoint file of the 3D model')\n    parser.add_argument('--pselab', action='store_true', help='generate pseudo-labels')\n    parser.add_argument(\n        'opts',\n        help='Modify config options using the command-line',\n        default=None,\n        nargs=argparse.REMAINDER,\n    )\n    args = parser.parse_args()\n    return args\n\n\ndef test(cfg, args, output_dir=''):\n    logger = logging.getLogger('xmuda.test')\n\n    # build 2d model\n    model_2d = build_model_2d(cfg)[0]\n\n    # build 3d model\n    model_3d = build_model_3d(cfg)[0]\n\n    model_2d = model_2d.cuda()\n    model_3d = model_3d.cuda()\n\n    # build checkpointer\n    checkpointer_2d = CheckpointerV2(model_2d, save_dir=output_dir, logger=logger)\n    if args.ckpt2d:\n        # load weight if specified\n        weight_path = args.ckpt2d.replace('@', output_dir)\n        checkpointer_2d.load(weight_path, resume=False)\n    else:\n        # load last checkpoint\n        checkpointer_2d.load(None, resume=True)\n    checkpointer_3d = CheckpointerV2(model_3d, save_dir=output_dir, logger=logger)\n    if args.ckpt3d:\n        # load weight if specified\n        weight_path = args.ckpt3d.replace('@', output_dir)\n        checkpointer_3d.load(weight_path, resume=False)\n    else:\n        # load last checkpoint\n        checkpointer_3d.load(None, resume=True)\n\n    # build dataset\n    test_dataloader = build_dataloader(cfg, mode='test', domain='target')\n\n    pselab_path = None\n    if args.pselab:\n        pselab_dir = osp.join(output_dir, 'pselab_data')\n        os.makedirs(pselab_dir, exist_ok=True)\n        assert len(cfg.DATASET_TARGET.TEST) == 1\n        pselab_path = osp.join(pselab_dir, cfg.DATASET_TARGET.TEST[0] + '.npy')\n\n    # ---------------------------------------------------------------------------- #\n    # Test\n    # ---------------------------------------------------------------------------- #\n\n    set_random_seed(cfg.RNG_SEED)\n    test_metric_logger = MetricLogger(delimiter='  ')\n    model_2d.eval()\n    model_3d.eval()\n\n    validate(cfg, model_2d, model_3d, test_dataloader, test_metric_logger, pselab_path=pselab_path)\n\n\ndef main():\n    args = parse_args()\n\n    # load the configuration\n    # import on-the-fly to avoid overwriting cfg\n    from xmuda.common.config import purge_cfg\n    from xmuda.config.xmuda import cfg\n    cfg.merge_from_file(args.config_file)\n    cfg.merge_from_list(args.opts)\n    purge_cfg(cfg)\n    cfg.freeze()\n\n    output_dir = cfg.OUTPUT_DIR\n    # replace '@' with config path\n    if output_dir:\n        config_path = osp.splitext(args.config_file)[0]\n        output_dir = output_dir.replace('@', config_path.replace('configs/', ''))\n        if not osp.isdir(output_dir):\n            warnings.warn('Make a new directory: {}'.format(output_dir))\n            os.makedirs(output_dir)\n\n    # run name\n    timestamp = time.strftime('%m-%d_%H-%M-%S')\n    hostname = socket.gethostname()\n    run_name = '{:s}.{:s}'.format(timestamp, hostname)\n\n    logger = setup_logger('xmuda', output_dir, comment='test.{:s}'.format(run_name))\n    logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))\n    logger.info(args)\n\n    logger.info('Loaded configuration file {:s}'.format(args.config_file))\n    logger.info('Running with config:\\n{}'.format(cfg))\n\n    assert cfg.MODEL_2D.DUAL_HEAD == cfg.MODEL_3D.DUAL_HEAD\n    test(cfg, args, output_dir)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "xmuda/train_baseline.py",
    "content": "#!/usr/bin/env python\nimport os\nimport os.path as osp\nimport argparse\nimport logging\nimport time\nimport socket\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom xmuda.common.solver.build import build_optimizer, build_scheduler\nfrom xmuda.common.utils.checkpoint import CheckpointerV2\nfrom xmuda.common.utils.logger import setup_logger\nfrom xmuda.common.utils.metric_logger import MetricLogger\nfrom xmuda.common.utils.torch_util import set_random_seed\nfrom xmuda.models.build import build_model_2d, build_model_3d\nfrom xmuda.data.build import build_dataloader\nfrom xmuda.data.utils.validate import validate\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='xMUDA training')\n    parser.add_argument(\n        '--cfg',\n        dest='config_file',\n        default='',\n        metavar='FILE',\n        help='path to config file',\n        type=str,\n    )\n    parser.add_argument(\n        'opts',\n        help='Modify config options using the command-line',\n        default=None,\n        nargs=argparse.REMAINDER,\n    )\n    args = parser.parse_args()\n    return args\n\n\ndef init_metric_logger(metric_list):\n    new_metric_list = []\n    for metric in metric_list:\n        if isinstance(metric, (list, tuple)):\n            new_metric_list.extend(metric)\n        else:\n            new_metric_list.append(metric)\n    metric_logger = MetricLogger(delimiter='  ')\n    metric_logger.add_meters(new_metric_list)\n    return metric_logger\n\n\ndef train(cfg, output_dir='', run_name=''):\n    # ---------------------------------------------------------------------------- #\n    # Build models, optimizer, scheduler, checkpointer, etc.\n    # ---------------------------------------------------------------------------- #\n    logger = logging.getLogger('xmuda.train')\n\n    set_random_seed(cfg.RNG_SEED)\n\n    # build 2d model\n    model_2d, train_metric_2d = build_model_2d(cfg)\n    logger.info('Build 2D model:\\n{}'.format(str(model_2d)))\n    num_params = sum(param.numel() for param in model_2d.parameters())\n    print('#Parameters: {:.2e}'.format(num_params))\n\n    # build 3d model\n    model_3d, train_metric_3d = build_model_3d(cfg)\n    logger.info('Build 3D model:\\n{}'.format(str(model_3d)))\n    num_params = sum(param.numel() for param in model_3d.parameters())\n    print('#Parameters: {:.2e}'.format(num_params))\n\n    model_2d = model_2d.cuda()\n    model_3d = model_3d.cuda()\n\n    # build optimizer\n    optimizer_2d = build_optimizer(cfg, model_2d)\n    optimizer_3d = build_optimizer(cfg, model_3d)\n\n    # build lr scheduler\n    scheduler_2d = build_scheduler(cfg, optimizer_2d)\n    scheduler_3d = build_scheduler(cfg, optimizer_3d)\n\n    # build checkpointer\n    # Note that checkpointer will load state_dict of model, optimizer and scheduler.\n    checkpointer_2d = CheckpointerV2(model_2d,\n                                     optimizer=optimizer_2d,\n                                     scheduler=scheduler_2d,\n                                     save_dir=output_dir,\n                                     logger=logger,\n                                     postfix='_2d',\n                                     max_to_keep=cfg.TRAIN.MAX_TO_KEEP)\n    checkpoint_data_2d = checkpointer_2d.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)\n    checkpointer_3d = CheckpointerV2(model_3d,\n                                     optimizer=optimizer_3d,\n                                     scheduler=scheduler_3d,\n                                     save_dir=output_dir,\n                                     logger=logger,\n                                     postfix='_3d',\n                                     max_to_keep=cfg.TRAIN.MAX_TO_KEEP)\n    checkpoint_data_3d = checkpointer_3d.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)\n    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD\n\n    # build tensorboard logger (optionally by comment)\n    if output_dir:\n        tb_dir = osp.join(output_dir, 'tb.{:s}'.format(run_name))\n        summary_writer = SummaryWriter(tb_dir)\n    else:\n        summary_writer = None\n\n    # ---------------------------------------------------------------------------- #\n    # Train\n    # ---------------------------------------------------------------------------- #\n    max_iteration = cfg.SCHEDULER.MAX_ITERATION\n    start_iteration = checkpoint_data_2d.get('iteration', 0)\n\n    # build data loader\n    # Reset the random seed again in case the initialization of models changes the random state.\n    set_random_seed(cfg.RNG_SEED)\n    train_dataloader_src = build_dataloader(cfg, mode='train', domain='source', start_iteration=start_iteration)\n    val_period = cfg.VAL.PERIOD\n    val_dataloader = build_dataloader(cfg, mode='val', domain='target') if val_period > 0 else None\n\n    best_metric_name = 'best_{}'.format(cfg.VAL.METRIC)\n    best_metric = {\n        '2d': checkpoint_data_2d.get(best_metric_name, None),\n        '3d': checkpoint_data_3d.get(best_metric_name, None)\n    }\n    best_metric_iter = {'2d': -1, '3d': -1}\n    logger.info('Start training from iteration {}'.format(start_iteration))\n\n    # add metrics\n    train_metric_logger = init_metric_logger([train_metric_2d, train_metric_3d])\n    val_metric_logger = MetricLogger(delimiter='  ')\n\n    def setup_train():\n        # set training mode\n        model_2d.train()\n        model_3d.train()\n        # reset metric\n        train_metric_logger.reset()\n\n    def setup_validate():\n        # set evaluate mode\n        model_2d.eval()\n        model_3d.eval()\n        # reset metric\n        val_metric_logger.reset()\n\n    if cfg.TRAIN.CLASS_WEIGHTS:\n        class_weights = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS).cuda()\n    else:\n        class_weights = None\n\n    setup_train()\n    end = time.time()\n    train_iter_src = enumerate(train_dataloader_src)\n    for iteration in range(start_iteration, max_iteration):\n        # fetch data_batches for source & target\n        _, data_batch_src = train_iter_src.__next__()\n        data_time = time.time() - end\n        # copy data from cpu to gpu\n        if 'SCN' in cfg.DATASET_SOURCE.TYPE and 'SCN' in cfg.DATASET_TARGET.TYPE:\n            # source\n            data_batch_src['x'][1] = data_batch_src['x'][1].cuda()\n            data_batch_src['seg_label'] = data_batch_src['seg_label'].cuda()\n            data_batch_src['img'] = data_batch_src['img'].cuda()\n        else:\n            raise NotImplementedError('Only SCN is supported for now.')\n\n        optimizer_2d.zero_grad()\n        optimizer_3d.zero_grad()\n\n        # ---------------------------------------------------------------------------- #\n        # Train on source\n        # ---------------------------------------------------------------------------- #\n\n        preds_2d = model_2d(data_batch_src)\n        preds_3d = model_3d(data_batch_src)\n\n        # segmentation loss: cross entropy\n        seg_loss_src_2d = F.cross_entropy(preds_2d['seg_logit'], data_batch_src['seg_label'], weight=class_weights)\n        seg_loss_src_3d = F.cross_entropy(preds_3d['seg_logit'], data_batch_src['seg_label'], weight=class_weights)\n        train_metric_logger.update(seg_loss_src_2d=seg_loss_src_2d, seg_loss_src_3d=seg_loss_src_3d)\n        loss_2d = seg_loss_src_2d\n        loss_3d = seg_loss_src_3d\n\n        if cfg.TRAIN.XMUDA.lambda_xm_src > 0:\n            # cross-modal loss: KL divergence\n            seg_logit_2d = preds_2d['seg_logit2'] if cfg.MODEL_2D.DUAL_HEAD else preds_2d['seg_logit']\n            seg_logit_3d = preds_3d['seg_logit2'] if cfg.MODEL_3D.DUAL_HEAD else preds_3d['seg_logit']\n            xm_loss_src_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),\n                                      F.softmax(preds_3d['seg_logit'].detach(), dim=1),\n                                      reduction='none').sum(1).mean()\n            xm_loss_src_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),\n                                      F.softmax(preds_2d['seg_logit'].detach(), dim=1),\n                                      reduction='none').sum(1).mean()\n            train_metric_logger.update(xm_loss_src_2d=xm_loss_src_2d,\n                                       xm_loss_src_3d=xm_loss_src_3d)\n            loss_2d += cfg.TRAIN.XMUDA.lambda_xm_src * xm_loss_src_2d\n            loss_3d += cfg.TRAIN.XMUDA.lambda_xm_src * xm_loss_src_3d\n\n        # update metric (e.g. IoU)\n        with torch.no_grad():\n            train_metric_2d.update_dict(preds_2d, data_batch_src)\n            train_metric_3d.update_dict(preds_3d, data_batch_src)\n\n        # backward\n        loss_2d.backward()\n        loss_3d.backward()\n\n        optimizer_2d.step()\n        optimizer_3d.step()\n\n        batch_time = time.time() - end\n        train_metric_logger.update(time=batch_time, data=data_time)\n\n        # log\n        cur_iter = iteration + 1\n        if cur_iter == 1 or (cfg.TRAIN.LOG_PERIOD > 0 and cur_iter % cfg.TRAIN.LOG_PERIOD == 0):\n            logger.info(\n                train_metric_logger.delimiter.join(\n                    [\n                        'iter: {iter:4d}',\n                        '{meters}',\n                        'lr: {lr:.2e}',\n                        'max mem: {memory:.0f}',\n                    ]\n                ).format(\n                    iter=cur_iter,\n                    meters=str(train_metric_logger),\n                    lr=optimizer_2d.param_groups[0]['lr'],\n                    memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),\n                )\n            )\n\n        # summary\n        if summary_writer is not None and cfg.TRAIN.SUMMARY_PERIOD > 0 and cur_iter % cfg.TRAIN.SUMMARY_PERIOD == 0:\n            keywords = ('loss', 'acc', 'iou')\n            for name, meter in train_metric_logger.meters.items():\n                if all(k not in name for k in keywords):\n                    continue\n                summary_writer.add_scalar('train/' + name, meter.avg, global_step=cur_iter)\n\n        # checkpoint\n        if (ckpt_period > 0 and cur_iter % ckpt_period == 0) or cur_iter == max_iteration:\n            checkpoint_data_2d['iteration'] = cur_iter\n            checkpoint_data_2d[best_metric_name] = best_metric['2d']\n            checkpointer_2d.save('model_2d_{:06d}'.format(cur_iter), **checkpoint_data_2d)\n            checkpoint_data_3d['iteration'] = cur_iter\n            checkpoint_data_3d[best_metric_name] = best_metric['3d']\n            checkpointer_3d.save('model_3d_{:06d}'.format(cur_iter), **checkpoint_data_3d)\n\n        # ---------------------------------------------------------------------------- #\n        # validate for one epoch\n        # ---------------------------------------------------------------------------- #\n        if val_period > 0 and (cur_iter % val_period == 0 or cur_iter == max_iteration):\n            start_time_val = time.time()\n            setup_validate()\n\n            validate(cfg,\n                     model_2d,\n                     model_3d,\n                     val_dataloader,\n                     val_metric_logger)\n\n            epoch_time_val = time.time() - start_time_val\n            logger.info('Iteration[{}]-Val {}  total_time: {:.2f}s'.format(\n                cur_iter, val_metric_logger.summary_str, epoch_time_val))\n\n            # summary\n            if summary_writer is not None:\n                keywords = ('loss', 'acc', 'iou')\n                for name, meter in val_metric_logger.meters.items():\n                    if all(k not in name for k in keywords):\n                        continue\n                    summary_writer.add_scalar('val/' + name, meter.avg, global_step=cur_iter)\n\n            # best validation\n            for modality in ['2d', '3d']:\n                cur_metric_name = cfg.VAL.METRIC + '_' + modality\n                if cur_metric_name in val_metric_logger.meters:\n                    cur_metric = val_metric_logger.meters[cur_metric_name].global_avg\n                    if best_metric[modality] is None or best_metric[modality] < cur_metric:\n                        best_metric[modality] = cur_metric\n                        best_metric_iter[modality] = cur_iter\n\n            # restore training\n            setup_train()\n\n        scheduler_2d.step()\n        scheduler_3d.step()\n        end = time.time()\n\n    for modality in ['2d', '3d']:\n        logger.info('Best val-{}-{} = {:.2f} at iteration {}'.format(modality.upper(),\n                                                                     cfg.VAL.METRIC,\n                                                                     best_metric[modality] * 100,\n                                                                     best_metric_iter[modality]))\n\n\ndef main():\n    args = parse_args()\n\n    # load the configuration\n    # import on-the-fly to avoid overwriting cfg\n    from xmuda.common.config import purge_cfg\n    from xmuda.config.xmuda import cfg\n    cfg.merge_from_file(args.config_file)\n    cfg.merge_from_list(args.opts)\n    purge_cfg(cfg)\n    cfg.freeze()\n\n    output_dir = cfg.OUTPUT_DIR\n    # replace '@' with config path\n    if output_dir:\n        config_path = osp.splitext(args.config_file)[0]\n        output_dir = output_dir.replace('@', config_path.replace('configs/', ''))\n        if osp.isdir(output_dir):\n            warnings.warn('Output directory exists.')\n        os.makedirs(output_dir, exist_ok=True)\n\n    # run name\n    timestamp = time.strftime('%m-%d_%H-%M-%S')\n    hostname = socket.gethostname()\n    run_name = '{:s}.{:s}'.format(timestamp, hostname)\n\n    logger = setup_logger('xmuda', output_dir, comment='train.{:s}'.format(run_name))\n    logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))\n    logger.info(args)\n\n    logger.info('Loaded configuration file {:s}'.format(args.config_file))\n    logger.info('Running with config:\\n{}'.format(cfg))\n\n    # check that 2D and 3D model use either both single head or both dual head\n    assert cfg.MODEL_2D.DUAL_HEAD == cfg.MODEL_3D.DUAL_HEAD\n    # check if there is at least one loss on target set\n    assert cfg.TRAIN.XMUDA.lambda_xm_trg == 0 and cfg.TRAIN.XMUDA.lambda_pl == 0\n    train(cfg, output_dir, run_name)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "xmuda/train_xmuda.py",
    "content": "#!/usr/bin/env python\nimport os\nimport os.path as osp\nimport argparse\nimport logging\nimport time\nimport socket\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom xmuda.common.solver.build import build_optimizer, build_scheduler\nfrom xmuda.common.utils.checkpoint import CheckpointerV2\nfrom xmuda.common.utils.logger import setup_logger\nfrom xmuda.common.utils.metric_logger import MetricLogger\nfrom xmuda.common.utils.torch_util import set_random_seed\nfrom xmuda.models.build import build_model_2d, build_model_3d\nfrom xmuda.data.build import build_dataloader\nfrom xmuda.data.utils.validate import validate\nfrom xmuda.models.losses import entropy_loss\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='xMUDA training')\n    parser.add_argument(\n        '--cfg',\n        dest='config_file',\n        default='',\n        metavar='FILE',\n        help='path to config file',\n        type=str,\n    )\n    parser.add_argument(\n        'opts',\n        help='Modify config options using the command-line',\n        default=None,\n        nargs=argparse.REMAINDER,\n    )\n    args = parser.parse_args()\n    return args\n\n\ndef init_metric_logger(metric_list):\n    new_metric_list = []\n    for metric in metric_list:\n        if isinstance(metric, (list, tuple)):\n            new_metric_list.extend(metric)\n        else:\n            new_metric_list.append(metric)\n    metric_logger = MetricLogger(delimiter='  ')\n    metric_logger.add_meters(new_metric_list)\n    return metric_logger\n\n\ndef train(cfg, output_dir='', run_name=''):\n    # ---------------------------------------------------------------------------- #\n    # Build models, optimizer, scheduler, checkpointer, etc.\n    # ---------------------------------------------------------------------------- #\n    logger = logging.getLogger('xmuda.train')\n\n    set_random_seed(cfg.RNG_SEED)\n\n    # build 2d model\n    model_2d, train_metric_2d = build_model_2d(cfg)\n    logger.info('Build 2D model:\\n{}'.format(str(model_2d)))\n    num_params = sum(param.numel() for param in model_2d.parameters())\n    print('#Parameters: {:.2e}'.format(num_params))\n\n    # build 3d model\n    model_3d, train_metric_3d = build_model_3d(cfg)\n    logger.info('Build 3D model:\\n{}'.format(str(model_3d)))\n    num_params = sum(param.numel() for param in model_3d.parameters())\n    print('#Parameters: {:.2e}'.format(num_params))\n\n    model_2d = model_2d.cuda()\n    model_3d = model_3d.cuda()\n\n    # build optimizer\n    optimizer_2d = build_optimizer(cfg, model_2d)\n    optimizer_3d = build_optimizer(cfg, model_3d)\n\n    # build lr scheduler\n    scheduler_2d = build_scheduler(cfg, optimizer_2d)\n    scheduler_3d = build_scheduler(cfg, optimizer_3d)\n\n    # build checkpointer\n    # Note that checkpointer will load state_dict of model, optimizer and scheduler.\n    checkpointer_2d = CheckpointerV2(model_2d,\n                                     optimizer=optimizer_2d,\n                                     scheduler=scheduler_2d,\n                                     save_dir=output_dir,\n                                     logger=logger,\n                                     postfix='_2d',\n                                     max_to_keep=cfg.TRAIN.MAX_TO_KEEP)\n    checkpoint_data_2d = checkpointer_2d.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)\n    checkpointer_3d = CheckpointerV2(model_3d,\n                                     optimizer=optimizer_3d,\n                                     scheduler=scheduler_3d,\n                                     save_dir=output_dir,\n                                     logger=logger,\n                                     postfix='_3d',\n                                     max_to_keep=cfg.TRAIN.MAX_TO_KEEP)\n    checkpoint_data_3d = checkpointer_3d.load(cfg.RESUME_PATH, resume=cfg.AUTO_RESUME, resume_states=cfg.RESUME_STATES)\n    ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD\n\n    # build tensorboard logger (optionally by comment)\n    if output_dir:\n        tb_dir = osp.join(output_dir, 'tb.{:s}'.format(run_name))\n        summary_writer = SummaryWriter(tb_dir)\n    else:\n        summary_writer = None\n\n    # ---------------------------------------------------------------------------- #\n    # Train\n    # ---------------------------------------------------------------------------- #\n    max_iteration = cfg.SCHEDULER.MAX_ITERATION\n    start_iteration = checkpoint_data_2d.get('iteration', 0)\n\n    # build data loader\n    # Reset the random seed again in case the initialization of models changes the random state.\n    set_random_seed(cfg.RNG_SEED)\n    train_dataloader_src = build_dataloader(cfg, mode='train', domain='source', start_iteration=start_iteration)\n    train_dataloader_trg = build_dataloader(cfg, mode='train', domain='target', start_iteration=start_iteration)\n    val_period = cfg.VAL.PERIOD\n    val_dataloader = build_dataloader(cfg, mode='val', domain='target') if val_period > 0 else None\n\n    best_metric_name = 'best_{}'.format(cfg.VAL.METRIC)\n    best_metric = {\n        '2d': checkpoint_data_2d.get(best_metric_name, None),\n        '3d': checkpoint_data_3d.get(best_metric_name, None)\n    }\n    best_metric_iter = {'2d': -1, '3d': -1}\n    logger.info('Start training from iteration {}'.format(start_iteration))\n\n    # add metrics\n    train_metric_logger = init_metric_logger([train_metric_2d, train_metric_3d])\n    val_metric_logger = MetricLogger(delimiter='  ')\n\n    def setup_train():\n        # set training mode\n        model_2d.train()\n        model_3d.train()\n        # reset metric\n        train_metric_logger.reset()\n\n    def setup_validate():\n        # set evaluate mode\n        model_2d.eval()\n        model_3d.eval()\n        # reset metric\n        val_metric_logger.reset()\n\n    if cfg.TRAIN.CLASS_WEIGHTS:\n        class_weights = torch.tensor(cfg.TRAIN.CLASS_WEIGHTS).cuda()\n    else:\n        class_weights = None\n\n    setup_train()\n    end = time.time()\n    train_iter_src = enumerate(train_dataloader_src)\n    train_iter_trg = enumerate(train_dataloader_trg)\n    for iteration in range(start_iteration, max_iteration):\n        # fetch data_batches for source & target\n        _, data_batch_src = train_iter_src.__next__()\n        _, data_batch_trg = train_iter_trg.__next__()\n        data_time = time.time() - end\n        # copy data from cpu to gpu\n        if 'SCN' in cfg.DATASET_SOURCE.TYPE and 'SCN' in cfg.DATASET_TARGET.TYPE:\n            # source\n            data_batch_src['x'][1] = data_batch_src['x'][1].cuda()\n            data_batch_src['seg_label'] = data_batch_src['seg_label'].cuda()\n            data_batch_src['img'] = data_batch_src['img'].cuda()\n            # target\n            data_batch_trg['x'][1] = data_batch_trg['x'][1].cuda()\n            data_batch_trg['seg_label'] = data_batch_trg['seg_label'].cuda()\n            data_batch_trg['img'] = data_batch_trg['img'].cuda()\n            if cfg.TRAIN.XMUDA.lambda_pl > 0:\n                data_batch_trg['pseudo_label_2d'] = data_batch_trg['pseudo_label_2d'].cuda()\n                data_batch_trg['pseudo_label_3d'] = data_batch_trg['pseudo_label_3d'].cuda()\n        else:\n            raise NotImplementedError('Only SCN is supported for now.')\n\n        optimizer_2d.zero_grad()\n        optimizer_3d.zero_grad()\n\n        # ---------------------------------------------------------------------------- #\n        # Train on source\n        # ---------------------------------------------------------------------------- #\n\n        preds_2d = model_2d(data_batch_src)\n        preds_3d = model_3d(data_batch_src)\n\n        # segmentation loss: cross entropy\n        seg_loss_src_2d = F.cross_entropy(preds_2d['seg_logit'], data_batch_src['seg_label'], weight=class_weights)\n        seg_loss_src_3d = F.cross_entropy(preds_3d['seg_logit'], data_batch_src['seg_label'], weight=class_weights)\n        train_metric_logger.update(seg_loss_src_2d=seg_loss_src_2d, seg_loss_src_3d=seg_loss_src_3d)\n        loss_2d = seg_loss_src_2d\n        loss_3d = seg_loss_src_3d\n\n        if cfg.TRAIN.XMUDA.lambda_xm_src > 0:\n            # cross-modal loss: KL divergence\n            seg_logit_2d = preds_2d['seg_logit2'] if cfg.MODEL_2D.DUAL_HEAD else preds_2d['seg_logit']\n            seg_logit_3d = preds_3d['seg_logit2'] if cfg.MODEL_3D.DUAL_HEAD else preds_3d['seg_logit']\n            xm_loss_src_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),\n                                      F.softmax(preds_3d['seg_logit'].detach(), dim=1),\n                                      reduction='none').sum(1).mean()\n            xm_loss_src_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),\n                                      F.softmax(preds_2d['seg_logit'].detach(), dim=1),\n                                      reduction='none').sum(1).mean()\n            train_metric_logger.update(xm_loss_src_2d=xm_loss_src_2d,\n                                       xm_loss_src_3d=xm_loss_src_3d)\n            loss_2d += cfg.TRAIN.XMUDA.lambda_xm_src * xm_loss_src_2d\n            loss_3d += cfg.TRAIN.XMUDA.lambda_xm_src * xm_loss_src_3d\n\n        # update metric (e.g. IoU)\n        with torch.no_grad():\n            train_metric_2d.update_dict(preds_2d, data_batch_src)\n            train_metric_3d.update_dict(preds_3d, data_batch_src)\n\n        # backward\n        loss_2d.backward()\n        loss_3d.backward()\n\n        # ---------------------------------------------------------------------------- #\n        # Train on target\n        # ---------------------------------------------------------------------------- #\n\n        preds_2d = model_2d(data_batch_trg)\n        preds_3d = model_3d(data_batch_trg)\n\n        loss_2d = []\n        loss_3d = []\n        if cfg.TRAIN.XMUDA.lambda_xm_trg > 0:\n            # cross-modal loss: KL divergence\n            seg_logit_2d = preds_2d['seg_logit2'] if cfg.MODEL_2D.DUAL_HEAD else preds_2d['seg_logit']\n            seg_logit_3d = preds_3d['seg_logit2'] if cfg.MODEL_3D.DUAL_HEAD else preds_3d['seg_logit']\n            xm_loss_trg_2d = F.kl_div(F.log_softmax(seg_logit_2d, dim=1),\n                                      F.softmax(preds_3d['seg_logit'].detach(), dim=1),\n                                      reduction='none').sum(1).mean()\n            xm_loss_trg_3d = F.kl_div(F.log_softmax(seg_logit_3d, dim=1),\n                                      F.softmax(preds_2d['seg_logit'].detach(), dim=1),\n                                      reduction='none').sum(1).mean()\n            train_metric_logger.update(xm_loss_trg_2d=xm_loss_trg_2d,\n                                       xm_loss_trg_3d=xm_loss_trg_3d)\n            loss_2d.append(cfg.TRAIN.XMUDA.lambda_xm_trg * xm_loss_trg_2d)\n            loss_3d.append(cfg.TRAIN.XMUDA.lambda_xm_trg * xm_loss_trg_3d)\n        if cfg.TRAIN.XMUDA.lambda_pl > 0:\n            # uni-modal self-training loss with pseudo labels\n            pl_loss_trg_2d = F.cross_entropy(preds_2d['seg_logit'], data_batch_trg['pseudo_label_2d'])\n            pl_loss_trg_3d = F.cross_entropy(preds_3d['seg_logit'], data_batch_trg['pseudo_label_3d'])\n            train_metric_logger.update(pl_loss_trg_2d=pl_loss_trg_2d,\n                                       pl_loss_trg_3d=pl_loss_trg_3d)\n            loss_2d.append(cfg.TRAIN.XMUDA.lambda_pl * pl_loss_trg_2d)\n            loss_3d.append(cfg.TRAIN.XMUDA.lambda_pl * pl_loss_trg_3d)\n        if cfg.TRAIN.XMUDA.lambda_minent > 0:\n            # MinEnt\n            minent_loss_trg_2d = entropy_loss(F.softmax(preds_2d['seg_logit'], dim=1))\n            minent_loss_trg_3d = entropy_loss(F.softmax(preds_3d['seg_logit'], dim=1))\n            train_metric_logger.update(minent_loss_trg_2d=minent_loss_trg_2d,\n                                       minent_loss_trg_3d=minent_loss_trg_3d)\n            loss_2d.append(cfg.TRAIN.XMUDA.lambda_minent * minent_loss_trg_2d)\n            loss_3d.append(cfg.TRAIN.XMUDA.lambda_minent * minent_loss_trg_3d)\n\n        sum(loss_2d).backward()\n        sum(loss_3d).backward()\n\n        optimizer_2d.step()\n        optimizer_3d.step()\n\n        batch_time = time.time() - end\n        train_metric_logger.update(time=batch_time, data=data_time)\n\n        # log\n        cur_iter = iteration + 1\n        if cur_iter == 1 or (cfg.TRAIN.LOG_PERIOD > 0 and cur_iter % cfg.TRAIN.LOG_PERIOD == 0):\n            logger.info(\n                train_metric_logger.delimiter.join(\n                    [\n                        'iter: {iter:4d}',\n                        '{meters}',\n                        'lr: {lr:.2e}',\n                        'max mem: {memory:.0f}',\n                    ]\n                ).format(\n                    iter=cur_iter,\n                    meters=str(train_metric_logger),\n                    lr=optimizer_2d.param_groups[0]['lr'],\n                    memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),\n                )\n            )\n\n        # summary\n        if summary_writer is not None and cfg.TRAIN.SUMMARY_PERIOD > 0 and cur_iter % cfg.TRAIN.SUMMARY_PERIOD == 0:\n            keywords = ('loss', 'acc', 'iou')\n            for name, meter in train_metric_logger.meters.items():\n                if all(k not in name for k in keywords):\n                    continue\n                summary_writer.add_scalar('train/' + name, meter.avg, global_step=cur_iter)\n\n        # checkpoint\n        if (ckpt_period > 0 and cur_iter % ckpt_period == 0) or cur_iter == max_iteration:\n            checkpoint_data_2d['iteration'] = cur_iter\n            checkpoint_data_2d[best_metric_name] = best_metric['2d']\n            checkpointer_2d.save('model_2d_{:06d}'.format(cur_iter), **checkpoint_data_2d)\n            checkpoint_data_3d['iteration'] = cur_iter\n            checkpoint_data_3d[best_metric_name] = best_metric['3d']\n            checkpointer_3d.save('model_3d_{:06d}'.format(cur_iter), **checkpoint_data_3d)\n\n        # ---------------------------------------------------------------------------- #\n        # validate for one epoch\n        # ---------------------------------------------------------------------------- #\n        if val_period > 0 and (cur_iter % val_period == 0 or cur_iter == max_iteration):\n            start_time_val = time.time()\n            setup_validate()\n\n            validate(cfg,\n                     model_2d,\n                     model_3d,\n                     val_dataloader,\n                     val_metric_logger)\n\n            epoch_time_val = time.time() - start_time_val\n            logger.info('Iteration[{}]-Val {}  total_time: {:.2f}s'.format(\n                cur_iter, val_metric_logger.summary_str, epoch_time_val))\n\n            # summary\n            if summary_writer is not None:\n                keywords = ('loss', 'acc', 'iou')\n                for name, meter in val_metric_logger.meters.items():\n                    if all(k not in name for k in keywords):\n                        continue\n                    summary_writer.add_scalar('val/' + name, meter.avg, global_step=cur_iter)\n\n            # best validation\n            for modality in ['2d', '3d']:\n                cur_metric_name = cfg.VAL.METRIC + '_' + modality\n                if cur_metric_name in val_metric_logger.meters:\n                    cur_metric = val_metric_logger.meters[cur_metric_name].global_avg\n                    if best_metric[modality] is None or best_metric[modality] < cur_metric:\n                        best_metric[modality] = cur_metric\n                        best_metric_iter[modality] = cur_iter\n\n            # restore training\n            setup_train()\n\n        scheduler_2d.step()\n        scheduler_3d.step()\n        end = time.time()\n\n    for modality in ['2d', '3d']:\n        logger.info('Best val-{}-{} = {:.2f} at iteration {}'.format(modality.upper(),\n                                                                     cfg.VAL.METRIC,\n                                                                     best_metric[modality] * 100,\n                                                                     best_metric_iter[modality]))\n\n\ndef main():\n    args = parse_args()\n\n    # load the configuration\n    # import on-the-fly to avoid overwriting cfg\n    from xmuda.common.config import purge_cfg\n    from xmuda.config.xmuda import cfg\n    cfg.merge_from_file(args.config_file)\n    cfg.merge_from_list(args.opts)\n    purge_cfg(cfg)\n    cfg.freeze()\n\n    output_dir = cfg.OUTPUT_DIR\n    # replace '@' with config path\n    if output_dir:\n        config_path = osp.splitext(args.config_file)[0]\n        output_dir = output_dir.replace('@', config_path.replace('configs/', ''))\n        if osp.isdir(output_dir):\n            warnings.warn('Output directory exists.')\n        os.makedirs(output_dir, exist_ok=True)\n\n    # run name\n    timestamp = time.strftime('%m-%d_%H-%M-%S')\n    hostname = socket.gethostname()\n    run_name = '{:s}.{:s}'.format(timestamp, hostname)\n\n    logger = setup_logger('xmuda', output_dir, comment='train.{:s}'.format(run_name))\n    logger.info('{:d} GPUs available'.format(torch.cuda.device_count()))\n    logger.info(args)\n\n    logger.info('Loaded configuration file {:s}'.format(args.config_file))\n    logger.info('Running with config:\\n{}'.format(cfg))\n\n    # check that 2D and 3D model use either both single head or both dual head\n    assert cfg.MODEL_2D.DUAL_HEAD == cfg.MODEL_3D.DUAL_HEAD\n    # check if there is at least one loss on target set\n    assert cfg.TRAIN.XMUDA.lambda_xm_src > 0 or cfg.TRAIN.XMUDA.lambda_xm_trg > 0 or cfg.TRAIN.XMUDA.lambda_pl > 0 or \\\n           cfg.TRAIN.XMUDA.lambda_minent > 0\n    train(cfg, output_dir, run_name)\n\n\nif __name__ == '__main__':\n    main()\n"
  }
]