[
  {
    "path": ".gitignore",
    "content": "# OS generated files\n.DS_Store\n.DS_Store?\n._*\n.Spotlight-V100\n.Trashes\nehthumbs.db\nThumbs.db\n\n# Compiled source\nbuild\ndebug\nDebug\nrelease\nRelease\nx64\n*.so\n*.whl\n\n# VS project files\n*.sln\n*.vcxproj\n*.vcxproj.filters\n*.vcxproj.user\n*.rc\n.vs\n\n# Byte-compiled / optimized / DLL files\n*__pycache__*\n*.py[cod]\n*$py.class\n\n# Distribution / packaging\n.Python\nbuild\ndevelop-eggs\ndist\ndownloads\n\n# IDE\n.idea\n.vscode\npyrightconfig.json\n\n# Custom\ndata\noutputs\nprediction\nsubmission\ncheckpoints\npretrain\nckpts\nocc_result\nwandb"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# SparseOcc\n\nThis is the official PyTorch implementation for our paper:\n\n> [**Fully Sparse 3D Panoptic Occupancy Prediction**](https://arxiv.org/abs/2312.17118)<br>\n> :school: Presented by Nanjing University and Shanghai AI Lab<br>\n> :email: Primary contact: Haisong Liu (afterthat97@gmail.com)<br>\n> :trophy: [CVPR 2024 Autonomous Driving Challenge - Occupancy and Flow](https://opendrivelab.com/challenge2024/#occupancy_and_flow)<br>\n> :book: 中文解读（官方）：https://zhuanlan.zhihu.com/p/709576252<br>\n> :book: 中文解读（第三方）: [AIming](https://zhuanlan.zhihu.com/p/691549750), [自动驾驶之心](https://zhuanlan.zhihu.com/p/675811281)\n\n## :warning: Important Notes\n\nThere is another concurrent project titled *''SparseOcc: Rethinking sparse latent representation for vision-based semantic occupancy prediction''* by Tang et al., which shares the same name SparseOcc with ours. However, this repository is **unrelated** to the aforementioned paper.\n\nIf you cite our research, please ensure that you reference the correct version (arXiv **2312.17118**, authored by **Liu et al.**):\n\n```\n@article{liu2023fully,\n  title={Fully sparse 3d panoptic occupancy prediction},\n  author={Liu, Haisong and Wang, Haiguang and Chen, Yang and Yang, Zetong and Zeng, Jia and Chen, Li and Wang, Limin},\n  journal={arXiv preprint arXiv:2312.17118},\n  year={2023}\n}\n```\n\n> In arXiv 2312.17118v3, we removed the word \"panoptic\" from the title. However, Google Scholar's database has not been updated and still shows the old one. Therefore, we still recommend citing the old title - \"Fully sparse 3d panoptic occupancy prediction\" - so that Google Scholar can index it correctly. Thank you all.\n\n## News\n\n* **2024-07-19**: We released an updated version of SparseOcc on [arXiv](https://arxiv.org/abs/2312.17118). All charts and colors have been carefully adjusted. Delete the old version and download the new one!\n* **2024-07-01**: SparseOcc is accepted to ECCV 2024.\n* **2024-06-27**: SparseOcc v1.1 is released. In this change, we introduce BEV data augmentation (BDA) and Lovasz-Softmax loss to further enhance the performance. Compared with [v1.0](https://github.com/MCG-NJU/SparseOcc/tree/v1.0) (35.0 RayIoU with 48 epochs), SparseOcc v1.1 can achieve 36.8 RayIoU with 24 epochs!\n* **2024-05-29**: We add support for [OpenOcc v2](configs/r50_nuimg_704x256_8f_openocc.py) dataset (without occupancy flow).\n* **2024-04-11**: The panoptic version of SparseOcc ([configs/r50_nuimg_704x256_8f_pano.py](configs/r50_nuimg_704x256_8f_pano.py)) is released.\n* **2024-04-09**: An updated arXiv version [https://arxiv.org/abs/2312.17118v3](https://arxiv.org/abs/2312.17118v3) has been released.\n* **2024-03-31**: We release the code and pretrained weights.\n* **2023-12-30**: We release the paper.\n\n## Highlights\n\n**New model**:1st_place_medal:: SparseOcc initially reconstructs a sparse 3D representation from visual inputs and subsequently predicts semantic/instance occupancy from the 3D sparse representation by sparse queries.\n\n![](asserts/arch.jpg)\n\n**New evaluation metric**:chart_with_upwards_trend:: We design a thoughtful ray-based evaluation metric, namely RayIoU, to solve the inconsistency penalty along depths raised in traditional voxel-level mIoU criteria.\n\n![](asserts/rayiou.jpg)\n\nSome FAQs from the community about the evaluation metrics:\n\n1. **Why does training with visible masks result in significant improvements in the old mIoU metric, but not in the new RayIoU metric?** As mentioned in the paper, when using the visible mask during training, the area behind the surface won't be supervised, so the model tends to fill this area with duplicated predictions, leading to a thicker surface. The old metric inconsistently penalizes along the depth axis when the prediction has a thick surface. Thus, this ''imporovement'' is mainly due to the vulnerability of old metric.\n2. **Why SparseOcc cannot exploit the vulnerability of the old metrics?** As SparseOcc employs a fully sparse architecture, it always predicts a thin surface. Thus, there are two ways for a fair comparison: (a) use the old metric, but all methods must predict a thin surface, which implies they cannot use the visible mask during training; (b) use RayIoU, as it is more reasonable and can fairly compare thick or thin surface. Our method achieves SOTA performance on both cases.\n3. **Does RayIoU overlook interior reconstruction?** Firstly, we are unable to obtain the interior occupancy ground-truth. This is because the ground-truth is derived from voxelizing LiDAR point clouds, and LiDARs are only capable of scanning the thin surface of an object. Secondly, the query ray in RayIoU can originate from any position within the scene (see the figure above). This allows it to evaluate the overall reconstruction performance, unlike depth estimation. We would like to emphasize that the evaluation logic of RayIoU aligns with the process of ground-truth generation.\n\nIf you have other questions, feel free to contact me (Haisong Liu, afterthat97@gmail.com).\n\n## Model Zoo\n\nThese results are from our latest version, v1.1, which outperforms the results reported in the paper. Additionally, our implementation differs slightly from the original paper. If you wish to reproduce the paper exactly, please refer to the [v1.0](https://github.com/MCG-NJU/SparseOcc/tree/v1.0) tag.\n\n| Setting  | Epochs | Training Cost | RayIoU | RayPQ | FPS | Weights |\n|----------|:--------:|:-------------:|:------:|:-----:|:---:|:-------:|\n| [r50_nuimg_704x256_8f](configs/r50_nuimg_704x256_8f.py) | 24 | 15h, ~12GB | 36.8 | - | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_24e_v1.1.pth) |\n| [r50_nuimg_704x256_8f_60e](configs/r50_nuimg_704x256_8f_60e.py) | 60 | 37h, ~12GB | 37.7 | - | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_60e_v1.1.pth) |\n| [r50_nuimg_704x256_8f_pano](configs/r50_nuimg_704x256_8f_pano.py) | 24 | 15h, ~12GB | 35.9 | 14.0 | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_pano_24e_v1.1.pth) |\n\n* The backbone is pretrained on [nuImages](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth). Download the weights to `pretrain/xxx.pth` before you start training.\n* FPS is measured with Intel(R) Xeon(R) Platinum 8369B CPU and NVIDIA A100-SXM4-80GB GPU (PyTorch `fp32` backend, including data loading).\n* We will release more settings in the future.\n\n## Environment\n\n> The requirements are the same as those of [SparseBEV](https://github.com/MCG-NJU/SparseBEV).\n\nInstall PyTorch 2.0 + CUDA 11.8:\n\n```\nconda create -n sparseocc python=3.8\nconda activate sparseocc\nconda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.8 -c pytorch -c nvidia\n```\n\nInstall other dependencies:\n\n```\npip install openmim\nmim install mmcv-full==1.6.0\nmim install mmdet==2.28.2\nmim install mmsegmentation==0.30.0\nmim install mmdet3d==1.0.0rc6\npip install setuptools==59.5.0\npip install numpy==1.23.5\n```\n\nInstall turbojpeg and pillow-simd to speed up data loading (optional but important):\n\n```\nsudo apt-get update\nsudo apt-get install -y libturbojpeg\npip install pyturbojpeg\npip uninstall pillow\npip install pillow-simd==9.0.0.post1\n```\n\nCompile CUDA extensions:\n\n```\ncd models/csrc\npython setup.py build_ext --inplace\n```\n\n## Prepare Dataset\n\n> The first two steps are the same as those of [SparseBEV](https://github.com/MCG-NJU/SparseBEV).\n\n1. Download nuScenes from [https://www.nuscenes.org/nuscenes](https://www.nuscenes.org/nuscenes), put it to `data/nuscenes` and preprocess it with [mmdetection3d](https://github.com/open-mmlab/mmdetection3d/tree/v1.0.0rc6).\n\n2. Download the generated info file from [gdrive](https://drive.google.com/file/d/1uyoUuSRIVScrm_CUpge6V_UzwDT61ODO/view?usp=sharing) and unzip it. These `*.pkl` files can also be generated with our script: `gen_sweep_info.py`.\n\n3. Download Occ3D-nuScenes occupancy GT from [gdrive](https://drive.google.com/file/d/1kiXVNSEi3UrNERPMz_CfiJXKkgts_5dY/view?usp=drive_link), unzip it, and save it to `data/nuscenes/occ3d`.\n\n4. Folder structure:\n\n```\ndata/nuscenes\n├── maps\n├── nuscenes_infos_test_sweep.pkl\n├── nuscenes_infos_train_sweep.pkl\n├── nuscenes_infos_val_sweep.pkl\n├── samples\n├── sweeps\n├── v1.0-test\n└── v1.0-trainval\n└── occ3d\n    ├── scene-0001\n    │   ├── 0037a705a2e04559b1bba6c01beca1cf\n    │   │   └── labels.npz\n    │   ├── 026155aa1c554e2f87914ec9ba80acae\n    │   │   └── labels.npz\n    ...\n```\n\n5. (Optional) Generate the panoptic occupancy ground truth with `gen_instance_info.py`. The panoptic version of Occ3D will be saved to `data/nuscenes/occ3d_panoptic`.\n\n## Training\n\nTrain SparseOcc with 8 GPUs:\n\n```\ntorchrun --nproc_per_node 8 train.py --config configs/sparseocc_r50_nuimg_704x256_8f.py\n```\n\nTrain SparseOcc with 4 GPUs (i.e the last four GPUs):\n\n```\nexport CUDA_VISIBLE_DEVICES=4,5,6,7\ntorchrun --nproc_per_node 4 train.py --config configs/sparseocc_r50_nuimg_704x256_8f.py\n```\n\nThe batch size for each GPU will be scaled automatically. So there is no need to modify the `batch_size` in config files.\n\n## Evaluation\n\nSingle-GPU evaluation:\n\n```\nexport CUDA_VISIBLE_DEVICES=0\npython val.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth\n```\n\nMulti-GPU evaluation:\n\n```\nexport CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7\ntorchrun --nproc_per_node 8 val.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth\n```\n\n## Standalone Evaluation\n\nIf you want to evaluate your own model using RayIoU, please follow the steps below:\n\n1. Save the predictions (shape=`[200x200x16]`, dtype=`np.uint8`) with the compressed `npz` format. For example:\n\n```\nsave_path = os.path.join(save_dir, sample_token + '.npz')\nnp.savez_compressed(save_path, pred=sem_pred)\n``` \n\n2. The filename for each sample is `sample_token.npz`,  for example:\n\n```\nprediction/your_model\n├── 000681a060c04755a1537cf83b53ba57.npz\n├── 000868a72138448191b4092f75ed7776.npz\n├── 0017c2623c914571a1ff2a37f034ffd7.npz\n├── ...\n```\n\n3. Run `ray_metrics.py` to evaluate on the RayIoU:\n\n```\npython ray_metrics.py --pred-dir prediction/your_model\n```\n\n## Timing\n\nFPS is measured with a single GPU:\n\n```\nexport CUDA_VISIBLE_DEVICES=0\npython timing.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth\n```\n\n## Acknowledgements\n\nMany thanks to these excellent open-source projects:\n\n* [MaskFormer](https://github.com/facebookresearch/MaskFormer)\n* [NeuralRecon](https://github.com/zju3dv/NeuralRecon)\n* [4D-Occ](https://github.com/tarashakhurana/4d-occ-forecasting)\n* [MMDetection3D](https://github.com/open-mmlab/mmdetection3d)\n"
  },
  {
    "path": "configs/r50_nuimg_704x256_8f.py",
    "content": "dataset_type = 'NuSceneOcc'\ndataset_root = 'data/nuscenes/'\nocc_gt_root = 'data/nuscenes/occ3d'\n\n# If point cloud range is changed, the models should also change their point\n# cloud range accordingly\npoint_cloud_range = [-40, -40, -1.0, 40, 40, 5.4]\nocc_size = [200, 200, 16]\n\nimg_norm_cfg = dict(\n    mean=[123.675, 116.280, 103.530],\n    std=[58.395, 57.120, 57.375],\n    to_rgb=True\n)\n\n# For nuScenes we usually do 10-class detection\ndet_class_names = [\n    'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',\n    'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'\n]\n\nocc_class_names = [\n    'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',\n    'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',\n    'driveable_surface', 'other_flat', 'sidewalk',\n    'terrain', 'manmade', 'vegetation', 'free'\n]\n\ninput_modality = dict(\n    use_lidar=False,\n    use_camera=True,\n    use_radar=False,\n    use_map=False,\n    use_external=False\n)\n\n_dim_ = 256\n_num_points_ = 4\n_num_groups_ = 4\n_num_layers_ = 2\n_num_frames_ = 8\n_num_queries_ = 100\n_topk_training_ = [4000, 16000, 64000]\n_topk_testing_ = [2000, 8000, 32000]\n\nmodel = dict(\n    type='SparseOcc',\n    data_aug=dict(\n        img_color_aug=True,  # Move some augmentations to GPU\n        img_norm_cfg=img_norm_cfg,\n        img_pad_cfg=dict(size_divisor=32)),\n    use_mask_camera=False,\n    img_backbone=dict(\n        type='ResNet',\n        depth=50,\n        num_stages=4,\n        out_indices=(0, 1, 2, 3),\n        frozen_stages=1,\n        norm_cfg=dict(type='BN2d', requires_grad=True),\n        norm_eval=True,\n        style='pytorch',\n        with_cp=True),\n    img_neck=dict(\n        type='FPN',\n        in_channels=[256, 512, 1024, 2048],\n        out_channels=_dim_,\n        num_outs=4),\n    pts_bbox_head=dict(\n        type='SparseOccHead',\n        class_names=occ_class_names,\n        embed_dims=_dim_,\n        occ_size=occ_size,\n        pc_range=point_cloud_range,\n        transformer=dict(\n            type='SparseOccTransformer',\n            embed_dims=_dim_,\n            num_layers=_num_layers_,\n            num_frames=_num_frames_,\n            num_points=_num_points_,\n            num_groups=_num_groups_,\n            num_queries=_num_queries_,\n            num_levels=4,\n            num_classes=len(occ_class_names),\n            pc_range=point_cloud_range,\n            occ_size=occ_size,\n            topk_training=_topk_training_,\n            topk_testing=_topk_testing_),\n        loss_cfgs=dict(\n            loss_mask2former=dict(\n                type='Mask2FormerLoss',\n                num_classes=len(occ_class_names),\n                no_class_weight=0.1,\n                loss_cls_weight=2.0,\n                loss_mask_weight=5.0,\n                loss_dice_weight=5.0,\n            ),\n            loss_geo_scal=dict(\n                type='GeoScalLoss',\n                num_classes=len(occ_class_names),\n                loss_weight=1.0\n            ),\n            loss_sem_scal=dict(\n                type='SemScalLoss',\n                num_classes=len(occ_class_names),\n                loss_weight=1.0\n            )\n        ),\n    ),\n)\n\nida_aug_conf = {\n    'resize_lim': (0.38, 0.55),\n    'final_dim': (256, 704),\n    'bot_pct_lim': (0.0, 0.0),\n    'rot_lim': (0.0, 0.0),\n    'H': 900, 'W': 1600,\n    'rand_flip': True,\n}\n\nbda_aug_conf = dict(\n    rot_lim=(-22.5, 22.5),\n    scale_lim=(1., 1.),\n    flip_dx_ratio=0.5,\n    flip_dy_ratio=0.5\n)\n\ntrain_pipeline = [\n    dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),\n    dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1),\n    dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=True),\n    dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)),\n    dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),\n    dict(type='DefaultFormatBundle3D', class_names=det_class_names),\n    dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],  # other keys: 'mask_camera'\n         meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))\n]\n\ntest_pipeline = [\n    dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),\n    dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True),\n    dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=False),\n    dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)),\n    dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),\n    dict(type='DefaultFormatBundle3D', class_names=det_class_names),\n    dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],\n         meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))\n]\n\ndata = dict(\n    workers_per_gpu=8,\n    train=dict(\n        type=dataset_type,\n        data_root=dataset_root,\n        occ_gt_root=occ_gt_root,\n        ann_file=dataset_root + 'nuscenes_infos_train_sweep.pkl',\n        pipeline=train_pipeline,\n        classes=det_class_names,\n        modality=input_modality,\n        test_mode=False\n    ),\n    val=dict(\n        type=dataset_type,\n        data_root=dataset_root,\n        occ_gt_root=occ_gt_root,\n        ann_file=dataset_root + 'nuscenes_infos_val_sweep.pkl',\n        pipeline=test_pipeline,\n        classes=det_class_names,\n        modality=input_modality,\n        test_mode=True\n    ),\n    test=dict(\n        type=dataset_type,\n        data_root=dataset_root,\n        occ_gt_root=occ_gt_root,\n        ann_file=dataset_root + 'nuscenes_infos_test_sweep.pkl',\n        pipeline=test_pipeline,\n        classes=det_class_names,\n        modality=input_modality,\n        test_mode=True\n    ),\n)\n\noptimizer = dict(\n    type='AdamW',\n    lr=5e-4,\n    paramwise_cfg=dict(\n        custom_keys={\n            'img_backbone': dict(lr_mult=0.1),\n            'sampling_offset': dict(lr_mult=0.1),\n        }),\n    weight_decay=0.01\n)\noptimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=500,\n    warmup_ratio=1.0 / 3,\n    by_epoch=True,\n    step=[22, 24],\n    gamma=0.2\n)\ntotal_epochs = 24\nbatch_size = 8\n\n# load pretrained weights\nload_from = 'pretrain/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth'\nrevise_keys = [('backbone', 'img_backbone')]\n\n# resume the last training\nresume_from = None\n\n# checkpointing\ncheckpoint_config = dict(interval=1, max_keep_ckpts=1)\n\n# logging\nlog_config = dict(\n    interval=1,\n    hooks=[\n        dict(type='MyTextLoggerHook', interval=1, reset_flag=True),\n        dict(type='MyTensorboardLoggerHook', interval=500, reset_flag=True)\n    ]\n)\n\n# evaluation\neval_config = dict(interval=total_epochs)\n\n# other flags\ndebug = False"
  },
  {
    "path": "configs/r50_nuimg_704x256_8f_60e.py",
    "content": "_base_ = ['./r50_nuimg_704x256_8f.py']\n\nlr_config = dict(\n    policy='step',\n    warmup='linear',\n    warmup_iters=500,\n    warmup_ratio=1.0 / 3,\n    by_epoch=True,\n    step=[48, 60],\n    gamma=0.2\n)\ntotal_epochs = 60\n\n# evaluation\neval_config = dict(interval=total_epochs)"
  },
  {
    "path": "configs/r50_nuimg_704x256_8f_openocc.py",
    "content": "_base_ = ['./r50_nuimg_704x256_8f.py']\n\nocc_gt_root = 'data/nuscenes/openocc_v2'\n\ndet_class_names = [\n    'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',\n    'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'\n]\n\nocc_class_names = [\n    'car', 'truck', 'trailer', 'bus', 'construction_vehicle',\n    'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier',\n    'driveable_surface', 'other_flat', 'sidewalk',\n    'terrain', 'manmade', 'vegetation', 'free'\n]\n\n_num_frames_ = 8\n\nmodel = dict(\n    pts_bbox_head=dict(\n        class_names=occ_class_names,\n        transformer=dict(\n            num_classes=len(occ_class_names)),\n        loss_cfgs=dict(\n            loss_mask2former=dict(\n                num_classes=len(occ_class_names)\n            ),\n        ),\n    ),\n)\n\nida_aug_conf = {\n    'resize_lim': (0.38, 0.55),\n    'final_dim': (256, 704),\n    'bot_pct_lim': (0.0, 0.0),\n    'rot_lim': (0.0, 0.0),\n    'H': 900, 'W': 1600,\n    'rand_flip': False,\n}\n\ntrain_pipeline = [\n    dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),\n    dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1),\n    dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)),\n    dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),\n    dict(type='DefaultFormatBundle3D', class_names=det_class_names),\n    dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],  # other keys: 'mask_camera'\n         meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))\n]\n\ntest_pipeline = [\n    dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),\n    dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True),\n    dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)),\n    dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),\n    dict(type='DefaultFormatBundle3D', class_names=det_class_names),\n    dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],\n         meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))\n]\n\ndata = dict(\n    workers_per_gpu=8,\n    train=dict(\n        pipeline=train_pipeline,\n        occ_gt_root=occ_gt_root\n    ),\n    val=dict(\n        pipeline=test_pipeline,\n        occ_gt_root=occ_gt_root\n    ),\n    test=dict(\n        pipeline=test_pipeline,\n        occ_gt_root=occ_gt_root\n    ),\n)"
  },
  {
    "path": "configs/r50_nuimg_704x256_8f_pano.py",
    "content": "_base_ = ['./r50_nuimg_704x256_8f.py']\n\nocc_gt_root = 'data/nuscenes/occ3d_panoptic'\n\n# For nuScenes we usually do 10-class detection\ndet_class_names = [\n    'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',\n    'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'\n]\n\nocc_class_names = [\n    'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',\n    'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',\n    'driveable_surface', 'other_flat', 'sidewalk',\n    'terrain', 'manmade', 'vegetation', 'free'\n]\n\n_num_frames_ = 8\n\nmodel = dict(\n    pts_bbox_head=dict(\n        panoptic=True\n    )\n)\n\nida_aug_conf = {\n    'resize_lim': (0.38, 0.55),\n    'final_dim': (256, 704),\n    'bot_pct_lim': (0.0, 0.0),\n    'rot_lim': (0.0, 0.0),\n    'H': 900, 'W': 1600,\n    'rand_flip': True,\n}\n\nbda_aug_conf = dict(\n    rot_lim=(-22.5, 22.5),\n    scale_lim=(1., 1.),\n    flip_dx_ratio=0.5,\n    flip_dy_ratio=0.5\n)\n\ntrain_pipeline = [\n    dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),\n    dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1),\n    dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=True),\n    dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names), inst_class_ids=[2, 3, 4, 5, 6, 7, 9, 10]),\n    dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),\n    dict(type='DefaultFormatBundle3D', class_names=det_class_names),\n    dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],  # other keys: 'mask_camera'\n         meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))\n]\n\ntest_pipeline = [\n    dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),\n    dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True),\n    dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=False),\n    dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names), inst_class_ids=[2, 3, 4, 5, 6, 7, 9, 10]),\n    dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),\n    dict(type='DefaultFormatBundle3D', class_names=det_class_names),\n    dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],\n         meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))\n]\n\ndata = dict(\n    workers_per_gpu=8,\n    train=dict(\n        pipeline=train_pipeline,\n        occ_gt_root=occ_gt_root\n    ),\n    val=dict(\n        pipeline=test_pipeline,\n        occ_gt_root=occ_gt_root\n    ),\n    test=dict(\n        pipeline=test_pipeline,\n        occ_gt_root=occ_gt_root\n    ),\n)\n"
  },
  {
    "path": "gen_instance_info.py",
    "content": "import os\nimport tqdm\nimport glob\nimport pickle\nimport argparse\nimport numpy as np\nimport torch\nimport multiprocessing\nfrom pyquaternion import Quaternion\nfrom nuscenes.utils.data_classes import Box\nfrom nuscenes.utils.geometry_utils import points_in_box\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--nusc-root', default='data/nuscenes')\nparser.add_argument('--occ3d-root', default='data/nuscenes/occ3d')\nparser.add_argument('--output-dir', default='data/nuscenes/occ3d_panoptic')\nparser.add_argument('--version', default='v1.0-trainval')\nargs = parser.parse_args()\n\ntoken2path = {}\nfor gt_path in glob.glob(os.path.join(args.occ3d_root, '*/*/*.npz')):\n    token = gt_path.split('/')[-2]\n    token2path[token] = gt_path\n\nocc_class_names = [\n    'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',\n    'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',\n    'driveable_surface', 'other_flat', 'sidewalk',\n    'terrain', 'manmade', 'vegetation', 'free'\n]\n\ndet_class_names = [\n    'car', 'truck', 'trailer', 'bus', 'construction_vehicle',\n    'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'\n]\n\n\ndef convert_to_nusc_box(bboxes, lift_center=False, wlh_margin=0.0):\n    results = []\n    for q in range(bboxes.shape[0]):\n\n        bbox = bboxes[q].copy()\n        if lift_center:\n            bbox[2] += bbox[5] * 0.5\n\n        bbox_yaw = -bbox[6] - np.pi / 2\n        orientation = Quaternion(axis=[0, 0, 1], radians=bbox_yaw).inverse\n\n        box = Box(\n            center=[bbox[0], bbox[1], bbox[2]],\n            # 0.8 in pc range is roungly 2 voxels in occ grid\n            # enlarge bbox to include voxels on the edge\n            size=[bbox[3]+wlh_margin, bbox[4]+wlh_margin, bbox[5]+wlh_margin],\n            orientation=orientation,\n        )\n\n        results.append(box)\n\n    return results\n\n\ndef meshgrid3d(occ_size, pc_range):  # points in ego coord\n    W, H, D = occ_size\n    \n    xs = torch.linspace(0.5, W - 0.5, W).view(W, 1, 1).expand(W, H, D) / W\n    ys = torch.linspace(0.5, H - 0.5, H).view(1, H, 1).expand(W, H, D) / H\n    zs = torch.linspace(0.5, D - 0.5, D).view(1, 1, D).expand(W, H, D) / D\n    xs = xs * (pc_range[3] - pc_range[0]) + pc_range[0]\n    ys = ys * (pc_range[4] - pc_range[1]) + pc_range[1]\n    zs = zs * (pc_range[5] - pc_range[2]) + pc_range[2]\n    xyz = torch.stack((xs, ys, zs), -1)\n\n    return xyz\n\n\ndef process_add_instance_info(sample):\n    point_cloud_range = [-40, -40, -1.0, 40, 40, 5.4]\n    occ_size = [200, 200, 16]\n    num_classes = 18\n    \n    occ_gt_path = token2path[sample['token']]\n    occ_labels = np.load(occ_gt_path)\n    \n    occ_gt = occ_labels['semantics']\n    gt_boxes = sample['gt_boxes']\n    gt_names = sample['gt_names']\n    \n    bboxes = convert_to_nusc_box(gt_boxes)\n    \n    instance_gt = np.zeros(occ_gt.shape).astype(np.uint8)\n    instance_id = 1\n    \n    pts = meshgrid3d(occ_size, point_cloud_range).numpy()\n    \n    # filter out free voxels to accelerate\n    valid_idx = np.where(occ_gt < num_classes - 1)\n    flatten_occ_gt = occ_gt[valid_idx]\n    flatten_inst_gt = instance_gt[valid_idx]\n    flatten_pts = pts[valid_idx]\n    \n    instance_boxes = []\n    instance_class_ids = []\n    \n    for i in range(len(gt_names)):\n        if gt_names[i] not in occ_class_names:\n            continue\n        occ_tag_id = occ_class_names.index(gt_names[i])\n            \n        # Move box to ego vehicle coord system\n        bbox = bboxes[i]\n        bbox.rotate(Quaternion(sample['lidar2ego_rotation']))\n        bbox.translate(np.array(sample['lidar2ego_translation']))\n        \n        mask = points_in_box(bbox, flatten_pts.transpose(1, 0))\n        \n        # ignore voxels not belonging to this class\n        mask[mask] = (flatten_occ_gt[mask] == occ_tag_id)\n        # ignore voxels already occupied\n        mask[mask] = (flatten_inst_gt[mask] == 0)\n        \n        # only instance with at least 1 voxel will be recorded\n        if mask.sum() > 0:\n            flatten_inst_gt[mask] = instance_id\n            instance_id += 1\n            \n            # enlarge boxes to include voxels on the edge\n            new_box = bbox.copy()\n            new_box.wlh = new_box.wlh + 0.8\n            \n            instance_boxes.append(new_box)\n            instance_class_ids.append(occ_tag_id)\n    \n    # classes that should be viewed as one instance\n    all_class_ids_unique = np.unique(occ_gt)\n    for i, class_name in enumerate(occ_class_names):\n        if class_name in det_class_names or class_name == 'free' or i not in all_class_ids_unique:\n            continue\n        flatten_inst_gt[flatten_occ_gt == i] = instance_id\n        instance_id += 1\n    \n    # post process unconvered non-occupied voxels\n    uncover_idx = np.where(flatten_inst_gt == 0)\n    uncover_pts = flatten_pts[uncover_idx]\n    uncover_inst_gt = np.zeros_like(uncover_pts[..., 0]).astype(np.uint8)\n    unconver_occ_gt = flatten_occ_gt[uncover_idx]\n    \n    # uncover_inst_dist records the dist between each voxel and its current nearest bbox's center\n    uncover_inst_dist = np.ones_like(uncover_pts[..., 0]) * 1e8\n    for i, box in enumerate(instance_boxes):\n        # important, non-background inst id starts from 1\n        inst_id = i + 1\n        class_id = instance_class_ids[i]\n        mask = points_in_box(box, uncover_pts.transpose(1, 0))\n        # mask voxels not belonging to this class\n        mask[unconver_occ_gt != class_id] = False\n        dist = np.sum((box.center - uncover_pts) ** 2, axis=-1)\n        # voxels that have already been assigned to a closer box's instance should be ignored\n        # voxels that not inside the box should be ignored\n        # `mask[(dist >= uncover_inst_dist)]=False` is right, as it only transforms True masks into False without converting False into True\n        # to give readers a more clear understanding, the most standard writing is `mask[mask & (dist >= uncover_inst_dist)]=False`\n        mask[dist >= uncover_inst_dist] = False\n        # mask[mask & (dist >= uncover_inst_dist)]=False\n        \n        # important: only voxels inside the box (mask = True) and having no closer identical-class box need to update dist\n        uncover_inst_dist[mask] = dist[mask]\n        uncover_inst_gt[mask] = inst_id\n        \n    flatten_inst_gt[uncover_idx] = uncover_inst_gt\n    \n    instance_gt[valid_idx] = flatten_inst_gt\n    # not using this checking function yet\n    # assert (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum() < 100, \"too many non-free voxels are not assigned to any instance in %s\"%(occ_gt_path)\n    # global max_margin\n    # if max_margin < (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum():\n    #     print(\"###### new max margin: \", max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum()))\n    # max_margin = max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum())\n    \n    # save to original path\n    data_split = occ_gt_path.split(os.path.sep)[-3:]\n    data_path = os.path.sep.join(data_split)\n    \n    ##### Warning: Using args.xxx (global variable) here is strongly unrecommended\n    save_path = os.path.join(args.output_dir, data_path)\n    \n    save_dir = os.path.split(save_path)[0]\n    if not os.path.exists(save_dir):\n        os.makedirs(save_dir)\n    \n    if np.unique(instance_gt).shape[0] != instance_gt.max()+1:\n        print('warning: some instance masks are covered by following ones %s'%(save_dir))\n    \n    # only semantic and mask information is needed to be reserved\n    retain_keys = ['semantics', 'mask_lidar', 'mask_camera']   \n    new_occ_labels = {k: occ_labels[k] for k in retain_keys}\n    new_occ_labels['instances'] = instance_gt\n    np.savez_compressed(save_path, **new_occ_labels)\n\n\ndef add_instance_info(sample_infos):\n    if not os.path.exists(args.output_dir):\n        os.makedirs(args.output_dir)\n    \n    # all cpus participate in multi processing\n    pool = multiprocessing.Pool(multiprocessing.cpu_count())\n    with tqdm.tqdm(total=len(sample_infos['infos'])) as pbar:\n        for _ in pool.imap(process_add_instance_info, sample_infos['infos']):\n            pbar.update(1)\n    \n    pool.close()\n    pool.join()\n\n\nif __name__ == '__main__':\n    if args.version == 'v1.0-trainval':\n        sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_train_sweep.pkl'), 'rb'))\n        add_instance_info(sample_infos)\n\n        sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_val_sweep.pkl'), 'rb'))\n        add_instance_info(sample_infos)\n\n    elif args.version == 'v1.0-test':\n        sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_test_sweep.pkl'), 'rb'))\n        add_instance_info(sample_infos)\n\n    else:\n        raise ValueError\n"
  },
  {
    "path": "gen_sweep_info.py",
    "content": "# Generate info files manually\nimport os\nimport mmcv\nimport tqdm\nimport pickle\nimport argparse\nimport numpy as np\nfrom nuscenes import NuScenes\nfrom pyquaternion import Quaternion\n\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--data-root', default='data/nuscenes')\nparser.add_argument('--version', default='v1.0-trainval')\nargs = parser.parse_args()\n\n\ndef get_cam_info(nusc, sample_data):\n    pose_record = nusc.get('ego_pose', sample_data['ego_pose_token'])\n    cs_record = nusc.get('calibrated_sensor', sample_data['calibrated_sensor_token'])\n    \n    sensor2ego_translation = cs_record['translation']\n    ego2global_translation = pose_record['translation']\n    sensor2ego_rotation = Quaternion(cs_record['rotation']).rotation_matrix\n    ego2global_rotation = Quaternion(pose_record['rotation']).rotation_matrix\n    cam_intrinsic = np.array(cs_record['camera_intrinsic'])\n\n    sensor2global_rotation = sensor2ego_rotation.T @ ego2global_rotation.T\n    sensor2global_translation = sensor2ego_translation @ ego2global_rotation.T + ego2global_translation\n\n    return {\n        'data_path': os.path.join(args.data_root, sample_data['filename']),\n        'sensor2global_rotation': sensor2global_rotation,\n        'sensor2global_translation': sensor2global_translation,\n        'cam_intrinsic': cam_intrinsic,\n        'timestamp': sample_data['timestamp'],\n    }\n\n\ndef add_sweep_info(nusc, sample_infos):\n    for curr_id in tqdm.tqdm(range(len(sample_infos['infos']))):\n        sample = nusc.get('sample', sample_infos['infos'][curr_id]['token'])\n\n        cam_types = [\n            'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT',\n            'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT'\n        ]\n\n        curr_cams = dict()\n        for cam in cam_types:\n            curr_cams[cam] = nusc.get('sample_data', sample['data'][cam])\n\n        for cam in cam_types:\n            sample_data = nusc.get('sample_data', sample['data'][cam])\n            sweep_cam = get_cam_info(nusc, sample_data)\n            sample_infos['infos'][curr_id]['cams'][cam].update(sweep_cam)\n\n        # remove unnecessary\n        for cam in cam_types:\n            del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_translation']\n            del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_rotation']\n            del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_translation']\n            del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_rotation']\n\n        sweep_infos = []\n        if sample['prev'] != '':  # add sweep frame between two key frame\n            for _ in range(5):\n                sweep_info = dict()\n                for cam in cam_types: \n                    if curr_cams[cam]['prev'] == '':    \n                        sweep_info = sweep_infos[-1] \n                        break\n                    sample_data = nusc.get('sample_data', curr_cams[cam]['prev'])\n                    sweep_cam = get_cam_info(nusc, sample_data)\n                    curr_cams[cam] = sample_data\n                    sweep_info[cam] = sweep_cam\n                sweep_infos.append(sweep_info)\n\n        sample_infos['infos'][curr_id]['sweeps'] = sweep_infos\n\n    return sample_infos\n\n\nif __name__ == '__main__':\n    nusc = NuScenes(args.version, args.data_root)\n\n    if args.version == 'v1.0-trainval':\n        sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_train.pkl'), 'rb'))\n        sample_infos = add_sweep_info(nusc, sample_infos)\n        mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_train_sweep.pkl'))\n\n        sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_val.pkl'), 'rb'))\n        sample_infos = add_sweep_info(nusc, sample_infos)\n        mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_val_sweep.pkl'))\n\n    elif args.version == 'v1.0-test':\n        sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_test.pkl'), 'rb'))\n        sample_infos = add_sweep_info(nusc, sample_infos)\n        mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_test_sweep.pkl'))\n\n    else:\n        raise ValueError\n"
  },
  {
    "path": "lib/dvr/dvr.cpp",
    "content": "// Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting\n// Modified by Haisong Liu\n\n#include <string>\n#include <torch/extension.h>\n#include <vector>\n\n/*\n * CUDA forward declarations\n */\n\nstd::vector<torch::Tensor> render_forward_cuda(torch::Tensor sigma,\n                                               torch::Tensor origin,\n                                               torch::Tensor points,\n                                               torch::Tensor tindex,\n                                               const std::vector<int> grid,\n                                               std::string phase_name);\n\nstd::vector<torch::Tensor>\nrender_cuda(torch::Tensor sigma, torch::Tensor origin, torch::Tensor points,\n            torch::Tensor tindex, std::string loss_name);\n\ntorch::Tensor init_cuda(torch::Tensor points, torch::Tensor tindex,\n                        const std::vector<int> grid);\n\n\n/*\n * C++ interface\n */\n\n#define CHECK_CUDA(x)                                                          \\\n  TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x)                                                    \\\n  TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x)                                                         \\\n  CHECK_CUDA(x);                                                               \\\n  CHECK_CONTIGUOUS(x)\n\nstd::vector<torch::Tensor>\nrender_forward(torch::Tensor sigma, torch::Tensor origin, torch::Tensor points,\n               torch::Tensor tindex, const std::vector<int> grid,\n               std::string phase_name) {\n  CHECK_INPUT(sigma);\n  CHECK_INPUT(origin);\n  CHECK_INPUT(points);\n  CHECK_INPUT(tindex);\n  return render_forward_cuda(sigma, origin, points, tindex, grid, phase_name);\n}\n\n\nstd::vector<torch::Tensor> render(torch::Tensor sigma, torch::Tensor origin,\n                                  torch::Tensor points, torch::Tensor tindex,\n                                  std::string loss_name) {\n  CHECK_INPUT(sigma);\n  CHECK_INPUT(origin);\n  CHECK_INPUT(points);\n  CHECK_INPUT(tindex);\n  return render_cuda(sigma, origin, points, tindex, loss_name);\n}\n\ntorch::Tensor init(torch::Tensor points, torch::Tensor tindex,\n                   const std::vector<int> grid) {\n  CHECK_INPUT(points);\n  CHECK_INPUT(tindex);\n  return init_cuda(points, tindex, grid);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"init\", &init, \"Initialize\");\n  m.def(\"render\", &render, \"Render\");\n  m.def(\"render_forward\", &render_forward, \"Render (forward pass only)\");\n}\n"
  },
  {
    "path": "lib/dvr/dvr.cu",
    "content": "// Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting\n// Modified by Haisong Liu\n\n#include <torch/extension.h>\n#include <stdio.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <vector>\n#include <string>\n#include <iostream>\n\n#define MAX_D 1446 // 700 + 700 + 45 + 1\n#define MAX_STEP 1000\n\nenum LossType {L1, L2, ABSREL};\nenum PhaseName {TEST, TRAIN};\n\ntemplate <typename scalar_t>\n__global__ void init_cuda_kernel(\n    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> points,\n    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> tindex,\n    torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> occupancy) {\n\n    // batch index\n    const auto n = blockIdx.y;\n\n    // ray index\n    const auto c = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // num of rays\n    const auto M = points.size(1);\n    const auto T = occupancy.size(1);\n\n    // we allocated more threads than num_rays\n    if (c < M) {\n        // ray end point\n        const auto t = tindex[n][c];\n\n        // invalid points\n        assert(T == 1 || t < T);\n\n        // if t < 0, it is a padded point\n        if (t < 0) return;\n\n        // time index for sigma\n        // when T = 1, we have a static sigma\n        const auto ts = (T == 1) ? 0 : t;\n\n        // grid shape\n        const int vzsize = occupancy.size(2);\n        const int vysize = occupancy.size(3);\n        const int vxsize = occupancy.size(4);\n        // assert(vzsize + vysize + vxsize <= MAX_D);\n\n        // end point\n        const int vx = int(points[n][c][0]);\n        const int vy = int(points[n][c][1]);\n        const int vz = int(points[n][c][2]);\n\n        //\n        if (0 <= vx && vx < vxsize &&\n            0 <= vy && vy < vysize &&\n            0 <= vz && vz < vzsize) {\n            occupancy[n][ts][vz][vy][vx] = 1;\n        }\n    }\n}\n\ntemplate <typename scalar_t>\n__global__ void render_forward_cuda_kernel(\n    const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> sigma,\n    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> origin,\n    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> points,\n    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> tindex,\n    // torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> pog,\n    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> pred_dist,\n    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> gt_dist,\n    torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> coord_index,\n    PhaseName train_phase) {\n\n    // batch index\n    const auto n = blockIdx.y;\n\n    // ray index\n    const auto c = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // num of rays\n    const auto M = points.size(1);\n    const auto T = sigma.size(1);\n\n    // we allocated more threads than num_rays\n    if (c < M) {\n        // ray end point\n        const auto t = tindex[n][c];\n\n        // invalid points\n        // assert(t < T);\n        assert(T == 1 || t < T);\n\n        // time index for sigma\n        // when T = 1, we have a static sigma\n        const auto ts = (T == 1) ? 0 : t;\n\n        // if t < 0, it is a padded point\n        if (t < 0) return;\n\n        // grid shape\n        const int vzsize = sigma.size(2);\n        const int vysize = sigma.size(3);\n        const int vxsize = sigma.size(4);\n        // assert(vzsize + vysize + vxsize <= MAX_D);\n\n        // origin\n        const double xo = origin[n][t][0];\n        const double yo = origin[n][t][1];\n        const double zo = origin[n][t][2];\n\n        // end point\n        const double xe = points[n][c][0];\n        const double ye = points[n][c][1];\n        const double ze = points[n][c][2];\n\n        // locate the voxel where the origin resides\n        const int vxo = int(xo);\n        const int vyo = int(yo);\n        const int vzo = int(zo);\n\n        const int vxe = int(xe);\n        const int vye = int(ye);\n        const int vze = int(ze);\n\n        // NOTE: new\n        int vx = vxo;\n        int vy = vyo;\n        int vz = vzo;\n\n        // origin to end\n        const double rx = xe - xo;\n        const double ry = ye - yo;\n        const double rz = ze - zo;\n        double gt_d = sqrt(rx * rx + ry * ry + rz * rz);\n\n        // directional vector\n        const double dx = rx / gt_d;\n        const double dy = ry / gt_d;\n        const double dz = rz / gt_d;\n\n        // In which direction the voxel ids are incremented.\n        const int stepX = (dx >= 0) ? 1 : -1;\n        const int stepY = (dy >= 0) ? 1 : -1;\n        const int stepZ = (dz >= 0) ? 1 : -1;\n\n        // Distance along the ray to the next voxel border from the current position (tMaxX, tMaxY, tMaxZ).\n        const double next_voxel_boundary_x = vx + (stepX < 0 ? 0 : 1);\n        const double next_voxel_boundary_y = vy + (stepY < 0 ? 0 : 1);\n        const double next_voxel_boundary_z = vz + (stepZ < 0 ? 0 : 1);\n\n        // tMaxX, tMaxY, tMaxZ -- distance until next intersection with voxel-border\n        // the value of t at which the ray crosses the first vertical voxel boundary\n        double tMaxX = (dx!=0) ? (next_voxel_boundary_x - xo)/dx : DBL_MAX; //\n        double tMaxY = (dy!=0) ? (next_voxel_boundary_y - yo)/dy : DBL_MAX; //\n        double tMaxZ = (dz!=0) ? (next_voxel_boundary_z - zo)/dz : DBL_MAX; //\n\n        // tDeltaX, tDeltaY, tDeltaZ --\n        // how far along the ray we must move for the horizontal component to equal the width of a voxel\n        // the direction in which we traverse the grid\n        // can only be FLT_MAX if we never go in that direction\n        const double tDeltaX = (dx!=0) ? stepX/dx : DBL_MAX;\n        const double tDeltaY = (dy!=0) ? stepY/dy : DBL_MAX;\n        const double tDeltaZ = (dz!=0) ? stepZ/dz : DBL_MAX;\n\n        int3 path[MAX_D];\n        double csd[MAX_D];  // cumulative sum of sigma times delta\n        double p[MAX_D];  // alpha\n        double d[MAX_D];\n\n        // forward raymarching with voxel traversal\n        int step = 0;  // total number of voxels traversed\n        int count = 0;  // number of voxels traversed inside the voxel grid\n        double last_d = 0.0;  // correct initialization\n\n        // voxel traversal raycasting\n        bool was_inside = false;\n        while (true) {\n            bool inside = (0 <= vx && vx < vxsize) &&\n                (0 <= vy && vy < vysize) &&\n                (0 <= vz && vz < vzsize);\n            if (inside) {\n                was_inside = true;\n                path[count] = make_int3(vx, vy, vz);\n            } else if (was_inside) { // was but no longer inside\n                // we know we are not coming back so terminate\n                break;\n            } /*else if (last_d > gt_d) {\n                break;\n            } */\n            /*else { // has not gone inside yet\n                // assert(count == 0);\n                // (1) when we have hit the destination but haven't gone inside the voxel grid\n                // (2) when we have traveled MAX_D voxels but haven't found one valid voxel\n                //     handle intersection corner cases in case of infinite loop\n                bool hit = (vx == vxe && vy == vye && vz == vze);  // this test seems brittle with corner cases\n                if (hit || step >= MAX_D)\n                    break;\n                //if (last_d >= gt_d || step >= MAX_D) break;\n            } */\n            // _d represents the ray distance has traveled before escaping the current voxel cell\n            double _d = 0.0;\n            // voxel traversal\n            if (tMaxX < tMaxY) {\n                if (tMaxX < tMaxZ) {\n                    _d = tMaxX;\n                    vx += stepX;\n                    tMaxX += tDeltaX;\n                } else {\n                    _d = tMaxZ;\n                    vz += stepZ;\n                    tMaxZ += tDeltaZ;\n                }\n            } else {\n                if (tMaxY < tMaxZ) {\n                    _d = tMaxY;\n                    vy += stepY;\n                    tMaxY += tDeltaY;\n                } else {\n                    _d = tMaxZ;\n                    vz += stepZ;\n                    tMaxZ += tDeltaZ;\n                }\n            }\n            if (inside) {\n                // get sigma at the current voxel\n                const int3 &v = path[count];  // use the recorded index\n                const double _sigma = sigma[n][ts][v.z][v.y][v.x];\n                const double _delta = max(0.0, _d - last_d);  // THIS TURNS OUT IMPORTANT\n                const double sd = _sigma * _delta;\n                if (count == 0) { // the first voxel inside\n                    csd[count] = sd;\n                    p[count] = 1 - exp(-sd);\n                } else {\n                    csd[count] = csd[count-1] + sd;\n                    p[count] = exp(-csd[count-1]) - exp(-csd[count]);\n                }\n                // record the traveled distance\n                d[count] = _d;\n                // count the number of voxels we have escaped\n                count ++;\n            }\n            last_d = _d;\n            step ++;\n\n            if (step > MAX_STEP) {\n                break;\n            }\n        }\n\n        // the total number of voxels visited should not exceed this number\n        assert(count <= MAX_D);\n        \n        if (count > 0) {\n            // compute the expected ray distance\n            //double exp_d = 0.0;\n            double exp_d = d[count-1];\n            \n            const int3 &v_init = path[count-1];\n            int x = v_init.x;\n            int y = v_init.y;\n            int z = v_init.z;\n\n            for (int i = 0; i < count; i++) {\n                //printf(\"%f\\t%f\\n\",p[i], d[i]);\n                //exp_d += p[i] * d[i];\n                const int3 &v = path[i];\n                const double occ = sigma[n][ts][v.z][v.y][v.x];\n                if (occ > 0.5) {\n                    exp_d = d[i];\n                    \n                    x = v.x;\n                    y = v.y;\n                    z = v.z;\n                \n                    break;\n                }\n\n            }\n            //printf(\"%f\\n\",exp_d);\n\n            // add an imaginary sample at the end point should gt_d exceeds max_d\n            double p_out = exp(-csd[count-1]);\n            double max_d = d[count-1];\n\n            // if (gt_d > max_d)\n            //   exp_d += (p_out * gt_d);\n\n            // p_out is the probability the ray escapes the voxel grid\n            //exp_d += (p_out * max_d);\n            if (train_phase == 1) {\n                gt_d = min(gt_d, max_d);\n            }\n\n            // write the rendered ray distance (max_d)\n            pred_dist[n][c] = exp_d;\n            gt_dist[n][c] = gt_d;\n          \n            coord_index[n][c][0] = double(x);\n            coord_index[n][c][1] = double(y);\n            coord_index[n][c][2] = double(z);\n\n            // // write occupancy\n            // for (int i = 0; i < count; i ++) {\n            //     const int3 &v = path[i];\n            //     auto & occ = pog[n][t][v.z][v.y][v.x];\n            //     if (p[i] >= occ) {\n            //         occ = p[i];\n            //     }\n            // }\n        }\n    }\n}\n\n/*\n * input shape\n *   sigma      : N x T x H x L x W\n *   origin   : N x T x 3\n *   points   : N x M x 4\n * output shape\n *   dist     : N x M\n */\nstd::vector<torch::Tensor> render_forward_cuda(\n    torch::Tensor sigma,\n    torch::Tensor origin,\n    torch::Tensor points,\n    torch::Tensor tindex,\n    const std::vector<int> grid,\n    std::string phase_name) {\n\n    const auto N = points.size(0); // batch size\n    const auto M = points.size(1); // num of rays\n\n    const auto T = grid[0];\n    const auto H = grid[1];\n    const auto L = grid[2];\n    const auto W = grid[3];\n\n    const auto device = sigma.device();\n\n    const int threads = 1024;\n    const dim3 blocks((M + threads - 1) / threads, N);\n\n    //\n    // const auto dtype = points.dtype();\n    // const auto options = torch::TensorOptions().dtype(dtype).device(device).requires_grad(false);\n    // auto pog = torch::zeros({N, T, H, L, W}, options);\n\n    // perform rendering\n    auto gt_dist = -torch::ones({N, M}, device);\n    auto pred_dist = -torch::ones({N, M}, device);\n\n    auto coord_index = torch::zeros({N, M, 3}, device);\n\n    PhaseName train_phase;\n    if (phase_name.compare(\"test\") == 0) {\n        train_phase = TEST;\n    } else if (phase_name.compare(\"train\") == 0){\n        train_phase = TRAIN;\n    } else {\n        std::cout << \"UNKNOWN PHASE NAME: \" << phase_name << std::endl;\n        exit(1);\n    }\n\n    AT_DISPATCH_FLOATING_TYPES(sigma.type(), \"render_forward_cuda\", ([&] {\n                render_forward_cuda_kernel<scalar_t><<<blocks, threads>>>(\n                    sigma.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),\n                    origin.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),\n                    points.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),\n                    tindex.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),\n                    // pog.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),\n                    pred_dist.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),\n                    gt_dist.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),\n                    coord_index.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),\n                    train_phase);\n            }));\n\n    cudaDeviceSynchronize();\n\n    // return {pog, pred_dist, gt_dist};\n    return {pred_dist, gt_dist, coord_index};\n}\n\ntemplate <typename scalar_t>\n__global__ void render_cuda_kernel(\n    const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> sigma,\n    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> origin,\n    const torch::PackedTensorAccessor32<scalar_t,3,torch::RestrictPtrTraits> points,\n    const torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> tindex,\n    // const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> occupancy,\n    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> pred_dist,\n    torch::PackedTensorAccessor32<scalar_t,2,torch::RestrictPtrTraits> gt_dist,\n    torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> grad_sigma,\n    // torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> grad_sigma_count,\n    LossType loss_type) {\n\n    // batch index\n    const auto n = blockIdx.y;\n\n    // ray index\n    const auto c = blockIdx.x * blockDim.x + threadIdx.x;\n\n    // num of rays\n    const auto M = points.size(1);\n    const auto T = sigma.size(1);\n\n    // we allocated more threads than num_rays\n    if (c < M) {\n        // ray end point\n        const auto t = tindex[n][c];\n\n        // invalid points\n        // assert(t < T);\n        assert(T == 1 || t < T);\n\n        // time index for sigma\n        // when T = 1, we have a static sigma\n        const auto ts = (T == 1) ? 0 : t;\n\n        // if t < 0, it is a padded point\n        if (t < 0) return;\n\n        // grid shape\n        const int vzsize = sigma.size(2);\n        const int vysize = sigma.size(3);\n        const int vxsize = sigma.size(4);\n        // assert(vzsize + vysize + vxsize <= MAX_D);\n\n        // origin\n        const double xo = origin[n][t][0];\n        const double yo = origin[n][t][1];\n        const double zo = origin[n][t][2];\n\n        // end point\n        const double xe = points[n][c][0];\n        const double ye = points[n][c][1];\n        const double ze = points[n][c][2];\n\n        // locate the voxel where the origin resides\n        const int vxo = int(xo);\n        const int vyo = int(yo);\n        const int vzo = int(zo);\n\n        //\n        const int vxe = int(xe);\n        const int vye = int(ye);\n        const int vze = int(ze);\n\n        // NOTE: new\n        int vx = vxo;\n        int vy = vyo;\n        int vz = vzo;\n\n        // origin to end\n        const double rx = xe - xo;\n        const double ry = ye - yo;\n        const double rz = ze - zo;\n        double gt_d = sqrt(rx * rx + ry * ry + rz * rz);\n\n        // directional vector\n        const double dx = rx / gt_d;\n        const double dy = ry / gt_d;\n        const double dz = rz / gt_d;\n\n        // In which direction the voxel ids are incremented.\n        const int stepX = (dx >= 0) ? 1 : -1;\n        const int stepY = (dy >= 0) ? 1 : -1;\n        const int stepZ = (dz >= 0) ? 1 : -1;\n\n        // Distance along the ray to the next voxel border from the current position (tMaxX, tMaxY, tMaxZ).\n        const double next_voxel_boundary_x = vx + (stepX < 0 ? 0 : 1);\n        const double next_voxel_boundary_y = vy + (stepY < 0 ? 0 : 1);\n        const double next_voxel_boundary_z = vz + (stepZ < 0 ? 0 : 1);\n\n        // tMaxX, tMaxY, tMaxZ -- distance until next intersection with voxel-border\n        // the value of t at which the ray crosses the first vertical voxel boundary\n        double tMaxX = (dx!=0) ? (next_voxel_boundary_x - xo)/dx : DBL_MAX; //\n        double tMaxY = (dy!=0) ? (next_voxel_boundary_y - yo)/dy : DBL_MAX; //\n        double tMaxZ = (dz!=0) ? (next_voxel_boundary_z - zo)/dz : DBL_MAX; //\n\n        // tDeltaX, tDeltaY, tDeltaZ --\n        // how far along the ray we must move for the horizontal component to equal the width of a voxel\n        // the direction in which we traverse the grid\n        // can only be FLT_MAX if we never go in that direction\n        const double tDeltaX = (dx!=0) ? stepX/dx : DBL_MAX;\n        const double tDeltaY = (dy!=0) ? stepY/dy : DBL_MAX;\n        const double tDeltaZ = (dz!=0) ? stepZ/dz : DBL_MAX;\n\n        int3 path[MAX_D];\n        double csd[MAX_D];  // cumulative sum of sigma times delta\n        double p[MAX_D];  // alpha\n        double d[MAX_D];\n        double dt[MAX_D];\n\n        // forward raymarching with voxel traversal\n        int step = 0;  // total number of voxels traversed\n        int count = 0;  // number of voxels traversed inside the voxel grid\n        double last_d = 0.0;  // correct initialization\n\n        // voxel traversal raycasting\n        bool was_inside = false;\n        while (true) {\n            bool inside = (0 <= vx && vx < vxsize) &&\n                (0 <= vy && vy < vysize) &&\n                (0 <= vz && vz < vzsize);\n            if (inside) { // now inside\n                was_inside = true;\n                path[count] = make_int3(vx, vy, vz);\n            } else if (was_inside) { // was inside but no longer\n                // we know we are not coming back so terminate\n                break;\n            } else if (last_d > gt_d) {\n                break;\n            } /* else { // has not gone inside yet\n                // assert(count == 0);\n                // (1) when we have hit the destination but haven't gone inside the voxel grid\n                // (2) when we have traveled MAX_D voxels but haven't found one valid voxel\n                //     handle intersection corner cases in case of infinite loop\n                // bool hit = (vx == vxe && vy == vye && vz == vze);\n                // if (hit || step >= MAX_D)\n                //     break;\n                if (last_d >= gt_d || step >= MAX_D) break;\n            } */\n            // _d represents the ray distance has traveled before escaping the current voxel cell\n            double _d = 0.0;\n            // voxel traversal\n            if (tMaxX < tMaxY) {\n                if (tMaxX < tMaxZ) {\n                    _d = tMaxX;\n                    vx += stepX;\n                    tMaxX += tDeltaX;\n                } else {\n                    _d = tMaxZ;\n                    vz += stepZ;\n                    tMaxZ += tDeltaZ;\n                }\n            } else {\n                if (tMaxY < tMaxZ) {\n                    _d = tMaxY;\n                    vy += stepY;\n                    tMaxY += tDeltaY;\n                } else {\n                    _d = tMaxZ;\n                    vz += stepZ;\n                    tMaxZ += tDeltaZ;\n                }\n            }\n            if (inside) {\n                // get sigma at the current voxel\n                const int3 &v = path[count];  // use the recorded index\n                const double _sigma = sigma[n][ts][v.z][v.y][v.x];\n                const double _delta = max(0.0, _d - last_d);  // THIS TURNS OUT IMPORTANT\n                const double sd = _sigma * _delta;\n                if (count == 0) { // the first voxel inside\n                    csd[count] = sd;\n                    p[count] = 1 - exp(-sd);\n                } else {\n                    csd[count] = csd[count-1] + sd;\n                    p[count] = exp(-csd[count-1]) - exp(-csd[count]);\n                }\n                // record the traveled distance\n                d[count] = _d;\n                dt[count] = _delta;\n                // count the number of voxels we have escaped\n                count ++;\n            }\n            last_d = _d;\n            step ++;\n\n            if (step > MAX_STEP) {\n                break;\n            }\n        }\n\n        // the total number of voxels visited should not exceed this number\n        assert(count <= MAX_D);\n\n        // WHEN THERE IS AN INTERSECTION BETWEEN THE RAY AND THE VOXEL GRID\n        if (count > 0) {\n            // compute the expected ray distance\n            double exp_d = 0.0;\n            for (int i = 0; i < count; i ++)\n                exp_d += p[i] * d[i];\n\n            // add an imaginary sample at the end point should gt_d exceeds max_d\n            double p_out = exp(-csd[count-1]);\n            double max_d = d[count-1];\n\n            exp_d += (p_out * max_d);\n            gt_d = min(gt_d, max_d);\n\n            // write the rendered ray distance (max_d)\n            pred_dist[n][c] = exp_d;\n            gt_dist[n][c] = gt_d;\n\n            /* backward raymarching */\n            double dd_dsigma[MAX_D];\n            for (int i = count - 1; i >= 0; i --) {\n                // NOTE: probably need to double check again\n                if (i == count - 1)\n                    dd_dsigma[i] = p_out * max_d;\n                else\n                    dd_dsigma[i] = dd_dsigma[i+1] - exp(-csd[i]) * (d[i+1] - d[i]);\n            }\n\n            for (int i = count - 1; i >= 0; i --)\n                dd_dsigma[i] *= dt[i];\n\n            // option 2: cap at the boundary\n            for (int i = count - 1; i >= 0; i --)\n                dd_dsigma[i] -= dt[i] * p_out * max_d;\n\n            double dl_dd = 1.0;\n            if (loss_type == L1)\n                dl_dd = (exp_d >= gt_d) ? 1 : -1;\n            else if (loss_type == L2)\n                dl_dd = (exp_d - gt_d);\n            else if (loss_type == ABSREL)\n                dl_dd = (exp_d >= gt_d) ? (1.0/gt_d) : -(1.0/gt_d);\n\n            // apply chain rule\n            for (int i = 0; i < count; i ++) {\n                const int3 &v = path[i];\n                // NOTE: potential race conditions when writing gradients\n                grad_sigma[n][ts][v.z][v.y][v.x] += dl_dd * dd_dsigma[i];\n                // grad_sigma_count[n][ts][v.z][v.y][v.x] += 1;\n            }\n        }\n    }\n}\n\n/*\n * input shape\n *   sigma      : N x T x H x L x W\n *   origin   : N x T x 3\n *   points   : N x M x 4\n * output shape\n *   dist     : N x M\n *   loss     : N x M\n *   grad_sigma : N x T x H x L x W\n */\nstd::vector<torch::Tensor> render_cuda(\n    torch::Tensor sigma,\n    torch::Tensor origin,\n    torch::Tensor points,\n    torch::Tensor tindex,\n    std::string loss_name) {\n\n    const auto N = points.size(0); // batch size\n    const auto M = points.size(1); // num of rays\n\n    const auto device = sigma.device();\n\n    const int threads = 1024;\n    const dim3 blocks((M + threads - 1) / threads, N);\n\n    // perform rendering\n    auto gt_dist = -torch::ones({N, M}, device);\n    auto pred_dist = -torch::ones({N, M}, device);\n    auto grad_sigma = torch::zeros_like(sigma);\n    // auto grad_sigma_count = torch::zeros_like(sigma);\n\n    LossType loss_type;\n    if (loss_name.compare(\"l1\") == 0) {\n        loss_type = L1;\n    } else if (loss_name.compare(\"l2\") == 0) {\n        loss_type = L2;\n    } else if (loss_name.compare(\"absrel\") == 0) {\n        loss_type = ABSREL;\n    } else if (loss_name.compare(\"bce\") == 0){\n        loss_type = L1;\n    } else {\n        std::cout << \"UNKNOWN LOSS TYPE: \" << loss_name << std::endl;\n        exit(1);\n    }\n\n    AT_DISPATCH_FLOATING_TYPES(sigma.type(), \"render_cuda\", ([&] {\n                render_cuda_kernel<scalar_t><<<blocks, threads>>>(\n                    sigma.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),\n                    origin.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),\n                    points.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),\n                    tindex.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),\n                    // occupancy.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),\n                    pred_dist.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),\n                    gt_dist.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),\n                    grad_sigma.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),\n                    // grad_sigma_count.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),\n                    loss_type);\n            }));\n\n    cudaDeviceSynchronize();\n\n    // grad_sigma_count += (grad_sigma_count == 0);\n    // grad_sigma /= grad_sigma_count;\n\n    return {pred_dist, gt_dist, grad_sigma};\n}\n\n\n/*\n * input shape\n *   origin   : N x T x 3\n *   points   : N x M x 3\n *   tindex   : N x M\n * output shape\n *   occupancy: N x T x H x L x W\n */\ntorch::Tensor init_cuda(\n    torch::Tensor points,\n    torch::Tensor tindex,\n    const std::vector<int> grid) {\n\n    const auto N = points.size(0); // batch size\n    const auto M = points.size(1); // num of rays\n\n    const auto T = grid[0];\n    const auto H = grid[1];\n    const auto L = grid[2];\n    const auto W = grid[3];\n\n    const auto dtype = points.dtype();\n    const auto device = points.device();\n    const auto options = torch::TensorOptions().dtype(dtype).device(device).requires_grad(false);\n    auto occupancy = torch::zeros({N, T, H, L, W}, options);\n\n    const int threads = 1024;\n    const dim3 blocks((M + threads - 1) / threads, N);\n\n    // initialize occupancy such that every voxel with one or more points is occupied\n    AT_DISPATCH_FLOATING_TYPES(points.type(), \"init_cuda\", ([&] {\n                init_cuda_kernel<scalar_t><<<blocks, threads>>>(\n                    points.packed_accessor32<scalar_t,3,torch::RestrictPtrTraits>(),\n                    tindex.packed_accessor32<scalar_t,2,torch::RestrictPtrTraits>(),\n                    occupancy.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>());\n            }));\n\n    // synchronize\n    cudaDeviceSynchronize();\n\n    return occupancy;\n}"
  },
  {
    "path": "loaders/__init__.py",
    "content": "from .pipelines import __all__\nfrom .nuscenes_dataset import CustomNuScenesDataset\nfrom .nuscenes_occ_dataset import NuSceneOcc\n\n__all__ = [\n    'CustomNuScenesDataset', 'NuSceneOcc'\n]\n"
  },
  {
    "path": "loaders/builder.py",
    "content": "from functools import partial\nfrom mmcv.parallel import collate\nfrom mmcv.runner import get_dist_info\nfrom torch.utils.data import DataLoader\nfrom mmdet.datasets.builder import worker_init_fn\nfrom mmdet.datasets.samplers import DistributedGroupSampler, DistributedSampler, GroupSampler\n\n\ndef build_dataloader(dataset,\n                     samples_per_gpu,\n                     workers_per_gpu,\n                     num_gpus=1,\n                     dist=True,\n                     shuffle=True,\n                     seed=None,\n                     **kwargs):\n\n    rank, world_size = get_dist_info()\n    if dist:\n        # DistributedGroupSampler will definitely shuffle the data to satisfy\n        # that images on each GPU are in the same group\n        if shuffle:\n            sampler = DistributedGroupSampler(\n                dataset, samples_per_gpu, world_size, rank, seed=seed)\n        else:\n            sampler = DistributedSampler(\n                dataset, world_size, rank, shuffle=False, seed=seed)\n        batch_size = samples_per_gpu\n        num_workers = workers_per_gpu\n    else:\n        sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None\n        batch_size = num_gpus * samples_per_gpu\n        num_workers = num_gpus * workers_per_gpu\n\n    init_fn = partial(\n        worker_init_fn, num_workers=num_workers, rank=rank,\n        seed=seed) if seed is not None else None\n\n    data_loader = DataLoader(\n        dataset,\n        batch_size=batch_size,\n        sampler=sampler,\n        num_workers=num_workers,\n        collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),\n        pin_memory=False,\n        worker_init_fn=init_fn,\n        **kwargs)\n\n    return data_loader\n"
  },
  {
    "path": "loaders/ego_pose_dataset.py",
    "content": "import torch\nimport numpy as np\nfrom pyquaternion import Quaternion\nfrom torch.utils.data import Dataset\nnp.set_printoptions(precision=3, suppress=True)\n\n\ndef trans_matrix(T, R):\n    tm = np.eye(4)\n    tm[:3, :3] = R.rotation_matrix\n    tm[:3, 3] = T\n    return tm\n\n\n# A helper dataset for RayIoU. It is NOT used during training.\nclass EgoPoseDataset(Dataset):\n    def __init__(self, data_infos):\n        super(EgoPoseDataset, self).__init__()\n\n        self.data_infos = data_infos\n        self.scene_frames = {}\n\n        for info in data_infos:\n            scene_name = info['scene_name']\n            if scene_name not in self.scene_frames:\n                self.scene_frames[scene_name] = []\n            self.scene_frames[scene_name].append(info)\n\n    def __len__(self):\n        return len(self.data_infos)\n\n    def get_ego_from_lidar(self, info):\n        ego_from_lidar = trans_matrix(\n            np.array(info['lidar2ego_translation']), \n            Quaternion(info['lidar2ego_rotation']))\n        return ego_from_lidar\n\n    def get_global_pose(self, info, inverse=False):\n        global_from_ego = trans_matrix(\n            np.array(info['ego2global_translation']), \n            Quaternion(info['ego2global_rotation']))\n        ego_from_lidar = trans_matrix(\n            np.array(info['lidar2ego_translation']), \n            Quaternion(info['lidar2ego_rotation']))\n        pose = global_from_ego.dot(ego_from_lidar)\n        if inverse:\n            pose = np.linalg.inv(pose)\n        return pose\n\n    def __getitem__(self, idx):\n        info = self.data_infos[idx]\n\n        ref_sample_token = info['token']\n        ref_lidar_from_global = self.get_global_pose(info, inverse=True)\n        ref_ego_from_lidar = self.get_ego_from_lidar(info)\n\n        scene_frame = self.scene_frames[info['scene_name']]\n        ref_index = scene_frame.index(info)\n\n        # NOTE: getting output frames\n        output_origin_list = []\n        for curr_index in range(len(scene_frame)):\n            # if this exists a valid target\n            if curr_index == ref_index:\n                origin_tf = np.array([0.0, 0.0, 0.0], dtype=np.float32)\n            else:\n                # transform from the current lidar frame to global and then to the reference lidar frame\n                global_from_curr = self.get_global_pose(scene_frame[curr_index], inverse=False)\n                ref_from_curr = ref_lidar_from_global.dot(global_from_curr)\n                origin_tf = np.array(ref_from_curr[:3, 3], dtype=np.float32)\n\n            origin_tf_pad = np.ones([4])\n            origin_tf_pad[:3] = origin_tf  # pad to [4]\n            origin_tf = np.dot(ref_ego_from_lidar[:3], origin_tf_pad.T).T  # [3]\n\n            # origin\n            if np.abs(origin_tf[0]) < 39 and np.abs(origin_tf[1]) < 39:\n                output_origin_list.append(origin_tf)\n        \n        # select 8 origins\n        if len(output_origin_list) > 8:\n            select_idx = np.round(np.linspace(0, len(output_origin_list) - 1, 8)).astype(np.int64)\n            output_origin_list = [output_origin_list[i] for i in select_idx]\n\n        output_origin_tensor = torch.from_numpy(np.stack(output_origin_list))  # [T, 3]\n\n        return (ref_sample_token, output_origin_tensor)\n"
  },
  {
    "path": "loaders/nuscenes_dataset.py",
    "content": "import os\nimport numpy as np\nfrom mmdet.datasets import DATASETS\nfrom mmdet3d.datasets import NuScenesDataset\nfrom pyquaternion import Quaternion\n\n\n@DATASETS.register_module()\nclass CustomNuScenesDataset(NuScenesDataset):\n\n    def collect_sweeps(self, index, into_past=60, into_future=0):\n        all_sweeps_prev = []\n        curr_index = index\n        while len(all_sweeps_prev) < into_past:\n            curr_sweeps = self.data_infos[curr_index]['sweeps']\n            if len(curr_sweeps) == 0:\n                break\n            all_sweeps_prev.extend(curr_sweeps)\n            all_sweeps_prev.append(self.data_infos[curr_index - 1]['cams'])\n            curr_index = curr_index - 1\n        \n        all_sweeps_next = []\n        curr_index = index + 1\n        while len(all_sweeps_next) < into_future:\n            if curr_index >= len(self.data_infos):\n                break\n            curr_sweeps = self.data_infos[curr_index]['sweeps']\n            all_sweeps_next.extend(curr_sweeps[::-1])\n            all_sweeps_next.append(self.data_infos[curr_index]['cams'])\n            curr_index = curr_index + 1\n\n        return all_sweeps_prev, all_sweeps_next\n\n    def get_data_info(self, index):\n        info = self.data_infos[index]\n        sweeps_prev, sweeps_next = self.collect_sweeps(index)\n\n        ego2global_translation = info['ego2global_translation']\n        ego2global_rotation = info['ego2global_rotation']\n        lidar2ego_translation = info['lidar2ego_translation']\n        lidar2ego_rotation = info['lidar2ego_rotation']\n        ego2global_rotation = Quaternion(ego2global_rotation).rotation_matrix\n        lidar2ego_rotation = Quaternion(lidar2ego_rotation).rotation_matrix\n\n        input_dict = dict(\n            sample_idx=info['token'],\n            sweeps={'prev': sweeps_prev, 'next': sweeps_next},\n            timestamp=info['timestamp'] / 1e6,\n            ego2global_translation=ego2global_translation,\n            ego2global_rotation=ego2global_rotation,\n            lidar2ego_translation=lidar2ego_translation,\n            lidar2ego_rotation=lidar2ego_rotation,\n        )\n\n        if self.modality['use_camera']:\n            img_paths = []\n            img_timestamps = []\n            lidar2img_rts = []\n\n            for _, cam_info in info['cams'].items():\n                img_paths.append(os.path.relpath(cam_info['data_path']))\n                img_timestamps.append(cam_info['timestamp'] / 1e6)\n\n                # obtain lidar to image transformation matrix\n                lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation'])\n                lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T\n\n                lidar2cam_rt = np.eye(4)\n                lidar2cam_rt[:3, :3] = lidar2cam_r.T\n                lidar2cam_rt[3, :3] = -lidar2cam_t\n                \n                intrinsic = cam_info['cam_intrinsic']\n                viewpad = np.eye(4)\n                viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic\n                lidar2img_rt = (viewpad @ lidar2cam_rt.T)\n                lidar2img_rts.append(lidar2img_rt)\n\n            input_dict.update(dict(\n                img_filename=img_paths,\n                img_timestamp=img_timestamps,\n                lidar2img=lidar2img_rts,\n            ))\n\n        if not self.test_mode:\n            annos = self.get_ann_info(index)\n            input_dict['ann_info'] = annos\n\n        return input_dict"
  },
  {
    "path": "loaders/nuscenes_occ_dataset.py",
    "content": "import os\nimport mmcv\nimport glob\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom mmdet.datasets import DATASETS\nfrom mmdet3d.datasets import NuScenesDataset\nfrom nuscenes.eval.common.utils import Quaternion\nfrom nuscenes.utils.geometry_utils import transform_matrix\nfrom torch.utils.data import DataLoader\nfrom models.utils import sparse2dense\nfrom .ray_metrics import main_rayiou, main_raypq\nfrom .ego_pose_dataset import EgoPoseDataset\nfrom configs.r50_nuimg_704x256_8f import occ_class_names as occ3d_class_names\nfrom configs.r50_nuimg_704x256_8f_openocc import occ_class_names as openocc_class_names\n\n@DATASETS.register_module()\nclass NuSceneOcc(NuScenesDataset):    \n    def __init__(self, occ_gt_root, *args, **kwargs):\n        super().__init__(filter_empty_gt=False, *args, **kwargs)\n        self.occ_gt_root = occ_gt_root\n        self.data_infos = self.load_annotations(self.ann_file)\n\n        self.token2scene = {}\n        for gt_path in glob.glob(os.path.join(self.occ_gt_root, '*/*/*.npz')):\n            token = gt_path.split('/')[-2]\n            scene_name = gt_path.split('/')[-3]\n            self.token2scene[token] = scene_name\n\n        for i in range(len(self.data_infos)):\n            scene_name = self.token2scene[self.data_infos[i]['token']]\n            self.data_infos[i]['scene_name'] = scene_name\n\n    def collect_sweeps(self, index, into_past=150, into_future=0):\n        all_sweeps_prev = []\n        curr_index = index\n        while len(all_sweeps_prev) < into_past:\n            curr_sweeps = self.data_infos[curr_index]['sweeps']\n            if len(curr_sweeps) == 0:\n                break\n            all_sweeps_prev.extend(curr_sweeps)\n            all_sweeps_prev.append(self.data_infos[curr_index - 1]['cams'])\n            curr_index = curr_index - 1\n        \n        all_sweeps_next = []\n        curr_index = index + 1\n        while len(all_sweeps_next) < into_future:\n            if curr_index >= len(self.data_infos):\n                break\n            curr_sweeps = self.data_infos[curr_index]['sweeps']\n            all_sweeps_next.extend(curr_sweeps[::-1])\n            all_sweeps_next.append(self.data_infos[curr_index]['cams'])\n            curr_index = curr_index + 1\n\n        return all_sweeps_prev, all_sweeps_next\n\n    def get_data_info(self, index):\n        info = self.data_infos[index]\n        sweeps_prev, sweeps_next = self.collect_sweeps(index)\n\n        ego2global_translation = info['ego2global_translation']\n        ego2global_rotation = info['ego2global_rotation']\n        lidar2ego_translation = info['lidar2ego_translation']\n        lidar2ego_rotation = info['lidar2ego_rotation']\n        ego2global_rotation_mat = Quaternion(ego2global_rotation).rotation_matrix\n        lidar2ego_rotation_mat = Quaternion(lidar2ego_rotation).rotation_matrix\n\n        input_dict = dict(\n            sample_idx=info['token'],\n            sweeps={'prev': sweeps_prev, 'next': sweeps_next},\n            timestamp=info['timestamp'] / 1e6,\n            ego2global_translation=ego2global_translation,\n            ego2global_rotation=ego2global_rotation_mat,\n            lidar2ego_translation=lidar2ego_translation,\n            lidar2ego_rotation=lidar2ego_rotation_mat,\n        )\n\n        ego2lidar = transform_matrix(lidar2ego_translation, Quaternion(lidar2ego_rotation), inverse=True)\n        input_dict['ego2lidar'] = [ego2lidar for _ in range(6)]\n        input_dict['occ_path'] = os.path.join(self.occ_gt_root, info['scene_name'], info['token'], 'labels.npz')\n\n        if self.modality['use_camera']:\n            img_paths = []\n            img_timestamps = []\n            lidar2img_rts = []\n\n            for _, cam_info in info['cams'].items():\n                img_paths.append(os.path.relpath(cam_info['data_path']))\n                img_timestamps.append(cam_info['timestamp'] / 1e6)\n\n                # obtain lidar to image transformation matrix\n                lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation'])\n                lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T\n\n                lidar2cam_rt = np.eye(4)\n                lidar2cam_rt[:3, :3] = lidar2cam_r.T\n                lidar2cam_rt[3, :3] = -lidar2cam_t\n                \n                intrinsic = cam_info['cam_intrinsic']\n                viewpad = np.eye(4)\n                viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic\n                lidar2img_rt = (viewpad @ lidar2cam_rt.T)\n                lidar2img_rts.append(lidar2img_rt)\n\n            input_dict.update(dict(\n                img_filename=img_paths,\n                img_timestamp=img_timestamps,\n                lidar2img=lidar2img_rts,\n            ))\n\n        if not self.test_mode:\n            annos = self.get_ann_info(index)\n            input_dict['ann_info'] = annos\n\n        return input_dict\n\n    def evaluate(self, occ_results, runner=None, show_dir=None, **eval_kwargs):\n        occ_gts, occ_preds, inst_gts, inst_preds, lidar_origins = [], [], [], [], []\n        print('\\nStarting Evaluation...')\n\n        sample_tokens = [info['token'] for info in self.data_infos]\n\n        for batch in DataLoader(EgoPoseDataset(self.data_infos), num_workers=8):\n            token = batch[0][0]\n            output_origin = batch[1]\n            \n            data_id = sample_tokens.index(token)\n            info = self.data_infos[data_id]\n\n            occ_path = os.path.join(self.occ_gt_root, info['scene_name'], info['token'], 'labels.npz')\n            occ_gt = np.load(occ_path, allow_pickle=True)\n            gt_semantics = occ_gt['semantics']\n\n            occ_pred = occ_results[data_id]\n            sem_pred = torch.from_numpy(occ_pred['sem_pred'])  # [B, N]\n            occ_loc = torch.from_numpy(occ_pred['occ_loc'].astype(np.int64))  # [B, N, 3]\n            \n            data_type = self.occ_gt_root.split('/')[-1]\n            if data_type == 'occ3d' or data_type == 'occ3d_panoptic':\n                occ_class_names = occ3d_class_names\n            elif data_type == 'openocc_v2':\n                occ_class_names = openocc_class_names\n            else:\n                raise ValueError\n            free_id = len(occ_class_names) - 1\n            \n            occ_size = list(gt_semantics.shape)\n            sem_pred, _ = sparse2dense(occ_loc, sem_pred, dense_shape=occ_size, empty_value=free_id)\n            sem_pred = sem_pred.squeeze(0).numpy()\n\n            if 'pano_inst' in occ_pred.keys():\n                pano_inst = torch.from_numpy(occ_pred['pano_inst'])\n                pano_sem = torch.from_numpy(occ_pred['pano_sem'])\n\n                pano_inst, _ = sparse2dense(occ_loc, pano_inst, dense_shape=occ_size, empty_value=0)\n                pano_sem, _ = sparse2dense(occ_loc, pano_sem, dense_shape=occ_size, empty_value=free_id)\n                pano_inst = pano_inst.squeeze(0).numpy()\n                pano_sem = pano_sem.squeeze(0).numpy()\n                sem_pred = pano_sem\n\n                gt_instances = occ_gt['instances']\n                inst_gts.append(gt_instances)\n                inst_preds.append(pano_inst)\n\n            lidar_origins.append(output_origin)\n            occ_gts.append(gt_semantics)\n            occ_preds.append(sem_pred)\n        \n        if len(inst_preds) > 0:\n            results = main_raypq(occ_preds, occ_gts, inst_preds, inst_gts, lidar_origins, occ_class_names=occ_class_names)\n            results.update(main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names))\n            return results\n        else:\n            return main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names)\n\n    def format_results(self, occ_results, submission_prefix, **kwargs):\n        if submission_prefix is not None:\n            mmcv.mkdir_or_exist(submission_prefix)\n\n        for index, occ_pred in enumerate(tqdm(occ_results)):\n            info = self.data_infos[index]\n            sample_token = info['token']\n            save_path = os.path.join(submission_prefix, '{}.npz'.format(sample_token))\n            np.savez_compressed(save_path, occ_pred.astype(np.uint8))\n        \n        print('\\nFinished.')\n"
  },
  {
    "path": "loaders/old_metrics.py",
    "content": "import os\nimport numpy as np\nfrom sklearn.neighbors import KDTree\nfrom termcolor import colored\nfrom functools import reduce\nfrom typing import Iterable\n\nnp.seterr(divide='ignore', invalid='ignore')\nos.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n\n\ndef pcolor(string, color, on_color=None, attrs=None):\n    \"\"\"\n    Produces a colored string for printing\n\n    Parameters\n    ----------\n    string : str\n        String that will be colored\n    color : str\n        Color to use\n    on_color : str\n        Background color to use\n    attrs : list of str\n        Different attributes for the string\n\n    Returns\n    -------\n    string: str\n        Colored string\n    \"\"\"\n    return colored(string, color, on_color, attrs)\n\n\ndef getCellCoordinates(points, voxelSize):\n    return (points / voxelSize).astype(np.int)\n\n\ndef getNumUniqueCells(cells):\n    M = cells.max() + 1\n    return np.unique(cells[:, 0] + M * cells[:, 1] + M ** 2 * cells[:, 2]).shape[0]\n\n\nclass Metric_mIoU():\n    def __init__(self,\n                 save_dir='.',\n                 num_classes=18,\n                 use_lidar_mask=False,\n                 use_image_mask=False,\n                 ):\n        if num_classes == 18:\n            self.class_names = [\n                'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',\n                'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',\n                'driveable_surface', 'other_flat', 'sidewalk',\n                'terrain', 'manmade', 'vegetation','free'\n            ]\n        elif num_classes == 2:\n            self.class_names = ['non-free', 'free']\n        \n        self.save_dir = save_dir\n        self.use_lidar_mask = use_lidar_mask\n        self.use_image_mask = use_image_mask\n        self.num_classes = num_classes\n\n        self.point_cloud_range = [-40.0, -40.0, -1.0, 40.0, 40.0, 5.4]\n        self.occupancy_size = [0.4, 0.4, 0.4]\n        self.voxel_size = 0.4\n        self.occ_xdim = int((self.point_cloud_range[3] - self.point_cloud_range[0]) / self.occupancy_size[0])\n        self.occ_ydim = int((self.point_cloud_range[4] - self.point_cloud_range[1]) / self.occupancy_size[1])\n        self.occ_zdim = int((self.point_cloud_range[5] - self.point_cloud_range[2]) / self.occupancy_size[2])\n        self.voxel_num = self.occ_xdim * self.occ_ydim * self.occ_zdim\n        self.hist = np.zeros((self.num_classes, self.num_classes))\n        self.cnt = 0\n\n    def hist_info(self, n_cl, pred, gt):\n        \"\"\"\n        build confusion matrix\n        # empty classes:0\n        non-empty class: 0-16\n        free voxel class: 17\n\n        Args:\n            n_cl (int): num_classes_occupancy\n            pred (1-d array): pred_occupancy_label\n            gt (1-d array): gt_occupancu_label\n\n        Returns:\n            tuple:(hist, correctly number_predicted_labels, num_labelled_sample)\n        \"\"\"\n        assert pred.shape == gt.shape\n        k = (gt >= 0) & (gt < n_cl)  # exclude 255\n        labeled = np.sum(k)\n        correct = np.sum((pred[k] == gt[k]))\n\n        return (\n            np.bincount(\n                n_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl ** 2\n            ).reshape(n_cl, n_cl),\n            correct,\n            labeled,\n        )\n\n    def per_class_iu(self, hist):\n        #return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))\n        result = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))\n        result[hist.sum(1) == 0] = float('nan')\n        return result\n\n    def compute_mIoU(self, pred, label, n_classes):\n        hist = np.zeros((n_classes, n_classes))\n        new_hist, correct, labeled = self.hist_info(n_classes, pred.flatten(), label.flatten())\n        hist += new_hist\n        mIoUs = self.per_class_iu(hist)\n        # for ind_class in range(n_classes):\n        #     print(str(round(mIoUs[ind_class] * 100, 2)))\n        # print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)))\n        return round(np.nanmean(mIoUs) * 100, 2), hist\n\n    def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera):\n        self.cnt += 1\n        if self.use_image_mask:\n            masked_semantics_gt = semantics_gt[mask_camera]\n            masked_semantics_pred = semantics_pred[mask_camera]\n        elif self.use_lidar_mask:\n            masked_semantics_gt = semantics_gt[mask_lidar]\n            masked_semantics_pred = semantics_pred[mask_lidar]\n        else:\n            masked_semantics_gt = semantics_gt\n            masked_semantics_pred = semantics_pred\n\n        if self.num_classes == 2:\n            masked_semantics_pred = np.copy(masked_semantics_pred)\n            masked_semantics_gt = np.copy(masked_semantics_gt)\n            masked_semantics_pred[masked_semantics_pred < 17] = 0\n            masked_semantics_pred[masked_semantics_pred == 17] = 1\n            masked_semantics_gt[masked_semantics_gt < 17] = 0\n            masked_semantics_gt[masked_semantics_gt == 17] = 1\n        \n        _, _hist = self.compute_mIoU(masked_semantics_pred, masked_semantics_gt, self.num_classes)\n        self.hist += _hist\n\n    def count_miou(self):\n        mIoU = self.per_class_iu(self.hist)\n        # assert cnt == num_samples, 'some samples are not included in the miou calculation'\n        print(f'===> per class IoU of {self.cnt} samples:')\n        for ind_class in range(self.num_classes-1):\n            print(f'===> {self.class_names[ind_class]} - IoU = ' + str(round(mIoU[ind_class] * 100, 2)))\n\n        print(f'===> mIoU of {self.cnt} samples: ' + str(round(np.nanmean(mIoU[:self.num_classes-1]) * 100, 2)))\n        # print(f'===> sample-wise averaged mIoU of {cnt} samples: ' + str(round(np.nanmean(mIoU_avg), 2)))\n\n        return round(np.nanmean(mIoU[:self.num_classes-1]) * 100, 2)\n\n\nclass Metric_FScore():\n    def __init__(self,\n                 leaf_size=10,\n                 threshold_acc=0.6,\n                 threshold_complete=0.6,\n                 voxel_size=[0.4, 0.4, 0.4],\n                 range=[-40, -40, -1, 40, 40, 5.4],\n                 void=[17, 255],\n                 use_lidar_mask=False,\n                 use_image_mask=False, ) -> None:\n\n        self.leaf_size = leaf_size\n        self.threshold_acc = threshold_acc\n        self.threshold_complete = threshold_complete\n        self.voxel_size = voxel_size\n        self.range = range\n        self.void = void\n        self.use_lidar_mask = use_lidar_mask\n        self.use_image_mask = use_image_mask\n        self.cnt=0\n        self.tot_acc = 0.\n        self.tot_cmpl = 0.\n        self.tot_f1_mean = 0.\n        self.eps = 1e-8\n\n    def voxel2points(self, voxel):\n        # occIdx = torch.where(torch.logical_and(voxel != FREE, voxel != NOT_OBSERVED))\n        # if isinstance(voxel, np.ndarray): voxel = torch.from_numpy(voxel)\n        mask = np.logical_not(reduce(np.logical_or, [voxel == self.void[i] for i in range(len(self.void))]))\n        occIdx = np.where(mask)\n\n        points = np.concatenate((occIdx[0][:, None] * self.voxel_size[0] + self.voxel_size[0] / 2 + self.range[0], \\\n                                 occIdx[1][:, None] * self.voxel_size[1] + self.voxel_size[1] / 2 + self.range[1], \\\n                                 occIdx[2][:, None] * self.voxel_size[2] + self.voxel_size[2] / 2 + self.range[2]),\n                                axis=1)\n        return points\n\n    def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera ):\n        # for scene_token in tqdm(preds_dict.keys()):\n        self.cnt += 1\n\n        if self.use_image_mask:\n\n            semantics_gt[mask_camera == False] = 255\n            semantics_pred[mask_camera == False] = 255\n        elif self.use_lidar_mask:\n            semantics_gt[mask_lidar == False] = 255\n            semantics_pred[mask_lidar == False] = 255\n        else:\n            pass\n\n        ground_truth = self.voxel2points(semantics_gt)\n        prediction = self.voxel2points(semantics_pred)\n        if prediction.shape[0] == 0:\n            accuracy=0\n            completeness=0\n            fmean=0\n\n        else:\n            prediction_tree = KDTree(prediction, leaf_size=self.leaf_size)\n            ground_truth_tree = KDTree(ground_truth, leaf_size=self.leaf_size)\n            complete_distance, _ = prediction_tree.query(ground_truth)\n            complete_distance = complete_distance.flatten()\n\n            accuracy_distance, _ = ground_truth_tree.query(prediction)\n            accuracy_distance = accuracy_distance.flatten()\n\n            # evaluate completeness\n            complete_mask = complete_distance < self.threshold_complete\n            completeness = complete_mask.mean()\n\n            # evalute accuracy\n            accuracy_mask = accuracy_distance < self.threshold_acc\n            accuracy = accuracy_mask.mean()\n\n            fmean = 2.0 / (1 / (accuracy+self.eps) + 1 / (completeness+self.eps))\n\n        self.tot_acc += accuracy\n        self.tot_cmpl += completeness\n        self.tot_f1_mean += fmean\n\n    def count_fscore(self,):\n        base_color, attrs = 'red', ['bold', 'dark']\n        print(pcolor('\\n######## F score: {} #######'.format(self.tot_f1_mean / self.cnt), base_color, attrs=attrs))\n        return self.tot_f1_mean / self.cnt\n\nclass Metric_mRecall():\n    def __init__(self,\n                 save_dir='.',\n                 num_classes=18,\n                 pred_classes=2,\n                 use_lidar_mask=False,\n                 use_image_mask=False,\n                 ):\n        if num_classes == 18:\n            self.class_names = [\n                'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',\n                'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',\n                'driveable_surface', 'other_flat', 'sidewalk',\n                'terrain', 'manmade', 'vegetation','free'\n            ]\n        elif num_classes == 2:\n            self.class_names = ['non-free', 'free']\n        \n        self.pred_classes = pred_classes\n        self.save_dir = save_dir\n        self.use_lidar_mask = use_lidar_mask\n        self.use_image_mask = use_image_mask\n        self.num_classes = num_classes\n\n        self.point_cloud_range = [-40.0, -40.0, -1.0, 40.0, 40.0, 5.4]\n        self.occupancy_size = [0.4, 0.4, 0.4]\n        self.voxel_size = 0.4\n        self.occ_xdim = int((self.point_cloud_range[3] - self.point_cloud_range[0]) / self.occupancy_size[0])\n        self.occ_ydim = int((self.point_cloud_range[4] - self.point_cloud_range[1]) / self.occupancy_size[1])\n        self.occ_zdim = int((self.point_cloud_range[5] - self.point_cloud_range[2]) / self.occupancy_size[2])\n        self.voxel_num = self.occ_xdim * self.occ_ydim * self.occ_zdim\n        self.hist = np.zeros((self.num_classes, self.pred_classes))   # n_cl, p_cl\n        self.cnt = 0\n\n    def hist_info(self, n_cl, p_cl, pred, gt):\n        \"\"\"\n        build confusion matrix\n        # empty classes:0\n        non-empty class: 0-16\n        free voxel class: 17\n\n        Args:\n            n_cl (int): num_classes_occupancy\n            pred (1-d array): pred_occupancy_label\n            gt (1-d array): gt_occupancu_label\n\n        Returns:\n            tuple:(hist, correctly number_predicted_labels, num_labelled_sample)\n        \"\"\"\n        assert pred.shape == gt.shape\n        k = (gt >= 0) & (gt < n_cl)  # exclude 255\n        labeled = np.sum(k)\n        correct = np.sum((pred[k] == gt[k]))\n\n        return (\n            np.bincount(\n                p_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl * p_cl\n            ).reshape(n_cl, p_cl),   # 18, 2\n            correct,\n            labeled,\n        )\n\n    def per_class_recall(self, hist):\n        return hist[:, 1] / hist.sum(1)   ## recall \n\n    def compute_mRecall(self, pred, label, n_classes, p_classes):\n        hist = np.zeros((n_classes, p_classes))\n        new_hist, correct, labeled = self.hist_info(n_classes, p_classes, pred.flatten(), label.flatten())\n        hist += new_hist\n        mRecalls = self.per_class_recall(hist)\n        # for ind_class in range(n_classes):\n        #     print(str(round(mIoUs[ind_class] * 100, 2)))\n        # print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)))\n        return round(np.nanmean(mRecalls) * 100, 2), hist\n\n    def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera):\n        self.cnt += 1\n        if self.use_image_mask:\n            masked_semantics_gt = semantics_gt[mask_camera]\n            masked_semantics_pred = semantics_pred[mask_camera]\n        elif self.use_lidar_mask:\n            masked_semantics_gt = semantics_gt[mask_lidar]\n            masked_semantics_pred = semantics_pred[mask_lidar]\n        else:\n            masked_semantics_gt = semantics_gt\n            masked_semantics_pred = semantics_pred\n\n        if self.pred_classes == 2:\n            masked_semantics_pred = np.copy(masked_semantics_pred)\n            masked_semantics_gt = np.copy(masked_semantics_gt)\n            masked_semantics_pred[masked_semantics_pred < 17] = 1  \n            masked_semantics_pred[masked_semantics_pred == 17] = 0 # 0 is free\n\n        _, _hist = self.compute_mRecall(masked_semantics_pred, masked_semantics_gt, self.num_classes, self.pred_classes)\n        self.hist += _hist\n\n    def count_mrecall(self):\n        mRecall = self.per_class_recall(self.hist)\n        # assert cnt == num_samples, 'some samples are not included in the miou calculation'\n        print(f'===> per class Recall of {self.cnt} samples:')\n        for ind_class in range(self.num_classes-1):\n            print(f'===> {self.class_names[ind_class]} - Recall = ' + str(round(mRecall[ind_class] * 100, 2)))\n\n        print(f'===> mRecall of {self.cnt} samples: ' + str(round(np.nanmean(mRecall[:self.num_classes-1]) * 100, 2)))\n\n        return round(np.nanmean(mRecall[:self.num_classes-1]) * 100, 2)\n\n\n# modified from https://github.com/open-mmlab/mmdetection3d/blob/main/mmdet3d/evaluation/functional/panoptic_seg_eval.py#L10\nclass Metric_Panoptic():\n    def __init__(self, \n                 save_dir='.',\n                 num_classes=18,\n                 use_lidar_mask=False,\n                 use_image_mask=False,\n                 ignore_index: Iterable[int]=[],\n                 ):\n        \"\"\"\n        Args:\n            ignore_index (llist): Class ids that not be considered in pq counting.\n        \"\"\"\n        if num_classes == 18:\n            self.class_names = [\n                'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',\n                'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',\n                'driveable_surface', 'other_flat', 'sidewalk',\n                'terrain', 'manmade', 'vegetation','free'\n            ]\n        else:\n            raise ValueError\n        \n        self.save_dir = save_dir\n        self.num_classes = num_classes\n        self.use_lidar_mask = use_lidar_mask\n        self.use_image_mask = use_image_mask\n        self.ignore_index = ignore_index\n        self.id_offset = 2 ** 16\n        self.eps = 1e-5\n        \n        self.min_num_points = 20\n        self.include = np.array(\n            [n for n in range(self.num_classes - 1) if n not in self.ignore_index],\n            dtype=int)\n        self.cnt = 0\n        \n        # panoptic stuff\n        self.pan_tp = np.zeros(self.num_classes, dtype=int)\n        self.pan_iou = np.zeros(self.num_classes, dtype=np.double)\n        self.pan_fp = np.zeros(self.num_classes, dtype=int)\n        self.pan_fn = np.zeros(self.num_classes, dtype=int)\n        \n    def add_batch(self,semantics_pred,semantics_gt,instances_pred,instances_gt,mask_lidar,mask_camera):\n        self.cnt += 1\n        if self.use_image_mask:\n            masked_semantics_gt = semantics_gt[mask_camera]\n            masked_semantics_pred = semantics_pred[mask_camera]\n            masked_instances_gt = instances_gt[mask_camera]\n            masked_instances_pred = instances_pred[mask_camera]\n        elif self.use_lidar_mask:\n            masked_semantics_gt = semantics_gt[mask_lidar]\n            masked_semantics_pred = semantics_pred[mask_lidar]\n            masked_instances_gt = instances_gt[mask_lidar]\n            masked_instances_pred = instances_pred[mask_lidar]\n        else:\n            masked_semantics_gt = semantics_gt\n            masked_semantics_pred = semantics_pred\n            masked_instances_gt = instances_gt\n            masked_instances_pred = instances_pred\n        self.add_panoptic_sample(masked_semantics_pred, masked_semantics_gt, masked_instances_pred, masked_instances_gt) \n    \n    def add_panoptic_sample(self, semantics_pred, semantics_gt, instances_pred, instances_gt):\n        \"\"\"Add one sample of panoptic predictions and ground truths for\n        evaluation.\n\n        Args:\n            semantics_pred (np.ndarray): Semantic predictions.\n            semantics_gt (np.ndarray): Semantic ground truths.\n            instances_pred (np.ndarray): Instance predictions.\n            instances_gt (np.ndarray): Instance ground truths.\n        \"\"\"\n        # get instance_class_id from instance_gt\n        instance_class_ids = [self.num_classes - 1]\n        for i in range(1, instances_gt.max() + 1):\n            class_id = np.unique(semantics_gt[instances_gt == i])\n            # assert class_id.shape[0] == 1, \"each instance must belong to only one class\"\n            if class_id.shape[0] == 1:\n                instance_class_ids.append(class_id[0])\n            else:\n                instance_class_ids.append(self.num_classes - 1)\n        instance_class_ids = np.array(instance_class_ids)\n\n        instance_count = 1\n        final_instance_class_ids = []\n        final_instances = np.zeros_like(instances_gt)  # empty space has instance id \"0\"\n\n        for class_id in range(self.num_classes - 1):\n            if np.sum(semantics_gt == class_id) == 0:\n                continue\n\n            if self.class_names[class_id] in ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'motorcycle', 'bicycle', 'pedestrian']:\n                # treat as instances\n                for instance_id in range(len(instance_class_ids)):\n                    if instance_class_ids[instance_id] != class_id:\n                        continue\n                    final_instances[instances_gt == instance_id] = instance_count\n                    instance_count += 1\n                    final_instance_class_ids.append(class_id)\n            else:\n                # treat as semantics\n                final_instances[semantics_gt == class_id] = instance_count\n                instance_count += 1\n                final_instance_class_ids.append(class_id)\n                \n        instances_gt = final_instances\n        \n        # avoid zero (ignored label)\n        instances_pred = instances_pred + 1\n        instances_gt = instances_gt + 1\n        \n        for cl in self.ignore_index:\n            # make a mask for this class\n            gt_not_in_excl_mask = semantics_gt != cl\n            # remove all other points\n            semantics_pred = semantics_pred[gt_not_in_excl_mask]\n            semantics_gt = semantics_gt[gt_not_in_excl_mask]\n            instances_pred = instances_pred[gt_not_in_excl_mask]\n            instances_gt = instances_gt[gt_not_in_excl_mask]\n        \n        # for each class (except the ignored ones)\n        for cl in self.include:\n            # get a class mask\n            pred_inst_in_cl_mask = semantics_pred == cl\n            gt_inst_in_cl_mask = semantics_gt == cl\n\n            # get instance points in class (makes outside stuff 0)\n            pred_inst_in_cl = instances_pred * pred_inst_in_cl_mask.astype(int)\n            gt_inst_in_cl = instances_gt * gt_inst_in_cl_mask.astype(int)\n\n            # generate the areas for each unique instance prediction\n            unique_pred, counts_pred = np.unique(\n                pred_inst_in_cl[pred_inst_in_cl > 0], return_counts=True)\n            id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)}\n            matched_pred = np.array([False] * unique_pred.shape[0])\n\n            # generate the areas for each unique instance gt_np\n            unique_gt, counts_gt = np.unique(\n                gt_inst_in_cl[gt_inst_in_cl > 0], return_counts=True)\n            id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)}\n            matched_gt = np.array([False] * unique_gt.shape[0])\n\n            # generate intersection using offset\n            valid_combos = np.logical_and(pred_inst_in_cl > 0,\n                                          gt_inst_in_cl > 0)\n            id_offset_combo = pred_inst_in_cl[\n                valid_combos] + self.id_offset * gt_inst_in_cl[valid_combos]\n            unique_combo, counts_combo = np.unique(\n                id_offset_combo, return_counts=True)\n\n            # generate an intersection map\n            # count the intersections with over 0.5 IoU as TP\n            gt_labels = unique_combo // self.id_offset\n            pred_labels = unique_combo % self.id_offset\n            gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels])\n            pred_areas = np.array(\n                [counts_pred[id2idx_pred[id]] for id in pred_labels])\n            intersections = counts_combo\n            unions = gt_areas + pred_areas - intersections\n            ious = intersections.astype(float) / unions.astype(float)\n\n            tp_indexes = ious > 0.5\n            self.pan_tp[cl] += np.sum(tp_indexes)\n            self.pan_iou[cl] += np.sum(ious[tp_indexes])\n\n            matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True\n            matched_pred[[id2idx_pred[id]\n                          for id in pred_labels[tp_indexes]]] = True\n\n            # count the FN\n            if len(counts_gt) > 0:\n                self.pan_fn[cl] += np.sum(\n                    np.logical_and(counts_gt >= self.min_num_points,\n                                   ~matched_gt))\n\n            # count the FP\n            if len(matched_pred) > 0:\n                self.pan_fp[cl] += np.sum(\n                    np.logical_and(counts_pred >= self.min_num_points,\n                                   ~matched_pred))\n    \n    def count_pq(self, ):\n        sq_all = self.pan_iou.astype(np.double) / np.maximum(\n            self.pan_tp.astype(np.double), self.eps)\n        rq_all = self.pan_tp.astype(np.double) / np.maximum(\n            self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double)\n            + 0.5 * self.pan_fn.astype(np.double), self.eps)\n        pq_all = sq_all * rq_all\n        \n        # mask classes not occurring in dataset\n        mask = (self.pan_tp + self.pan_fp + self.pan_fn) > 0\n        sq_all[~mask] = float('nan')\n        rq_all[~mask] = float('nan')\n        pq_all[~mask] = float('nan')\n        \n        # then do the REAL mean (no ignored classes)\n        sq = round(np.nanmean(sq_all[self.include]) * 100, 2)\n        rq = round(np.nanmean(rq_all[self.include]) * 100, 2)\n        pq = round(np.nanmean(pq_all[self.include]) * 100, 2)\n        \n        print(f'===> per class sq, rq, pq of {self.cnt} samples:')\n        for ind_class in self.include:\n            print(f'===> {self.class_names[ind_class]} -' + \\\n                  f' sq = {round(sq_all[ind_class] * 100, 2)},' + \\\n                  f' rq = {round(rq_all[ind_class] * 100, 2)},' + \\\n                  f' pq = {round(pq_all[ind_class] * 100, 2)}')\n        \n        print(f'===> sq of {self.cnt} samples: ' + str(sq))\n        print(f'===> rq of {self.cnt} samples: ' + str(rq))\n        print(f'===> pq of {self.cnt} samples: ' + str(pq))\n\n        return (pq, sq, rq)"
  },
  {
    "path": "loaders/pipelines/__init__.py",
    "content": "from .loading import LoadMultiViewImageFromMultiSweeps, LoadOccGTFromFile\nfrom .transforms import PadMultiViewImage, NormalizeMultiviewImage, PhotoMetricDistortionMultiViewImage\n\n__all__ = [\n    'LoadMultiViewImageFromMultiSweeps', 'PadMultiViewImage', 'NormalizeMultiviewImage', \n    'PhotoMetricDistortionMultiViewImage', 'LoadOccGTFromFile'\n]"
  },
  {
    "path": "loaders/pipelines/loading.py",
    "content": "import os\nimport mmcv\nimport torch\nimport numpy as np\nfrom mmdet.datasets.builder import PIPELINES\nfrom numpy.linalg import inv\nfrom mmcv.runner import get_dist_info\nfrom mmcv.parallel import DataContainer as DC\nfrom mmdet.datasets.pipelines import to_tensor\nfrom torchvision.transforms.functional import rotate\n\n\ndef compose_lidar2img(ego2global_translation_curr,\n                      ego2global_rotation_curr,\n                      lidar2ego_translation_curr,\n                      lidar2ego_rotation_curr,\n                      sensor2global_translation_past,\n                      sensor2global_rotation_past,\n                      cam_intrinsic_past):\n    \n    R = sensor2global_rotation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T)\n    T = sensor2global_translation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T)\n    T -= ego2global_translation_curr @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T) + lidar2ego_translation_curr @ inv(lidar2ego_rotation_curr).T\n\n    lidar2cam_r = inv(R.T)\n    lidar2cam_t = T @ lidar2cam_r.T\n\n    lidar2cam_rt = np.eye(4)\n    lidar2cam_rt[:3, :3] = lidar2cam_r.T\n    lidar2cam_rt[3, :3] = -lidar2cam_t\n\n    viewpad = np.eye(4)\n    viewpad[:cam_intrinsic_past.shape[0], :cam_intrinsic_past.shape[1]] = cam_intrinsic_past\n    lidar2img = (viewpad @ lidar2cam_rt.T).astype(np.float32)\n\n    return lidar2img\n\n\n@PIPELINES.register_module()\nclass LoadMultiViewImageFromMultiSweeps(object):\n    def __init__(self,\n                 sweeps_num=5,\n                 color_type='color',\n                 test_mode=False):\n        self.sweeps_num = sweeps_num\n        self.color_type = color_type\n        self.test_mode = test_mode\n\n        self.train_interval = [4, 8]\n        self.test_interval = 6\n\n        try:\n            mmcv.use_backend('turbojpeg')\n        except ImportError:\n            mmcv.use_backend('cv2')\n\n    def load_offline(self, results):\n        cam_types = [\n            'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',\n            'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'\n        ]\n\n        if len(results['sweeps']['prev']) == 0:\n            for _ in range(self.sweeps_num):\n                for j in range(len(cam_types)):\n                    results['img'].append(results['img'][j])\n                    results['img_timestamp'].append(results['img_timestamp'][j])\n                    results['filename'].append(results['filename'][j])\n                    results['lidar2img'].append(np.copy(results['lidar2img'][j]))\n                    if 'ego2lidar' in results:\n                        results['ego2lidar'].append(results['ego2lidar'][0])\n        else:\n            if self.test_mode:\n                interval = self.test_interval\n                choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]\n            elif len(results['sweeps']['prev']) <= self.sweeps_num:\n                pad_len = self.sweeps_num - len(results['sweeps']['prev'])\n                choices = list(range(len(results['sweeps']['prev']))) + [len(results['sweeps']['prev']) - 1] * pad_len\n            else:\n                max_interval = len(results['sweeps']['prev']) // self.sweeps_num\n                max_interval = min(max_interval, self.train_interval[1])\n                min_interval = min(max_interval, self.train_interval[0])\n                interval = np.random.randint(min_interval, max_interval + 1)\n                choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]\n\n            for idx in sorted(list(choices)):\n                sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)\n                sweep = results['sweeps']['prev'][sweep_idx]\n\n                if len(sweep.keys()) < len(cam_types):\n                    sweep = results['sweeps']['prev'][sweep_idx - 1]\n\n                for sensor in cam_types:\n                    results['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type))\n                    results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)\n                    results['filename'].append(os.path.relpath(sweep[sensor]['data_path']))\n                    results['lidar2img'].append(compose_lidar2img(\n                        results['ego2global_translation'],\n                        results['ego2global_rotation'],\n                        results['lidar2ego_translation'],\n                        results['lidar2ego_rotation'],\n                        sweep[sensor]['sensor2global_translation'],\n                        sweep[sensor]['sensor2global_rotation'],\n                        sweep[sensor]['cam_intrinsic'],\n                    ))\n                    if 'ego2lidar' in results:\n                        results['ego2lidar'].append(results['ego2lidar'][0])\n\n        return results\n\n    def load_online(self, results):\n        # only used when measuring FPS\n        assert self.test_mode\n        assert self.test_interval % 6 == 0\n\n        cam_types = [\n            'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',\n            'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'\n        ]\n\n        if len(results['sweeps']['prev']) == 0:\n            for _ in range(self.sweeps_num):\n                for j in range(len(cam_types)):\n                    results['img_timestamp'].append(results['img_timestamp'][j])\n                    results['filename'].append(results['filename'][j])\n                    results['lidar2img'].append(np.copy(results['lidar2img'][j]))\n                    if 'ego2lidar' in results:\n                        results['ego2lidar'].append(results['ego2lidar'][0])\n        else:\n            interval = self.test_interval\n            choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]\n\n            for idx in sorted(list(choices)):\n                sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)\n                sweep = results['sweeps']['prev'][sweep_idx]\n\n                if len(sweep.keys()) < len(cam_types):\n                    sweep = results['sweeps']['prev'][sweep_idx - 1]\n\n                for sensor in cam_types:\n                    # skip loading history frames\n                    results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)\n                    results['filename'].append(os.path.relpath(sweep[sensor]['data_path']))\n                    results['lidar2img'].append(compose_lidar2img(\n                        results['ego2global_translation'],\n                        results['ego2global_rotation'],\n                        results['lidar2ego_translation'],\n                        results['lidar2ego_rotation'],\n                        sweep[sensor]['sensor2global_translation'],\n                        sweep[sensor]['sensor2global_rotation'],\n                        sweep[sensor]['cam_intrinsic'],\n                    ))\n                    if 'ego2lidar' in results:\n                        results['ego2lidar'].append(results['ego2lidar'][0])\n\n        return results\n\n    def __call__(self, results):\n        if self.sweeps_num == 0:\n            return results\n\n        world_size = get_dist_info()[1]\n        if world_size == 1 and self.test_mode:\n            return self.load_online(results)\n        else:\n            return self.load_offline(results)\n\n\n@PIPELINES.register_module()\nclass LoadOccGTFromFile(object):\n    def __init__(self, num_classes=18, inst_class_ids=[]):\n        self.num_classes = num_classes\n        self.inst_class_ids = inst_class_ids\n    \n    def __call__(self, results):\n        occ_labels = np.load(results['occ_path'])\n        semantics = occ_labels['semantics']  # [200, 200, 16]\n        # mask_lidar = occ_labels['mask_lidar'].astype(np.bool_)  # [200, 200, 16]\n        # mask_camera = occ_labels['mask_camera'].astype(np.bool_)  # [200, 200, 16]\n\n        # results['mask_lidar'] = mask_lidar\n        # results['mask_camera'] = mask_camera\n  \n        # instance GT\n        if 'instances' in occ_labels.keys():\n            instances = occ_labels['instances']\n            instance_class_ids = [self.num_classes - 1]  # the 0-th class is always free class\n            for i in range(1, instances.max() + 1):\n                class_id = np.unique(semantics[instances == i])\n                assert class_id.shape[0] == 1, \"each instance must belong to only one class\"\n                instance_class_ids.append(class_id[0])\n            instance_class_ids = np.array(instance_class_ids)\n        else:\n            instances = None\n            instance_class_ids = None\n\n        instance_count = 0\n        final_instance_class_ids = []\n        final_instances = np.ones_like(semantics) * 255  # empty space has instance id \"255\"\n\n        for class_id in range(self.num_classes - 1):\n            if np.sum(semantics == class_id) == 0:\n                continue\n\n            if class_id in self.inst_class_ids:\n                assert instances is not None, 'instance annotation not found'\n                # treat as instances\n                for instance_id in range(len(instance_class_ids)):\n                    if instance_class_ids[instance_id] != class_id:\n                        continue\n                    final_instances[instances == instance_id] = instance_count\n                    instance_count += 1\n                    final_instance_class_ids.append(class_id)\n            else:\n                # treat as semantics\n                final_instances[semantics == class_id] = instance_count\n                instance_count += 1\n                final_instance_class_ids.append(class_id)\n\n        results['voxel_semantics'] = semantics\n        results['voxel_instances'] = final_instances\n        results['instance_class_ids'] = DC(to_tensor(final_instance_class_ids))\n\n        if results.get('rotate_bda', False):\n            semantics = torch.from_numpy(semantics).permute(2, 0, 1)  # [16, 200, 200]\n            semantics = rotate(semantics, results['rotate_bda'], fill=255).permute(1, 2, 0)  # [200, 200, 16]\n            results['voxel_semantics'] = semantics.numpy()\n\n            final_instances = torch.from_numpy(final_instances).permute(2, 0, 1)  # [16, 200, 200]\n            final_instances = rotate(final_instances, results['rotate_bda'], fill=255).permute(1, 2, 0)  # [200, 200, 16]\n            results['voxel_instances'] = final_instances.numpy()\n\n        if results.get('flip_dx', False):\n            results['voxel_semantics'] = results['voxel_semantics'][::-1, ...].copy()\n            results['voxel_instances'] = results['voxel_instances'][::-1, ...].copy()\n            \n        if results.get('flip_dy', False):\n            results['voxel_semantics'] = results['voxel_semantics'][:, ::-1, ...].copy()\n            results['voxel_instances'] = results['voxel_instances'][:, ::-1, ...].copy()\n\n        return results\n\n\n# https://github.com/HuangJunJie2017/BEVDet/blob/58c2587a8f89a1927926f0bdb6cde2917c91a9a5/mmdet3d/datasets/pipelines/loading.py#L1177\n@PIPELINES.register_module()\nclass BEVAug(object):\n    def __init__(self, bda_aug_conf, classes, is_train=True):\n        self.bda_aug_conf = bda_aug_conf\n        self.is_train = is_train\n        self.classes = classes\n\n    def sample_bda_augmentation(self):\n        \"\"\"Generate bda augmentation values based on bda_config.\"\"\"\n        if self.is_train:\n            rotate_bda = np.random.uniform(*self.bda_aug_conf['rot_lim'])\n            scale_bda = np.random.uniform(*self.bda_aug_conf['scale_lim'])\n            flip_dx = np.random.uniform() < self.bda_aug_conf['flip_dx_ratio']\n            flip_dy = np.random.uniform() < self.bda_aug_conf['flip_dy_ratio']\n        else:\n            rotate_bda = 0\n            scale_bda = 1.0\n            flip_dx = False\n            flip_dy = False\n        return rotate_bda, scale_bda, flip_dx, flip_dy\n\n    def bev_transform(self, rotate_angle, scale_ratio, flip_dx, flip_dy):\n        \"\"\"\n        Returns:\n            rot_mat: (3, 3)\n        \"\"\"\n        rotate_angle = torch.tensor(rotate_angle / 180 * np.pi)\n        rot_sin = torch.sin(rotate_angle)\n        rot_cos = torch.cos(rotate_angle)\n        rot_mat = torch.Tensor([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0],\n                                [0, 0, 1]])\n        scale_mat = torch.Tensor([[scale_ratio, 0, 0], [0, scale_ratio, 0],\n                                  [0, 0, scale_ratio]])\n        flip_mat = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n\n        if flip_dx:\n            flip_mat = flip_mat @ torch.Tensor([[-1, 0, 0], [0, 1, 0],\n                                                [0, 0, 1]])\n        if flip_dy:\n            flip_mat = flip_mat @ torch.Tensor([[1, 0, 0], [0, -1, 0],\n                                                [0, 0, 1]])\n        rot_mat = flip_mat @ (scale_mat @ rot_mat)\n        \n        return rot_mat\n\n    def __call__(self, results):\n        rotate_bda, scale_bda, flip_dx, flip_dy = self.sample_bda_augmentation()\n\n        bda_mat = torch.zeros(4, 4)\n        bda_mat[3, 3] = 1\n\n        # bda_rot: (3, 3)\n        bda_rot = self.bev_transform(rotate_bda, scale_bda, flip_dx, flip_dy)\n        bda_mat[:3, :3] = bda_rot\n\n        results['bda_mat'] = bda_mat\n        results['flip_dx'] = flip_dx\n        results['flip_dy'] = flip_dy\n        results['rotate_bda'] = rotate_bda\n        results['scale_bda'] = scale_bda\n\n        for i in range(len(results['ego2lidar'])):\n            results['ego2lidar'][i] = results['ego2lidar'][i] @ torch.inverse(bda_mat).numpy()  # [4, 4] @ [4, 4]\n\n        return results\n"
  },
  {
    "path": "loaders/pipelines/transforms.py",
    "content": "import mmcv\nimport torch\nimport numpy as np\nfrom PIL import Image\nfrom numpy import random\nfrom mmdet.datasets.builder import PIPELINES\n\n\n@PIPELINES.register_module()\nclass PadMultiViewImage(object):\n    \"\"\"Pad the multi-view image.\n    There are two padding modes: (1) pad to a fixed size and (2) pad to the\n    minimum size that is divisible by some number.\n    Added keys are \"pad_shape\", \"pad_fixed_size\", \"pad_size_divisor\",\n    Args:\n        size (tuple, optional): Fixed padding size.\n        size_divisor (int, optional): The divisor of padded size.\n        pad_val (float, optional): Padding value, 0 by default.\n    \"\"\"\n\n    def __init__(self, size=None, size_divisor=None, pad_val=0):\n        self.size = size\n        self.size_divisor = size_divisor\n        self.pad_val = pad_val\n        # only one of size and size_divisor should be valid\n        assert size is not None or size_divisor is not None\n        assert size is None or size_divisor is None\n\n    def _pad_img(self, img):\n        if self.size_divisor is not None:\n            pad_h = int(np.ceil(img.shape[0] / self.size_divisor)) * self.size_divisor\n            pad_w = int(np.ceil(img.shape[1] / self.size_divisor)) * self.size_divisor\n        else:\n            pad_h, pad_w = self.size\n\n        pad_width = ((0, pad_h - img.shape[0]), (0, pad_w - img.shape[1]), (0, 0))\n        img = np.pad(img, pad_width, constant_values=self.pad_val)\n        return img\n\n    def _pad_imgs(self, results):\n        padded_img = [self._pad_img(img) for img in results['img']]\n        \n        results['ori_shape'] = [img.shape for img in results['img']]\n        results['img'] = padded_img\n        results['img_shape'] = [img.shape for img in padded_img]\n        results['pad_shape'] = [img.shape for img in padded_img]\n        results['pad_fixed_size'] = self.size\n        results['pad_size_divisor'] = self.size_divisor\n\n    def __call__(self, results):\n        \"\"\"Call function to pad images, masks, semantic segmentation maps.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Updated result dict.\n        \"\"\"\n        self._pad_imgs(results)\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(size={self.size}, '\n        repr_str += f'size_divisor={self.size_divisor}, '\n        repr_str += f'pad_val={self.pad_val})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass NormalizeMultiviewImage(object):\n    \"\"\"Normalize the image.\n    Added key is \"img_norm_cfg\".\n    Args:\n        mean (sequence): Mean values of 3 channels.\n        std (sequence): Std values of 3 channels.\n        to_rgb (bool): Whether to convert the image from BGR to RGB,\n            default is true.\n    \"\"\"\n\n    def __init__(self, mean, std, to_rgb=True):\n        self.mean = np.array(mean, dtype=np.float32).reshape(-1)\n        self.std = 1 / np.array(std, dtype=np.float32).reshape(-1)\n        self.to_rgb = to_rgb\n\n    def __call__(self, results):\n        \"\"\"Call function to normalize images.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Normalized results, 'img_norm_cfg' key is added into\n                result dict.\n        \"\"\"\n        normalized_imgs = []\n\n        for img in results['img']:\n            img = img.astype(np.float32)\n            if self.to_rgb:\n                img = img[..., ::-1]\n            img = img - self.mean\n            img = img * self.std\n            normalized_imgs.append(img)\n\n        results['img'] = normalized_imgs\n        results['img_norm_cfg'] = dict(\n            mean=self.mean,\n            std=self.std,\n            to_rgb=self.to_rgb\n        )\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass PhotoMetricDistortionMultiViewImage:\n    \"\"\"Apply photometric distortion to image sequentially, every transformation\n    is applied with a probability of 0.5. The position of random contrast is in\n    second or second to last.\n    1. random brightness\n    2. random contrast (mode 0)\n    3. convert color from BGR to HSV\n    4. random saturation\n    5. random hue\n    6. convert color from HSV to BGR\n    7. random contrast (mode 1)\n    8. randomly swap channels\n    Args:\n        brightness_delta (int): delta of brightness.\n        contrast_range (tuple): range of contrast.\n        saturation_range (tuple): range of saturation.\n        hue_delta (int): delta of hue.\n    \"\"\"\n\n    def __init__(self,\n                 brightness_delta=32,\n                 contrast_range=(0.5, 1.5),\n                 saturation_range=(0.5, 1.5),\n                 hue_delta=18):\n        self.brightness_delta = brightness_delta\n        self.contrast_lower, self.contrast_upper = contrast_range\n        self.saturation_lower, self.saturation_upper = saturation_range\n        self.hue_delta = hue_delta\n\n    def __call__(self, results):\n        \"\"\"Call function to perform photometric distortion on images.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Result dict with images distorted.\n        \"\"\"\n        imgs = results['img']\n        new_imgs = []\n        for img in imgs:\n            ori_dtype = img.dtype\n            img = img.astype(np.float32)\n\n            # random brightness\n            if random.randint(2):\n                delta = random.uniform(-self.brightness_delta,\n                                    self.brightness_delta)\n                img += delta\n\n            # mode == 0 --> do random contrast first\n            # mode == 1 --> do random contrast last\n            mode = random.randint(2)\n            if mode == 1:\n                if random.randint(2):\n                    alpha = random.uniform(self.contrast_lower,\n                                        self.contrast_upper)\n                    img *= alpha\n\n            # convert color from BGR to HSV\n            img = mmcv.bgr2hsv(img)\n\n            # random saturation\n            if random.randint(2):\n                img[..., 1] *= random.uniform(self.saturation_lower,\n                                            self.saturation_upper)\n\n            # random hue\n            if random.randint(2):\n                img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)\n                img[..., 0][img[..., 0] > 360] -= 360\n                img[..., 0][img[..., 0] < 0] += 360\n\n            # convert color from HSV to BGR\n            img = mmcv.hsv2bgr(img)\n\n            # random contrast\n            if mode == 0:\n                if random.randint(2):\n                    alpha = random.uniform(self.contrast_lower,\n                                        self.contrast_upper)\n                    img *= alpha\n\n            # randomly swap channels\n            if random.randint(2):\n                img = img[..., random.permutation(3)]\n\n            new_imgs.append(img.astype(ori_dtype))\n\n        results['img'] = new_imgs\n        return results\n\n    def __repr__(self):\n        repr_str = self.__class__.__name__\n        repr_str += f'(\\nbrightness_delta={self.brightness_delta},\\n'\n        repr_str += 'contrast_range='\n        repr_str += f'{(self.contrast_lower, self.contrast_upper)},\\n'\n        repr_str += 'saturation_range='\n        repr_str += f'{(self.saturation_lower, self.saturation_upper)},\\n'\n        repr_str += f'hue_delta={self.hue_delta})'\n        return repr_str\n\n\n@PIPELINES.register_module()\nclass RandomTransformImage(object):\n    def __init__(self, ida_aug_conf=None, training=True):\n        self.ida_aug_conf = ida_aug_conf\n        self.training = training\n\n    def __call__(self, results):\n        resize, resize_dims, crop, flip, rotate = self.sample_augmentation()\n        \n        if len(results['lidar2img']) == len(results['img']):\n            for i in range(len(results['img'])):\n                img = Image.fromarray(np.uint8(results['img'][i]))\n                \n                # resize, resize_dims, crop, flip, rotate = self._sample_augmentation()\n                img, ida_mat = self.img_transform(\n                    img,\n                    resize=resize,\n                    resize_dims=resize_dims,\n                    crop=crop,\n                    flip=flip,\n                    rotate=rotate,\n                )\n                results['img'][i] = np.array(img).astype(np.uint8)\n                results['lidar2img'][i] = ida_mat @ results['lidar2img'][i]\n\n        elif len(results['img']) == 6:\n            for i in range(len(results['img'])):\n                img = Image.fromarray(np.uint8(results['img'][i]))\n                \n                # resize, resize_dims, crop, flip, rotate = self._sample_augmentation()\n                img, ida_mat = self.img_transform(\n                    img,\n                    resize=resize,\n                    resize_dims=resize_dims,\n                    crop=crop,\n                    flip=flip,\n                    rotate=rotate,\n                )\n                results['img'][i] = np.array(img).astype(np.uint8)\n\n            for i in range(len(results['lidar2img'])):\n                results['lidar2img'][i] = ida_mat @ results['lidar2img'][i]\n\n        else:\n            raise ValueError()\n\n        results['ori_shape'] = [img.shape for img in results['img']]\n        results['img_shape'] = [img.shape for img in results['img']]\n        results['pad_shape'] = [img.shape for img in results['img']]\n\n        return results\n\n    def img_transform(self, img, resize, resize_dims, crop, flip, rotate):\n        \"\"\"\n        https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L48\n        \"\"\"\n        def get_rot(h):\n            return torch.Tensor([\n                [np.cos(h), np.sin(h)],\n                [-np.sin(h), np.cos(h)],\n            ])\n\n        ida_rot = torch.eye(2)\n        ida_tran = torch.zeros(2)\n\n        # adjust image\n        img = img.resize(resize_dims)\n        img = img.crop(crop)\n        if flip:\n            img = img.transpose(method=Image.FLIP_LEFT_RIGHT)\n        img = img.rotate(rotate)\n\n        # post-homography transformation\n        ida_rot *= resize\n        ida_tran -= torch.Tensor(crop[:2])\n        \n        if flip:\n            A = torch.Tensor([[-1, 0], [0, 1]])\n            b = torch.Tensor([crop[2] - crop[0], 0])\n            ida_rot = A.matmul(ida_rot)\n            ida_tran = A.matmul(ida_tran) + b\n        \n        A = get_rot(rotate / 180 * np.pi)\n        b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2\n        b = A.matmul(-b) + b\n\n        ida_rot = A.matmul(ida_rot)\n        ida_tran = A.matmul(ida_tran) + b\n\n        ida_mat = torch.eye(4)\n        ida_mat[:2, :2] = ida_rot\n        ida_mat[:2, 2] = ida_tran\n\n        return img, ida_mat.numpy()\n\n    def sample_augmentation(self):\n        \"\"\"\n        https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L247\n        \"\"\"\n        H, W = self.ida_aug_conf['H'], self.ida_aug_conf['W']\n        fH, fW = self.ida_aug_conf['final_dim']\n\n        if self.training:\n            resize = np.random.uniform(*self.ida_aug_conf['resize_lim'])\n            resize_dims = (int(W * resize), int(H * resize))\n            newW, newH = resize_dims\n            crop_h = int((1 - np.random.uniform(*self.ida_aug_conf['bot_pct_lim'])) * newH) - fH\n            crop_w = int(np.random.uniform(0, max(0, newW - fW)))\n            crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)\n            flip = False\n            if self.ida_aug_conf['rand_flip'] and np.random.choice([0, 1]):\n                flip = True\n            rotate = np.random.uniform(*self.ida_aug_conf['rot_lim'])\n        else:\n            resize = max(fH / H, fW / W)\n            resize_dims = (int(W * resize), int(H * resize))\n            newW, newH = resize_dims\n            crop_h = int((1 - np.mean(self.ida_aug_conf['bot_pct_lim'])) * newH) - fH\n            crop_w = int(max(0, newW - fW) / 2)\n            crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)\n            flip = False\n            rotate = 0\n\n        return resize, resize_dims, crop, flip, rotate\n\n\n@PIPELINES.register_module()\nclass GlobalRotScaleTransImage(object):\n    def __init__(self,\n                 rot_range=[-0.3925, 0.3925],\n                 scale_ratio_range=[0.95, 1.05],\n                 translation_std=[0, 0, 0]):\n        self.rot_range = rot_range\n        self.scale_ratio_range = scale_ratio_range\n        self.translation_std = translation_std\n\n    def __call__(self, results):\n        # random rotate\n        rot_angle = np.random.uniform(*self.rot_range)\n        self.rotate_z(results, rot_angle)\n        results[\"gt_bboxes_3d\"].rotate(np.array(rot_angle))\n\n        # random scale\n        scale_ratio = np.random.uniform(*self.scale_ratio_range)\n        self.scale_xyz(results, scale_ratio)\n        results[\"gt_bboxes_3d\"].scale(scale_ratio)\n\n        # TODO: support translation\n\n        return results\n\n    def rotate_z(self, results, rot_angle):\n        rot_cos = torch.cos(torch.tensor(rot_angle))\n        rot_sin = torch.sin(torch.tensor(rot_angle))\n\n        rot_mat = torch.tensor([\n            [rot_cos, -rot_sin, 0, 0],\n            [rot_sin, rot_cos, 0, 0],\n            [0, 0, 1, 0],\n            [0, 0, 0, 1],\n        ])\n        rot_mat_inv = torch.inverse(rot_mat)\n\n        for view in range(len(results['lidar2img'])):\n            results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ rot_mat_inv).numpy()\n\n    def scale_xyz(self, results, scale_ratio):\n        scale_mat = torch.tensor([\n            [scale_ratio, 0, 0, 0],\n            [0, scale_ratio, 0, 0],\n            [0, 0, scale_ratio, 0],\n            [0, 0, 0, 1],\n        ])\n        scale_mat_inv = torch.inverse(scale_mat)\n\n        for view in range(len(results['lidar2img'])):\n            results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ scale_mat_inv).numpy()\n"
  },
  {
    "path": "loaders/ray_metrics.py",
    "content": "# Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting\n# Modified by Haisong Liu\nimport math\nimport copy\nimport numpy as np\nimport torch\nfrom torch.utils.cpp_extension import load\nfrom tqdm import tqdm\nfrom prettytable import PrettyTable\nfrom .ray_pq import Metric_RayPQ\n\n\ndvr = load(\"dvr\", sources=[\"lib/dvr/dvr.cpp\", \"lib/dvr/dvr.cu\"], verbose=True, extra_cuda_cflags=['-allow-unsupported-compiler'])\n_pc_range = [-40, -40, -1.0, 40, 40, 5.4]\n_voxel_size = 0.4\n\n\n# https://github.com/tarashakhurana/4d-occ-forecasting/blob/ff986082cd6ea10e67ab7839bf0e654736b3f4e2/test_fgbg.py#L29C1-L46C16\ndef get_rendered_pcds(origin, points, tindex, pred_dist):\n    pcds = []\n    \n    for t in range(len(origin)):\n        mask = (tindex == t)\n        # skip the ones with no data\n        if not mask.any():\n            continue\n        _pts = points[mask, :3]\n        # use ground truth lidar points for the raycasting direction\n        v = _pts - origin[t][None, :]\n        d = v / np.sqrt((v ** 2).sum(axis=1, keepdims=True))\n        pred_pts = origin[t][None, :] + d * pred_dist[mask][:, None]\n        pcds.append(torch.from_numpy(pred_pts))\n        \n    return pcds\n\n\ndef meshgrid3d(occ_size, pc_range):\n    W, H, D = occ_size\n    \n    xs = torch.linspace(0.5, W - 0.5, W).view(W, 1, 1).expand(W, H, D) / W\n    ys = torch.linspace(0.5, H - 0.5, H).view(1, H, 1).expand(W, H, D) / H\n    zs = torch.linspace(0.5, D - 0.5, D).view(1, 1, D).expand(W, H, D) / D\n    xs = xs * (pc_range[3] - pc_range[0]) + pc_range[0]\n    ys = ys * (pc_range[4] - pc_range[1]) + pc_range[1]\n    zs = zs * (pc_range[5] - pc_range[2]) + pc_range[2]\n    xyz = torch.stack((xs, ys, zs), -1)\n\n    return xyz\n\n\ndef generate_lidar_rays():\n    # prepare lidar ray angles\n    pitch_angles = []\n    for k in range(10):\n        angle = math.pi / 2 - math.atan(k + 1)\n        pitch_angles.append(-angle)\n    \n    # nuscenes lidar fov: [0.2107773983152201, -0.5439104895672159] (rad)\n    while pitch_angles[-1] < 0.21:\n        delta = pitch_angles[-1] - pitch_angles[-2]\n        pitch_angles.append(pitch_angles[-1] + delta)\n\n    lidar_rays = []\n    for pitch_angle in pitch_angles:\n        for azimuth_angle in np.arange(0, 360, 1):\n            azimuth_angle = np.deg2rad(azimuth_angle)\n\n            x = np.cos(pitch_angle) * np.cos(azimuth_angle)\n            y = np.cos(pitch_angle) * np.sin(azimuth_angle)\n            z = np.sin(pitch_angle)\n\n            lidar_rays.append((x, y, z))\n\n    return np.array(lidar_rays, dtype=np.float32)\n\n\ndef process_one_sample(sem_pred, lidar_rays, output_origin, instance_pred=None, occ_class_names=None):\n    # lidar origin in ego coordinate\n    # lidar_origin = torch.tensor([[[0.9858, 0.0000, 1.8402]]])\n    T = output_origin.shape[1]\n    pred_pcds_t = []\n\n    free_id = len(occ_class_names) - 1 \n    occ_pred = copy.deepcopy(sem_pred)\n    occ_pred[sem_pred < free_id] = 1\n    occ_pred[sem_pred == free_id] = 0\n    occ_pred = occ_pred.permute(2, 1, 0)\n    occ_pred = occ_pred[None, None, :].contiguous().float()\n\n    offset = torch.Tensor(_pc_range[:3])[None, None, :]\n    scaler = torch.Tensor([_voxel_size] * 3)[None, None, :]\n\n    lidar_tindex = torch.zeros([1, lidar_rays.shape[0]])\n    \n    for t in range(T): \n        lidar_origin = output_origin[:, t:t+1, :]  # [1, 1, 3]\n        lidar_endpts = lidar_rays[None] + lidar_origin  # [1, 15840, 3]\n\n        output_origin_render = ((lidar_origin - offset) / scaler).float()  # [1, 1, 3]\n        output_points_render = ((lidar_endpts - offset) / scaler).float()  # [1, N, 3]\n        output_tindex_render = lidar_tindex  # [1, N], all zeros\n\n        with torch.no_grad():\n            pred_dist, _, coord_index = dvr.render_forward(\n                occ_pred.cuda(),\n                output_origin_render.cuda(),\n                output_points_render.cuda(),\n                output_tindex_render.cuda(),\n                [1, 16, 200, 200],\n                \"test\"\n            )\n            pred_dist *= _voxel_size\n\n        pred_pcds = get_rendered_pcds(\n            lidar_origin[0].cpu().numpy(),\n            lidar_endpts[0].cpu().numpy(),\n            lidar_tindex[0].cpu().numpy(),\n            pred_dist[0].cpu().numpy()\n        )\n        coord_index = coord_index[0, :, :].int().cpu()  # [N, 3]\n\n        pred_label = sem_pred[coord_index[:, 0], coord_index[:, 1], coord_index[:, 2]][:, None]  # [N, 1]        \n        pred_dist = pred_dist[0, :, None].cpu()\n\n        if instance_pred is not None:\n            pred_instance = instance_pred[coord_index[:, 0], coord_index[:, 1], coord_index[:, 2]][:, None]  # [N, 1]\n            pred_pcds = torch.cat([pred_label.float(), pred_instance.float(), pred_dist], dim=-1)\n        else:\n            pred_pcds = torch.cat([pred_label.float(), pred_dist], dim=-1)\n\n        pred_pcds_t.append(pred_pcds)\n\n    pred_pcds_t = torch.cat(pred_pcds_t, dim=0)\n   \n    return pred_pcds_t.numpy()\n\n\ndef calc_rayiou(pcd_pred_list, pcd_gt_list, occ_class_names):\n    thresholds = [1, 2, 4]\n\n    gt_cnt = np.zeros([len(occ_class_names)])\n    pred_cnt = np.zeros([len(occ_class_names)])\n    tp_cnt = np.zeros([len(thresholds), len(occ_class_names)])\n\n    for pcd_pred, pcd_gt in zip(pcd_pred_list, pcd_gt_list):\n        for j, threshold in enumerate(thresholds):\n            # L1\n            depth_pred = pcd_pred[:, 1]\n            depth_gt = pcd_gt[:, 1]\n            l1_error = np.abs(depth_pred - depth_gt)\n            tp_dist_mask = (l1_error < threshold)\n            \n            for i, cls in enumerate(occ_class_names):\n                cls_id = occ_class_names.index(cls)\n                cls_mask_pred = (pcd_pred[:, 0] == cls_id)\n                cls_mask_gt = (pcd_gt[:, 0] == cls_id)\n\n                gt_cnt_i = cls_mask_gt.sum()\n                pred_cnt_i = cls_mask_pred.sum()\n                if j == 0:\n                    gt_cnt[i] += gt_cnt_i\n                    pred_cnt[i] += pred_cnt_i\n\n                tp_cls = cls_mask_gt & cls_mask_pred  # [N]\n                tp_mask = np.logical_and(tp_cls, tp_dist_mask)\n                tp_cnt[j][i] += tp_mask.sum()\n    \n    iou_list = []\n    for j, threshold in enumerate(thresholds):\n        iou_list.append((tp_cnt[j] / (gt_cnt + pred_cnt - tp_cnt[j]))[:-1])\n\n    return iou_list\n\n\ndef main_rayiou(sem_pred_list, sem_gt_list, lidar_origin_list, occ_class_names):\n    torch.cuda.empty_cache()\n\n    # generate lidar rays\n    lidar_rays = generate_lidar_rays()\n    lidar_rays = torch.from_numpy(lidar_rays)\n\n    pcd_pred_list, pcd_gt_list = [], []\n    for sem_pred, sem_gt, lidar_origins in tqdm(zip(sem_pred_list, sem_gt_list, lidar_origin_list), ncols=50):\n        sem_pred = torch.from_numpy(np.reshape(sem_pred, [200, 200, 16]))\n        sem_gt = torch.from_numpy(np.reshape(sem_gt, [200, 200, 16]))\n\n        pcd_pred = process_one_sample(sem_pred, lidar_rays, lidar_origins, occ_class_names=occ_class_names)\n        pcd_gt = process_one_sample(sem_gt, lidar_rays, lidar_origins, occ_class_names=occ_class_names)\n\n        # evalute on non-free rays\n        valid_mask = (pcd_gt[:, 0].astype(np.int32) != len(occ_class_names) - 1)\n        pcd_pred = pcd_pred[valid_mask]\n        pcd_gt = pcd_gt[valid_mask]\n\n        assert pcd_pred.shape == pcd_gt.shape\n        pcd_pred_list.append(pcd_pred)\n        pcd_gt_list.append(pcd_gt)\n\n    iou_list = calc_rayiou(pcd_pred_list, pcd_gt_list, occ_class_names)\n    rayiou = np.nanmean(iou_list)\n    rayiou_0 = np.nanmean(iou_list[0])\n    rayiou_1 = np.nanmean(iou_list[1])\n    rayiou_2 = np.nanmean(iou_list[2])\n    \n    table = PrettyTable([\n        'Class Names',\n        'RayIoU@1', 'RayIoU@2', 'RayIoU@4'\n    ])\n    table.float_format = '.3'\n\n    for i in range(len(occ_class_names) - 1):\n        table.add_row([\n            occ_class_names[i],\n            iou_list[0][i], iou_list[1][i], iou_list[2][i]\n        ], divider=(i == len(occ_class_names) - 2))\n    \n    table.add_row(['MEAN', rayiou_0, rayiou_1, rayiou_2])\n\n    print(table)\n\n    torch.cuda.empty_cache()\n\n    return {\n        'RayIoU': rayiou,\n        'RayIoU@1': rayiou_0,\n        'RayIoU@2': rayiou_1,\n        'RayIoU@4': rayiou_2,\n    }\n\n\ndef main_raypq(sem_pred_list, sem_gt_list, inst_pred_list, inst_gt_list, lidar_origin_list, occ_class_names):\n    torch.cuda.empty_cache()\n\n    eval_metrics_pq = Metric_RayPQ(\n        occ_class_names=occ_class_names,\n        num_classes=len(occ_class_names),\n        thresholds=[1, 2, 4]\n    )\n\n    # generate lidar rays\n    lidar_rays = generate_lidar_rays()\n    lidar_rays = torch.from_numpy(lidar_rays)\n\n    for sem_pred, sem_gt, inst_pred, inst_gt, lidar_origins in \\\n        tqdm(zip(sem_pred_list, sem_gt_list, inst_pred_list, inst_gt_list, lidar_origin_list), ncols=50):\n        sem_pred = torch.from_numpy(np.reshape(sem_pred, [200, 200, 16]))\n        sem_gt = torch.from_numpy(np.reshape(sem_gt, [200, 200, 16]))\n\n        inst_pred = torch.from_numpy(np.reshape(inst_pred, [200, 200, 16]))\n        inst_gt = torch.from_numpy(np.reshape(inst_gt, [200, 200, 16]))\n\n        pcd_pred = process_one_sample(sem_pred, lidar_rays, lidar_origins, instance_pred=inst_pred, occ_class_names=occ_class_names)\n        pcd_gt = process_one_sample(sem_gt, lidar_rays, lidar_origins, instance_pred=inst_gt, occ_class_names=occ_class_names)\n\n        # evalute on non-free rays\n        valid_mask = (pcd_gt[:, 0].astype(np.int32) != len(occ_class_names) - 1)\n        pcd_pred = pcd_pred[valid_mask]\n        pcd_gt = pcd_gt[valid_mask]\n\n        assert pcd_pred.shape == pcd_gt.shape\n        \n        sem_gt = pcd_gt[:, 0].astype(np.int32)\n        sem_pred = pcd_pred[:, 0].astype(np.int32)\n\n        instances_gt = pcd_gt[:, 1].astype(np.int32)\n        instances_pred = pcd_pred[:, 1].astype(np.int32)\n\n        # L1\n        depth_gt = pcd_gt[:, 2]\n        depth_pred = pcd_pred[:, 2]\n        l1_error = np.abs(depth_pred - depth_gt)\n\n        eval_metrics_pq.add_batch(sem_pred, sem_gt, instances_pred, instances_gt, l1_error)\n\n    torch.cuda.empty_cache()\n\n    return eval_metrics_pq.count_pq()\n"
  },
  {
    "path": "loaders/ray_pq.py",
    "content": "import numpy as np\nfrom prettytable import PrettyTable\n\n\nclass Metric_RayPQ:\n    def __init__(self,\n                 occ_class_names, \n                 num_classes=18,\n                 thresholds=[1, 2, 4]):\n        \"\"\"\n        Args:\n            ignore_index (llist): Class ids that not be considered in pq counting.\n        \"\"\"\n        if num_classes == 18 or num_classes == 17:\n            self.class_names = occ_class_names\n        else:\n            raise ValueError\n        \n        self.num_classes = num_classes\n        self.id_offset = 2 ** 16\n        self.eps = 1e-5\n        self.thresholds = thresholds\n        \n        self.min_num_points = 10\n        self.include = np.array(\n            [n for n in range(self.num_classes - 1)],\n            dtype=int)\n        self.cnt = 0\n        \n        # panoptic stuff\n        self.pan_tp = np.zeros([len(self.thresholds), num_classes], dtype=int)\n        self.pan_iou = np.zeros([len(self.thresholds), num_classes], dtype=np.double)\n        self.pan_fp = np.zeros([len(self.thresholds), num_classes], dtype=int)\n        self.pan_fn = np.zeros([len(self.thresholds), num_classes], dtype=int)\n        \n    def add_batch(self,semantics_pred,semantics_gt,instances_pred,instances_gt, l1_error):\n        self.cnt += 1\n        self.add_panoptic_sample(semantics_pred, semantics_gt, instances_pred, instances_gt, l1_error) \n    \n    def add_panoptic_sample(self, semantics_pred, semantics_gt, instances_pred, instances_gt, l1_error):\n        \"\"\"Add one sample of panoptic predictions and ground truths for\n        evaluation.\n\n        Args:\n            semantics_pred (np.ndarray): Semantic predictions.\n            semantics_gt (np.ndarray): Semantic ground truths.\n            instances_pred (np.ndarray): Instance predictions.\n            instances_gt (np.ndarray): Instance ground truths.\n        \"\"\"\n        # get instance_class_id from instance_gt\n        instance_class_ids = [self.num_classes - 1]\n        for i in range(1, instances_gt.max() + 1):\n            class_id = np.unique(semantics_gt[instances_gt == i])\n            # assert class_id.shape[0] == 1, \"each instance must belong to only one class\"\n            if class_id.shape[0] == 1:\n                instance_class_ids.append(class_id[0])\n            else:\n                instance_class_ids.append(self.num_classes - 1)\n        instance_class_ids = np.array(instance_class_ids)\n\n        instance_count = 1\n        final_instance_class_ids = []\n        final_instances = np.zeros_like(instances_gt)  # empty space has instance id \"0\"\n\n        for class_id in range(self.num_classes - 1):\n            if np.sum(semantics_gt == class_id) == 0:\n                continue\n\n            if self.class_names[class_id] in ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'motorcycle', 'bicycle', 'pedestrian']:\n                # treat as instances\n                for instance_id in range(len(instance_class_ids)):\n                    if instance_class_ids[instance_id] != class_id:\n                        continue\n                    final_instances[instances_gt == instance_id] = instance_count\n                    instance_count += 1\n                    final_instance_class_ids.append(class_id)\n            else:\n                # treat as semantics\n                final_instances[semantics_gt == class_id] = instance_count\n                instance_count += 1\n                final_instance_class_ids.append(class_id)\n                \n        instances_gt = final_instances\n        \n        # avoid zero (ignored label)\n        instances_pred = instances_pred + 1\n        instances_gt = instances_gt + 1\n        \n        for j, threshold in enumerate(self.thresholds):\n            tp_dist_mask = l1_error < threshold\n            # for each class (except the ignored ones)\n            for cl in self.include:\n                # get a class mask\n                pred_inst_in_cl_mask = semantics_pred == cl\n                gt_inst_in_cl_mask = semantics_gt == cl\n\n                # get instance points in class (makes outside stuff 0)\n                pred_inst_in_cl = instances_pred * pred_inst_in_cl_mask.astype(int)\n                gt_inst_in_cl = instances_gt * gt_inst_in_cl_mask.astype(int)\n\n                # generate the areas for each unique instance prediction\n                unique_pred, counts_pred = np.unique(\n                    pred_inst_in_cl[pred_inst_in_cl > 0], return_counts=True)\n                id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)}\n                matched_pred = np.array([False] * unique_pred.shape[0])\n\n                # generate the areas for each unique instance gt_np\n                unique_gt, counts_gt = np.unique(\n                    gt_inst_in_cl[gt_inst_in_cl > 0], return_counts=True)\n                id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)}\n                matched_gt = np.array([False] * unique_gt.shape[0])\n\n                # generate intersection using offset\n                valid_combos = np.logical_and(pred_inst_in_cl > 0,\n                                            gt_inst_in_cl > 0)\n                # add dist_mask\n                valid_combos = np.logical_and(valid_combos, tp_dist_mask)\n\n                id_offset_combo = pred_inst_in_cl[\n                    valid_combos] + self.id_offset * gt_inst_in_cl[valid_combos]\n                unique_combo, counts_combo = np.unique(\n                    id_offset_combo, return_counts=True)\n\n                # generate an intersection map\n                # count the intersections with over 0.5 IoU as TP\n                gt_labels = unique_combo // self.id_offset\n                pred_labels = unique_combo % self.id_offset\n                gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels])\n                pred_areas = np.array(\n                    [counts_pred[id2idx_pred[id]] for id in pred_labels])\n                intersections = counts_combo\n                unions = gt_areas + pred_areas - intersections\n                ious = intersections.astype(float) / unions.astype(float)\n\n                tp_indexes = ious > 0.5\n                self.pan_tp[j][cl] += np.sum(tp_indexes)\n                self.pan_iou[j][cl] += np.sum(ious[tp_indexes])\n\n                matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True\n                matched_pred[[id2idx_pred[id]\n                            for id in pred_labels[tp_indexes]]] = True\n\n                # count the FN\n                if len(counts_gt) > 0:\n                    self.pan_fn[j][cl] += np.sum(\n                        np.logical_and(counts_gt >= self.min_num_points,\n                                    ~matched_gt))\n\n                # count the FP\n                if len(matched_pred) > 0:\n                    self.pan_fp[j][cl] += np.sum(\n                        np.logical_and(counts_pred >= self.min_num_points,\n                                    ~matched_pred))\n    \n    def count_pq(self):\n        sq_all = self.pan_iou.astype(np.double) / np.maximum(\n            self.pan_tp.astype(np.double), self.eps)\n        rq_all = self.pan_tp.astype(np.double) / np.maximum(\n            self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double)\n            + 0.5 * self.pan_fn.astype(np.double), self.eps)\n        pq_all = sq_all * rq_all\n        \n        # mask classes not occurring in dataset\n        mask = (self.pan_tp + self.pan_fp + self.pan_fn) > 0\n        pq_all[~mask] = float('nan')\n\n        table = PrettyTable([\n            'Class Names',\n            'RayPQ@%d' % self.thresholds[0],\n            'RayPQ@%d' % self.thresholds[1],\n            'RayPQ@%d' % self.thresholds[2]\n        ])\n        table.float_format = '.3'\n\n        for i in range(len(self.class_names) - 1):\n            table.add_row([\n                self.class_names[i],\n                pq_all[0][i], pq_all[1][i], pq_all[2][i],\n            ], divider=(i == len(self.class_names) - 2))\n        \n        table.add_row([\n            'MEAN',\n            np.nanmean(pq_all[0]), np.nanmean(pq_all[1]), np.nanmean(pq_all[2])\n        ])\n\n        print(table)\n\n        return {\n            'RayPQ': np.nanmean(pq_all),\n            'RayPQ@1': np.nanmean(pq_all[0]),\n            'RayPQ@2': np.nanmean(pq_all[1]),\n            'RayPQ@4': np.nanmean(pq_all[2]),\n        }\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from .backbones import __all__\r\nfrom .bbox import __all__\r\nfrom .sparseocc import SparseOcc\r\nfrom .sparseocc_head import SparseOccHead\r\nfrom .sparseocc_transformer import SparseOccTransformer\r\nfrom .loss_utils import *\r\n\r\n__all__ = []\r\n"
  },
  {
    "path": "models/backbones/__init__.py",
    "content": "from .vovnet import VoVNet\n\n__all__ = ['VoVNet']\n"
  },
  {
    "path": "models/backbones/vovnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport warnings\nimport torch.utils.checkpoint as cp\nfrom collections import OrderedDict\nfrom mmcv.runner import BaseModule\nfrom mmdet.models.builder import BACKBONES\nfrom torch.nn.modules.batchnorm import _BatchNorm\n\n\nVoVNet19_slim_dw_eSE = {\n    'stem': [64, 64, 64],\n    'stage_conv_ch': [64, 80, 96, 112],\n    'stage_out_ch': [112, 256, 384, 512],\n    \"layer_per_block\": 3,\n    \"block_per_stage\": [1, 1, 1, 1],\n    \"eSE\": True,\n    \"dw\": True\n}\n\nVoVNet19_dw_eSE = {\n    'stem': [64, 64, 64],\n    \"stage_conv_ch\": [128, 160, 192, 224],\n    \"stage_out_ch\": [256, 512, 768, 1024],\n    \"layer_per_block\": 3,\n    \"block_per_stage\": [1, 1, 1, 1],\n    \"eSE\": True,\n    \"dw\": True\n}\n\nVoVNet19_slim_eSE = {\n    'stem': [64, 64, 128],\n    'stage_conv_ch': [64, 80, 96, 112],\n    'stage_out_ch': [112, 256, 384, 512],\n    'layer_per_block': 3,\n    'block_per_stage': [1, 1, 1, 1],\n    'eSE': True,\n    \"dw\": False\n}\n\nVoVNet19_eSE = {\n    'stem': [64, 64, 128],\n    \"stage_conv_ch\": [128, 160, 192, 224],\n    \"stage_out_ch\": [256, 512, 768, 1024],\n    \"layer_per_block\": 3,\n    \"block_per_stage\": [1, 1, 1, 1],\n    \"eSE\": True,\n    \"dw\": False\n}\n\nVoVNet39_eSE = {\n    'stem': [64, 64, 128],\n    \"stage_conv_ch\": [128, 160, 192, 224],\n    \"stage_out_ch\": [256, 512, 768, 1024],\n    \"layer_per_block\": 5,\n    \"block_per_stage\": [1, 1, 2, 2],\n    \"eSE\": True,\n    \"dw\": False\n}\n\nVoVNet57_eSE = {\n    'stem': [64, 64, 128],\n    \"stage_conv_ch\": [128, 160, 192, 224],\n    \"stage_out_ch\": [256, 512, 768, 1024],\n    \"layer_per_block\": 5,\n    \"block_per_stage\": [1, 1, 4, 3],\n    \"eSE\": True,\n    \"dw\": False\n}\n\nVoVNet99_eSE = {\n    'stem': [64, 64, 128],\n    \"stage_conv_ch\": [128, 160, 192, 224],\n    \"stage_out_ch\": [256, 512, 768, 1024],\n    \"layer_per_block\": 5,\n    \"block_per_stage\": [1, 3, 9, 3],\n    \"eSE\": True,\n    \"dw\": False\n}\n\n_STAGE_SPECS = {\n    \"V-19-slim-dw-eSE\": VoVNet19_slim_dw_eSE,\n    \"V-19-dw-eSE\": VoVNet19_dw_eSE,\n    \"V-19-slim-eSE\": VoVNet19_slim_eSE,\n    \"V-19-eSE\": VoVNet19_eSE,\n    \"V-39-eSE\": VoVNet39_eSE,\n    \"V-57-eSE\": VoVNet57_eSE,\n    \"V-99-eSE\": VoVNet99_eSE,\n}\n\n\ndef dw_conv3x3(in_channels, out_channels, module_name, postfix, stride=1, kernel_size=3, padding=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return [\n        (\n            '{}_{}/dw_conv3x3'.format(module_name, postfix),\n            nn.Conv2d(\n                in_channels,\n                out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                padding=padding,\n                groups=out_channels,\n                bias=False\n            )\n        ),\n        (\n            '{}_{}/pw_conv1x1'.format(module_name, postfix),\n            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=False)\n        ),\n        ('{}_{}/pw_norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)),\n        ('{}_{}/pw_relu'.format(module_name, postfix), nn.ReLU(inplace=True)),\n    ]\n\n\ndef conv3x3(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=3, padding=1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return [\n        (\n            f\"{module_name}_{postfix}/conv\",\n            nn.Conv2d(\n                in_channels,\n                out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                padding=padding,\n                groups=groups,\n                bias=False,\n            ),\n        ),\n        (f\"{module_name}_{postfix}/norm\", nn.BatchNorm2d(out_channels)),\n        (f\"{module_name}_{postfix}/relu\", nn.ReLU(inplace=True)),\n    ]\n\n\ndef conv1x1(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=1, padding=0):\n    \"\"\"1x1 convolution with padding\"\"\"\n    return [\n        (\n            f\"{module_name}_{postfix}/conv\",\n            nn.Conv2d(\n                in_channels,\n                out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                padding=padding,\n                groups=groups,\n                bias=False,\n            ),\n        ),\n        (f\"{module_name}_{postfix}/norm\", nn.BatchNorm2d(out_channels)),\n        (f\"{module_name}_{postfix}/relu\", nn.ReLU(inplace=True)),\n    ]\n\n\nclass Hsigmoid(nn.Module):\n    def __init__(self, inplace=True):\n        super(Hsigmoid, self).__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        return F.relu6(x + 3.0, inplace=self.inplace) / 6.0\n\n\nclass eSEModule(nn.Module):\n    def __init__(self, channel, reduction=4):\n        super(eSEModule, self).__init__()\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0)\n        self.hsigmoid = Hsigmoid()\n\n    def forward(self, x):\n        inputs = x\n        x = self.avg_pool(x)\n        x = self.fc(x)\n        x = self.hsigmoid(x)\n        return inputs * x\n\n\nclass _OSA_module(nn.Module):\n    def __init__(self, in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE=False, identity=False, depthwise=False, with_cp=False):\n        super(_OSA_module, self).__init__()\n        self.with_cp = with_cp\n\n        self.identity = identity\n        self.depthwise = depthwise\n        self.isReduced = False\n        self.layers = nn.ModuleList()\n        in_channel = in_ch\n\n        if self.depthwise and in_channel != stage_ch:\n            self.isReduced = True\n            self.conv_reduction = nn.Sequential(\n                OrderedDict(conv1x1(in_channel, stage_ch, \"{}_reduction\".format(module_name), \"0\"))\n            )\n\n        for i in range(layer_per_block):\n            if self.depthwise:\n                self.layers.append(nn.Sequential(OrderedDict(dw_conv3x3(stage_ch, stage_ch, module_name, i))))\n            else:\n                self.layers.append(nn.Sequential(OrderedDict(conv3x3(in_channel, stage_ch, module_name, i))))\n            in_channel = stage_ch\n\n        # feature aggregation\n        in_channel = in_ch + layer_per_block * stage_ch\n        self.concat = nn.Sequential(OrderedDict(conv1x1(in_channel, concat_ch, module_name, \"concat\")))\n\n        self.ese = eSEModule(concat_ch)\n\n    def _forward(self, x):\n        identity_feat = x\n\n        output = []\n        output.append(x)\n\n        if self.depthwise and self.isReduced:\n            x = self.conv_reduction(x)\n\n        for layer in self.layers:\n            x = layer(x)\n            output.append(x)\n\n        x = torch.cat(output, dim=1)\n        xt = self.concat(x)\n\n        xt = self.ese(xt)\n\n        if self.identity:\n            xt = xt + identity_feat\n\n        return xt\n\n    def forward(self, x):\n        if self.with_cp and self.training and x.requires_grad:\n            return cp.checkpoint(self._forward, x)\n        else:\n            return self._forward(x)\n\n\nclass _OSA_stage(nn.Sequential):\n    def __init__(self, in_ch, stage_ch, concat_ch, block_per_stage, layer_per_block, stage_num, SE=False, depthwise=False, with_cp=False):\n        super(_OSA_stage, self).__init__()\n        if not stage_num == 2:\n            self.add_module(\"Pooling\", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True))\n\n        if block_per_stage != 1:\n            SE = False\n\n        module_name = f\"OSA{stage_num}_1\"\n        self.add_module(\n            module_name, _OSA_module(in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE, depthwise=depthwise, with_cp=with_cp)\n        )\n\n        for i in range(block_per_stage - 1):\n            if i != block_per_stage - 2:  # last block\n                SE = False\n            module_name = f\"OSA{stage_num}_{i + 2}\"\n            self.add_module(\n                module_name,\n                _OSA_module(\n                    concat_ch,\n                    stage_ch,\n                    concat_ch,\n                    layer_per_block,\n                    module_name,\n                    SE,\n                    identity=True,\n                    depthwise=depthwise,\n                    with_cp=with_cp\n                ),\n            )\n\n\n@BACKBONES.register_module()\nclass VoVNet(BaseModule):\n    def __init__(self, spec_name, input_ch=3, out_features=None, frozen_stages=-1, norm_eval=True, with_cp=False, pretrained=None, init_cfg=None):\n        \"\"\"\n        Args:\n            input_ch(int) : the number of input channel\n            out_features (list[str]): name of the layers whose outputs should\n                be returned in forward. Can be anything in \"stem\", \"stage2\" ...\n        \"\"\"\n        super(VoVNet, self).__init__(init_cfg)\n        self.frozen_stages = frozen_stages\n        self.norm_eval = norm_eval\n\n        if isinstance(pretrained, str):\n            warnings.warn('DeprecationWarning: pretrained is deprecated, '\n                          'please use \"init_cfg\" instead')\n            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)\n        stage_specs = _STAGE_SPECS[spec_name]\n\n        stem_ch = stage_specs[\"stem\"]\n        config_stage_ch = stage_specs[\"stage_conv_ch\"]\n        config_concat_ch = stage_specs[\"stage_out_ch\"]\n        block_per_stage = stage_specs[\"block_per_stage\"]\n        layer_per_block = stage_specs[\"layer_per_block\"]\n        SE = stage_specs[\"eSE\"]\n        depthwise = stage_specs[\"dw\"]\n\n        self._out_features = out_features\n\n        # Stem module\n        conv_type = dw_conv3x3 if depthwise else conv3x3\n        stem = conv3x3(input_ch, stem_ch[0], \"stem\", \"1\", 2)\n        stem += conv_type(stem_ch[0], stem_ch[1], \"stem\", \"2\", 1)\n        stem += conv_type(stem_ch[1], stem_ch[2], \"stem\", \"3\", 2)\n        self.add_module(\"stem\", nn.Sequential((OrderedDict(stem))))\n        current_stirde = 4\n        self._out_feature_strides = {\"stem\": current_stirde, \"stage2\": current_stirde}\n        self._out_feature_channels = {\"stem\": stem_ch[2]}\n\n        stem_out_ch = [stem_ch[2]]\n        in_ch_list = stem_out_ch + config_concat_ch[:-1]\n\n        # OSA stages\n        self.stage_names = []\n        for i in range(4):  # num_stages\n            name = \"stage%d\" % (i + 2)  # stage 2 ... stage 5\n            self.stage_names.append(name)\n            self.add_module(\n                name,\n                _OSA_stage(\n                    in_ch_list[i],\n                    config_stage_ch[i],\n                    config_concat_ch[i],\n                    block_per_stage[i],\n                    layer_per_block,\n                    i + 2,\n                    SE,\n                    depthwise,\n                    with_cp=with_cp\n                ),\n            )\n\n            self._out_feature_channels[name] = config_concat_ch[i]\n            if not i == 0:\n                self._out_feature_strides[name] = current_stirde = int(current_stirde * 2)\n\n        # initialize weights\n        # self._initialize_weights()\n\n    def _initialize_weights(self):\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight)\n\n    def forward(self, x):\n        outputs = {}\n        x = self.stem(x)\n        if \"stem\" in self._out_features:\n            outputs[\"stem\"] = x\n        for name in self.stage_names:\n            x = getattr(self, name)(x)\n            if name in self._out_features:\n                outputs[name] = x\n\n        return outputs\n\n    def _freeze_stages(self):\n        if self.frozen_stages >= 0:\n            m = getattr(self, 'stem')\n            m.eval()\n            for param in m.parameters():\n                param.requires_grad = False\n\n        for i in range(1, self.frozen_stages + 1):\n            m = getattr(self, f'stage{i+1}')\n            m.eval()\n            for param in m.parameters():\n                param.requires_grad = False\n\n    def train(self, mode=True):\n        \"\"\"Convert the model into training mode while keep normalization layer\n        freezed.\"\"\"\n        super(VoVNet, self).train(mode)\n        self._freeze_stages()\n        if mode and self.norm_eval:\n            for m in self.modules():\n                # trick: eval have effect on BatchNorm only\n                if isinstance(m, _BatchNorm):\n                    m.eval()\n"
  },
  {
    "path": "models/bbox/__init__.py",
    "content": "from .assigners import __all__\nfrom .coders import __all__\nfrom .match_costs import __all__"
  },
  {
    "path": "models/bbox/assigners/__init__.py",
    "content": "from .hungarian_assigner_3d import HungarianAssigner3D\n\n__all__ = ['HungarianAssigner3D']\n"
  },
  {
    "path": "models/bbox/assigners/hungarian_assigner_3d.py",
    "content": "import torch\n\nfrom mmdet.core.bbox.builder import BBOX_ASSIGNERS\nfrom mmdet.core.bbox.assigners import AssignResult\nfrom mmdet.core.bbox.assigners import BaseAssigner\nfrom mmdet.core.bbox.match_costs import build_match_cost\nfrom ..utils import normalize_bbox\n\ntry:\n    from scipy.optimize import linear_sum_assignment\nexcept ImportError:\n    linear_sum_assignment = None\n\n\n@BBOX_ASSIGNERS.register_module()\nclass HungarianAssigner3D(BaseAssigner):\n    def __init__(self,\n                 cls_cost=dict(type='ClassificationCost', weight=1.),\n                 reg_cost=dict(type='BBoxL1Cost', weight=1.0),\n                 iou_cost=dict(type='IoUCost', weight=0.0),\n                 pc_range=None):\n        self.cls_cost = build_match_cost(cls_cost)\n        self.reg_cost = build_match_cost(reg_cost)\n        self.iou_cost = build_match_cost(iou_cost)\n        self.pc_range = pc_range\n\n    def assign(self,\n               bbox_pred,\n               cls_pred,\n               gt_bboxes,\n               gt_labels,\n               gt_bboxes_ignore=None,\n               code_weights=None,\n               with_velo=False):\n        assert gt_bboxes_ignore is None, \\\n            'Only case when gt_bboxes_ignore is None is supported.'\n        num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)\n\n        # 1. assign -1 by default\n        assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),\n                                              -1,\n                                              dtype=torch.long)\n        assigned_labels = bbox_pred.new_full((num_bboxes, ),\n                                             -1,\n                                             dtype=torch.long)\n        if num_gts == 0 or num_bboxes == 0:\n            # No ground truth or boxes, return empty assignment\n            if num_gts == 0:\n                # No ground truth, assign all to background\n                assigned_gt_inds[:] = 0\n            return AssignResult(\n                num_gts, assigned_gt_inds, None, labels=assigned_labels)\n        \n        # 2. compute the weighted costs\n        # classification and bboxcost.\n        cls_cost = self.cls_cost(cls_pred, gt_labels)\n        # regression L1 cost\n        normalized_gt_bboxes = normalize_bbox(gt_bboxes)\n        \n        if code_weights is not None:\n            bbox_pred = bbox_pred * code_weights\n            normalized_gt_bboxes = normalized_gt_bboxes * code_weights\n        \n        if with_velo:\n            reg_cost = self.reg_cost(bbox_pred, normalized_gt_bboxes)\n        else:\n            reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8])\n        \n        # weighted sum of above two costs\n        cost = cls_cost + reg_cost\n        \n        # 3. do Hungarian matching on CPU using linear_sum_assignment\n        cost = cost.detach().cpu()\n        cost = torch.nan_to_num(cost, nan=100.0, posinf=100.0, neginf=-100.0)\n        \n        if linear_sum_assignment is None:\n            raise ImportError('Please run \"pip install scipy\" '\n                              'to install scipy first.')\n        \n        matched_row_inds, matched_col_inds = linear_sum_assignment(cost)\n        matched_row_inds = torch.from_numpy(matched_row_inds).to(\n            bbox_pred.device)\n        matched_col_inds = torch.from_numpy(matched_col_inds).to(\n            bbox_pred.device)\n\n        # 4. assign backgrounds and foregrounds\n        # assign all indices to backgrounds first\n        assigned_gt_inds[:] = 0\n        # assign foregrounds based on matching results\n        assigned_gt_inds[matched_row_inds] = matched_col_inds + 1\n        assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]\n        return AssignResult(\n            num_gts, assigned_gt_inds, None, labels=assigned_labels)\n"
  },
  {
    "path": "models/bbox/coders/__init__.py",
    "content": "from .nms_free_coder import NMSFreeCoder\n\n__all__ = ['NMSFreeCoder']\n"
  },
  {
    "path": "models/bbox/coders/nms_free_coder.py",
    "content": "import torch\n\nfrom mmdet.core.bbox import BaseBBoxCoder\nfrom mmdet.core.bbox.builder import BBOX_CODERS\nfrom ..utils import denormalize_bbox\n\n\n@BBOX_CODERS.register_module()\nclass NMSFreeCoder(BaseBBoxCoder):\n    \"\"\"Bbox coder for NMS-free detector.\n    Args:\n        pc_range (list[float]): Range of point cloud.\n        post_center_range (list[float]): Limit of the center.\n            Default: None.\n        max_num (int): Max number to be kept. Default: 100.\n        score_threshold (float): Threshold to filter boxes based on score.\n            Default: None.\n        code_size (int): Code size of bboxes. Default: 9\n    \"\"\"\n    def __init__(self,\n                 pc_range,\n                 voxel_size=None,\n                 post_center_range=None,\n                 max_num=100,\n                 score_threshold=None,\n                 num_classes=10):\n        self.pc_range = pc_range\n        self.voxel_size = voxel_size\n        self.post_center_range = post_center_range\n        self.max_num = max_num\n        self.score_threshold = score_threshold\n        self.num_classes = num_classes\n\n    def encode(self):\n        pass\n\n    def decode_single(self, cls_scores, bbox_preds):\n        \"\"\"Decode bboxes.\n        Args:\n            cls_scores (Tensor): Outputs from the classification head, \\\n                shape [num_query, cls_out_channels]. Note \\\n                cls_out_channels should includes background.\n            bbox_preds (Tensor): Outputs from the regression \\\n                head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \\\n                Shape [num_query, 9].\n        Returns:\n            list[dict]: Decoded boxes.\n        \"\"\"\n        max_num = self.max_num\n\n        cls_scores = cls_scores.sigmoid()\n        scores, indexs = cls_scores.view(-1).topk(max_num)\n        labels = indexs % self.num_classes\n        bbox_index = torch.div(indexs, self.num_classes, rounding_mode='trunc')\n        bbox_preds = bbox_preds[bbox_index]\n\n        final_box_preds = denormalize_bbox(bbox_preds)   \n        final_scores = scores \n        final_preds = labels \n\n        # use score threshold\n        if self.score_threshold is not None:\n            thresh_mask = final_scores > self.score_threshold\n\n        if self.post_center_range is not None:\n            limit = torch.tensor(self.post_center_range, device=scores.device)\n            mask = (final_box_preds[..., :3] >= limit[:3]).all(1)\n            mask &= (final_box_preds[..., :3] <= limit[3:]).all(1)\n\n            if self.score_threshold:\n                mask &= thresh_mask\n\n            boxes3d = final_box_preds[mask]\n            scores = final_scores[mask]\n            labels = final_preds[mask]\n            predictions_dict = {\n                'bboxes': boxes3d,\n                'scores': scores,\n                'labels': labels\n            }\n\n        else:\n            raise NotImplementedError(\n                'Need to reorganize output as a batch, only '\n                'support post_center_range is not None for now!'\n            )\n\n        return predictions_dict\n\n    def decode(self, preds_dicts):\n        \"\"\"Decode bboxes.\n        Args:\n            all_cls_scores (Tensor): Outputs from the classification head, \\\n                shape [nb_dec, bs, num_query, cls_out_channels]. Note \\\n                cls_out_channels should includes background.\n            all_bbox_preds (Tensor): Sigmoid outputs from the regression \\\n                head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \\\n                Shape [nb_dec, bs, num_query, 9].\n        Returns:\n            list[dict]: Decoded boxes.\n        \"\"\"\n        all_cls_scores = preds_dicts['all_cls_scores'][-1]\n        all_bbox_preds = preds_dicts['all_bbox_preds'][-1]\n        \n        batch_size = all_cls_scores.size()[0]\n        predictions_list = []\n        for i in range(batch_size):\n            predictions_list.append(self.decode_single(all_cls_scores[i], all_bbox_preds[i]))\n\n        return predictions_list\n"
  },
  {
    "path": "models/bbox/match_costs/__init__.py",
    "content": "from .match_cost import BBox3DL1Cost\n\n__all__ = ['BBox3DL1Cost']"
  },
  {
    "path": "models/bbox/match_costs/match_cost.py",
    "content": "import torch\nfrom mmdet.core.bbox.match_costs.builder import MATCH_COST\n\n\n@MATCH_COST.register_module()\nclass BBox3DL1Cost(object):\n    \"\"\"BBox3DL1Cost.\n     Args:\n         weight (int | float, optional): loss_weight\n    \"\"\"\n\n    def __init__(self, weight=1.0):\n        self.weight = weight\n\n    def __call__(self, bbox_pred, gt_bboxes):\n        \"\"\"\n        Args:\n            bbox_pred (Tensor): Predicted boxes with normalized coordinates\n                (cx, cy, w, h), which are all in range [0, 1]. Shape\n                [num_query, 4].\n            gt_bboxes (Tensor): Ground truth boxes with normalized\n                coordinates (x1, y1, x2, y2). Shape [num_gt, 4].\n        Returns:\n            torch.Tensor: bbox_cost value with weight\n        \"\"\"\n        bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)\n        return bbox_cost * self.weight\n\n\n@MATCH_COST.register_module()\nclass BBoxBEVL1Cost(object):\n    def __init__(self, weight, pc_range):\n        self.weight = weight\n        self.pc_range = pc_range\n\n    def __call__(self, bboxes, gt_bboxes):\n        pc_start = bboxes.new(self.pc_range[0:2])\n        pc_range = bboxes.new(self.pc_range[3:5]) - bboxes.new(self.pc_range[0:2])\n        # normalize the box center to [0, 1]\n        normalized_bboxes_xy = (bboxes[:, :2] - pc_start) / pc_range\n        normalized_gt_bboxes_xy = (gt_bboxes[:, :2] - pc_start) / pc_range\n        reg_cost = torch.cdist(normalized_bboxes_xy, normalized_gt_bboxes_xy, p=1)\n        return reg_cost * self.weight\n\n\n@MATCH_COST.register_module()\nclass IoU3DCost(object):\n    def __init__(self, weight):\n        self.weight = weight\n\n    def __call__(self, iou):\n        iou_cost = - iou\n        return iou_cost * self.weight\n"
  },
  {
    "path": "models/bbox/utils.py",
    "content": "import torch \n\n\ndef normalize_bbox(bboxes):\n    cx = bboxes[..., 0:1]\n    cy = bboxes[..., 1:2]\n    cz = bboxes[..., 2:3]\n    w = bboxes[..., 3:4].log()\n    l = bboxes[..., 4:5].log()\n    h = bboxes[..., 5:6].log()\n    rot = bboxes[..., 6:7]\n\n    if bboxes.size(-1) > 7:\n        vx = bboxes[..., 7:8]\n        vy = bboxes[..., 8:9]\n        out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos(), vx, vy], dim=-1)\n    else:\n        out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos()], dim=-1)\n\n    return out\n\n\ndef denormalize_bbox(normalized_bboxes):\n    rot_sin = normalized_bboxes[..., 6:7]\n    rot_cos = normalized_bboxes[..., 7:8]\n    rot = torch.atan2(rot_sin, rot_cos)\n\n    cx = normalized_bboxes[..., 0:1]\n    cy = normalized_bboxes[..., 1:2]\n    cz = normalized_bboxes[..., 4:5]\n\n    w = normalized_bboxes[..., 2:3].exp()\n    l = normalized_bboxes[..., 3:4].exp()\n    h = normalized_bboxes[..., 5:6].exp()\n\n    if normalized_bboxes.size(-1) > 8:\n        vx = normalized_bboxes[..., 8:9]\n        vy = normalized_bboxes[..., 9:10]\n        out = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1)\n    else:\n        out = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1)\n\n    return out\n\n\ndef encode_bbox(bboxes, pc_range=None):\n    xyz = bboxes[..., 0:3].clone()\n    wlh = bboxes[..., 3:6].log()\n    rot = bboxes[..., 6:7]\n\n    if pc_range is not None:\n        xyz[..., 0] = (xyz[..., 0] - pc_range[0]) / (pc_range[3] - pc_range[0])\n        xyz[..., 1] = (xyz[..., 1] - pc_range[1]) / (pc_range[4] - pc_range[1])\n        xyz[..., 2] = (xyz[..., 2] - pc_range[2]) / (pc_range[5] - pc_range[2])\n\n    if bboxes.shape[-1] > 7:\n        vel = bboxes[..., 7:9].clone()\n        return torch.cat([xyz, wlh, rot.sin(), rot.cos(), vel], dim=-1)\n    else:\n        return torch.cat([xyz, wlh, rot.sin(), rot.cos()], dim=-1)\n\n\ndef decode_bbox(bboxes, pc_range=None):\n    xyz = bboxes[..., 0:3].clone()\n    wlh = bboxes[..., 3:6].exp()\n    rot = torch.atan2(bboxes[..., 6:7], bboxes[..., 7:8])\n\n    if pc_range is not None:\n        xyz[..., 0] = xyz[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0]\n        xyz[..., 1] = xyz[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1]\n        xyz[..., 2] = xyz[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2]\n\n    if bboxes.shape[-1] > 8:\n        vel = bboxes[..., 8:10].clone()\n        return torch.cat([xyz, wlh, rot, vel], dim=-1)\n    else:\n        return torch.cat([xyz, wlh, rot], dim=-1)\n\ndef bbox2occrange(bboxes, occ_size, query_cube_size=None):\n    \"\"\"\n    xyz in [0, 1]\n    wlh in [0, 1]\n    \"\"\"\n    xyz = bboxes[..., 0:3].clone()\n    if query_cube_size is not None:\n        wlh = torch.zeros_like(xyz)\n        wlh[..., 0] = query_cube_size[0]\n        wlh[..., 1] = query_cube_size[1]\n        wlh[..., 2] = query_cube_size[2]\n    else:\n        wlh = bboxes[..., 3:6]\n        wlh[..., 0] = wlh[..., 0] * occ_size[0]\n        wlh[..., 1] = wlh[..., 1] * occ_size[1]\n        wlh[..., 2] = wlh[..., 2] * occ_size[2]\n    \n    xyz[..., 0] = xyz[..., 0] * occ_size[0]\n    xyz[..., 1] = xyz[..., 1] * occ_size[1]\n    xyz[..., 2] = xyz[..., 2] * occ_size[2]\n    \n    xyz = torch.round(xyz)\n        \n    low_bound = torch.round(xyz - wlh/2)\n    high_bound = torch.round(xyz + wlh/2)\n    \n    return torch.cat((low_bound, high_bound), dim=-1).long()\n\ndef occrange2bbox(occ_range, occ_size, pc_range):\n    \"\"\"\n    Return: xyz in [0, 1], wlh in [0, pc_range_size)\n    \"\"\"\n    xyz = (occ_range[..., :3] + occ_range[..., 3:]).to(torch.float32) / 2\n    xyz[..., 0] /= occ_size[0]\n    xyz[..., 1] /= occ_size[1]\n    xyz[..., 2] /= occ_size[2]\n    wlh = (occ_range[..., 3:] - occ_range[..., :3]).to(torch.float32)\n    wlh[..., 0] *= (pc_range[3] - pc_range[0]) / occ_size[0]\n    wlh[..., 1] *= (pc_range[4] - pc_range[1]) / occ_size[1]\n    wlh[..., 2] *= (pc_range[5] - pc_range[2]) / occ_size[2]\n    return torch.cat((xyz, wlh), dim=-1)"
  },
  {
    "path": "models/checkpoint.py",
    "content": "# This page is completely copied from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint\n# If you are using torch 2.0 or higher, you can safely delete this page and import the related functions from official PyTorch\n\nimport torch\nimport warnings\nimport weakref\nfrom typing import Any, Iterable, List, Tuple\n\n__all__ = [\n    \"checkpoint\", \"checkpoint_sequential\", \"CheckpointFunction\",\n    \"check_backward_validity\", \"detach_variable\", \"get_device_states\",\n    \"set_device_states\",\n]\n\ndef detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:\n    if isinstance(inputs, tuple):\n        out = []\n        for inp in inputs:\n            if not isinstance(inp, torch.Tensor):\n                out.append(inp)\n                continue\n\n            x = inp.detach()\n            x.requires_grad = inp.requires_grad\n            out.append(x)\n        return tuple(out)\n    else:\n        raise RuntimeError(\n            \"Only tuple of tensors is supported. Got Unsupported input type: \", type(inputs).__name__)\n\n\ndef check_backward_validity(inputs: Iterable[Any]) -> None:\n    if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):\n        warnings.warn(\"None of the inputs have requires_grad=True. Gradients will be None\")\n\n\n# We can't know if the run_fn will internally move some args to different devices,\n# which would require logic to preserve rng states for those devices as well.\n# We could paranoically stash and restore ALL the rng states for all visible devices,\n# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for\n# the device of all Tensor args.\n#\n# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?\ndef get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:\n    # This will not error out if \"arg\" is a CPU tensor or a non-tensor type because\n    # the conditionals short-circuit.\n    fwd_gpu_devices = list({arg.get_device() for arg in args\n                            if isinstance(arg, torch.Tensor) and arg.is_cuda})\n\n    fwd_gpu_states = []\n    for device in fwd_gpu_devices:\n        with torch.cuda.device(device):\n            fwd_gpu_states.append(torch.cuda.get_rng_state())\n\n    return fwd_gpu_devices, fwd_gpu_states\n\n\ndef set_device_states(devices, states) -> None:\n    for device, state in zip(devices, states):\n        with torch.cuda.device(device):\n            torch.cuda.set_rng_state(state)\n\ndef _get_autocast_kwargs():\n    gpu_autocast_kwargs = {\"enabled\": torch.is_autocast_enabled(),\n                           \"dtype\": torch.get_autocast_gpu_dtype(),\n                           \"cache_enabled\": torch.is_autocast_cache_enabled()}\n\n    cpu_autocast_kwargs = {\"enabled\": torch.is_autocast_cpu_enabled(),\n                           \"dtype\": torch.get_autocast_cpu_dtype(),\n                           \"cache_enabled\": torch.is_autocast_cache_enabled()}\n\n    return gpu_autocast_kwargs, cpu_autocast_kwargs\n\nclass CheckpointFunction(torch.autograd.Function):\n\n    @staticmethod\n    def forward(ctx, run_function, preserve_rng_state, *args):\n        check_backward_validity(args)\n        ctx.run_function = run_function\n        ctx.preserve_rng_state = preserve_rng_state\n        # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.\n        ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()\n        if preserve_rng_state:\n            ctx.fwd_cpu_state = torch.get_rng_state()\n            # Don't eagerly initialize the cuda context by accident.\n            # (If the user intends that the context is initialized later, within their\n            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,\n            # we have no way to anticipate this will happen before we run the function.)\n            ctx.had_cuda_in_fwd = False\n            if torch.cuda._initialized:\n                ctx.had_cuda_in_fwd = True\n                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)\n\n        # Save non-tensor inputs in ctx, keep a placeholder None for tensors\n        # to be filled out during the backward.\n        ctx.inputs = []\n        ctx.tensor_indices = []\n        tensor_inputs = []\n        for i, arg in enumerate(args):\n            if torch.is_tensor(arg):\n                tensor_inputs.append(arg)\n                ctx.tensor_indices.append(i)\n                ctx.inputs.append(None)\n            else:\n                ctx.inputs.append(arg)\n\n        ctx.save_for_backward(*tensor_inputs)\n\n        with torch.no_grad():\n            outputs = run_function(*args)\n        return outputs\n\n    @staticmethod\n    def backward(ctx, *args):\n        if not torch.autograd._is_checkpoint_valid():\n            raise RuntimeError(\n                \"Checkpointing is not compatible with .grad() or when an `inputs` parameter\"\n                \" is passed to .backward(). Please use .backward() and do not pass its `inputs`\"\n                \" argument.\")\n        # Copy the list to avoid modifying original list.\n        inputs = list(ctx.inputs)\n        tensor_indices = ctx.tensor_indices\n        tensors = ctx.saved_tensors\n\n        # Fill in inputs with appropriate saved tensors.\n        for i, idx in enumerate(tensor_indices):\n            inputs[idx] = tensors[i]\n\n        # Stash the surrounding rng state, and mimic the state that was\n        # present at this time during forward.  Restore the surrounding state\n        # when we're done.\n        rng_devices = []\n        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:\n            rng_devices = ctx.fwd_gpu_devices\n        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):\n            if ctx.preserve_rng_state:\n                torch.set_rng_state(ctx.fwd_cpu_state)\n                if ctx.had_cuda_in_fwd:\n                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)\n            detached_inputs = detach_variable(tuple(inputs))\n            with torch.enable_grad(), \\\n                 torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \\\n                 torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):\n                outputs = ctx.run_function(*detached_inputs)\n\n        if isinstance(outputs, torch.Tensor):\n            outputs = (outputs,)\n\n        # run backward() with only tensor that requires grad\n        outputs_with_grad = []\n        args_with_grad = []\n        for i in range(len(outputs)):\n            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:\n                outputs_with_grad.append(outputs[i])\n                args_with_grad.append(args[i])\n        if len(outputs_with_grad) == 0:\n            raise RuntimeError(\n                \"none of output has requires_grad=True,\"\n                \" this checkpoint() is not necessary\")\n        torch.autograd.backward(outputs_with_grad, args_with_grad)\n        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None\n                      for inp in detached_inputs)\n\n        return (None, None) + grads\n\n\ndef checkpoint(function, *args, use_reentrant: bool = True, **kwargs):\n    r\"\"\"Checkpoint a model or part of the model\n\n    Checkpointing works by trading compute for memory. Rather than storing all\n    intermediate activations of the entire computation graph for computing\n    backward, the checkpointed part does **not** save intermediate activations,\n    and instead recomputes them in backward pass. It can be applied on any part\n    of a model.\n\n    Specifically, in the forward pass, :attr:`function` will run in\n    :func:`torch.no_grad` manner, i.e., not storing the intermediate\n    activations. Instead, the forward pass saves the inputs tuple and the\n    :attr:`function` parameter. In the backwards pass, the saved inputs and\n    :attr:`function` is retrieved, and the forward pass is computed on\n    :attr:`function` again, now tracking the intermediate activations, and then\n    the gradients are calculated using these activation values.\n\n    The output of :attr:`function` can contain non-Tensor values and gradient\n    recording is only performed for the Tensor values. Note that if the output\n    consists of nested structures (ex: custom objects, lists, dicts etc.)\n    consisting of Tensors, these Tensors nested in custom structures will not\n    be considered as part of autograd.\n\n\n    .. warning::\n        If :attr:`function` invocation during backward does anything different\n        than the one during forward, e.g., due to some global variable, the\n        checkpointed version won't be equivalent, and unfortunately it can't be\n        detected.\n\n    .. warning::\n        If ``use_reentrant=True`` is specified, then if the checkpointed segment\n        contains tensors detached from the computational graph by `detach()` or\n        `torch.no_grad()`, the backward pass will raise an error. This is\n        because `checkpoint` makes all the outputs require gradients which\n        causes issues when a tensor is defined to have no gradient in the model.\n        To circumvent this, detach the tensors outside of the `checkpoint`\n        function. Note that the checkpointed segment can contain tensors\n        detached from the computational graph if ``use_reentrant=False`` is\n        specified.\n\n    .. warning::\n        If ``use_reentrant=True`` is specified, at least one of the inputs needs\n        to have :code:`requires_grad=True` if grads are needed for model inputs,\n        otherwise the checkpointed part of the model won't have gradients. At\n        least one of the outputs needs to have :code:`requires_grad=True` as\n        well. Note that this does not apply if ``use_reentrant=False`` is\n        specified.\n\n    .. warning::\n        If ``use_reentrant=True`` is specified, checkpointing currently only\n        supports :func:`torch.autograd.backward` and only if its `inputs`\n        argument is not passed. :func:`torch.autograd.grad`\n        is not supported. If ``use_reentrant=False`` is specified, checkpointing\n        will work with :func:`torch.autograd.grad`.\n\n    Args:\n        function: describes what to run in the forward pass of the model or\n            part of the model. It should also know how to handle the inputs\n            passed as the tuple. For example, in LSTM, if user passes\n            ``(activation, hidden)``, :attr:`function` should correctly use the\n            first input as ``activation`` and the second input as ``hidden``\n        preserve_rng_state(bool, optional):  Omit stashing and restoring\n            the RNG state during each checkpoint.\n            Default: ``True``\n        use_reentrant(bool, optional): Use checkpointing\n            implementation that requires re-entrant autograd.\n            If ``use_reentrant=False`` is specified, ``checkpoint`` will use an\n            implementation that does not require re-entrant autograd. This\n            allows ``checkpoint`` to support additional functionality, such as\n            working as expected with ``torch.autograd.grad`` and support for\n            keyword arguments input into the checkpointed function. Note that future\n            versions of PyTorch will default to ``use_reentrant=False``.\n            Default: ``True``\n        args: tuple containing inputs to the :attr:`function`\n\n    Returns:\n        Output of running :attr:`function` on :attr:`*args`\n    \"\"\"\n    # Hack to mix *args with **kwargs in a python 2.7-compliant way\n    preserve = kwargs.pop('preserve_rng_state', True)\n    if kwargs and use_reentrant:\n        raise ValueError(\"Unexpected keyword arguments: \" + \",\".join(arg for arg in kwargs))\n\n    if use_reentrant:\n        return CheckpointFunction.apply(function, preserve, *args)\n    else:\n        return _checkpoint_without_reentrant(\n            function,\n            preserve,\n            *args,\n            **kwargs,\n        )\n\n\ndef checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):\n    r\"\"\"A helper function for checkpointing sequential models.\n\n    Sequential models execute a list of modules/functions in order\n    (sequentially). Therefore, we can divide such a model in various segments\n    and checkpoint each segment. All segments except the last will run in\n    :func:`torch.no_grad` manner, i.e., not storing the intermediate\n    activations. The inputs of each checkpointed segment will be saved for\n    re-running the segment in the backward pass.\n\n    See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.\n\n    .. warning::\n        Checkpointing currently only supports :func:`torch.autograd.backward`\n        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`\n        is not supported.\n\n    .. warning:\n        At least one of the inputs needs to have :code:`requires_grad=True` if\n        grads are needed for model inputs, otherwise the checkpointed part of the\n        model won't have gradients.\n\n    .. warning:\n        Since PyTorch 1.4, it allows only one Tensor as the input and\n        intermediate outputs, just like :class:`torch.nn.Sequential`.\n\n    Args:\n        functions: A :class:`torch.nn.Sequential` or the list of modules or\n            functions (comprising the model) to run sequentially.\n        segments: Number of chunks to create in the model\n        input: A Tensor that is input to :attr:`functions`\n        preserve_rng_state(bool, optional):  Omit stashing and restoring\n            the RNG state during each checkpoint.\n            Default: ``True``\n        use_reentrant(bool, optional): Use checkpointing\n            implementation that requires re-entrant autograd.\n            If ``use_reentrant=False`` is specified, ``checkpoint`` will use an\n            implementation that does not require re-entrant autograd. This\n            allows ``checkpoint`` to support additional functionality, such as\n            working as expected with ``torch.autograd.grad`` and support for\n            keyword arguments input into the checkpointed function.\n            Default: ``True``\n\n    Returns:\n        Output of running :attr:`functions` sequentially on :attr:`*inputs`\n\n    Example:\n        >>> # xdoctest: +SKIP(\"stub\")\n        >>> model = nn.Sequential(...)\n        >>> input_var = checkpoint_sequential(model, chunks, input_var)\n    \"\"\"\n    # Hack for keyword-only parameter in a python 2.7-compliant way\n    preserve = kwargs.pop('preserve_rng_state', True)\n    if kwargs:\n        raise ValueError(\"Unexpected keyword arguments: \" + \",\".join(arg for arg in kwargs))\n\n    def run_function(start, end, functions):\n        def forward(input):\n            for j in range(start, end + 1):\n                input = functions[j](input)\n            return input\n        return forward\n\n    if isinstance(functions, torch.nn.Sequential):\n        functions = list(functions.children())\n\n    segment_size = len(functions) // segments\n    # the last chunk has to be non-volatile\n    end = -1\n    for start in range(0, segment_size * (segments - 1), segment_size):\n        end = start + segment_size - 1\n        input = checkpoint(\n            run_function(start, end, functions),\n            input,\n            use_reentrant=use_reentrant,\n            preserve_rng_state=preserve\n        )\n    return run_function(end + 1, len(functions) - 1, functions)(input)\n\n\ndef _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs):\n    \"\"\"Checkpointining without re-entrant autograd\n    Args:\n        function: describes what to run in the forward pass of the model or\n            part of the model. It should also know how to handle the inputs\n            passed as the tuple. For example, in LSTM, if user passes\n            ``(activation, hidden)``, :attr:`function` should correctly use the\n            first input as ``activation`` and the second input as ``hidden``\n        preserve_rng_state(bool, optional):  Omit stashing and restoring\n            the RNG state during each checkpoint.\n            Default: ``True``\n        *args: Arguments to pass in to the given ``function``.\n        **kwargs: Keyword arguments to pass into the given ``function``.\n    \"\"\"\n    # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.\n    gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()\n\n    if preserve_rng_state:\n        fwd_cpu_state = torch.get_rng_state()\n        # Don't eagerly initialize the cuda context by accident.\n        # (If the user intends that the context is initialized later, within their\n        # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,\n        # we have no way to anticipate this will happen before we run the function.\n        # If they do so, we raise an error.)\n        had_cuda_in_fwd = False\n        if torch.cuda._initialized:\n            had_cuda_in_fwd = True\n            fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)\n\n    # Custom class to be able to take weak references\n    class Holder():\n        pass\n    # The Holder object for each of the saved object is saved directly on the\n    # SavedVariable and is cleared when reset_data() is called on it. We MUST make\n    # sure that this is the only object having an owning reference to ensure that\n    # the Tensor stored in storage is deleted as soon as the corresponding SavedVariable\n    # data is cleared.\n    storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()\n    weak_holder_list = []\n\n    def pack(x):\n        # TODO(varal7): Instead of returning abstract object, we can return things metadata (such as\n        # size, device, ...) to catch certain cases of undeterministic behavior of the forward\n        res = Holder()\n        weak_holder_list.append(weakref.ref(res))\n        return res\n\n\n    def unpack(x):\n        unpack_counter = 0\n        if len(storage) == 0:\n            def inner_pack(inner):\n                nonlocal unpack_counter\n                unpack_counter += 1\n                # If the holder went out of scope, the SavedVariable is dead and so\n                # the value will never be read from the storage. Skip filling it.\n                if weak_holder_list[unpack_counter - 1]() is None:\n                    return\n                # Use detach here to ensure we don't keep the temporary autograd\n                # graph created during the second forward\n                storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()\n                return\n\n            def inner_unpack(packed):\n                raise RuntimeError(\"You are calling backwards on a tensor that is never exposed. Please open an issue.\")\n\n            # Stash the surrounding rng state, and mimic the state that was\n            # present at this time during forward.  Restore the surrounding state\n            # when we're done.\n            rng_devices = []\n            if preserve_rng_state and had_cuda_in_fwd:\n                rng_devices = fwd_gpu_devices\n            with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):\n                if preserve_rng_state:\n                    torch.set_rng_state(fwd_cpu_state)\n                    if had_cuda_in_fwd:\n                        set_device_states(fwd_gpu_devices, fwd_gpu_states)\n\n                with torch.enable_grad(), \\\n                     torch.cuda.amp.autocast(**gpu_autocast_kwargs), \\\n                     torch.cpu.amp.autocast(**cpu_autocast_kwargs), \\\n                     torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):\n                    _unused = function(*args, **kwargs)\n\n        if x not in storage:\n            raise RuntimeError(\n                \"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint\"\n                \" recomputation being triggered in between, this is not currently supported. Please\"\n                \" open an issue with details on your use case so that we can prioritize adding this.\"\n            )\n\n        return storage[x]\n\n    with torch.autograd.graph.saved_tensors_hooks(pack, unpack):\n        output = function(*args, **kwargs)\n        if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:\n            # Cuda was not initialized before running the forward, so we didn't\n            # stash the CUDA state.\n            raise RuntimeError(\n                \"PyTorch's CUDA state was initialized in the forward pass \"\n                \"of a Checkpoint, which is not allowed. Please open an issue \"\n                \"if you need this feature.\")\n\n    return output"
  },
  {
    "path": "models/csrc/__init__.py",
    "content": ""
  },
  {
    "path": "models/csrc/msmv_sampling/msmv_sampling.cpp",
    "content": "#include \"msmv_sampling.h\"\n\n#define MAX_POINT 32\n\nvoid ms_deformable_im2col_cuda_c2345(\n    const float* feat_c2,\n    const float* feat_c3,\n    const float* feat_c4,\n    const float* feat_c5,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const float* data_sampling_loc,\n    const float* data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float* data_col\n);\n\nvoid ms_deformable_im2col_cuda_c23456(\n    const float* feat_c2,\n    const float* feat_c3,\n    const float* feat_c4,\n    const float* feat_c5,\n    const float* feat_c6,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const int h_c6, const int w_c6,\n    const float* data_sampling_loc,\n    const float* data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float* data_col\n);\n\nvoid ms_deformable_col2im_cuda_c2345(\n    const float* grad_col,\n    const float* feat_c2,\n    const float* feat_c3,\n    const float* feat_c4,\n    const float* feat_c5,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const float* data_sampling_loc,\n    const float* data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float* grad_value_c2,\n    float* grad_value_c3,\n    float* grad_value_c4,\n    float* grad_value_c5,\n    float* grad_sampling_loc,\n    float* grad_attn_weight\n);\n\nvoid ms_deformable_col2im_cuda_c23456(\n    const float *grad_col,\n    const float *feat_c2,\n    const float *feat_c3,\n    const float *feat_c4,\n    const float *feat_c5,\n    const float *feat_c6,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const int h_c6, const int w_c6,\n    const float *data_sampling_loc,\n    const float *data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float *grad_value_c2,\n    float *grad_value_c3,\n    float *grad_value_c4,\n    float *grad_value_c5,\n    float *grad_value_c6,\n    float *grad_sampling_loc,\n    float *grad_attn_weight\n);\n\nat::Tensor ms_deform_attn_cuda_c2345_forward(\n    const at::Tensor& feat_c2,  // [B, N, H, W, C]\n    const at::Tensor& feat_c3,  // [B, N, H, W, C]\n    const at::Tensor& feat_c4,  // [B, N, H, W, C]\n    const at::Tensor& feat_c5,  // [B, N, H, W, C]\n    const at::Tensor& sampling_loc,  // [B, Q, P, 3]\n    const at::Tensor& attn_weight  // [B, Q, P, 4]\n    ) {\n    AT_ASSERTM(feat_c2.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c3.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c4.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c5.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(sampling_loc.is_contiguous(), \"sampling_loc tensor has to be contiguous\");\n    AT_ASSERTM(attn_weight.is_contiguous(), \"attn_weight tensor has to be contiguous\");\n\n    AT_ASSERTM(feat_c2.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c3.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c4.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c5.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(sampling_loc.is_cuda(), \"sampling_loc must be a CUDA tensor\");\n    AT_ASSERTM(attn_weight.is_cuda(), \"attn_weight must be a CUDA tensor\");\n\n    const int batch_size = feat_c2.size(0);\n    const int num_views = feat_c2.size(1);\n    const int channels = feat_c2.size(4);\n    const int num_query = sampling_loc.size(1);\n    const int num_point = sampling_loc.size(2);\n    AT_ASSERTM(num_point <= MAX_POINT, \"num_point exceed limits\");\n\n    const int h_c2 = feat_c2.size(2);\n    const int w_c2 = feat_c2.size(3);\n    const int h_c3 = feat_c3.size(2);\n    const int w_c3 = feat_c3.size(3);\n    const int h_c4 = feat_c4.size(2);\n    const int w_c4 = feat_c4.size(3);\n    const int h_c5 = feat_c5.size(2);\n    const int w_c5 = feat_c5.size(3);\n\n    auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options());\n    ms_deformable_im2col_cuda_c2345(\n        feat_c2.data_ptr<float>(),\n        feat_c3.data_ptr<float>(),\n        feat_c4.data_ptr<float>(),\n        feat_c5.data_ptr<float>(),\n        h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,\n        sampling_loc.data_ptr<float>(),\n        attn_weight.data_ptr<float>(),\n        batch_size, channels, num_views, num_query, num_point,\n        output.data_ptr<float>()\n    );\n\n    return output;\n}\n\nat::Tensor ms_deform_attn_cuda_c23456_forward(\n    const at::Tensor& feat_c2,  // [B, N, H, W, C]\n    const at::Tensor& feat_c3,  // [B, N, H, W, C]\n    const at::Tensor& feat_c4,  // [B, N, H, W, C]\n    const at::Tensor& feat_c5,  // [B, N, H, W, C]\n    const at::Tensor& feat_c6,  // [B, N, H, W, C]\n    const at::Tensor& sampling_loc,  // [B, Q, P, 3]\n    const at::Tensor& attn_weight  // [B, Q, P, 4]\n    ) {\n    AT_ASSERTM(feat_c2.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c3.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c4.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c5.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c6.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(sampling_loc.is_contiguous(), \"sampling_loc tensor has to be contiguous\");\n    AT_ASSERTM(attn_weight.is_contiguous(), \"attn_weight tensor has to be contiguous\");\n\n    AT_ASSERTM(feat_c2.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c3.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c4.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c5.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c6.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(sampling_loc.is_cuda(), \"sampling_loc must be a CUDA tensor\");\n    AT_ASSERTM(attn_weight.is_cuda(), \"attn_weight must be a CUDA tensor\");\n\n    const int batch_size = feat_c2.size(0);\n    const int num_views = feat_c2.size(1);\n    const int channels = feat_c2.size(4);\n    const int num_query = sampling_loc.size(1);\n    const int num_point = sampling_loc.size(2);\n    AT_ASSERTM(num_point <= MAX_POINT, \"num_point exceed limits\");\n\n    const int h_c2 = feat_c2.size(2);\n    const int w_c2 = feat_c2.size(3);\n    const int h_c3 = feat_c3.size(2);\n    const int w_c3 = feat_c3.size(3);\n    const int h_c4 = feat_c4.size(2);\n    const int w_c4 = feat_c4.size(3);\n    const int h_c5 = feat_c5.size(2);\n    const int w_c5 = feat_c5.size(3);\n    const int h_c6 = feat_c6.size(2);\n    const int w_c6 = feat_c6.size(3);\n\n    auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options());\n    ms_deformable_im2col_cuda_c23456(\n        feat_c2.data_ptr<float>(),\n        feat_c3.data_ptr<float>(),\n        feat_c4.data_ptr<float>(),\n        feat_c5.data_ptr<float>(),\n        feat_c6.data_ptr<float>(),\n        h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,\n        sampling_loc.data_ptr<float>(),\n        attn_weight.data_ptr<float>(),\n        batch_size, channels, num_views, num_query, num_point,\n        output.data_ptr<float>()\n    );\n\n    return output;\n}\n\nstd::vector<at::Tensor> ms_deform_attn_cuda_c2345_backward(\n    const at::Tensor& grad_output,\n    const at::Tensor& feat_c2,  // [B, N, H, W, C]\n    const at::Tensor& feat_c3,  // [B, N, H, W, C]\n    const at::Tensor& feat_c4,  // [B, N, H, W, C]\n    const at::Tensor& feat_c5,  // [B, N, H, W, C]\n    const at::Tensor& sampling_loc,  // [B, Q, P, 3]\n    const at::Tensor& attn_weight  // [B, Q, P, 4]\n    ) {\n    AT_ASSERTM(feat_c2.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c3.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c4.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c5.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(sampling_loc.is_contiguous(), \"sampling_loc tensor has to be contiguous\");\n    AT_ASSERTM(attn_weight.is_contiguous(), \"attn_weight tensor has to be contiguous\");\n    AT_ASSERTM(grad_output.is_contiguous(), \"grad_output tensor has to be contiguous\");\n\n    AT_ASSERTM(feat_c2.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c3.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c4.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c5.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(sampling_loc.is_cuda(), \"sampling_loc must be a CUDA tensor\");\n    AT_ASSERTM(attn_weight.is_cuda(), \"attn_weight must be a CUDA tensor\");\n    AT_ASSERTM(grad_output.is_cuda(), \"grad_output must be a CUDA tensor\");\n\n    const int batch_size = feat_c2.size(0);\n    const int num_views = feat_c2.size(1);\n    const int channels = feat_c2.size(4);\n    const int num_query = sampling_loc.size(1);\n    const int num_point = sampling_loc.size(2);\n    AT_ASSERTM(num_point <= MAX_POINT, \"num_point exceed limits\");\n\n    auto grad_value_c2 = at::zeros_like(feat_c2);\n    auto grad_value_c3 = at::zeros_like(feat_c3);\n    auto grad_value_c4 = at::zeros_like(feat_c4);\n    auto grad_value_c5 = at::zeros_like(feat_c5);\n    auto grad_sampling_loc = at::zeros_like(sampling_loc);\n    auto grad_attn_weight = at::zeros_like(attn_weight);\n\n    const int h_c2 = feat_c2.size(2);\n    const int w_c2 = feat_c2.size(3);\n    const int h_c3 = feat_c3.size(2);\n    const int w_c3 = feat_c3.size(3);\n    const int h_c4 = feat_c4.size(2);\n    const int w_c4 = feat_c4.size(3);\n    const int h_c5 = feat_c5.size(2);\n    const int w_c5 = feat_c5.size(3);\n\n    ms_deformable_col2im_cuda_c2345(\n        grad_output.data_ptr<float>(),\n        feat_c2.data_ptr<float>(),\n        feat_c3.data_ptr<float>(),\n        feat_c4.data_ptr<float>(),\n        feat_c5.data_ptr<float>(),\n        h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,\n        sampling_loc.data_ptr<float>(),\n        attn_weight.data_ptr<float>(),\n        batch_size, channels, num_views, num_query, num_point,\n        grad_value_c2.data_ptr<float>(),\n        grad_value_c3.data_ptr<float>(),\n        grad_value_c4.data_ptr<float>(),\n        grad_value_c5.data_ptr<float>(),\n        grad_sampling_loc.data_ptr<float>(),\n        grad_attn_weight.data_ptr<float>()\n    );\n\n    return {\n        grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight\n    };\n}\n\nstd::vector<at::Tensor> ms_deform_attn_cuda_c23456_backward(\n    const at::Tensor& grad_output,\n    const at::Tensor& feat_c2,  // [B, N, H, W, C]\n    const at::Tensor& feat_c3,  // [B, N, H, W, C]\n    const at::Tensor& feat_c4,  // [B, N, H, W, C]\n    const at::Tensor& feat_c5,  // [B, N, H, W, C]\n    const at::Tensor& feat_c6,  // [B, N, H, W, C]\n    const at::Tensor& sampling_loc,  // [B, Q, P, 3]\n    const at::Tensor& attn_weight  // [B, Q, P, 4]\n    ) {\n    AT_ASSERTM(feat_c2.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c3.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c4.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c5.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(feat_c6.is_contiguous(), \"value tensor has to be contiguous\");\n    AT_ASSERTM(sampling_loc.is_contiguous(), \"sampling_loc tensor has to be contiguous\");\n    AT_ASSERTM(attn_weight.is_contiguous(), \"attn_weight tensor has to be contiguous\");\n    AT_ASSERTM(grad_output.is_contiguous(), \"grad_output tensor has to be contiguous\");\n\n    AT_ASSERTM(feat_c2.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c3.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c4.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c5.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(feat_c6.is_cuda(), \"value must be a CUDA tensor\");\n    AT_ASSERTM(sampling_loc.is_cuda(), \"sampling_loc must be a CUDA tensor\");\n    AT_ASSERTM(attn_weight.is_cuda(), \"attn_weight must be a CUDA tensor\");\n    AT_ASSERTM(grad_output.is_cuda(), \"grad_output must be a CUDA tensor\");\n\n    const int batch_size = feat_c2.size(0);\n    const int num_views = feat_c2.size(1);\n    const int channels = feat_c2.size(4);\n    const int num_query = sampling_loc.size(1);\n    const int num_point = sampling_loc.size(2);\n    AT_ASSERTM(num_point <= MAX_POINT, \"num_point exceed limits\");\n\n    auto grad_value_c2 = at::zeros_like(feat_c2);\n    auto grad_value_c3 = at::zeros_like(feat_c3);\n    auto grad_value_c4 = at::zeros_like(feat_c4);\n    auto grad_value_c5 = at::zeros_like(feat_c5);\n    auto grad_value_c6 = at::zeros_like(feat_c6);\n    auto grad_sampling_loc = at::zeros_like(sampling_loc);\n    auto grad_attn_weight = at::zeros_like(attn_weight);\n\n    const int h_c2 = feat_c2.size(2);\n    const int w_c2 = feat_c2.size(3);\n    const int h_c3 = feat_c3.size(2);\n    const int w_c3 = feat_c3.size(3);\n    const int h_c4 = feat_c4.size(2);\n    const int w_c4 = feat_c4.size(3);\n    const int h_c5 = feat_c5.size(2);\n    const int w_c5 = feat_c5.size(3);\n    const int h_c6 = feat_c6.size(2);\n    const int w_c6 = feat_c6.size(3);\n\n    ms_deformable_col2im_cuda_c23456(\n        grad_output.data_ptr<float>(),\n        feat_c2.data_ptr<float>(),\n        feat_c3.data_ptr<float>(),\n        feat_c4.data_ptr<float>(),\n        feat_c5.data_ptr<float>(),\n        feat_c6.data_ptr<float>(),\n        h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,\n        sampling_loc.data_ptr<float>(),\n        attn_weight.data_ptr<float>(),\n        batch_size, channels, num_views, num_query, num_point,\n        grad_value_c2.data_ptr<float>(),\n        grad_value_c3.data_ptr<float>(),\n        grad_value_c4.data_ptr<float>(),\n        grad_value_c5.data_ptr<float>(),\n        grad_value_c6.data_ptr<float>(),\n        grad_sampling_loc.data_ptr<float>(),\n        grad_attn_weight.data_ptr<float>()\n    );\n\n    return {\n        grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight\n    };\n}\n\n#ifdef TORCH_EXTENSION_NAME\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"_ms_deform_attn_cuda_c2345_forward\", &ms_deform_attn_cuda_c2345_forward, \"pass\");\n    m.def(\"_ms_deform_attn_cuda_c2345_backward\", &ms_deform_attn_cuda_c2345_backward, \"pass\");\n    m.def(\"_ms_deform_attn_cuda_c23456_forward\", &ms_deform_attn_cuda_c23456_forward, \"pass\");\n    m.def(\"_ms_deform_attn_cuda_c23456_backward\", &ms_deform_attn_cuda_c23456_backward, \"pass\");\n}\n#endif"
  },
  {
    "path": "models/csrc/msmv_sampling/msmv_sampling.h",
    "content": "#pragma once\n\n#include <torch/extension.h>\n\nat::Tensor ms_deform_attn_cuda_c2345_forward(\n    const at::Tensor& feat_c2,  // [B, N, H, W, C]\n    const at::Tensor& feat_c3,  // [B, N, H, W, C]\n    const at::Tensor& feat_c4,  // [B, N, H, W, C]\n    const at::Tensor& feat_c5,  // [B, N, H, W, C]\n    const at::Tensor& sampling_loc,  // [B, Q, P, 3]\n    const at::Tensor& attn_weight  // [B, Q, P, 4]\n);\n\nstd::vector<at::Tensor> ms_deform_attn_cuda_c2345_backward(\n    const at::Tensor& feat_c2,  // [B, N, H, W, C]\n    const at::Tensor& feat_c3,  // [B, N, H, W, C]\n    const at::Tensor& feat_c4,  // [B, N, H, W, C]\n    const at::Tensor& feat_c5,  // [B, N, H, W, C]\n    const at::Tensor& sampling_loc,  // [B, Q, P, 3]\n    const at::Tensor& attn_weight,  // [B, Q, P, 4]\n    const at::Tensor& grad_output\n);\n\nat::Tensor ms_deform_attn_cuda_c23456_forward(\n    const at::Tensor& feat_c2,  // [B, N, H, W, C]\n    const at::Tensor& feat_c3,  // [B, N, H, W, C]\n    const at::Tensor& feat_c4,  // [B, N, H, W, C]\n    const at::Tensor& feat_c5,  // [B, N, H, W, C]\n    const at::Tensor& feat_c6,  // [B, N, H, W, C]\n    const at::Tensor& sampling_loc,  // [B, Q, P, 3]\n    const at::Tensor& attn_weight  // [B, Q, P, 4]\n);\n\nstd::vector<at::Tensor> ms_deform_attn_cuda_c23456_backward(\n    const at::Tensor& grad_output,\n    const at::Tensor& feat_c2,  // [B, N, H, W, C]\n    const at::Tensor& feat_c3,  // [B, N, H, W, C]\n    const at::Tensor& feat_c4,  // [B, N, H, W, C]\n    const at::Tensor& feat_c5,  // [B, N, H, W, C]\n    const at::Tensor& feat_c6,  // [B, N, H, W, C]\n    const at::Tensor& sampling_loc,  // [B, Q, P, 3]\n    const at::Tensor& attn_weight  // [B, Q, P, 4]\n);"
  },
  {
    "path": "models/csrc/msmv_sampling/msmv_sampling_backward.cu",
    "content": "/*!\n * Modified from Deformable DETR\n */\n\n#include <cstdio>\n#include <iostream>\n#include <algorithm>\n#include <cstring>\n#include <cuda_runtime.h>\n#include <device_launch_parameters.h>\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THCAtomics.cuh>\n\n#define CUDA_KERNEL_LOOP(i, n)                          \\\n    for (int i = blockIdx.x * blockDim.x + threadIdx.x; \\\n         i < (n);                                       \\\n         i += blockDim.x * gridDim.x)\n\n#define CUDA_NUM_THREADS 512\n#define MAX_POINT 32\n\ninline int GET_BLOCKS(const int N, const int num_threads)\n{\n    return (N + num_threads - 1) / num_threads;\n}\n\n__device__ void ms_deform_attn_col2im_bilinear(const float *&bottom_data,\n                                               const int &height, const int &width, const int &channels,\n                                               const float &h, const float &w, const int &c,\n                                               const float &top_grad,\n                                               const float &attn_weight,\n                                               const float *&grad_value,\n                                               float *&grad_sampling_loc,\n                                               float *&grad_attn_weight)\n{\n    const int h_low = floor(h);\n    const int w_low = floor(w);\n    const int h_high = h_low + 1;\n    const int w_high = w_low + 1;\n\n    const float lh = h - h_low;\n    const float lw = w - w_low;\n    const float hh = 1 - lh, hw = 1 - lw;\n\n    const int w_stride = channels;\n    const int h_stride = width * w_stride;\n    const int h_low_ptr_offset = h_low * h_stride;\n    const int h_high_ptr_offset = h_low_ptr_offset + h_stride;\n    const int w_low_ptr_offset = w_low * w_stride;\n    const int w_high_ptr_offset = w_low_ptr_offset + w_stride;\n\n    const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n    const float top_grad_value = top_grad * attn_weight;\n    float grad_h_weight = 0, grad_w_weight = 0;\n\n    float *grad_ptr;\n\n    float v1 = 0;\n    if (h_low >= 0 && w_low >= 0)\n    {\n        const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;\n        grad_ptr = const_cast<float *>(grad_value + ptr1);\n        v1 = bottom_data[ptr1];\n        grad_h_weight -= hw * v1;\n        grad_w_weight -= hh * v1;\n        atomicAdd(grad_ptr, w1 * top_grad_value);\n    }\n    float v2 = 0;\n    if (h_low >= 0 && w_high <= width - 1)\n    {\n        const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;\n        grad_ptr = const_cast<float *>(grad_value + ptr2);\n        v2 = bottom_data[ptr2];\n        grad_h_weight -= lw * v2;\n        grad_w_weight += hh * v2;\n        atomicAdd(grad_ptr, w2 * top_grad_value);\n    }\n    float v3 = 0;\n    if (h_high <= height - 1 && w_low >= 0)\n    {\n        const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;\n        grad_ptr = const_cast<float *>(grad_value + ptr3);\n        v3 = bottom_data[ptr3];\n        grad_h_weight += hw * v3;\n        grad_w_weight -= lh * v3;\n        atomicAdd(grad_ptr, w3 * top_grad_value);\n    }\n    float v4 = 0;\n    if (h_high <= height - 1 && w_high <= width - 1)\n    {\n        const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;\n        grad_ptr = const_cast<float *>(grad_value + ptr4);\n        v4 = bottom_data[ptr4];\n        grad_h_weight += lw * v4;\n        grad_w_weight += lh * v4;\n        atomicAdd(grad_ptr, w4 * top_grad_value);\n    }\n\n    const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n    atomicAdd(grad_attn_weight, top_grad * val);\n    atomicAdd(grad_sampling_loc, (width - 1) * grad_w_weight * top_grad_value);\n    atomicAdd(grad_sampling_loc + 1, (height - 1) * grad_h_weight * top_grad_value);\n}\n\n// global_memory_way\n__global__ void ms_deformable_col2im_gpu_kernel_gm_c2345(\n    const float *grad_col,\n    const float *feat_c2,\n    const float *feat_c3,\n    const float *feat_c4,\n    const float *feat_c5,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const float *data_sampling_loc,\n    const float *data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float *grad_value_c2,\n    float *grad_value_c3,\n    float *grad_value_c4,\n    float *grad_value_c5,\n    float *grad_sampling_loc,\n    float *grad_attn_weight)\n{\n    CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point)\n    { // n: bs x query x channels\n\n        int _temp = index;\n        const int p_col = _temp % num_point;\n        _temp /= num_point;\n        const int c_col = _temp % channels;\n        _temp /= channels;\n        const int sampling_index = _temp;\n        _temp /= num_query;\n        const int b_col = _temp;\n\n        const float top_grad = grad_col[index];\n\n        // Sampling location in range [0, 1]\n        int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;\n        const float loc_w = data_sampling_loc[data_loc_ptr];\n        const float loc_h = data_sampling_loc[data_loc_ptr + 1];\n        const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));\n\n        // Attn weights\n        int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;\n\n        const float weight_c2 = data_attn_weight[data_weight_ptr];\n        const float weight_c3 = data_attn_weight[data_weight_ptr + 1];\n        const float weight_c4 = data_attn_weight[data_weight_ptr + 2];\n        const float weight_c5 = data_attn_weight[data_weight_ptr + 3];\n\n        // const float h_im = loc_h * spatial_h - 0.5;  // align_corners = False\n        // const float w_im = loc_w * spatial_w - 0.5;\n\n        // C2 Feature\n        float h_im = loc_h * (h_c2 - 1); // align_corners = True\n        float w_im = loc_w * (w_c2 - 1);\n\n        float *grad_location_ptr = grad_sampling_loc + data_loc_ptr;\n        float *grad_weights_ptr = grad_attn_weight + data_weight_ptr;\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2)\n        {\n            const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;\n            const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c2,\n                                           grad_c2_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n\n        grad_weights_ptr += 1;\n\n        // C3 Feature\n        h_im = loc_h * (h_c3 - 1); // align_corners = True\n        w_im = loc_w * (w_c3 - 1);\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3)\n        {\n            const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;\n            const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c3,\n                                           grad_c3_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n\n        grad_weights_ptr += 1;\n\n        // C4 Feature\n        h_im = loc_h * (h_c4 - 1); // align_corners = True\n        w_im = loc_w * (w_c4 - 1);\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4)\n        {\n            const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;\n            const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c4,\n                                           grad_c4_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n\n        grad_weights_ptr += 1;\n\n        // C5 Feature\n        h_im = loc_h * (h_c5 - 1); // align_corners = True\n        w_im = loc_w * (w_c5 - 1);\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5)\n        {\n            const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;\n            const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c5,\n                                           grad_c5_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n    }\n}\n\n__global__ void ms_deformable_col2im_gpu_kernel_gm_c23456(\n    const float *grad_col,\n    const float *feat_c2,\n    const float *feat_c3,\n    const float *feat_c4,\n    const float *feat_c5,\n    const float *feat_c6,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const int h_c6, const int w_c6,\n    const float *data_sampling_loc,\n    const float *data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float *grad_value_c2,\n    float *grad_value_c3,\n    float *grad_value_c4,\n    float *grad_value_c5,\n    float *grad_value_c6,\n    float *grad_sampling_loc,\n    float *grad_attn_weight)\n{\n    CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point)\n    { // n: bs x query x channels\n\n        int _temp = index;\n        const int p_col = _temp % num_point;\n        _temp /= num_point;\n        const int c_col = _temp % channels;\n        _temp /= channels;\n        const int sampling_index = _temp;\n        _temp /= num_query;\n        const int b_col = _temp;\n\n        const float top_grad = grad_col[index];\n\n        // Sampling location in range [0, 1]\n        int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;\n        const float loc_w = data_sampling_loc[data_loc_ptr];\n        const float loc_h = data_sampling_loc[data_loc_ptr + 1];\n        const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));\n\n        // Attn weights\n        int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;\n\n        const float weight_c2 = data_attn_weight[data_weight_ptr];\n        const float weight_c3 = data_attn_weight[data_weight_ptr + 1];\n        const float weight_c4 = data_attn_weight[data_weight_ptr + 2];\n        const float weight_c5 = data_attn_weight[data_weight_ptr + 3];\n        const float weight_c6 = data_attn_weight[data_weight_ptr + 4];\n\n        // const float h_im = loc_h * spatial_h - 0.5;  // align_corners = False\n        // const float w_im = loc_w * spatial_w - 0.5;\n\n        // C2 Feature\n        float h_im = loc_h * (h_c2 - 1); // align_corners = True\n        float w_im = loc_w * (w_c2 - 1);\n\n        float *grad_location_ptr = grad_sampling_loc + data_loc_ptr;\n        float *grad_weights_ptr = grad_attn_weight + data_weight_ptr;\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2)\n        {\n            const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;\n            const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c2,\n                                           grad_c2_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n\n        grad_weights_ptr += 1;\n\n        // C3 Feature\n        h_im = loc_h * (h_c3 - 1); // align_corners = True\n        w_im = loc_w * (w_c3 - 1);\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3)\n        {\n            const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;\n            const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c3,\n                                           grad_c3_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n\n        grad_weights_ptr += 1;\n\n        // C4 Feature\n        h_im = loc_h * (h_c4 - 1); // align_corners = True\n        w_im = loc_w * (w_c4 - 1);\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4)\n        {\n            const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;\n            const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c4,\n                                           grad_c4_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n\n        grad_weights_ptr += 1;\n\n        // C5 Feature\n        h_im = loc_h * (h_c5 - 1); // align_corners = True\n        w_im = loc_w * (w_c5 - 1);\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5)\n        {\n            const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;\n            const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c5,\n                                           grad_c5_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n\n        grad_weights_ptr += 1;\n\n        // C6 Feature\n        h_im = loc_h * (h_c6 - 1); // align_corners = True\n        w_im = loc_w * (w_c6 - 1);\n\n        if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6)\n        {\n            const float *feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;\n            const float *grad_c6_ptr = grad_value_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;\n            ms_deform_attn_col2im_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col,\n                                           top_grad, weight_c6,\n                                           grad_c6_ptr, grad_location_ptr, grad_weights_ptr);\n        }\n    }\n}\n\nvoid ms_deformable_col2im_cuda_c2345(\n    const float *grad_col,\n    const float *feat_c2,\n    const float *feat_c3,\n    const float *feat_c4,\n    const float *feat_c5,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const float *data_sampling_loc,\n    const float *data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float *grad_value_c2,\n    float *grad_value_c3,\n    float *grad_value_c4,\n    float *grad_value_c5,\n    float *grad_sampling_loc,\n    float *grad_attn_weight)\n{\n    const int num_kernels = batch_size * num_query * channels * num_point;\n    const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point;\n\n    ms_deformable_col2im_gpu_kernel_gm_c2345 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>>(\n        grad_col, feat_c2, feat_c3, feat_c4, feat_c5,\n        h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,\n        data_sampling_loc, data_attn_weight,\n        batch_size, channels, num_views, num_query, num_point,\n        grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5,\n        grad_sampling_loc, grad_attn_weight);\n\n    cudaError_t err = cudaGetLastError();\n    if (err != cudaSuccess)\n    {\n        printf(\"error in ms_deformable_col2im_cuda_c2345: %s\\n\", cudaGetErrorString(err));\n    }\n}\n\nvoid ms_deformable_col2im_cuda_c23456(\n    const float *grad_col,\n    const float *feat_c2,\n    const float *feat_c3,\n    const float *feat_c4,\n    const float *feat_c5,\n    const float *feat_c6,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const int h_c6, const int w_c6,\n    const float *data_sampling_loc,\n    const float *data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float *grad_value_c2,\n    float *grad_value_c3,\n    float *grad_value_c4,\n    float *grad_value_c5,\n    float *grad_value_c6,\n    float *grad_sampling_loc,\n    float *grad_attn_weight)\n{\n    const int num_kernels = batch_size * num_query * channels * num_point;\n    const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point;\n\n    ms_deformable_col2im_gpu_kernel_gm_c23456 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>>(\n        grad_col, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,\n        h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,\n        data_sampling_loc, data_attn_weight,\n        batch_size, channels, num_views, num_query, num_point,\n        grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6,\n        grad_sampling_loc, grad_attn_weight);\n\n    cudaError_t err = cudaGetLastError();\n    if (err != cudaSuccess)\n    {\n        printf(\"error in ms_deformable_col2im_cuda_c23456: %s\\n\", cudaGetErrorString(err));\n    }\n}\n"
  },
  {
    "path": "models/csrc/msmv_sampling/msmv_sampling_forward.cu",
    "content": "/*!\n* Modified from Deformable DETR\n*/\n\n#include <cstdio>\n#include <algorithm>\n#include <cstring>\n#include <cuda_runtime.h>\n#include <device_launch_parameters.h>\n#include <torch/extension.h>\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <THC/THCAtomics.cuh>\n\n#define CUDA_KERNEL_LOOP(i, n)                          \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x;   \\\n      i < (n);                                          \\\n      i += blockDim.x * gridDim.x)\n\n#define CUDA_NUM_THREADS 512\n#define MAX_POINT 32\n\ninline int GET_BLOCKS(const int N, const int num_threads) {\n    return (N + num_threads - 1) / num_threads;\n}\n\n__device__ float ms_deform_attn_im2col_bilinear(\n    const float*& bottom_data,\n    const int& height, const int& width, const int& channels,\n    const float& h, const float& w, const int& c) {\n\n    const int h_low = floor(h);\n    const int w_low = floor(w);\n    const int h_high = h_low + 1;\n    const int w_high = w_low + 1;\n\n    const float lh = h - h_low;\n    const float lw = w - w_low;\n    const float hh = 1 - lh, hw = 1 - lw;\n\n    const int w_stride = channels;\n    const int h_stride = width * w_stride;\n    const int h_low_ptr_offset = h_low * h_stride;\n    const int h_high_ptr_offset = h_low_ptr_offset + h_stride;\n    const int w_low_ptr_offset = w_low * w_stride;\n    const int w_high_ptr_offset = w_low_ptr_offset + w_stride;\n\n    float v1 = 0;\n    if (h_low >= 0 && w_low >= 0) {\n        const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;\n        v1 = bottom_data[ptr1];\n    }\n    float v2 = 0;\n    if (h_low >= 0 && w_high <= width - 1) {\n        const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;\n        v2 = bottom_data[ptr2];\n    }\n    float v3 = 0;\n    if (h_high <= height - 1 && w_low >= 0) {\n        const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;\n        v3 = bottom_data[ptr3];\n    }\n    float v4 = 0;\n    if (h_high <= height - 1 && w_high <= width - 1) {\n        const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;\n        v4 = bottom_data[ptr4];\n    }\n\n    const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n    const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n\n    return val;\n}\n\n__global__ void ms_deformable_im2col_gpu_kernel_c2345(\n    const float* feat_c2,\n    const float* feat_c3,\n    const float* feat_c4,\n    const float* feat_c5,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const float* data_sampling_loc,\n    const float* data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float* data_col) {\n\n    float res[MAX_POINT];\n\n    CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) {  // n: bs x query x channels\n        int _temp = index;\n        const int c_col = _temp % channels;\n        _temp /= channels;\n        const int sampling_index = _temp;\n        _temp /= num_query;\n        const int b_col = _temp;\n\n        for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }\n\n        for (int p_col = 0; p_col < num_point; ++p_col) {\n            // Sampling location in range [0, 1]\n            int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;\n            const float loc_w = data_sampling_loc[data_loc_ptr];\n            const float loc_h = data_sampling_loc[data_loc_ptr + 1];\n            const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));\n\n            // Attn weights\n            int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;\n            const float weight_c2 = data_attn_weight[data_weight_ptr];\n            const float weight_c3 = data_attn_weight[data_weight_ptr + 1];\n            const float weight_c4 = data_attn_weight[data_weight_ptr + 2];\n            const float weight_c5 = data_attn_weight[data_weight_ptr + 3];\n\n            //const float h_im = loc_h * spatial_h - 0.5;  // align_corners = False\n            //const float w_im = loc_w * spatial_w - 0.5;\n\n            // C2 Feature\n            float h_im = loc_h * (h_c2 - 1);  // align_corners = True\n            float w_im = loc_w * (w_c2 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {\n                const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;\n            }\n\n            // C3 Feature\n            h_im = loc_h * (h_c3 - 1);  // align_corners = True\n            w_im = loc_w * (w_c3 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {\n                const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;\n            }\n\n            // C4 Feature\n            h_im = loc_h * (h_c4 - 1);  // align_corners = True\n            w_im = loc_w * (w_c4 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {\n                const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;\n            }\n\n            // C5 Feature\n            h_im = loc_h * (h_c5 - 1);  // align_corners = True\n            w_im = loc_w * (w_c5 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {\n                const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;\n            }\n        }\n\n        for (int p_col = 0; p_col < num_point; ++p_col) {\n            float* data_col_ptr = data_col + index * num_point + p_col;\n            *data_col_ptr = res[p_col];\n        }\n    }\n}\n\n__global__ void ms_deformable_im2col_gpu_kernel_c23456(\n    const float* feat_c2,\n    const float* feat_c3,\n    const float* feat_c4,\n    const float* feat_c5,\n    const float* feat_c6,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const int h_c6, const int w_c6,\n    const float* data_sampling_loc,\n    const float* data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float* data_col) {\n\n    float res[MAX_POINT];\n\n    CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) {  // n: bs x query x channels\n        int _temp = index;\n        const int c_col = _temp % channels;\n        _temp /= channels;\n        const int sampling_index = _temp;\n        _temp /= num_query;\n        const int b_col = _temp;\n\n        for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }\n\n        for (int p_col = 0; p_col < num_point; ++p_col) {\n            // Sampling location in range [0, 1]\n            int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;\n            const float loc_w = data_sampling_loc[data_loc_ptr];\n            const float loc_h = data_sampling_loc[data_loc_ptr + 1];\n            const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));\n\n            // Attn weights\n            int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;\n            const float weight_c2 = data_attn_weight[data_weight_ptr];\n            const float weight_c3 = data_attn_weight[data_weight_ptr + 1];\n            const float weight_c4 = data_attn_weight[data_weight_ptr + 2];\n            const float weight_c5 = data_attn_weight[data_weight_ptr + 3];\n            const float weight_c6 = data_attn_weight[data_weight_ptr + 4];\n\n            //const float h_im = loc_h * spatial_h - 0.5;  // align_corners = False\n            //const float w_im = loc_w * spatial_w - 0.5;\n\n            // C2 Feature\n            float h_im = loc_h * (h_c2 - 1);  // align_corners = True\n            float w_im = loc_w * (w_c2 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {\n                const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;\n            }\n\n            // C3 Feature\n            h_im = loc_h * (h_c3 - 1);  // align_corners = True\n            w_im = loc_w * (w_c3 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {\n                const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;\n            }\n\n            // C4 Feature\n            h_im = loc_h * (h_c4 - 1);  // align_corners = True\n            w_im = loc_w * (w_c4 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {\n                const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;\n            }\n\n            // C5 Feature\n            h_im = loc_h * (h_c5 - 1);  // align_corners = True\n            w_im = loc_w * (w_c5 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {\n                const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;\n            }\n\n            // C6 Feature\n            h_im = loc_h * (h_c6 - 1);  // align_corners = True\n            w_im = loc_w * (w_c6 - 1);\n\n            if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6) {\n                const float* feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;\n                res[p_col] += ms_deform_attn_im2col_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col) * weight_c6;\n            }\n        }\n\n        for (int p_col = 0; p_col < num_point; ++p_col) {\n            float* data_col_ptr = data_col + index * num_point + p_col;\n            *data_col_ptr = res[p_col];\n        }\n    }\n}\n\nvoid ms_deformable_im2col_cuda_c2345(\n    const float* feat_c2,\n    const float* feat_c3,\n    const float* feat_c4,\n    const float* feat_c5,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const float* data_sampling_loc,\n    const float* data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float* data_col) {\n\n    const int num_kernels = batch_size * num_query * channels;\n    const int num_threads = CUDA_NUM_THREADS;\n\n    ms_deformable_im2col_gpu_kernel_c2345 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>> (\n        feat_c2, feat_c3, feat_c4, feat_c5, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,\n        data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col\n    );\n\n    cudaError_t err = cudaGetLastError();\n    if (err != cudaSuccess) {\n        printf(\"error in ms_deformable_im2col_cuda_c2345: %s\\n\", cudaGetErrorString(err));\n    }\n}\n\nvoid ms_deformable_im2col_cuda_c23456(\n    const float* feat_c2,\n    const float* feat_c3,\n    const float* feat_c4,\n    const float* feat_c5,\n    const float* feat_c6,\n    const int h_c2, const int w_c2,\n    const int h_c3, const int w_c3,\n    const int h_c4, const int w_c4,\n    const int h_c5, const int w_c5,\n    const int h_c6, const int w_c6,\n    const float* data_sampling_loc,\n    const float* data_attn_weight,\n    const int batch_size,\n    const int channels,\n    const int num_views,\n    const int num_query,\n    const int num_point,\n    float* data_col) {\n\n    const int num_kernels = batch_size * num_query * channels;\n    const int num_threads = CUDA_NUM_THREADS;\n\n    ms_deformable_im2col_gpu_kernel_c23456 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>> (\n        feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,\n        data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col\n    );\n\n    cudaError_t err = cudaGetLastError();\n    if (err != cudaSuccess) {\n        printf(\"error in ms_deformable_im2col_cuda_c23456: %s\\n\", cudaGetErrorString(err));\n    }\n}\n"
  },
  {
    "path": "models/csrc/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\ndef get_ext_modules():\n    return [\n        CUDAExtension(\n            name='_msmv_sampling_cuda',\n            sources=[\n                'msmv_sampling/msmv_sampling.cpp',\n                'msmv_sampling/msmv_sampling_forward.cu',\n                'msmv_sampling/msmv_sampling_backward.cu'\n            ],\n            include_dirs=['msmv_sampling']\n        )\n    ]\n\n\nsetup(\n    name='csrc',\n    ext_modules=get_ext_modules(),\n    cmdclass={'build_ext': BuildExtension}\n)\n\n"
  },
  {
    "path": "models/csrc/wrapper.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward\nfrom ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward\n\n\ndef msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):\n    \"\"\"\n    value: [B, N, H1W1 + H2W2..., C]\n    sampling_locations: [B, Q, P, 3]\n    scale_weights: [B, Q, P, 4]\n    \"\"\"\n    assert scale_weights.shape[-1] == len(mlvl_feats)\n\n    B, _, _, _, C = mlvl_feats[0].shape\n    _, Q, P, _ = sampling_locations.shape\n\n    sampling_locations = sampling_locations * 2 - 1\n    sampling_locations = sampling_locations[:, :, :, None, :]  # [B, Q, P, 1, 3]\n\n    final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)\n\n    for lvl, feat in enumerate(mlvl_feats):\n        feat = feat.permute(0, 4, 1, 2, 3)\n        out = F.grid_sample(\n            feat, sampling_locations, mode='bilinear',\n            padding_mode='zeros', align_corners=True,\n        )[..., 0]  # [B, C, Q, P]\n        out = out * scale_weights[..., lvl].reshape(B, 1, Q, P)\n        final += out\n\n    return final.permute(0, 2, 1, 3)\n\n\nclass MSMVSamplingC2345(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights):\n        ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights)\n        \n        assert callable(_ms_deform_attn_cuda_c2345_forward)\n        return _ms_deform_attn_cuda_c2345_forward(\n            feat_c2, feat_c3, feat_c4, feat_c5,\n            sampling_locations, scale_weights)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights = ctx.saved_tensors\n\n        assert callable(_ms_deform_attn_cuda_c2345_backward)\n        grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c2345_backward(grad_output.contiguous(), \n            feat_c2, feat_c3, feat_c4, feat_c5,\n            sampling_locations, scale_weights\n        )\n        \n        return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight\n\n\nclass MSMVSamplingC23456(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights):\n        ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights)\n        \n        assert callable(_ms_deform_attn_cuda_c23456_forward)\n        return _ms_deform_attn_cuda_c23456_forward(\n            feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,\n            sampling_locations, scale_weights)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights = ctx.saved_tensors\n\n        assert callable(_ms_deform_attn_cuda_c23456_backward)\n        grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c23456_backward(grad_output.contiguous(), \n            feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,\n            sampling_locations, scale_weights\n        )\n        \n        return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight\n\n\ndef msmv_sampling(mlvl_feats, sampling_locations, scale_weights):\n    sampling_locations = sampling_locations.contiguous()\n    scale_weights = scale_weights.contiguous()\n    if len(mlvl_feats) == 4:\n        return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)\n    elif len(mlvl_feats) == 5:\n        return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)\n    else:\n        return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)\n"
  },
  {
    "path": "models/loss_utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom mmdet.models.builder import LOSSES, build_loss\nfrom mmdet.core import reduce_mean\nfrom .utils import sparse2dense\nfrom torch.cuda.amp import autocast\nfrom torch.autograd import Variable\n\n\ndef get_voxel_decoder_loss_input(voxel_semantics, occ_loc_i, seg_pred_i, scale, num_classes=18):\n    assert voxel_semantics.shape[0] == 1  # bs = 1\n    voxel_semantics = voxel_semantics.long()\n\n    if seg_pred_i is not None:  # semantic prediction\n        assert seg_pred_i.shape[-1] == num_classes\n        \n        seg_pred_dense, sparse_mask = sparse2dense(\n            occ_loc_i, seg_pred_i,\n            dense_shape=[200 // scale, 200 // scale, 16 // scale, num_classes],\n            empty_value=torch.zeros((num_classes)).to(seg_pred_i)\n        )\n        sparse_mask = F.interpolate(sparse_mask[:, None].float(), scale_factor=scale)[:, 0].bool()\n        seg_pred_dense = seg_pred_dense.permute(0, 4, 1, 2, 3)   # [B, CLS, W, H, D]\n        seg_pred_dense = F.interpolate(seg_pred_dense, scale_factor=scale)\n        seg_pred_dense = seg_pred_dense.permute(0, 2, 3, 4, 1)   # [B, W, H, D, CLS]\n\n        seg_pred_i_sparse = seg_pred_dense[sparse_mask]  # [K, CLS]\n        voxel_semantics_sparse = voxel_semantics[sparse_mask]  # [K]\n\n    return seg_pred_i_sparse, voxel_semantics_sparse, sparse_mask\n\n\ndef compute_scal_loss(pred, gt, class_id, reverse=False, ignore_index=255):\n    p = pred[:, class_id, :]\n    completion_target = (gt == class_id).long()\n    \n    loss = torch.zeros(pred.shape[0], device=pred.device)\n    \n    if reverse:\n        p = 1 - p\n        completion_target = ((gt != class_id) & (gt != ignore_index)).long()\n    \n    target_sum = completion_target.sum(dim=(1))\n    mask = (target_sum > 0)\n    \n    p = p[torch.where(mask)]\n    completion_target = completion_target[torch.where(mask)]\n    \n    nominator = torch.sum(p * completion_target, dim=(1))\n    \n    p_mask = torch.where(torch.sum(p, dim=(1)) > 0)\n    if p_mask[0].shape[0] > 0:\n        precision = nominator[p_mask] / torch.sum(p[p_mask], dim=(1))\n        loss_precision = F.binary_cross_entropy(\n            precision, torch.ones_like(precision),\n            reduction='none'\n        )\n        loss[torch.where(mask)[0][p_mask]] += loss_precision\n        \n    t_mask = torch.where(torch.sum(completion_target, dim=(1)) > 0)\n    if t_mask[0].shape[0] > 0:\n        recall = nominator[t_mask] / torch.sum(completion_target[t_mask], dim=(1))\n        loss_recall = F.binary_cross_entropy(\n            recall, torch.ones_like(recall),\n            reduction='none'\n        )\n        loss[torch.where(mask)[0][t_mask]] += loss_recall\n        \n    ct_mask = torch.where(torch.sum(1 - completion_target, dim=(1)) > 0)\n    if ct_mask[0].shape[0] > 0:\n        specificity = torch.sum((1 - p[ct_mask]) * (1 - completion_target[ct_mask]), dim=(1)) / (\n            torch.sum(1 - completion_target[ct_mask], dim=(1))\n        )\n        loss_ct = F.binary_cross_entropy(\n            specificity, torch.ones_like(specificity),\n            reduction='none'\n        )\n        loss[torch.where(mask)[0][ct_mask]] += loss_ct\n        \n    return loss, mask\n\n\n@LOSSES.register_module()\nclass GeoScalLoss(nn.Module):\n    def __init__(self, \n                 num_classes,\n                 loss_weight=1.0):\n        super().__init__()\n        self.num_classes = num_classes\n        self.loss_weight = loss_weight\n        \n    def forward(self, pred, gt):\n        loss = torch.tensor(0, device=pred.device, dtype=pred.dtype)\n        pred = F.softmax(pred, dim=1)\n        \n        loss, _ = compute_scal_loss(pred, gt, self.num_classes - 1, reverse=True)\n        return self.loss_weight * torch.mean(loss)\n\n\n@LOSSES.register_module()\nclass SemScalLoss(nn.Module):\n    def __init__(self, \n                 num_classes,\n                 class_weights=None,\n                 loss_weight=1.0):\n        super().__init__()\n        self.num_classes = num_classes\n        self.class_weights = class_weights\n        if self.class_weights is not None:\n            assert len(self.class_weights) == self.num_classes, \"number of class weights must equal to class number\"\n        else:\n            self.class_weights = [1.0 for _ in range(self.num_classes)]\n        self.loss_weight = loss_weight\n        \n    def forward(self, pred, gt):\n        pred = F.softmax(pred, dim=1)\n        batch_size = pred.shape[0]\n        loss = torch.zeros(batch_size, device=pred.device)\n        count = torch.zeros(batch_size, device=pred.device)\n        for i in range(self.num_classes):\n            loss_cls, mask_cls = compute_scal_loss(pred, gt, i)\n            count += mask_cls.long()\n            loss += loss_cls * self.class_weights[i]\n        \n        return self.loss_weight * (loss / count).mean()\n\n\n# borrowed from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L21\ndef dice_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        num_masks: float,\n        mask_camera: torch.Tensor\n    ):\n    \"\"\"\n    Compute the DICE loss, similar to generalized IOU for masks\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    \"\"\"\n    if mask_camera is not None:\n        inputs = inputs[:, :, mask_camera]\n        targets = targets[:, :, mask_camera]\n    \n    inputs = inputs.sigmoid()\n    inputs = inputs.flatten(1)\n    targets = targets.squeeze(1)\n    numerator = 2 * (inputs * targets).sum(-1)\n    denominator = inputs.sum(-1) + targets.sum(-1)\n    loss = 1 - (numerator + 1) / (denominator + 1)\n    return loss.sum() / num_masks\n\n\ndice_loss_jit = torch.jit.script(\n    dice_loss\n)  # type: torch.jit.ScriptModule\n\n\n# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L48\ndef sigmoid_ce_loss(\n        inputs: torch.Tensor,\n        targets: torch.Tensor,\n        num_masks: float,\n        mask_camera: torch.Tensor\n    ):\n    \"\"\"\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    Returns:\n        Loss tensor\n    \"\"\"\n    # [M, 1, K]\n    if mask_camera is not None:\n        mask_camera = mask_camera.to(torch.int32)\n        mask_camera = mask_camera[None, None, ...].expand(targets.shape[0], 1, mask_camera.shape[-1])\n        loss = F.binary_cross_entropy_with_logits(inputs, targets, mask_camera, reduction=\"none\")\n    else:\n        loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=\"none\")\n    \n    return loss.mean(2).mean(1).sum() / num_masks\n\n\nsigmoid_ce_loss_jit = torch.jit.script(\n    sigmoid_ce_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef CE_ssc_loss(pred, target, class_weights=None, ignore_index=255):\n    \"\"\"\n    :param: prediction: the predicted tensor, must be [BS, C, ...]\n    \"\"\"\n\n    criterion = nn.CrossEntropyLoss(\n        weight=class_weights, ignore_index=ignore_index, reduction=\"mean\"\n    )\n    with autocast(False):\n        loss = criterion(pred, target.long())\n\n    return loss\n\n\n# https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L22\ndef lovasz_grad(gt_sorted):\n    \"\"\"\n    Computes gradient of the Lovasz extension w.r.t sorted errors\n    See Alg. 1 in paper\n    \"\"\"\n    p = len(gt_sorted)\n    gts = gt_sorted.sum()\n    intersection = gts - gt_sorted.float().cumsum(0)\n    union = gts + (1 - gt_sorted).float().cumsum(0)\n    jaccard = 1. - intersection / union\n    if p > 1: # cover 1-pixel case\n        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]\n    return jaccard\n\n\n# https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L157\ndef lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):\n    \"\"\"\n    Multi-class Lovasz-Softmax loss\n      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).\n              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].\n      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)\n      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.\n      per_image: compute the loss per image instead of per batch\n      ignore: void class labels\n    \"\"\"\n    if per_image:\n        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)\n                          for prob, lab in zip(probas, labels))\n    else:\n        with autocast(False):\n            loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)\n    return loss\n\n\n# https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L176\ndef lovasz_softmax_flat(probas, labels, classes='present'):\n    \"\"\"\n    Multi-class Lovasz-Softmax loss\n      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)\n      labels: [P] Tensor, ground truth labels (between 0 and C - 1)\n      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.\n    \"\"\"\n    if probas.numel() == 0:\n        # only void pixels, the gradients should be 0\n        return probas * 0.\n    C = probas.size(1)\n    losses = []\n    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes\n    for c in class_to_sum:\n        fg = (labels == c).float() # foreground for class c\n        if (classes == 'present' and fg.sum() == 0):\n            continue\n        if C == 1:\n            if len(classes) > 1:\n                raise ValueError('Sigmoid output possible only with 1 class')\n            class_pred = probas[:, 0]\n        else:\n            class_pred = probas[:, c]\n        errors = (Variable(fg) - class_pred).abs()\n        errors_sorted, perm = torch.sort(errors, 0, descending=True)\n        perm = perm.data\n        fg_sorted = fg[perm]\n        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))\n    return mean(losses)\n\n\n# https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L207\ndef flatten_probas(probas, labels, ignore=None):\n    \"\"\"\n    Flattens predictions in the batch\n    \"\"\"\n    if probas.dim() == 2:\n        if ignore is not None:\n            valid = (labels != ignore)\n            probas = probas[valid]\n            labels = labels[valid]\n        return probas, labels\n\n    elif probas.dim() == 3:\n        # assumes output of a sigmoid layer\n        B, H, W = probas.size()\n        probas = probas.view(B, 1, H, W)\n    elif probas.dim() == 5:\n        #3D segmentation\n        B, C, L, H, W = probas.size()\n        probas = probas.contiguous().view(B, C, L, H*W)\n    B, C, H, W = probas.size()\n    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C\n    labels = labels.view(-1)\n    if ignore is None:\n        return probas, labels\n    valid = (labels != ignore)\n    vprobas = probas[valid.nonzero().squeeze()]\n    vlabels = labels[valid]\n    return vprobas, vlabels\n\n\n# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L90\n@LOSSES.register_module()\nclass Mask2FormerLoss(nn.Module):\n    def __init__(self, \n                 num_classes,\n                 loss_cls_weight=1.0, \n                 loss_mask_weight=1.0, \n                 loss_dice_weight=1.0, \n                 no_class_weight=0.1):\n        super().__init__()\n        self.num_classes = num_classes\n        self.loss_cls_weight = loss_cls_weight\n        self.loss_mask_weight = loss_mask_weight\n        self.loss_dice_weight = loss_dice_weight\n        self.no_class_weight = no_class_weight\n        self.empty_weight = torch.ones(self.num_classes)\n        self.empty_weight[-1] = self.no_class_weight\n        self.loss_cls = build_loss(dict(\n            type='FocalLoss',\n            use_sigmoid=True,\n            gamma=2.0,\n            alpha=0.25,\n            loss_weight=2.0\n        ))\n        \n    def forward(self, mask_pred, class_pred, mask_gt, class_gt, indices, mask_camera):\n        bs = mask_pred.shape[0]\n        loss_masks = torch.tensor(0).to(mask_pred)\n        loss_dices = torch.tensor(0).to(mask_pred)\n        loss_classes = torch.tensor(0).to(mask_pred)\n\n        num_total_pos = sum([tc.numel() for tc in class_gt])\n        avg_factor = torch.clamp(reduce_mean(class_pred.new_tensor([num_total_pos * 1.0])), min=1).item()\n        \n        for b in range(bs):\n            mask_camera_b = mask_camera[b] if mask_camera is not None else None# N\n            tgt_mask = mask_gt[b]\n            num_instances = class_gt[b].shape[0]\n\n            tgt_class = class_gt[b]\n            tgt_mask = (tgt_mask.unsqueeze(-1) == torch.arange(num_instances).to(mask_gt.device))\n            tgt_mask = tgt_mask.permute(1, 0)\n            \n            src_idx, tgt_idx = indices[b]\n            src_mask = mask_pred[b][src_idx]   # [M, N], M is number of gt instances, N is number of remaining voxels\n            tgt_mask = tgt_mask[tgt_idx]   # [M, N]\n            src_class = class_pred[b]   # [Q, CLS]\n            \n            # pad non-aligned queries' tgt classes with 'no class'\n            pad_tgt_class = torch.full(\n                (src_class.shape[0], ), self.num_classes - 1, dtype=torch.int64, device=class_pred.device\n            )   # [Q]\n            pad_tgt_class[src_idx] = tgt_class[tgt_idx]\n            \n            # only calculates loss mask for aligned pairs\n            loss_mask, loss_dice = self.loss_masks(src_mask, tgt_mask, avg_factor=avg_factor, mask_camera=mask_camera_b)\n            # calculates loss class for all queries\n            loss_class = self.loss_labels(src_class, pad_tgt_class, self.empty_weight.to(src_class.device), avg_factor=avg_factor)\n            \n            loss_masks += loss_mask * self.loss_mask_weight\n            loss_dices += loss_dice * self.loss_dice_weight\n            loss_classes += loss_class * self.loss_cls_weight\n            \n        return loss_masks, loss_dices, loss_classes\n    \n    # mask2former use point sampling to calculate loss of fewer important points\n    # we omit point sampling as we have limited number of points\n    def loss_masks(self, src_mask, tgt_mask, avg_factor=None, mask_camera=None):\n        \"\"\"Compute the losses related to the masks: the focal loss and the dice loss.\n        targets dicts must contain the key \"masks\" containing a tensor of dim [nb_target_boxes, h, w]\n        \"\"\"\n        # No need to upsample predictions as we are using normalized coordinates :)\n        # N x 1 x H x W\n        num_masks = tgt_mask.shape[0]\n        src_mask = src_mask.view(num_masks, 1, -1)\n        tgt_mask = tgt_mask.view(num_masks, 1, -1)\n        \n        if avg_factor is None:\n            avg_factor = num_masks\n\n        loss_dice = dice_loss(src_mask, tgt_mask, avg_factor, mask_camera)\n        loss_mask = sigmoid_ce_loss(src_mask, tgt_mask.float(), avg_factor, mask_camera)\n        \n        return loss_mask, loss_dice\n        \n    def loss_labels(self, src_class, tgt_class, empty_weight=None, avg_factor=None):\n        \"\"\"Classification loss (NLL)\n        targets dicts must contain the key \"labels\" containing a tensor of dim [nb_target_boxes]\n        \"\"\"\n        return self.loss_cls(\n            src_class, tgt_class, torch.ones_like(tgt_class), avg_factor=avg_factor\n        ).mean()\n\n# --------------------------- HELPER FUNCTIONS ---------------------------\ndef mean(l, empty=0):\n    \"\"\"\n    nanmean compatible with generators.\n    \"\"\"\n    l = iter(l)\n    \n    try:\n        n = 1\n        acc = next(l)\n    except StopIteration:\n        if empty == 'raise':\n            raise ValueError('Empty mean')\n        return empty\n    \n    for n, v in enumerate(l, 2):\n        acc += v\n    \n    if n == 1:\n        return acc\n    \n    return acc / n"
  },
  {
    "path": "models/matcher.py",
    "content": "\"\"\"\nModified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py\n\"\"\"\nimport torch\nimport torch.nn.functional as F\nfrom torch.cuda.amp import autocast\nfrom scipy.optimize import linear_sum_assignment\nfrom mmcv.runner import BaseModule\nfrom mmdet.core.bbox.match_costs import build_match_cost\n\n\ndef batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor, mask_camera: torch.Tensor):\n    \"\"\"\n    Compute the DICE loss, similar to generalized IOU for masks\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    \"\"\"\n    if mask_camera is not None:\n        inputs = inputs[:, mask_camera]\n        targets = targets[:, mask_camera]\n    \n    inputs = inputs.sigmoid()\n    inputs = inputs.flatten(1)\n    numerator = 2 * torch.einsum(\"nc,mc->nm\", inputs, targets)\n    denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]\n    loss = 1 - (numerator + 1) / (denominator + 1)\n    return loss\n\n\nbatch_dice_loss_jit = torch.jit.script(\n    batch_dice_loss\n)  # type: torch.jit.ScriptModule\n\n\ndef batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor, mask_camera: torch.Tensor):\n    \"\"\"\n    Args:\n        inputs: A float tensor of arbitrary shape.\n                The predictions for each example.\n        targets: A float tensor with the same shape as inputs. Stores the binary\n                 classification label for each element in inputs\n                (0 for the negative class and 1 for the positive class).\n    Returns:\n        Loss tensor\n    \"\"\"\n    hw = inputs.shape[1]\n    \n    if mask_camera is not None:\n        mask_camera = mask_camera.to(torch.int32)\n        mask_camera = mask_camera[None].expand(inputs.shape[0], mask_camera.shape[-1])\n        \n        pos = F.binary_cross_entropy_with_logits(\n            inputs, torch.ones_like(inputs), mask_camera, reduction=\"none\"\n        )\n        neg = F.binary_cross_entropy_with_logits(\n            inputs, torch.zeros_like(inputs), mask_camera, reduction=\"none\"\n        )\n    else:\n        pos = F.binary_cross_entropy_with_logits(\n            inputs, torch.ones_like(inputs), reduction=\"none\"\n        )\n        neg = F.binary_cross_entropy_with_logits(\n            inputs, torch.zeros_like(inputs), reduction=\"none\"\n        )\n\n\n    loss = torch.einsum(\"nc,mc->nm\", pos, targets) + torch.einsum(\n        \"nc,mc->nm\", neg, (1 - targets)\n    )\n\n    return loss / hw\n\n\nbatch_sigmoid_ce_loss_jit = torch.jit.script(\n    batch_sigmoid_ce_loss\n)  # type: torch.jit.ScriptModule\n\n\n# modified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py#L70\nclass HungarianMatcher(BaseModule):\n    \"\"\"This class computes an assignment between the targets and the predictions of the network\n\n    For efficiency reasons, the targets don't include the no_object. Because of this, in general,\n    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,\n    while the others are un-matched (and thus treated as non-objects).\n    \"\"\"\n\n    def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1):\n        \"\"\"Creates the matcher\n\n        Params:\n            cost_class: This is the relative weight of the classification error in the matching cost\n            cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost\n            cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost\n        \"\"\"\n        super().__init__()\n        self.cost_class = cost_class\n        self.cost_mask = cost_mask\n        self.cost_dice = cost_dice\n\n        self.loss_focal = build_match_cost(dict(type='FocalLossCost', weight=2.0))\n\n        assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, \"all costs cant be 0\"\n\n    @torch.no_grad()\n    def forward(self, mask_pred, class_pred, mask_gt, class_gt, mask_camera):\n        \"\"\"\n        Args:\n            mask_pred: [bs, num_query, num_voxel (65536)]\n            class_pred: [bs, num_query, 17]\n            mask_gt: [bs, num_voxel], value in range [0, num_obj - 1]\n            class_gt: [[bs0_num_obj], [bs1_num_obj], ...], value in range [0, num_cls - 1]\n        \"\"\"\n        bs, num_queries = class_pred.shape[:2]\n\n        indices = []\n\n        # Iterate through batch size\n        for b in range(bs):\n            mask_camera_b = mask_camera[b] if mask_camera is not None else None\n            tgt_ids = class_gt[b]\n            num_instances = tgt_ids.shape[0]  # must be here, cause num of instances may change after masking\n\n            # Compute the classification cost. Contrary to the loss, we don't use the NLL,\n            # but approximate it in 1 - proba[target class].\n            # The 1 is a constant that doesn't change the matching, it can be ommitted.\n            '''out_prob = class_pred[b].softmax(-1)  # [num_queries, num_classes]\n            cost_class = -out_prob[:, tgt_ids.long()].squeeze(1)'''\n\n            # Compute the classification cost. We use focal loss provided by mmdet as sparsebev does\n            out_prob = class_pred[b]  # TODO\n            cost_class = self.loss_focal(out_prob, tgt_ids.long())\n\n            out_mask = mask_pred[b]  # [num_queries, H_pred, W_pred]\n            # gt masks are already padded when preparing target\n            tgt_mask = mask_gt[b]\n            \n            tgt_mask = (tgt_mask.unsqueeze(-1) == torch.arange(num_instances).to(mask_gt.device))\n            tgt_mask = tgt_mask.permute(1, 0) # [Q, N]\n\n            # all masks share the same set of points for efficient matching!\n            tgt_mask = tgt_mask.view(tgt_mask.shape[0], -1)\n            out_mask = out_mask.view(out_mask.shape[0], -1)\n\n            with autocast(enabled=False):\n                out_mask = out_mask.float()\n                tgt_mask = tgt_mask.float()\n                # Compute the focal loss between masks\n                cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask, mask_camera_b)\n\n                # Compute the dice loss betwen masks\n                cost_dice = batch_dice_loss(out_mask, tgt_mask, mask_camera_b)\n            \n            # Final cost matrix\n            C = (\n                self.cost_mask * cost_mask\n                + self.cost_class * cost_class\n                + self.cost_dice * cost_dice\n            )\n            C = C.reshape(num_queries, -1).cpu()\n            \n            indices.append(linear_sum_assignment(C))\n            \n        return [\n            (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))\n            for i, j in indices\n        ]\n"
  },
  {
    "path": "models/sparse_voxel_decoder.py",
    "content": "import torch\nimport torch.nn as nn\nfrom mmcv.runner import BaseModule\nfrom mmcv.cnn.bricks.transformer import FFN\nfrom .sparsebev_transformer import SparseBEVSelfAttention, SparseBEVSampling, AdaptiveMixing\nfrom .utils import DUMP, generate_grid, batch_indexing\nfrom .bbox.utils import encode_bbox\nimport torch.nn.functional as F\n\n\ndef index2point(coords, pc_range, voxel_size):\n    \"\"\"\n    coords: [B, N, 3], int\n    pc_range: [-40, -40, -1.0, 40, 40, 5.4]\n    voxel_size: float\n    \"\"\"\n    coords = coords * voxel_size\n    coords = coords + torch.tensor(pc_range[:3], device=coords.device)\n    return coords\n\n\ndef point2bbox(coords, box_size):\n    \"\"\"\n    coords: [B, N, 3], float\n    box_size: float\n    \"\"\"\n    wlh = torch.ones_like(coords.float()) * box_size\n    bboxes = torch.cat([coords, wlh], dim=-1)  # [B, N, 6]\n    return bboxes\n\n\ndef upsample(pre_feat, pre_coords, interval):\n    '''\n    :param pre_feat: (Tensor), features from last level, (B, N, C)\n    :param pre_coords: (Tensor), coordinates from last level, (B, N, 3) (3: x, y, z)\n    :param interval: interval of voxels, interval = scale ** 2\n    :param num: 1 -> 8\n    :return: up_feat : upsampled features, (B, N*8, C//8)\n    :return: up_coords: upsampled coordinates, (B, N*8, 3)\n    '''\n    pos_list = [0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2]]\n    bs, num_query, num_channels = pre_feat.shape\n    \n    up_feat = pre_feat.reshape(bs, num_query, 8, num_channels // 8)  # [B, N, 8, C/8]\n    up_coords = pre_coords.unsqueeze(2).repeat(1, 1, 8, 1).contiguous()  # [B, N, 8, 3]\n    for i in range(len(pos_list)):\n        up_coords[:, :, i + 1, pos_list[i]] += interval\n\n    up_feat = up_feat.reshape(bs, -1, num_channels // 8)\n    up_coords = up_coords.reshape(bs, -1, 3)\n\n    return up_feat, up_coords\n\n\nclass SparseVoxelDecoder(BaseModule):\n    def __init__(self,\n                 embed_dims=None,\n                 num_layers=None,\n                 num_frames=None,\n                 num_points=None,\n                 num_groups=None,\n                 num_levels=None,\n                 num_classes=None,\n                 semantic=False,\n                 topk_training=None,\n                 topk_testing=None,\n                 pc_range=None):\n        super().__init__()\n\n        self.embed_dims = embed_dims\n        self.num_frames = num_frames\n        self.num_layers = num_layers\n        self.pc_range = pc_range\n        self.semantic = semantic\n        self.voxel_dim = [200, 200, 16]\n        self.topk_training = topk_training\n        self.topk_testing = topk_testing\n\n        self.decoder_layers = nn.ModuleList()\n        self.lift_feat_heads = nn.ModuleList()\n        #self.occ_pred_heads = nn.ModuleList()\n        \n        if semantic:\n            self.seg_pred_heads = nn.ModuleList()\n\n        for i in range(num_layers):\n            self.decoder_layers.append(SparseVoxelDecoderLayer(\n                 embed_dims=embed_dims,\n                 num_frames=num_frames,\n                 num_points=num_points // (2 ** i),\n                 num_groups=num_groups,\n                 num_levels=num_levels,\n                 pc_range=pc_range,\n                 self_attn=i in [0, 1]\n            ))\n            self.lift_feat_heads.append(nn.Sequential(\n                nn.Linear(embed_dims, embed_dims * 8),\n                nn.ReLU(inplace=True)\n            ))\n            #self.occ_pred_heads.append(nn.Linear(embed_dims, 1))\n\n            if semantic:\n                self.seg_pred_heads.append(nn.Linear(embed_dims, num_classes))\n\n    @torch.no_grad()\n    def init_weights(self):\n        for i in range(len(self.decoder_layers)):\n            self.decoder_layers[i].init_weights()\n\n    def forward(self, mlvl_feats, img_metas):\n        occ_preds = []\n        \n        topk = self.topk_training if self.training else self.topk_testing\n        \n        B = len(img_metas)\n        # init query coords\n        interval = 2 ** self.num_layers\n        query_coord = generate_grid(self.voxel_dim, interval).expand(B, -1, -1)  # [B, N, 3]\n        query_feat = torch.zeros([B, query_coord.shape[1], self.embed_dims], device=query_coord.device)  # [B, N, C]\n\n        for i, layer in enumerate(self.decoder_layers):\n            DUMP.stage_count = i\n            \n            interval = 2 ** (self.num_layers - i)  # 8 4 2 1\n\n            # bbox from coords\n            query_bbox = index2point(query_coord, self.pc_range, voxel_size=0.4)  # [B, N, 3]\n            query_bbox = point2bbox(query_bbox, box_size=0.4 * interval)  # [B, N, 6]\n            query_bbox = encode_bbox(query_bbox, pc_range=self.pc_range)  # [B, N, 6]\n\n            # transformer layer\n            query_feat = layer(query_feat, query_bbox, mlvl_feats, img_metas)  # [B, N, C]\n            \n            # upsample 2x\n            query_feat = self.lift_feat_heads[i](query_feat)  # [B, N, 8C]\n            query_feat_2x, query_coord_2x = upsample(query_feat, query_coord, interval // 2)\n\n            if self.semantic:\n                seg_pred_2x = self.seg_pred_heads[i](query_feat_2x)  # [B, K, CLS]\n            else:\n                seg_pred_2x = None\n\n            # sparsify after seg_pred\n            non_free_prob = 1 - F.softmax(seg_pred_2x, dim=-1)[..., -1]  # [B, K]\n            indices = torch.topk(non_free_prob, k=topk[i], dim=1)[1]  # [B, K]\n\n            query_coord_2x = batch_indexing(query_coord_2x, indices, layout='channel_last')  # [B, K, 3]\n            query_feat_2x = batch_indexing(query_feat_2x, indices, layout='channel_last')  # [B, K, C]\n            seg_pred_2x = batch_indexing(seg_pred_2x, indices, layout='channel_last')  # [B, K, CLS]\n\n            occ_preds.append((\n                torch.div(query_coord_2x, interval // 2, rounding_mode='trunc').long(),\n                None,\n                seg_pred_2x,\n                query_feat_2x,\n                interval // 2)\n            )\n\n            query_coord = query_coord_2x.detach()\n            query_feat = query_feat_2x.detach()\n\n        return occ_preds\n\n\nclass SparseVoxelDecoderLayer(BaseModule):\n    def __init__(self,\n                 embed_dims=None,\n                 num_frames=None,\n                 num_points=None,\n                 num_groups=None,\n                 num_levels=None,\n                 pc_range=None,\n                 self_attn=True):\n        super().__init__()\n\n        self.position_encoder = nn.Sequential(\n            nn.Linear(3, embed_dims), \n            nn.LayerNorm(embed_dims),\n            nn.ReLU(inplace=True),\n            nn.Linear(embed_dims, embed_dims),\n            nn.LayerNorm(embed_dims),\n            nn.ReLU(inplace=True),\n        )\n\n        if self_attn:\n            self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range, scale_adaptive=True)\n            self.norm1 = nn.LayerNorm(embed_dims)\n        else:\n            self.self_attn = None\n        \n        self.sampling = SparseBEVSampling(\n            embed_dims=embed_dims,\n            num_frames=num_frames,\n            num_groups=num_groups,\n            num_points=num_points,\n            num_levels=num_levels,\n            pc_range=pc_range\n        )\n        self.mixing = AdaptiveMixing(\n            in_dim=embed_dims,\n            in_points=num_points * num_frames,\n            n_groups=num_groups,\n            out_points=num_points * num_frames * num_groups\n        )\n        self.ffn = FFN(embed_dims, feedforward_channels=embed_dims * 2, ffn_drop=0.1)\n        \n        self.norm2 = nn.LayerNorm(embed_dims)\n        self.norm3 = nn.LayerNorm(embed_dims)\n\n    @torch.no_grad()\n    def init_weights(self):\n        if self.self_attn is not None:\n            self.self_attn.init_weights()\n        self.sampling.init_weights()\n        self.mixing.init_weights()\n        self.ffn.init_weights()\n\n    def forward(self, query_feat, query_bbox, mlvl_feats, img_metas):\n        query_pos = self.position_encoder(query_bbox[..., :3])\n        query_feat = query_feat + query_pos\n\n        if self.self_attn is not None:\n            query_feat = self.norm1(self.self_attn(query_bbox, query_feat))\n        sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas)\n        query_feat = self.norm2(self.mixing(sampled_feat, query_feat))\n        query_feat = self.norm3(self.ffn(query_feat))\n\n        return query_feat\n"
  },
  {
    "path": "models/sparsebev_head.py",
    "content": "import math\nimport torch\nimport torch.nn as nn\nfrom mmcv.runner import force_fp32\nfrom mmdet.core import multi_apply, reduce_mean\nfrom mmdet.models import HEADS\nfrom mmdet.models.dense_heads import DETRHead\nfrom mmdet3d.core.bbox.coders import build_bbox_coder\nfrom mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes\nfrom .bbox.utils import normalize_bbox, encode_bbox\n\n\n@HEADS.register_module()\nclass SparseBEVHead(DETRHead):\n    def __init__(self,\n                 *args,\n                 num_classes,\n                 in_channels,\n                 query_denoising=True,\n                 query_denoising_groups=10,\n                 bbox_coder=None,\n                 code_size=10,\n                 code_weights=[1.0] * 10,\n                 train_cfg=dict(),\n                 test_cfg=dict(max_per_img=100),\n                 **kwargs):\n        self.code_size = code_size\n        self.code_weights = code_weights\n        self.num_classes = num_classes\n        self.in_channels = in_channels\n        self.train_cfg = train_cfg\n        self.test_cfg = test_cfg\n        self.fp16_enabled = False\n        self.embed_dims = in_channels\n\n        super(SparseBEVHead, self).__init__(num_classes, in_channels, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs)\n\n        self.code_weights = nn.Parameter(torch.tensor(self.code_weights), requires_grad=False)\n        self.bbox_coder = build_bbox_coder(bbox_coder)\n        self.pc_range = self.bbox_coder.pc_range\n\n        self.dn_enabled = query_denoising\n        self.dn_group_num = query_denoising_groups\n        self.dn_weight = 1.0\n        self.dn_bbox_noise_scale = 0.5\n        self.dn_label_noise_scale = 0.5\n\n    def _init_layers(self):\n        self.init_query_bbox = nn.Embedding(self.num_query, 10)  # (x, y, z, w, l, h, sin, cos, vx, vy)\n        self.label_enc = nn.Embedding(self.num_classes + 1, self.embed_dims - 1)  # DAB-DETR\n\n        nn.init.zeros_(self.init_query_bbox.weight[:, 2:3])\n        nn.init.zeros_(self.init_query_bbox.weight[:, 8:10])\n        nn.init.constant_(self.init_query_bbox.weight[:, 5:6], 1.5)\n\n        grid_size = int(math.sqrt(self.num_query))\n        assert grid_size * grid_size == self.num_query\n        x = y = torch.arange(grid_size)\n        xx, yy = torch.meshgrid(x, y, indexing='ij')  # [0, grid_size - 1]\n        xy = torch.cat([xx[..., None], yy[..., None]], dim=-1)\n        xy = (xy + 0.5) / grid_size  # [0.5, grid_size - 0.5] / grid_size ~= (0, 1)\n        with torch.no_grad():\n            self.init_query_bbox.weight[:, :2] = xy.reshape(-1, 2)  # [Q, 2]\n\n    def init_weights(self):\n        self.transformer.init_weights()\n\n    def forward(self, mlvl_feats, img_metas):\n        query_bbox = self.init_query_bbox.weight.clone()  # [Q, 10]\n        #query_bbox[..., :3] = query_bbox[..., :3].sigmoid()\n\n        B = mlvl_feats[0].shape[0]\n        query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)\n\n        cls_scores, bbox_preds = self.transformer(\n            query_bbox,\n            query_feat,\n            mlvl_feats,\n            attn_mask=attn_mask,\n            img_metas=img_metas,\n        )\n\n        bbox_preds[..., 0] = bbox_preds[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]\n        bbox_preds[..., 1] = bbox_preds[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]\n        bbox_preds[..., 2] = bbox_preds[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]\n\n        bbox_preds = torch.cat([\n            bbox_preds[..., 0:2],\n            bbox_preds[..., 3:5],\n            bbox_preds[..., 2:3],\n            bbox_preds[..., 5:10],\n        ], dim=-1)  # [cx, cy, w, l, cz, h, sin, cos, vx, vy]\n\n        if mask_dict is not None and mask_dict['pad_size'] > 0:\n            output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]\n            output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]\n            output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]\n            output_bbox_preds = bbox_preds[:, :, mask_dict['pad_size']:, :]\n            mask_dict['output_known_lbs_bboxes'] = (output_known_cls_scores, output_known_bbox_preds)\n            outs = {\n                'all_cls_scores': output_cls_scores,\n                'all_bbox_preds': output_bbox_preds,\n                'enc_cls_scores': None,\n                'enc_bbox_preds': None, \n                'dn_mask_dict': mask_dict,\n            }\n        else:\n            outs = {\n                'all_cls_scores': cls_scores,\n                'all_bbox_preds': bbox_preds,\n                'enc_cls_scores': None,\n                'enc_bbox_preds': None, \n            }\n\n        return outs\n\n    def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):\n        device = init_query_bbox.device\n        indicator0 = torch.zeros([self.num_query, 1], device=device)\n        init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)\n        init_query_feat = torch.cat([init_query_feat, indicator0], dim=1)\n\n        if self.training and self.dn_enabled:\n            targets = [{\n                'bboxes': torch.cat([m['gt_bboxes_3d'].gravity_center,\n                                     m['gt_bboxes_3d'].tensor[:, 3:]], dim=1).cuda(),\n                'labels': m['gt_labels_3d'].cuda().long()\n            } for m in img_metas]\n\n            known = [torch.ones_like(t['labels'], device=device) for t in targets]\n            known_num = [sum(k) for k in known]\n\n            # can be modified to selectively denosie some label or boxes; also known label prediction\n            unmask_bbox = unmask_label = torch.cat(known)\n            labels = torch.cat([t['labels'] for t in targets]).clone()\n            bboxes = torch.cat([t['bboxes'] for t in targets]).clone()\n            batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])\n\n            known_indice = torch.nonzero(unmask_label + unmask_bbox)\n            known_indice = known_indice.view(-1)\n\n            # add noise\n            known_indice = known_indice.repeat(self.dn_group_num, 1).view(-1)\n            known_labels = labels.repeat(self.dn_group_num, 1).view(-1)\n            known_bid = batch_idx.repeat(self.dn_group_num, 1).view(-1)\n            known_bboxs = bboxes.repeat(self.dn_group_num, 1) # 9\n            known_labels_expand = known_labels.clone()\n            known_bbox_expand = known_bboxs.clone()\n\n            # noise on the box\n            if self.dn_bbox_noise_scale > 0:\n                wlh = known_bbox_expand[..., 3:6].clone()\n                rand_prob = torch.rand_like(known_bbox_expand) * 2 - 1.0\n                known_bbox_expand[..., 0:3] += torch.mul(rand_prob[..., 0:3], wlh / 2) * self.dn_bbox_noise_scale\n                # known_bbox_expand[..., 3:6] += torch.mul(rand_prob[..., 3:6], wlh) * self.dn_bbox_noise_scale\n                # known_bbox_expand[..., 6:7] += torch.mul(rand_prob[..., 6:7], 3.14159) * self.dn_bbox_noise_scale\n\n            known_bbox_expand = encode_bbox(known_bbox_expand, self.pc_range)\n            known_bbox_expand[..., 0:3].clamp_(min=0.0, max=1.0)\n            # nn.init.constant(known_bbox_expand[..., 8:10], 0.0)\n\n            # noise on the label\n            if self.dn_label_noise_scale > 0:\n                p = torch.rand_like(known_labels_expand.float())\n                chosen_indice = torch.nonzero(p < self.dn_label_noise_scale).view(-1)  # usually half of bbox noise\n                new_label = torch.randint_like(chosen_indice, 0, self.num_classes)  # randomly put a new one here\n                known_labels_expand.scatter_(0, chosen_indice, new_label)\n\n            known_feat_expand = label_enc(known_labels_expand)\n            indicator1 = torch.ones([known_feat_expand.shape[0], 1], device=device)  # add dn part indicator\n            known_feat_expand = torch.cat([known_feat_expand, indicator1], dim=1)\n\n            # construct final query\n            dn_single_pad = int(max(known_num))\n            dn_pad_size = int(dn_single_pad * self.dn_group_num)\n            dn_query_bbox = torch.zeros([dn_pad_size, init_query_bbox.shape[-1]], device=device)\n            dn_query_feat = torch.zeros([dn_pad_size, self.embed_dims], device=device)\n            input_query_bbox = torch.cat([dn_query_bbox, init_query_bbox], dim=0).repeat(batch_size, 1, 1)\n            input_query_feat = torch.cat([dn_query_feat, init_query_feat], dim=0).repeat(batch_size, 1, 1)\n\n            if len(known_num):\n                map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num])  # [1,2, 1,2,3]\n                map_known_indice = torch.cat([map_known_indice + dn_single_pad * i for i in range(self.dn_group_num)]).long()\n\n            if len(known_bid):\n                input_query_bbox[known_bid.long(), map_known_indice] = known_bbox_expand\n                input_query_feat[(known_bid.long(), map_known_indice)] = known_feat_expand\n\n            total_size = dn_pad_size + self.num_query\n            attn_mask = torch.ones([total_size, total_size], device=device) < 0\n\n            # match query cannot see the reconstruct\n            attn_mask[dn_pad_size:, :dn_pad_size] = True\n            for i in range(self.dn_group_num):\n                if i == 0:\n                    attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True\n                if i == self.dn_group_num - 1:\n                    attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True\n                else:\n                    attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True\n                    attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True\n\n            mask_dict = {\n                'known_indice': torch.as_tensor(known_indice).long(),\n                'batch_idx': torch.as_tensor(batch_idx).long(),\n                'map_known_indice': torch.as_tensor(map_known_indice).long(),\n                'known_lbs_bboxes': (known_labels, known_bboxs),\n                'pad_size': dn_pad_size\n            }\n        else:\n            input_query_bbox = init_query_bbox.repeat(batch_size, 1, 1)\n            input_query_feat = init_query_feat.repeat(batch_size, 1, 1)\n            attn_mask = None\n            mask_dict = None\n\n        return input_query_bbox, input_query_feat, attn_mask, mask_dict\n\n    def prepare_for_dn_loss(self, mask_dict):\n        cls_scores, bbox_preds = mask_dict['output_known_lbs_bboxes']\n        known_labels, known_bboxs = mask_dict['known_lbs_bboxes']\n        map_known_indice = mask_dict['map_known_indice'].long()\n        known_indice = mask_dict['known_indice'].long()\n        batch_idx = mask_dict['batch_idx'].long()\n        bid = batch_idx[known_indice]\n        num_tgt = known_indice.numel()\n\n        if len(cls_scores) > 0:\n            cls_scores = cls_scores.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)\n            bbox_preds = bbox_preds.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)\n\n        return known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt\n\n    def dn_loss_single(self,\n                       cls_scores,\n                       bbox_preds,\n                       known_bboxs,\n                       known_labels,\n                       num_total_pos=None):        \n        # Compute the average number of gt boxes accross all gpus\n        num_total_pos = cls_scores.new_tensor([num_total_pos])\n        num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1.0).item()\n\n        # cls loss\n        cls_scores = cls_scores.reshape(-1, self.cls_out_channels)\n        bbox_weights = torch.ones_like(bbox_preds)\n        label_weights = torch.ones_like(known_labels)\n        loss_cls = self.loss_cls(\n            cls_scores,\n            known_labels.long(),\n            label_weights,\n            avg_factor=num_total_pos\n        )\n\n        # regression L1 loss\n        bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))\n        normalized_bbox_targets = normalize_bbox(known_bboxs)\n        isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)\n        bbox_weights = bbox_weights * self.code_weights\n        loss_bbox = self.loss_bbox(\n            bbox_preds[isnotnan, :10],\n            normalized_bbox_targets[isnotnan, :10],\n            bbox_weights[isnotnan, :10],\n            avg_factor=num_total_pos\n        )\n\n        loss_cls = self.dn_weight * torch.nan_to_num(loss_cls)\n        loss_bbox = self.dn_weight * torch.nan_to_num(loss_bbox)\n\n        return loss_cls, loss_bbox\n\n    @force_fp32(apply_to=('preds_dicts'))\n    def calc_dn_loss(self, loss_dict, preds_dicts, num_dec_layers):\n        known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt = \\\n            self.prepare_for_dn_loss(preds_dicts['dn_mask_dict'])\n\n        all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)]\n        all_known_labels_list = [known_labels for _ in range(num_dec_layers)]\n        all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)]\n\n        dn_losses_cls, dn_losses_bbox = multi_apply(\n            self.dn_loss_single, cls_scores, bbox_preds,\n            all_known_bboxs_list, all_known_labels_list, all_num_tgts_list)\n\n        loss_dict['loss_cls_dn'] = dn_losses_cls[-1]\n        loss_dict['loss_bbox_dn'] = dn_losses_bbox[-1]\n\n        num_dec_layer = 0\n        for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]):\n            loss_dict[f'd{num_dec_layer}.loss_cls_dn'] = loss_cls_i\n            loss_dict[f'd{num_dec_layer}.loss_bbox_dn'] = loss_bbox_i\n            num_dec_layer += 1\n\n        return loss_dict\n\n    def _get_target_single(self,\n                           cls_score,\n                           bbox_pred,\n                           gt_labels,\n                           gt_bboxes,\n                           gt_bboxes_ignore=None):\n        num_bboxes = bbox_pred.size(0)\n\n        # assigner and sampler\n        assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, gt_labels, gt_bboxes_ignore, self.code_weights, True)\n        sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)\n        pos_inds = sampling_result.pos_inds\n        neg_inds = sampling_result.neg_inds\n\n        # label targets\n        labels = gt_bboxes.new_full((num_bboxes, ), self.num_classes, dtype=torch.long)\n        labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]\n        label_weights = gt_bboxes.new_ones(num_bboxes)\n\n        # bbox targets\n        bbox_targets = torch.zeros_like(bbox_pred)[..., :9]\n        bbox_weights = torch.zeros_like(bbox_pred)\n        bbox_weights[pos_inds] = 1.0\n        \n        # DETR\n        bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes\n        return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds)\n\n    def get_targets(self,\n                    cls_scores_list,\n                    bbox_preds_list,\n                    gt_bboxes_list,\n                    gt_labels_list,\n                    gt_bboxes_ignore_list=None):\n        assert gt_bboxes_ignore_list is None, \\\n            'Only supports for gt_bboxes_ignore setting to None.'\n        num_imgs = len(cls_scores_list)\n        gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)]\n\n        (labels_list, label_weights_list, bbox_targets_list,\n         bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(\n                self._get_target_single, cls_scores_list, bbox_preds_list,\n             gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list)\n        num_total_pos = sum((inds.numel() for inds in pos_inds_list))\n        num_total_neg = sum((inds.numel() for inds in neg_inds_list))\n        return (labels_list, label_weights_list, bbox_targets_list,\n                bbox_weights_list, num_total_pos, num_total_neg)\n\n    def loss_single(self,\n                    cls_scores,\n                    bbox_preds,\n                    gt_bboxes_list,\n                    gt_labels_list,\n                    gt_bboxes_ignore_list=None):\n        num_imgs = cls_scores.size(0)\n        cls_scores_list = [cls_scores[i] for i in range(num_imgs)]\n        bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]\n        cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,\n                gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list)\n        (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,\n         num_total_pos, num_total_neg) = cls_reg_targets\n\n        labels = torch.cat(labels_list, 0)\n        label_weights = torch.cat(label_weights_list, 0)\n        bbox_targets = torch.cat(bbox_targets_list, 0)\n        bbox_weights = torch.cat(bbox_weights_list, 0)\n\n        # classification loss\n        cls_scores = cls_scores.reshape(-1, self.cls_out_channels)\n        # construct weighted avg_factor to match with the official DETR repo\n        cls_avg_factor = num_total_pos * 1.0 + \\\n            num_total_neg * self.bg_cls_weight\n        if self.sync_cls_avg_factor:\n            cls_avg_factor = reduce_mean(\n                cls_scores.new_tensor([cls_avg_factor]))\n\n        cls_avg_factor = max(cls_avg_factor, 1)\n        loss_cls = self.loss_cls(\n            cls_scores, labels, label_weights, avg_factor=cls_avg_factor)\n\n        # Compute the average number of gt boxes accross all gpus, for\n        # normalization purposes\n        num_total_pos = loss_cls.new_tensor([num_total_pos])\n        num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()\n\n        # regression L1 loss\n        bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))\n        normalized_bbox_targets = normalize_bbox(bbox_targets)\n        isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)\n        bbox_weights = bbox_weights * self.code_weights\n\n        loss_bbox = self.loss_bbox(\n            bbox_preds[isnotnan, :10],\n            normalized_bbox_targets[isnotnan, :10],\n            bbox_weights[isnotnan, :10],\n            avg_factor=num_total_pos\n        )\n\n        loss_cls = torch.nan_to_num(loss_cls)\n        loss_bbox = torch.nan_to_num(loss_bbox)\n        \n        return loss_cls, loss_bbox\n\n    @force_fp32(apply_to=('preds_dicts'))\n    def loss(self,\n             gt_bboxes_list,\n             gt_labels_list,\n             preds_dicts,\n             gt_bboxes_ignore=None):\n        assert gt_bboxes_ignore is None, \\\n            f'{self.__class__.__name__} only supports ' \\\n            f'for gt_bboxes_ignore setting to None.'\n\n        all_cls_scores = preds_dicts['all_cls_scores']\n        all_bbox_preds = preds_dicts['all_bbox_preds']\n        enc_cls_scores = preds_dicts['enc_cls_scores']\n        enc_bbox_preds = preds_dicts['enc_bbox_preds']\n\n        num_dec_layers = len(all_cls_scores)\n        device = gt_labels_list[0].device\n        gt_bboxes_list = [torch.cat(\n            (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),\n            dim=1).to(device) for gt_bboxes in gt_bboxes_list]\n\n        all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]\n        all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]\n        all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]\n\n        losses_cls, losses_bbox = multi_apply(\n            self.loss_single, all_cls_scores, all_bbox_preds,\n            all_gt_bboxes_list, all_gt_labels_list, \n            all_gt_bboxes_ignore_list)\n\n        loss_dict = dict()\n        # loss of proposal generated from encode feature map\n        if enc_cls_scores is not None:\n            binary_labels_list = [\n                torch.zeros_like(gt_labels_list[i])\n                for i in range(len(all_gt_labels_list))\n            ]\n            enc_loss_cls, enc_losses_bbox = \\\n                self.loss_single(enc_cls_scores, enc_bbox_preds,\n                                 gt_bboxes_list, binary_labels_list, gt_bboxes_ignore)\n            loss_dict['enc_loss_cls'] = enc_loss_cls\n            loss_dict['enc_loss_bbox'] = enc_losses_bbox\n\n        if 'dn_mask_dict' in preds_dicts and preds_dicts['dn_mask_dict'] is not None:\n            loss_dict = self.calc_dn_loss(loss_dict, preds_dicts, num_dec_layers)\n\n        # loss from the last decoder layer\n        loss_dict['loss_cls'] = losses_cls[-1]\n        loss_dict['loss_bbox'] = losses_bbox[-1]\n\n        # loss from other decoder layers\n        num_dec_layer = 0\n        for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]):\n            loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i\n            loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i\n            num_dec_layer += 1\n        return loss_dict\n\n    @force_fp32(apply_to=('preds_dicts'))\n    def get_bboxes(self, preds_dicts, img_metas, rescale=False):\n        preds_dicts = self.bbox_coder.decode(preds_dicts)\n        num_samples = len(preds_dicts)\n        ret_list = []\n        for i in range(num_samples):\n            preds = preds_dicts[i]\n            bboxes = preds['bboxes']\n            bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5\n            bboxes = LiDARInstance3DBoxes(bboxes, 9)\n            scores = preds['scores']\n            labels = preds['labels']\n            ret_list.append([bboxes, scores, labels])\n        return ret_list\n"
  },
  {
    "path": "models/sparsebev_sampling.py",
    "content": "import torch\nfrom .bbox.utils import decode_bbox\nfrom .utils import rotation_3d_in_axis, DUMP\nfrom .csrc.wrapper import msmv_sampling\n\n\ndef make_sample_points_from_bbox(query_bbox, offset, pc_range):\n    '''\n    query_bbox: [B, Q, 10]\n    offset: [B, Q, num_points, 4], normalized by stride\n    '''\n    query_bbox = decode_bbox(query_bbox, pc_range)  # [B, Q, 9]\n\n    xyz = query_bbox[..., 0:3]  # [B, Q, 3]\n    wlh = query_bbox[..., 3:6]  # [B, Q, 3]\n\n    # NOTE: different from SparseBEV\n    xyz += wlh / 2  # conver to center\n    \n    delta_xyz = offset[..., 0:3]  # [B, Q, P, 3]\n    delta_xyz = wlh[:, :, None, :] * delta_xyz  # [B, Q, P, 3]\n\n    if query_bbox.shape[-1] > 6:\n        ang = query_bbox[..., 6:7]  # [B, Q, 1]\n        delta_xyz = rotation_3d_in_axis(delta_xyz, ang)  # [B, Q, P, 3]\n    \n    sample_xyz = xyz[:, :, None, :] + delta_xyz  # [B, Q, P, 3]\n\n    return sample_xyz  # [B, Q, P, 3]\n\n\ndef make_sample_points_from_mask(valid_map, pc_range, occ_size, num_points, occ_loc=None, offset=None):\n    '''\n    valid_map: [B, Q, W, H, D] or [B, Q, N]\n    occ_loc: [B, N, 3] if valid map is sparse\n    Return: [B, Q, GP, 3] in pc_range\n    '''\n    B, Q = valid_map.shape[:2]\n    occ_size = torch.tensor(occ_size).to(valid_map.device)\n    \n    sampling_pts = []\n    for b in range(B):\n        indices = torch.where(valid_map[b])\n        if indices[0].shape[0] == 0:\n            pts = torch.rand((Q, num_points, 3)).to(valid_map.device)\n        else:\n            if len(valid_map.shape) == 5:\n                bin_count = valid_map[b].sum(dim=(1,2,3))\n            else:\n                bin_count = valid_map[b].sum(dim=1)\n            sampling_rand = torch.rand((Q, num_points)).to(bin_count.device)\n            sampling_index = (sampling_rand * bin_count[:, None]).floor().long()\n            low_bound = torch.cumsum(bin_count, dim=0) - bin_count\n            sampling_index = sampling_index + low_bound[:, None]\n            sampling_index[sampling_index >= indices[0].shape[0]] = indices[0].shape[0] -1  # this can happen when zeros appear in the tail\n            sampling_index = sampling_index.to(valid_map.device)\n            \n            if occ_loc is None: # dense occ points\n                pts = torch.stack((indices[1][sampling_index], indices[2][sampling_index], indices[3][sampling_index]))\n                pts = pts.permute(1, 2, 0)\n            else:\n                occ_idx = indices[1][sampling_index]\n                pts = occ_loc[b][occ_idx]\n        \n            # pad queries with no valid occ\n            pts = pts.float()\n            rand_sampling_points = torch.rand(((bin_count==0).sum(), num_points, 3)).to(pts.device) * occ_size\n            pts[bin_count==0] = rand_sampling_points\n        sampling_pts.append(pts)\n        \n    sampling_pts = torch.stack(sampling_pts)\n    if offset is not None:\n        sampling_pts = sampling_pts + offset\n    \n    sampling_pts = sampling_pts / occ_size\n    sampling_pts[..., 0] = sampling_pts[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0]\n    sampling_pts[..., 1] = sampling_pts[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1]\n    sampling_pts[..., 2] = sampling_pts[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2]\n\n    return sampling_pts\n\n\ndef sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):\n    B, Q, T, G, P, _ = sample_points.shape  # [B, Q, T, G, P, 4]\n    N = 6\n\n    sample_points = sample_points.reshape(B, Q, T, G * P, 3)\n    \n    if DUMP.enabled:\n        torch.save(sample_points,\n                   '{}/sample_points_3d_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))\n\n    # get the projection matrix\n    lidar2img = lidar2img[:, :(T*N), None, None, :, :]  # [B, TN, 1, 1, 4, 4]\n    lidar2img = lidar2img.expand(B, T*N, Q, G * P, 4, 4)\n    lidar2img = lidar2img.reshape(B, T, N, Q, G*P, 4, 4)\n\n    # expand the points\n    ones = torch.ones_like(sample_points[..., :1])\n    sample_points = torch.cat([sample_points, ones], dim=-1)  # [B, Q, GP, 4]\n    sample_points = sample_points[:, :, None, ..., None]     # [B, Q, T, GP, 4]\n    sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)\n    sample_points = sample_points.transpose(1, 3)   # [B, T, N, Q, GP, 4, 1]\n\n    # project 3d sampling points to image\n    sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1)  # [B, T, N, Q, GP, 4]\n\n    # homo coord -> pixel coord\n    homo = sample_points_cam[..., 2:3]\n    homo_nonzero = torch.maximum(homo, torch.zeros_like(homo) + eps)\n    sample_points_cam = sample_points_cam[..., 0:2] / homo_nonzero  # [B, T, N, Q, GP, 2]\n\n    # normalize\n    sample_points_cam[..., 0] /= image_w\n    sample_points_cam[..., 1] /= image_h\n\n    # check if out of image\n    valid_mask = ((homo > eps) \\\n        & (sample_points_cam[..., 1:2] > 0.0)\n        & (sample_points_cam[..., 1:2] < 1.0)\n        & (sample_points_cam[..., 0:1] > 0.0)\n        & (sample_points_cam[..., 0:1] < 1.0)\n    ).squeeze(-1).float()  # [B, T, N, Q, GP]\n\n    if DUMP.enabled:\n        torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1),\n                   '{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))\n        torch.save(valid_mask,\n                   '{}/sample_points_cam_valid_mask_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))\n\n    valid_mask = valid_mask.permute(0, 1, 3, 4, 2)  # [B, T, Q, GP, N]\n    sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5)  # [B, T, Q, GP, N, 2]\n\n    i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)\n    i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)\n    i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)\n    i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device)\n    i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1)\n    i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)\n    i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)\n    i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)\n    i_view = torch.argmax(valid_mask, dim=-1)[..., None]  # [B, T, Q, GP, 1]\n\n    sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :]  # [B, Q, GP, 1, 2]\n    valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view]  # [B, Q, GP, 1]\n\n    sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / 5], dim=-1)\n    sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)\n    sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6)  # [B, T, G, Q, P, 1, 3]\n    sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)\n\n    scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)\n    scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)\n    scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)\n\n    final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)\n    C = final.shape[2]  # [BTG, Q, C, P]\n    final = final.reshape(B, T, G, Q, C, P)\n    final = final.permute(0, 3, 2, 1, 5, 4)\n    final = final.flatten(3, 4)  # [B, Q, G, FP, C]\n\n    return final\n"
  },
  {
    "path": "models/sparsebev_transformer.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional as F\nfrom mmcv.runner import BaseModule\nfrom mmcv.cnn import bias_init_with_prob\nfrom mmcv.cnn.bricks.transformer import MultiheadAttention, FFN\nfrom mmdet.models.utils.builder import TRANSFORMER\nfrom .bbox.utils import decode_bbox\nfrom .utils import inverse_sigmoid, DUMP\nfrom .sparsebev_sampling import sampling_4d, make_sample_points_from_bbox\nfrom .checkpoint import checkpoint as cp\n\n\n@TRANSFORMER.register_module()\nclass SparseBEVTransformer(BaseModule):\n    def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None):\n        assert init_cfg is None, 'To prevent abnormal initialization ' \\\n                            'behavior, init_cfg is not allowed to be set'\n        super(SparseBEVTransformer, self).__init__(init_cfg=init_cfg)\n\n        self.embed_dims = embed_dims\n        self.pc_range = pc_range\n\n        self.decoder = SparseBEVTransformerDecoder(embed_dims, num_frames, num_points, num_layers, num_levels, num_classes, code_size, pc_range=pc_range)\n\n    @torch.no_grad()\n    def init_weights(self):\n        self.decoder.init_weights()\n\n    def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):\n        cls_scores, bbox_preds = self.decoder(query_bbox, query_feat, mlvl_feats, attn_mask, img_metas)\n\n        cls_scores = torch.nan_to_num(cls_scores)\n        bbox_preds = torch.nan_to_num(bbox_preds)\n\n        return cls_scores, bbox_preds\n\n\nclass SparseBEVTransformerDecoder(BaseModule):\n    def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None):\n        super(SparseBEVTransformerDecoder, self).__init__(init_cfg)\n        self.num_layers = num_layers\n        self.pc_range = pc_range\n\n        self.decoder_layer = SparseBEVTransformerDecoderLayer(\n            embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range\n        )\n\n    @torch.no_grad()\n    def init_weights(self):\n        self.decoder_layer.init_weights()\n\n    def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):\n        cls_scores, bbox_preds = [], []\n\n        timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)\n        timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])\n        time_diff = timestamps[:, :1, :] - timestamps\n        time_diff = np.mean(time_diff, axis=-1).astype(np.float32)  # [B, F]\n        time_diff = torch.from_numpy(time_diff).to(query_bbox.device)  # [B, F]\n        img_metas[0]['time_diff'] = time_diff\n\n        lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)\n        lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device)  # [B, N, 4, 4]\n        img_metas[0]['lidar2img'] = lidar2img\n\n        for lvl, feat in enumerate(mlvl_feats):\n            B, TN, GC, H, W = feat.shape  # [B, TN, GC, H, W]\n            N, T, G, C = 6, TN // 6, 4, GC // 4\n            feat = feat.reshape(B, T, N, G, C, H, W)\n            feat = feat.permute(0, 1, 3, 2, 5, 6, 4)  # [B, T, G, N, H, W, C]\n            feat = feat.reshape(B*T*G, N, H, W, C)  # [BTG, C, N, H, W]\n            mlvl_feats[lvl] = feat.contiguous()\n\n        for i in range(self.num_layers):\n            DUMP.stage_count = i\n\n            query_feat, cls_score, bbox_pred = self.decoder_layer(\n                query_bbox, query_feat, mlvl_feats, attn_mask, img_metas\n            )\n            query_bbox = bbox_pred.clone().detach()\n\n            cls_scores.append(cls_score)\n            bbox_preds.append(bbox_pred)\n\n        cls_scores = torch.stack(cls_scores)\n        bbox_preds = torch.stack(bbox_preds)\n\n        return cls_scores, bbox_preds\n\n\nclass SparseBEVTransformerDecoderLayer(BaseModule):\n    def __init__(self, embed_dims, num_frames=8, num_points=4, num_levels=4, num_classes=10, code_size=10, num_cls_fcs=2, num_reg_fcs=2, pc_range=[], init_cfg=None):\n        super(SparseBEVTransformerDecoderLayer, self).__init__(init_cfg)\n\n        self.embed_dims = embed_dims\n        self.num_classes = num_classes\n        self.code_size = code_size\n        self.pc_range = pc_range\n\n        self.position_encoder = nn.Sequential(\n            nn.Linear(3, self.embed_dims), \n            nn.LayerNorm(self.embed_dims),\n            nn.ReLU(inplace=True),\n            nn.Linear(self.embed_dims, self.embed_dims),\n            nn.LayerNorm(self.embed_dims),\n            nn.ReLU(inplace=True),\n        )\n\n        self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range)\n        self.sampling = SparseBEVSampling(embed_dims, num_frames=num_frames, num_groups=4, num_points=num_points, num_levels=num_levels, pc_range=pc_range)\n        self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=4, out_points=128)\n        self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1)\n\n        self.norm1 = nn.LayerNorm(embed_dims)\n        self.norm2 = nn.LayerNorm(embed_dims)\n        self.norm3 = nn.LayerNorm(embed_dims)\n\n        cls_branch = []\n        for _ in range(num_cls_fcs):\n            cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims))\n            cls_branch.append(nn.LayerNorm(self.embed_dims))\n            cls_branch.append(nn.ReLU(inplace=True))\n        cls_branch.append(nn.Linear(self.embed_dims, self.num_classes))\n        self.cls_branch = nn.Sequential(*cls_branch)\n\n        reg_branch = []\n        for _ in range(num_reg_fcs):\n            reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))\n            reg_branch.append(nn.ReLU(inplace=True))\n        reg_branch.append(nn.Linear(self.embed_dims, self.code_size))\n        self.reg_branch = nn.Sequential(*reg_branch)\n\n    @torch.no_grad()\n    def init_weights(self):\n        self.self_attn.init_weights()\n        self.sampling.init_weights()\n        self.mixing.init_weights()\n\n        bias_init = bias_init_with_prob(0.01)\n        nn.init.constant_(self.cls_branch[-1].bias, bias_init)\n\n    def refine_bbox(self, bbox_proposal, bbox_delta):\n        xyz = inverse_sigmoid(bbox_proposal[..., 0:3])\n        xyz_delta = bbox_delta[..., 0:3]\n        xyz_new = torch.sigmoid(xyz_delta + xyz)\n\n        return torch.cat([xyz_new, bbox_delta[..., 3:]], dim=-1)\n\n    def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):\n        \"\"\"\n        query_bbox: [B, Q, 10] [cx, cy, cz, w, h, d, rot.sin, rot.cos, vx, vy]\n        \"\"\"\n        query_pos = self.position_encoder(query_bbox[..., :3])\n        query_feat = query_feat + query_pos\n\n        query_feat = self.norm1(self.self_attn(query_bbox, query_feat, attn_mask))\n        sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas)\n        query_feat = self.norm2(self.mixing(sampled_feat, query_feat))\n        query_feat = self.norm3(self.ffn(query_feat))\n\n        cls_score = self.cls_branch(query_feat)  # [B, Q, num_classes]\n        bbox_pred = self.reg_branch(query_feat)  # [B, Q, code_size]\n        bbox_pred = self.refine_bbox(query_bbox, bbox_pred)\n\n        time_diff = img_metas[0]['time_diff']  # [B, F]\n        if time_diff.shape[1] > 1:\n            time_diff = time_diff.clone()\n            time_diff[time_diff < 1e-5] = 1.0\n            bbox_pred[..., 8:] = bbox_pred[..., 8:] / time_diff[:, 1:2, None]\n\n        if DUMP.enabled:\n            query_bbox_dec = decode_bbox(query_bbox, self.pc_range)\n            bbox_pred_dec = decode_bbox(bbox_pred, self.pc_range)\n            cls_score_sig = torch.sigmoid(cls_score)\n            torch.save(query_bbox_dec, '{}/query_bbox_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))\n            torch.save(bbox_pred_dec, '{}/bbox_pred_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))\n            torch.save(cls_score_sig, '{}/cls_score_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))\n\n        return query_feat, cls_score, bbox_pred\n\n\nclass SparseBEVSelfAttention(BaseModule):\n    def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], scale_adaptive=True):\n        super().__init__()\n        self.pc_range = pc_range\n\n        self.attention = MultiheadAttention(embed_dims, num_heads, dropout, batch_first=True)\n\n        if scale_adaptive:\n            self.gen_tau = nn.Linear(embed_dims, num_heads)\n        else:\n            self.gen_tau = None\n\n    @torch.no_grad()\n    def init_weights(self):\n        if self.gen_tau is not None:\n            nn.init.zeros_(self.gen_tau.weight)\n            nn.init.uniform_(self.gen_tau.bias, 0.0, 2.0)\n\n    def inner_forward(self, query_bbox, query_feat, pre_attn_mask=None):\n        \"\"\"\n        query_bbox: [B, Q, 10]\n        query_feat: [B, Q, C]\n        \"\"\"\n        if self.gen_tau is not None:\n            dist = self.calc_bbox_dists(query_bbox)\n            tau = self.gen_tau(query_feat)  # [B, Q, 8]\n\n            if DUMP.enabled:\n                torch.save(tau, '{}/sasa_tau_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))\n\n            tau = tau.permute(0, 2, 1)  # [B, 8, Q]\n            attn_mask = dist[:, None, :, :] * tau[..., None]  # [B, 8, Q, Q]\n            if pre_attn_mask is not None:\n                attn_mask[:, :, pre_attn_mask] = float('-inf')\n            attn_mask = attn_mask.flatten(0, 1)  # [Bx8, Q, Q]\n        else:\n            attn_mask = None\n        \n        return self.attention(query_feat, attn_mask=attn_mask)\n\n    def forward(self, query_bbox, query_feat, pre_attn_mask=None):\n        if self.training and query_feat.requires_grad:\n            return cp(self.inner_forward, query_bbox, query_feat, pre_attn_mask, use_reentrant=False)\n        else:\n            return self.inner_forward(query_bbox, query_feat, pre_attn_mask)\n\n    @torch.no_grad()\n    def calc_bbox_dists(self, bboxes):\n        centers = decode_bbox(bboxes, self.pc_range)[..., :2]  # [B, Q, 2]\n\n        dist = []\n        for b in range(centers.shape[0]):\n            dist_b = torch.norm(centers[b].reshape(-1, 1, 2) - centers[b].reshape(1, -1, 2), dim=-1)\n            dist.append(dist_b[None, ...])\n\n        dist = torch.cat(dist, dim=0)  # [B, Q, Q]\n        dist = -dist\n\n        return dist\n\n\nclass SparseBEVSampling(BaseModule):\n    def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None):\n        super().__init__(init_cfg)\n\n        self.num_frames = num_frames\n        self.num_points = num_points\n        self.num_groups = num_groups\n        self.num_levels = num_levels\n        self.pc_range = pc_range\n\n        self.sampling_offset = nn.Linear(embed_dims, num_groups * num_points * 3)\n        self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels)\n\n    def init_weights(self):\n        bias = self.sampling_offset.bias.data.view(self.num_groups * self.num_points, 3)\n        nn.init.zeros_(self.sampling_offset.weight)\n        nn.init.uniform_(bias[:, 0:3], -0.5, 0.5)\n\n    def inner_forward(self, query_bbox, query_feat, mlvl_feats, img_metas):\n        '''\n        query_bbox: [B, Q, 10]\n        query_feat: [B, Q, C]\n        '''\n        B, Q = query_bbox.shape[:2]\n        image_h, image_w, _ = img_metas[0]['img_shape'][0]\n\n        # sampling offset of all frames\n        sampling_offset = self.sampling_offset(query_feat)\n        sampling_offset = sampling_offset.view(B, Q, self.num_groups * self.num_points, 3)\n        sampling_points = make_sample_points_from_bbox(query_bbox, sampling_offset, self.pc_range)  # [B, Q, GP, 3]\n        sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3)\n        sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3)\n\n        # warp sample points based on velocity\n        if query_bbox.shape[-1] > 8:\n            time_diff = img_metas[0]['time_diff']  # [B, F]\n            time_diff = time_diff[:, None, :, None]  # [B, 1, F, 1]\n            vel = query_bbox[..., 8:].detach()  # [B, Q, 2]\n            vel = vel[:, :, None, :]  # [B, Q, 1, 2]\n            dist = vel * time_diff  # [B, Q, F, 2]\n            dist = dist[:, :, :, None, None, :]  # [B, Q, F, 1, 1, 2]\n            sampling_points = torch.cat([\n                sampling_points[..., 0:2] - dist,\n                sampling_points[..., 2:3]\n            ], dim=-1)\n\n        # scale weights\n        scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels)\n        scale_weights = torch.softmax(scale_weights, dim=-1)\n        scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels)\n\n        # sampling\n        sampled_feats = sampling_4d(\n            sampling_points,\n            mlvl_feats,\n            scale_weights,\n            img_metas[0]['lidar2img'],\n            image_h, image_w\n        )  # [B, Q, G, FP, C]\n\n        return sampled_feats\n\n    def forward(self, query_bbox, query_feat, mlvl_feats, img_metas):\n        if self.training and query_feat.requires_grad:\n            return cp(self.inner_forward, query_bbox, query_feat, mlvl_feats, img_metas, use_reentrant=False)\n        else:\n            return self.inner_forward(query_bbox, query_feat, mlvl_feats, img_metas)\n\n\nclass AdaptiveMixing(nn.Module):\n    def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):\n        super(AdaptiveMixing, self).__init__()\n\n        out_dim = out_dim if out_dim is not None else in_dim\n        out_points = out_points if out_points is not None else in_points\n        query_dim = query_dim if query_dim is not None else in_dim\n\n        self.query_dim = query_dim\n        self.in_dim = in_dim\n        self.in_points = in_points\n        self.n_groups = n_groups\n        self.out_dim = out_dim\n        self.out_points = out_points\n\n        self.eff_in_dim = in_dim // n_groups\n        self.eff_out_dim = out_dim // n_groups\n\n        self.m_parameters = self.eff_in_dim * self.eff_out_dim\n        self.s_parameters = self.in_points * self.out_points\n        self.total_parameters = self.m_parameters + self.s_parameters\n\n        self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters)\n        self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim)\n        self.act = nn.ReLU(inplace=True)\n\n    @torch.no_grad()\n    def init_weights(self):\n        nn.init.zeros_(self.parameter_generator.weight)\n\n    def inner_forward(self, x, query):\n        B, Q, G, P, C = x.shape\n        assert G == self.n_groups\n        assert P == self.in_points\n        assert C == self.eff_in_dim\n\n        '''generate mixing parameters'''\n        params = self.parameter_generator(query)\n        params = params.reshape(B*Q, G, -1)\n        out = x.reshape(B*Q, G, P, C)\n\n        M, S = params.split([self.m_parameters, self.s_parameters], 2)\n        M = M.reshape(B*Q, G, self.eff_in_dim, self.eff_out_dim)\n        S = S.reshape(B*Q, G, self.out_points, self.in_points)\n\n        '''adaptive channel mixing'''\n        out = torch.matmul(out, M)\n        out = F.layer_norm(out, [out.size(-2), out.size(-1)])\n        out = self.act(out)\n\n        '''adaptive point mixing'''\n        out = torch.matmul(S, out)  # implicitly transpose and matmul\n        out = F.layer_norm(out, [out.size(-2), out.size(-1)])\n        out = self.act(out)\n\n        '''linear transfomation to query dim'''\n        out = out.reshape(B, Q, -1)\n        out = self.out_proj(out)\n        out = query + out\n\n        return out\n\n    def forward(self, x, query):\n        if self.training and x.requires_grad:\n            return cp(self.inner_forward, x, query, use_reentrant=False)\n        else:\n            return self.inner_forward(x, query)\n\n\nclass AdaptiveMixingPointOnly(nn.Module):\n    def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):\n        super(AdaptiveMixingPointOnly, self).__init__()\n\n        out_dim = out_dim if out_dim is not None else in_dim\n        out_points = out_points if out_points is not None else in_points\n        query_dim = query_dim if query_dim is not None else in_dim\n\n        self.query_dim = query_dim\n        self.in_dim = in_dim\n        self.in_points = in_points\n        self.n_groups = n_groups\n        self.out_dim = out_dim\n        self.out_points = out_points\n\n        self.eff_in_dim = in_dim // n_groups\n        self.eff_out_dim = out_dim // n_groups\n\n        self.s_parameters = self.in_points * self.out_points\n        self.total_parameters = self.s_parameters\n\n        self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters)\n        self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim)\n        self.act = nn.ReLU(inplace=True)\n\n    @torch.no_grad()\n    def init_weights(self):\n        nn.init.zeros_(self.parameter_generator.weight)\n\n    def inner_forward(self, x, query):\n        B, Q, G, P, C = x.shape\n        assert G == self.n_groups\n        assert P == self.in_points\n        assert C == self.eff_in_dim\n\n        '''generate mixing parameters'''\n        params = self.parameter_generator(query)\n        params = params.reshape(B*Q, G, -1)\n        out = x.reshape(B*Q, G, P, C)\n\n        S = params.reshape(B*Q, G, self.out_points, self.in_points)\n\n        '''adaptive spatial mixing'''\n        out = torch.matmul(S, out)  # implicitly transpose and matmul\n        out = F.layer_norm(out, [out.size(-2), out.size(-1)])\n        out = self.act(out)\n\n        '''linear transfomation to query dim'''\n        out = out.reshape(B, Q, -1)\n        out = self.out_proj(out)\n        out = query + out\n\n        return out\n\n    def forward(self, x, query):\n        if self.training and x.requires_grad:\n            return cp(self.inner_forward, x, query, use_reentrant=False)\n        else:\n            return self.inner_forward(x, query)\n\n\nclass DeformAggregation(nn.Module):\n    def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):\n        super(DeformAggregation, self).__init__()\n\n        out_dim = out_dim if out_dim is not None else in_dim\n        out_points = out_points if out_points is not None else in_points\n        query_dim = query_dim if query_dim is not None else in_dim\n\n        self.query_dim = query_dim\n        self.in_dim = in_dim\n        self.in_points = in_points\n        self.n_groups = n_groups\n        self.out_dim = out_dim\n        self.out_points = out_points\n\n        self.eff_in_dim = in_dim // n_groups\n        self.eff_out_dim = out_dim // n_groups\n\n        self.attn_weights = nn.Linear(query_dim, n_groups * in_points)\n        self.out_proj = nn.Linear(self.eff_in_dim * n_groups, self.query_dim)\n        self.act = nn.ReLU(inplace=True)\n\n    @torch.no_grad()\n    def init_weights(self):\n        pass\n\n    def inner_forward(self, x, query):\n        B, Q, G, P, C = x.shape\n        assert G == self.n_groups\n        assert P == self.in_points\n        assert C == self.eff_in_dim\n        out = x.reshape(B, Q, G, P, C)\n\n        attn_weights = self.attn_weights(query)  # [B, Q, GP]\n        attn_weights = attn_weights.reshape(B, Q, self.n_groups, self.in_points, 1)  # [B, Q, G, P, 1]\n        attn_weights = attn_weights.softmax(dim=-2)\n\n        out = torch.sum(out * attn_weights, dim=-2)  # [B, Q, G, C]\n        out = out.reshape(B, Q, -1)\n        out = self.out_proj(out)\n        out = query + out\n\n        return out\n\n    def forward(self, x, query):\n        if self.training and x.requires_grad:\n            return cp(self.inner_forward, x, query, use_reentrant=False)\n        else:\n            return self.inner_forward(x, query)\n"
  },
  {
    "path": "models/sparseocc.py",
    "content": "import torch\nimport queue\nimport numpy as np\nfrom mmcv.runner import get_dist_info\nfrom mmcv.runner.fp16_utils import cast_tensor_type\nfrom mmcv.runner import force_fp32, auto_fp16\nfrom mmdet.models import DETECTORS\nfrom mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector\nfrom .utils import pad_multiple, GpuPhotoMetricDistortion\n\n\n@DETECTORS.register_module()\nclass SparseOcc(MVXTwoStageDetector):\n    def __init__(self,\n                 pts_voxel_layer=None,\n                 pts_voxel_encoder=None,\n                 pts_middle_encoder=None,\n                 pts_fusion_layer=None,\n                 img_backbone=None,\n                 pts_backbone=None,\n                 img_neck=None,\n                 pts_neck=None,\n                 pts_bbox_head=None,\n                 img_roi_head=None,\n                 img_rpn_head=None,\n                 train_cfg=None,\n                 test_cfg=None,\n                 pretrained=None,\n                 data_aug=None,\n                 use_mask_camera=False,\n                 **kwargs):\n\n        super(SparseOcc, self).__init__(pts_voxel_layer, pts_voxel_encoder,\n                             pts_middle_encoder, pts_fusion_layer,\n                             img_backbone, pts_backbone, img_neck, pts_neck,\n                             pts_bbox_head, img_roi_head, img_rpn_head,\n                             train_cfg, test_cfg, pretrained)\n\n        self.use_mask_camera = use_mask_camera\n        self.fp16_enabled = False\n        self.data_aug = data_aug\n        self.color_aug = GpuPhotoMetricDistortion()\n\n        self.memory = {}\n        self.queue = queue.Queue()\n\n    @auto_fp16(apply_to=('img'), out_fp32=True)\n    def extract_img_feat(self, img):\n        img_feats = self.img_backbone(img)\n\n        if isinstance(img_feats, dict):\n            img_feats = list(img_feats.values())\n\n        if self.with_img_neck:\n            img_feats = self.img_neck(img_feats)\n\n        return img_feats\n\n    @auto_fp16(apply_to=('img'))\n    def extract_feat(self, img, img_metas=None):\n        \"\"\"Extract features from images and points.\"\"\"\n        if len(img.shape) == 6:\n            img = img.flatten(1, 2)  # [B, TN, C, H, W]\n\n        B, N, C, H, W = img.size()\n        img = img.view(B * N, C, H, W)\n        img = img.float()\n\n        if self.data_aug is not None:\n            if 'img_color_aug' in self.data_aug and self.data_aug['img_color_aug'] and self.training:\n                img = self.color_aug(img)\n\n            if 'img_norm_cfg' in self.data_aug:\n                img_norm_cfg = self.data_aug['img_norm_cfg']\n\n                norm_mean = torch.tensor(img_norm_cfg['mean'], device=img.device)\n                norm_std = torch.tensor(img_norm_cfg['std'], device=img.device)\n\n                if img_norm_cfg['to_rgb']:\n                    img = img[:, [2, 1, 0], :, :]  # BGR to RGB\n\n                img = img - norm_mean.reshape(1, 3, 1, 1)\n                img = img / norm_std.reshape(1, 3, 1, 1)\n\n            for b in range(B):\n                img_shape = (img.shape[2], img.shape[3], img.shape[1])\n                img_metas[b]['img_shape'] = [img_shape for _ in range(N)]\n                img_metas[b]['ori_shape'] = [img_shape for _ in range(N)]\n\n            if 'img_pad_cfg' in self.data_aug:\n                img_pad_cfg = self.data_aug['img_pad_cfg']\n                img = pad_multiple(img, img_metas, size_divisor=img_pad_cfg['size_divisor'])\n                H, W = img.shape[-2:]\n\n        input_shape = img.shape[-2:]\n        # update real input shape of each single img\n        for img_meta in img_metas:\n            img_meta.update(input_shape=input_shape)\n\n        img_feats = self.extract_img_feat(img)\n\n        img_feats_reshaped = []\n        for img_feat in img_feats:\n            BN, C, H, W = img_feat.size()\n            img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))\n\n        return img_feats_reshaped\n\n    def forward_pts_train(self, mlvl_feats, voxel_semantics, voxel_instances, instance_class_ids, mask_camera, img_metas):\n        \"\"\"\n        voxel_semantics: [bs, 200, 200, 16], value in range [0, num_cls - 1]\n        voxel_instances: [bs, 200, 200, 16], value in range [0, num_obj - 1]\n        instance_class_ids: [[bs0_num_obj], [bs1_num_obj], ...], value in range [0, num_cls - 1]\n        \"\"\"\n        outs = self.pts_bbox_head(mlvl_feats, img_metas)\n        loss_inputs = [voxel_semantics, voxel_instances, instance_class_ids, outs]\n        return self.pts_bbox_head.loss(*loss_inputs)\n\n    def forward(self, return_loss=True, **kwargs):\n        if return_loss:\n            return self.forward_train(**kwargs)\n        else:\n            return self.forward_test(**kwargs)\n\n    @force_fp32(apply_to=('img'))\n    def forward_train(self, img_metas=None, img=None, voxel_semantics=None, voxel_instances=None, instance_class_ids=None, mask_camera=None, **kwargs):\n        img_feats = self.extract_feat(img=img, img_metas=img_metas)\n        return self.forward_pts_train(img_feats, voxel_semantics, voxel_instances, instance_class_ids, mask_camera, img_metas)\n\n    def forward_test(self, img_metas, img=None, **kwargs):\n        output = self.simple_test(img_metas, img)\n\n        sem_pred = output['sem_pred'].cpu().numpy().astype(np.uint8)\n        occ_loc = output['occ_loc'].cpu().numpy().astype(np.uint8)\n\n        batch_size = sem_pred.shape[0]\n\n        if 'pano_inst' and 'pano_sem' in output:\n            # important: uint8 is not enough for pano_pred\n            pano_inst = output['pano_inst'].cpu().numpy().astype(np.int16)\n            pano_sem = output['pano_sem'].cpu().numpy().astype(np.uint8)\n            \n            return [{\n                'sem_pred': sem_pred[b:b+1],\n                'pano_inst': pano_inst[b:b+1],\n                'pano_sem': pano_sem[b:b+1],\n                'occ_loc': occ_loc[b:b+1]\n            } for b in range(batch_size)]\n        else:\n            return [{\n                'sem_pred': sem_pred[b:b+1],\n                'occ_loc': occ_loc[b:b+1]\n            } for b in range(batch_size)]\n\n    def simple_test_pts(self, x, img_metas, rescale=False):\n        outs = self.pts_bbox_head(x, img_metas)\n        outs = self.pts_bbox_head.merge_occ_pred(outs)\n        return outs\n\n    def simple_test(self, img_metas, img=None, rescale=False):\n        world_size = get_dist_info()[1]\n        if world_size == 1:  # online\n            return self.simple_test_online(img_metas, img, rescale)\n        else:  # offline\n            return self.simple_test_offline(img_metas, img, rescale)\n\n    def simple_test_offline(self, img_metas, img=None, rescale=False):\n        img_feats = self.extract_feat(img=img, img_metas=img_metas)\n        return self.simple_test_pts(img_feats, img_metas, rescale=rescale)\n\n    def simple_test_online(self, img_metas, img=None, rescale=False):\n        self.fp16_enabled = False\n        assert len(img_metas) == 1  # batch_size = 1\n\n        B, N, C, H, W = img.shape\n        img = img.reshape(B, N//6, 6, C, H, W)\n\n        img_filenames = img_metas[0]['filename']\n        num_frames = len(img_filenames) // 6\n        # assert num_frames == img.shape[1]\n\n        img_shape = (H, W, C)\n        img_metas[0]['img_shape'] = [img_shape for _ in range(len(img_filenames))]\n        img_metas[0]['ori_shape'] = [img_shape for _ in range(len(img_filenames))]\n        img_metas[0]['pad_shape'] = [img_shape for _ in range(len(img_filenames))]\n\n        img_feats_list, img_metas_list = [], []\n\n        # extract feature frame by frame\n        for i in range(num_frames):\n            img_indices = list(np.arange(i * 6, (i + 1) * 6))\n\n            img_metas_curr = [{}]\n            for k in img_metas[0].keys():\n                if isinstance(img_metas[0][k], list):\n                    img_metas_curr[0][k] = [img_metas[0][k][i] for i in img_indices]\n\n            if img_filenames[img_indices[0]] in self.memory:\n                # found in memory\n                img_feats_curr = self.memory[img_filenames[img_indices[0]]]\n            else:\n                # extract feature and put into memory\n                img_feats_curr = self.extract_feat(img[:, i], img_metas_curr)\n                self.memory[img_filenames[img_indices[0]]] = img_feats_curr\n                self.queue.put(img_filenames[img_indices[0]])\n                while self.queue.qsize() > 16:  # avoid OOM\n                    pop_key = self.queue.get()\n                    self.memory.pop(pop_key)\n\n            img_feats_list.append(img_feats_curr)\n            img_metas_list.append(img_metas_curr)\n\n        # reorganize\n        feat_levels = len(img_feats_list[0])\n        img_feats_reorganized = []\n        for j in range(feat_levels):\n            feat_l = torch.cat([img_feats_list[i][j] for i in range(len(img_feats_list))], dim=0)\n            feat_l = feat_l.flatten(0, 1)[None, ...]\n            img_feats_reorganized.append(feat_l)\n\n        img_metas_reorganized = img_metas_list[0]\n        for i in range(1, len(img_metas_list)):\n            for k, v in img_metas_list[i][0].items():\n                if isinstance(v, list):\n                    img_metas_reorganized[0][k].extend(v)\n\n        img_feats = img_feats_reorganized\n        img_metas = img_metas_reorganized\n        img_feats = cast_tensor_type(img_feats, torch.half, torch.float32)\n\n        # run detector\n        return self.simple_test_pts(img_feats, img_metas, rescale=rescale)\n"
  },
  {
    "path": "models/sparseocc_head.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom mmdet.models import HEADS\nfrom mmcv.runner import force_fp32, auto_fp16\nfrom mmdet.models.builder import build_loss\nfrom mmdet.models.utils import build_transformer\nfrom .matcher import HungarianMatcher\nfrom .loss_utils import CE_ssc_loss, lovasz_softmax, get_voxel_decoder_loss_input\n\n\nNUSC_CLASS_FREQ = np.array([\n    944004, 1897170, 152386, 2391677, 16957802, 724139, 189027, 2074468, 413451, 2384460,\n    5916653, 175883646, 4275424, 51393615, 61411620, 105975596, 116424404, 1892500630\n])\n\n\n@HEADS.register_module()\nclass SparseOccHead(nn.Module): \n    def __init__(self,\n                 transformer=None,\n                 class_names=None,\n                 embed_dims=None,\n                 occ_size=None,\n                 pc_range=None,\n                 loss_cfgs=None,\n                 panoptic=False,\n                 **kwargs):\n        super(SparseOccHead, self).__init__()\n        self.num_classes = len(class_names)\n        self.class_names = class_names\n        self.pc_range = pc_range\n        self.occ_size = occ_size\n        self.embed_dims = embed_dims\n        self.score_threshold = 0.3\n        self.overlap_threshold = 0.8\n        self.panoptic = panoptic\n\n        self.transformer = build_transformer(transformer)\n        self.criterions = {k: build_loss(loss_cfg) for k, loss_cfg in loss_cfgs.items()}\n        self.matcher = HungarianMatcher(cost_class=2.0, cost_mask=5.0, cost_dice=5.0)\n\n        self.class_weights = torch.from_numpy(1 / np.log(NUSC_CLASS_FREQ + 0.001))\n\n    def init_weights(self):\n        self.transformer.init_weights()\n\n    @auto_fp16(apply_to=('mlvl_feats'))\n    def forward(self, mlvl_feats, img_metas):\n        occ_preds, mask_preds, class_preds = self.transformer(mlvl_feats, img_metas=img_metas)\n        \n        return {\n            'occ_preds': occ_preds, \n            'mask_preds': mask_preds, \n            'class_preds': class_preds\n        }\n\n    @force_fp32(apply_to=('preds_dicts'))\n    def loss(self, voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera=None):\n        return self.loss_single(voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera)\n\n    def loss_single(self, voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera=None):\n        loss_dict = {}\n        B = voxel_instances.shape[0]\n\n        if mask_camera is not None:\n            assert mask_camera.shape == voxel_semantics.shape\n            assert mask_camera.dtype == torch.bool\n        \n        for i, (occ_loc_i, _, seg_pred_i, _, scale) in enumerate(preds_dicts['occ_preds']):\n            loss_dict_i = {}\n            for b in range(B):\n                loss_dict_i_b = {}\n                seg_pred_i_sparse, voxel_semantics_sparse, sparse_mask = get_voxel_decoder_loss_input(\n                    voxel_semantics[b:b + 1],\n                    occ_loc_i[b:b + 1],\n                    seg_pred_i[b:b + 1] if seg_pred_i is not None else None,\n                    scale,\n                    self.num_classes\n                )\n\n                loss_dict_i_b['loss_sem_lovasz'] = lovasz_softmax(torch.softmax(seg_pred_i_sparse, dim=1), voxel_semantics_sparse)\n\n                valid_mask = (voxel_semantics_sparse < 255)\n                seg_pred_i_sparse = seg_pred_i_sparse[valid_mask].transpose(0, 1).unsqueeze(0)  # [K, CLS] -> [B, CLS, K]\n                voxel_semantics_sparse = voxel_semantics_sparse[valid_mask].unsqueeze(0)  # [K] -> [B, K]\n\n                if 'loss_geo_scal' in self.criterions.keys():\n                    loss_dict_i_b['loss_geo_scal'] = self.criterions['loss_geo_scal'](seg_pred_i_sparse, voxel_semantics_sparse)  \n                if 'loss_sem_scal' in self.criterions.keys():\n                    loss_dict_i_b['loss_sem_scal'] = self.criterions['loss_sem_scal'](seg_pred_i_sparse, voxel_semantics_sparse)\n\n                loss_dict_i_b['loss_sem_ce'] = CE_ssc_loss(seg_pred_i_sparse, voxel_semantics_sparse, self.class_weights.type_as(seg_pred_i_sparse))\n\n                for loss_key in loss_dict_i_b.keys():\n                    loss_dict_i[loss_key] = loss_dict_i.get(loss_key, 0) + loss_dict_i_b[loss_key] / B\n\n            for k, v in loss_dict_i.items():\n                loss_dict['%s_%d' % (k, i)] = v\n\n        occ_loc = preds_dicts['occ_preds'][-1][0]\n        \n        batch_idx = torch.arange(B)[:, None, None].expand(B, occ_loc.shape[1], 1).to(occ_loc.device)\n        occ_loc = occ_loc.reshape(-1, 3)\n        voxel_instances = voxel_instances[batch_idx.reshape(-1), occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]]\n        voxel_instances = voxel_instances.reshape(B, -1)  # [B, N]\n\n        if mask_camera is not None:\n            mask_camera = mask_camera[batch_idx.reshape(-1), occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]]\n            mask_camera = mask_camera.reshape(B, -1)  # [B, N]\n        \n        # drop instances if it has no positive voxels\n        for b in range(B):\n            instance_count = instance_class_ids[b].shape[0]\n            instance_voxel_counts = torch.bincount(voxel_instances[b].long())  # [255]\n            id_map = torch.cumsum(instance_voxel_counts > 0, dim=0) - 1\n            id_map[255] = 255  # empty space still has an id of 255\n            voxel_instances[b] = id_map[voxel_instances[b].long()]\n            instance_class_ids[b] = instance_class_ids[b][instance_voxel_counts[:instance_count] > 0]\n\n        for i, pred in enumerate(preds_dicts['mask_preds']):\n            indices = self.matcher(pred, preds_dicts['class_preds'][i], voxel_instances, instance_class_ids, mask_camera)\n            loss_mask, loss_dice, loss_class = self.criterions['loss_mask2former'](\n                pred, preds_dicts['class_preds'][i], voxel_instances, instance_class_ids, indices, mask_camera)\n            loss_dict['loss_mask_{:d}'.format(i)] = loss_mask\n            loss_dict['loss_dice_mask_{:d}'.format(i)] = loss_dice\n            loss_dict['loss_class_{:d}'.format(i)] = loss_class\n\n        return loss_dict\n    \n    def merge_occ_pred(self, outs):\n        mask_cls = outs['class_preds'][-1].sigmoid()\n        mask_pred = outs['mask_preds'][-1].sigmoid()\n        occ_indices = outs['occ_preds'][-1][0]\n        \n        sem_pred = self.merge_semseg(mask_cls, mask_pred)  # [B, C, N]\n        outs['sem_pred'] = sem_pred\n        outs['occ_loc'] = occ_indices\n\n        if self.panoptic:\n            pano_inst, pano_sem = self.merge_panoseg(mask_cls, mask_pred)  # [B, C, N]\n            outs['pano_inst'] = pano_inst\n            outs['pano_sem'] = pano_sem\n        \n        return outs\n    \n    # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/mask_former_model.py#L242\n    def merge_semseg(self, mask_cls, mask_pred):\n        valid_mask = mask_cls.max(dim=-1).values > self.score_threshold\n        mask_cls[~valid_mask] = 0.0\n\n        semseg = torch.einsum(\"bqc,bqn->bcn\", mask_cls, mask_pred)\n        if semseg.shape[1] == self.num_classes:\n            semseg = semseg[:, :-1]\n        \n        cls_score, cls_id = torch.max(semseg, dim=1)\n        cls_id[cls_score < 0.01] = self.num_classes - 1\n        return cls_id  # [B, N]\n    \n    def merge_panoseg(self, mask_cls, mask_pred):\n        pano_inst, pano_sem = [], []\n        for b in range(mask_cls.shape[0]):\n            pano_inst_b, pano_sem_b = self.merge_panoseg_single(\n                mask_cls[b:b + 1],\n                mask_pred[b:b + 1]\n            )\n            pano_inst.append(pano_inst_b)\n            pano_sem.append(pano_sem_b)\n        \n        pano_inst = torch.cat(pano_inst, dim=0)\n        pano_sem = torch.cat(pano_sem, dim=0)\n        \n        return pano_inst, pano_sem\n\n    # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/maskformer_model.py#L286\n    def merge_panoseg_single(self, mask_cls, mask_pred):\n        assert mask_cls.shape[0] == 1, \"bs != 1\"\n        scores, labels = mask_cls.max(-1)\n        \n        # filter out low score and background instances\n        keep = labels.ne(self.num_classes - 1) & (scores > self.score_threshold)\n        cur_scores = scores[keep]\n        cur_classes = labels[keep]\n        cur_masks = mask_pred[keep]\n\n        cur_prob_masks = cur_scores.view(-1, 1) * cur_masks\n\n        N = cur_masks.shape[-1]\n        instance_seg = torch.zeros((N), dtype=torch.int32, device=cur_masks.device)\n        semantic_seg = torch.ones((N), dtype=torch.int32, device=cur_masks.device) * (self.num_classes - 1)\n        \n        current_segment_id = 0\n        stuff_memory_list = {self.num_classes - 1: 0}\n\n        # skip all process if no mask is detected\n        if cur_masks.shape[0] != 0:\n            # take argmax\n            cur_mask_ids = cur_prob_masks.argmax(0)  # [N]\n            for k in range(cur_classes.shape[0]):\n                pred_class = cur_classes[k].item()\n\n                # moving objects are treated as instances\n                is_thing = self.class_names[pred_class] in [\n                    'car', 'truck', 'construction_vehicle', 'bus',\n                    'trailer', 'motorcycle', 'bicycle', 'pedestrian'\n                ]\n\n                mask_area = (cur_mask_ids == k).sum().item()\n                original_area = (cur_masks[k] >= 0.5).sum().item()\n                mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)\n\n                if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:\n                    if mask_area / original_area < self.overlap_threshold:\n                        continue\n\n                    # merge stuff regions\n                    if not is_thing:\n                        if int(pred_class) in stuff_memory_list.keys():\n                            instance_seg[mask] = stuff_memory_list[int(pred_class)]\n                            continue\n                        else:\n                            stuff_memory_list[int(pred_class)] = current_segment_id + 1\n\n                    current_segment_id += 1\n                    instance_seg[mask] = current_segment_id\n                    semantic_seg[mask] = pred_class\n        \n        instance_seg = instance_seg.unsqueeze(0)\n        semantic_seg = semantic_seg.unsqueeze(0)\n        \n        return instance_seg, semantic_seg  # [B, N]\n"
  },
  {
    "path": "models/sparseocc_transformer.py",
    "content": "import copy\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom mmcv.runner import BaseModule\nfrom mmdet.models.utils.builder import TRANSFORMER\nfrom mmcv.cnn.bricks.transformer import FFN\nfrom .sparsebev_transformer import AdaptiveMixing\nfrom .utils import DUMP\nfrom .checkpoint import checkpoint as cp\nfrom .sparsebev_sampling import sampling_4d, make_sample_points_from_mask\nfrom .sparse_voxel_decoder import SparseVoxelDecoder\n\n\n@TRANSFORMER.register_module()\nclass SparseOccTransformer(BaseModule):\n    def __init__(self,\n                 embed_dims=None,\n                 num_layers=None,\n                 num_queries=None,\n                 num_frames=None,\n                 num_points=None,\n                 num_groups=None,\n                 num_levels=None,\n                 num_classes=None,\n                 pc_range=None,\n                 occ_size=None,\n                 topk_training=None,\n                 topk_testing=None):\n        super().__init__()\n        self.num_frames = num_frames\n        \n        self.voxel_decoder = SparseVoxelDecoder(\n            embed_dims=embed_dims,\n            num_layers=3,\n            num_frames=num_frames,\n            num_points=num_points,\n            num_groups=num_groups,\n            num_levels=num_levels,\n            num_classes=num_classes,\n            pc_range=pc_range,\n            semantic=True,\n            topk_training=topk_training,\n            topk_testing=topk_testing\n        )\n        self.decoder = MaskFormerOccDecoder(\n            embed_dims=embed_dims,\n            num_layers=num_layers,\n            num_frames=num_frames,\n            num_queries=num_queries,\n            num_points=num_points,\n            num_groups=num_groups,\n            num_levels=num_levels,\n            num_classes=num_classes,\n            pc_range=pc_range,\n            occ_size=occ_size,\n        )\n        \n    @torch.no_grad()\n    def init_weights(self):\n        self.voxel_decoder.init_weights()\n        self.decoder.init_weights()\n\n    def forward(self, mlvl_feats, img_metas):\n        for lvl, feat in enumerate(mlvl_feats):\n            B, TN, GC, H, W = feat.shape  # [B, TN, GC, H, W]\n            N, T, G, C = 6, TN // 6, 4, GC // 4\n            feat = feat.reshape(B, T, N, G, C, H, W)\n            feat = feat.permute(0, 1, 3, 2, 5, 6, 4)  # [B, T, G, N, H, W, C]\n            feat = feat.reshape(B*T*G, N, H, W, C)  # [BTG, N, H, W, C]\n            mlvl_feats[lvl] = feat.contiguous()\n        \n        lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)\n        lidar2img = torch.from_numpy(lidar2img).to(feat.device)  # [B, N, 4, 4]\n        ego2lidar = np.asarray([m['ego2lidar'] for m in img_metas]).astype(np.float32)\n        ego2lidar = torch.from_numpy(ego2lidar).to(feat.device)  # [B, N, 4, 4]\n        \n        img_metas = copy.deepcopy(img_metas)\n        img_metas[0]['lidar2img'] = torch.matmul(lidar2img, ego2lidar)\n\n        occ_preds = self.voxel_decoder(mlvl_feats, img_metas=img_metas)\n        mask_preds, class_preds = self.decoder(occ_preds, mlvl_feats, img_metas)\n        \n        return occ_preds, mask_preds, class_preds\n\n\nclass MaskFormerOccDecoder(BaseModule):\n    def __init__(self,\n                 embed_dims=None,\n                 num_layers=None,\n                 num_frames=None,\n                 num_queries=None,\n                 num_points=None,\n                 num_groups=None,\n                 num_levels=None,\n                 num_classes=None,\n                 pc_range=None,\n                 occ_size=None):\n        super().__init__()\n\n        self.num_layers = num_layers\n        self.num_queries = num_queries\n        self.num_frames = num_frames\n\n        self.decoder_layer = MaskFormerOccDecoderLayer(\n            embed_dims=embed_dims,\n            mask_dim=embed_dims,\n            num_frames=num_frames,\n            num_points=num_points,\n            num_groups=num_groups,\n            num_levels=num_levels,\n            num_classes=num_classes,\n            pc_range=pc_range,\n            occ_size=occ_size,\n        )\n        \n        self.query_feat = nn.Embedding(num_queries, embed_dims)\n        self.query_pos = nn.Embedding(num_queries, embed_dims)\n        \n    @torch.no_grad()\n    def init_weights(self):\n        self.decoder_layer.init_weights()\n        \n    def forward(self, occ_preds, mlvl_feats, img_metas):\n        occ_loc, occ_pred, _, mask_feat, _ = occ_preds[-1]\n        bs = mask_feat.shape[0]\n        query_feat = self.query_feat.weight[None].repeat(bs, 1, 1)\n        query_pos = self.query_pos.weight[None].repeat(bs, 1, 1)\n        \n        valid_map, mask_pred, class_pred = self.decoder_layer.pred_segmentation(query_feat, mask_feat)\n        \n        class_preds = [class_pred]\n        mask_preds = [mask_pred]\n\n        for i in range(self.num_layers):\n            DUMP.stage_count = i\n            query_feat, valid_map, mask_pred, class_pred = self.decoder_layer(\n                query_feat, valid_map, mask_pred, occ_preds, mlvl_feats, query_pos, img_metas\n            )\n            mask_preds.append(mask_pred)\n            class_preds.append(class_pred)\n\n        return mask_preds, class_preds\n\n\nclass MaskFormerOccDecoderLayer(BaseModule):\n    def __init__(self,\n                 embed_dims=None,\n                 mask_dim=None,\n                 num_frames=None,\n                 num_queries=None,\n                 num_points=None,\n                 num_groups=None,\n                 num_levels=None,\n                 num_classes=None,\n                 pc_range=None,\n                 occ_size=None):\n        super().__init__()\n        \n        self.pc_range = pc_range\n        self.occ_size = occ_size\n        \n        self.self_attn = MaskFormerSelfAttention(embed_dims, num_heads=8)\n        self.sampling = MaskFormerSampling(embed_dims, num_frames, num_groups, num_points, num_levels, pc_range=pc_range, occ_size=occ_size)\n        self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=num_groups, out_points=128)\n        self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1)\n        self.mask_proj = nn.Linear(embed_dims, mask_dim)\n        self.classifier = nn.Linear(embed_dims, num_classes - 1)\n        \n        self.norm1 = nn.LayerNorm(embed_dims)\n        self.norm2 = nn.LayerNorm(embed_dims)\n        self.norm3 = nn.LayerNorm(embed_dims)\n\n    @torch.no_grad()\n    def init_weights(self):\n        self.self_attn.init_weights()\n        self.sampling.init_weights()\n        self.mixing.init_weights()\n        self.ffn.init_weights()\n        \n    def forward(self, query_feat, valid_map, mask_pred, occ_preds, mlvl_feats, query_pos, img_metas):\n        \"\"\"\n        query_feat: [bs, num_query, embed_dim]\n        valid_map: [bs, num_query, num_voxel]\n        mask_pred: [bs, num_query, num_voxel]\n        occ_preds: list(occ_loc, occ_pred, _, mask_feat, scale), all voxel decoder's outputs\n            mask_feat: [bs, num_voxel, embed_dim]\n            occ_pred: [bs, num_voxel]\n            occ_loc: [bs, num_voxel, 3]\n        \"\"\"\n        occ_loc, occ_pred, _, mask_feat, _ = occ_preds[-1]\n        query_feat = self.norm1(self.self_attn(query_feat, query_pos=query_pos))\n\n        sampled_feat = self.sampling(query_feat, valid_map, occ_loc, mlvl_feats, img_metas)\n        query_feat = self.norm2(self.mixing(sampled_feat, query_feat))\n        \n        query_feat = self.norm3(self.ffn(query_feat))\n        \n        valid_map, mask_pred, class_pred = self.pred_segmentation(query_feat, mask_feat)\n        return query_feat, valid_map, mask_pred, class_pred\n    \n    def pred_segmentation(self, query_feat, mask_feat):\n        if self.training and query_feat.requires_grad:\n            return cp(self.inner_pred_segmentation, query_feat, mask_feat, use_reentrant=False)\n        else:\n            return self.inner_pred_segmentation(query_feat, mask_feat)\n    \n    def inner_pred_segmentation(self, query_feat, mask_feat):\n        class_pred = self.classifier(query_feat)\n        feat_proj = self.mask_proj(query_feat)\n        mask_pred = torch.einsum(\"bqc,bnc->bqn\", feat_proj, mask_feat)\n        valid_map = (mask_pred > 0.0)\n        \n        return valid_map, mask_pred, class_pred\n\n\nclass MaskFormerSelfAttention(BaseModule):\n    def __init__(self, embed_dims, num_heads, dropout=0.0):\n        super().__init__()\n        self.self_attn = nn.MultiheadAttention(embed_dims, num_heads, dropout=dropout, batch_first=True)\n        self.dropout = nn.Dropout(dropout)\n        self.activation = nn.ReLU(inplace=True)\n    \n    def init_weights(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                nn.init.xavier_uniform_(p)\n    \n    def with_pos_embed(self, tensor, pos=None):\n        return tensor if pos is None else tensor + pos\n                \n    def inner_forward(self, query, mask = None, key_padding_mask = None,query_pos= None):\n        q = k = self.with_pos_embed(query, query_pos)\n        tgt = self.self_attn(q, k, value=query, attn_mask=mask, key_padding_mask=key_padding_mask)[0]\n        query = query + self.dropout(tgt)\n        return query\n\n    def forward(self, query, mask = None, key_padding_mask = None,query_pos= None):\n        if self.training and query.requires_grad:\n            return cp(self.inner_forward, query, mask, key_padding_mask, query_pos, use_reentrant=False)\n        else:\n            return self.inner_forward(query, mask, key_padding_mask, query_pos)\n\n\nclass MaskFormerSampling(BaseModule):\n    def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], occ_size=[], init_cfg=None):\n        super().__init__(init_cfg)\n\n        self.num_frames = num_frames\n        self.num_points = num_points\n        self.num_groups = num_groups\n        self.num_levels = num_levels\n        self.pc_range = pc_range\n        self.occ_size = occ_size\n\n        self.offset = nn.Linear(embed_dims, num_groups * num_points * 3)\n        self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels)\n        \n    def init_weights(self, ):\n        nn.init.zeros_(self.offset.weight)\n        nn.init.zeros_(self.offset.bias)\n\n    def inner_forward(self, query_feat, valid_map, occ_loc, mlvl_feats, img_metas):\n        '''\n        valid_map: [B, Q, W, H, D]\n        query_feat: [B, Q, C]\n        '''\n        B, Q = query_feat.shape[:2]\n        image_h, image_w, _ = img_metas[0]['img_shape'][0]\n\n        # sampling offset of all frames\n        offset = self.offset(query_feat).view(B, Q, self.num_groups * self.num_points, 3)  # [B, Q, GP, 3]\n        sampling_points = make_sample_points_from_mask(valid_map, self.pc_range, self.occ_size, self.num_groups*self.num_points, occ_loc, offset)\n        sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3)\n        sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3)\n\n        # scale weights\n        scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels)\n        scale_weights = torch.softmax(scale_weights, dim=-1)\n        scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels)\n\n        # sampling\n        sampled_feats = sampling_4d(\n            sampling_points,\n            mlvl_feats,\n            scale_weights,\n            img_metas[0]['lidar2img'],\n            image_h, image_w\n        )  # [B, Q, G, FP, C]\n\n        return sampled_feats\n\n    def forward(self, query_feat, valid_map, occ_loc,  mlvl_feats, img_metas):\n        if self.training and query_feat.requires_grad:\n            return cp(self.inner_forward, query_feat, valid_map, occ_loc, mlvl_feats, img_metas, use_reentrant=False)\n        else:\n            return self.inner_forward(query_feat, valid_map, occ_loc, mlvl_feats, img_metas)\n"
  },
  {
    "path": "models/utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom numpy import random\nfrom mmcv.cnn.bricks import ConvTranspose3d, Conv3d\n\n\ndef conv3d_gn_relu(in_channels, out_channels, kernel_size=1, stride=1):\n    return nn.Sequential(\n        Conv3d(in_channels, out_channels, kernel_size, stride, bias=False),\n        nn.GroupNorm(16, out_channels),\n        nn.ReLU(inplace=True),\n    )\n\n\ndef deconv3d_gn_relu(in_channels, out_channels, kernel_size=2, stride=2):\n    return nn.Sequential(\n        ConvTranspose3d(in_channels, out_channels, kernel_size, stride, bias=False),\n        nn.GroupNorm(16, out_channels),\n        nn.ReLU(inplace=True),\n    )\n\n\ndef sparse2dense(indices, value, dense_shape, empty_value=0):\n    B, N = indices.shape[:2]  # [B, N, 3]\n\n    batch_index = torch.arange(B).unsqueeze(1).expand(B, N)\n    dense = torch.ones([B] + dense_shape, device=value.device, dtype=value.dtype) * empty_value\n    dense[batch_index, indices[..., 0], indices[..., 1], indices[..., 2]] = value\n    \n    mask = torch.zeros([B] + dense_shape[:3], dtype=torch.bool, device=value.device)\n    mask[batch_index, indices[..., 0], indices[..., 1], indices[..., 2]] = 1\n\n    return dense, mask\n\n\n@torch.no_grad()\ndef generate_grid(n_vox, interval):\n    # Create voxel grid\n    grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)]\n    grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2], indexing='ij'))  # 3 dx dy dz\n    grid = grid.cuda().view(3, -1).permute(1, 0)  # N, 3\n    return grid[None]  # 1, N, 3\n\n\ndef batch_indexing(batched_data: torch.Tensor, batched_indices: torch.Tensor, layout='channel_first'):\n    def batch_indexing_channel_first(batched_data: torch.Tensor, batched_indices: torch.Tensor):\n        \"\"\"\n        :param batched_data: [batch_size, C, N]\n        :param batched_indices: [batch_size, I1, I2, ..., Im]\n        :return: indexed data: [batch_size, C, I1, I2, ..., Im]\n        \"\"\"\n        def product(arr):\n            p = 1\n            for i in arr:\n                p *= i\n            return p\n        assert batched_data.shape[0] == batched_indices.shape[0]\n        batch_size, n_channels = batched_data.shape[:2]\n        indices_shape = list(batched_indices.shape[1:])\n        batched_indices = batched_indices.reshape([batch_size, 1, -1])\n        batched_indices = batched_indices.expand([batch_size, n_channels, product(indices_shape)])\n        result = torch.gather(batched_data, dim=2, index=batched_indices.to(torch.int64))\n        result = result.view([batch_size, n_channels] + indices_shape)\n        return result\n\n    def batch_indexing_channel_last(batched_data: torch.Tensor, batched_indices: torch.Tensor):\n        \"\"\"\n        :param batched_data: [batch_size, N, C]\n        :param batched_indices: [batch_size, I1, I2, ..., Im]\n        :return: indexed data: [batch_size, I1, I2, ..., Im, C]\n        \"\"\"\n        assert batched_data.shape[0] == batched_indices.shape[0]\n        batch_size = batched_data.shape[0]\n        view_shape = [batch_size] + [1] * (len(batched_indices.shape) - 1)\n        expand_shape = [batch_size] + list(batched_indices.shape)[1:]\n        indices_of_batch = torch.arange(batch_size, dtype=torch.long, device=batched_data.device)\n        indices_of_batch = indices_of_batch.view(view_shape).expand(expand_shape)  # [bs, I1, I2, ..., Im]\n        if len(batched_data.shape) == 2:\n            return batched_data[indices_of_batch, batched_indices.to(torch.long)]\n        else:\n            return batched_data[indices_of_batch, batched_indices.to(torch.long), :]\n\n    if layout == 'channel_first':\n        return batch_indexing_channel_first(batched_data, batched_indices)\n    elif layout == 'channel_last':\n        return batch_indexing_channel_last(batched_data, batched_indices)\n    else:\n        raise ValueError\n\n\ndef rotation_3d_in_axis(points, angles):\n    assert points.shape[-1] == 3\n    assert angles.shape[-1] == 1\n    angles = angles[..., 0]\n\n    n_points = points.shape[-2]\n    input_dims = angles.shape\n\n    if len(input_dims) > 1:\n        points = points.reshape(-1, n_points, 3)\n        angles = angles.reshape(-1)\n\n    rot_sin = torch.sin(angles)\n    rot_cos = torch.cos(angles)\n    ones = torch.ones_like(rot_cos)\n    zeros = torch.zeros_like(rot_cos)\n\n    rot_mat_T = torch.stack([\n        rot_cos, rot_sin, zeros,\n        -rot_sin, rot_cos, zeros,\n        zeros, zeros, ones,\n    ]).transpose(0, 1).reshape(-1, 3, 3)\n\n    points = torch.bmm(points, rot_mat_T)\n\n    if len(input_dims) > 1:\n        points = points.reshape(*input_dims, n_points, 3)\n    \n    return points\n\n\ndef inverse_sigmoid(x, eps=1e-5):\n    \"\"\"Inverse function of sigmoid.\n    Args:\n        x (Tensor): The tensor to do the\n            inverse.\n        eps (float): EPS avoid numerical\n            overflow. Defaults 1e-5.\n    Returns:\n        Tensor: The x has passed the inverse\n            function of sigmoid, has same\n            shape with input.\n    \"\"\"\n    x = x.clamp(min=0, max=1)\n    x1 = x.clamp(min=eps)\n    x2 = (1 - x).clamp(min=eps)\n    return torch.log(x1 / x2)\n\n\ndef pad_multiple(inputs, img_metas, size_divisor=32):\n    _, _, img_h, img_w = inputs.shape\n\n    pad_h = 0 if img_h % size_divisor == 0 else size_divisor - (img_h % size_divisor)\n    pad_w = 0 if img_w % size_divisor == 0 else size_divisor - (img_w % size_divisor)\n\n    B = len(img_metas)\n    N = len(img_metas[0]['ori_shape'])\n\n    for b in range(B):\n        img_metas[b]['img_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)]\n        img_metas[b]['pad_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)]\n\n    if pad_h == 0 and pad_w == 0:\n        return inputs\n    else:\n        return F.pad(inputs, [0, pad_w, 0, pad_h], value=0)\n\n\ndef rgb_to_hsv(image: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:\n    r\"\"\"Convert an image from RGB to HSV.\n\n    .. image:: _static/img/rgb_to_hsv.png\n\n    The image data is assumed to be in the range of (0, 1).\n\n    Args:\n        image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.\n        eps: scalar to enforce numarical stability.\n\n    Returns:\n        HSV version of the image with shape of :math:`(*, 3, H, W)`.\n        The H channel values are in the range 0..2pi. S and V are in the range 0..1.\n\n    .. note::\n       See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/\n       color_conversions.html>`__.\n\n    Example:\n        >>> input = torch.rand(2, 3, 4, 5)\n        >>> output = rgb_to_hsv(input)  # 2x3x4x5\n    \"\"\"\n    if not isinstance(image, torch.Tensor):\n        raise TypeError(f\"Input type is not a torch.Tensor. Got {type(image)}\")\n\n    if len(image.shape) < 3 or image.shape[-3] != 3:\n        raise ValueError(f\"Input size must have a shape of (*, 3, H, W). Got {image.shape}\")\n\n    image = image / 255.0\n\n    max_rgb, argmax_rgb = image.max(-3)\n    min_rgb, argmin_rgb = image.min(-3)\n    deltac = max_rgb - min_rgb\n\n    v = max_rgb\n    s = deltac / (max_rgb + eps)\n\n    deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)\n    rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)\n\n    h1 = bc - gc\n    h2 = (rc - bc) + 2.0 * deltac\n    h3 = (gc - rc) + 4.0 * deltac\n\n    h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)\n    h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)\n    h = (h / 6.0) % 1.0\n\n    h = h * 360.0\n    v = v * 255.0\n\n    return torch.stack((h, s, v), dim=-3)\n\n\ndef hsv_to_rgb(image: torch.Tensor) -> torch.Tensor:\n    r\"\"\"Convert an image from HSV to RGB.\n\n    The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1.\n\n    Args:\n        image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.\n\n    Returns:\n        RGB version of the image with shape of :math:`(*, 3, H, W)`.\n\n    Example:\n        >>> input = torch.rand(2, 3, 4, 5)\n        >>> output = hsv_to_rgb(input)  # 2x3x4x5\n    \"\"\"\n    if not isinstance(image, torch.Tensor):\n        raise TypeError(f\"Input type is not a torch.Tensor. Got {type(image)}\")\n\n    if len(image.shape) < 3 or image.shape[-3] != 3:\n        raise ValueError(f\"Input size must have a shape of (*, 3, H, W). Got {image.shape}\")\n\n    h: torch.Tensor = image[..., 0, :, :] / 360.0\n    s: torch.Tensor = image[..., 1, :, :]\n    v: torch.Tensor = image[..., 2, :, :] / 255.0\n\n    hi: torch.Tensor = torch.floor(h * 6) % 6\n    f: torch.Tensor = ((h * 6) % 6) - hi\n    one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype)\n    p: torch.Tensor = v * (one - s)\n    q: torch.Tensor = v * (one - f * s)\n    t: torch.Tensor = v * (one - (one - f) * s)\n\n    hi = hi.long()\n    indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3)\n    out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)\n    out = torch.gather(out, -3, indices)\n    out = out * 255.0\n\n    return out\n\n\nclass GpuPhotoMetricDistortion:\n    \"\"\"Apply photometric distortion to image sequentially, every transformation\n    is applied with a probability of 0.5. The position of random contrast is in\n    second or second to last.\n    1. random brightness\n    2. random contrast (mode 0)\n    3. convert color from BGR to HSV\n    4. random saturation\n    5. random hue\n    6. convert color from HSV to BGR\n    7. random contrast (mode 1)\n    8. randomly swap channels\n    Args:\n        brightness_delta (int): delta of brightness.\n        contrast_range (tuple): range of contrast.\n        saturation_range (tuple): range of saturation.\n        hue_delta (int): delta of hue.\n    \"\"\"\n\n    def __init__(self,\n                 brightness_delta=32,\n                 contrast_range=(0.5, 1.5),\n                 saturation_range=(0.5, 1.5),\n                 hue_delta=18):\n        self.brightness_delta = brightness_delta\n        self.contrast_lower, self.contrast_upper = contrast_range\n        self.saturation_lower, self.saturation_upper = saturation_range\n        self.hue_delta = hue_delta\n\n    def __call__(self, imgs):\n        \"\"\"Call function to perform photometric distortion on images.\n        Args:\n            results (dict): Result dict from loading pipeline.\n        Returns:\n            dict: Result dict with images distorted.\n        \"\"\"\n        imgs = imgs[:, [2, 1, 0], :, :]  # BGR to RGB\n\n        contrast_modes = []\n        for _ in range(imgs.shape[0]):\n            # mode == 0 --> do random contrast first\n            # mode == 1 --> do random contrast last\n            contrast_modes.append(random.randint(2))\n\n        for idx in range(imgs.shape[0]):\n            # random brightness\n            if random.randint(2):\n                delta = random.uniform(-self.brightness_delta, self.brightness_delta)\n                imgs[idx] += delta\n\n            if contrast_modes[idx] == 0:\n                if random.randint(2):\n                    alpha = random.uniform(self.contrast_lower, self.contrast_upper)\n                    imgs[idx] *= alpha\n\n        # convert color from BGR to HSV\n        imgs = rgb_to_hsv(imgs)\n\n        for idx in range(imgs.shape[0]):\n            # random saturation\n            if random.randint(2):\n                imgs[idx, 1] *= random.uniform(self.saturation_lower, self.saturation_upper)\n\n            # random hue\n            if random.randint(2):\n                imgs[idx, 0] += random.uniform(-self.hue_delta, self.hue_delta)\n\n        imgs[:, 0][imgs[:, 0] > 360] -= 360\n        imgs[:, 0][imgs[:, 0] < 0] += 360\n\n        # convert color from HSV to BGR\n        imgs = hsv_to_rgb(imgs)\n\n        for idx in range(imgs.shape[0]):\n            # random contrast\n            if contrast_modes[idx] == 1:\n                if random.randint(2):\n                    alpha = random.uniform(self.contrast_lower, self.contrast_upper)\n                    imgs[idx] *= alpha\n\n            # randomly swap channels\n            if random.randint(2):\n                imgs[idx] = imgs[idx, random.permutation(3)]\n\n        imgs = imgs[:, [2, 1, 0], :, :]  # RGB to BGR\n\n        return imgs\n\n\nclass DumpConfig:\n    def __init__(self):\n        self.enabled = False\n        self.out_dir = 'outputs'\n        self.stage_count = 0\n        self.frame_count = 0\n\n\nDUMP = DumpConfig()\n"
  },
  {
    "path": "old_metrics.py",
    "content": "import os\nimport glob\nimport torch\nimport argparse\nimport numpy as np\nfrom tqdm import tqdm\nfrom loaders.old_metrics import Metric_mIoU\n\n\ndef main(args):\n    pred_filepaths = sorted(glob.glob(os.path.join(args.pred_dir, '*.npz')))\n    gt_filepaths = sorted(glob.glob(os.path.join(args.data_root, 'occ3d', '*/*/*.npz')))\n\n    eval_metrics_miou = Metric_mIoU(\n        num_classes=18,\n        use_lidar_mask=False,\n        use_image_mask=True)\n\n    for pred_filepath in tqdm(pred_filepaths):\n        sample_token = os.path.basename(pred_filepath).split('.')[0]\n        for gt_filepath in gt_filepaths:\n            if sample_token in gt_filepath:\n                sem_pred = np.load(pred_filepath, allow_pickle=True)['pred']\n                sem_pred = np.reshape(sem_pred, [200, 200, 16])\n                occ_gt = np.load(gt_filepath, allow_pickle=True)\n\n                gt_semantics = occ_gt['semantics']\n                mask_lidar = occ_gt['mask_lidar'].astype(bool)\n                mask_camera = occ_gt['mask_camera'].astype(bool)\n                \n                eval_metrics_miou.add_batch(sem_pred, gt_semantics, mask_lidar, mask_camera)\n\n    eval_metrics_miou.count_miou()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-root\", type=str, default='data/nuscenes')\n    parser.add_argument(\"--pred-dir\", type=str)\n    args = parser.parse_args()\n\n    torch.random.manual_seed(0)\n    np.random.seed(0)\n\n    main(args)\n"
  },
  {
    "path": "ray_metrics.py",
    "content": "import os\nimport glob\nimport mmcv\nimport argparse\nimport numpy as np\nimport torch\nfrom torch.utils.data import DataLoader\nfrom loaders.ray_metrics import main_rayiou\nfrom loaders.ego_pose_dataset import EgoPoseDataset\nfrom configs.r50_nuimg_704x256_8f import occ_class_names as occ3d_class_names\nfrom configs.r50_nuimg_704x256_8f_openocc import occ_class_names as openocc_class_names\n\n\ndef main(args):\n    data_infos = mmcv.load(os.path.join(args.data_root, 'nuscenes_infos_val.pkl'))['infos']\n    gt_filepaths = sorted(glob.glob(os.path.join(args.data_root, args.data_type, '*/*/*.npz')))\n\n    # retrieve scene_name\n    token2scene = {}\n    for gt_path in gt_filepaths:\n        token = gt_path.split('/')[-2]\n        scene_name = gt_path.split('/')[-3]\n        token2scene[token] = scene_name\n\n    for i in range(len(data_infos)):\n        scene_name = token2scene[data_infos[i]['token']]\n        data_infos[i]['scene_name'] = scene_name\n\n    lidar_origins = []\n    occ_gts = []\n    occ_preds = []\n\n    for idx, batch in enumerate(DataLoader(EgoPoseDataset(data_infos), num_workers=8)):\n        output_origin = batch[1]\n        info = data_infos[idx]\n\n        occ_path = os.path.join(args.data_root, args.data_type, info['scene_name'], info['token'], 'labels.npz')\n        occ_gt = np.load(occ_path, allow_pickle=True)['semantics']\n        occ_gt = np.reshape(occ_gt, [200, 200, 16]).astype(np.uint8)\n\n        occ_path = os.path.join(args.pred_dir, info['token'] + '.npz')\n        occ_pred = np.load(occ_path, allow_pickle=True)['pred']\n        occ_pred = np.reshape(occ_pred, [200, 200, 16]).astype(np.uint8)\n        \n        lidar_origins.append(output_origin)\n        occ_gts.append(occ_gt)\n        occ_preds.append(occ_pred)\n    \n    if args.data_type == 'occ3d':\n        occ_class_names = occ3d_class_names\n    elif args.data_type == 'openocc_v2':\n        occ_class_names = openocc_class_names\n    else:\n        raise ValueError\n    \n    print(main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--data-root\", type=str, default='data/nuscenes')\n    parser.add_argument(\"--pred-dir\", type=str)\n    parser.add_argument(\"--data-type\", type=str, choices=['occ3d', 'openocc_v2'], default='occ3d')\n    args = parser.parse_args()\n\n    torch.random.manual_seed(0)\n    np.random.seed(0)\n\n    main(args)\n"
  },
  {
    "path": "timing.py",
    "content": "import time\nimport utils\nimport logging\nimport argparse\nimport importlib\nimport torch\nimport torch.distributed\nimport torch.backends.cudnn as cudnn\nfrom mmcv import Config, DictAction\nfrom mmcv.parallel import MMDataParallel\nfrom mmcv.runner import load_checkpoint\nfrom mmdet.apis import set_random_seed\nfrom mmdet3d.datasets import build_dataset, build_dataloader\nfrom mmdet3d.models import build_model\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Validate a detector')\n    parser.add_argument('--config', required=True)\n    parser.add_argument('--weights', required=True)\n    parser.add_argument('--num_warmup', default=10)\n    parser.add_argument('--samples', default=200)\n    parser.add_argument('--log-interval', default=50, help='interval of logging')\n    parser.add_argument('--override', nargs='+', action=DictAction)\n    args = parser.parse_args()\n\n    # parse configs\n    cfgs = Config.fromfile(args.config)\n    if args.override is not None:\n        cfgs.merge_from_dict(args.override)\n\n    # register custom module\n    importlib.import_module('models')\n    importlib.import_module('loaders')\n\n    # MMCV, please shut up\n    from mmcv.utils.logging import logger_initialized\n    logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)\n    logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)\n    utils.init_logging(None, cfgs.debug)\n\n    # you need GPUs\n    assert torch.cuda.is_available() and torch.cuda.device_count() == 1\n    logging.info('Using GPU: %s' % torch.cuda.get_device_name(0))\n    torch.cuda.set_device(0)\n\n    logging.info('Setting random seed: 0')\n    set_random_seed(0, deterministic=True)\n    cudnn.benchmark = True\n\n    logging.info('Loading validation set from %s' % cfgs.data.val.data_root)\n    val_dataset = build_dataset(cfgs.data.val)\n    val_loader = build_dataloader(\n        val_dataset,\n        samples_per_gpu=1,\n        workers_per_gpu=cfgs.data.workers_per_gpu,\n        num_gpus=1,\n        dist=False,\n        shuffle=False,\n        seed=0,\n    )\n\n    logging.info('Creating model: %s' % cfgs.model.type)\n    model = build_model(cfgs.model)\n    model.cuda()\n\n    assert torch.cuda.device_count() == 1\n    model = MMDataParallel(model, [0])\n\n    logging.info('Loading checkpoint from %s' % args.weights)\n    load_checkpoint(\n        model, args.weights, map_location='cuda', strict=False,\n        logger=logging.Logger(__name__, logging.ERROR)\n    )\n    model.eval()\n\n    print('Timing w/ data loading:')\n    pure_inf_time = 0\n    with torch.no_grad():\n        for i, data in enumerate(val_loader):\n            torch.cuda.synchronize()\n            start_time = time.perf_counter()\n\n            model(return_loss=False, rescale=True, **data)\n\n            torch.cuda.synchronize()\n            elapsed = time.perf_counter() - start_time\n\n            if i >= args.num_warmup:\n                pure_inf_time += elapsed\n                if (i + 1) % args.log_interval == 0:\n                    fps = (i + 1 - args.num_warmup) / pure_inf_time\n                    print(f'Done sample [{i + 1:<3}/ {args.samples}], '\n                        f'fps: {fps:.1f} sample / s')\n\n            if (i + 1) == args.samples:\n                break\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "train.py",
    "content": "import os\nimport utils\nimport shutil\nimport logging\nimport argparse\nimport importlib\nimport torch\nimport torch.distributed as dist\nfrom datetime import datetime\nfrom mmcv import Config, DictAction\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import EpochBasedRunner, build_optimizer, load_checkpoint\nfrom mmdet.apis import set_random_seed\nfrom mmdet.core import DistEvalHook, EvalHook\nfrom mmdet3d.datasets import build_dataset\nfrom mmdet3d.models import build_model\nfrom loaders.builder import build_dataloader\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Train a detector')\n    parser.add_argument('--config', required=True)\n    parser.add_argument('--run_name', required=False, default='')\n    parser.add_argument('--override', nargs='+', action=DictAction)\n    parser.add_argument('--local_rank', type=int, default=0)\n    parser.add_argument('--world_size', type=int, default=1)\n    args = parser.parse_args()\n\n    # parse configs\n    cfgs = Config.fromfile(args.config)\n    if args.override is not None:\n        cfgs.merge_from_dict(args.override)\n\n    # register custom module\n    importlib.import_module('models')\n    importlib.import_module('loaders')\n\n    # MMCV, please shut up\n    from mmcv.utils.logging import logger_initialized\n    logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)\n    logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)\n    logger_initialized['mmdet3d'] = logging.Logger(__name__, logging.WARNING)\n\n    # you need GPUs\n    assert torch.cuda.is_available()\n\n    # determine local_rank and world_size\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n    \n    if 'WORLD_SIZE' not in os.environ:\n        os.environ['WORLD_SIZE'] = str(args.world_size)\n\n    local_rank = int(os.environ['LOCAL_RANK'])\n    world_size = int(os.environ['WORLD_SIZE'])\n\n    if local_rank == 0:\n        # resume or start a new run\n        if cfgs.resume_from is not None:\n            assert os.path.isfile(cfgs.resume_from)\n            work_dir = os.path.dirname(cfgs.resume_from)\n        else:\n            run_name = args.run_name\n            if not cfgs.debug and run_name == '':\n                run_name = input('Name your run (leave blank for default): ')\n            if run_name == '':\n                run_name = datetime.now().strftime(\"%Y-%m-%d/%H-%M-%S\")\n\n            work_dir = os.path.join('outputs', cfgs.model.type, run_name)\n            if os.path.exists(work_dir):  # must be an empty dir\n                if input('Path \"%s\" already exists, overwrite it? [Y/n] ' % work_dir) == 'n':\n                    print('Bye.')\n                    exit(0)\n                shutil.rmtree(work_dir)\n\n            os.makedirs(work_dir, exist_ok=False)\n\n        # init logging, backup code\n        utils.init_logging(os.path.join(work_dir, 'train.log'), cfgs.debug)\n        utils.backup_code(work_dir)\n        logging.info('Logs will be saved to %s' % work_dir)\n\n    else:\n        # disable logging on other workers\n        logging.root.disabled = True\n        work_dir = '/tmp'\n\n    logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank))\n    torch.cuda.set_device(local_rank)\n\n    if world_size > 1:\n        logging.info('Initializing DDP with %d GPUs...' % world_size)\n        dist.init_process_group('nccl', init_method='env://')\n\n    logging.info('Setting random seed: 0')\n    set_random_seed(0, deterministic=True)\n\n    logging.info('Loading training set from %s' % cfgs.dataset_root)\n    train_dataset = build_dataset(cfgs.data.train)\n    train_loader = build_dataloader(\n        train_dataset,\n        samples_per_gpu=cfgs.batch_size // world_size,\n        workers_per_gpu=cfgs.data.workers_per_gpu,\n        num_gpus=world_size,\n        dist=world_size > 1,\n        shuffle=True,\n        seed=0,\n    )\n\n    logging.info('Loading validation set from %s' % cfgs.dataset_root)\n    val_dataset = build_dataset(cfgs.data.val)\n    val_loader = build_dataloader(\n        val_dataset,\n        samples_per_gpu=1,\n        workers_per_gpu=cfgs.data.workers_per_gpu,\n        num_gpus=world_size,\n        dist=world_size > 1,\n        shuffle=False\n    )\n\n    logging.info('Creating model: %s' % cfgs.model.type)\n    model = build_model(cfgs.model)\n    model.init_weights()\n    model.cuda()\n    model.train()\n\n    n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])\n    logging.info('Trainable parameters: %d (%.1fM)' % (n_params, n_params / 1e6))\n    logging.info('Batch size per GPU: %d' % (cfgs.batch_size // world_size))\n\n    if world_size > 1:\n        model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)\n    else:\n        model = MMDataParallel(model, [0])\n\n    logging.info('Creating optimizer: %s' % cfgs.optimizer.type)\n    optimizer = build_optimizer(model, cfgs.optimizer)\n\n    runner = EpochBasedRunner(\n        model,\n        optimizer=optimizer,\n        work_dir=work_dir,\n        logger=logging.root,\n        max_epochs=cfgs.total_epochs,\n        meta=dict(),\n    )\n\n    runner.register_lr_hook(cfgs.lr_config)\n    runner.register_optimizer_hook(cfgs.optimizer_config)\n    runner.register_checkpoint_hook(cfgs.checkpoint_config)\n    runner.register_logger_hooks(cfgs.log_config)\n    runner.register_timer_hook(dict(type='IterTimerHook'))\n    runner.register_custom_hooks(dict(type='DistSamplerSeedHook'))\n\n    if cfgs.eval_config['interval'] > 0:\n        if world_size > 1:\n            runner.register_hook(DistEvalHook(val_loader, interval=cfgs.eval_config['interval'], gpu_collect=True))\n        else:\n            runner.register_hook(EvalHook(val_loader, interval=cfgs.eval_config['interval']))\n\n    if cfgs.resume_from is not None:\n        logging.info('Resuming from %s' % cfgs.resume_from)\n        runner.resume(cfgs.resume_from)\n\n    elif cfgs.load_from is not None:\n        logging.info('Loading checkpoint from %s' % cfgs.load_from)\n        if cfgs.revise_keys is not None:\n            load_checkpoint(\n                model, cfgs.load_from, map_location='cpu',\n                revise_keys=cfgs.revise_keys\n            )\n        else:\n            load_checkpoint(\n                model, cfgs.load_from, map_location='cpu',\n            )\n\n    runner.run([train_loader], [('train', 1)])\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "utils.py",
    "content": "import os\nimport sys\nimport glob\nimport torch\nimport shutil\nimport logging\nimport datetime\nimport socket\nimport wandb\nfrom mmcv.runner.hooks import HOOKS\nfrom mmcv.runner.hooks.logger import LoggerHook, TextLoggerHook\nfrom mmcv.runner.dist_utils import master_only\nfrom torch.utils.tensorboard import SummaryWriter\n\n\ndef init_logging(filename=None, debug=False):\n    logging.root = logging.RootLogger('DEBUG' if debug else 'INFO')\n    formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s')\n\n    stream_handler = logging.StreamHandler(sys.stdout)\n    stream_handler.setFormatter(formatter)\n    logging.root.addHandler(stream_handler)\n\n    if filename is not None:\n        file_handler = logging.FileHandler(filename)\n        file_handler.setFormatter(formatter)\n        logging.root.addHandler(file_handler)\n\n\ndef backup_code(work_dir, verbose=False):\n    base_dir = os.path.dirname(os.path.abspath(__file__))\n    for pattern in ['*.py', 'configs/*.py', 'models/*.py', 'loaders/*.py', 'loaders/pipelines/*.py']:\n        for file in glob.glob(pattern):\n            src = os.path.join(base_dir, file)\n            dst = os.path.join(work_dir, 'backup', os.path.dirname(file))\n\n            if verbose:\n                logging.info('Copying %s -> %s' % (os.path.relpath(src), os.path.relpath(dst)))\n            \n            os.makedirs(dst, exist_ok=True)\n            shutil.copy2(src, dst)\n\n\n@HOOKS.register_module()\nclass MyTextLoggerHook(TextLoggerHook):\n    def _log_info(self, log_dict, runner):\n        # print exp name for users to distinguish experiments\n        # at every ``interval_exp_name`` iterations and the end of each epoch\n        if runner.meta is not None and 'exp_name' in runner.meta:\n            if (self.every_n_iters(runner, self.interval_exp_name)) or (\n                    self.by_epoch and self.end_of_epoch(runner)):\n                exp_info = f'Exp name: {runner.meta[\"exp_name\"]}'\n                runner.logger.info(exp_info)\n\n        # by epoch: Epoch [4][100/1000]\n        # by iter:  Iter [100/100000]\n        if self.by_epoch:\n            log_str = f'Epoch [{log_dict[\"epoch\"]}/{runner.max_epochs}]' \\\n                        f'[{log_dict[\"iter\"]}/{len(runner.data_loader)}] '\n        else:\n            log_str = f'Iter [{log_dict[\"iter\"]}/{runner.max_iters}] '\n\n        log_str += 'loss: %.2f, ' % log_dict['loss']\n\n        if 'time' in log_dict.keys():\n            # MOD: skip the first iteration since it's not accurate\n            if runner.iter == self.start_iter:\n                time_sec_avg = log_dict['time']\n            else:\n                self.time_sec_tot += (log_dict['time'] * self.interval)\n                time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter)\n\n            eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)\n            eta_str = str(datetime.timedelta(seconds=int(eta_sec)))\n            log_str += f'eta: {eta_str}, '\n            log_str += f'time: {log_dict[\"time\"]:.2f}s, ' \\\n                        f'data: {log_dict[\"data_time\"] * 1000:.0f}ms, '\n            # statistic memory\n            if torch.cuda.is_available():\n                log_str += f'mem: {log_dict[\"memory\"]}M'\n\n        runner.logger.info(log_str)\n\n    def log(self, runner):\n        if 'eval_iter_num' in runner.log_buffer.output:\n            # this doesn't modify runner.iter and is regardless of by_epoch\n            cur_iter = runner.log_buffer.output.pop('eval_iter_num')\n        else:\n            cur_iter = self.get_iter(runner, inner_iter=True)\n\n        log_dict = {\n            'mode': self.get_mode(runner),\n            'epoch': self.get_epoch(runner),\n            'iter': cur_iter\n        }\n\n        # only record lr of the first param group\n        cur_lr = runner.current_lr()\n        if isinstance(cur_lr, list):\n            log_dict['lr'] = cur_lr[0]\n        else:\n            assert isinstance(cur_lr, dict)\n            log_dict['lr'] = {}\n            for k, lr_ in cur_lr.items():\n                assert isinstance(lr_, list)\n                log_dict['lr'].update({k: lr_[0]})\n\n        if 'time' in runner.log_buffer.output:\n            # statistic memory\n            if torch.cuda.is_available():\n                log_dict['memory'] = self._get_max_memory(runner)\n\n        log_dict = dict(log_dict, **runner.log_buffer.output)\n\n        # MOD: disable writing to files\n        # self._dump_log(log_dict, runner)\n        self._log_info(log_dict, runner)\n\n        return log_dict\n\n    def after_train_epoch(self, runner):\n        if 'eval_iter_num' in runner.log_buffer.output:\n            runner.log_buffer.output.pop('eval_iter_num')\n\n        if runner.log_buffer.ready:\n            metrics = self.get_loggable_tags(runner)\n            runner.logger.info('--- Evaluation Results ---')\n            runner.logger.info('RayIoU: %.4f' % metrics['val/RayIoU'])\n\n\n@HOOKS.register_module()\nclass MyTensorboardLoggerHook(LoggerHook):\n    def __init__(self, log_dir=None, interval=10, ignore_last=True, reset_flag=False, by_epoch=True):\n        super(MyTensorboardLoggerHook, self).__init__(\n            interval, ignore_last, reset_flag, by_epoch)\n        self.log_dir = log_dir\n\n    @master_only\n    def before_run(self, runner):\n        super(MyTensorboardLoggerHook, self).before_run(runner)\n        if self.log_dir is None:\n            self.log_dir = runner.work_dir\n        self.writer = SummaryWriter(self.log_dir)\n\n    @master_only\n    def log(self, runner):\n        tags = self.get_loggable_tags(runner)\n\n        for key, value in tags.items():\n            # MOD: merge into the 'train' group\n            if key == 'learning_rate':\n                key = 'train/learning_rate'\n\n            # MOD: skip momentum\n            ignore = False\n            if key == 'momentum':\n                ignore = True\n\n            # MOD: skip intermediate losses\n            for i in range(5):\n                if key[:13] == 'train/d%d.loss' % i:\n                    ignore = True\n\n            if self.get_mode(runner) == 'train' and key[:5] != 'train':\n                ignore = True\n\n            if self.get_mode(runner) != 'train' and key[:3] != 'val':\n                ignore = True\n\n            if ignore:\n                continue\n\n            if key[:5] == 'train':\n                self.writer.add_scalar(key, value, self.get_iter(runner))\n            elif key[:3] == 'val':\n                self.writer.add_scalar(key, value, self.get_epoch(runner))\n\n    @master_only\n    def after_run(self, runner):\n        self.writer.close()\n\n\n# modified from mmcv.runner.hooks.logger.wandb\n@HOOKS.register_module()\nclass MyWandbLoggerHook(LoggerHook):\n    \"\"\"Class to log metrics with wandb.\n\n    It requires `wandb`_ to be installed.\n\n\n    Args:\n        log_dir (str): directory for saving logs\n            Default None.\n        project_name (str): name for your project (mainly used to specify saving path on wandb server)\n            Default None.\n        team_name (str): name for your team (mainly used to specify saving path on wandb server)\n            Default None.\n        experiment_name (str): name for your run, if not specified, use the last part of log_dir\n            Default None.\n        interval (int): Logging interval (every k iterations).\n            Default 10.\n        ignore_last (bool): Ignore the log of last iterations in each epoch\n            if less than `interval`.\n            Default: True.\n        reset_flag (bool): Whether to clear the output buffer after logging.\n            Default: False.\n        commit (bool): Save the metrics dict to the wandb server and increment\n            the step. If false ``wandb.log`` just updates the current metrics\n            dict with the row argument and metrics won't be saved until\n            ``wandb.log`` is called with ``commit=True``.\n            Default: True.\n        by_epoch (bool): Whether EpochBasedRunner is used.\n            Default: True.\n        with_step (bool): If True, the step will be logged from\n            ``self.get_iters``. Otherwise, step will not be logged.\n            Default: True.\n        out_suffix (str or tuple[str], optional): Those filenames ending with\n            ``out_suffix`` will be uploaded to wandb.\n            Default: ('.log.json', '.log', '.py').\n            `New in version 1.4.3.`\n\n    .. _wandb:\n        https://docs.wandb.ai\n    \"\"\"\n    def __init__(self, log_dir=None, project_name=None, team_name=None, experiment_name=None, \n                 interval=10, ignore_last=True, reset_flag=False, by_epoch=True, commit=True, \n                 with_step=True, out_suffix = ('.log.json', '.log', '.py')):\n        \n        super().__init__(interval, ignore_last, reset_flag, by_epoch)\n        self.import_wandb()\n        self.commit = commit\n        self.with_step = with_step\n        self.out_suffix = out_suffix\n        \n        self.log_dir = log_dir\n        self.project_name = project_name\n        self.team_name = team_name\n        self.experiment_name = experiment_name\n        if commit:\n            os.system('wandb online')\n        else:\n            os.system('wandb offline')\n            \n    def import_wandb(self) -> None:\n        try:\n            import wandb\n        except ImportError:\n            raise ImportError(\n                'Please run \"pip install wandb\" to install wandb')\n        self.wandb = wandb\n        \n    @master_only\n    def before_run(self, runner) -> None:\n        super().before_run(runner)\n        if self.log_dir is None:\n            self.log_dir = runner.work_dir\n        if self.experiment_name is None:\n            self.experiment_name = os.path.basename(self.log_dir)\n        init_kwargs = dict(\n            project=self.project_name,\n            entity=self.team_name,\n            notes=socket.gethostname(),\n            name=self.experiment_name,\n            dir=self.log_dir,\n            reinit=True\n        )\n            \n        if self.wandb is None:\n            self.import_wandb()\n        if init_kwargs:\n            self.wandb.init(**init_kwargs)  # type: ignore\n        else:\n            self.wandb.init()  # type: ignore\n    \n    @master_only\n    def log(self, runner) -> None:\n        tags = self.get_loggable_tags(runner)\n        mode = self.get_mode(runner)\n        if not tags:\n            return\n        if 'learning_rate' in tags.keys():\n            tags['train/learning_rate'] = tags['learning_rate']\n            del tags['learning_rate']\n        if 'momentum' in tags.keys():\n            del tags['momentum']\n        tags = {k: v for k, v in tags.items() if k.startswith(mode)}\n        \n        if self.with_step:\n            self.wandb.log(\n                tags, step=self.get_iter(runner), commit=self.commit)\n        else:\n            tags['global_step'] = self.get_iter(runner)\n            self.wandb.log(tags, commit=self.commit)\n\n    @master_only\n    def after_run(self, runner) -> None:\n        self.wandb.join()\n"
  },
  {
    "path": "val.py",
    "content": "import os\nimport utils\nimport logging\nimport argparse\nimport importlib\nimport torch\nimport torch.distributed\nimport torch.distributed as dist\nimport torch.backends.cudnn as cudnn\nfrom mmcv import Config\nfrom mmcv.parallel import MMDataParallel, MMDistributedDataParallel\nfrom mmcv.runner import load_checkpoint\nfrom mmdet.apis import set_random_seed, multi_gpu_test, single_gpu_test\nfrom mmdet3d.datasets import build_dataset, build_dataloader\nfrom mmdet3d.models import build_model\n\n\ndef evaluate(dataset, results):\n    metrics = dataset.evaluate(results, jsonfile_prefix=None)\n\n    logging.info('--- Evaluation Results ---')\n    for k, v in metrics.items():\n        logging.info('%s: %.4f' % (k, v))\n\n    return metrics\n\n\ndef main():\n    parser = argparse.ArgumentParser(description='Validate a detector')\n    parser.add_argument('--config', required=True)\n    parser.add_argument('--weights', required=True)\n    parser.add_argument('--local_rank', type=int, default=0)\n    parser.add_argument('--world_size', type=int, default=1)\n    parser.add_argument('--batch_size', type=int, default=1)\n    args = parser.parse_args()\n\n    # parse configs\n    cfgs = Config.fromfile(args.config)\n\n    # register custom module\n    importlib.import_module('models')\n    importlib.import_module('loaders')\n\n    # MMCV, please shut up\n    from mmcv.utils.logging import logger_initialized\n    logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)\n    logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)\n\n    # you need GPUs\n    assert torch.cuda.is_available()\n\n    # determine local_rank and world_size\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n    \n    if 'WORLD_SIZE' not in os.environ:\n        os.environ['WORLD_SIZE'] = str(args.world_size)\n\n    local_rank = int(os.environ['LOCAL_RANK'])\n    world_size = int(os.environ['WORLD_SIZE'])\n\n    if local_rank == 0:\n        utils.init_logging(None, cfgs.debug)\n    else:\n        logging.root.disabled = True\n\n    logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank))\n    torch.cuda.set_device(local_rank)\n\n    if world_size > 1:\n        logging.info('Initializing DDP with %d GPUs...' % world_size)\n        dist.init_process_group('nccl', init_method='env://')\n\n    logging.info('Setting random seed: 0')\n    set_random_seed(0, deterministic=True)\n    cudnn.benchmark = True\n\n    logging.info('Loading validation set from %s' % cfgs.data.val.data_root)\n    val_dataset = build_dataset(cfgs.data.val)\n    val_loader = build_dataloader(\n        val_dataset,\n        samples_per_gpu=args.batch_size,\n        workers_per_gpu=cfgs.data.workers_per_gpu,\n        num_gpus=world_size,\n        dist=world_size > 1,\n        shuffle=False,\n        seed=0,\n    )\n\n    logging.info('Creating model: %s' % cfgs.model.type)\n    model = build_model(cfgs.model)\n    model.cuda()\n\n    if world_size > 1:\n        model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)\n    else:\n        model = MMDataParallel(model, [0])\n\n    if os.path.isfile(args.weights):\n        logging.info('Loading checkpoint from %s' % args.weights)\n        load_checkpoint(\n            model, args.weights, map_location='cuda', strict=True,\n            logger=logging.Logger(__name__, logging.ERROR)\n        )\n\n    if world_size > 1:\n        results = multi_gpu_test(model, val_loader, gpu_collect=True)\n    else:\n        results = single_gpu_test(model, val_loader)\n\n    if local_rank == 0:\n        evaluate(val_dataset, results)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "viz_prediction.py",
    "content": "import os\nimport cv2\nimport utils\nimport logging\nimport argparse\nimport importlib\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom mmcv import Config, DictAction\nfrom mmdet.apis import set_random_seed\nfrom mmdet3d.datasets import build_dataset, build_dataloader\nfrom configs.r50_nuimg_704x256_8f import point_cloud_range as pc_range\nfrom configs.r50_nuimg_704x256_8f import occ_size\nfrom configs.r50_nuimg_704x256_8f import occ_class_names\nfrom mmcv.parallel import MMDataParallel\nfrom mmcv.runner import load_checkpoint\nfrom mmdet3d.models import build_model\n\n\ncolor_map = np.array([\n    [0, 0, 0, 255],    # others\n    [255, 120, 50, 255],  # barrier              orangey\n    [255, 192, 203, 255],  # bicycle              pink\n    [255, 255, 0, 255],  # bus                  yellow\n    [0, 150, 245, 255],  # car                  blue\n    [0, 255, 255, 255],  # construction_vehicle cyan\n    [200, 180, 0, 255],  # motorcycle           dark orange\n    [255, 0, 0, 255],  # pedestrian           red\n    [255, 240, 150, 255],  # traffic_cone         light yellow\n    [135, 60, 0, 255],  # trailer              brown\n    [160, 32, 240, 255],  # truck                purple\n    [255, 0, 255, 255],  # driveable_surface    dark pink\n    [175,   0,  75, 255],       # other_flat           dark red\n    [75, 0, 75, 255],  # sidewalk             dard purple\n    [150, 240, 80, 255],  # terrain              light green\n    [230, 230, 250, 255],  # manmade              white\n    [0, 175, 0, 255],  # vegetation           green\n    [255, 255, 255, 255],  # free             white\n], dtype=np.uint8)\n\ndef occ2img(semantics):\n    H, W, D = semantics.shape\n\n    free_id = len(occ_class_names) - 1\n    semantics_2d = np.ones([H, W], dtype=np.int32) * free_id\n\n    for i in range(D):\n        semantics_i = semantics[..., i]\n        non_free_mask = (semantics_i != free_id)\n        semantics_2d[non_free_mask] = semantics_i[non_free_mask]\n\n    viz = color_map[semantics_2d]\n    viz = viz[..., :3]\n    viz = cv2.resize(viz, dsize=(800, 800))\n\n    return viz\n\ndef main():\n    parser = argparse.ArgumentParser(description='Validate a detector')\n    parser.add_argument('--config', required=True)\n    parser.add_argument('--weights', required=True)\n    parser.add_argument('--viz-dir', required=True)\n    parser.add_argument('--override', nargs='+', action=DictAction)\n    args = parser.parse_args()\n\n    # parse configs\n    cfgs = Config.fromfile(args.config)\n    if args.override is not None:\n        cfgs.merge_from_dict(args.override)\n\n    # use val-mini for visualization\n    #cfgs.data.val.ann_file = cfgs.data.val.ann_file.replace('val', 'val_mini')\n\n    # register custom module\n    importlib.import_module('models')\n    importlib.import_module('loaders')\n\n    # MMCV, please shut up\n    from mmcv.utils.logging import logger_initialized\n    logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)\n    logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)\n\n    # you need one GPU\n    assert torch.cuda.is_available()\n    assert torch.cuda.device_count() == 1\n\n    # logging\n    utils.init_logging(None, cfgs.debug)\n    logging.info('Using GPU: %s' % torch.cuda.get_device_name(0))\n\n    # random seed\n    logging.info('Setting random seed: 0')\n    set_random_seed(0, deterministic=True)\n\n    logging.info('Loading validation set from %s' % cfgs.data.val.data_root)\n    val_dataset = build_dataset(cfgs.data.val)\n    val_loader = build_dataloader(\n        val_dataset,\n        samples_per_gpu=1,\n        workers_per_gpu=cfgs.data.workers_per_gpu,\n        num_gpus=1,\n        dist=False,\n        shuffle=False,\n        seed=0,\n    )\n\n    logging.info('Creating model: %s' % cfgs.model.type)\n    model = build_model(cfgs.model)\n    model.cuda()\n    model = MMDataParallel(model, [0])\n    model.eval()\n\n    logging.info('Loading checkpoint from %s' % args.weights)\n    load_checkpoint(\n        model, args.weights, map_location='cuda', strict=True,\n        logger=logging.Logger(__name__, logging.ERROR)\n    )\n\n    for i, data in tqdm(enumerate(val_loader)):\n\n        #print(data['img_metas'].data[0][0]['filename'][:6])\n\n        with torch.no_grad():\n            occ_pred = model(return_loss=False, rescale=True, **data)[0]\n\n            sem_pred = torch.from_numpy(occ_pred['sem_pred'])[0]  # [N]\n            occ_loc = torch.from_numpy(occ_pred['occ_loc'].astype(np.int64))[0]  # [N, 3]\n            \n            # sparse to dense\n            free_id = len(occ_class_names) - 1\n            dense_pred = torch.ones(occ_size, device=sem_pred.device, dtype=sem_pred.dtype) * free_id  # [200, 200, 16]\n            dense_pred[occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]] = sem_pred\n            \n            sem_pred = dense_pred.numpy()\n\n            cv2.imwrite(os.path.join(args.viz_dir, 'sem_%04d.jpg' % i), occ2img(sem_pred)[..., ::-1])\n\n\nif __name__ == '__main__':\n    main()\n"
  }
]