Repository: MCG-NJU/SparseOcc Branch: main Commit: af4d9df83bfd Files: 58 Total size: 409.0 KB Directory structure: gitextract_8j2gi4im/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── r50_nuimg_704x256_8f.py │ ├── r50_nuimg_704x256_8f_60e.py │ ├── r50_nuimg_704x256_8f_openocc.py │ └── r50_nuimg_704x256_8f_pano.py ├── gen_instance_info.py ├── gen_sweep_info.py ├── lib/ │ └── dvr/ │ ├── dvr.cpp │ └── dvr.cu ├── loaders/ │ ├── __init__.py │ ├── builder.py │ ├── ego_pose_dataset.py │ ├── nuscenes_dataset.py │ ├── nuscenes_occ_dataset.py │ ├── old_metrics.py │ ├── pipelines/ │ │ ├── __init__.py │ │ ├── loading.py │ │ └── transforms.py │ ├── ray_metrics.py │ └── ray_pq.py ├── models/ │ ├── __init__.py │ ├── backbones/ │ │ ├── __init__.py │ │ └── vovnet.py │ ├── bbox/ │ │ ├── __init__.py │ │ ├── assigners/ │ │ │ ├── __init__.py │ │ │ └── hungarian_assigner_3d.py │ │ ├── coders/ │ │ │ ├── __init__.py │ │ │ └── nms_free_coder.py │ │ ├── match_costs/ │ │ │ ├── __init__.py │ │ │ └── match_cost.py │ │ └── utils.py │ ├── checkpoint.py │ ├── csrc/ │ │ ├── __init__.py │ │ ├── msmv_sampling/ │ │ │ ├── msmv_sampling.cpp │ │ │ ├── msmv_sampling.h │ │ │ ├── msmv_sampling_backward.cu │ │ │ └── msmv_sampling_forward.cu │ │ ├── setup.py │ │ └── wrapper.py │ ├── loss_utils.py │ ├── matcher.py │ ├── sparse_voxel_decoder.py │ ├── sparsebev_head.py │ ├── sparsebev_sampling.py │ ├── sparsebev_transformer.py │ ├── sparseocc.py │ ├── sparseocc_head.py │ ├── sparseocc_transformer.py │ └── utils.py ├── old_metrics.py ├── ray_metrics.py ├── timing.py ├── train.py ├── utils.py ├── val.py └── viz_prediction.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ # OS generated files .DS_Store .DS_Store? ._* .Spotlight-V100 .Trashes ehthumbs.db Thumbs.db # Compiled source build debug Debug release Release x64 *.so *.whl # VS project files *.sln *.vcxproj *.vcxproj.filters *.vcxproj.user *.rc .vs # Byte-compiled / optimized / DLL files *__pycache__* *.py[cod] *$py.class # Distribution / packaging .Python build develop-eggs dist downloads # IDE .idea .vscode pyrightconfig.json # Custom data outputs prediction submission checkpoints pretrain ckpts occ_result wandb ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # SparseOcc This is the official PyTorch implementation for our paper: > [**Fully Sparse 3D Panoptic Occupancy Prediction**](https://arxiv.org/abs/2312.17118)
> :school: Presented by Nanjing University and Shanghai AI Lab
> :email: Primary contact: Haisong Liu (afterthat97@gmail.com)
> :trophy: [CVPR 2024 Autonomous Driving Challenge - Occupancy and Flow](https://opendrivelab.com/challenge2024/#occupancy_and_flow)
> :book: 中文解读(官方):https://zhuanlan.zhihu.com/p/709576252
> :book: 中文解读(第三方): [AIming](https://zhuanlan.zhihu.com/p/691549750), [自动驾驶之心](https://zhuanlan.zhihu.com/p/675811281) ## :warning: Important Notes There is another concurrent project titled *''SparseOcc: Rethinking sparse latent representation for vision-based semantic occupancy prediction''* by Tang et al., which shares the same name SparseOcc with ours. However, this repository is **unrelated** to the aforementioned paper. If you cite our research, please ensure that you reference the correct version (arXiv **2312.17118**, authored by **Liu et al.**): ``` @article{liu2023fully, title={Fully sparse 3d panoptic occupancy prediction}, author={Liu, Haisong and Wang, Haiguang and Chen, Yang and Yang, Zetong and Zeng, Jia and Chen, Li and Wang, Limin}, journal={arXiv preprint arXiv:2312.17118}, year={2023} } ``` > In arXiv 2312.17118v3, we removed the word "panoptic" from the title. However, Google Scholar's database has not been updated and still shows the old one. Therefore, we still recommend citing the old title - "Fully sparse 3d panoptic occupancy prediction" - so that Google Scholar can index it correctly. Thank you all. ## News * **2024-07-19**: We released an updated version of SparseOcc on [arXiv](https://arxiv.org/abs/2312.17118). All charts and colors have been carefully adjusted. Delete the old version and download the new one! * **2024-07-01**: SparseOcc is accepted to ECCV 2024. * **2024-06-27**: SparseOcc v1.1 is released. In this change, we introduce BEV data augmentation (BDA) and Lovasz-Softmax loss to further enhance the performance. Compared with [v1.0](https://github.com/MCG-NJU/SparseOcc/tree/v1.0) (35.0 RayIoU with 48 epochs), SparseOcc v1.1 can achieve 36.8 RayIoU with 24 epochs! * **2024-05-29**: We add support for [OpenOcc v2](configs/r50_nuimg_704x256_8f_openocc.py) dataset (without occupancy flow). * **2024-04-11**: The panoptic version of SparseOcc ([configs/r50_nuimg_704x256_8f_pano.py](configs/r50_nuimg_704x256_8f_pano.py)) is released. * **2024-04-09**: An updated arXiv version [https://arxiv.org/abs/2312.17118v3](https://arxiv.org/abs/2312.17118v3) has been released. * **2024-03-31**: We release the code and pretrained weights. * **2023-12-30**: We release the paper. ## Highlights **New model**:1st_place_medal:: SparseOcc initially reconstructs a sparse 3D representation from visual inputs and subsequently predicts semantic/instance occupancy from the 3D sparse representation by sparse queries. ![](asserts/arch.jpg) **New evaluation metric**:chart_with_upwards_trend:: We design a thoughtful ray-based evaluation metric, namely RayIoU, to solve the inconsistency penalty along depths raised in traditional voxel-level mIoU criteria. ![](asserts/rayiou.jpg) Some FAQs from the community about the evaluation metrics: 1. **Why does training with visible masks result in significant improvements in the old mIoU metric, but not in the new RayIoU metric?** As mentioned in the paper, when using the visible mask during training, the area behind the surface won't be supervised, so the model tends to fill this area with duplicated predictions, leading to a thicker surface. The old metric inconsistently penalizes along the depth axis when the prediction has a thick surface. Thus, this ''imporovement'' is mainly due to the vulnerability of old metric. 2. **Why SparseOcc cannot exploit the vulnerability of the old metrics?** As SparseOcc employs a fully sparse architecture, it always predicts a thin surface. Thus, there are two ways for a fair comparison: (a) use the old metric, but all methods must predict a thin surface, which implies they cannot use the visible mask during training; (b) use RayIoU, as it is more reasonable and can fairly compare thick or thin surface. Our method achieves SOTA performance on both cases. 3. **Does RayIoU overlook interior reconstruction?** Firstly, we are unable to obtain the interior occupancy ground-truth. This is because the ground-truth is derived from voxelizing LiDAR point clouds, and LiDARs are only capable of scanning the thin surface of an object. Secondly, the query ray in RayIoU can originate from any position within the scene (see the figure above). This allows it to evaluate the overall reconstruction performance, unlike depth estimation. We would like to emphasize that the evaluation logic of RayIoU aligns with the process of ground-truth generation. If you have other questions, feel free to contact me (Haisong Liu, afterthat97@gmail.com). ## Model Zoo These results are from our latest version, v1.1, which outperforms the results reported in the paper. Additionally, our implementation differs slightly from the original paper. If you wish to reproduce the paper exactly, please refer to the [v1.0](https://github.com/MCG-NJU/SparseOcc/tree/v1.0) tag. | Setting | Epochs | Training Cost | RayIoU | RayPQ | FPS | Weights | |----------|:--------:|:-------------:|:------:|:-----:|:---:|:-------:| | [r50_nuimg_704x256_8f](configs/r50_nuimg_704x256_8f.py) | 24 | 15h, ~12GB | 36.8 | - | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_24e_v1.1.pth) | | [r50_nuimg_704x256_8f_60e](configs/r50_nuimg_704x256_8f_60e.py) | 60 | 37h, ~12GB | 37.7 | - | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_60e_v1.1.pth) | | [r50_nuimg_704x256_8f_pano](configs/r50_nuimg_704x256_8f_pano.py) | 24 | 15h, ~12GB | 35.9 | 14.0 | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_pano_24e_v1.1.pth) | * The backbone is pretrained on [nuImages](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth). Download the weights to `pretrain/xxx.pth` before you start training. * FPS is measured with Intel(R) Xeon(R) Platinum 8369B CPU and NVIDIA A100-SXM4-80GB GPU (PyTorch `fp32` backend, including data loading). * We will release more settings in the future. ## Environment > The requirements are the same as those of [SparseBEV](https://github.com/MCG-NJU/SparseBEV). Install PyTorch 2.0 + CUDA 11.8: ``` conda create -n sparseocc python=3.8 conda activate sparseocc conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.8 -c pytorch -c nvidia ``` Install other dependencies: ``` pip install openmim mim install mmcv-full==1.6.0 mim install mmdet==2.28.2 mim install mmsegmentation==0.30.0 mim install mmdet3d==1.0.0rc6 pip install setuptools==59.5.0 pip install numpy==1.23.5 ``` Install turbojpeg and pillow-simd to speed up data loading (optional but important): ``` sudo apt-get update sudo apt-get install -y libturbojpeg pip install pyturbojpeg pip uninstall pillow pip install pillow-simd==9.0.0.post1 ``` Compile CUDA extensions: ``` cd models/csrc python setup.py build_ext --inplace ``` ## Prepare Dataset > The first two steps are the same as those of [SparseBEV](https://github.com/MCG-NJU/SparseBEV). 1. Download nuScenes from [https://www.nuscenes.org/nuscenes](https://www.nuscenes.org/nuscenes), put it to `data/nuscenes` and preprocess it with [mmdetection3d](https://github.com/open-mmlab/mmdetection3d/tree/v1.0.0rc6). 2. Download the generated info file from [gdrive](https://drive.google.com/file/d/1uyoUuSRIVScrm_CUpge6V_UzwDT61ODO/view?usp=sharing) and unzip it. These `*.pkl` files can also be generated with our script: `gen_sweep_info.py`. 3. Download Occ3D-nuScenes occupancy GT from [gdrive](https://drive.google.com/file/d/1kiXVNSEi3UrNERPMz_CfiJXKkgts_5dY/view?usp=drive_link), unzip it, and save it to `data/nuscenes/occ3d`. 4. Folder structure: ``` data/nuscenes ├── maps ├── nuscenes_infos_test_sweep.pkl ├── nuscenes_infos_train_sweep.pkl ├── nuscenes_infos_val_sweep.pkl ├── samples ├── sweeps ├── v1.0-test └── v1.0-trainval └── occ3d ├── scene-0001 │ ├── 0037a705a2e04559b1bba6c01beca1cf │ │ └── labels.npz │ ├── 026155aa1c554e2f87914ec9ba80acae │ │ └── labels.npz ... ``` 5. (Optional) Generate the panoptic occupancy ground truth with `gen_instance_info.py`. The panoptic version of Occ3D will be saved to `data/nuscenes/occ3d_panoptic`. ## Training Train SparseOcc with 8 GPUs: ``` torchrun --nproc_per_node 8 train.py --config configs/sparseocc_r50_nuimg_704x256_8f.py ``` Train SparseOcc with 4 GPUs (i.e the last four GPUs): ``` export CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node 4 train.py --config configs/sparseocc_r50_nuimg_704x256_8f.py ``` The batch size for each GPU will be scaled automatically. So there is no need to modify the `batch_size` in config files. ## Evaluation Single-GPU evaluation: ``` export CUDA_VISIBLE_DEVICES=0 python val.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth ``` Multi-GPU evaluation: ``` export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node 8 val.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth ``` ## Standalone Evaluation If you want to evaluate your own model using RayIoU, please follow the steps below: 1. Save the predictions (shape=`[200x200x16]`, dtype=`np.uint8`) with the compressed `npz` format. For example: ``` save_path = os.path.join(save_dir, sample_token + '.npz') np.savez_compressed(save_path, pred=sem_pred) ``` 2. The filename for each sample is `sample_token.npz`, for example: ``` prediction/your_model ├── 000681a060c04755a1537cf83b53ba57.npz ├── 000868a72138448191b4092f75ed7776.npz ├── 0017c2623c914571a1ff2a37f034ffd7.npz ├── ... ``` 3. Run `ray_metrics.py` to evaluate on the RayIoU: ``` python ray_metrics.py --pred-dir prediction/your_model ``` ## Timing FPS is measured with a single GPU: ``` export CUDA_VISIBLE_DEVICES=0 python timing.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth ``` ## Acknowledgements Many thanks to these excellent open-source projects: * [MaskFormer](https://github.com/facebookresearch/MaskFormer) * [NeuralRecon](https://github.com/zju3dv/NeuralRecon) * [4D-Occ](https://github.com/tarashakhurana/4d-occ-forecasting) * [MMDetection3D](https://github.com/open-mmlab/mmdetection3d) ================================================ FILE: configs/r50_nuimg_704x256_8f.py ================================================ dataset_type = 'NuSceneOcc' dataset_root = 'data/nuscenes/' occ_gt_root = 'data/nuscenes/occ3d' # If point cloud range is changed, the models should also change their point # cloud range accordingly point_cloud_range = [-40, -40, -1.0, 40, 40, 5.4] occ_size = [200, 200, 16] img_norm_cfg = dict( mean=[123.675, 116.280, 103.530], std=[58.395, 57.120, 57.375], to_rgb=True ) # For nuScenes we usually do 10-class detection det_class_names = [ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' ] occ_class_names = [ 'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation', 'free' ] input_modality = dict( use_lidar=False, use_camera=True, use_radar=False, use_map=False, use_external=False ) _dim_ = 256 _num_points_ = 4 _num_groups_ = 4 _num_layers_ = 2 _num_frames_ = 8 _num_queries_ = 100 _topk_training_ = [4000, 16000, 64000] _topk_testing_ = [2000, 8000, 32000] model = dict( type='SparseOcc', data_aug=dict( img_color_aug=True, # Move some augmentations to GPU img_norm_cfg=img_norm_cfg, img_pad_cfg=dict(size_divisor=32)), use_mask_camera=False, img_backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=1, norm_cfg=dict(type='BN2d', requires_grad=True), norm_eval=True, style='pytorch', with_cp=True), img_neck=dict( type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=_dim_, num_outs=4), pts_bbox_head=dict( type='SparseOccHead', class_names=occ_class_names, embed_dims=_dim_, occ_size=occ_size, pc_range=point_cloud_range, transformer=dict( type='SparseOccTransformer', embed_dims=_dim_, num_layers=_num_layers_, num_frames=_num_frames_, num_points=_num_points_, num_groups=_num_groups_, num_queries=_num_queries_, num_levels=4, num_classes=len(occ_class_names), pc_range=point_cloud_range, occ_size=occ_size, topk_training=_topk_training_, topk_testing=_topk_testing_), loss_cfgs=dict( loss_mask2former=dict( type='Mask2FormerLoss', num_classes=len(occ_class_names), no_class_weight=0.1, loss_cls_weight=2.0, loss_mask_weight=5.0, loss_dice_weight=5.0, ), loss_geo_scal=dict( type='GeoScalLoss', num_classes=len(occ_class_names), loss_weight=1.0 ), loss_sem_scal=dict( type='SemScalLoss', num_classes=len(occ_class_names), loss_weight=1.0 ) ), ), ) ida_aug_conf = { 'resize_lim': (0.38, 0.55), 'final_dim': (256, 704), 'bot_pct_lim': (0.0, 0.0), 'rot_lim': (0.0, 0.0), 'H': 900, 'W': 1600, 'rand_flip': True, } bda_aug_conf = dict( rot_lim=(-22.5, 22.5), scale_lim=(1., 1.), flip_dx_ratio=0.5, flip_dy_ratio=0.5 ) train_pipeline = [ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'), dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1), dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=True), dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)), dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True), dict(type='DefaultFormatBundle3D', class_names=det_class_names), dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], # other keys: 'mask_camera' meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar')) ] test_pipeline = [ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'), dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True), dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=False), dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)), dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False), dict(type='DefaultFormatBundle3D', class_names=det_class_names), dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar')) ] data = dict( workers_per_gpu=8, train=dict( type=dataset_type, data_root=dataset_root, occ_gt_root=occ_gt_root, ann_file=dataset_root + 'nuscenes_infos_train_sweep.pkl', pipeline=train_pipeline, classes=det_class_names, modality=input_modality, test_mode=False ), val=dict( type=dataset_type, data_root=dataset_root, occ_gt_root=occ_gt_root, ann_file=dataset_root + 'nuscenes_infos_val_sweep.pkl', pipeline=test_pipeline, classes=det_class_names, modality=input_modality, test_mode=True ), test=dict( type=dataset_type, data_root=dataset_root, occ_gt_root=occ_gt_root, ann_file=dataset_root + 'nuscenes_infos_test_sweep.pkl', pipeline=test_pipeline, classes=det_class_names, modality=input_modality, test_mode=True ), ) optimizer = dict( type='AdamW', lr=5e-4, paramwise_cfg=dict( custom_keys={ 'img_backbone': dict(lr_mult=0.1), 'sampling_offset': dict(lr_mult=0.1), }), weight_decay=0.01 ) optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) lr_config = dict( policy='step', warmup='linear', warmup_iters=500, warmup_ratio=1.0 / 3, by_epoch=True, step=[22, 24], gamma=0.2 ) total_epochs = 24 batch_size = 8 # load pretrained weights load_from = 'pretrain/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth' revise_keys = [('backbone', 'img_backbone')] # resume the last training resume_from = None # checkpointing checkpoint_config = dict(interval=1, max_keep_ckpts=1) # logging log_config = dict( interval=1, hooks=[ dict(type='MyTextLoggerHook', interval=1, reset_flag=True), dict(type='MyTensorboardLoggerHook', interval=500, reset_flag=True) ] ) # evaluation eval_config = dict(interval=total_epochs) # other flags debug = False ================================================ FILE: configs/r50_nuimg_704x256_8f_60e.py ================================================ _base_ = ['./r50_nuimg_704x256_8f.py'] lr_config = dict( policy='step', warmup='linear', warmup_iters=500, warmup_ratio=1.0 / 3, by_epoch=True, step=[48, 60], gamma=0.2 ) total_epochs = 60 # evaluation eval_config = dict(interval=total_epochs) ================================================ FILE: configs/r50_nuimg_704x256_8f_openocc.py ================================================ _base_ = ['./r50_nuimg_704x256_8f.py'] occ_gt_root = 'data/nuscenes/openocc_v2' det_class_names = [ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' ] occ_class_names = [ 'car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier', 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation', 'free' ] _num_frames_ = 8 model = dict( pts_bbox_head=dict( class_names=occ_class_names, transformer=dict( num_classes=len(occ_class_names)), loss_cfgs=dict( loss_mask2former=dict( num_classes=len(occ_class_names) ), ), ), ) ida_aug_conf = { 'resize_lim': (0.38, 0.55), 'final_dim': (256, 704), 'bot_pct_lim': (0.0, 0.0), 'rot_lim': (0.0, 0.0), 'H': 900, 'W': 1600, 'rand_flip': False, } train_pipeline = [ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'), dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1), dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)), dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True), dict(type='DefaultFormatBundle3D', class_names=det_class_names), dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], # other keys: 'mask_camera' meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar')) ] test_pipeline = [ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'), dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True), dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)), dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False), dict(type='DefaultFormatBundle3D', class_names=det_class_names), dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar')) ] data = dict( workers_per_gpu=8, train=dict( pipeline=train_pipeline, occ_gt_root=occ_gt_root ), val=dict( pipeline=test_pipeline, occ_gt_root=occ_gt_root ), test=dict( pipeline=test_pipeline, occ_gt_root=occ_gt_root ), ) ================================================ FILE: configs/r50_nuimg_704x256_8f_pano.py ================================================ _base_ = ['./r50_nuimg_704x256_8f.py'] occ_gt_root = 'data/nuscenes/occ3d_panoptic' # For nuScenes we usually do 10-class detection det_class_names = [ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' ] occ_class_names = [ 'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation', 'free' ] _num_frames_ = 8 model = dict( pts_bbox_head=dict( panoptic=True ) ) ida_aug_conf = { 'resize_lim': (0.38, 0.55), 'final_dim': (256, 704), 'bot_pct_lim': (0.0, 0.0), 'rot_lim': (0.0, 0.0), 'H': 900, 'W': 1600, 'rand_flip': True, } bda_aug_conf = dict( rot_lim=(-22.5, 22.5), scale_lim=(1., 1.), flip_dx_ratio=0.5, flip_dy_ratio=0.5 ) train_pipeline = [ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'), dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1), dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=True), dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names), inst_class_ids=[2, 3, 4, 5, 6, 7, 9, 10]), dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True), dict(type='DefaultFormatBundle3D', class_names=det_class_names), dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], # other keys: 'mask_camera' meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar')) ] test_pipeline = [ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'), dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True), dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=False), dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names), inst_class_ids=[2, 3, 4, 5, 6, 7, 9, 10]), dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False), dict(type='DefaultFormatBundle3D', class_names=det_class_names), dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar')) ] data = dict( workers_per_gpu=8, train=dict( pipeline=train_pipeline, occ_gt_root=occ_gt_root ), val=dict( pipeline=test_pipeline, occ_gt_root=occ_gt_root ), test=dict( pipeline=test_pipeline, occ_gt_root=occ_gt_root ), ) ================================================ FILE: gen_instance_info.py ================================================ import os import tqdm import glob import pickle import argparse import numpy as np import torch import multiprocessing from pyquaternion import Quaternion from nuscenes.utils.data_classes import Box from nuscenes.utils.geometry_utils import points_in_box parser = argparse.ArgumentParser() parser.add_argument('--nusc-root', default='data/nuscenes') parser.add_argument('--occ3d-root', default='data/nuscenes/occ3d') parser.add_argument('--output-dir', default='data/nuscenes/occ3d_panoptic') parser.add_argument('--version', default='v1.0-trainval') args = parser.parse_args() token2path = {} for gt_path in glob.glob(os.path.join(args.occ3d_root, '*/*/*.npz')): token = gt_path.split('/')[-2] token2path[token] = gt_path occ_class_names = [ 'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation', 'free' ] det_class_names = [ 'car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier' ] def convert_to_nusc_box(bboxes, lift_center=False, wlh_margin=0.0): results = [] for q in range(bboxes.shape[0]): bbox = bboxes[q].copy() if lift_center: bbox[2] += bbox[5] * 0.5 bbox_yaw = -bbox[6] - np.pi / 2 orientation = Quaternion(axis=[0, 0, 1], radians=bbox_yaw).inverse box = Box( center=[bbox[0], bbox[1], bbox[2]], # 0.8 in pc range is roungly 2 voxels in occ grid # enlarge bbox to include voxels on the edge size=[bbox[3]+wlh_margin, bbox[4]+wlh_margin, bbox[5]+wlh_margin], orientation=orientation, ) results.append(box) return results def meshgrid3d(occ_size, pc_range): # points in ego coord W, H, D = occ_size xs = torch.linspace(0.5, W - 0.5, W).view(W, 1, 1).expand(W, H, D) / W ys = torch.linspace(0.5, H - 0.5, H).view(1, H, 1).expand(W, H, D) / H zs = torch.linspace(0.5, D - 0.5, D).view(1, 1, D).expand(W, H, D) / D xs = xs * (pc_range[3] - pc_range[0]) + pc_range[0] ys = ys * (pc_range[4] - pc_range[1]) + pc_range[1] zs = zs * (pc_range[5] - pc_range[2]) + pc_range[2] xyz = torch.stack((xs, ys, zs), -1) return xyz def process_add_instance_info(sample): point_cloud_range = [-40, -40, -1.0, 40, 40, 5.4] occ_size = [200, 200, 16] num_classes = 18 occ_gt_path = token2path[sample['token']] occ_labels = np.load(occ_gt_path) occ_gt = occ_labels['semantics'] gt_boxes = sample['gt_boxes'] gt_names = sample['gt_names'] bboxes = convert_to_nusc_box(gt_boxes) instance_gt = np.zeros(occ_gt.shape).astype(np.uint8) instance_id = 1 pts = meshgrid3d(occ_size, point_cloud_range).numpy() # filter out free voxels to accelerate valid_idx = np.where(occ_gt < num_classes - 1) flatten_occ_gt = occ_gt[valid_idx] flatten_inst_gt = instance_gt[valid_idx] flatten_pts = pts[valid_idx] instance_boxes = [] instance_class_ids = [] for i in range(len(gt_names)): if gt_names[i] not in occ_class_names: continue occ_tag_id = occ_class_names.index(gt_names[i]) # Move box to ego vehicle coord system bbox = bboxes[i] bbox.rotate(Quaternion(sample['lidar2ego_rotation'])) bbox.translate(np.array(sample['lidar2ego_translation'])) mask = points_in_box(bbox, flatten_pts.transpose(1, 0)) # ignore voxels not belonging to this class mask[mask] = (flatten_occ_gt[mask] == occ_tag_id) # ignore voxels already occupied mask[mask] = (flatten_inst_gt[mask] == 0) # only instance with at least 1 voxel will be recorded if mask.sum() > 0: flatten_inst_gt[mask] = instance_id instance_id += 1 # enlarge boxes to include voxels on the edge new_box = bbox.copy() new_box.wlh = new_box.wlh + 0.8 instance_boxes.append(new_box) instance_class_ids.append(occ_tag_id) # classes that should be viewed as one instance all_class_ids_unique = np.unique(occ_gt) for i, class_name in enumerate(occ_class_names): if class_name in det_class_names or class_name == 'free' or i not in all_class_ids_unique: continue flatten_inst_gt[flatten_occ_gt == i] = instance_id instance_id += 1 # post process unconvered non-occupied voxels uncover_idx = np.where(flatten_inst_gt == 0) uncover_pts = flatten_pts[uncover_idx] uncover_inst_gt = np.zeros_like(uncover_pts[..., 0]).astype(np.uint8) unconver_occ_gt = flatten_occ_gt[uncover_idx] # uncover_inst_dist records the dist between each voxel and its current nearest bbox's center uncover_inst_dist = np.ones_like(uncover_pts[..., 0]) * 1e8 for i, box in enumerate(instance_boxes): # important, non-background inst id starts from 1 inst_id = i + 1 class_id = instance_class_ids[i] mask = points_in_box(box, uncover_pts.transpose(1, 0)) # mask voxels not belonging to this class mask[unconver_occ_gt != class_id] = False dist = np.sum((box.center - uncover_pts) ** 2, axis=-1) # voxels that have already been assigned to a closer box's instance should be ignored # voxels that not inside the box should be ignored # `mask[(dist >= uncover_inst_dist)]=False` is right, as it only transforms True masks into False without converting False into True # to give readers a more clear understanding, the most standard writing is `mask[mask & (dist >= uncover_inst_dist)]=False` mask[dist >= uncover_inst_dist] = False # mask[mask & (dist >= uncover_inst_dist)]=False # important: only voxels inside the box (mask = True) and having no closer identical-class box need to update dist uncover_inst_dist[mask] = dist[mask] uncover_inst_gt[mask] = inst_id flatten_inst_gt[uncover_idx] = uncover_inst_gt instance_gt[valid_idx] = flatten_inst_gt # not using this checking function yet # assert (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum() < 100, "too many non-free voxels are not assigned to any instance in %s"%(occ_gt_path) # global max_margin # if max_margin < (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum(): # print("###### new max margin: ", max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum())) # max_margin = max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum()) # save to original path data_split = occ_gt_path.split(os.path.sep)[-3:] data_path = os.path.sep.join(data_split) ##### Warning: Using args.xxx (global variable) here is strongly unrecommended save_path = os.path.join(args.output_dir, data_path) save_dir = os.path.split(save_path)[0] if not os.path.exists(save_dir): os.makedirs(save_dir) if np.unique(instance_gt).shape[0] != instance_gt.max()+1: print('warning: some instance masks are covered by following ones %s'%(save_dir)) # only semantic and mask information is needed to be reserved retain_keys = ['semantics', 'mask_lidar', 'mask_camera'] new_occ_labels = {k: occ_labels[k] for k in retain_keys} new_occ_labels['instances'] = instance_gt np.savez_compressed(save_path, **new_occ_labels) def add_instance_info(sample_infos): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) # all cpus participate in multi processing pool = multiprocessing.Pool(multiprocessing.cpu_count()) with tqdm.tqdm(total=len(sample_infos['infos'])) as pbar: for _ in pool.imap(process_add_instance_info, sample_infos['infos']): pbar.update(1) pool.close() pool.join() if __name__ == '__main__': if args.version == 'v1.0-trainval': sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_train_sweep.pkl'), 'rb')) add_instance_info(sample_infos) sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_val_sweep.pkl'), 'rb')) add_instance_info(sample_infos) elif args.version == 'v1.0-test': sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_test_sweep.pkl'), 'rb')) add_instance_info(sample_infos) else: raise ValueError ================================================ FILE: gen_sweep_info.py ================================================ # Generate info files manually import os import mmcv import tqdm import pickle import argparse import numpy as np from nuscenes import NuScenes from pyquaternion import Quaternion parser = argparse.ArgumentParser() parser.add_argument('--data-root', default='data/nuscenes') parser.add_argument('--version', default='v1.0-trainval') args = parser.parse_args() def get_cam_info(nusc, sample_data): pose_record = nusc.get('ego_pose', sample_data['ego_pose_token']) cs_record = nusc.get('calibrated_sensor', sample_data['calibrated_sensor_token']) sensor2ego_translation = cs_record['translation'] ego2global_translation = pose_record['translation'] sensor2ego_rotation = Quaternion(cs_record['rotation']).rotation_matrix ego2global_rotation = Quaternion(pose_record['rotation']).rotation_matrix cam_intrinsic = np.array(cs_record['camera_intrinsic']) sensor2global_rotation = sensor2ego_rotation.T @ ego2global_rotation.T sensor2global_translation = sensor2ego_translation @ ego2global_rotation.T + ego2global_translation return { 'data_path': os.path.join(args.data_root, sample_data['filename']), 'sensor2global_rotation': sensor2global_rotation, 'sensor2global_translation': sensor2global_translation, 'cam_intrinsic': cam_intrinsic, 'timestamp': sample_data['timestamp'], } def add_sweep_info(nusc, sample_infos): for curr_id in tqdm.tqdm(range(len(sample_infos['infos']))): sample = nusc.get('sample', sample_infos['infos'][curr_id]['token']) cam_types = [ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT' ] curr_cams = dict() for cam in cam_types: curr_cams[cam] = nusc.get('sample_data', sample['data'][cam]) for cam in cam_types: sample_data = nusc.get('sample_data', sample['data'][cam]) sweep_cam = get_cam_info(nusc, sample_data) sample_infos['infos'][curr_id]['cams'][cam].update(sweep_cam) # remove unnecessary for cam in cam_types: del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_translation'] del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_rotation'] del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_translation'] del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_rotation'] sweep_infos = [] if sample['prev'] != '': # add sweep frame between two key frame for _ in range(5): sweep_info = dict() for cam in cam_types: if curr_cams[cam]['prev'] == '': sweep_info = sweep_infos[-1] break sample_data = nusc.get('sample_data', curr_cams[cam]['prev']) sweep_cam = get_cam_info(nusc, sample_data) curr_cams[cam] = sample_data sweep_info[cam] = sweep_cam sweep_infos.append(sweep_info) sample_infos['infos'][curr_id]['sweeps'] = sweep_infos return sample_infos if __name__ == '__main__': nusc = NuScenes(args.version, args.data_root) if args.version == 'v1.0-trainval': sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_train.pkl'), 'rb')) sample_infos = add_sweep_info(nusc, sample_infos) mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_train_sweep.pkl')) sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_val.pkl'), 'rb')) sample_infos = add_sweep_info(nusc, sample_infos) mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_val_sweep.pkl')) elif args.version == 'v1.0-test': sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_test.pkl'), 'rb')) sample_infos = add_sweep_info(nusc, sample_infos) mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_test_sweep.pkl')) else: raise ValueError ================================================ FILE: lib/dvr/dvr.cpp ================================================ // Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting // Modified by Haisong Liu #include #include #include /* * CUDA forward declarations */ std::vector render_forward_cuda(torch::Tensor sigma, torch::Tensor origin, torch::Tensor points, torch::Tensor tindex, const std::vector grid, std::string phase_name); std::vector render_cuda(torch::Tensor sigma, torch::Tensor origin, torch::Tensor points, torch::Tensor tindex, std::string loss_name); torch::Tensor init_cuda(torch::Tensor points, torch::Tensor tindex, const std::vector grid); /* * C++ interface */ #define CHECK_CUDA(x) \ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) std::vector render_forward(torch::Tensor sigma, torch::Tensor origin, torch::Tensor points, torch::Tensor tindex, const std::vector grid, std::string phase_name) { CHECK_INPUT(sigma); CHECK_INPUT(origin); CHECK_INPUT(points); CHECK_INPUT(tindex); return render_forward_cuda(sigma, origin, points, tindex, grid, phase_name); } std::vector render(torch::Tensor sigma, torch::Tensor origin, torch::Tensor points, torch::Tensor tindex, std::string loss_name) { CHECK_INPUT(sigma); CHECK_INPUT(origin); CHECK_INPUT(points); CHECK_INPUT(tindex); return render_cuda(sigma, origin, points, tindex, loss_name); } torch::Tensor init(torch::Tensor points, torch::Tensor tindex, const std::vector grid) { CHECK_INPUT(points); CHECK_INPUT(tindex); return init_cuda(points, tindex, grid); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("init", &init, "Initialize"); m.def("render", &render, "Render"); m.def("render_forward", &render_forward, "Render (forward pass only)"); } ================================================ FILE: lib/dvr/dvr.cu ================================================ // Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting // Modified by Haisong Liu #include #include #include #include #include #include #include #define MAX_D 1446 // 700 + 700 + 45 + 1 #define MAX_STEP 1000 enum LossType {L1, L2, ABSREL}; enum PhaseName {TEST, TRAIN}; template __global__ void init_cuda_kernel( const torch::PackedTensorAccessor32 points, const torch::PackedTensorAccessor32 tindex, torch::PackedTensorAccessor32 occupancy) { // batch index const auto n = blockIdx.y; // ray index const auto c = blockIdx.x * blockDim.x + threadIdx.x; // num of rays const auto M = points.size(1); const auto T = occupancy.size(1); // we allocated more threads than num_rays if (c < M) { // ray end point const auto t = tindex[n][c]; // invalid points assert(T == 1 || t < T); // if t < 0, it is a padded point if (t < 0) return; // time index for sigma // when T = 1, we have a static sigma const auto ts = (T == 1) ? 0 : t; // grid shape const int vzsize = occupancy.size(2); const int vysize = occupancy.size(3); const int vxsize = occupancy.size(4); // assert(vzsize + vysize + vxsize <= MAX_D); // end point const int vx = int(points[n][c][0]); const int vy = int(points[n][c][1]); const int vz = int(points[n][c][2]); // if (0 <= vx && vx < vxsize && 0 <= vy && vy < vysize && 0 <= vz && vz < vzsize) { occupancy[n][ts][vz][vy][vx] = 1; } } } template __global__ void render_forward_cuda_kernel( const torch::PackedTensorAccessor32 sigma, const torch::PackedTensorAccessor32 origin, const torch::PackedTensorAccessor32 points, const torch::PackedTensorAccessor32 tindex, // torch::PackedTensorAccessor32 pog, torch::PackedTensorAccessor32 pred_dist, torch::PackedTensorAccessor32 gt_dist, torch::PackedTensorAccessor32 coord_index, PhaseName train_phase) { // batch index const auto n = blockIdx.y; // ray index const auto c = blockIdx.x * blockDim.x + threadIdx.x; // num of rays const auto M = points.size(1); const auto T = sigma.size(1); // we allocated more threads than num_rays if (c < M) { // ray end point const auto t = tindex[n][c]; // invalid points // assert(t < T); assert(T == 1 || t < T); // time index for sigma // when T = 1, we have a static sigma const auto ts = (T == 1) ? 0 : t; // if t < 0, it is a padded point if (t < 0) return; // grid shape const int vzsize = sigma.size(2); const int vysize = sigma.size(3); const int vxsize = sigma.size(4); // assert(vzsize + vysize + vxsize <= MAX_D); // origin const double xo = origin[n][t][0]; const double yo = origin[n][t][1]; const double zo = origin[n][t][2]; // end point const double xe = points[n][c][0]; const double ye = points[n][c][1]; const double ze = points[n][c][2]; // locate the voxel where the origin resides const int vxo = int(xo); const int vyo = int(yo); const int vzo = int(zo); const int vxe = int(xe); const int vye = int(ye); const int vze = int(ze); // NOTE: new int vx = vxo; int vy = vyo; int vz = vzo; // origin to end const double rx = xe - xo; const double ry = ye - yo; const double rz = ze - zo; double gt_d = sqrt(rx * rx + ry * ry + rz * rz); // directional vector const double dx = rx / gt_d; const double dy = ry / gt_d; const double dz = rz / gt_d; // In which direction the voxel ids are incremented. const int stepX = (dx >= 0) ? 1 : -1; const int stepY = (dy >= 0) ? 1 : -1; const int stepZ = (dz >= 0) ? 1 : -1; // Distance along the ray to the next voxel border from the current position (tMaxX, tMaxY, tMaxZ). const double next_voxel_boundary_x = vx + (stepX < 0 ? 0 : 1); const double next_voxel_boundary_y = vy + (stepY < 0 ? 0 : 1); const double next_voxel_boundary_z = vz + (stepZ < 0 ? 0 : 1); // tMaxX, tMaxY, tMaxZ -- distance until next intersection with voxel-border // the value of t at which the ray crosses the first vertical voxel boundary double tMaxX = (dx!=0) ? (next_voxel_boundary_x - xo)/dx : DBL_MAX; // double tMaxY = (dy!=0) ? (next_voxel_boundary_y - yo)/dy : DBL_MAX; // double tMaxZ = (dz!=0) ? (next_voxel_boundary_z - zo)/dz : DBL_MAX; // // tDeltaX, tDeltaY, tDeltaZ -- // how far along the ray we must move for the horizontal component to equal the width of a voxel // the direction in which we traverse the grid // can only be FLT_MAX if we never go in that direction const double tDeltaX = (dx!=0) ? stepX/dx : DBL_MAX; const double tDeltaY = (dy!=0) ? stepY/dy : DBL_MAX; const double tDeltaZ = (dz!=0) ? stepZ/dz : DBL_MAX; int3 path[MAX_D]; double csd[MAX_D]; // cumulative sum of sigma times delta double p[MAX_D]; // alpha double d[MAX_D]; // forward raymarching with voxel traversal int step = 0; // total number of voxels traversed int count = 0; // number of voxels traversed inside the voxel grid double last_d = 0.0; // correct initialization // voxel traversal raycasting bool was_inside = false; while (true) { bool inside = (0 <= vx && vx < vxsize) && (0 <= vy && vy < vysize) && (0 <= vz && vz < vzsize); if (inside) { was_inside = true; path[count] = make_int3(vx, vy, vz); } else if (was_inside) { // was but no longer inside // we know we are not coming back so terminate break; } /*else if (last_d > gt_d) { break; } */ /*else { // has not gone inside yet // assert(count == 0); // (1) when we have hit the destination but haven't gone inside the voxel grid // (2) when we have traveled MAX_D voxels but haven't found one valid voxel // handle intersection corner cases in case of infinite loop bool hit = (vx == vxe && vy == vye && vz == vze); // this test seems brittle with corner cases if (hit || step >= MAX_D) break; //if (last_d >= gt_d || step >= MAX_D) break; } */ // _d represents the ray distance has traveled before escaping the current voxel cell double _d = 0.0; // voxel traversal if (tMaxX < tMaxY) { if (tMaxX < tMaxZ) { _d = tMaxX; vx += stepX; tMaxX += tDeltaX; } else { _d = tMaxZ; vz += stepZ; tMaxZ += tDeltaZ; } } else { if (tMaxY < tMaxZ) { _d = tMaxY; vy += stepY; tMaxY += tDeltaY; } else { _d = tMaxZ; vz += stepZ; tMaxZ += tDeltaZ; } } if (inside) { // get sigma at the current voxel const int3 &v = path[count]; // use the recorded index const double _sigma = sigma[n][ts][v.z][v.y][v.x]; const double _delta = max(0.0, _d - last_d); // THIS TURNS OUT IMPORTANT const double sd = _sigma * _delta; if (count == 0) { // the first voxel inside csd[count] = sd; p[count] = 1 - exp(-sd); } else { csd[count] = csd[count-1] + sd; p[count] = exp(-csd[count-1]) - exp(-csd[count]); } // record the traveled distance d[count] = _d; // count the number of voxels we have escaped count ++; } last_d = _d; step ++; if (step > MAX_STEP) { break; } } // the total number of voxels visited should not exceed this number assert(count <= MAX_D); if (count > 0) { // compute the expected ray distance //double exp_d = 0.0; double exp_d = d[count-1]; const int3 &v_init = path[count-1]; int x = v_init.x; int y = v_init.y; int z = v_init.z; for (int i = 0; i < count; i++) { //printf("%f\t%f\n",p[i], d[i]); //exp_d += p[i] * d[i]; const int3 &v = path[i]; const double occ = sigma[n][ts][v.z][v.y][v.x]; if (occ > 0.5) { exp_d = d[i]; x = v.x; y = v.y; z = v.z; break; } } //printf("%f\n",exp_d); // add an imaginary sample at the end point should gt_d exceeds max_d double p_out = exp(-csd[count-1]); double max_d = d[count-1]; // if (gt_d > max_d) // exp_d += (p_out * gt_d); // p_out is the probability the ray escapes the voxel grid //exp_d += (p_out * max_d); if (train_phase == 1) { gt_d = min(gt_d, max_d); } // write the rendered ray distance (max_d) pred_dist[n][c] = exp_d; gt_dist[n][c] = gt_d; coord_index[n][c][0] = double(x); coord_index[n][c][1] = double(y); coord_index[n][c][2] = double(z); // // write occupancy // for (int i = 0; i < count; i ++) { // const int3 &v = path[i]; // auto & occ = pog[n][t][v.z][v.y][v.x]; // if (p[i] >= occ) { // occ = p[i]; // } // } } } } /* * input shape * sigma : N x T x H x L x W * origin : N x T x 3 * points : N x M x 4 * output shape * dist : N x M */ std::vector render_forward_cuda( torch::Tensor sigma, torch::Tensor origin, torch::Tensor points, torch::Tensor tindex, const std::vector grid, std::string phase_name) { const auto N = points.size(0); // batch size const auto M = points.size(1); // num of rays const auto T = grid[0]; const auto H = grid[1]; const auto L = grid[2]; const auto W = grid[3]; const auto device = sigma.device(); const int threads = 1024; const dim3 blocks((M + threads - 1) / threads, N); // // const auto dtype = points.dtype(); // const auto options = torch::TensorOptions().dtype(dtype).device(device).requires_grad(false); // auto pog = torch::zeros({N, T, H, L, W}, options); // perform rendering auto gt_dist = -torch::ones({N, M}, device); auto pred_dist = -torch::ones({N, M}, device); auto coord_index = torch::zeros({N, M, 3}, device); PhaseName train_phase; if (phase_name.compare("test") == 0) { train_phase = TEST; } else if (phase_name.compare("train") == 0){ train_phase = TRAIN; } else { std::cout << "UNKNOWN PHASE NAME: " << phase_name << std::endl; exit(1); } AT_DISPATCH_FLOATING_TYPES(sigma.type(), "render_forward_cuda", ([&] { render_forward_cuda_kernel<<>>( sigma.packed_accessor32(), origin.packed_accessor32(), points.packed_accessor32(), tindex.packed_accessor32(), // pog.packed_accessor32(), pred_dist.packed_accessor32(), gt_dist.packed_accessor32(), coord_index.packed_accessor32(), train_phase); })); cudaDeviceSynchronize(); // return {pog, pred_dist, gt_dist}; return {pred_dist, gt_dist, coord_index}; } template __global__ void render_cuda_kernel( const torch::PackedTensorAccessor32 sigma, const torch::PackedTensorAccessor32 origin, const torch::PackedTensorAccessor32 points, const torch::PackedTensorAccessor32 tindex, // const torch::PackedTensorAccessor32 occupancy, torch::PackedTensorAccessor32 pred_dist, torch::PackedTensorAccessor32 gt_dist, torch::PackedTensorAccessor32 grad_sigma, // torch::PackedTensorAccessor32 grad_sigma_count, LossType loss_type) { // batch index const auto n = blockIdx.y; // ray index const auto c = blockIdx.x * blockDim.x + threadIdx.x; // num of rays const auto M = points.size(1); const auto T = sigma.size(1); // we allocated more threads than num_rays if (c < M) { // ray end point const auto t = tindex[n][c]; // invalid points // assert(t < T); assert(T == 1 || t < T); // time index for sigma // when T = 1, we have a static sigma const auto ts = (T == 1) ? 0 : t; // if t < 0, it is a padded point if (t < 0) return; // grid shape const int vzsize = sigma.size(2); const int vysize = sigma.size(3); const int vxsize = sigma.size(4); // assert(vzsize + vysize + vxsize <= MAX_D); // origin const double xo = origin[n][t][0]; const double yo = origin[n][t][1]; const double zo = origin[n][t][2]; // end point const double xe = points[n][c][0]; const double ye = points[n][c][1]; const double ze = points[n][c][2]; // locate the voxel where the origin resides const int vxo = int(xo); const int vyo = int(yo); const int vzo = int(zo); // const int vxe = int(xe); const int vye = int(ye); const int vze = int(ze); // NOTE: new int vx = vxo; int vy = vyo; int vz = vzo; // origin to end const double rx = xe - xo; const double ry = ye - yo; const double rz = ze - zo; double gt_d = sqrt(rx * rx + ry * ry + rz * rz); // directional vector const double dx = rx / gt_d; const double dy = ry / gt_d; const double dz = rz / gt_d; // In which direction the voxel ids are incremented. const int stepX = (dx >= 0) ? 1 : -1; const int stepY = (dy >= 0) ? 1 : -1; const int stepZ = (dz >= 0) ? 1 : -1; // Distance along the ray to the next voxel border from the current position (tMaxX, tMaxY, tMaxZ). const double next_voxel_boundary_x = vx + (stepX < 0 ? 0 : 1); const double next_voxel_boundary_y = vy + (stepY < 0 ? 0 : 1); const double next_voxel_boundary_z = vz + (stepZ < 0 ? 0 : 1); // tMaxX, tMaxY, tMaxZ -- distance until next intersection with voxel-border // the value of t at which the ray crosses the first vertical voxel boundary double tMaxX = (dx!=0) ? (next_voxel_boundary_x - xo)/dx : DBL_MAX; // double tMaxY = (dy!=0) ? (next_voxel_boundary_y - yo)/dy : DBL_MAX; // double tMaxZ = (dz!=0) ? (next_voxel_boundary_z - zo)/dz : DBL_MAX; // // tDeltaX, tDeltaY, tDeltaZ -- // how far along the ray we must move for the horizontal component to equal the width of a voxel // the direction in which we traverse the grid // can only be FLT_MAX if we never go in that direction const double tDeltaX = (dx!=0) ? stepX/dx : DBL_MAX; const double tDeltaY = (dy!=0) ? stepY/dy : DBL_MAX; const double tDeltaZ = (dz!=0) ? stepZ/dz : DBL_MAX; int3 path[MAX_D]; double csd[MAX_D]; // cumulative sum of sigma times delta double p[MAX_D]; // alpha double d[MAX_D]; double dt[MAX_D]; // forward raymarching with voxel traversal int step = 0; // total number of voxels traversed int count = 0; // number of voxels traversed inside the voxel grid double last_d = 0.0; // correct initialization // voxel traversal raycasting bool was_inside = false; while (true) { bool inside = (0 <= vx && vx < vxsize) && (0 <= vy && vy < vysize) && (0 <= vz && vz < vzsize); if (inside) { // now inside was_inside = true; path[count] = make_int3(vx, vy, vz); } else if (was_inside) { // was inside but no longer // we know we are not coming back so terminate break; } else if (last_d > gt_d) { break; } /* else { // has not gone inside yet // assert(count == 0); // (1) when we have hit the destination but haven't gone inside the voxel grid // (2) when we have traveled MAX_D voxels but haven't found one valid voxel // handle intersection corner cases in case of infinite loop // bool hit = (vx == vxe && vy == vye && vz == vze); // if (hit || step >= MAX_D) // break; if (last_d >= gt_d || step >= MAX_D) break; } */ // _d represents the ray distance has traveled before escaping the current voxel cell double _d = 0.0; // voxel traversal if (tMaxX < tMaxY) { if (tMaxX < tMaxZ) { _d = tMaxX; vx += stepX; tMaxX += tDeltaX; } else { _d = tMaxZ; vz += stepZ; tMaxZ += tDeltaZ; } } else { if (tMaxY < tMaxZ) { _d = tMaxY; vy += stepY; tMaxY += tDeltaY; } else { _d = tMaxZ; vz += stepZ; tMaxZ += tDeltaZ; } } if (inside) { // get sigma at the current voxel const int3 &v = path[count]; // use the recorded index const double _sigma = sigma[n][ts][v.z][v.y][v.x]; const double _delta = max(0.0, _d - last_d); // THIS TURNS OUT IMPORTANT const double sd = _sigma * _delta; if (count == 0) { // the first voxel inside csd[count] = sd; p[count] = 1 - exp(-sd); } else { csd[count] = csd[count-1] + sd; p[count] = exp(-csd[count-1]) - exp(-csd[count]); } // record the traveled distance d[count] = _d; dt[count] = _delta; // count the number of voxels we have escaped count ++; } last_d = _d; step ++; if (step > MAX_STEP) { break; } } // the total number of voxels visited should not exceed this number assert(count <= MAX_D); // WHEN THERE IS AN INTERSECTION BETWEEN THE RAY AND THE VOXEL GRID if (count > 0) { // compute the expected ray distance double exp_d = 0.0; for (int i = 0; i < count; i ++) exp_d += p[i] * d[i]; // add an imaginary sample at the end point should gt_d exceeds max_d double p_out = exp(-csd[count-1]); double max_d = d[count-1]; exp_d += (p_out * max_d); gt_d = min(gt_d, max_d); // write the rendered ray distance (max_d) pred_dist[n][c] = exp_d; gt_dist[n][c] = gt_d; /* backward raymarching */ double dd_dsigma[MAX_D]; for (int i = count - 1; i >= 0; i --) { // NOTE: probably need to double check again if (i == count - 1) dd_dsigma[i] = p_out * max_d; else dd_dsigma[i] = dd_dsigma[i+1] - exp(-csd[i]) * (d[i+1] - d[i]); } for (int i = count - 1; i >= 0; i --) dd_dsigma[i] *= dt[i]; // option 2: cap at the boundary for (int i = count - 1; i >= 0; i --) dd_dsigma[i] -= dt[i] * p_out * max_d; double dl_dd = 1.0; if (loss_type == L1) dl_dd = (exp_d >= gt_d) ? 1 : -1; else if (loss_type == L2) dl_dd = (exp_d - gt_d); else if (loss_type == ABSREL) dl_dd = (exp_d >= gt_d) ? (1.0/gt_d) : -(1.0/gt_d); // apply chain rule for (int i = 0; i < count; i ++) { const int3 &v = path[i]; // NOTE: potential race conditions when writing gradients grad_sigma[n][ts][v.z][v.y][v.x] += dl_dd * dd_dsigma[i]; // grad_sigma_count[n][ts][v.z][v.y][v.x] += 1; } } } } /* * input shape * sigma : N x T x H x L x W * origin : N x T x 3 * points : N x M x 4 * output shape * dist : N x M * loss : N x M * grad_sigma : N x T x H x L x W */ std::vector render_cuda( torch::Tensor sigma, torch::Tensor origin, torch::Tensor points, torch::Tensor tindex, std::string loss_name) { const auto N = points.size(0); // batch size const auto M = points.size(1); // num of rays const auto device = sigma.device(); const int threads = 1024; const dim3 blocks((M + threads - 1) / threads, N); // perform rendering auto gt_dist = -torch::ones({N, M}, device); auto pred_dist = -torch::ones({N, M}, device); auto grad_sigma = torch::zeros_like(sigma); // auto grad_sigma_count = torch::zeros_like(sigma); LossType loss_type; if (loss_name.compare("l1") == 0) { loss_type = L1; } else if (loss_name.compare("l2") == 0) { loss_type = L2; } else if (loss_name.compare("absrel") == 0) { loss_type = ABSREL; } else if (loss_name.compare("bce") == 0){ loss_type = L1; } else { std::cout << "UNKNOWN LOSS TYPE: " << loss_name << std::endl; exit(1); } AT_DISPATCH_FLOATING_TYPES(sigma.type(), "render_cuda", ([&] { render_cuda_kernel<<>>( sigma.packed_accessor32(), origin.packed_accessor32(), points.packed_accessor32(), tindex.packed_accessor32(), // occupancy.packed_accessor32(), pred_dist.packed_accessor32(), gt_dist.packed_accessor32(), grad_sigma.packed_accessor32(), // grad_sigma_count.packed_accessor32(), loss_type); })); cudaDeviceSynchronize(); // grad_sigma_count += (grad_sigma_count == 0); // grad_sigma /= grad_sigma_count; return {pred_dist, gt_dist, grad_sigma}; } /* * input shape * origin : N x T x 3 * points : N x M x 3 * tindex : N x M * output shape * occupancy: N x T x H x L x W */ torch::Tensor init_cuda( torch::Tensor points, torch::Tensor tindex, const std::vector grid) { const auto N = points.size(0); // batch size const auto M = points.size(1); // num of rays const auto T = grid[0]; const auto H = grid[1]; const auto L = grid[2]; const auto W = grid[3]; const auto dtype = points.dtype(); const auto device = points.device(); const auto options = torch::TensorOptions().dtype(dtype).device(device).requires_grad(false); auto occupancy = torch::zeros({N, T, H, L, W}, options); const int threads = 1024; const dim3 blocks((M + threads - 1) / threads, N); // initialize occupancy such that every voxel with one or more points is occupied AT_DISPATCH_FLOATING_TYPES(points.type(), "init_cuda", ([&] { init_cuda_kernel<<>>( points.packed_accessor32(), tindex.packed_accessor32(), occupancy.packed_accessor32()); })); // synchronize cudaDeviceSynchronize(); return occupancy; } ================================================ FILE: loaders/__init__.py ================================================ from .pipelines import __all__ from .nuscenes_dataset import CustomNuScenesDataset from .nuscenes_occ_dataset import NuSceneOcc __all__ = [ 'CustomNuScenesDataset', 'NuSceneOcc' ] ================================================ FILE: loaders/builder.py ================================================ from functools import partial from mmcv.parallel import collate from mmcv.runner import get_dist_info from torch.utils.data import DataLoader from mmdet.datasets.builder import worker_init_fn from mmdet.datasets.samplers import DistributedGroupSampler, DistributedSampler, GroupSampler def build_dataloader(dataset, samples_per_gpu, workers_per_gpu, num_gpus=1, dist=True, shuffle=True, seed=None, **kwargs): rank, world_size = get_dist_info() if dist: # DistributedGroupSampler will definitely shuffle the data to satisfy # that images on each GPU are in the same group if shuffle: sampler = DistributedGroupSampler( dataset, samples_per_gpu, world_size, rank, seed=seed) else: sampler = DistributedSampler( dataset, world_size, rank, shuffle=False, seed=seed) batch_size = samples_per_gpu num_workers = workers_per_gpu else: sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None batch_size = num_gpus * samples_per_gpu num_workers = num_gpus * workers_per_gpu init_fn = partial( worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None data_loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), pin_memory=False, worker_init_fn=init_fn, **kwargs) return data_loader ================================================ FILE: loaders/ego_pose_dataset.py ================================================ import torch import numpy as np from pyquaternion import Quaternion from torch.utils.data import Dataset np.set_printoptions(precision=3, suppress=True) def trans_matrix(T, R): tm = np.eye(4) tm[:3, :3] = R.rotation_matrix tm[:3, 3] = T return tm # A helper dataset for RayIoU. It is NOT used during training. class EgoPoseDataset(Dataset): def __init__(self, data_infos): super(EgoPoseDataset, self).__init__() self.data_infos = data_infos self.scene_frames = {} for info in data_infos: scene_name = info['scene_name'] if scene_name not in self.scene_frames: self.scene_frames[scene_name] = [] self.scene_frames[scene_name].append(info) def __len__(self): return len(self.data_infos) def get_ego_from_lidar(self, info): ego_from_lidar = trans_matrix( np.array(info['lidar2ego_translation']), Quaternion(info['lidar2ego_rotation'])) return ego_from_lidar def get_global_pose(self, info, inverse=False): global_from_ego = trans_matrix( np.array(info['ego2global_translation']), Quaternion(info['ego2global_rotation'])) ego_from_lidar = trans_matrix( np.array(info['lidar2ego_translation']), Quaternion(info['lidar2ego_rotation'])) pose = global_from_ego.dot(ego_from_lidar) if inverse: pose = np.linalg.inv(pose) return pose def __getitem__(self, idx): info = self.data_infos[idx] ref_sample_token = info['token'] ref_lidar_from_global = self.get_global_pose(info, inverse=True) ref_ego_from_lidar = self.get_ego_from_lidar(info) scene_frame = self.scene_frames[info['scene_name']] ref_index = scene_frame.index(info) # NOTE: getting output frames output_origin_list = [] for curr_index in range(len(scene_frame)): # if this exists a valid target if curr_index == ref_index: origin_tf = np.array([0.0, 0.0, 0.0], dtype=np.float32) else: # transform from the current lidar frame to global and then to the reference lidar frame global_from_curr = self.get_global_pose(scene_frame[curr_index], inverse=False) ref_from_curr = ref_lidar_from_global.dot(global_from_curr) origin_tf = np.array(ref_from_curr[:3, 3], dtype=np.float32) origin_tf_pad = np.ones([4]) origin_tf_pad[:3] = origin_tf # pad to [4] origin_tf = np.dot(ref_ego_from_lidar[:3], origin_tf_pad.T).T # [3] # origin if np.abs(origin_tf[0]) < 39 and np.abs(origin_tf[1]) < 39: output_origin_list.append(origin_tf) # select 8 origins if len(output_origin_list) > 8: select_idx = np.round(np.linspace(0, len(output_origin_list) - 1, 8)).astype(np.int64) output_origin_list = [output_origin_list[i] for i in select_idx] output_origin_tensor = torch.from_numpy(np.stack(output_origin_list)) # [T, 3] return (ref_sample_token, output_origin_tensor) ================================================ FILE: loaders/nuscenes_dataset.py ================================================ import os import numpy as np from mmdet.datasets import DATASETS from mmdet3d.datasets import NuScenesDataset from pyquaternion import Quaternion @DATASETS.register_module() class CustomNuScenesDataset(NuScenesDataset): def collect_sweeps(self, index, into_past=60, into_future=0): all_sweeps_prev = [] curr_index = index while len(all_sweeps_prev) < into_past: curr_sweeps = self.data_infos[curr_index]['sweeps'] if len(curr_sweeps) == 0: break all_sweeps_prev.extend(curr_sweeps) all_sweeps_prev.append(self.data_infos[curr_index - 1]['cams']) curr_index = curr_index - 1 all_sweeps_next = [] curr_index = index + 1 while len(all_sweeps_next) < into_future: if curr_index >= len(self.data_infos): break curr_sweeps = self.data_infos[curr_index]['sweeps'] all_sweeps_next.extend(curr_sweeps[::-1]) all_sweeps_next.append(self.data_infos[curr_index]['cams']) curr_index = curr_index + 1 return all_sweeps_prev, all_sweeps_next def get_data_info(self, index): info = self.data_infos[index] sweeps_prev, sweeps_next = self.collect_sweeps(index) ego2global_translation = info['ego2global_translation'] ego2global_rotation = info['ego2global_rotation'] lidar2ego_translation = info['lidar2ego_translation'] lidar2ego_rotation = info['lidar2ego_rotation'] ego2global_rotation = Quaternion(ego2global_rotation).rotation_matrix lidar2ego_rotation = Quaternion(lidar2ego_rotation).rotation_matrix input_dict = dict( sample_idx=info['token'], sweeps={'prev': sweeps_prev, 'next': sweeps_next}, timestamp=info['timestamp'] / 1e6, ego2global_translation=ego2global_translation, ego2global_rotation=ego2global_rotation, lidar2ego_translation=lidar2ego_translation, lidar2ego_rotation=lidar2ego_rotation, ) if self.modality['use_camera']: img_paths = [] img_timestamps = [] lidar2img_rts = [] for _, cam_info in info['cams'].items(): img_paths.append(os.path.relpath(cam_info['data_path'])) img_timestamps.append(cam_info['timestamp'] / 1e6) # obtain lidar to image transformation matrix lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation']) lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T lidar2cam_rt = np.eye(4) lidar2cam_rt[:3, :3] = lidar2cam_r.T lidar2cam_rt[3, :3] = -lidar2cam_t intrinsic = cam_info['cam_intrinsic'] viewpad = np.eye(4) viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic lidar2img_rt = (viewpad @ lidar2cam_rt.T) lidar2img_rts.append(lidar2img_rt) input_dict.update(dict( img_filename=img_paths, img_timestamp=img_timestamps, lidar2img=lidar2img_rts, )) if not self.test_mode: annos = self.get_ann_info(index) input_dict['ann_info'] = annos return input_dict ================================================ FILE: loaders/nuscenes_occ_dataset.py ================================================ import os import mmcv import glob import torch import numpy as np from tqdm import tqdm from mmdet.datasets import DATASETS from mmdet3d.datasets import NuScenesDataset from nuscenes.eval.common.utils import Quaternion from nuscenes.utils.geometry_utils import transform_matrix from torch.utils.data import DataLoader from models.utils import sparse2dense from .ray_metrics import main_rayiou, main_raypq from .ego_pose_dataset import EgoPoseDataset from configs.r50_nuimg_704x256_8f import occ_class_names as occ3d_class_names from configs.r50_nuimg_704x256_8f_openocc import occ_class_names as openocc_class_names @DATASETS.register_module() class NuSceneOcc(NuScenesDataset): def __init__(self, occ_gt_root, *args, **kwargs): super().__init__(filter_empty_gt=False, *args, **kwargs) self.occ_gt_root = occ_gt_root self.data_infos = self.load_annotations(self.ann_file) self.token2scene = {} for gt_path in glob.glob(os.path.join(self.occ_gt_root, '*/*/*.npz')): token = gt_path.split('/')[-2] scene_name = gt_path.split('/')[-3] self.token2scene[token] = scene_name for i in range(len(self.data_infos)): scene_name = self.token2scene[self.data_infos[i]['token']] self.data_infos[i]['scene_name'] = scene_name def collect_sweeps(self, index, into_past=150, into_future=0): all_sweeps_prev = [] curr_index = index while len(all_sweeps_prev) < into_past: curr_sweeps = self.data_infos[curr_index]['sweeps'] if len(curr_sweeps) == 0: break all_sweeps_prev.extend(curr_sweeps) all_sweeps_prev.append(self.data_infos[curr_index - 1]['cams']) curr_index = curr_index - 1 all_sweeps_next = [] curr_index = index + 1 while len(all_sweeps_next) < into_future: if curr_index >= len(self.data_infos): break curr_sweeps = self.data_infos[curr_index]['sweeps'] all_sweeps_next.extend(curr_sweeps[::-1]) all_sweeps_next.append(self.data_infos[curr_index]['cams']) curr_index = curr_index + 1 return all_sweeps_prev, all_sweeps_next def get_data_info(self, index): info = self.data_infos[index] sweeps_prev, sweeps_next = self.collect_sweeps(index) ego2global_translation = info['ego2global_translation'] ego2global_rotation = info['ego2global_rotation'] lidar2ego_translation = info['lidar2ego_translation'] lidar2ego_rotation = info['lidar2ego_rotation'] ego2global_rotation_mat = Quaternion(ego2global_rotation).rotation_matrix lidar2ego_rotation_mat = Quaternion(lidar2ego_rotation).rotation_matrix input_dict = dict( sample_idx=info['token'], sweeps={'prev': sweeps_prev, 'next': sweeps_next}, timestamp=info['timestamp'] / 1e6, ego2global_translation=ego2global_translation, ego2global_rotation=ego2global_rotation_mat, lidar2ego_translation=lidar2ego_translation, lidar2ego_rotation=lidar2ego_rotation_mat, ) ego2lidar = transform_matrix(lidar2ego_translation, Quaternion(lidar2ego_rotation), inverse=True) input_dict['ego2lidar'] = [ego2lidar for _ in range(6)] input_dict['occ_path'] = os.path.join(self.occ_gt_root, info['scene_name'], info['token'], 'labels.npz') if self.modality['use_camera']: img_paths = [] img_timestamps = [] lidar2img_rts = [] for _, cam_info in info['cams'].items(): img_paths.append(os.path.relpath(cam_info['data_path'])) img_timestamps.append(cam_info['timestamp'] / 1e6) # obtain lidar to image transformation matrix lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation']) lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T lidar2cam_rt = np.eye(4) lidar2cam_rt[:3, :3] = lidar2cam_r.T lidar2cam_rt[3, :3] = -lidar2cam_t intrinsic = cam_info['cam_intrinsic'] viewpad = np.eye(4) viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic lidar2img_rt = (viewpad @ lidar2cam_rt.T) lidar2img_rts.append(lidar2img_rt) input_dict.update(dict( img_filename=img_paths, img_timestamp=img_timestamps, lidar2img=lidar2img_rts, )) if not self.test_mode: annos = self.get_ann_info(index) input_dict['ann_info'] = annos return input_dict def evaluate(self, occ_results, runner=None, show_dir=None, **eval_kwargs): occ_gts, occ_preds, inst_gts, inst_preds, lidar_origins = [], [], [], [], [] print('\nStarting Evaluation...') sample_tokens = [info['token'] for info in self.data_infos] for batch in DataLoader(EgoPoseDataset(self.data_infos), num_workers=8): token = batch[0][0] output_origin = batch[1] data_id = sample_tokens.index(token) info = self.data_infos[data_id] occ_path = os.path.join(self.occ_gt_root, info['scene_name'], info['token'], 'labels.npz') occ_gt = np.load(occ_path, allow_pickle=True) gt_semantics = occ_gt['semantics'] occ_pred = occ_results[data_id] sem_pred = torch.from_numpy(occ_pred['sem_pred']) # [B, N] occ_loc = torch.from_numpy(occ_pred['occ_loc'].astype(np.int64)) # [B, N, 3] data_type = self.occ_gt_root.split('/')[-1] if data_type == 'occ3d' or data_type == 'occ3d_panoptic': occ_class_names = occ3d_class_names elif data_type == 'openocc_v2': occ_class_names = openocc_class_names else: raise ValueError free_id = len(occ_class_names) - 1 occ_size = list(gt_semantics.shape) sem_pred, _ = sparse2dense(occ_loc, sem_pred, dense_shape=occ_size, empty_value=free_id) sem_pred = sem_pred.squeeze(0).numpy() if 'pano_inst' in occ_pred.keys(): pano_inst = torch.from_numpy(occ_pred['pano_inst']) pano_sem = torch.from_numpy(occ_pred['pano_sem']) pano_inst, _ = sparse2dense(occ_loc, pano_inst, dense_shape=occ_size, empty_value=0) pano_sem, _ = sparse2dense(occ_loc, pano_sem, dense_shape=occ_size, empty_value=free_id) pano_inst = pano_inst.squeeze(0).numpy() pano_sem = pano_sem.squeeze(0).numpy() sem_pred = pano_sem gt_instances = occ_gt['instances'] inst_gts.append(gt_instances) inst_preds.append(pano_inst) lidar_origins.append(output_origin) occ_gts.append(gt_semantics) occ_preds.append(sem_pred) if len(inst_preds) > 0: results = main_raypq(occ_preds, occ_gts, inst_preds, inst_gts, lidar_origins, occ_class_names=occ_class_names) results.update(main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names)) return results else: return main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names) def format_results(self, occ_results, submission_prefix, **kwargs): if submission_prefix is not None: mmcv.mkdir_or_exist(submission_prefix) for index, occ_pred in enumerate(tqdm(occ_results)): info = self.data_infos[index] sample_token = info['token'] save_path = os.path.join(submission_prefix, '{}.npz'.format(sample_token)) np.savez_compressed(save_path, occ_pred.astype(np.uint8)) print('\nFinished.') ================================================ FILE: loaders/old_metrics.py ================================================ import os import numpy as np from sklearn.neighbors import KDTree from termcolor import colored from functools import reduce from typing import Iterable np.seterr(divide='ignore', invalid='ignore') os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" def pcolor(string, color, on_color=None, attrs=None): """ Produces a colored string for printing Parameters ---------- string : str String that will be colored color : str Color to use on_color : str Background color to use attrs : list of str Different attributes for the string Returns ------- string: str Colored string """ return colored(string, color, on_color, attrs) def getCellCoordinates(points, voxelSize): return (points / voxelSize).astype(np.int) def getNumUniqueCells(cells): M = cells.max() + 1 return np.unique(cells[:, 0] + M * cells[:, 1] + M ** 2 * cells[:, 2]).shape[0] class Metric_mIoU(): def __init__(self, save_dir='.', num_classes=18, use_lidar_mask=False, use_image_mask=False, ): if num_classes == 18: self.class_names = [ 'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation','free' ] elif num_classes == 2: self.class_names = ['non-free', 'free'] self.save_dir = save_dir self.use_lidar_mask = use_lidar_mask self.use_image_mask = use_image_mask self.num_classes = num_classes self.point_cloud_range = [-40.0, -40.0, -1.0, 40.0, 40.0, 5.4] self.occupancy_size = [0.4, 0.4, 0.4] self.voxel_size = 0.4 self.occ_xdim = int((self.point_cloud_range[3] - self.point_cloud_range[0]) / self.occupancy_size[0]) self.occ_ydim = int((self.point_cloud_range[4] - self.point_cloud_range[1]) / self.occupancy_size[1]) self.occ_zdim = int((self.point_cloud_range[5] - self.point_cloud_range[2]) / self.occupancy_size[2]) self.voxel_num = self.occ_xdim * self.occ_ydim * self.occ_zdim self.hist = np.zeros((self.num_classes, self.num_classes)) self.cnt = 0 def hist_info(self, n_cl, pred, gt): """ build confusion matrix # empty classes:0 non-empty class: 0-16 free voxel class: 17 Args: n_cl (int): num_classes_occupancy pred (1-d array): pred_occupancy_label gt (1-d array): gt_occupancu_label Returns: tuple:(hist, correctly number_predicted_labels, num_labelled_sample) """ assert pred.shape == gt.shape k = (gt >= 0) & (gt < n_cl) # exclude 255 labeled = np.sum(k) correct = np.sum((pred[k] == gt[k])) return ( np.bincount( n_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl ** 2 ).reshape(n_cl, n_cl), correct, labeled, ) def per_class_iu(self, hist): #return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) result = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) result[hist.sum(1) == 0] = float('nan') return result def compute_mIoU(self, pred, label, n_classes): hist = np.zeros((n_classes, n_classes)) new_hist, correct, labeled = self.hist_info(n_classes, pred.flatten(), label.flatten()) hist += new_hist mIoUs = self.per_class_iu(hist) # for ind_class in range(n_classes): # print(str(round(mIoUs[ind_class] * 100, 2))) # print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2))) return round(np.nanmean(mIoUs) * 100, 2), hist def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera): self.cnt += 1 if self.use_image_mask: masked_semantics_gt = semantics_gt[mask_camera] masked_semantics_pred = semantics_pred[mask_camera] elif self.use_lidar_mask: masked_semantics_gt = semantics_gt[mask_lidar] masked_semantics_pred = semantics_pred[mask_lidar] else: masked_semantics_gt = semantics_gt masked_semantics_pred = semantics_pred if self.num_classes == 2: masked_semantics_pred = np.copy(masked_semantics_pred) masked_semantics_gt = np.copy(masked_semantics_gt) masked_semantics_pred[masked_semantics_pred < 17] = 0 masked_semantics_pred[masked_semantics_pred == 17] = 1 masked_semantics_gt[masked_semantics_gt < 17] = 0 masked_semantics_gt[masked_semantics_gt == 17] = 1 _, _hist = self.compute_mIoU(masked_semantics_pred, masked_semantics_gt, self.num_classes) self.hist += _hist def count_miou(self): mIoU = self.per_class_iu(self.hist) # assert cnt == num_samples, 'some samples are not included in the miou calculation' print(f'===> per class IoU of {self.cnt} samples:') for ind_class in range(self.num_classes-1): print(f'===> {self.class_names[ind_class]} - IoU = ' + str(round(mIoU[ind_class] * 100, 2))) print(f'===> mIoU of {self.cnt} samples: ' + str(round(np.nanmean(mIoU[:self.num_classes-1]) * 100, 2))) # print(f'===> sample-wise averaged mIoU of {cnt} samples: ' + str(round(np.nanmean(mIoU_avg), 2))) return round(np.nanmean(mIoU[:self.num_classes-1]) * 100, 2) class Metric_FScore(): def __init__(self, leaf_size=10, threshold_acc=0.6, threshold_complete=0.6, voxel_size=[0.4, 0.4, 0.4], range=[-40, -40, -1, 40, 40, 5.4], void=[17, 255], use_lidar_mask=False, use_image_mask=False, ) -> None: self.leaf_size = leaf_size self.threshold_acc = threshold_acc self.threshold_complete = threshold_complete self.voxel_size = voxel_size self.range = range self.void = void self.use_lidar_mask = use_lidar_mask self.use_image_mask = use_image_mask self.cnt=0 self.tot_acc = 0. self.tot_cmpl = 0. self.tot_f1_mean = 0. self.eps = 1e-8 def voxel2points(self, voxel): # occIdx = torch.where(torch.logical_and(voxel != FREE, voxel != NOT_OBSERVED)) # if isinstance(voxel, np.ndarray): voxel = torch.from_numpy(voxel) mask = np.logical_not(reduce(np.logical_or, [voxel == self.void[i] for i in range(len(self.void))])) occIdx = np.where(mask) points = np.concatenate((occIdx[0][:, None] * self.voxel_size[0] + self.voxel_size[0] / 2 + self.range[0], \ occIdx[1][:, None] * self.voxel_size[1] + self.voxel_size[1] / 2 + self.range[1], \ occIdx[2][:, None] * self.voxel_size[2] + self.voxel_size[2] / 2 + self.range[2]), axis=1) return points def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera ): # for scene_token in tqdm(preds_dict.keys()): self.cnt += 1 if self.use_image_mask: semantics_gt[mask_camera == False] = 255 semantics_pred[mask_camera == False] = 255 elif self.use_lidar_mask: semantics_gt[mask_lidar == False] = 255 semantics_pred[mask_lidar == False] = 255 else: pass ground_truth = self.voxel2points(semantics_gt) prediction = self.voxel2points(semantics_pred) if prediction.shape[0] == 0: accuracy=0 completeness=0 fmean=0 else: prediction_tree = KDTree(prediction, leaf_size=self.leaf_size) ground_truth_tree = KDTree(ground_truth, leaf_size=self.leaf_size) complete_distance, _ = prediction_tree.query(ground_truth) complete_distance = complete_distance.flatten() accuracy_distance, _ = ground_truth_tree.query(prediction) accuracy_distance = accuracy_distance.flatten() # evaluate completeness complete_mask = complete_distance < self.threshold_complete completeness = complete_mask.mean() # evalute accuracy accuracy_mask = accuracy_distance < self.threshold_acc accuracy = accuracy_mask.mean() fmean = 2.0 / (1 / (accuracy+self.eps) + 1 / (completeness+self.eps)) self.tot_acc += accuracy self.tot_cmpl += completeness self.tot_f1_mean += fmean def count_fscore(self,): base_color, attrs = 'red', ['bold', 'dark'] print(pcolor('\n######## F score: {} #######'.format(self.tot_f1_mean / self.cnt), base_color, attrs=attrs)) return self.tot_f1_mean / self.cnt class Metric_mRecall(): def __init__(self, save_dir='.', num_classes=18, pred_classes=2, use_lidar_mask=False, use_image_mask=False, ): if num_classes == 18: self.class_names = [ 'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation','free' ] elif num_classes == 2: self.class_names = ['non-free', 'free'] self.pred_classes = pred_classes self.save_dir = save_dir self.use_lidar_mask = use_lidar_mask self.use_image_mask = use_image_mask self.num_classes = num_classes self.point_cloud_range = [-40.0, -40.0, -1.0, 40.0, 40.0, 5.4] self.occupancy_size = [0.4, 0.4, 0.4] self.voxel_size = 0.4 self.occ_xdim = int((self.point_cloud_range[3] - self.point_cloud_range[0]) / self.occupancy_size[0]) self.occ_ydim = int((self.point_cloud_range[4] - self.point_cloud_range[1]) / self.occupancy_size[1]) self.occ_zdim = int((self.point_cloud_range[5] - self.point_cloud_range[2]) / self.occupancy_size[2]) self.voxel_num = self.occ_xdim * self.occ_ydim * self.occ_zdim self.hist = np.zeros((self.num_classes, self.pred_classes)) # n_cl, p_cl self.cnt = 0 def hist_info(self, n_cl, p_cl, pred, gt): """ build confusion matrix # empty classes:0 non-empty class: 0-16 free voxel class: 17 Args: n_cl (int): num_classes_occupancy pred (1-d array): pred_occupancy_label gt (1-d array): gt_occupancu_label Returns: tuple:(hist, correctly number_predicted_labels, num_labelled_sample) """ assert pred.shape == gt.shape k = (gt >= 0) & (gt < n_cl) # exclude 255 labeled = np.sum(k) correct = np.sum((pred[k] == gt[k])) return ( np.bincount( p_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl * p_cl ).reshape(n_cl, p_cl), # 18, 2 correct, labeled, ) def per_class_recall(self, hist): return hist[:, 1] / hist.sum(1) ## recall def compute_mRecall(self, pred, label, n_classes, p_classes): hist = np.zeros((n_classes, p_classes)) new_hist, correct, labeled = self.hist_info(n_classes, p_classes, pred.flatten(), label.flatten()) hist += new_hist mRecalls = self.per_class_recall(hist) # for ind_class in range(n_classes): # print(str(round(mIoUs[ind_class] * 100, 2))) # print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2))) return round(np.nanmean(mRecalls) * 100, 2), hist def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera): self.cnt += 1 if self.use_image_mask: masked_semantics_gt = semantics_gt[mask_camera] masked_semantics_pred = semantics_pred[mask_camera] elif self.use_lidar_mask: masked_semantics_gt = semantics_gt[mask_lidar] masked_semantics_pred = semantics_pred[mask_lidar] else: masked_semantics_gt = semantics_gt masked_semantics_pred = semantics_pred if self.pred_classes == 2: masked_semantics_pred = np.copy(masked_semantics_pred) masked_semantics_gt = np.copy(masked_semantics_gt) masked_semantics_pred[masked_semantics_pred < 17] = 1 masked_semantics_pred[masked_semantics_pred == 17] = 0 # 0 is free _, _hist = self.compute_mRecall(masked_semantics_pred, masked_semantics_gt, self.num_classes, self.pred_classes) self.hist += _hist def count_mrecall(self): mRecall = self.per_class_recall(self.hist) # assert cnt == num_samples, 'some samples are not included in the miou calculation' print(f'===> per class Recall of {self.cnt} samples:') for ind_class in range(self.num_classes-1): print(f'===> {self.class_names[ind_class]} - Recall = ' + str(round(mRecall[ind_class] * 100, 2))) print(f'===> mRecall of {self.cnt} samples: ' + str(round(np.nanmean(mRecall[:self.num_classes-1]) * 100, 2))) return round(np.nanmean(mRecall[:self.num_classes-1]) * 100, 2) # modified from https://github.com/open-mmlab/mmdetection3d/blob/main/mmdet3d/evaluation/functional/panoptic_seg_eval.py#L10 class Metric_Panoptic(): def __init__(self, save_dir='.', num_classes=18, use_lidar_mask=False, use_image_mask=False, ignore_index: Iterable[int]=[], ): """ Args: ignore_index (llist): Class ids that not be considered in pq counting. """ if num_classes == 18: self.class_names = [ 'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', 'vegetation','free' ] else: raise ValueError self.save_dir = save_dir self.num_classes = num_classes self.use_lidar_mask = use_lidar_mask self.use_image_mask = use_image_mask self.ignore_index = ignore_index self.id_offset = 2 ** 16 self.eps = 1e-5 self.min_num_points = 20 self.include = np.array( [n for n in range(self.num_classes - 1) if n not in self.ignore_index], dtype=int) self.cnt = 0 # panoptic stuff self.pan_tp = np.zeros(self.num_classes, dtype=int) self.pan_iou = np.zeros(self.num_classes, dtype=np.double) self.pan_fp = np.zeros(self.num_classes, dtype=int) self.pan_fn = np.zeros(self.num_classes, dtype=int) def add_batch(self,semantics_pred,semantics_gt,instances_pred,instances_gt,mask_lidar,mask_camera): self.cnt += 1 if self.use_image_mask: masked_semantics_gt = semantics_gt[mask_camera] masked_semantics_pred = semantics_pred[mask_camera] masked_instances_gt = instances_gt[mask_camera] masked_instances_pred = instances_pred[mask_camera] elif self.use_lidar_mask: masked_semantics_gt = semantics_gt[mask_lidar] masked_semantics_pred = semantics_pred[mask_lidar] masked_instances_gt = instances_gt[mask_lidar] masked_instances_pred = instances_pred[mask_lidar] else: masked_semantics_gt = semantics_gt masked_semantics_pred = semantics_pred masked_instances_gt = instances_gt masked_instances_pred = instances_pred self.add_panoptic_sample(masked_semantics_pred, masked_semantics_gt, masked_instances_pred, masked_instances_gt) def add_panoptic_sample(self, semantics_pred, semantics_gt, instances_pred, instances_gt): """Add one sample of panoptic predictions and ground truths for evaluation. Args: semantics_pred (np.ndarray): Semantic predictions. semantics_gt (np.ndarray): Semantic ground truths. instances_pred (np.ndarray): Instance predictions. instances_gt (np.ndarray): Instance ground truths. """ # get instance_class_id from instance_gt instance_class_ids = [self.num_classes - 1] for i in range(1, instances_gt.max() + 1): class_id = np.unique(semantics_gt[instances_gt == i]) # assert class_id.shape[0] == 1, "each instance must belong to only one class" if class_id.shape[0] == 1: instance_class_ids.append(class_id[0]) else: instance_class_ids.append(self.num_classes - 1) instance_class_ids = np.array(instance_class_ids) instance_count = 1 final_instance_class_ids = [] final_instances = np.zeros_like(instances_gt) # empty space has instance id "0" for class_id in range(self.num_classes - 1): if np.sum(semantics_gt == class_id) == 0: continue if self.class_names[class_id] in ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'motorcycle', 'bicycle', 'pedestrian']: # treat as instances for instance_id in range(len(instance_class_ids)): if instance_class_ids[instance_id] != class_id: continue final_instances[instances_gt == instance_id] = instance_count instance_count += 1 final_instance_class_ids.append(class_id) else: # treat as semantics final_instances[semantics_gt == class_id] = instance_count instance_count += 1 final_instance_class_ids.append(class_id) instances_gt = final_instances # avoid zero (ignored label) instances_pred = instances_pred + 1 instances_gt = instances_gt + 1 for cl in self.ignore_index: # make a mask for this class gt_not_in_excl_mask = semantics_gt != cl # remove all other points semantics_pred = semantics_pred[gt_not_in_excl_mask] semantics_gt = semantics_gt[gt_not_in_excl_mask] instances_pred = instances_pred[gt_not_in_excl_mask] instances_gt = instances_gt[gt_not_in_excl_mask] # for each class (except the ignored ones) for cl in self.include: # get a class mask pred_inst_in_cl_mask = semantics_pred == cl gt_inst_in_cl_mask = semantics_gt == cl # get instance points in class (makes outside stuff 0) pred_inst_in_cl = instances_pred * pred_inst_in_cl_mask.astype(int) gt_inst_in_cl = instances_gt * gt_inst_in_cl_mask.astype(int) # generate the areas for each unique instance prediction unique_pred, counts_pred = np.unique( pred_inst_in_cl[pred_inst_in_cl > 0], return_counts=True) id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)} matched_pred = np.array([False] * unique_pred.shape[0]) # generate the areas for each unique instance gt_np unique_gt, counts_gt = np.unique( gt_inst_in_cl[gt_inst_in_cl > 0], return_counts=True) id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)} matched_gt = np.array([False] * unique_gt.shape[0]) # generate intersection using offset valid_combos = np.logical_and(pred_inst_in_cl > 0, gt_inst_in_cl > 0) id_offset_combo = pred_inst_in_cl[ valid_combos] + self.id_offset * gt_inst_in_cl[valid_combos] unique_combo, counts_combo = np.unique( id_offset_combo, return_counts=True) # generate an intersection map # count the intersections with over 0.5 IoU as TP gt_labels = unique_combo // self.id_offset pred_labels = unique_combo % self.id_offset gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels]) pred_areas = np.array( [counts_pred[id2idx_pred[id]] for id in pred_labels]) intersections = counts_combo unions = gt_areas + pred_areas - intersections ious = intersections.astype(float) / unions.astype(float) tp_indexes = ious > 0.5 self.pan_tp[cl] += np.sum(tp_indexes) self.pan_iou[cl] += np.sum(ious[tp_indexes]) matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True matched_pred[[id2idx_pred[id] for id in pred_labels[tp_indexes]]] = True # count the FN if len(counts_gt) > 0: self.pan_fn[cl] += np.sum( np.logical_and(counts_gt >= self.min_num_points, ~matched_gt)) # count the FP if len(matched_pred) > 0: self.pan_fp[cl] += np.sum( np.logical_and(counts_pred >= self.min_num_points, ~matched_pred)) def count_pq(self, ): sq_all = self.pan_iou.astype(np.double) / np.maximum( self.pan_tp.astype(np.double), self.eps) rq_all = self.pan_tp.astype(np.double) / np.maximum( self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double) + 0.5 * self.pan_fn.astype(np.double), self.eps) pq_all = sq_all * rq_all # mask classes not occurring in dataset mask = (self.pan_tp + self.pan_fp + self.pan_fn) > 0 sq_all[~mask] = float('nan') rq_all[~mask] = float('nan') pq_all[~mask] = float('nan') # then do the REAL mean (no ignored classes) sq = round(np.nanmean(sq_all[self.include]) * 100, 2) rq = round(np.nanmean(rq_all[self.include]) * 100, 2) pq = round(np.nanmean(pq_all[self.include]) * 100, 2) print(f'===> per class sq, rq, pq of {self.cnt} samples:') for ind_class in self.include: print(f'===> {self.class_names[ind_class]} -' + \ f' sq = {round(sq_all[ind_class] * 100, 2)},' + \ f' rq = {round(rq_all[ind_class] * 100, 2)},' + \ f' pq = {round(pq_all[ind_class] * 100, 2)}') print(f'===> sq of {self.cnt} samples: ' + str(sq)) print(f'===> rq of {self.cnt} samples: ' + str(rq)) print(f'===> pq of {self.cnt} samples: ' + str(pq)) return (pq, sq, rq) ================================================ FILE: loaders/pipelines/__init__.py ================================================ from .loading import LoadMultiViewImageFromMultiSweeps, LoadOccGTFromFile from .transforms import PadMultiViewImage, NormalizeMultiviewImage, PhotoMetricDistortionMultiViewImage __all__ = [ 'LoadMultiViewImageFromMultiSweeps', 'PadMultiViewImage', 'NormalizeMultiviewImage', 'PhotoMetricDistortionMultiViewImage', 'LoadOccGTFromFile' ] ================================================ FILE: loaders/pipelines/loading.py ================================================ import os import mmcv import torch import numpy as np from mmdet.datasets.builder import PIPELINES from numpy.linalg import inv from mmcv.runner import get_dist_info from mmcv.parallel import DataContainer as DC from mmdet.datasets.pipelines import to_tensor from torchvision.transforms.functional import rotate def compose_lidar2img(ego2global_translation_curr, ego2global_rotation_curr, lidar2ego_translation_curr, lidar2ego_rotation_curr, sensor2global_translation_past, sensor2global_rotation_past, cam_intrinsic_past): R = sensor2global_rotation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T) T = sensor2global_translation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T) T -= ego2global_translation_curr @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T) + lidar2ego_translation_curr @ inv(lidar2ego_rotation_curr).T lidar2cam_r = inv(R.T) lidar2cam_t = T @ lidar2cam_r.T lidar2cam_rt = np.eye(4) lidar2cam_rt[:3, :3] = lidar2cam_r.T lidar2cam_rt[3, :3] = -lidar2cam_t viewpad = np.eye(4) viewpad[:cam_intrinsic_past.shape[0], :cam_intrinsic_past.shape[1]] = cam_intrinsic_past lidar2img = (viewpad @ lidar2cam_rt.T).astype(np.float32) return lidar2img @PIPELINES.register_module() class LoadMultiViewImageFromMultiSweeps(object): def __init__(self, sweeps_num=5, color_type='color', test_mode=False): self.sweeps_num = sweeps_num self.color_type = color_type self.test_mode = test_mode self.train_interval = [4, 8] self.test_interval = 6 try: mmcv.use_backend('turbojpeg') except ImportError: mmcv.use_backend('cv2') def load_offline(self, results): cam_types = [ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT' ] if len(results['sweeps']['prev']) == 0: for _ in range(self.sweeps_num): for j in range(len(cam_types)): results['img'].append(results['img'][j]) results['img_timestamp'].append(results['img_timestamp'][j]) results['filename'].append(results['filename'][j]) results['lidar2img'].append(np.copy(results['lidar2img'][j])) if 'ego2lidar' in results: results['ego2lidar'].append(results['ego2lidar'][0]) else: if self.test_mode: interval = self.test_interval choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)] elif len(results['sweeps']['prev']) <= self.sweeps_num: pad_len = self.sweeps_num - len(results['sweeps']['prev']) choices = list(range(len(results['sweeps']['prev']))) + [len(results['sweeps']['prev']) - 1] * pad_len else: max_interval = len(results['sweeps']['prev']) // self.sweeps_num max_interval = min(max_interval, self.train_interval[1]) min_interval = min(max_interval, self.train_interval[0]) interval = np.random.randint(min_interval, max_interval + 1) choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)] for idx in sorted(list(choices)): sweep_idx = min(idx, len(results['sweeps']['prev']) - 1) sweep = results['sweeps']['prev'][sweep_idx] if len(sweep.keys()) < len(cam_types): sweep = results['sweeps']['prev'][sweep_idx - 1] for sensor in cam_types: results['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type)) results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6) results['filename'].append(os.path.relpath(sweep[sensor]['data_path'])) results['lidar2img'].append(compose_lidar2img( results['ego2global_translation'], results['ego2global_rotation'], results['lidar2ego_translation'], results['lidar2ego_rotation'], sweep[sensor]['sensor2global_translation'], sweep[sensor]['sensor2global_rotation'], sweep[sensor]['cam_intrinsic'], )) if 'ego2lidar' in results: results['ego2lidar'].append(results['ego2lidar'][0]) return results def load_online(self, results): # only used when measuring FPS assert self.test_mode assert self.test_interval % 6 == 0 cam_types = [ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT', 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT' ] if len(results['sweeps']['prev']) == 0: for _ in range(self.sweeps_num): for j in range(len(cam_types)): results['img_timestamp'].append(results['img_timestamp'][j]) results['filename'].append(results['filename'][j]) results['lidar2img'].append(np.copy(results['lidar2img'][j])) if 'ego2lidar' in results: results['ego2lidar'].append(results['ego2lidar'][0]) else: interval = self.test_interval choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)] for idx in sorted(list(choices)): sweep_idx = min(idx, len(results['sweeps']['prev']) - 1) sweep = results['sweeps']['prev'][sweep_idx] if len(sweep.keys()) < len(cam_types): sweep = results['sweeps']['prev'][sweep_idx - 1] for sensor in cam_types: # skip loading history frames results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6) results['filename'].append(os.path.relpath(sweep[sensor]['data_path'])) results['lidar2img'].append(compose_lidar2img( results['ego2global_translation'], results['ego2global_rotation'], results['lidar2ego_translation'], results['lidar2ego_rotation'], sweep[sensor]['sensor2global_translation'], sweep[sensor]['sensor2global_rotation'], sweep[sensor]['cam_intrinsic'], )) if 'ego2lidar' in results: results['ego2lidar'].append(results['ego2lidar'][0]) return results def __call__(self, results): if self.sweeps_num == 0: return results world_size = get_dist_info()[1] if world_size == 1 and self.test_mode: return self.load_online(results) else: return self.load_offline(results) @PIPELINES.register_module() class LoadOccGTFromFile(object): def __init__(self, num_classes=18, inst_class_ids=[]): self.num_classes = num_classes self.inst_class_ids = inst_class_ids def __call__(self, results): occ_labels = np.load(results['occ_path']) semantics = occ_labels['semantics'] # [200, 200, 16] # mask_lidar = occ_labels['mask_lidar'].astype(np.bool_) # [200, 200, 16] # mask_camera = occ_labels['mask_camera'].astype(np.bool_) # [200, 200, 16] # results['mask_lidar'] = mask_lidar # results['mask_camera'] = mask_camera # instance GT if 'instances' in occ_labels.keys(): instances = occ_labels['instances'] instance_class_ids = [self.num_classes - 1] # the 0-th class is always free class for i in range(1, instances.max() + 1): class_id = np.unique(semantics[instances == i]) assert class_id.shape[0] == 1, "each instance must belong to only one class" instance_class_ids.append(class_id[0]) instance_class_ids = np.array(instance_class_ids) else: instances = None instance_class_ids = None instance_count = 0 final_instance_class_ids = [] final_instances = np.ones_like(semantics) * 255 # empty space has instance id "255" for class_id in range(self.num_classes - 1): if np.sum(semantics == class_id) == 0: continue if class_id in self.inst_class_ids: assert instances is not None, 'instance annotation not found' # treat as instances for instance_id in range(len(instance_class_ids)): if instance_class_ids[instance_id] != class_id: continue final_instances[instances == instance_id] = instance_count instance_count += 1 final_instance_class_ids.append(class_id) else: # treat as semantics final_instances[semantics == class_id] = instance_count instance_count += 1 final_instance_class_ids.append(class_id) results['voxel_semantics'] = semantics results['voxel_instances'] = final_instances results['instance_class_ids'] = DC(to_tensor(final_instance_class_ids)) if results.get('rotate_bda', False): semantics = torch.from_numpy(semantics).permute(2, 0, 1) # [16, 200, 200] semantics = rotate(semantics, results['rotate_bda'], fill=255).permute(1, 2, 0) # [200, 200, 16] results['voxel_semantics'] = semantics.numpy() final_instances = torch.from_numpy(final_instances).permute(2, 0, 1) # [16, 200, 200] final_instances = rotate(final_instances, results['rotate_bda'], fill=255).permute(1, 2, 0) # [200, 200, 16] results['voxel_instances'] = final_instances.numpy() if results.get('flip_dx', False): results['voxel_semantics'] = results['voxel_semantics'][::-1, ...].copy() results['voxel_instances'] = results['voxel_instances'][::-1, ...].copy() if results.get('flip_dy', False): results['voxel_semantics'] = results['voxel_semantics'][:, ::-1, ...].copy() results['voxel_instances'] = results['voxel_instances'][:, ::-1, ...].copy() return results # https://github.com/HuangJunJie2017/BEVDet/blob/58c2587a8f89a1927926f0bdb6cde2917c91a9a5/mmdet3d/datasets/pipelines/loading.py#L1177 @PIPELINES.register_module() class BEVAug(object): def __init__(self, bda_aug_conf, classes, is_train=True): self.bda_aug_conf = bda_aug_conf self.is_train = is_train self.classes = classes def sample_bda_augmentation(self): """Generate bda augmentation values based on bda_config.""" if self.is_train: rotate_bda = np.random.uniform(*self.bda_aug_conf['rot_lim']) scale_bda = np.random.uniform(*self.bda_aug_conf['scale_lim']) flip_dx = np.random.uniform() < self.bda_aug_conf['flip_dx_ratio'] flip_dy = np.random.uniform() < self.bda_aug_conf['flip_dy_ratio'] else: rotate_bda = 0 scale_bda = 1.0 flip_dx = False flip_dy = False return rotate_bda, scale_bda, flip_dx, flip_dy def bev_transform(self, rotate_angle, scale_ratio, flip_dx, flip_dy): """ Returns: rot_mat: (3, 3) """ rotate_angle = torch.tensor(rotate_angle / 180 * np.pi) rot_sin = torch.sin(rotate_angle) rot_cos = torch.cos(rotate_angle) rot_mat = torch.Tensor([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]) scale_mat = torch.Tensor([[scale_ratio, 0, 0], [0, scale_ratio, 0], [0, 0, scale_ratio]]) flip_mat = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) if flip_dx: flip_mat = flip_mat @ torch.Tensor([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) if flip_dy: flip_mat = flip_mat @ torch.Tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) rot_mat = flip_mat @ (scale_mat @ rot_mat) return rot_mat def __call__(self, results): rotate_bda, scale_bda, flip_dx, flip_dy = self.sample_bda_augmentation() bda_mat = torch.zeros(4, 4) bda_mat[3, 3] = 1 # bda_rot: (3, 3) bda_rot = self.bev_transform(rotate_bda, scale_bda, flip_dx, flip_dy) bda_mat[:3, :3] = bda_rot results['bda_mat'] = bda_mat results['flip_dx'] = flip_dx results['flip_dy'] = flip_dy results['rotate_bda'] = rotate_bda results['scale_bda'] = scale_bda for i in range(len(results['ego2lidar'])): results['ego2lidar'][i] = results['ego2lidar'][i] @ torch.inverse(bda_mat).numpy() # [4, 4] @ [4, 4] return results ================================================ FILE: loaders/pipelines/transforms.py ================================================ import mmcv import torch import numpy as np from PIL import Image from numpy import random from mmdet.datasets.builder import PIPELINES @PIPELINES.register_module() class PadMultiViewImage(object): """Pad the multi-view image. There are two padding modes: (1) pad to a fixed size and (2) pad to the minimum size that is divisible by some number. Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor", Args: size (tuple, optional): Fixed padding size. size_divisor (int, optional): The divisor of padded size. pad_val (float, optional): Padding value, 0 by default. """ def __init__(self, size=None, size_divisor=None, pad_val=0): self.size = size self.size_divisor = size_divisor self.pad_val = pad_val # only one of size and size_divisor should be valid assert size is not None or size_divisor is not None assert size is None or size_divisor is None def _pad_img(self, img): if self.size_divisor is not None: pad_h = int(np.ceil(img.shape[0] / self.size_divisor)) * self.size_divisor pad_w = int(np.ceil(img.shape[1] / self.size_divisor)) * self.size_divisor else: pad_h, pad_w = self.size pad_width = ((0, pad_h - img.shape[0]), (0, pad_w - img.shape[1]), (0, 0)) img = np.pad(img, pad_width, constant_values=self.pad_val) return img def _pad_imgs(self, results): padded_img = [self._pad_img(img) for img in results['img']] results['ori_shape'] = [img.shape for img in results['img']] results['img'] = padded_img results['img_shape'] = [img.shape for img in padded_img] results['pad_shape'] = [img.shape for img in padded_img] results['pad_fixed_size'] = self.size results['pad_size_divisor'] = self.size_divisor def __call__(self, results): """Call function to pad images, masks, semantic segmentation maps. Args: results (dict): Result dict from loading pipeline. Returns: dict: Updated result dict. """ self._pad_imgs(results) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(size={self.size}, ' repr_str += f'size_divisor={self.size_divisor}, ' repr_str += f'pad_val={self.pad_val})' return repr_str @PIPELINES.register_module() class NormalizeMultiviewImage(object): """Normalize the image. Added key is "img_norm_cfg". Args: mean (sequence): Mean values of 3 channels. std (sequence): Std values of 3 channels. to_rgb (bool): Whether to convert the image from BGR to RGB, default is true. """ def __init__(self, mean, std, to_rgb=True): self.mean = np.array(mean, dtype=np.float32).reshape(-1) self.std = 1 / np.array(std, dtype=np.float32).reshape(-1) self.to_rgb = to_rgb def __call__(self, results): """Call function to normalize images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Normalized results, 'img_norm_cfg' key is added into result dict. """ normalized_imgs = [] for img in results['img']: img = img.astype(np.float32) if self.to_rgb: img = img[..., ::-1] img = img - self.mean img = img * self.std normalized_imgs.append(img) results['img'] = normalized_imgs results['img_norm_cfg'] = dict( mean=self.mean, std=self.std, to_rgb=self.to_rgb ) return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})' return repr_str @PIPELINES.register_module() class PhotoMetricDistortionMultiViewImage: """Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in second or second to last. 1. random brightness 2. random contrast (mode 0) 3. convert color from BGR to HSV 4. random saturation 5. random hue 6. convert color from HSV to BGR 7. random contrast (mode 1) 8. randomly swap channels Args: brightness_delta (int): delta of brightness. contrast_range (tuple): range of contrast. saturation_range (tuple): range of saturation. hue_delta (int): delta of hue. """ def __init__(self, brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18): self.brightness_delta = brightness_delta self.contrast_lower, self.contrast_upper = contrast_range self.saturation_lower, self.saturation_upper = saturation_range self.hue_delta = hue_delta def __call__(self, results): """Call function to perform photometric distortion on images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Result dict with images distorted. """ imgs = results['img'] new_imgs = [] for img in imgs: ori_dtype = img.dtype img = img.astype(np.float32) # random brightness if random.randint(2): delta = random.uniform(-self.brightness_delta, self.brightness_delta) img += delta # mode == 0 --> do random contrast first # mode == 1 --> do random contrast last mode = random.randint(2) if mode == 1: if random.randint(2): alpha = random.uniform(self.contrast_lower, self.contrast_upper) img *= alpha # convert color from BGR to HSV img = mmcv.bgr2hsv(img) # random saturation if random.randint(2): img[..., 1] *= random.uniform(self.saturation_lower, self.saturation_upper) # random hue if random.randint(2): img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) img[..., 0][img[..., 0] > 360] -= 360 img[..., 0][img[..., 0] < 0] += 360 # convert color from HSV to BGR img = mmcv.hsv2bgr(img) # random contrast if mode == 0: if random.randint(2): alpha = random.uniform(self.contrast_lower, self.contrast_upper) img *= alpha # randomly swap channels if random.randint(2): img = img[..., random.permutation(3)] new_imgs.append(img.astype(ori_dtype)) results['img'] = new_imgs return results def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(\nbrightness_delta={self.brightness_delta},\n' repr_str += 'contrast_range=' repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n' repr_str += 'saturation_range=' repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n' repr_str += f'hue_delta={self.hue_delta})' return repr_str @PIPELINES.register_module() class RandomTransformImage(object): def __init__(self, ida_aug_conf=None, training=True): self.ida_aug_conf = ida_aug_conf self.training = training def __call__(self, results): resize, resize_dims, crop, flip, rotate = self.sample_augmentation() if len(results['lidar2img']) == len(results['img']): for i in range(len(results['img'])): img = Image.fromarray(np.uint8(results['img'][i])) # resize, resize_dims, crop, flip, rotate = self._sample_augmentation() img, ida_mat = self.img_transform( img, resize=resize, resize_dims=resize_dims, crop=crop, flip=flip, rotate=rotate, ) results['img'][i] = np.array(img).astype(np.uint8) results['lidar2img'][i] = ida_mat @ results['lidar2img'][i] elif len(results['img']) == 6: for i in range(len(results['img'])): img = Image.fromarray(np.uint8(results['img'][i])) # resize, resize_dims, crop, flip, rotate = self._sample_augmentation() img, ida_mat = self.img_transform( img, resize=resize, resize_dims=resize_dims, crop=crop, flip=flip, rotate=rotate, ) results['img'][i] = np.array(img).astype(np.uint8) for i in range(len(results['lidar2img'])): results['lidar2img'][i] = ida_mat @ results['lidar2img'][i] else: raise ValueError() results['ori_shape'] = [img.shape for img in results['img']] results['img_shape'] = [img.shape for img in results['img']] results['pad_shape'] = [img.shape for img in results['img']] return results def img_transform(self, img, resize, resize_dims, crop, flip, rotate): """ https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L48 """ def get_rot(h): return torch.Tensor([ [np.cos(h), np.sin(h)], [-np.sin(h), np.cos(h)], ]) ida_rot = torch.eye(2) ida_tran = torch.zeros(2) # adjust image img = img.resize(resize_dims) img = img.crop(crop) if flip: img = img.transpose(method=Image.FLIP_LEFT_RIGHT) img = img.rotate(rotate) # post-homography transformation ida_rot *= resize ida_tran -= torch.Tensor(crop[:2]) if flip: A = torch.Tensor([[-1, 0], [0, 1]]) b = torch.Tensor([crop[2] - crop[0], 0]) ida_rot = A.matmul(ida_rot) ida_tran = A.matmul(ida_tran) + b A = get_rot(rotate / 180 * np.pi) b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2 b = A.matmul(-b) + b ida_rot = A.matmul(ida_rot) ida_tran = A.matmul(ida_tran) + b ida_mat = torch.eye(4) ida_mat[:2, :2] = ida_rot ida_mat[:2, 2] = ida_tran return img, ida_mat.numpy() def sample_augmentation(self): """ https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L247 """ H, W = self.ida_aug_conf['H'], self.ida_aug_conf['W'] fH, fW = self.ida_aug_conf['final_dim'] if self.training: resize = np.random.uniform(*self.ida_aug_conf['resize_lim']) resize_dims = (int(W * resize), int(H * resize)) newW, newH = resize_dims crop_h = int((1 - np.random.uniform(*self.ida_aug_conf['bot_pct_lim'])) * newH) - fH crop_w = int(np.random.uniform(0, max(0, newW - fW))) crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) flip = False if self.ida_aug_conf['rand_flip'] and np.random.choice([0, 1]): flip = True rotate = np.random.uniform(*self.ida_aug_conf['rot_lim']) else: resize = max(fH / H, fW / W) resize_dims = (int(W * resize), int(H * resize)) newW, newH = resize_dims crop_h = int((1 - np.mean(self.ida_aug_conf['bot_pct_lim'])) * newH) - fH crop_w = int(max(0, newW - fW) / 2) crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) flip = False rotate = 0 return resize, resize_dims, crop, flip, rotate @PIPELINES.register_module() class GlobalRotScaleTransImage(object): def __init__(self, rot_range=[-0.3925, 0.3925], scale_ratio_range=[0.95, 1.05], translation_std=[0, 0, 0]): self.rot_range = rot_range self.scale_ratio_range = scale_ratio_range self.translation_std = translation_std def __call__(self, results): # random rotate rot_angle = np.random.uniform(*self.rot_range) self.rotate_z(results, rot_angle) results["gt_bboxes_3d"].rotate(np.array(rot_angle)) # random scale scale_ratio = np.random.uniform(*self.scale_ratio_range) self.scale_xyz(results, scale_ratio) results["gt_bboxes_3d"].scale(scale_ratio) # TODO: support translation return results def rotate_z(self, results, rot_angle): rot_cos = torch.cos(torch.tensor(rot_angle)) rot_sin = torch.sin(torch.tensor(rot_angle)) rot_mat = torch.tensor([ [rot_cos, -rot_sin, 0, 0], [rot_sin, rot_cos, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], ]) rot_mat_inv = torch.inverse(rot_mat) for view in range(len(results['lidar2img'])): results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ rot_mat_inv).numpy() def scale_xyz(self, results, scale_ratio): scale_mat = torch.tensor([ [scale_ratio, 0, 0, 0], [0, scale_ratio, 0, 0], [0, 0, scale_ratio, 0], [0, 0, 0, 1], ]) scale_mat_inv = torch.inverse(scale_mat) for view in range(len(results['lidar2img'])): results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ scale_mat_inv).numpy() ================================================ FILE: loaders/ray_metrics.py ================================================ # Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting # Modified by Haisong Liu import math import copy import numpy as np import torch from torch.utils.cpp_extension import load from tqdm import tqdm from prettytable import PrettyTable from .ray_pq import Metric_RayPQ dvr = load("dvr", sources=["lib/dvr/dvr.cpp", "lib/dvr/dvr.cu"], verbose=True, extra_cuda_cflags=['-allow-unsupported-compiler']) _pc_range = [-40, -40, -1.0, 40, 40, 5.4] _voxel_size = 0.4 # https://github.com/tarashakhurana/4d-occ-forecasting/blob/ff986082cd6ea10e67ab7839bf0e654736b3f4e2/test_fgbg.py#L29C1-L46C16 def get_rendered_pcds(origin, points, tindex, pred_dist): pcds = [] for t in range(len(origin)): mask = (tindex == t) # skip the ones with no data if not mask.any(): continue _pts = points[mask, :3] # use ground truth lidar points for the raycasting direction v = _pts - origin[t][None, :] d = v / np.sqrt((v ** 2).sum(axis=1, keepdims=True)) pred_pts = origin[t][None, :] + d * pred_dist[mask][:, None] pcds.append(torch.from_numpy(pred_pts)) return pcds def meshgrid3d(occ_size, pc_range): W, H, D = occ_size xs = torch.linspace(0.5, W - 0.5, W).view(W, 1, 1).expand(W, H, D) / W ys = torch.linspace(0.5, H - 0.5, H).view(1, H, 1).expand(W, H, D) / H zs = torch.linspace(0.5, D - 0.5, D).view(1, 1, D).expand(W, H, D) / D xs = xs * (pc_range[3] - pc_range[0]) + pc_range[0] ys = ys * (pc_range[4] - pc_range[1]) + pc_range[1] zs = zs * (pc_range[5] - pc_range[2]) + pc_range[2] xyz = torch.stack((xs, ys, zs), -1) return xyz def generate_lidar_rays(): # prepare lidar ray angles pitch_angles = [] for k in range(10): angle = math.pi / 2 - math.atan(k + 1) pitch_angles.append(-angle) # nuscenes lidar fov: [0.2107773983152201, -0.5439104895672159] (rad) while pitch_angles[-1] < 0.21: delta = pitch_angles[-1] - pitch_angles[-2] pitch_angles.append(pitch_angles[-1] + delta) lidar_rays = [] for pitch_angle in pitch_angles: for azimuth_angle in np.arange(0, 360, 1): azimuth_angle = np.deg2rad(azimuth_angle) x = np.cos(pitch_angle) * np.cos(azimuth_angle) y = np.cos(pitch_angle) * np.sin(azimuth_angle) z = np.sin(pitch_angle) lidar_rays.append((x, y, z)) return np.array(lidar_rays, dtype=np.float32) def process_one_sample(sem_pred, lidar_rays, output_origin, instance_pred=None, occ_class_names=None): # lidar origin in ego coordinate # lidar_origin = torch.tensor([[[0.9858, 0.0000, 1.8402]]]) T = output_origin.shape[1] pred_pcds_t = [] free_id = len(occ_class_names) - 1 occ_pred = copy.deepcopy(sem_pred) occ_pred[sem_pred < free_id] = 1 occ_pred[sem_pred == free_id] = 0 occ_pred = occ_pred.permute(2, 1, 0) occ_pred = occ_pred[None, None, :].contiguous().float() offset = torch.Tensor(_pc_range[:3])[None, None, :] scaler = torch.Tensor([_voxel_size] * 3)[None, None, :] lidar_tindex = torch.zeros([1, lidar_rays.shape[0]]) for t in range(T): lidar_origin = output_origin[:, t:t+1, :] # [1, 1, 3] lidar_endpts = lidar_rays[None] + lidar_origin # [1, 15840, 3] output_origin_render = ((lidar_origin - offset) / scaler).float() # [1, 1, 3] output_points_render = ((lidar_endpts - offset) / scaler).float() # [1, N, 3] output_tindex_render = lidar_tindex # [1, N], all zeros with torch.no_grad(): pred_dist, _, coord_index = dvr.render_forward( occ_pred.cuda(), output_origin_render.cuda(), output_points_render.cuda(), output_tindex_render.cuda(), [1, 16, 200, 200], "test" ) pred_dist *= _voxel_size pred_pcds = get_rendered_pcds( lidar_origin[0].cpu().numpy(), lidar_endpts[0].cpu().numpy(), lidar_tindex[0].cpu().numpy(), pred_dist[0].cpu().numpy() ) coord_index = coord_index[0, :, :].int().cpu() # [N, 3] pred_label = sem_pred[coord_index[:, 0], coord_index[:, 1], coord_index[:, 2]][:, None] # [N, 1] pred_dist = pred_dist[0, :, None].cpu() if instance_pred is not None: pred_instance = instance_pred[coord_index[:, 0], coord_index[:, 1], coord_index[:, 2]][:, None] # [N, 1] pred_pcds = torch.cat([pred_label.float(), pred_instance.float(), pred_dist], dim=-1) else: pred_pcds = torch.cat([pred_label.float(), pred_dist], dim=-1) pred_pcds_t.append(pred_pcds) pred_pcds_t = torch.cat(pred_pcds_t, dim=0) return pred_pcds_t.numpy() def calc_rayiou(pcd_pred_list, pcd_gt_list, occ_class_names): thresholds = [1, 2, 4] gt_cnt = np.zeros([len(occ_class_names)]) pred_cnt = np.zeros([len(occ_class_names)]) tp_cnt = np.zeros([len(thresholds), len(occ_class_names)]) for pcd_pred, pcd_gt in zip(pcd_pred_list, pcd_gt_list): for j, threshold in enumerate(thresholds): # L1 depth_pred = pcd_pred[:, 1] depth_gt = pcd_gt[:, 1] l1_error = np.abs(depth_pred - depth_gt) tp_dist_mask = (l1_error < threshold) for i, cls in enumerate(occ_class_names): cls_id = occ_class_names.index(cls) cls_mask_pred = (pcd_pred[:, 0] == cls_id) cls_mask_gt = (pcd_gt[:, 0] == cls_id) gt_cnt_i = cls_mask_gt.sum() pred_cnt_i = cls_mask_pred.sum() if j == 0: gt_cnt[i] += gt_cnt_i pred_cnt[i] += pred_cnt_i tp_cls = cls_mask_gt & cls_mask_pred # [N] tp_mask = np.logical_and(tp_cls, tp_dist_mask) tp_cnt[j][i] += tp_mask.sum() iou_list = [] for j, threshold in enumerate(thresholds): iou_list.append((tp_cnt[j] / (gt_cnt + pred_cnt - tp_cnt[j]))[:-1]) return iou_list def main_rayiou(sem_pred_list, sem_gt_list, lidar_origin_list, occ_class_names): torch.cuda.empty_cache() # generate lidar rays lidar_rays = generate_lidar_rays() lidar_rays = torch.from_numpy(lidar_rays) pcd_pred_list, pcd_gt_list = [], [] for sem_pred, sem_gt, lidar_origins in tqdm(zip(sem_pred_list, sem_gt_list, lidar_origin_list), ncols=50): sem_pred = torch.from_numpy(np.reshape(sem_pred, [200, 200, 16])) sem_gt = torch.from_numpy(np.reshape(sem_gt, [200, 200, 16])) pcd_pred = process_one_sample(sem_pred, lidar_rays, lidar_origins, occ_class_names=occ_class_names) pcd_gt = process_one_sample(sem_gt, lidar_rays, lidar_origins, occ_class_names=occ_class_names) # evalute on non-free rays valid_mask = (pcd_gt[:, 0].astype(np.int32) != len(occ_class_names) - 1) pcd_pred = pcd_pred[valid_mask] pcd_gt = pcd_gt[valid_mask] assert pcd_pred.shape == pcd_gt.shape pcd_pred_list.append(pcd_pred) pcd_gt_list.append(pcd_gt) iou_list = calc_rayiou(pcd_pred_list, pcd_gt_list, occ_class_names) rayiou = np.nanmean(iou_list) rayiou_0 = np.nanmean(iou_list[0]) rayiou_1 = np.nanmean(iou_list[1]) rayiou_2 = np.nanmean(iou_list[2]) table = PrettyTable([ 'Class Names', 'RayIoU@1', 'RayIoU@2', 'RayIoU@4' ]) table.float_format = '.3' for i in range(len(occ_class_names) - 1): table.add_row([ occ_class_names[i], iou_list[0][i], iou_list[1][i], iou_list[2][i] ], divider=(i == len(occ_class_names) - 2)) table.add_row(['MEAN', rayiou_0, rayiou_1, rayiou_2]) print(table) torch.cuda.empty_cache() return { 'RayIoU': rayiou, 'RayIoU@1': rayiou_0, 'RayIoU@2': rayiou_1, 'RayIoU@4': rayiou_2, } def main_raypq(sem_pred_list, sem_gt_list, inst_pred_list, inst_gt_list, lidar_origin_list, occ_class_names): torch.cuda.empty_cache() eval_metrics_pq = Metric_RayPQ( occ_class_names=occ_class_names, num_classes=len(occ_class_names), thresholds=[1, 2, 4] ) # generate lidar rays lidar_rays = generate_lidar_rays() lidar_rays = torch.from_numpy(lidar_rays) for sem_pred, sem_gt, inst_pred, inst_gt, lidar_origins in \ tqdm(zip(sem_pred_list, sem_gt_list, inst_pred_list, inst_gt_list, lidar_origin_list), ncols=50): sem_pred = torch.from_numpy(np.reshape(sem_pred, [200, 200, 16])) sem_gt = torch.from_numpy(np.reshape(sem_gt, [200, 200, 16])) inst_pred = torch.from_numpy(np.reshape(inst_pred, [200, 200, 16])) inst_gt = torch.from_numpy(np.reshape(inst_gt, [200, 200, 16])) pcd_pred = process_one_sample(sem_pred, lidar_rays, lidar_origins, instance_pred=inst_pred, occ_class_names=occ_class_names) pcd_gt = process_one_sample(sem_gt, lidar_rays, lidar_origins, instance_pred=inst_gt, occ_class_names=occ_class_names) # evalute on non-free rays valid_mask = (pcd_gt[:, 0].astype(np.int32) != len(occ_class_names) - 1) pcd_pred = pcd_pred[valid_mask] pcd_gt = pcd_gt[valid_mask] assert pcd_pred.shape == pcd_gt.shape sem_gt = pcd_gt[:, 0].astype(np.int32) sem_pred = pcd_pred[:, 0].astype(np.int32) instances_gt = pcd_gt[:, 1].astype(np.int32) instances_pred = pcd_pred[:, 1].astype(np.int32) # L1 depth_gt = pcd_gt[:, 2] depth_pred = pcd_pred[:, 2] l1_error = np.abs(depth_pred - depth_gt) eval_metrics_pq.add_batch(sem_pred, sem_gt, instances_pred, instances_gt, l1_error) torch.cuda.empty_cache() return eval_metrics_pq.count_pq() ================================================ FILE: loaders/ray_pq.py ================================================ import numpy as np from prettytable import PrettyTable class Metric_RayPQ: def __init__(self, occ_class_names, num_classes=18, thresholds=[1, 2, 4]): """ Args: ignore_index (llist): Class ids that not be considered in pq counting. """ if num_classes == 18 or num_classes == 17: self.class_names = occ_class_names else: raise ValueError self.num_classes = num_classes self.id_offset = 2 ** 16 self.eps = 1e-5 self.thresholds = thresholds self.min_num_points = 10 self.include = np.array( [n for n in range(self.num_classes - 1)], dtype=int) self.cnt = 0 # panoptic stuff self.pan_tp = np.zeros([len(self.thresholds), num_classes], dtype=int) self.pan_iou = np.zeros([len(self.thresholds), num_classes], dtype=np.double) self.pan_fp = np.zeros([len(self.thresholds), num_classes], dtype=int) self.pan_fn = np.zeros([len(self.thresholds), num_classes], dtype=int) def add_batch(self,semantics_pred,semantics_gt,instances_pred,instances_gt, l1_error): self.cnt += 1 self.add_panoptic_sample(semantics_pred, semantics_gt, instances_pred, instances_gt, l1_error) def add_panoptic_sample(self, semantics_pred, semantics_gt, instances_pred, instances_gt, l1_error): """Add one sample of panoptic predictions and ground truths for evaluation. Args: semantics_pred (np.ndarray): Semantic predictions. semantics_gt (np.ndarray): Semantic ground truths. instances_pred (np.ndarray): Instance predictions. instances_gt (np.ndarray): Instance ground truths. """ # get instance_class_id from instance_gt instance_class_ids = [self.num_classes - 1] for i in range(1, instances_gt.max() + 1): class_id = np.unique(semantics_gt[instances_gt == i]) # assert class_id.shape[0] == 1, "each instance must belong to only one class" if class_id.shape[0] == 1: instance_class_ids.append(class_id[0]) else: instance_class_ids.append(self.num_classes - 1) instance_class_ids = np.array(instance_class_ids) instance_count = 1 final_instance_class_ids = [] final_instances = np.zeros_like(instances_gt) # empty space has instance id "0" for class_id in range(self.num_classes - 1): if np.sum(semantics_gt == class_id) == 0: continue if self.class_names[class_id] in ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'motorcycle', 'bicycle', 'pedestrian']: # treat as instances for instance_id in range(len(instance_class_ids)): if instance_class_ids[instance_id] != class_id: continue final_instances[instances_gt == instance_id] = instance_count instance_count += 1 final_instance_class_ids.append(class_id) else: # treat as semantics final_instances[semantics_gt == class_id] = instance_count instance_count += 1 final_instance_class_ids.append(class_id) instances_gt = final_instances # avoid zero (ignored label) instances_pred = instances_pred + 1 instances_gt = instances_gt + 1 for j, threshold in enumerate(self.thresholds): tp_dist_mask = l1_error < threshold # for each class (except the ignored ones) for cl in self.include: # get a class mask pred_inst_in_cl_mask = semantics_pred == cl gt_inst_in_cl_mask = semantics_gt == cl # get instance points in class (makes outside stuff 0) pred_inst_in_cl = instances_pred * pred_inst_in_cl_mask.astype(int) gt_inst_in_cl = instances_gt * gt_inst_in_cl_mask.astype(int) # generate the areas for each unique instance prediction unique_pred, counts_pred = np.unique( pred_inst_in_cl[pred_inst_in_cl > 0], return_counts=True) id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)} matched_pred = np.array([False] * unique_pred.shape[0]) # generate the areas for each unique instance gt_np unique_gt, counts_gt = np.unique( gt_inst_in_cl[gt_inst_in_cl > 0], return_counts=True) id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)} matched_gt = np.array([False] * unique_gt.shape[0]) # generate intersection using offset valid_combos = np.logical_and(pred_inst_in_cl > 0, gt_inst_in_cl > 0) # add dist_mask valid_combos = np.logical_and(valid_combos, tp_dist_mask) id_offset_combo = pred_inst_in_cl[ valid_combos] + self.id_offset * gt_inst_in_cl[valid_combos] unique_combo, counts_combo = np.unique( id_offset_combo, return_counts=True) # generate an intersection map # count the intersections with over 0.5 IoU as TP gt_labels = unique_combo // self.id_offset pred_labels = unique_combo % self.id_offset gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels]) pred_areas = np.array( [counts_pred[id2idx_pred[id]] for id in pred_labels]) intersections = counts_combo unions = gt_areas + pred_areas - intersections ious = intersections.astype(float) / unions.astype(float) tp_indexes = ious > 0.5 self.pan_tp[j][cl] += np.sum(tp_indexes) self.pan_iou[j][cl] += np.sum(ious[tp_indexes]) matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True matched_pred[[id2idx_pred[id] for id in pred_labels[tp_indexes]]] = True # count the FN if len(counts_gt) > 0: self.pan_fn[j][cl] += np.sum( np.logical_and(counts_gt >= self.min_num_points, ~matched_gt)) # count the FP if len(matched_pred) > 0: self.pan_fp[j][cl] += np.sum( np.logical_and(counts_pred >= self.min_num_points, ~matched_pred)) def count_pq(self): sq_all = self.pan_iou.astype(np.double) / np.maximum( self.pan_tp.astype(np.double), self.eps) rq_all = self.pan_tp.astype(np.double) / np.maximum( self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double) + 0.5 * self.pan_fn.astype(np.double), self.eps) pq_all = sq_all * rq_all # mask classes not occurring in dataset mask = (self.pan_tp + self.pan_fp + self.pan_fn) > 0 pq_all[~mask] = float('nan') table = PrettyTable([ 'Class Names', 'RayPQ@%d' % self.thresholds[0], 'RayPQ@%d' % self.thresholds[1], 'RayPQ@%d' % self.thresholds[2] ]) table.float_format = '.3' for i in range(len(self.class_names) - 1): table.add_row([ self.class_names[i], pq_all[0][i], pq_all[1][i], pq_all[2][i], ], divider=(i == len(self.class_names) - 2)) table.add_row([ 'MEAN', np.nanmean(pq_all[0]), np.nanmean(pq_all[1]), np.nanmean(pq_all[2]) ]) print(table) return { 'RayPQ': np.nanmean(pq_all), 'RayPQ@1': np.nanmean(pq_all[0]), 'RayPQ@2': np.nanmean(pq_all[1]), 'RayPQ@4': np.nanmean(pq_all[2]), } ================================================ FILE: models/__init__.py ================================================ from .backbones import __all__ from .bbox import __all__ from .sparseocc import SparseOcc from .sparseocc_head import SparseOccHead from .sparseocc_transformer import SparseOccTransformer from .loss_utils import * __all__ = [] ================================================ FILE: models/backbones/__init__.py ================================================ from .vovnet import VoVNet __all__ = ['VoVNet'] ================================================ FILE: models/backbones/vovnet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import warnings import torch.utils.checkpoint as cp from collections import OrderedDict from mmcv.runner import BaseModule from mmdet.models.builder import BACKBONES from torch.nn.modules.batchnorm import _BatchNorm VoVNet19_slim_dw_eSE = { 'stem': [64, 64, 64], 'stage_conv_ch': [64, 80, 96, 112], 'stage_out_ch': [112, 256, 384, 512], "layer_per_block": 3, "block_per_stage": [1, 1, 1, 1], "eSE": True, "dw": True } VoVNet19_dw_eSE = { 'stem': [64, 64, 64], "stage_conv_ch": [128, 160, 192, 224], "stage_out_ch": [256, 512, 768, 1024], "layer_per_block": 3, "block_per_stage": [1, 1, 1, 1], "eSE": True, "dw": True } VoVNet19_slim_eSE = { 'stem': [64, 64, 128], 'stage_conv_ch': [64, 80, 96, 112], 'stage_out_ch': [112, 256, 384, 512], 'layer_per_block': 3, 'block_per_stage': [1, 1, 1, 1], 'eSE': True, "dw": False } VoVNet19_eSE = { 'stem': [64, 64, 128], "stage_conv_ch": [128, 160, 192, 224], "stage_out_ch": [256, 512, 768, 1024], "layer_per_block": 3, "block_per_stage": [1, 1, 1, 1], "eSE": True, "dw": False } VoVNet39_eSE = { 'stem': [64, 64, 128], "stage_conv_ch": [128, 160, 192, 224], "stage_out_ch": [256, 512, 768, 1024], "layer_per_block": 5, "block_per_stage": [1, 1, 2, 2], "eSE": True, "dw": False } VoVNet57_eSE = { 'stem': [64, 64, 128], "stage_conv_ch": [128, 160, 192, 224], "stage_out_ch": [256, 512, 768, 1024], "layer_per_block": 5, "block_per_stage": [1, 1, 4, 3], "eSE": True, "dw": False } VoVNet99_eSE = { 'stem': [64, 64, 128], "stage_conv_ch": [128, 160, 192, 224], "stage_out_ch": [256, 512, 768, 1024], "layer_per_block": 5, "block_per_stage": [1, 3, 9, 3], "eSE": True, "dw": False } _STAGE_SPECS = { "V-19-slim-dw-eSE": VoVNet19_slim_dw_eSE, "V-19-dw-eSE": VoVNet19_dw_eSE, "V-19-slim-eSE": VoVNet19_slim_eSE, "V-19-eSE": VoVNet19_eSE, "V-39-eSE": VoVNet39_eSE, "V-57-eSE": VoVNet57_eSE, "V-99-eSE": VoVNet99_eSE, } def dw_conv3x3(in_channels, out_channels, module_name, postfix, stride=1, kernel_size=3, padding=1): """3x3 convolution with padding""" return [ ( '{}_{}/dw_conv3x3'.format(module_name, postfix), nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=out_channels, bias=False ) ), ( '{}_{}/pw_conv1x1'.format(module_name, postfix), nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=False) ), ('{}_{}/pw_norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)), ('{}_{}/pw_relu'.format(module_name, postfix), nn.ReLU(inplace=True)), ] def conv3x3(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=3, padding=1): """3x3 convolution with padding""" return [ ( f"{module_name}_{postfix}/conv", nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, ), ), (f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)), (f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)), ] def conv1x1(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=1, padding=0): """1x1 convolution with padding""" return [ ( f"{module_name}_{postfix}/conv", nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False, ), ), (f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)), (f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)), ] class Hsigmoid(nn.Module): def __init__(self, inplace=True): super(Hsigmoid, self).__init__() self.inplace = inplace def forward(self, x): return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 class eSEModule(nn.Module): def __init__(self, channel, reduction=4): super(eSEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0) self.hsigmoid = Hsigmoid() def forward(self, x): inputs = x x = self.avg_pool(x) x = self.fc(x) x = self.hsigmoid(x) return inputs * x class _OSA_module(nn.Module): def __init__(self, in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE=False, identity=False, depthwise=False, with_cp=False): super(_OSA_module, self).__init__() self.with_cp = with_cp self.identity = identity self.depthwise = depthwise self.isReduced = False self.layers = nn.ModuleList() in_channel = in_ch if self.depthwise and in_channel != stage_ch: self.isReduced = True self.conv_reduction = nn.Sequential( OrderedDict(conv1x1(in_channel, stage_ch, "{}_reduction".format(module_name), "0")) ) for i in range(layer_per_block): if self.depthwise: self.layers.append(nn.Sequential(OrderedDict(dw_conv3x3(stage_ch, stage_ch, module_name, i)))) else: self.layers.append(nn.Sequential(OrderedDict(conv3x3(in_channel, stage_ch, module_name, i)))) in_channel = stage_ch # feature aggregation in_channel = in_ch + layer_per_block * stage_ch self.concat = nn.Sequential(OrderedDict(conv1x1(in_channel, concat_ch, module_name, "concat"))) self.ese = eSEModule(concat_ch) def _forward(self, x): identity_feat = x output = [] output.append(x) if self.depthwise and self.isReduced: x = self.conv_reduction(x) for layer in self.layers: x = layer(x) output.append(x) x = torch.cat(output, dim=1) xt = self.concat(x) xt = self.ese(xt) if self.identity: xt = xt + identity_feat return xt def forward(self, x): if self.with_cp and self.training and x.requires_grad: return cp.checkpoint(self._forward, x) else: return self._forward(x) class _OSA_stage(nn.Sequential): def __init__(self, in_ch, stage_ch, concat_ch, block_per_stage, layer_per_block, stage_num, SE=False, depthwise=False, with_cp=False): super(_OSA_stage, self).__init__() if not stage_num == 2: self.add_module("Pooling", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)) if block_per_stage != 1: SE = False module_name = f"OSA{stage_num}_1" self.add_module( module_name, _OSA_module(in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE, depthwise=depthwise, with_cp=with_cp) ) for i in range(block_per_stage - 1): if i != block_per_stage - 2: # last block SE = False module_name = f"OSA{stage_num}_{i + 2}" self.add_module( module_name, _OSA_module( concat_ch, stage_ch, concat_ch, layer_per_block, module_name, SE, identity=True, depthwise=depthwise, with_cp=with_cp ), ) @BACKBONES.register_module() class VoVNet(BaseModule): def __init__(self, spec_name, input_ch=3, out_features=None, frozen_stages=-1, norm_eval=True, with_cp=False, pretrained=None, init_cfg=None): """ Args: input_ch(int) : the number of input channel out_features (list[str]): name of the layers whose outputs should be returned in forward. Can be anything in "stem", "stage2" ... """ super(VoVNet, self).__init__(init_cfg) self.frozen_stages = frozen_stages self.norm_eval = norm_eval if isinstance(pretrained, str): warnings.warn('DeprecationWarning: pretrained is deprecated, ' 'please use "init_cfg" instead') self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) stage_specs = _STAGE_SPECS[spec_name] stem_ch = stage_specs["stem"] config_stage_ch = stage_specs["stage_conv_ch"] config_concat_ch = stage_specs["stage_out_ch"] block_per_stage = stage_specs["block_per_stage"] layer_per_block = stage_specs["layer_per_block"] SE = stage_specs["eSE"] depthwise = stage_specs["dw"] self._out_features = out_features # Stem module conv_type = dw_conv3x3 if depthwise else conv3x3 stem = conv3x3(input_ch, stem_ch[0], "stem", "1", 2) stem += conv_type(stem_ch[0], stem_ch[1], "stem", "2", 1) stem += conv_type(stem_ch[1], stem_ch[2], "stem", "3", 2) self.add_module("stem", nn.Sequential((OrderedDict(stem)))) current_stirde = 4 self._out_feature_strides = {"stem": current_stirde, "stage2": current_stirde} self._out_feature_channels = {"stem": stem_ch[2]} stem_out_ch = [stem_ch[2]] in_ch_list = stem_out_ch + config_concat_ch[:-1] # OSA stages self.stage_names = [] for i in range(4): # num_stages name = "stage%d" % (i + 2) # stage 2 ... stage 5 self.stage_names.append(name) self.add_module( name, _OSA_stage( in_ch_list[i], config_stage_ch[i], config_concat_ch[i], block_per_stage[i], layer_per_block, i + 2, SE, depthwise, with_cp=with_cp ), ) self._out_feature_channels[name] = config_concat_ch[i] if not i == 0: self._out_feature_strides[name] = current_stirde = int(current_stirde * 2) # initialize weights # self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) def forward(self, x): outputs = {} x = self.stem(x) if "stem" in self._out_features: outputs["stem"] = x for name in self.stage_names: x = getattr(self, name)(x) if name in self._out_features: outputs[name] = x return outputs def _freeze_stages(self): if self.frozen_stages >= 0: m = getattr(self, 'stem') m.eval() for param in m.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'stage{i+1}') m.eval() for param in m.parameters(): param.requires_grad = False def train(self, mode=True): """Convert the model into training mode while keep normalization layer freezed.""" super(VoVNet, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval() ================================================ FILE: models/bbox/__init__.py ================================================ from .assigners import __all__ from .coders import __all__ from .match_costs import __all__ ================================================ FILE: models/bbox/assigners/__init__.py ================================================ from .hungarian_assigner_3d import HungarianAssigner3D __all__ = ['HungarianAssigner3D'] ================================================ FILE: models/bbox/assigners/hungarian_assigner_3d.py ================================================ import torch from mmdet.core.bbox.builder import BBOX_ASSIGNERS from mmdet.core.bbox.assigners import AssignResult from mmdet.core.bbox.assigners import BaseAssigner from mmdet.core.bbox.match_costs import build_match_cost from ..utils import normalize_bbox try: from scipy.optimize import linear_sum_assignment except ImportError: linear_sum_assignment = None @BBOX_ASSIGNERS.register_module() class HungarianAssigner3D(BaseAssigner): def __init__(self, cls_cost=dict(type='ClassificationCost', weight=1.), reg_cost=dict(type='BBoxL1Cost', weight=1.0), iou_cost=dict(type='IoUCost', weight=0.0), pc_range=None): self.cls_cost = build_match_cost(cls_cost) self.reg_cost = build_match_cost(reg_cost) self.iou_cost = build_match_cost(iou_cost) self.pc_range = pc_range def assign(self, bbox_pred, cls_pred, gt_bboxes, gt_labels, gt_bboxes_ignore=None, code_weights=None, with_velo=False): assert gt_bboxes_ignore is None, \ 'Only case when gt_bboxes_ignore is None is supported.' num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) # 1. assign -1 by default assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long) assigned_labels = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long) if num_gts == 0 or num_bboxes == 0: # No ground truth or boxes, return empty assignment if num_gts == 0: # No ground truth, assign all to background assigned_gt_inds[:] = 0 return AssignResult( num_gts, assigned_gt_inds, None, labels=assigned_labels) # 2. compute the weighted costs # classification and bboxcost. cls_cost = self.cls_cost(cls_pred, gt_labels) # regression L1 cost normalized_gt_bboxes = normalize_bbox(gt_bboxes) if code_weights is not None: bbox_pred = bbox_pred * code_weights normalized_gt_bboxes = normalized_gt_bboxes * code_weights if with_velo: reg_cost = self.reg_cost(bbox_pred, normalized_gt_bboxes) else: reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8]) # weighted sum of above two costs cost = cls_cost + reg_cost # 3. do Hungarian matching on CPU using linear_sum_assignment cost = cost.detach().cpu() cost = torch.nan_to_num(cost, nan=100.0, posinf=100.0, neginf=-100.0) if linear_sum_assignment is None: raise ImportError('Please run "pip install scipy" ' 'to install scipy first.') matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds = torch.from_numpy(matched_row_inds).to( bbox_pred.device) matched_col_inds = torch.from_numpy(matched_col_inds).to( bbox_pred.device) # 4. assign backgrounds and foregrounds # assign all indices to backgrounds first assigned_gt_inds[:] = 0 # assign foregrounds based on matching results assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] return AssignResult( num_gts, assigned_gt_inds, None, labels=assigned_labels) ================================================ FILE: models/bbox/coders/__init__.py ================================================ from .nms_free_coder import NMSFreeCoder __all__ = ['NMSFreeCoder'] ================================================ FILE: models/bbox/coders/nms_free_coder.py ================================================ import torch from mmdet.core.bbox import BaseBBoxCoder from mmdet.core.bbox.builder import BBOX_CODERS from ..utils import denormalize_bbox @BBOX_CODERS.register_module() class NMSFreeCoder(BaseBBoxCoder): """Bbox coder for NMS-free detector. Args: pc_range (list[float]): Range of point cloud. post_center_range (list[float]): Limit of the center. Default: None. max_num (int): Max number to be kept. Default: 100. score_threshold (float): Threshold to filter boxes based on score. Default: None. code_size (int): Code size of bboxes. Default: 9 """ def __init__(self, pc_range, voxel_size=None, post_center_range=None, max_num=100, score_threshold=None, num_classes=10): self.pc_range = pc_range self.voxel_size = voxel_size self.post_center_range = post_center_range self.max_num = max_num self.score_threshold = score_threshold self.num_classes = num_classes def encode(self): pass def decode_single(self, cls_scores, bbox_preds): """Decode bboxes. Args: cls_scores (Tensor): Outputs from the classification head, \ shape [num_query, cls_out_channels]. Note \ cls_out_channels should includes background. bbox_preds (Tensor): Outputs from the regression \ head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \ Shape [num_query, 9]. Returns: list[dict]: Decoded boxes. """ max_num = self.max_num cls_scores = cls_scores.sigmoid() scores, indexs = cls_scores.view(-1).topk(max_num) labels = indexs % self.num_classes bbox_index = torch.div(indexs, self.num_classes, rounding_mode='trunc') bbox_preds = bbox_preds[bbox_index] final_box_preds = denormalize_bbox(bbox_preds) final_scores = scores final_preds = labels # use score threshold if self.score_threshold is not None: thresh_mask = final_scores > self.score_threshold if self.post_center_range is not None: limit = torch.tensor(self.post_center_range, device=scores.device) mask = (final_box_preds[..., :3] >= limit[:3]).all(1) mask &= (final_box_preds[..., :3] <= limit[3:]).all(1) if self.score_threshold: mask &= thresh_mask boxes3d = final_box_preds[mask] scores = final_scores[mask] labels = final_preds[mask] predictions_dict = { 'bboxes': boxes3d, 'scores': scores, 'labels': labels } else: raise NotImplementedError( 'Need to reorganize output as a batch, only ' 'support post_center_range is not None for now!' ) return predictions_dict def decode(self, preds_dicts): """Decode bboxes. Args: all_cls_scores (Tensor): Outputs from the classification head, \ shape [nb_dec, bs, num_query, cls_out_channels]. Note \ cls_out_channels should includes background. all_bbox_preds (Tensor): Sigmoid outputs from the regression \ head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \ Shape [nb_dec, bs, num_query, 9]. Returns: list[dict]: Decoded boxes. """ all_cls_scores = preds_dicts['all_cls_scores'][-1] all_bbox_preds = preds_dicts['all_bbox_preds'][-1] batch_size = all_cls_scores.size()[0] predictions_list = [] for i in range(batch_size): predictions_list.append(self.decode_single(all_cls_scores[i], all_bbox_preds[i])) return predictions_list ================================================ FILE: models/bbox/match_costs/__init__.py ================================================ from .match_cost import BBox3DL1Cost __all__ = ['BBox3DL1Cost'] ================================================ FILE: models/bbox/match_costs/match_cost.py ================================================ import torch from mmdet.core.bbox.match_costs.builder import MATCH_COST @MATCH_COST.register_module() class BBox3DL1Cost(object): """BBox3DL1Cost. Args: weight (int | float, optional): loss_weight """ def __init__(self, weight=1.0): self.weight = weight def __call__(self, bbox_pred, gt_bboxes): """ Args: bbox_pred (Tensor): Predicted boxes with normalized coordinates (cx, cy, w, h), which are all in range [0, 1]. Shape [num_query, 4]. gt_bboxes (Tensor): Ground truth boxes with normalized coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. Returns: torch.Tensor: bbox_cost value with weight """ bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1) return bbox_cost * self.weight @MATCH_COST.register_module() class BBoxBEVL1Cost(object): def __init__(self, weight, pc_range): self.weight = weight self.pc_range = pc_range def __call__(self, bboxes, gt_bboxes): pc_start = bboxes.new(self.pc_range[0:2]) pc_range = bboxes.new(self.pc_range[3:5]) - bboxes.new(self.pc_range[0:2]) # normalize the box center to [0, 1] normalized_bboxes_xy = (bboxes[:, :2] - pc_start) / pc_range normalized_gt_bboxes_xy = (gt_bboxes[:, :2] - pc_start) / pc_range reg_cost = torch.cdist(normalized_bboxes_xy, normalized_gt_bboxes_xy, p=1) return reg_cost * self.weight @MATCH_COST.register_module() class IoU3DCost(object): def __init__(self, weight): self.weight = weight def __call__(self, iou): iou_cost = - iou return iou_cost * self.weight ================================================ FILE: models/bbox/utils.py ================================================ import torch def normalize_bbox(bboxes): cx = bboxes[..., 0:1] cy = bboxes[..., 1:2] cz = bboxes[..., 2:3] w = bboxes[..., 3:4].log() l = bboxes[..., 4:5].log() h = bboxes[..., 5:6].log() rot = bboxes[..., 6:7] if bboxes.size(-1) > 7: vx = bboxes[..., 7:8] vy = bboxes[..., 8:9] out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos(), vx, vy], dim=-1) else: out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos()], dim=-1) return out def denormalize_bbox(normalized_bboxes): rot_sin = normalized_bboxes[..., 6:7] rot_cos = normalized_bboxes[..., 7:8] rot = torch.atan2(rot_sin, rot_cos) cx = normalized_bboxes[..., 0:1] cy = normalized_bboxes[..., 1:2] cz = normalized_bboxes[..., 4:5] w = normalized_bboxes[..., 2:3].exp() l = normalized_bboxes[..., 3:4].exp() h = normalized_bboxes[..., 5:6].exp() if normalized_bboxes.size(-1) > 8: vx = normalized_bboxes[..., 8:9] vy = normalized_bboxes[..., 9:10] out = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1) else: out = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1) return out def encode_bbox(bboxes, pc_range=None): xyz = bboxes[..., 0:3].clone() wlh = bboxes[..., 3:6].log() rot = bboxes[..., 6:7] if pc_range is not None: xyz[..., 0] = (xyz[..., 0] - pc_range[0]) / (pc_range[3] - pc_range[0]) xyz[..., 1] = (xyz[..., 1] - pc_range[1]) / (pc_range[4] - pc_range[1]) xyz[..., 2] = (xyz[..., 2] - pc_range[2]) / (pc_range[5] - pc_range[2]) if bboxes.shape[-1] > 7: vel = bboxes[..., 7:9].clone() return torch.cat([xyz, wlh, rot.sin(), rot.cos(), vel], dim=-1) else: return torch.cat([xyz, wlh, rot.sin(), rot.cos()], dim=-1) def decode_bbox(bboxes, pc_range=None): xyz = bboxes[..., 0:3].clone() wlh = bboxes[..., 3:6].exp() rot = torch.atan2(bboxes[..., 6:7], bboxes[..., 7:8]) if pc_range is not None: xyz[..., 0] = xyz[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0] xyz[..., 1] = xyz[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1] xyz[..., 2] = xyz[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2] if bboxes.shape[-1] > 8: vel = bboxes[..., 8:10].clone() return torch.cat([xyz, wlh, rot, vel], dim=-1) else: return torch.cat([xyz, wlh, rot], dim=-1) def bbox2occrange(bboxes, occ_size, query_cube_size=None): """ xyz in [0, 1] wlh in [0, 1] """ xyz = bboxes[..., 0:3].clone() if query_cube_size is not None: wlh = torch.zeros_like(xyz) wlh[..., 0] = query_cube_size[0] wlh[..., 1] = query_cube_size[1] wlh[..., 2] = query_cube_size[2] else: wlh = bboxes[..., 3:6] wlh[..., 0] = wlh[..., 0] * occ_size[0] wlh[..., 1] = wlh[..., 1] * occ_size[1] wlh[..., 2] = wlh[..., 2] * occ_size[2] xyz[..., 0] = xyz[..., 0] * occ_size[0] xyz[..., 1] = xyz[..., 1] * occ_size[1] xyz[..., 2] = xyz[..., 2] * occ_size[2] xyz = torch.round(xyz) low_bound = torch.round(xyz - wlh/2) high_bound = torch.round(xyz + wlh/2) return torch.cat((low_bound, high_bound), dim=-1).long() def occrange2bbox(occ_range, occ_size, pc_range): """ Return: xyz in [0, 1], wlh in [0, pc_range_size) """ xyz = (occ_range[..., :3] + occ_range[..., 3:]).to(torch.float32) / 2 xyz[..., 0] /= occ_size[0] xyz[..., 1] /= occ_size[1] xyz[..., 2] /= occ_size[2] wlh = (occ_range[..., 3:] - occ_range[..., :3]).to(torch.float32) wlh[..., 0] *= (pc_range[3] - pc_range[0]) / occ_size[0] wlh[..., 1] *= (pc_range[4] - pc_range[1]) / occ_size[1] wlh[..., 2] *= (pc_range[5] - pc_range[2]) / occ_size[2] return torch.cat((xyz, wlh), dim=-1) ================================================ FILE: models/checkpoint.py ================================================ # This page is completely copied from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint # If you are using torch 2.0 or higher, you can safely delete this page and import the related functions from official PyTorch import torch import warnings import weakref from typing import Any, Iterable, List, Tuple __all__ = [ "checkpoint", "checkpoint_sequential", "CheckpointFunction", "check_backward_validity", "detach_variable", "get_device_states", "set_device_states", ] def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: if isinstance(inputs, tuple): out = [] for inp in inputs: if not isinstance(inp, torch.Tensor): out.append(inp) continue x = inp.detach() x.requires_grad = inp.requires_grad out.append(x) return tuple(out) else: raise RuntimeError( "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__) def check_backward_validity(inputs: Iterable[Any]) -> None: if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): warnings.warn("None of the inputs have requires_grad=True. Gradients will be None") # We can't know if the run_fn will internally move some args to different devices, # which would require logic to preserve rng states for those devices as well. # We could paranoically stash and restore ALL the rng states for all visible devices, # but that seems very wasteful for most cases. Compromise: Stash the RNG state for # the device of all Tensor args. # # To consider: maybe get_device_states and set_device_states should reside in torch/random.py? def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: # This will not error out if "arg" is a CPU tensor or a non-tensor type because # the conditionals short-circuit. fwd_gpu_devices = list({arg.get_device() for arg in args if isinstance(arg, torch.Tensor) and arg.is_cuda}) fwd_gpu_states = [] for device in fwd_gpu_devices: with torch.cuda.device(device): fwd_gpu_states.append(torch.cuda.get_rng_state()) return fwd_gpu_devices, fwd_gpu_states def set_device_states(devices, states) -> None: for device, state in zip(devices, states): with torch.cuda.device(device): torch.cuda.set_rng_state(state) def _get_autocast_kwargs(): gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), "dtype": torch.get_autocast_gpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled()} cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(), "dtype": torch.get_autocast_cpu_dtype(), "cache_enabled": torch.is_autocast_cache_enabled()} return gpu_autocast_kwargs, cpu_autocast_kwargs class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs() if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.) ctx.had_cuda_in_fwd = False if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] tensor_inputs = [] for i, arg in enumerate(args): if torch.is_tensor(arg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: ctx.inputs.append(arg) ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): outputs = run_function(*args) return outputs @staticmethod def backward(ctx, *args): if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad() or when an `inputs` parameter" " is passed to .backward(). Please use .backward() and do not pass its `inputs`" " argument.") # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if ctx.preserve_rng_state and ctx.had_cuda_in_fwd: rng_devices = ctx.fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): if ctx.preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_cuda_in_fwd: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) detached_inputs = detach_variable(tuple(inputs)) with torch.enable_grad(), \ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \ torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) # run backward() with only tensor that requires grad outputs_with_grad = [] args_with_grad = [] for i in range(len(outputs)): if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: raise RuntimeError( "none of output has requires_grad=True," " this checkpoint() is not necessary") torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) return (None, None) + grads def checkpoint(function, *args, use_reentrant: bool = True, **kwargs): r"""Checkpoint a model or part of the model Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does **not** save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model. Specifically, in the forward pass, :attr:`function` will run in :func:`torch.no_grad` manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the :attr:`function` parameter. In the backwards pass, the saved inputs and :attr:`function` is retrieved, and the forward pass is computed on :attr:`function` again, now tracking the intermediate activations, and then the gradients are calculated using these activation values. The output of :attr:`function` can contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd. .. warning:: If :attr:`function` invocation during backward does anything different than the one during forward, e.g., due to some global variable, the checkpointed version won't be equivalent, and unfortunately it can't be detected. .. warning:: If ``use_reentrant=True`` is specified, then if the checkpointed segment contains tensors detached from the computational graph by `detach()` or `torch.no_grad()`, the backward pass will raise an error. This is because `checkpoint` makes all the outputs require gradients which causes issues when a tensor is defined to have no gradient in the model. To circumvent this, detach the tensors outside of the `checkpoint` function. Note that the checkpointed segment can contain tensors detached from the computational graph if ``use_reentrant=False`` is specified. .. warning:: If ``use_reentrant=True`` is specified, at least one of the inputs needs to have :code:`requires_grad=True` if grads are needed for model inputs, otherwise the checkpointed part of the model won't have gradients. At least one of the outputs needs to have :code:`requires_grad=True` as well. Note that this does not apply if ``use_reentrant=False`` is specified. .. warning:: If ``use_reentrant=True`` is specified, checkpointing currently only supports :func:`torch.autograd.backward` and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` is not supported. If ``use_reentrant=False`` is specified, checkpointing will work with :func:`torch.autograd.grad`. Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` use_reentrant(bool, optional): Use checkpointing implementation that requires re-entrant autograd. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an implementation that does not require re-entrant autograd. This allows ``checkpoint`` to support additional functionality, such as working as expected with ``torch.autograd.grad`` and support for keyword arguments input into the checkpointed function. Note that future versions of PyTorch will default to ``use_reentrant=False``. Default: ``True`` args: tuple containing inputs to the :attr:`function` Returns: Output of running :attr:`function` on :attr:`*args` """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) if kwargs and use_reentrant: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) if use_reentrant: return CheckpointFunction.apply(function, preserve, *args) else: return _checkpoint_without_reentrant( function, preserve, *args, **kwargs, ) def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs): r"""A helper function for checkpointing sequential models. Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will run in :func:`torch.no_grad` manner, i.e., not storing the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass. See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works. .. warning:: Checkpointing currently only supports :func:`torch.autograd.backward` and only if its `inputs` argument is not passed. :func:`torch.autograd.grad` is not supported. .. warning: At least one of the inputs needs to have :code:`requires_grad=True` if grads are needed for model inputs, otherwise the checkpointed part of the model won't have gradients. .. warning: Since PyTorch 1.4, it allows only one Tensor as the input and intermediate outputs, just like :class:`torch.nn.Sequential`. Args: functions: A :class:`torch.nn.Sequential` or the list of modules or functions (comprising the model) to run sequentially. segments: Number of chunks to create in the model input: A Tensor that is input to :attr:`functions` preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` use_reentrant(bool, optional): Use checkpointing implementation that requires re-entrant autograd. If ``use_reentrant=False`` is specified, ``checkpoint`` will use an implementation that does not require re-entrant autograd. This allows ``checkpoint`` to support additional functionality, such as working as expected with ``torch.autograd.grad`` and support for keyword arguments input into the checkpointed function. Default: ``True`` Returns: Output of running :attr:`functions` sequentially on :attr:`*inputs` Example: >>> # xdoctest: +SKIP("stub") >>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var) """ # Hack for keyword-only parameter in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) if kwargs: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) def run_function(start, end, functions): def forward(input): for j in range(start, end + 1): input = functions[j](input) return input return forward if isinstance(functions, torch.nn.Sequential): functions = list(functions.children()) segment_size = len(functions) // segments # the last chunk has to be non-volatile end = -1 for start in range(0, segment_size * (segments - 1), segment_size): end = start + segment_size - 1 input = checkpoint( run_function(start, end, functions), input, use_reentrant=use_reentrant, preserve_rng_state=preserve ) return run_function(end + 1, len(functions) - 1, functions)(input) def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs): """Checkpointining without re-entrant autograd Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional): Omit stashing and restoring the RNG state during each checkpoint. Default: ``True`` *args: Arguments to pass in to the given ``function``. **kwargs: Keyword arguments to pass into the given ``function``. """ # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs() if preserve_rng_state: fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function. # If they do so, we raise an error.) had_cuda_in_fwd = False if torch.cuda._initialized: had_cuda_in_fwd = True fwd_gpu_devices, fwd_gpu_states = get_device_states(*args) # Custom class to be able to take weak references class Holder(): pass # The Holder object for each of the saved object is saved directly on the # SavedVariable and is cleared when reset_data() is called on it. We MUST make # sure that this is the only object having an owning reference to ensure that # the Tensor stored in storage is deleted as soon as the corresponding SavedVariable # data is cleared. storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() weak_holder_list = [] def pack(x): # TODO(varal7): Instead of returning abstract object, we can return things metadata (such as # size, device, ...) to catch certain cases of undeterministic behavior of the forward res = Holder() weak_holder_list.append(weakref.ref(res)) return res def unpack(x): unpack_counter = 0 if len(storage) == 0: def inner_pack(inner): nonlocal unpack_counter unpack_counter += 1 # If the holder went out of scope, the SavedVariable is dead and so # the value will never be read from the storage. Skip filling it. if weak_holder_list[unpack_counter - 1]() is None: return # Use detach here to ensure we don't keep the temporary autograd # graph created during the second forward storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() return def inner_unpack(packed): raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.") # Stash the surrounding rng state, and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. rng_devices = [] if preserve_rng_state and had_cuda_in_fwd: rng_devices = fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): if preserve_rng_state: torch.set_rng_state(fwd_cpu_state) if had_cuda_in_fwd: set_device_states(fwd_gpu_devices, fwd_gpu_states) with torch.enable_grad(), \ torch.cuda.amp.autocast(**gpu_autocast_kwargs), \ torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack): _unused = function(*args, **kwargs) if x not in storage: raise RuntimeError( "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" " recomputation being triggered in between, this is not currently supported. Please" " open an issue with details on your use case so that we can prioritize adding this." ) return storage[x] with torch.autograd.graph.saved_tensors_hooks(pack, unpack): output = function(*args, **kwargs) if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd: # Cuda was not initialized before running the forward, so we didn't # stash the CUDA state. raise RuntimeError( "PyTorch's CUDA state was initialized in the forward pass " "of a Checkpoint, which is not allowed. Please open an issue " "if you need this feature.") return output ================================================ FILE: models/csrc/__init__.py ================================================ ================================================ FILE: models/csrc/msmv_sampling/msmv_sampling.cpp ================================================ #include "msmv_sampling.h" #define MAX_POINT 32 void ms_deformable_im2col_cuda_c2345( const float* feat_c2, const float* feat_c3, const float* feat_c4, const float* feat_c5, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const float* data_sampling_loc, const float* data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float* data_col ); void ms_deformable_im2col_cuda_c23456( const float* feat_c2, const float* feat_c3, const float* feat_c4, const float* feat_c5, const float* feat_c6, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const int h_c6, const int w_c6, const float* data_sampling_loc, const float* data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float* data_col ); void ms_deformable_col2im_cuda_c2345( const float* grad_col, const float* feat_c2, const float* feat_c3, const float* feat_c4, const float* feat_c5, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const float* data_sampling_loc, const float* data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float* grad_value_c2, float* grad_value_c3, float* grad_value_c4, float* grad_value_c5, float* grad_sampling_loc, float* grad_attn_weight ); void ms_deformable_col2im_cuda_c23456( const float *grad_col, const float *feat_c2, const float *feat_c3, const float *feat_c4, const float *feat_c5, const float *feat_c6, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const int h_c6, const int w_c6, const float *data_sampling_loc, const float *data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float *grad_value_c2, float *grad_value_c3, float *grad_value_c4, float *grad_value_c5, float *grad_value_c6, float *grad_sampling_loc, float *grad_attn_weight ); at::Tensor ms_deform_attn_cuda_c2345_forward( const at::Tensor& feat_c2, // [B, N, H, W, C] const at::Tensor& feat_c3, // [B, N, H, W, C] const at::Tensor& feat_c4, // [B, N, H, W, C] const at::Tensor& feat_c5, // [B, N, H, W, C] const at::Tensor& sampling_loc, // [B, Q, P, 3] const at::Tensor& attn_weight // [B, Q, P, 4] ) { AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); const int batch_size = feat_c2.size(0); const int num_views = feat_c2.size(1); const int channels = feat_c2.size(4); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(2); AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits"); const int h_c2 = feat_c2.size(2); const int w_c2 = feat_c2.size(3); const int h_c3 = feat_c3.size(2); const int w_c3 = feat_c3.size(3); const int h_c4 = feat_c4.size(2); const int w_c4 = feat_c4.size(3); const int h_c5 = feat_c5.size(2); const int w_c5 = feat_c5.size(3); auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options()); ms_deformable_im2col_cuda_c2345( feat_c2.data_ptr(), feat_c3.data_ptr(), feat_c4.data_ptr(), feat_c5.data_ptr(), h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, sampling_loc.data_ptr(), attn_weight.data_ptr(), batch_size, channels, num_views, num_query, num_point, output.data_ptr() ); return output; } at::Tensor ms_deform_attn_cuda_c23456_forward( const at::Tensor& feat_c2, // [B, N, H, W, C] const at::Tensor& feat_c3, // [B, N, H, W, C] const at::Tensor& feat_c4, // [B, N, H, W, C] const at::Tensor& feat_c5, // [B, N, H, W, C] const at::Tensor& feat_c6, // [B, N, H, W, C] const at::Tensor& sampling_loc, // [B, Q, P, 3] const at::Tensor& attn_weight // [B, Q, P, 4] ) { AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c6.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c6.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); const int batch_size = feat_c2.size(0); const int num_views = feat_c2.size(1); const int channels = feat_c2.size(4); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(2); AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits"); const int h_c2 = feat_c2.size(2); const int w_c2 = feat_c2.size(3); const int h_c3 = feat_c3.size(2); const int w_c3 = feat_c3.size(3); const int h_c4 = feat_c4.size(2); const int w_c4 = feat_c4.size(3); const int h_c5 = feat_c5.size(2); const int w_c5 = feat_c5.size(3); const int h_c6 = feat_c6.size(2); const int w_c6 = feat_c6.size(3); auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options()); ms_deformable_im2col_cuda_c23456( feat_c2.data_ptr(), feat_c3.data_ptr(), feat_c4.data_ptr(), feat_c5.data_ptr(), feat_c6.data_ptr(), h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6, sampling_loc.data_ptr(), attn_weight.data_ptr(), batch_size, channels, num_views, num_query, num_point, output.data_ptr() ); return output; } std::vector ms_deform_attn_cuda_c2345_backward( const at::Tensor& grad_output, const at::Tensor& feat_c2, // [B, N, H, W, C] const at::Tensor& feat_c3, // [B, N, H, W, C] const at::Tensor& feat_c4, // [B, N, H, W, C] const at::Tensor& feat_c5, // [B, N, H, W, C] const at::Tensor& sampling_loc, // [B, Q, P, 3] const at::Tensor& attn_weight // [B, Q, P, 4] ) { AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor"); const int batch_size = feat_c2.size(0); const int num_views = feat_c2.size(1); const int channels = feat_c2.size(4); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(2); AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits"); auto grad_value_c2 = at::zeros_like(feat_c2); auto grad_value_c3 = at::zeros_like(feat_c3); auto grad_value_c4 = at::zeros_like(feat_c4); auto grad_value_c5 = at::zeros_like(feat_c5); auto grad_sampling_loc = at::zeros_like(sampling_loc); auto grad_attn_weight = at::zeros_like(attn_weight); const int h_c2 = feat_c2.size(2); const int w_c2 = feat_c2.size(3); const int h_c3 = feat_c3.size(2); const int w_c3 = feat_c3.size(3); const int h_c4 = feat_c4.size(2); const int w_c4 = feat_c4.size(3); const int h_c5 = feat_c5.size(2); const int w_c5 = feat_c5.size(3); ms_deformable_col2im_cuda_c2345( grad_output.data_ptr(), feat_c2.data_ptr(), feat_c3.data_ptr(), feat_c4.data_ptr(), feat_c5.data_ptr(), h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, sampling_loc.data_ptr(), attn_weight.data_ptr(), batch_size, channels, num_views, num_query, num_point, grad_value_c2.data_ptr(), grad_value_c3.data_ptr(), grad_value_c4.data_ptr(), grad_value_c5.data_ptr(), grad_sampling_loc.data_ptr(), grad_attn_weight.data_ptr() ); return { grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight }; } std::vector ms_deform_attn_cuda_c23456_backward( const at::Tensor& grad_output, const at::Tensor& feat_c2, // [B, N, H, W, C] const at::Tensor& feat_c3, // [B, N, H, W, C] const at::Tensor& feat_c4, // [B, N, H, W, C] const at::Tensor& feat_c5, // [B, N, H, W, C] const at::Tensor& feat_c6, // [B, N, H, W, C] const at::Tensor& sampling_loc, // [B, Q, P, 3] const at::Tensor& attn_weight // [B, Q, P, 4] ) { AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(feat_c6.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(feat_c6.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor"); const int batch_size = feat_c2.size(0); const int num_views = feat_c2.size(1); const int channels = feat_c2.size(4); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(2); AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits"); auto grad_value_c2 = at::zeros_like(feat_c2); auto grad_value_c3 = at::zeros_like(feat_c3); auto grad_value_c4 = at::zeros_like(feat_c4); auto grad_value_c5 = at::zeros_like(feat_c5); auto grad_value_c6 = at::zeros_like(feat_c6); auto grad_sampling_loc = at::zeros_like(sampling_loc); auto grad_attn_weight = at::zeros_like(attn_weight); const int h_c2 = feat_c2.size(2); const int w_c2 = feat_c2.size(3); const int h_c3 = feat_c3.size(2); const int w_c3 = feat_c3.size(3); const int h_c4 = feat_c4.size(2); const int w_c4 = feat_c4.size(3); const int h_c5 = feat_c5.size(2); const int w_c5 = feat_c5.size(3); const int h_c6 = feat_c6.size(2); const int w_c6 = feat_c6.size(3); ms_deformable_col2im_cuda_c23456( grad_output.data_ptr(), feat_c2.data_ptr(), feat_c3.data_ptr(), feat_c4.data_ptr(), feat_c5.data_ptr(), feat_c6.data_ptr(), h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6, sampling_loc.data_ptr(), attn_weight.data_ptr(), batch_size, channels, num_views, num_query, num_point, grad_value_c2.data_ptr(), grad_value_c3.data_ptr(), grad_value_c4.data_ptr(), grad_value_c5.data_ptr(), grad_value_c6.data_ptr(), grad_sampling_loc.data_ptr(), grad_attn_weight.data_ptr() ); return { grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight }; } #ifdef TORCH_EXTENSION_NAME PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("_ms_deform_attn_cuda_c2345_forward", &ms_deform_attn_cuda_c2345_forward, "pass"); m.def("_ms_deform_attn_cuda_c2345_backward", &ms_deform_attn_cuda_c2345_backward, "pass"); m.def("_ms_deform_attn_cuda_c23456_forward", &ms_deform_attn_cuda_c23456_forward, "pass"); m.def("_ms_deform_attn_cuda_c23456_backward", &ms_deform_attn_cuda_c23456_backward, "pass"); } #endif ================================================ FILE: models/csrc/msmv_sampling/msmv_sampling.h ================================================ #pragma once #include at::Tensor ms_deform_attn_cuda_c2345_forward( const at::Tensor& feat_c2, // [B, N, H, W, C] const at::Tensor& feat_c3, // [B, N, H, W, C] const at::Tensor& feat_c4, // [B, N, H, W, C] const at::Tensor& feat_c5, // [B, N, H, W, C] const at::Tensor& sampling_loc, // [B, Q, P, 3] const at::Tensor& attn_weight // [B, Q, P, 4] ); std::vector ms_deform_attn_cuda_c2345_backward( const at::Tensor& feat_c2, // [B, N, H, W, C] const at::Tensor& feat_c3, // [B, N, H, W, C] const at::Tensor& feat_c4, // [B, N, H, W, C] const at::Tensor& feat_c5, // [B, N, H, W, C] const at::Tensor& sampling_loc, // [B, Q, P, 3] const at::Tensor& attn_weight, // [B, Q, P, 4] const at::Tensor& grad_output ); at::Tensor ms_deform_attn_cuda_c23456_forward( const at::Tensor& feat_c2, // [B, N, H, W, C] const at::Tensor& feat_c3, // [B, N, H, W, C] const at::Tensor& feat_c4, // [B, N, H, W, C] const at::Tensor& feat_c5, // [B, N, H, W, C] const at::Tensor& feat_c6, // [B, N, H, W, C] const at::Tensor& sampling_loc, // [B, Q, P, 3] const at::Tensor& attn_weight // [B, Q, P, 4] ); std::vector ms_deform_attn_cuda_c23456_backward( const at::Tensor& grad_output, const at::Tensor& feat_c2, // [B, N, H, W, C] const at::Tensor& feat_c3, // [B, N, H, W, C] const at::Tensor& feat_c4, // [B, N, H, W, C] const at::Tensor& feat_c5, // [B, N, H, W, C] const at::Tensor& feat_c6, // [B, N, H, W, C] const at::Tensor& sampling_loc, // [B, Q, P, 3] const at::Tensor& attn_weight // [B, Q, P, 4] ); ================================================ FILE: models/csrc/msmv_sampling/msmv_sampling_backward.cu ================================================ /*! * Modified from Deformable DETR */ #include #include #include #include #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) #define CUDA_NUM_THREADS 512 #define MAX_POINT 32 inline int GET_BLOCKS(const int N, const int num_threads) { return (N + num_threads - 1) / num_threads; } __device__ void ms_deform_attn_col2im_bilinear(const float *&bottom_data, const int &height, const int &width, const int &channels, const float &h, const float &w, const int &c, const float &top_grad, const float &attn_weight, const float *&grad_value, float *&grad_sampling_loc, float *&grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const float lh = h - h_low; const float lw = w - w_low; const float hh = 1 - lh, hw = 1 - lw; const int w_stride = channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const float top_grad_value = top_grad * attn_weight; float grad_h_weight = 0, grad_w_weight = 0; float *grad_ptr; float v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c; grad_ptr = const_cast(grad_value + ptr1); v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_ptr, w1 * top_grad_value); } float v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c; grad_ptr = const_cast(grad_value + ptr2); v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_ptr, w2 * top_grad_value); } float v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c; grad_ptr = const_cast(grad_value + ptr3); v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_ptr, w3 * top_grad_value); } float v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c; grad_ptr = const_cast(grad_value + ptr4); v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_ptr, w4 * top_grad_value); } const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); atomicAdd(grad_attn_weight, top_grad * val); atomicAdd(grad_sampling_loc, (width - 1) * grad_w_weight * top_grad_value); atomicAdd(grad_sampling_loc + 1, (height - 1) * grad_h_weight * top_grad_value); } // global_memory_way __global__ void ms_deformable_col2im_gpu_kernel_gm_c2345( const float *grad_col, const float *feat_c2, const float *feat_c3, const float *feat_c4, const float *feat_c5, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const float *data_sampling_loc, const float *data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float *grad_value_c2, float *grad_value_c3, float *grad_value_c4, float *grad_value_c5, float *grad_sampling_loc, float *grad_attn_weight) { CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point) { // n: bs x query x channels int _temp = index; const int p_col = _temp % num_point; _temp /= num_point; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; _temp /= num_query; const int b_col = _temp; const float top_grad = grad_col[index]; // Sampling location in range [0, 1] int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3; const float loc_w = data_sampling_loc[data_loc_ptr]; const float loc_h = data_sampling_loc[data_loc_ptr + 1]; const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1)); // Attn weights int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4; const float weight_c2 = data_attn_weight[data_weight_ptr]; const float weight_c3 = data_attn_weight[data_weight_ptr + 1]; const float weight_c4 = data_attn_weight[data_weight_ptr + 2]; const float weight_c5 = data_attn_weight[data_weight_ptr + 3]; // const float h_im = loc_h * spatial_h - 0.5; // align_corners = False // const float w_im = loc_w * spatial_w - 0.5; // C2 Feature float h_im = loc_h * (h_c2 - 1); // align_corners = True float w_im = loc_w * (w_c2 - 1); float *grad_location_ptr = grad_sampling_loc + data_loc_ptr; float *grad_weights_ptr = grad_attn_weight + data_weight_ptr; if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) { const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels; const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels; ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col, top_grad, weight_c2, grad_c2_ptr, grad_location_ptr, grad_weights_ptr); } grad_weights_ptr += 1; // C3 Feature h_im = loc_h * (h_c3 - 1); // align_corners = True w_im = loc_w * (w_c3 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) { const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels; const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels; ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col, top_grad, weight_c3, grad_c3_ptr, grad_location_ptr, grad_weights_ptr); } grad_weights_ptr += 1; // C4 Feature h_im = loc_h * (h_c4 - 1); // align_corners = True w_im = loc_w * (w_c4 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) { const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels; const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels; ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col, top_grad, weight_c4, grad_c4_ptr, grad_location_ptr, grad_weights_ptr); } grad_weights_ptr += 1; // C5 Feature h_im = loc_h * (h_c5 - 1); // align_corners = True w_im = loc_w * (w_c5 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) { const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels; const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels; ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col, top_grad, weight_c5, grad_c5_ptr, grad_location_ptr, grad_weights_ptr); } } } __global__ void ms_deformable_col2im_gpu_kernel_gm_c23456( const float *grad_col, const float *feat_c2, const float *feat_c3, const float *feat_c4, const float *feat_c5, const float *feat_c6, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const int h_c6, const int w_c6, const float *data_sampling_loc, const float *data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float *grad_value_c2, float *grad_value_c3, float *grad_value_c4, float *grad_value_c5, float *grad_value_c6, float *grad_sampling_loc, float *grad_attn_weight) { CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point) { // n: bs x query x channels int _temp = index; const int p_col = _temp % num_point; _temp /= num_point; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; _temp /= num_query; const int b_col = _temp; const float top_grad = grad_col[index]; // Sampling location in range [0, 1] int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3; const float loc_w = data_sampling_loc[data_loc_ptr]; const float loc_h = data_sampling_loc[data_loc_ptr + 1]; const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1)); // Attn weights int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5; const float weight_c2 = data_attn_weight[data_weight_ptr]; const float weight_c3 = data_attn_weight[data_weight_ptr + 1]; const float weight_c4 = data_attn_weight[data_weight_ptr + 2]; const float weight_c5 = data_attn_weight[data_weight_ptr + 3]; const float weight_c6 = data_attn_weight[data_weight_ptr + 4]; // const float h_im = loc_h * spatial_h - 0.5; // align_corners = False // const float w_im = loc_w * spatial_w - 0.5; // C2 Feature float h_im = loc_h * (h_c2 - 1); // align_corners = True float w_im = loc_w * (w_c2 - 1); float *grad_location_ptr = grad_sampling_loc + data_loc_ptr; float *grad_weights_ptr = grad_attn_weight + data_weight_ptr; if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) { const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels; const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels; ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col, top_grad, weight_c2, grad_c2_ptr, grad_location_ptr, grad_weights_ptr); } grad_weights_ptr += 1; // C3 Feature h_im = loc_h * (h_c3 - 1); // align_corners = True w_im = loc_w * (w_c3 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) { const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels; const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels; ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col, top_grad, weight_c3, grad_c3_ptr, grad_location_ptr, grad_weights_ptr); } grad_weights_ptr += 1; // C4 Feature h_im = loc_h * (h_c4 - 1); // align_corners = True w_im = loc_w * (w_c4 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) { const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels; const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels; ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col, top_grad, weight_c4, grad_c4_ptr, grad_location_ptr, grad_weights_ptr); } grad_weights_ptr += 1; // C5 Feature h_im = loc_h * (h_c5 - 1); // align_corners = True w_im = loc_w * (w_c5 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) { const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels; const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels; ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col, top_grad, weight_c5, grad_c5_ptr, grad_location_ptr, grad_weights_ptr); } grad_weights_ptr += 1; // C6 Feature h_im = loc_h * (h_c6 - 1); // align_corners = True w_im = loc_w * (w_c6 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6) { const float *feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels; const float *grad_c6_ptr = grad_value_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels; ms_deform_attn_col2im_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col, top_grad, weight_c6, grad_c6_ptr, grad_location_ptr, grad_weights_ptr); } } } void ms_deformable_col2im_cuda_c2345( const float *grad_col, const float *feat_c2, const float *feat_c3, const float *feat_c4, const float *feat_c5, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const float *data_sampling_loc, const float *data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float *grad_value_c2, float *grad_value_c3, float *grad_value_c4, float *grad_value_c5, float *grad_sampling_loc, float *grad_attn_weight) { const int num_kernels = batch_size * num_query * channels * num_point; const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point; ms_deformable_col2im_gpu_kernel_gm_c2345 <<>>( grad_col, feat_c2, feat_c3, feat_c4, feat_c5, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_col2im_cuda_c2345: %s\n", cudaGetErrorString(err)); } } void ms_deformable_col2im_cuda_c23456( const float *grad_col, const float *feat_c2, const float *feat_c3, const float *feat_c4, const float *feat_c5, const float *feat_c6, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const int h_c6, const int w_c6, const float *data_sampling_loc, const float *data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float *grad_value_c2, float *grad_value_c3, float *grad_value_c4, float *grad_value_c5, float *grad_value_c6, float *grad_sampling_loc, float *grad_attn_weight) { const int num_kernels = batch_size * num_query * channels * num_point; const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point; ms_deformable_col2im_gpu_kernel_gm_c23456 <<>>( grad_col, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6, data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_col2im_cuda_c23456: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: models/csrc/msmv_sampling/msmv_sampling_forward.cu ================================================ /*! * Modified from Deformable DETR */ #include #include #include #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) #define CUDA_NUM_THREADS 512 #define MAX_POINT 32 inline int GET_BLOCKS(const int N, const int num_threads) { return (N + num_threads - 1) / num_threads; } __device__ float ms_deform_attn_im2col_bilinear( const float*& bottom_data, const int& height, const int& width, const int& channels, const float& h, const float& w, const int& c) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const float lh = h - h_low; const float lw = w - w_low; const float hh = 1 - lh, hw = 1 - lw; const int w_stride = channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; float v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c; v1 = bottom_data[ptr1]; } float v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c; v2 = bottom_data[ptr2]; } float v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c; v3 = bottom_data[ptr3]; } float v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c; v4 = bottom_data[ptr4]; } const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } __global__ void ms_deformable_im2col_gpu_kernel_c2345( const float* feat_c2, const float* feat_c3, const float* feat_c4, const float* feat_c5, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const float* data_sampling_loc, const float* data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float* data_col) { float res[MAX_POINT]; CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; _temp /= num_query; const int b_col = _temp; for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; } for (int p_col = 0; p_col < num_point; ++p_col) { // Sampling location in range [0, 1] int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3; const float loc_w = data_sampling_loc[data_loc_ptr]; const float loc_h = data_sampling_loc[data_loc_ptr + 1]; const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1)); // Attn weights int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4; const float weight_c2 = data_attn_weight[data_weight_ptr]; const float weight_c3 = data_attn_weight[data_weight_ptr + 1]; const float weight_c4 = data_attn_weight[data_weight_ptr + 2]; const float weight_c5 = data_attn_weight[data_weight_ptr + 3]; //const float h_im = loc_h * spatial_h - 0.5; // align_corners = False //const float w_im = loc_w * spatial_w - 0.5; // C2 Feature float h_im = loc_h * (h_c2 - 1); // align_corners = True float w_im = loc_w * (w_c2 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) { const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2; } // C3 Feature h_im = loc_h * (h_c3 - 1); // align_corners = True w_im = loc_w * (w_c3 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) { const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3; } // C4 Feature h_im = loc_h * (h_c4 - 1); // align_corners = True w_im = loc_w * (w_c4 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) { const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4; } // C5 Feature h_im = loc_h * (h_c5 - 1); // align_corners = True w_im = loc_w * (w_c5 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) { const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5; } } for (int p_col = 0; p_col < num_point; ++p_col) { float* data_col_ptr = data_col + index * num_point + p_col; *data_col_ptr = res[p_col]; } } } __global__ void ms_deformable_im2col_gpu_kernel_c23456( const float* feat_c2, const float* feat_c3, const float* feat_c4, const float* feat_c5, const float* feat_c6, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const int h_c6, const int w_c6, const float* data_sampling_loc, const float* data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float* data_col) { float res[MAX_POINT]; CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; _temp /= num_query; const int b_col = _temp; for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; } for (int p_col = 0; p_col < num_point; ++p_col) { // Sampling location in range [0, 1] int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3; const float loc_w = data_sampling_loc[data_loc_ptr]; const float loc_h = data_sampling_loc[data_loc_ptr + 1]; const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1)); // Attn weights int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5; const float weight_c2 = data_attn_weight[data_weight_ptr]; const float weight_c3 = data_attn_weight[data_weight_ptr + 1]; const float weight_c4 = data_attn_weight[data_weight_ptr + 2]; const float weight_c5 = data_attn_weight[data_weight_ptr + 3]; const float weight_c6 = data_attn_weight[data_weight_ptr + 4]; //const float h_im = loc_h * spatial_h - 0.5; // align_corners = False //const float w_im = loc_w * spatial_w - 0.5; // C2 Feature float h_im = loc_h * (h_c2 - 1); // align_corners = True float w_im = loc_w * (w_c2 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) { const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2; } // C3 Feature h_im = loc_h * (h_c3 - 1); // align_corners = True w_im = loc_w * (w_c3 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) { const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3; } // C4 Feature h_im = loc_h * (h_c4 - 1); // align_corners = True w_im = loc_w * (w_c4 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) { const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4; } // C5 Feature h_im = loc_h * (h_c5 - 1); // align_corners = True w_im = loc_w * (w_c5 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) { const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5; } // C6 Feature h_im = loc_h * (h_c6 - 1); // align_corners = True w_im = loc_w * (w_c6 - 1); if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6) { const float* feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels; res[p_col] += ms_deform_attn_im2col_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col) * weight_c6; } } for (int p_col = 0; p_col < num_point; ++p_col) { float* data_col_ptr = data_col + index * num_point + p_col; *data_col_ptr = res[p_col]; } } } void ms_deformable_im2col_cuda_c2345( const float* feat_c2, const float* feat_c3, const float* feat_c4, const float* feat_c5, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const float* data_sampling_loc, const float* data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float* data_col) { const int num_kernels = batch_size * num_query * channels; const int num_threads = CUDA_NUM_THREADS; ms_deformable_im2col_gpu_kernel_c2345 <<>> ( feat_c2, feat_c3, feat_c4, feat_c5, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col ); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_im2col_cuda_c2345: %s\n", cudaGetErrorString(err)); } } void ms_deformable_im2col_cuda_c23456( const float* feat_c2, const float* feat_c3, const float* feat_c4, const float* feat_c5, const float* feat_c6, const int h_c2, const int w_c2, const int h_c3, const int w_c3, const int h_c4, const int w_c4, const int h_c5, const int w_c5, const int h_c6, const int w_c6, const float* data_sampling_loc, const float* data_attn_weight, const int batch_size, const int channels, const int num_views, const int num_query, const int num_point, float* data_col) { const int num_kernels = batch_size * num_query * channels; const int num_threads = CUDA_NUM_THREADS; ms_deformable_im2col_gpu_kernel_c23456 <<>> ( feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6, data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col ); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_im2col_cuda_c23456: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: models/csrc/setup.py ================================================ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension def get_ext_modules(): return [ CUDAExtension( name='_msmv_sampling_cuda', sources=[ 'msmv_sampling/msmv_sampling.cpp', 'msmv_sampling/msmv_sampling_forward.cu', 'msmv_sampling/msmv_sampling_backward.cu' ], include_dirs=['msmv_sampling'] ) ] setup( name='csrc', ext_modules=get_ext_modules(), cmdclass={'build_ext': BuildExtension} ) ================================================ FILE: models/csrc/wrapper.py ================================================ import torch import torch.nn.functional as F from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights): """ value: [B, N, H1W1 + H2W2..., C] sampling_locations: [B, Q, P, 3] scale_weights: [B, Q, P, 4] """ assert scale_weights.shape[-1] == len(mlvl_feats) B, _, _, _, C = mlvl_feats[0].shape _, Q, P, _ = sampling_locations.shape sampling_locations = sampling_locations * 2 - 1 sampling_locations = sampling_locations[:, :, :, None, :] # [B, Q, P, 1, 3] final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device) for lvl, feat in enumerate(mlvl_feats): feat = feat.permute(0, 4, 1, 2, 3) out = F.grid_sample( feat, sampling_locations, mode='bilinear', padding_mode='zeros', align_corners=True, )[..., 0] # [B, C, Q, P] out = out * scale_weights[..., lvl].reshape(B, 1, Q, P) final += out return final.permute(0, 2, 1, 3) class MSMVSamplingC2345(torch.autograd.Function): @staticmethod def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights): ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights) assert callable(_ms_deform_attn_cuda_c2345_forward) return _ms_deform_attn_cuda_c2345_forward( feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights) @staticmethod def backward(ctx, grad_output): feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights = ctx.saved_tensors assert callable(_ms_deform_attn_cuda_c2345_backward) grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c2345_backward(grad_output.contiguous(), feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights ) return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight class MSMVSamplingC23456(torch.autograd.Function): @staticmethod def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights): ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights) assert callable(_ms_deform_attn_cuda_c23456_forward) return _ms_deform_attn_cuda_c23456_forward( feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights) @staticmethod def backward(ctx, grad_output): feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights = ctx.saved_tensors assert callable(_ms_deform_attn_cuda_c23456_backward) grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c23456_backward(grad_output.contiguous(), feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights ) return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight def msmv_sampling(mlvl_feats, sampling_locations, scale_weights): sampling_locations = sampling_locations.contiguous() scale_weights = scale_weights.contiguous() if len(mlvl_feats) == 4: return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights) elif len(mlvl_feats) == 5: return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights) else: return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights) ================================================ FILE: models/loss_utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from mmdet.models.builder import LOSSES, build_loss from mmdet.core import reduce_mean from .utils import sparse2dense from torch.cuda.amp import autocast from torch.autograd import Variable def get_voxel_decoder_loss_input(voxel_semantics, occ_loc_i, seg_pred_i, scale, num_classes=18): assert voxel_semantics.shape[0] == 1 # bs = 1 voxel_semantics = voxel_semantics.long() if seg_pred_i is not None: # semantic prediction assert seg_pred_i.shape[-1] == num_classes seg_pred_dense, sparse_mask = sparse2dense( occ_loc_i, seg_pred_i, dense_shape=[200 // scale, 200 // scale, 16 // scale, num_classes], empty_value=torch.zeros((num_classes)).to(seg_pred_i) ) sparse_mask = F.interpolate(sparse_mask[:, None].float(), scale_factor=scale)[:, 0].bool() seg_pred_dense = seg_pred_dense.permute(0, 4, 1, 2, 3) # [B, CLS, W, H, D] seg_pred_dense = F.interpolate(seg_pred_dense, scale_factor=scale) seg_pred_dense = seg_pred_dense.permute(0, 2, 3, 4, 1) # [B, W, H, D, CLS] seg_pred_i_sparse = seg_pred_dense[sparse_mask] # [K, CLS] voxel_semantics_sparse = voxel_semantics[sparse_mask] # [K] return seg_pred_i_sparse, voxel_semantics_sparse, sparse_mask def compute_scal_loss(pred, gt, class_id, reverse=False, ignore_index=255): p = pred[:, class_id, :] completion_target = (gt == class_id).long() loss = torch.zeros(pred.shape[0], device=pred.device) if reverse: p = 1 - p completion_target = ((gt != class_id) & (gt != ignore_index)).long() target_sum = completion_target.sum(dim=(1)) mask = (target_sum > 0) p = p[torch.where(mask)] completion_target = completion_target[torch.where(mask)] nominator = torch.sum(p * completion_target, dim=(1)) p_mask = torch.where(torch.sum(p, dim=(1)) > 0) if p_mask[0].shape[0] > 0: precision = nominator[p_mask] / torch.sum(p[p_mask], dim=(1)) loss_precision = F.binary_cross_entropy( precision, torch.ones_like(precision), reduction='none' ) loss[torch.where(mask)[0][p_mask]] += loss_precision t_mask = torch.where(torch.sum(completion_target, dim=(1)) > 0) if t_mask[0].shape[0] > 0: recall = nominator[t_mask] / torch.sum(completion_target[t_mask], dim=(1)) loss_recall = F.binary_cross_entropy( recall, torch.ones_like(recall), reduction='none' ) loss[torch.where(mask)[0][t_mask]] += loss_recall ct_mask = torch.where(torch.sum(1 - completion_target, dim=(1)) > 0) if ct_mask[0].shape[0] > 0: specificity = torch.sum((1 - p[ct_mask]) * (1 - completion_target[ct_mask]), dim=(1)) / ( torch.sum(1 - completion_target[ct_mask], dim=(1)) ) loss_ct = F.binary_cross_entropy( specificity, torch.ones_like(specificity), reduction='none' ) loss[torch.where(mask)[0][ct_mask]] += loss_ct return loss, mask @LOSSES.register_module() class GeoScalLoss(nn.Module): def __init__(self, num_classes, loss_weight=1.0): super().__init__() self.num_classes = num_classes self.loss_weight = loss_weight def forward(self, pred, gt): loss = torch.tensor(0, device=pred.device, dtype=pred.dtype) pred = F.softmax(pred, dim=1) loss, _ = compute_scal_loss(pred, gt, self.num_classes - 1, reverse=True) return self.loss_weight * torch.mean(loss) @LOSSES.register_module() class SemScalLoss(nn.Module): def __init__(self, num_classes, class_weights=None, loss_weight=1.0): super().__init__() self.num_classes = num_classes self.class_weights = class_weights if self.class_weights is not None: assert len(self.class_weights) == self.num_classes, "number of class weights must equal to class number" else: self.class_weights = [1.0 for _ in range(self.num_classes)] self.loss_weight = loss_weight def forward(self, pred, gt): pred = F.softmax(pred, dim=1) batch_size = pred.shape[0] loss = torch.zeros(batch_size, device=pred.device) count = torch.zeros(batch_size, device=pred.device) for i in range(self.num_classes): loss_cls, mask_cls = compute_scal_loss(pred, gt, i) count += mask_cls.long() loss += loss_cls * self.class_weights[i] return self.loss_weight * (loss / count).mean() # borrowed from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L21 def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, mask_camera: torch.Tensor ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ if mask_camera is not None: inputs = inputs[:, :, mask_camera] targets = targets[:, :, mask_camera] inputs = inputs.sigmoid() inputs = inputs.flatten(1) targets = targets.squeeze(1) numerator = 2 * (inputs * targets).sum(-1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_masks dice_loss_jit = torch.jit.script( dice_loss ) # type: torch.jit.ScriptModule # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L48 def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, mask_camera: torch.Tensor ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ # [M, 1, K] if mask_camera is not None: mask_camera = mask_camera.to(torch.int32) mask_camera = mask_camera[None, None, ...].expand(targets.shape[0], 1, mask_camera.shape[-1]) loss = F.binary_cross_entropy_with_logits(inputs, targets, mask_camera, reduction="none") else: loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") return loss.mean(2).mean(1).sum() / num_masks sigmoid_ce_loss_jit = torch.jit.script( sigmoid_ce_loss ) # type: torch.jit.ScriptModule def CE_ssc_loss(pred, target, class_weights=None, ignore_index=255): """ :param: prediction: the predicted tensor, must be [BS, C, ...] """ criterion = nn.CrossEntropyLoss( weight=class_weights, ignore_index=ignore_index, reduction="mean" ) with autocast(False): loss = criterion(pred, target.long()) return loss # https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L22 def lovasz_grad(gt_sorted): """ Computes gradient of the Lovasz extension w.r.t sorted errors See Alg. 1 in paper """ p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts - gt_sorted.float().cumsum(0) union = gts + (1 - gt_sorted).float().cumsum(0) jaccard = 1. - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] return jaccard # https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L157 def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): """ Multi-class Lovasz-Softmax loss probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. per_image: compute the loss per image instead of per batch ignore: void class labels """ if per_image: loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) for prob, lab in zip(probas, labels)) else: with autocast(False): loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) return loss # https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L176 def lovasz_softmax_flat(probas, labels, classes='present'): """ Multi-class Lovasz-Softmax loss probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) labels: [P] Tensor, ground truth labels (between 0 and C - 1) classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. """ if probas.numel() == 0: # only void pixels, the gradients should be 0 return probas * 0. C = probas.size(1) losses = [] class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes for c in class_to_sum: fg = (labels == c).float() # foreground for class c if (classes == 'present' and fg.sum() == 0): continue if C == 1: if len(classes) > 1: raise ValueError('Sigmoid output possible only with 1 class') class_pred = probas[:, 0] else: class_pred = probas[:, c] errors = (Variable(fg) - class_pred).abs() errors_sorted, perm = torch.sort(errors, 0, descending=True) perm = perm.data fg_sorted = fg[perm] losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) return mean(losses) # https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L207 def flatten_probas(probas, labels, ignore=None): """ Flattens predictions in the batch """ if probas.dim() == 2: if ignore is not None: valid = (labels != ignore) probas = probas[valid] labels = labels[valid] return probas, labels elif probas.dim() == 3: # assumes output of a sigmoid layer B, H, W = probas.size() probas = probas.view(B, 1, H, W) elif probas.dim() == 5: #3D segmentation B, C, L, H, W = probas.size() probas = probas.contiguous().view(B, C, L, H*W) B, C, H, W = probas.size() probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C labels = labels.view(-1) if ignore is None: return probas, labels valid = (labels != ignore) vprobas = probas[valid.nonzero().squeeze()] vlabels = labels[valid] return vprobas, vlabels # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L90 @LOSSES.register_module() class Mask2FormerLoss(nn.Module): def __init__(self, num_classes, loss_cls_weight=1.0, loss_mask_weight=1.0, loss_dice_weight=1.0, no_class_weight=0.1): super().__init__() self.num_classes = num_classes self.loss_cls_weight = loss_cls_weight self.loss_mask_weight = loss_mask_weight self.loss_dice_weight = loss_dice_weight self.no_class_weight = no_class_weight self.empty_weight = torch.ones(self.num_classes) self.empty_weight[-1] = self.no_class_weight self.loss_cls = build_loss(dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=2.0 )) def forward(self, mask_pred, class_pred, mask_gt, class_gt, indices, mask_camera): bs = mask_pred.shape[0] loss_masks = torch.tensor(0).to(mask_pred) loss_dices = torch.tensor(0).to(mask_pred) loss_classes = torch.tensor(0).to(mask_pred) num_total_pos = sum([tc.numel() for tc in class_gt]) avg_factor = torch.clamp(reduce_mean(class_pred.new_tensor([num_total_pos * 1.0])), min=1).item() for b in range(bs): mask_camera_b = mask_camera[b] if mask_camera is not None else None# N tgt_mask = mask_gt[b] num_instances = class_gt[b].shape[0] tgt_class = class_gt[b] tgt_mask = (tgt_mask.unsqueeze(-1) == torch.arange(num_instances).to(mask_gt.device)) tgt_mask = tgt_mask.permute(1, 0) src_idx, tgt_idx = indices[b] src_mask = mask_pred[b][src_idx] # [M, N], M is number of gt instances, N is number of remaining voxels tgt_mask = tgt_mask[tgt_idx] # [M, N] src_class = class_pred[b] # [Q, CLS] # pad non-aligned queries' tgt classes with 'no class' pad_tgt_class = torch.full( (src_class.shape[0], ), self.num_classes - 1, dtype=torch.int64, device=class_pred.device ) # [Q] pad_tgt_class[src_idx] = tgt_class[tgt_idx] # only calculates loss mask for aligned pairs loss_mask, loss_dice = self.loss_masks(src_mask, tgt_mask, avg_factor=avg_factor, mask_camera=mask_camera_b) # calculates loss class for all queries loss_class = self.loss_labels(src_class, pad_tgt_class, self.empty_weight.to(src_class.device), avg_factor=avg_factor) loss_masks += loss_mask * self.loss_mask_weight loss_dices += loss_dice * self.loss_dice_weight loss_classes += loss_class * self.loss_cls_weight return loss_masks, loss_dices, loss_classes # mask2former use point sampling to calculate loss of fewer important points # we omit point sampling as we have limited number of points def loss_masks(self, src_mask, tgt_mask, avg_factor=None, mask_camera=None): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ # No need to upsample predictions as we are using normalized coordinates :) # N x 1 x H x W num_masks = tgt_mask.shape[0] src_mask = src_mask.view(num_masks, 1, -1) tgt_mask = tgt_mask.view(num_masks, 1, -1) if avg_factor is None: avg_factor = num_masks loss_dice = dice_loss(src_mask, tgt_mask, avg_factor, mask_camera) loss_mask = sigmoid_ce_loss(src_mask, tgt_mask.float(), avg_factor, mask_camera) return loss_mask, loss_dice def loss_labels(self, src_class, tgt_class, empty_weight=None, avg_factor=None): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ return self.loss_cls( src_class, tgt_class, torch.ones_like(tgt_class), avg_factor=avg_factor ).mean() # --------------------------- HELPER FUNCTIONS --------------------------- def mean(l, empty=0): """ nanmean compatible with generators. """ l = iter(l) try: n = 1 acc = next(l) except StopIteration: if empty == 'raise': raise ValueError('Empty mean') return empty for n, v in enumerate(l, 2): acc += v if n == 1: return acc return acc / n ================================================ FILE: models/matcher.py ================================================ """ Modified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py """ import torch import torch.nn.functional as F from torch.cuda.amp import autocast from scipy.optimize import linear_sum_assignment from mmcv.runner import BaseModule from mmdet.core.bbox.match_costs import build_match_cost def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor, mask_camera: torch.Tensor): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ if mask_camera is not None: inputs = inputs[:, mask_camera] targets = targets[:, mask_camera] inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss batch_dice_loss_jit = torch.jit.script( batch_dice_loss ) # type: torch.jit.ScriptModule def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor, mask_camera: torch.Tensor): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ hw = inputs.shape[1] if mask_camera is not None: mask_camera = mask_camera.to(torch.int32) mask_camera = mask_camera[None].expand(inputs.shape[0], mask_camera.shape[-1]) pos = F.binary_cross_entropy_with_logits( inputs, torch.ones_like(inputs), mask_camera, reduction="none" ) neg = F.binary_cross_entropy_with_logits( inputs, torch.zeros_like(inputs), mask_camera, reduction="none" ) else: pos = F.binary_cross_entropy_with_logits( inputs, torch.ones_like(inputs), reduction="none" ) neg = F.binary_cross_entropy_with_logits( inputs, torch.zeros_like(inputs), reduction="none" ) loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( "nc,mc->nm", neg, (1 - targets) ) return loss / hw batch_sigmoid_ce_loss_jit = torch.jit.script( batch_sigmoid_ce_loss ) # type: torch.jit.ScriptModule # modified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py#L70 class HungarianMatcher(BaseModule): """This class computes an assignment between the targets and the predictions of the network For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are un-matched (and thus treated as non-objects). """ def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1): """Creates the matcher Params: cost_class: This is the relative weight of the classification error in the matching cost cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost """ super().__init__() self.cost_class = cost_class self.cost_mask = cost_mask self.cost_dice = cost_dice self.loss_focal = build_match_cost(dict(type='FocalLossCost', weight=2.0)) assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" @torch.no_grad() def forward(self, mask_pred, class_pred, mask_gt, class_gt, mask_camera): """ Args: mask_pred: [bs, num_query, num_voxel (65536)] class_pred: [bs, num_query, 17] mask_gt: [bs, num_voxel], value in range [0, num_obj - 1] class_gt: [[bs0_num_obj], [bs1_num_obj], ...], value in range [0, num_cls - 1] """ bs, num_queries = class_pred.shape[:2] indices = [] # Iterate through batch size for b in range(bs): mask_camera_b = mask_camera[b] if mask_camera is not None else None tgt_ids = class_gt[b] num_instances = tgt_ids.shape[0] # must be here, cause num of instances may change after masking # Compute the classification cost. Contrary to the loss, we don't use the NLL, # but approximate it in 1 - proba[target class]. # The 1 is a constant that doesn't change the matching, it can be ommitted. '''out_prob = class_pred[b].softmax(-1) # [num_queries, num_classes] cost_class = -out_prob[:, tgt_ids.long()].squeeze(1)''' # Compute the classification cost. We use focal loss provided by mmdet as sparsebev does out_prob = class_pred[b] # TODO cost_class = self.loss_focal(out_prob, tgt_ids.long()) out_mask = mask_pred[b] # [num_queries, H_pred, W_pred] # gt masks are already padded when preparing target tgt_mask = mask_gt[b] tgt_mask = (tgt_mask.unsqueeze(-1) == torch.arange(num_instances).to(mask_gt.device)) tgt_mask = tgt_mask.permute(1, 0) # [Q, N] # all masks share the same set of points for efficient matching! tgt_mask = tgt_mask.view(tgt_mask.shape[0], -1) out_mask = out_mask.view(out_mask.shape[0], -1) with autocast(enabled=False): out_mask = out_mask.float() tgt_mask = tgt_mask.float() # Compute the focal loss between masks cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask, mask_camera_b) # Compute the dice loss betwen masks cost_dice = batch_dice_loss(out_mask, tgt_mask, mask_camera_b) # Final cost matrix C = ( self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice ) C = C.reshape(num_queries, -1).cpu() indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] ================================================ FILE: models/sparse_voxel_decoder.py ================================================ import torch import torch.nn as nn from mmcv.runner import BaseModule from mmcv.cnn.bricks.transformer import FFN from .sparsebev_transformer import SparseBEVSelfAttention, SparseBEVSampling, AdaptiveMixing from .utils import DUMP, generate_grid, batch_indexing from .bbox.utils import encode_bbox import torch.nn.functional as F def index2point(coords, pc_range, voxel_size): """ coords: [B, N, 3], int pc_range: [-40, -40, -1.0, 40, 40, 5.4] voxel_size: float """ coords = coords * voxel_size coords = coords + torch.tensor(pc_range[:3], device=coords.device) return coords def point2bbox(coords, box_size): """ coords: [B, N, 3], float box_size: float """ wlh = torch.ones_like(coords.float()) * box_size bboxes = torch.cat([coords, wlh], dim=-1) # [B, N, 6] return bboxes def upsample(pre_feat, pre_coords, interval): ''' :param pre_feat: (Tensor), features from last level, (B, N, C) :param pre_coords: (Tensor), coordinates from last level, (B, N, 3) (3: x, y, z) :param interval: interval of voxels, interval = scale ** 2 :param num: 1 -> 8 :return: up_feat : upsampled features, (B, N*8, C//8) :return: up_coords: upsampled coordinates, (B, N*8, 3) ''' pos_list = [0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2]] bs, num_query, num_channels = pre_feat.shape up_feat = pre_feat.reshape(bs, num_query, 8, num_channels // 8) # [B, N, 8, C/8] up_coords = pre_coords.unsqueeze(2).repeat(1, 1, 8, 1).contiguous() # [B, N, 8, 3] for i in range(len(pos_list)): up_coords[:, :, i + 1, pos_list[i]] += interval up_feat = up_feat.reshape(bs, -1, num_channels // 8) up_coords = up_coords.reshape(bs, -1, 3) return up_feat, up_coords class SparseVoxelDecoder(BaseModule): def __init__(self, embed_dims=None, num_layers=None, num_frames=None, num_points=None, num_groups=None, num_levels=None, num_classes=None, semantic=False, topk_training=None, topk_testing=None, pc_range=None): super().__init__() self.embed_dims = embed_dims self.num_frames = num_frames self.num_layers = num_layers self.pc_range = pc_range self.semantic = semantic self.voxel_dim = [200, 200, 16] self.topk_training = topk_training self.topk_testing = topk_testing self.decoder_layers = nn.ModuleList() self.lift_feat_heads = nn.ModuleList() #self.occ_pred_heads = nn.ModuleList() if semantic: self.seg_pred_heads = nn.ModuleList() for i in range(num_layers): self.decoder_layers.append(SparseVoxelDecoderLayer( embed_dims=embed_dims, num_frames=num_frames, num_points=num_points // (2 ** i), num_groups=num_groups, num_levels=num_levels, pc_range=pc_range, self_attn=i in [0, 1] )) self.lift_feat_heads.append(nn.Sequential( nn.Linear(embed_dims, embed_dims * 8), nn.ReLU(inplace=True) )) #self.occ_pred_heads.append(nn.Linear(embed_dims, 1)) if semantic: self.seg_pred_heads.append(nn.Linear(embed_dims, num_classes)) @torch.no_grad() def init_weights(self): for i in range(len(self.decoder_layers)): self.decoder_layers[i].init_weights() def forward(self, mlvl_feats, img_metas): occ_preds = [] topk = self.topk_training if self.training else self.topk_testing B = len(img_metas) # init query coords interval = 2 ** self.num_layers query_coord = generate_grid(self.voxel_dim, interval).expand(B, -1, -1) # [B, N, 3] query_feat = torch.zeros([B, query_coord.shape[1], self.embed_dims], device=query_coord.device) # [B, N, C] for i, layer in enumerate(self.decoder_layers): DUMP.stage_count = i interval = 2 ** (self.num_layers - i) # 8 4 2 1 # bbox from coords query_bbox = index2point(query_coord, self.pc_range, voxel_size=0.4) # [B, N, 3] query_bbox = point2bbox(query_bbox, box_size=0.4 * interval) # [B, N, 6] query_bbox = encode_bbox(query_bbox, pc_range=self.pc_range) # [B, N, 6] # transformer layer query_feat = layer(query_feat, query_bbox, mlvl_feats, img_metas) # [B, N, C] # upsample 2x query_feat = self.lift_feat_heads[i](query_feat) # [B, N, 8C] query_feat_2x, query_coord_2x = upsample(query_feat, query_coord, interval // 2) if self.semantic: seg_pred_2x = self.seg_pred_heads[i](query_feat_2x) # [B, K, CLS] else: seg_pred_2x = None # sparsify after seg_pred non_free_prob = 1 - F.softmax(seg_pred_2x, dim=-1)[..., -1] # [B, K] indices = torch.topk(non_free_prob, k=topk[i], dim=1)[1] # [B, K] query_coord_2x = batch_indexing(query_coord_2x, indices, layout='channel_last') # [B, K, 3] query_feat_2x = batch_indexing(query_feat_2x, indices, layout='channel_last') # [B, K, C] seg_pred_2x = batch_indexing(seg_pred_2x, indices, layout='channel_last') # [B, K, CLS] occ_preds.append(( torch.div(query_coord_2x, interval // 2, rounding_mode='trunc').long(), None, seg_pred_2x, query_feat_2x, interval // 2) ) query_coord = query_coord_2x.detach() query_feat = query_feat_2x.detach() return occ_preds class SparseVoxelDecoderLayer(BaseModule): def __init__(self, embed_dims=None, num_frames=None, num_points=None, num_groups=None, num_levels=None, pc_range=None, self_attn=True): super().__init__() self.position_encoder = nn.Sequential( nn.Linear(3, embed_dims), nn.LayerNorm(embed_dims), nn.ReLU(inplace=True), nn.Linear(embed_dims, embed_dims), nn.LayerNorm(embed_dims), nn.ReLU(inplace=True), ) if self_attn: self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range, scale_adaptive=True) self.norm1 = nn.LayerNorm(embed_dims) else: self.self_attn = None self.sampling = SparseBEVSampling( embed_dims=embed_dims, num_frames=num_frames, num_groups=num_groups, num_points=num_points, num_levels=num_levels, pc_range=pc_range ) self.mixing = AdaptiveMixing( in_dim=embed_dims, in_points=num_points * num_frames, n_groups=num_groups, out_points=num_points * num_frames * num_groups ) self.ffn = FFN(embed_dims, feedforward_channels=embed_dims * 2, ffn_drop=0.1) self.norm2 = nn.LayerNorm(embed_dims) self.norm3 = nn.LayerNorm(embed_dims) @torch.no_grad() def init_weights(self): if self.self_attn is not None: self.self_attn.init_weights() self.sampling.init_weights() self.mixing.init_weights() self.ffn.init_weights() def forward(self, query_feat, query_bbox, mlvl_feats, img_metas): query_pos = self.position_encoder(query_bbox[..., :3]) query_feat = query_feat + query_pos if self.self_attn is not None: query_feat = self.norm1(self.self_attn(query_bbox, query_feat)) sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas) query_feat = self.norm2(self.mixing(sampled_feat, query_feat)) query_feat = self.norm3(self.ffn(query_feat)) return query_feat ================================================ FILE: models/sparsebev_head.py ================================================ import math import torch import torch.nn as nn from mmcv.runner import force_fp32 from mmdet.core import multi_apply, reduce_mean from mmdet.models import HEADS from mmdet.models.dense_heads import DETRHead from mmdet3d.core.bbox.coders import build_bbox_coder from mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes from .bbox.utils import normalize_bbox, encode_bbox @HEADS.register_module() class SparseBEVHead(DETRHead): def __init__(self, *args, num_classes, in_channels, query_denoising=True, query_denoising_groups=10, bbox_coder=None, code_size=10, code_weights=[1.0] * 10, train_cfg=dict(), test_cfg=dict(max_per_img=100), **kwargs): self.code_size = code_size self.code_weights = code_weights self.num_classes = num_classes self.in_channels = in_channels self.train_cfg = train_cfg self.test_cfg = test_cfg self.fp16_enabled = False self.embed_dims = in_channels super(SparseBEVHead, self).__init__(num_classes, in_channels, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs) self.code_weights = nn.Parameter(torch.tensor(self.code_weights), requires_grad=False) self.bbox_coder = build_bbox_coder(bbox_coder) self.pc_range = self.bbox_coder.pc_range self.dn_enabled = query_denoising self.dn_group_num = query_denoising_groups self.dn_weight = 1.0 self.dn_bbox_noise_scale = 0.5 self.dn_label_noise_scale = 0.5 def _init_layers(self): self.init_query_bbox = nn.Embedding(self.num_query, 10) # (x, y, z, w, l, h, sin, cos, vx, vy) self.label_enc = nn.Embedding(self.num_classes + 1, self.embed_dims - 1) # DAB-DETR nn.init.zeros_(self.init_query_bbox.weight[:, 2:3]) nn.init.zeros_(self.init_query_bbox.weight[:, 8:10]) nn.init.constant_(self.init_query_bbox.weight[:, 5:6], 1.5) grid_size = int(math.sqrt(self.num_query)) assert grid_size * grid_size == self.num_query x = y = torch.arange(grid_size) xx, yy = torch.meshgrid(x, y, indexing='ij') # [0, grid_size - 1] xy = torch.cat([xx[..., None], yy[..., None]], dim=-1) xy = (xy + 0.5) / grid_size # [0.5, grid_size - 0.5] / grid_size ~= (0, 1) with torch.no_grad(): self.init_query_bbox.weight[:, :2] = xy.reshape(-1, 2) # [Q, 2] def init_weights(self): self.transformer.init_weights() def forward(self, mlvl_feats, img_metas): query_bbox = self.init_query_bbox.weight.clone() # [Q, 10] #query_bbox[..., :3] = query_bbox[..., :3].sigmoid() B = mlvl_feats[0].shape[0] query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas) cls_scores, bbox_preds = self.transformer( query_bbox, query_feat, mlvl_feats, attn_mask=attn_mask, img_metas=img_metas, ) bbox_preds[..., 0] = bbox_preds[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0] bbox_preds[..., 1] = bbox_preds[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1] bbox_preds[..., 2] = bbox_preds[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2] bbox_preds = torch.cat([ bbox_preds[..., 0:2], bbox_preds[..., 3:5], bbox_preds[..., 2:3], bbox_preds[..., 5:10], ], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy] if mask_dict is not None and mask_dict['pad_size'] > 0: output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :] output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :] output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :] output_bbox_preds = bbox_preds[:, :, mask_dict['pad_size']:, :] mask_dict['output_known_lbs_bboxes'] = (output_known_cls_scores, output_known_bbox_preds) outs = { 'all_cls_scores': output_cls_scores, 'all_bbox_preds': output_bbox_preds, 'enc_cls_scores': None, 'enc_bbox_preds': None, 'dn_mask_dict': mask_dict, } else: outs = { 'all_cls_scores': cls_scores, 'all_bbox_preds': bbox_preds, 'enc_cls_scores': None, 'enc_bbox_preds': None, } return outs def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas): device = init_query_bbox.device indicator0 = torch.zeros([self.num_query, 1], device=device) init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1) init_query_feat = torch.cat([init_query_feat, indicator0], dim=1) if self.training and self.dn_enabled: targets = [{ 'bboxes': torch.cat([m['gt_bboxes_3d'].gravity_center, m['gt_bboxes_3d'].tensor[:, 3:]], dim=1).cuda(), 'labels': m['gt_labels_3d'].cuda().long() } for m in img_metas] known = [torch.ones_like(t['labels'], device=device) for t in targets] known_num = [sum(k) for k in known] # can be modified to selectively denosie some label or boxes; also known label prediction unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat([t['labels'] for t in targets]).clone() bboxes = torch.cat([t['bboxes'] for t in targets]).clone() batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)]) known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) # add noise known_indice = known_indice.repeat(self.dn_group_num, 1).view(-1) known_labels = labels.repeat(self.dn_group_num, 1).view(-1) known_bid = batch_idx.repeat(self.dn_group_num, 1).view(-1) known_bboxs = bboxes.repeat(self.dn_group_num, 1) # 9 known_labels_expand = known_labels.clone() known_bbox_expand = known_bboxs.clone() # noise on the box if self.dn_bbox_noise_scale > 0: wlh = known_bbox_expand[..., 3:6].clone() rand_prob = torch.rand_like(known_bbox_expand) * 2 - 1.0 known_bbox_expand[..., 0:3] += torch.mul(rand_prob[..., 0:3], wlh / 2) * self.dn_bbox_noise_scale # known_bbox_expand[..., 3:6] += torch.mul(rand_prob[..., 3:6], wlh) * self.dn_bbox_noise_scale # known_bbox_expand[..., 6:7] += torch.mul(rand_prob[..., 6:7], 3.14159) * self.dn_bbox_noise_scale known_bbox_expand = encode_bbox(known_bbox_expand, self.pc_range) known_bbox_expand[..., 0:3].clamp_(min=0.0, max=1.0) # nn.init.constant(known_bbox_expand[..., 8:10], 0.0) # noise on the label if self.dn_label_noise_scale > 0: p = torch.rand_like(known_labels_expand.float()) chosen_indice = torch.nonzero(p < self.dn_label_noise_scale).view(-1) # usually half of bbox noise new_label = torch.randint_like(chosen_indice, 0, self.num_classes) # randomly put a new one here known_labels_expand.scatter_(0, chosen_indice, new_label) known_feat_expand = label_enc(known_labels_expand) indicator1 = torch.ones([known_feat_expand.shape[0], 1], device=device) # add dn part indicator known_feat_expand = torch.cat([known_feat_expand, indicator1], dim=1) # construct final query dn_single_pad = int(max(known_num)) dn_pad_size = int(dn_single_pad * self.dn_group_num) dn_query_bbox = torch.zeros([dn_pad_size, init_query_bbox.shape[-1]], device=device) dn_query_feat = torch.zeros([dn_pad_size, self.embed_dims], device=device) input_query_bbox = torch.cat([dn_query_bbox, init_query_bbox], dim=0).repeat(batch_size, 1, 1) input_query_feat = torch.cat([dn_query_feat, init_query_feat], dim=0).repeat(batch_size, 1, 1) if len(known_num): map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3] map_known_indice = torch.cat([map_known_indice + dn_single_pad * i for i in range(self.dn_group_num)]).long() if len(known_bid): input_query_bbox[known_bid.long(), map_known_indice] = known_bbox_expand input_query_feat[(known_bid.long(), map_known_indice)] = known_feat_expand total_size = dn_pad_size + self.num_query attn_mask = torch.ones([total_size, total_size], device=device) < 0 # match query cannot see the reconstruct attn_mask[dn_pad_size:, :dn_pad_size] = True for i in range(self.dn_group_num): if i == 0: attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True if i == self.dn_group_num - 1: attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True else: attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True mask_dict = { 'known_indice': torch.as_tensor(known_indice).long(), 'batch_idx': torch.as_tensor(batch_idx).long(), 'map_known_indice': torch.as_tensor(map_known_indice).long(), 'known_lbs_bboxes': (known_labels, known_bboxs), 'pad_size': dn_pad_size } else: input_query_bbox = init_query_bbox.repeat(batch_size, 1, 1) input_query_feat = init_query_feat.repeat(batch_size, 1, 1) attn_mask = None mask_dict = None return input_query_bbox, input_query_feat, attn_mask, mask_dict def prepare_for_dn_loss(self, mask_dict): cls_scores, bbox_preds = mask_dict['output_known_lbs_bboxes'] known_labels, known_bboxs = mask_dict['known_lbs_bboxes'] map_known_indice = mask_dict['map_known_indice'].long() known_indice = mask_dict['known_indice'].long() batch_idx = mask_dict['batch_idx'].long() bid = batch_idx[known_indice] num_tgt = known_indice.numel() if len(cls_scores) > 0: cls_scores = cls_scores.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) bbox_preds = bbox_preds.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2) return known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt def dn_loss_single(self, cls_scores, bbox_preds, known_bboxs, known_labels, num_total_pos=None): # Compute the average number of gt boxes accross all gpus num_total_pos = cls_scores.new_tensor([num_total_pos]) num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1.0).item() # cls loss cls_scores = cls_scores.reshape(-1, self.cls_out_channels) bbox_weights = torch.ones_like(bbox_preds) label_weights = torch.ones_like(known_labels) loss_cls = self.loss_cls( cls_scores, known_labels.long(), label_weights, avg_factor=num_total_pos ) # regression L1 loss bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) normalized_bbox_targets = normalize_bbox(known_bboxs) isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) bbox_weights = bbox_weights * self.code_weights loss_bbox = self.loss_bbox( bbox_preds[isnotnan, :10], normalized_bbox_targets[isnotnan, :10], bbox_weights[isnotnan, :10], avg_factor=num_total_pos ) loss_cls = self.dn_weight * torch.nan_to_num(loss_cls) loss_bbox = self.dn_weight * torch.nan_to_num(loss_bbox) return loss_cls, loss_bbox @force_fp32(apply_to=('preds_dicts')) def calc_dn_loss(self, loss_dict, preds_dicts, num_dec_layers): known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt = \ self.prepare_for_dn_loss(preds_dicts['dn_mask_dict']) all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)] all_known_labels_list = [known_labels for _ in range(num_dec_layers)] all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)] dn_losses_cls, dn_losses_bbox = multi_apply( self.dn_loss_single, cls_scores, bbox_preds, all_known_bboxs_list, all_known_labels_list, all_num_tgts_list) loss_dict['loss_cls_dn'] = dn_losses_cls[-1] loss_dict['loss_bbox_dn'] = dn_losses_bbox[-1] num_dec_layer = 0 for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]): loss_dict[f'd{num_dec_layer}.loss_cls_dn'] = loss_cls_i loss_dict[f'd{num_dec_layer}.loss_bbox_dn'] = loss_bbox_i num_dec_layer += 1 return loss_dict def _get_target_single(self, cls_score, bbox_pred, gt_labels, gt_bboxes, gt_bboxes_ignore=None): num_bboxes = bbox_pred.size(0) # assigner and sampler assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, gt_labels, gt_bboxes_ignore, self.code_weights, True) sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes) pos_inds = sampling_result.pos_inds neg_inds = sampling_result.neg_inds # label targets labels = gt_bboxes.new_full((num_bboxes, ), self.num_classes, dtype=torch.long) labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] label_weights = gt_bboxes.new_ones(num_bboxes) # bbox targets bbox_targets = torch.zeros_like(bbox_pred)[..., :9] bbox_weights = torch.zeros_like(bbox_pred) bbox_weights[pos_inds] = 1.0 # DETR bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds) def get_targets(self, cls_scores_list, bbox_preds_list, gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list=None): assert gt_bboxes_ignore_list is None, \ 'Only supports for gt_bboxes_ignore setting to None.' num_imgs = len(cls_scores_list) gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)] (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply( self._get_target_single, cls_scores_list, bbox_preds_list, gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list) num_total_pos = sum((inds.numel() for inds in pos_inds_list)) num_total_neg = sum((inds.numel() for inds in neg_inds_list)) return (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) def loss_single(self, cls_scores, bbox_preds, gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list=None): num_imgs = cls_scores.size(0) cls_scores_list = [cls_scores[i] for i in range(num_imgs)] bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list) (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets labels = torch.cat(labels_list, 0) label_weights = torch.cat(label_weights_list, 0) bbox_targets = torch.cat(bbox_targets_list, 0) bbox_weights = torch.cat(bbox_weights_list, 0) # classification loss cls_scores = cls_scores.reshape(-1, self.cls_out_channels) # construct weighted avg_factor to match with the official DETR repo cls_avg_factor = num_total_pos * 1.0 + \ num_total_neg * self.bg_cls_weight if self.sync_cls_avg_factor: cls_avg_factor = reduce_mean( cls_scores.new_tensor([cls_avg_factor])) cls_avg_factor = max(cls_avg_factor, 1) loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor) # Compute the average number of gt boxes accross all gpus, for # normalization purposes num_total_pos = loss_cls.new_tensor([num_total_pos]) num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() # regression L1 loss bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) normalized_bbox_targets = normalize_bbox(bbox_targets) isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) bbox_weights = bbox_weights * self.code_weights loss_bbox = self.loss_bbox( bbox_preds[isnotnan, :10], normalized_bbox_targets[isnotnan, :10], bbox_weights[isnotnan, :10], avg_factor=num_total_pos ) loss_cls = torch.nan_to_num(loss_cls) loss_bbox = torch.nan_to_num(loss_bbox) return loss_cls, loss_bbox @force_fp32(apply_to=('preds_dicts')) def loss(self, gt_bboxes_list, gt_labels_list, preds_dicts, gt_bboxes_ignore=None): assert gt_bboxes_ignore is None, \ f'{self.__class__.__name__} only supports ' \ f'for gt_bboxes_ignore setting to None.' all_cls_scores = preds_dicts['all_cls_scores'] all_bbox_preds = preds_dicts['all_bbox_preds'] enc_cls_scores = preds_dicts['enc_cls_scores'] enc_bbox_preds = preds_dicts['enc_bbox_preds'] num_dec_layers = len(all_cls_scores) device = gt_labels_list[0].device gt_bboxes_list = [torch.cat( (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), dim=1).to(device) for gt_bboxes in gt_bboxes_list] all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)] losses_cls, losses_bbox = multi_apply( self.loss_single, all_cls_scores, all_bbox_preds, all_gt_bboxes_list, all_gt_labels_list, all_gt_bboxes_ignore_list) loss_dict = dict() # loss of proposal generated from encode feature map if enc_cls_scores is not None: binary_labels_list = [ torch.zeros_like(gt_labels_list[i]) for i in range(len(all_gt_labels_list)) ] enc_loss_cls, enc_losses_bbox = \ self.loss_single(enc_cls_scores, enc_bbox_preds, gt_bboxes_list, binary_labels_list, gt_bboxes_ignore) loss_dict['enc_loss_cls'] = enc_loss_cls loss_dict['enc_loss_bbox'] = enc_losses_bbox if 'dn_mask_dict' in preds_dicts and preds_dicts['dn_mask_dict'] is not None: loss_dict = self.calc_dn_loss(loss_dict, preds_dicts, num_dec_layers) # loss from the last decoder layer loss_dict['loss_cls'] = losses_cls[-1] loss_dict['loss_bbox'] = losses_bbox[-1] # loss from other decoder layers num_dec_layer = 0 for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]): loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i num_dec_layer += 1 return loss_dict @force_fp32(apply_to=('preds_dicts')) def get_bboxes(self, preds_dicts, img_metas, rescale=False): preds_dicts = self.bbox_coder.decode(preds_dicts) num_samples = len(preds_dicts) ret_list = [] for i in range(num_samples): preds = preds_dicts[i] bboxes = preds['bboxes'] bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 bboxes = LiDARInstance3DBoxes(bboxes, 9) scores = preds['scores'] labels = preds['labels'] ret_list.append([bboxes, scores, labels]) return ret_list ================================================ FILE: models/sparsebev_sampling.py ================================================ import torch from .bbox.utils import decode_bbox from .utils import rotation_3d_in_axis, DUMP from .csrc.wrapper import msmv_sampling def make_sample_points_from_bbox(query_bbox, offset, pc_range): ''' query_bbox: [B, Q, 10] offset: [B, Q, num_points, 4], normalized by stride ''' query_bbox = decode_bbox(query_bbox, pc_range) # [B, Q, 9] xyz = query_bbox[..., 0:3] # [B, Q, 3] wlh = query_bbox[..., 3:6] # [B, Q, 3] # NOTE: different from SparseBEV xyz += wlh / 2 # conver to center delta_xyz = offset[..., 0:3] # [B, Q, P, 3] delta_xyz = wlh[:, :, None, :] * delta_xyz # [B, Q, P, 3] if query_bbox.shape[-1] > 6: ang = query_bbox[..., 6:7] # [B, Q, 1] delta_xyz = rotation_3d_in_axis(delta_xyz, ang) # [B, Q, P, 3] sample_xyz = xyz[:, :, None, :] + delta_xyz # [B, Q, P, 3] return sample_xyz # [B, Q, P, 3] def make_sample_points_from_mask(valid_map, pc_range, occ_size, num_points, occ_loc=None, offset=None): ''' valid_map: [B, Q, W, H, D] or [B, Q, N] occ_loc: [B, N, 3] if valid map is sparse Return: [B, Q, GP, 3] in pc_range ''' B, Q = valid_map.shape[:2] occ_size = torch.tensor(occ_size).to(valid_map.device) sampling_pts = [] for b in range(B): indices = torch.where(valid_map[b]) if indices[0].shape[0] == 0: pts = torch.rand((Q, num_points, 3)).to(valid_map.device) else: if len(valid_map.shape) == 5: bin_count = valid_map[b].sum(dim=(1,2,3)) else: bin_count = valid_map[b].sum(dim=1) sampling_rand = torch.rand((Q, num_points)).to(bin_count.device) sampling_index = (sampling_rand * bin_count[:, None]).floor().long() low_bound = torch.cumsum(bin_count, dim=0) - bin_count sampling_index = sampling_index + low_bound[:, None] sampling_index[sampling_index >= indices[0].shape[0]] = indices[0].shape[0] -1 # this can happen when zeros appear in the tail sampling_index = sampling_index.to(valid_map.device) if occ_loc is None: # dense occ points pts = torch.stack((indices[1][sampling_index], indices[2][sampling_index], indices[3][sampling_index])) pts = pts.permute(1, 2, 0) else: occ_idx = indices[1][sampling_index] pts = occ_loc[b][occ_idx] # pad queries with no valid occ pts = pts.float() rand_sampling_points = torch.rand(((bin_count==0).sum(), num_points, 3)).to(pts.device) * occ_size pts[bin_count==0] = rand_sampling_points sampling_pts.append(pts) sampling_pts = torch.stack(sampling_pts) if offset is not None: sampling_pts = sampling_pts + offset sampling_pts = sampling_pts / occ_size sampling_pts[..., 0] = sampling_pts[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0] sampling_pts[..., 1] = sampling_pts[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1] sampling_pts[..., 2] = sampling_pts[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2] return sampling_pts def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5): B, Q, T, G, P, _ = sample_points.shape # [B, Q, T, G, P, 4] N = 6 sample_points = sample_points.reshape(B, Q, T, G * P, 3) if DUMP.enabled: torch.save(sample_points, '{}/sample_points_3d_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) # get the projection matrix lidar2img = lidar2img[:, :(T*N), None, None, :, :] # [B, TN, 1, 1, 4, 4] lidar2img = lidar2img.expand(B, T*N, Q, G * P, 4, 4) lidar2img = lidar2img.reshape(B, T, N, Q, G*P, 4, 4) # expand the points ones = torch.ones_like(sample_points[..., :1]) sample_points = torch.cat([sample_points, ones], dim=-1) # [B, Q, GP, 4] sample_points = sample_points[:, :, None, ..., None] # [B, Q, T, GP, 4] sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1) sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1] # project 3d sampling points to image sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4] # homo coord -> pixel coord homo = sample_points_cam[..., 2:3] homo_nonzero = torch.maximum(homo, torch.zeros_like(homo) + eps) sample_points_cam = sample_points_cam[..., 0:2] / homo_nonzero # [B, T, N, Q, GP, 2] # normalize sample_points_cam[..., 0] /= image_w sample_points_cam[..., 1] /= image_h # check if out of image valid_mask = ((homo > eps) \ & (sample_points_cam[..., 1:2] > 0.0) & (sample_points_cam[..., 1:2] < 1.0) & (sample_points_cam[..., 0:1] > 0.0) & (sample_points_cam[..., 0:1] < 1.0) ).squeeze(-1).float() # [B, T, N, Q, GP] if DUMP.enabled: torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1), '{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) torch.save(valid_mask, '{}/sample_points_cam_valid_mask_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N] sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2] i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device) i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device) i_time = torch.arange(T, dtype=torch.long, device=sample_points.device) i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device) i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1) i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1) i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1) i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1) i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1] sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2] valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1] sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / 5], dim=-1) sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3) sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3] sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3) scale_weights = scale_weights.reshape(B, Q, G, T, P, -1) scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5) scale_weights = scale_weights.reshape(B*G*T, Q, P, -1) final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights) C = final.shape[2] # [BTG, Q, C, P] final = final.reshape(B, T, G, Q, C, P) final = final.permute(0, 3, 2, 1, 5, 4) final = final.flatten(3, 4) # [B, Q, G, FP, C] return final ================================================ FILE: models/sparsebev_transformer.py ================================================ import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from mmcv.runner import BaseModule from mmcv.cnn import bias_init_with_prob from mmcv.cnn.bricks.transformer import MultiheadAttention, FFN from mmdet.models.utils.builder import TRANSFORMER from .bbox.utils import decode_bbox from .utils import inverse_sigmoid, DUMP from .sparsebev_sampling import sampling_4d, make_sample_points_from_bbox from .checkpoint import checkpoint as cp @TRANSFORMER.register_module() class SparseBEVTransformer(BaseModule): def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None): assert init_cfg is None, 'To prevent abnormal initialization ' \ 'behavior, init_cfg is not allowed to be set' super(SparseBEVTransformer, self).__init__(init_cfg=init_cfg) self.embed_dims = embed_dims self.pc_range = pc_range self.decoder = SparseBEVTransformerDecoder(embed_dims, num_frames, num_points, num_layers, num_levels, num_classes, code_size, pc_range=pc_range) @torch.no_grad() def init_weights(self): self.decoder.init_weights() def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas): cls_scores, bbox_preds = self.decoder(query_bbox, query_feat, mlvl_feats, attn_mask, img_metas) cls_scores = torch.nan_to_num(cls_scores) bbox_preds = torch.nan_to_num(bbox_preds) return cls_scores, bbox_preds class SparseBEVTransformerDecoder(BaseModule): def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None): super(SparseBEVTransformerDecoder, self).__init__(init_cfg) self.num_layers = num_layers self.pc_range = pc_range self.decoder_layer = SparseBEVTransformerDecoderLayer( embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range ) @torch.no_grad() def init_weights(self): self.decoder_layer.init_weights() def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas): cls_scores, bbox_preds = [], [] timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64) timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6]) time_diff = timestamps[:, :1, :] - timestamps time_diff = np.mean(time_diff, axis=-1).astype(np.float32) # [B, F] time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F] img_metas[0]['time_diff'] = time_diff lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32) lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4] img_metas[0]['lidar2img'] = lidar2img for lvl, feat in enumerate(mlvl_feats): B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W] N, T, G, C = 6, TN // 6, 4, GC // 4 feat = feat.reshape(B, T, N, G, C, H, W) feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C] feat = feat.reshape(B*T*G, N, H, W, C) # [BTG, C, N, H, W] mlvl_feats[lvl] = feat.contiguous() for i in range(self.num_layers): DUMP.stage_count = i query_feat, cls_score, bbox_pred = self.decoder_layer( query_bbox, query_feat, mlvl_feats, attn_mask, img_metas ) query_bbox = bbox_pred.clone().detach() cls_scores.append(cls_score) bbox_preds.append(bbox_pred) cls_scores = torch.stack(cls_scores) bbox_preds = torch.stack(bbox_preds) return cls_scores, bbox_preds class SparseBEVTransformerDecoderLayer(BaseModule): def __init__(self, embed_dims, num_frames=8, num_points=4, num_levels=4, num_classes=10, code_size=10, num_cls_fcs=2, num_reg_fcs=2, pc_range=[], init_cfg=None): super(SparseBEVTransformerDecoderLayer, self).__init__(init_cfg) self.embed_dims = embed_dims self.num_classes = num_classes self.code_size = code_size self.pc_range = pc_range self.position_encoder = nn.Sequential( nn.Linear(3, self.embed_dims), nn.LayerNorm(self.embed_dims), nn.ReLU(inplace=True), nn.Linear(self.embed_dims, self.embed_dims), nn.LayerNorm(self.embed_dims), nn.ReLU(inplace=True), ) self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range) self.sampling = SparseBEVSampling(embed_dims, num_frames=num_frames, num_groups=4, num_points=num_points, num_levels=num_levels, pc_range=pc_range) self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=4, out_points=128) self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1) self.norm1 = nn.LayerNorm(embed_dims) self.norm2 = nn.LayerNorm(embed_dims) self.norm3 = nn.LayerNorm(embed_dims) cls_branch = [] for _ in range(num_cls_fcs): cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) cls_branch.append(nn.LayerNorm(self.embed_dims)) cls_branch.append(nn.ReLU(inplace=True)) cls_branch.append(nn.Linear(self.embed_dims, self.num_classes)) self.cls_branch = nn.Sequential(*cls_branch) reg_branch = [] for _ in range(num_reg_fcs): reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims)) reg_branch.append(nn.ReLU(inplace=True)) reg_branch.append(nn.Linear(self.embed_dims, self.code_size)) self.reg_branch = nn.Sequential(*reg_branch) @torch.no_grad() def init_weights(self): self.self_attn.init_weights() self.sampling.init_weights() self.mixing.init_weights() bias_init = bias_init_with_prob(0.01) nn.init.constant_(self.cls_branch[-1].bias, bias_init) def refine_bbox(self, bbox_proposal, bbox_delta): xyz = inverse_sigmoid(bbox_proposal[..., 0:3]) xyz_delta = bbox_delta[..., 0:3] xyz_new = torch.sigmoid(xyz_delta + xyz) return torch.cat([xyz_new, bbox_delta[..., 3:]], dim=-1) def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas): """ query_bbox: [B, Q, 10] [cx, cy, cz, w, h, d, rot.sin, rot.cos, vx, vy] """ query_pos = self.position_encoder(query_bbox[..., :3]) query_feat = query_feat + query_pos query_feat = self.norm1(self.self_attn(query_bbox, query_feat, attn_mask)) sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas) query_feat = self.norm2(self.mixing(sampled_feat, query_feat)) query_feat = self.norm3(self.ffn(query_feat)) cls_score = self.cls_branch(query_feat) # [B, Q, num_classes] bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size] bbox_pred = self.refine_bbox(query_bbox, bbox_pred) time_diff = img_metas[0]['time_diff'] # [B, F] if time_diff.shape[1] > 1: time_diff = time_diff.clone() time_diff[time_diff < 1e-5] = 1.0 bbox_pred[..., 8:] = bbox_pred[..., 8:] / time_diff[:, 1:2, None] if DUMP.enabled: query_bbox_dec = decode_bbox(query_bbox, self.pc_range) bbox_pred_dec = decode_bbox(bbox_pred, self.pc_range) cls_score_sig = torch.sigmoid(cls_score) torch.save(query_bbox_dec, '{}/query_bbox_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) torch.save(bbox_pred_dec, '{}/bbox_pred_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) torch.save(cls_score_sig, '{}/cls_score_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) return query_feat, cls_score, bbox_pred class SparseBEVSelfAttention(BaseModule): def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], scale_adaptive=True): super().__init__() self.pc_range = pc_range self.attention = MultiheadAttention(embed_dims, num_heads, dropout, batch_first=True) if scale_adaptive: self.gen_tau = nn.Linear(embed_dims, num_heads) else: self.gen_tau = None @torch.no_grad() def init_weights(self): if self.gen_tau is not None: nn.init.zeros_(self.gen_tau.weight) nn.init.uniform_(self.gen_tau.bias, 0.0, 2.0) def inner_forward(self, query_bbox, query_feat, pre_attn_mask=None): """ query_bbox: [B, Q, 10] query_feat: [B, Q, C] """ if self.gen_tau is not None: dist = self.calc_bbox_dists(query_bbox) tau = self.gen_tau(query_feat) # [B, Q, 8] if DUMP.enabled: torch.save(tau, '{}/sasa_tau_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count)) tau = tau.permute(0, 2, 1) # [B, 8, Q] attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q] if pre_attn_mask is not None: attn_mask[:, :, pre_attn_mask] = float('-inf') attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q] else: attn_mask = None return self.attention(query_feat, attn_mask=attn_mask) def forward(self, query_bbox, query_feat, pre_attn_mask=None): if self.training and query_feat.requires_grad: return cp(self.inner_forward, query_bbox, query_feat, pre_attn_mask, use_reentrant=False) else: return self.inner_forward(query_bbox, query_feat, pre_attn_mask) @torch.no_grad() def calc_bbox_dists(self, bboxes): centers = decode_bbox(bboxes, self.pc_range)[..., :2] # [B, Q, 2] dist = [] for b in range(centers.shape[0]): dist_b = torch.norm(centers[b].reshape(-1, 1, 2) - centers[b].reshape(1, -1, 2), dim=-1) dist.append(dist_b[None, ...]) dist = torch.cat(dist, dim=0) # [B, Q, Q] dist = -dist return dist class SparseBEVSampling(BaseModule): def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None): super().__init__(init_cfg) self.num_frames = num_frames self.num_points = num_points self.num_groups = num_groups self.num_levels = num_levels self.pc_range = pc_range self.sampling_offset = nn.Linear(embed_dims, num_groups * num_points * 3) self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels) def init_weights(self): bias = self.sampling_offset.bias.data.view(self.num_groups * self.num_points, 3) nn.init.zeros_(self.sampling_offset.weight) nn.init.uniform_(bias[:, 0:3], -0.5, 0.5) def inner_forward(self, query_bbox, query_feat, mlvl_feats, img_metas): ''' query_bbox: [B, Q, 10] query_feat: [B, Q, C] ''' B, Q = query_bbox.shape[:2] image_h, image_w, _ = img_metas[0]['img_shape'][0] # sampling offset of all frames sampling_offset = self.sampling_offset(query_feat) sampling_offset = sampling_offset.view(B, Q, self.num_groups * self.num_points, 3) sampling_points = make_sample_points_from_bbox(query_bbox, sampling_offset, self.pc_range) # [B, Q, GP, 3] sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3) sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3) # warp sample points based on velocity if query_bbox.shape[-1] > 8: time_diff = img_metas[0]['time_diff'] # [B, F] time_diff = time_diff[:, None, :, None] # [B, 1, F, 1] vel = query_bbox[..., 8:].detach() # [B, Q, 2] vel = vel[:, :, None, :] # [B, Q, 1, 2] dist = vel * time_diff # [B, Q, F, 2] dist = dist[:, :, :, None, None, :] # [B, Q, F, 1, 1, 2] sampling_points = torch.cat([ sampling_points[..., 0:2] - dist, sampling_points[..., 2:3] ], dim=-1) # scale weights scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels) scale_weights = torch.softmax(scale_weights, dim=-1) scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels) # sampling sampled_feats = sampling_4d( sampling_points, mlvl_feats, scale_weights, img_metas[0]['lidar2img'], image_h, image_w ) # [B, Q, G, FP, C] return sampled_feats def forward(self, query_bbox, query_feat, mlvl_feats, img_metas): if self.training and query_feat.requires_grad: return cp(self.inner_forward, query_bbox, query_feat, mlvl_feats, img_metas, use_reentrant=False) else: return self.inner_forward(query_bbox, query_feat, mlvl_feats, img_metas) class AdaptiveMixing(nn.Module): def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None): super(AdaptiveMixing, self).__init__() out_dim = out_dim if out_dim is not None else in_dim out_points = out_points if out_points is not None else in_points query_dim = query_dim if query_dim is not None else in_dim self.query_dim = query_dim self.in_dim = in_dim self.in_points = in_points self.n_groups = n_groups self.out_dim = out_dim self.out_points = out_points self.eff_in_dim = in_dim // n_groups self.eff_out_dim = out_dim // n_groups self.m_parameters = self.eff_in_dim * self.eff_out_dim self.s_parameters = self.in_points * self.out_points self.total_parameters = self.m_parameters + self.s_parameters self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters) self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim) self.act = nn.ReLU(inplace=True) @torch.no_grad() def init_weights(self): nn.init.zeros_(self.parameter_generator.weight) def inner_forward(self, x, query): B, Q, G, P, C = x.shape assert G == self.n_groups assert P == self.in_points assert C == self.eff_in_dim '''generate mixing parameters''' params = self.parameter_generator(query) params = params.reshape(B*Q, G, -1) out = x.reshape(B*Q, G, P, C) M, S = params.split([self.m_parameters, self.s_parameters], 2) M = M.reshape(B*Q, G, self.eff_in_dim, self.eff_out_dim) S = S.reshape(B*Q, G, self.out_points, self.in_points) '''adaptive channel mixing''' out = torch.matmul(out, M) out = F.layer_norm(out, [out.size(-2), out.size(-1)]) out = self.act(out) '''adaptive point mixing''' out = torch.matmul(S, out) # implicitly transpose and matmul out = F.layer_norm(out, [out.size(-2), out.size(-1)]) out = self.act(out) '''linear transfomation to query dim''' out = out.reshape(B, Q, -1) out = self.out_proj(out) out = query + out return out def forward(self, x, query): if self.training and x.requires_grad: return cp(self.inner_forward, x, query, use_reentrant=False) else: return self.inner_forward(x, query) class AdaptiveMixingPointOnly(nn.Module): def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None): super(AdaptiveMixingPointOnly, self).__init__() out_dim = out_dim if out_dim is not None else in_dim out_points = out_points if out_points is not None else in_points query_dim = query_dim if query_dim is not None else in_dim self.query_dim = query_dim self.in_dim = in_dim self.in_points = in_points self.n_groups = n_groups self.out_dim = out_dim self.out_points = out_points self.eff_in_dim = in_dim // n_groups self.eff_out_dim = out_dim // n_groups self.s_parameters = self.in_points * self.out_points self.total_parameters = self.s_parameters self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters) self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim) self.act = nn.ReLU(inplace=True) @torch.no_grad() def init_weights(self): nn.init.zeros_(self.parameter_generator.weight) def inner_forward(self, x, query): B, Q, G, P, C = x.shape assert G == self.n_groups assert P == self.in_points assert C == self.eff_in_dim '''generate mixing parameters''' params = self.parameter_generator(query) params = params.reshape(B*Q, G, -1) out = x.reshape(B*Q, G, P, C) S = params.reshape(B*Q, G, self.out_points, self.in_points) '''adaptive spatial mixing''' out = torch.matmul(S, out) # implicitly transpose and matmul out = F.layer_norm(out, [out.size(-2), out.size(-1)]) out = self.act(out) '''linear transfomation to query dim''' out = out.reshape(B, Q, -1) out = self.out_proj(out) out = query + out return out def forward(self, x, query): if self.training and x.requires_grad: return cp(self.inner_forward, x, query, use_reentrant=False) else: return self.inner_forward(x, query) class DeformAggregation(nn.Module): def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None): super(DeformAggregation, self).__init__() out_dim = out_dim if out_dim is not None else in_dim out_points = out_points if out_points is not None else in_points query_dim = query_dim if query_dim is not None else in_dim self.query_dim = query_dim self.in_dim = in_dim self.in_points = in_points self.n_groups = n_groups self.out_dim = out_dim self.out_points = out_points self.eff_in_dim = in_dim // n_groups self.eff_out_dim = out_dim // n_groups self.attn_weights = nn.Linear(query_dim, n_groups * in_points) self.out_proj = nn.Linear(self.eff_in_dim * n_groups, self.query_dim) self.act = nn.ReLU(inplace=True) @torch.no_grad() def init_weights(self): pass def inner_forward(self, x, query): B, Q, G, P, C = x.shape assert G == self.n_groups assert P == self.in_points assert C == self.eff_in_dim out = x.reshape(B, Q, G, P, C) attn_weights = self.attn_weights(query) # [B, Q, GP] attn_weights = attn_weights.reshape(B, Q, self.n_groups, self.in_points, 1) # [B, Q, G, P, 1] attn_weights = attn_weights.softmax(dim=-2) out = torch.sum(out * attn_weights, dim=-2) # [B, Q, G, C] out = out.reshape(B, Q, -1) out = self.out_proj(out) out = query + out return out def forward(self, x, query): if self.training and x.requires_grad: return cp(self.inner_forward, x, query, use_reentrant=False) else: return self.inner_forward(x, query) ================================================ FILE: models/sparseocc.py ================================================ import torch import queue import numpy as np from mmcv.runner import get_dist_info from mmcv.runner.fp16_utils import cast_tensor_type from mmcv.runner import force_fp32, auto_fp16 from mmdet.models import DETECTORS from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector from .utils import pad_multiple, GpuPhotoMetricDistortion @DETECTORS.register_module() class SparseOcc(MVXTwoStageDetector): def __init__(self, pts_voxel_layer=None, pts_voxel_encoder=None, pts_middle_encoder=None, pts_fusion_layer=None, img_backbone=None, pts_backbone=None, img_neck=None, pts_neck=None, pts_bbox_head=None, img_roi_head=None, img_rpn_head=None, train_cfg=None, test_cfg=None, pretrained=None, data_aug=None, use_mask_camera=False, **kwargs): super(SparseOcc, self).__init__(pts_voxel_layer, pts_voxel_encoder, pts_middle_encoder, pts_fusion_layer, img_backbone, pts_backbone, img_neck, pts_neck, pts_bbox_head, img_roi_head, img_rpn_head, train_cfg, test_cfg, pretrained) self.use_mask_camera = use_mask_camera self.fp16_enabled = False self.data_aug = data_aug self.color_aug = GpuPhotoMetricDistortion() self.memory = {} self.queue = queue.Queue() @auto_fp16(apply_to=('img'), out_fp32=True) def extract_img_feat(self, img): img_feats = self.img_backbone(img) if isinstance(img_feats, dict): img_feats = list(img_feats.values()) if self.with_img_neck: img_feats = self.img_neck(img_feats) return img_feats @auto_fp16(apply_to=('img')) def extract_feat(self, img, img_metas=None): """Extract features from images and points.""" if len(img.shape) == 6: img = img.flatten(1, 2) # [B, TN, C, H, W] B, N, C, H, W = img.size() img = img.view(B * N, C, H, W) img = img.float() if self.data_aug is not None: if 'img_color_aug' in self.data_aug and self.data_aug['img_color_aug'] and self.training: img = self.color_aug(img) if 'img_norm_cfg' in self.data_aug: img_norm_cfg = self.data_aug['img_norm_cfg'] norm_mean = torch.tensor(img_norm_cfg['mean'], device=img.device) norm_std = torch.tensor(img_norm_cfg['std'], device=img.device) if img_norm_cfg['to_rgb']: img = img[:, [2, 1, 0], :, :] # BGR to RGB img = img - norm_mean.reshape(1, 3, 1, 1) img = img / norm_std.reshape(1, 3, 1, 1) for b in range(B): img_shape = (img.shape[2], img.shape[3], img.shape[1]) img_metas[b]['img_shape'] = [img_shape for _ in range(N)] img_metas[b]['ori_shape'] = [img_shape for _ in range(N)] if 'img_pad_cfg' in self.data_aug: img_pad_cfg = self.data_aug['img_pad_cfg'] img = pad_multiple(img, img_metas, size_divisor=img_pad_cfg['size_divisor']) H, W = img.shape[-2:] input_shape = img.shape[-2:] # update real input shape of each single img for img_meta in img_metas: img_meta.update(input_shape=input_shape) img_feats = self.extract_img_feat(img) img_feats_reshaped = [] for img_feat in img_feats: BN, C, H, W = img_feat.size() img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W)) return img_feats_reshaped def forward_pts_train(self, mlvl_feats, voxel_semantics, voxel_instances, instance_class_ids, mask_camera, img_metas): """ voxel_semantics: [bs, 200, 200, 16], value in range [0, num_cls - 1] voxel_instances: [bs, 200, 200, 16], value in range [0, num_obj - 1] instance_class_ids: [[bs0_num_obj], [bs1_num_obj], ...], value in range [0, num_cls - 1] """ outs = self.pts_bbox_head(mlvl_feats, img_metas) loss_inputs = [voxel_semantics, voxel_instances, instance_class_ids, outs] return self.pts_bbox_head.loss(*loss_inputs) def forward(self, return_loss=True, **kwargs): if return_loss: return self.forward_train(**kwargs) else: return self.forward_test(**kwargs) @force_fp32(apply_to=('img')) def forward_train(self, img_metas=None, img=None, voxel_semantics=None, voxel_instances=None, instance_class_ids=None, mask_camera=None, **kwargs): img_feats = self.extract_feat(img=img, img_metas=img_metas) return self.forward_pts_train(img_feats, voxel_semantics, voxel_instances, instance_class_ids, mask_camera, img_metas) def forward_test(self, img_metas, img=None, **kwargs): output = self.simple_test(img_metas, img) sem_pred = output['sem_pred'].cpu().numpy().astype(np.uint8) occ_loc = output['occ_loc'].cpu().numpy().astype(np.uint8) batch_size = sem_pred.shape[0] if 'pano_inst' and 'pano_sem' in output: # important: uint8 is not enough for pano_pred pano_inst = output['pano_inst'].cpu().numpy().astype(np.int16) pano_sem = output['pano_sem'].cpu().numpy().astype(np.uint8) return [{ 'sem_pred': sem_pred[b:b+1], 'pano_inst': pano_inst[b:b+1], 'pano_sem': pano_sem[b:b+1], 'occ_loc': occ_loc[b:b+1] } for b in range(batch_size)] else: return [{ 'sem_pred': sem_pred[b:b+1], 'occ_loc': occ_loc[b:b+1] } for b in range(batch_size)] def simple_test_pts(self, x, img_metas, rescale=False): outs = self.pts_bbox_head(x, img_metas) outs = self.pts_bbox_head.merge_occ_pred(outs) return outs def simple_test(self, img_metas, img=None, rescale=False): world_size = get_dist_info()[1] if world_size == 1: # online return self.simple_test_online(img_metas, img, rescale) else: # offline return self.simple_test_offline(img_metas, img, rescale) def simple_test_offline(self, img_metas, img=None, rescale=False): img_feats = self.extract_feat(img=img, img_metas=img_metas) return self.simple_test_pts(img_feats, img_metas, rescale=rescale) def simple_test_online(self, img_metas, img=None, rescale=False): self.fp16_enabled = False assert len(img_metas) == 1 # batch_size = 1 B, N, C, H, W = img.shape img = img.reshape(B, N//6, 6, C, H, W) img_filenames = img_metas[0]['filename'] num_frames = len(img_filenames) // 6 # assert num_frames == img.shape[1] img_shape = (H, W, C) img_metas[0]['img_shape'] = [img_shape for _ in range(len(img_filenames))] img_metas[0]['ori_shape'] = [img_shape for _ in range(len(img_filenames))] img_metas[0]['pad_shape'] = [img_shape for _ in range(len(img_filenames))] img_feats_list, img_metas_list = [], [] # extract feature frame by frame for i in range(num_frames): img_indices = list(np.arange(i * 6, (i + 1) * 6)) img_metas_curr = [{}] for k in img_metas[0].keys(): if isinstance(img_metas[0][k], list): img_metas_curr[0][k] = [img_metas[0][k][i] for i in img_indices] if img_filenames[img_indices[0]] in self.memory: # found in memory img_feats_curr = self.memory[img_filenames[img_indices[0]]] else: # extract feature and put into memory img_feats_curr = self.extract_feat(img[:, i], img_metas_curr) self.memory[img_filenames[img_indices[0]]] = img_feats_curr self.queue.put(img_filenames[img_indices[0]]) while self.queue.qsize() > 16: # avoid OOM pop_key = self.queue.get() self.memory.pop(pop_key) img_feats_list.append(img_feats_curr) img_metas_list.append(img_metas_curr) # reorganize feat_levels = len(img_feats_list[0]) img_feats_reorganized = [] for j in range(feat_levels): feat_l = torch.cat([img_feats_list[i][j] for i in range(len(img_feats_list))], dim=0) feat_l = feat_l.flatten(0, 1)[None, ...] img_feats_reorganized.append(feat_l) img_metas_reorganized = img_metas_list[0] for i in range(1, len(img_metas_list)): for k, v in img_metas_list[i][0].items(): if isinstance(v, list): img_metas_reorganized[0][k].extend(v) img_feats = img_feats_reorganized img_metas = img_metas_reorganized img_feats = cast_tensor_type(img_feats, torch.half, torch.float32) # run detector return self.simple_test_pts(img_feats, img_metas, rescale=rescale) ================================================ FILE: models/sparseocc_head.py ================================================ import numpy as np import torch import torch.nn as nn from mmdet.models import HEADS from mmcv.runner import force_fp32, auto_fp16 from mmdet.models.builder import build_loss from mmdet.models.utils import build_transformer from .matcher import HungarianMatcher from .loss_utils import CE_ssc_loss, lovasz_softmax, get_voxel_decoder_loss_input NUSC_CLASS_FREQ = np.array([ 944004, 1897170, 152386, 2391677, 16957802, 724139, 189027, 2074468, 413451, 2384460, 5916653, 175883646, 4275424, 51393615, 61411620, 105975596, 116424404, 1892500630 ]) @HEADS.register_module() class SparseOccHead(nn.Module): def __init__(self, transformer=None, class_names=None, embed_dims=None, occ_size=None, pc_range=None, loss_cfgs=None, panoptic=False, **kwargs): super(SparseOccHead, self).__init__() self.num_classes = len(class_names) self.class_names = class_names self.pc_range = pc_range self.occ_size = occ_size self.embed_dims = embed_dims self.score_threshold = 0.3 self.overlap_threshold = 0.8 self.panoptic = panoptic self.transformer = build_transformer(transformer) self.criterions = {k: build_loss(loss_cfg) for k, loss_cfg in loss_cfgs.items()} self.matcher = HungarianMatcher(cost_class=2.0, cost_mask=5.0, cost_dice=5.0) self.class_weights = torch.from_numpy(1 / np.log(NUSC_CLASS_FREQ + 0.001)) def init_weights(self): self.transformer.init_weights() @auto_fp16(apply_to=('mlvl_feats')) def forward(self, mlvl_feats, img_metas): occ_preds, mask_preds, class_preds = self.transformer(mlvl_feats, img_metas=img_metas) return { 'occ_preds': occ_preds, 'mask_preds': mask_preds, 'class_preds': class_preds } @force_fp32(apply_to=('preds_dicts')) def loss(self, voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera=None): return self.loss_single(voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera) def loss_single(self, voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera=None): loss_dict = {} B = voxel_instances.shape[0] if mask_camera is not None: assert mask_camera.shape == voxel_semantics.shape assert mask_camera.dtype == torch.bool for i, (occ_loc_i, _, seg_pred_i, _, scale) in enumerate(preds_dicts['occ_preds']): loss_dict_i = {} for b in range(B): loss_dict_i_b = {} seg_pred_i_sparse, voxel_semantics_sparse, sparse_mask = get_voxel_decoder_loss_input( voxel_semantics[b:b + 1], occ_loc_i[b:b + 1], seg_pred_i[b:b + 1] if seg_pred_i is not None else None, scale, self.num_classes ) loss_dict_i_b['loss_sem_lovasz'] = lovasz_softmax(torch.softmax(seg_pred_i_sparse, dim=1), voxel_semantics_sparse) valid_mask = (voxel_semantics_sparse < 255) seg_pred_i_sparse = seg_pred_i_sparse[valid_mask].transpose(0, 1).unsqueeze(0) # [K, CLS] -> [B, CLS, K] voxel_semantics_sparse = voxel_semantics_sparse[valid_mask].unsqueeze(0) # [K] -> [B, K] if 'loss_geo_scal' in self.criterions.keys(): loss_dict_i_b['loss_geo_scal'] = self.criterions['loss_geo_scal'](seg_pred_i_sparse, voxel_semantics_sparse) if 'loss_sem_scal' in self.criterions.keys(): loss_dict_i_b['loss_sem_scal'] = self.criterions['loss_sem_scal'](seg_pred_i_sparse, voxel_semantics_sparse) loss_dict_i_b['loss_sem_ce'] = CE_ssc_loss(seg_pred_i_sparse, voxel_semantics_sparse, self.class_weights.type_as(seg_pred_i_sparse)) for loss_key in loss_dict_i_b.keys(): loss_dict_i[loss_key] = loss_dict_i.get(loss_key, 0) + loss_dict_i_b[loss_key] / B for k, v in loss_dict_i.items(): loss_dict['%s_%d' % (k, i)] = v occ_loc = preds_dicts['occ_preds'][-1][0] batch_idx = torch.arange(B)[:, None, None].expand(B, occ_loc.shape[1], 1).to(occ_loc.device) occ_loc = occ_loc.reshape(-1, 3) voxel_instances = voxel_instances[batch_idx.reshape(-1), occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]] voxel_instances = voxel_instances.reshape(B, -1) # [B, N] if mask_camera is not None: mask_camera = mask_camera[batch_idx.reshape(-1), occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]] mask_camera = mask_camera.reshape(B, -1) # [B, N] # drop instances if it has no positive voxels for b in range(B): instance_count = instance_class_ids[b].shape[0] instance_voxel_counts = torch.bincount(voxel_instances[b].long()) # [255] id_map = torch.cumsum(instance_voxel_counts > 0, dim=0) - 1 id_map[255] = 255 # empty space still has an id of 255 voxel_instances[b] = id_map[voxel_instances[b].long()] instance_class_ids[b] = instance_class_ids[b][instance_voxel_counts[:instance_count] > 0] for i, pred in enumerate(preds_dicts['mask_preds']): indices = self.matcher(pred, preds_dicts['class_preds'][i], voxel_instances, instance_class_ids, mask_camera) loss_mask, loss_dice, loss_class = self.criterions['loss_mask2former']( pred, preds_dicts['class_preds'][i], voxel_instances, instance_class_ids, indices, mask_camera) loss_dict['loss_mask_{:d}'.format(i)] = loss_mask loss_dict['loss_dice_mask_{:d}'.format(i)] = loss_dice loss_dict['loss_class_{:d}'.format(i)] = loss_class return loss_dict def merge_occ_pred(self, outs): mask_cls = outs['class_preds'][-1].sigmoid() mask_pred = outs['mask_preds'][-1].sigmoid() occ_indices = outs['occ_preds'][-1][0] sem_pred = self.merge_semseg(mask_cls, mask_pred) # [B, C, N] outs['sem_pred'] = sem_pred outs['occ_loc'] = occ_indices if self.panoptic: pano_inst, pano_sem = self.merge_panoseg(mask_cls, mask_pred) # [B, C, N] outs['pano_inst'] = pano_inst outs['pano_sem'] = pano_sem return outs # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/mask_former_model.py#L242 def merge_semseg(self, mask_cls, mask_pred): valid_mask = mask_cls.max(dim=-1).values > self.score_threshold mask_cls[~valid_mask] = 0.0 semseg = torch.einsum("bqc,bqn->bcn", mask_cls, mask_pred) if semseg.shape[1] == self.num_classes: semseg = semseg[:, :-1] cls_score, cls_id = torch.max(semseg, dim=1) cls_id[cls_score < 0.01] = self.num_classes - 1 return cls_id # [B, N] def merge_panoseg(self, mask_cls, mask_pred): pano_inst, pano_sem = [], [] for b in range(mask_cls.shape[0]): pano_inst_b, pano_sem_b = self.merge_panoseg_single( mask_cls[b:b + 1], mask_pred[b:b + 1] ) pano_inst.append(pano_inst_b) pano_sem.append(pano_sem_b) pano_inst = torch.cat(pano_inst, dim=0) pano_sem = torch.cat(pano_sem, dim=0) return pano_inst, pano_sem # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/maskformer_model.py#L286 def merge_panoseg_single(self, mask_cls, mask_pred): assert mask_cls.shape[0] == 1, "bs != 1" scores, labels = mask_cls.max(-1) # filter out low score and background instances keep = labels.ne(self.num_classes - 1) & (scores > self.score_threshold) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_prob_masks = cur_scores.view(-1, 1) * cur_masks N = cur_masks.shape[-1] instance_seg = torch.zeros((N), dtype=torch.int32, device=cur_masks.device) semantic_seg = torch.ones((N), dtype=torch.int32, device=cur_masks.device) * (self.num_classes - 1) current_segment_id = 0 stuff_memory_list = {self.num_classes - 1: 0} # skip all process if no mask is detected if cur_masks.shape[0] != 0: # take argmax cur_mask_ids = cur_prob_masks.argmax(0) # [N] for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() # moving objects are treated as instances is_thing = self.class_names[pred_class] in [ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'motorcycle', 'bicycle', 'pedestrian' ] mask_area = (cur_mask_ids == k).sum().item() original_area = (cur_masks[k] >= 0.5).sum().item() mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: if mask_area / original_area < self.overlap_threshold: continue # merge stuff regions if not is_thing: if int(pred_class) in stuff_memory_list.keys(): instance_seg[mask] = stuff_memory_list[int(pred_class)] continue else: stuff_memory_list[int(pred_class)] = current_segment_id + 1 current_segment_id += 1 instance_seg[mask] = current_segment_id semantic_seg[mask] = pred_class instance_seg = instance_seg.unsqueeze(0) semantic_seg = semantic_seg.unsqueeze(0) return instance_seg, semantic_seg # [B, N] ================================================ FILE: models/sparseocc_transformer.py ================================================ import copy import numpy as np import torch import torch.nn as nn from mmcv.runner import BaseModule from mmdet.models.utils.builder import TRANSFORMER from mmcv.cnn.bricks.transformer import FFN from .sparsebev_transformer import AdaptiveMixing from .utils import DUMP from .checkpoint import checkpoint as cp from .sparsebev_sampling import sampling_4d, make_sample_points_from_mask from .sparse_voxel_decoder import SparseVoxelDecoder @TRANSFORMER.register_module() class SparseOccTransformer(BaseModule): def __init__(self, embed_dims=None, num_layers=None, num_queries=None, num_frames=None, num_points=None, num_groups=None, num_levels=None, num_classes=None, pc_range=None, occ_size=None, topk_training=None, topk_testing=None): super().__init__() self.num_frames = num_frames self.voxel_decoder = SparseVoxelDecoder( embed_dims=embed_dims, num_layers=3, num_frames=num_frames, num_points=num_points, num_groups=num_groups, num_levels=num_levels, num_classes=num_classes, pc_range=pc_range, semantic=True, topk_training=topk_training, topk_testing=topk_testing ) self.decoder = MaskFormerOccDecoder( embed_dims=embed_dims, num_layers=num_layers, num_frames=num_frames, num_queries=num_queries, num_points=num_points, num_groups=num_groups, num_levels=num_levels, num_classes=num_classes, pc_range=pc_range, occ_size=occ_size, ) @torch.no_grad() def init_weights(self): self.voxel_decoder.init_weights() self.decoder.init_weights() def forward(self, mlvl_feats, img_metas): for lvl, feat in enumerate(mlvl_feats): B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W] N, T, G, C = 6, TN // 6, 4, GC // 4 feat = feat.reshape(B, T, N, G, C, H, W) feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C] feat = feat.reshape(B*T*G, N, H, W, C) # [BTG, N, H, W, C] mlvl_feats[lvl] = feat.contiguous() lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32) lidar2img = torch.from_numpy(lidar2img).to(feat.device) # [B, N, 4, 4] ego2lidar = np.asarray([m['ego2lidar'] for m in img_metas]).astype(np.float32) ego2lidar = torch.from_numpy(ego2lidar).to(feat.device) # [B, N, 4, 4] img_metas = copy.deepcopy(img_metas) img_metas[0]['lidar2img'] = torch.matmul(lidar2img, ego2lidar) occ_preds = self.voxel_decoder(mlvl_feats, img_metas=img_metas) mask_preds, class_preds = self.decoder(occ_preds, mlvl_feats, img_metas) return occ_preds, mask_preds, class_preds class MaskFormerOccDecoder(BaseModule): def __init__(self, embed_dims=None, num_layers=None, num_frames=None, num_queries=None, num_points=None, num_groups=None, num_levels=None, num_classes=None, pc_range=None, occ_size=None): super().__init__() self.num_layers = num_layers self.num_queries = num_queries self.num_frames = num_frames self.decoder_layer = MaskFormerOccDecoderLayer( embed_dims=embed_dims, mask_dim=embed_dims, num_frames=num_frames, num_points=num_points, num_groups=num_groups, num_levels=num_levels, num_classes=num_classes, pc_range=pc_range, occ_size=occ_size, ) self.query_feat = nn.Embedding(num_queries, embed_dims) self.query_pos = nn.Embedding(num_queries, embed_dims) @torch.no_grad() def init_weights(self): self.decoder_layer.init_weights() def forward(self, occ_preds, mlvl_feats, img_metas): occ_loc, occ_pred, _, mask_feat, _ = occ_preds[-1] bs = mask_feat.shape[0] query_feat = self.query_feat.weight[None].repeat(bs, 1, 1) query_pos = self.query_pos.weight[None].repeat(bs, 1, 1) valid_map, mask_pred, class_pred = self.decoder_layer.pred_segmentation(query_feat, mask_feat) class_preds = [class_pred] mask_preds = [mask_pred] for i in range(self.num_layers): DUMP.stage_count = i query_feat, valid_map, mask_pred, class_pred = self.decoder_layer( query_feat, valid_map, mask_pred, occ_preds, mlvl_feats, query_pos, img_metas ) mask_preds.append(mask_pred) class_preds.append(class_pred) return mask_preds, class_preds class MaskFormerOccDecoderLayer(BaseModule): def __init__(self, embed_dims=None, mask_dim=None, num_frames=None, num_queries=None, num_points=None, num_groups=None, num_levels=None, num_classes=None, pc_range=None, occ_size=None): super().__init__() self.pc_range = pc_range self.occ_size = occ_size self.self_attn = MaskFormerSelfAttention(embed_dims, num_heads=8) self.sampling = MaskFormerSampling(embed_dims, num_frames, num_groups, num_points, num_levels, pc_range=pc_range, occ_size=occ_size) self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=num_groups, out_points=128) self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1) self.mask_proj = nn.Linear(embed_dims, mask_dim) self.classifier = nn.Linear(embed_dims, num_classes - 1) self.norm1 = nn.LayerNorm(embed_dims) self.norm2 = nn.LayerNorm(embed_dims) self.norm3 = nn.LayerNorm(embed_dims) @torch.no_grad() def init_weights(self): self.self_attn.init_weights() self.sampling.init_weights() self.mixing.init_weights() self.ffn.init_weights() def forward(self, query_feat, valid_map, mask_pred, occ_preds, mlvl_feats, query_pos, img_metas): """ query_feat: [bs, num_query, embed_dim] valid_map: [bs, num_query, num_voxel] mask_pred: [bs, num_query, num_voxel] occ_preds: list(occ_loc, occ_pred, _, mask_feat, scale), all voxel decoder's outputs mask_feat: [bs, num_voxel, embed_dim] occ_pred: [bs, num_voxel] occ_loc: [bs, num_voxel, 3] """ occ_loc, occ_pred, _, mask_feat, _ = occ_preds[-1] query_feat = self.norm1(self.self_attn(query_feat, query_pos=query_pos)) sampled_feat = self.sampling(query_feat, valid_map, occ_loc, mlvl_feats, img_metas) query_feat = self.norm2(self.mixing(sampled_feat, query_feat)) query_feat = self.norm3(self.ffn(query_feat)) valid_map, mask_pred, class_pred = self.pred_segmentation(query_feat, mask_feat) return query_feat, valid_map, mask_pred, class_pred def pred_segmentation(self, query_feat, mask_feat): if self.training and query_feat.requires_grad: return cp(self.inner_pred_segmentation, query_feat, mask_feat, use_reentrant=False) else: return self.inner_pred_segmentation(query_feat, mask_feat) def inner_pred_segmentation(self, query_feat, mask_feat): class_pred = self.classifier(query_feat) feat_proj = self.mask_proj(query_feat) mask_pred = torch.einsum("bqc,bnc->bqn", feat_proj, mask_feat) valid_map = (mask_pred > 0.0) return valid_map, mask_pred, class_pred class MaskFormerSelfAttention(BaseModule): def __init__(self, embed_dims, num_heads, dropout=0.0): super().__init__() self.self_attn = nn.MultiheadAttention(embed_dims, num_heads, dropout=dropout, batch_first=True) self.dropout = nn.Dropout(dropout) self.activation = nn.ReLU(inplace=True) def init_weights(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos=None): return tensor if pos is None else tensor + pos def inner_forward(self, query, mask = None, key_padding_mask = None,query_pos= None): q = k = self.with_pos_embed(query, query_pos) tgt = self.self_attn(q, k, value=query, attn_mask=mask, key_padding_mask=key_padding_mask)[0] query = query + self.dropout(tgt) return query def forward(self, query, mask = None, key_padding_mask = None,query_pos= None): if self.training and query.requires_grad: return cp(self.inner_forward, query, mask, key_padding_mask, query_pos, use_reentrant=False) else: return self.inner_forward(query, mask, key_padding_mask, query_pos) class MaskFormerSampling(BaseModule): def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], occ_size=[], init_cfg=None): super().__init__(init_cfg) self.num_frames = num_frames self.num_points = num_points self.num_groups = num_groups self.num_levels = num_levels self.pc_range = pc_range self.occ_size = occ_size self.offset = nn.Linear(embed_dims, num_groups * num_points * 3) self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels) def init_weights(self, ): nn.init.zeros_(self.offset.weight) nn.init.zeros_(self.offset.bias) def inner_forward(self, query_feat, valid_map, occ_loc, mlvl_feats, img_metas): ''' valid_map: [B, Q, W, H, D] query_feat: [B, Q, C] ''' B, Q = query_feat.shape[:2] image_h, image_w, _ = img_metas[0]['img_shape'][0] # sampling offset of all frames offset = self.offset(query_feat).view(B, Q, self.num_groups * self.num_points, 3) # [B, Q, GP, 3] sampling_points = make_sample_points_from_mask(valid_map, self.pc_range, self.occ_size, self.num_groups*self.num_points, occ_loc, offset) sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3) sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3) # scale weights scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels) scale_weights = torch.softmax(scale_weights, dim=-1) scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels) # sampling sampled_feats = sampling_4d( sampling_points, mlvl_feats, scale_weights, img_metas[0]['lidar2img'], image_h, image_w ) # [B, Q, G, FP, C] return sampled_feats def forward(self, query_feat, valid_map, occ_loc, mlvl_feats, img_metas): if self.training and query_feat.requires_grad: return cp(self.inner_forward, query_feat, valid_map, occ_loc, mlvl_feats, img_metas, use_reentrant=False) else: return self.inner_forward(query_feat, valid_map, occ_loc, mlvl_feats, img_metas) ================================================ FILE: models/utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from numpy import random from mmcv.cnn.bricks import ConvTranspose3d, Conv3d def conv3d_gn_relu(in_channels, out_channels, kernel_size=1, stride=1): return nn.Sequential( Conv3d(in_channels, out_channels, kernel_size, stride, bias=False), nn.GroupNorm(16, out_channels), nn.ReLU(inplace=True), ) def deconv3d_gn_relu(in_channels, out_channels, kernel_size=2, stride=2): return nn.Sequential( ConvTranspose3d(in_channels, out_channels, kernel_size, stride, bias=False), nn.GroupNorm(16, out_channels), nn.ReLU(inplace=True), ) def sparse2dense(indices, value, dense_shape, empty_value=0): B, N = indices.shape[:2] # [B, N, 3] batch_index = torch.arange(B).unsqueeze(1).expand(B, N) dense = torch.ones([B] + dense_shape, device=value.device, dtype=value.dtype) * empty_value dense[batch_index, indices[..., 0], indices[..., 1], indices[..., 2]] = value mask = torch.zeros([B] + dense_shape[:3], dtype=torch.bool, device=value.device) mask[batch_index, indices[..., 0], indices[..., 1], indices[..., 2]] = 1 return dense, mask @torch.no_grad() def generate_grid(n_vox, interval): # Create voxel grid grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)] grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2], indexing='ij')) # 3 dx dy dz grid = grid.cuda().view(3, -1).permute(1, 0) # N, 3 return grid[None] # 1, N, 3 def batch_indexing(batched_data: torch.Tensor, batched_indices: torch.Tensor, layout='channel_first'): def batch_indexing_channel_first(batched_data: torch.Tensor, batched_indices: torch.Tensor): """ :param batched_data: [batch_size, C, N] :param batched_indices: [batch_size, I1, I2, ..., Im] :return: indexed data: [batch_size, C, I1, I2, ..., Im] """ def product(arr): p = 1 for i in arr: p *= i return p assert batched_data.shape[0] == batched_indices.shape[0] batch_size, n_channels = batched_data.shape[:2] indices_shape = list(batched_indices.shape[1:]) batched_indices = batched_indices.reshape([batch_size, 1, -1]) batched_indices = batched_indices.expand([batch_size, n_channels, product(indices_shape)]) result = torch.gather(batched_data, dim=2, index=batched_indices.to(torch.int64)) result = result.view([batch_size, n_channels] + indices_shape) return result def batch_indexing_channel_last(batched_data: torch.Tensor, batched_indices: torch.Tensor): """ :param batched_data: [batch_size, N, C] :param batched_indices: [batch_size, I1, I2, ..., Im] :return: indexed data: [batch_size, I1, I2, ..., Im, C] """ assert batched_data.shape[0] == batched_indices.shape[0] batch_size = batched_data.shape[0] view_shape = [batch_size] + [1] * (len(batched_indices.shape) - 1) expand_shape = [batch_size] + list(batched_indices.shape)[1:] indices_of_batch = torch.arange(batch_size, dtype=torch.long, device=batched_data.device) indices_of_batch = indices_of_batch.view(view_shape).expand(expand_shape) # [bs, I1, I2, ..., Im] if len(batched_data.shape) == 2: return batched_data[indices_of_batch, batched_indices.to(torch.long)] else: return batched_data[indices_of_batch, batched_indices.to(torch.long), :] if layout == 'channel_first': return batch_indexing_channel_first(batched_data, batched_indices) elif layout == 'channel_last': return batch_indexing_channel_last(batched_data, batched_indices) else: raise ValueError def rotation_3d_in_axis(points, angles): assert points.shape[-1] == 3 assert angles.shape[-1] == 1 angles = angles[..., 0] n_points = points.shape[-2] input_dims = angles.shape if len(input_dims) > 1: points = points.reshape(-1, n_points, 3) angles = angles.reshape(-1) rot_sin = torch.sin(angles) rot_cos = torch.cos(angles) ones = torch.ones_like(rot_cos) zeros = torch.zeros_like(rot_cos) rot_mat_T = torch.stack([ rot_cos, rot_sin, zeros, -rot_sin, rot_cos, zeros, zeros, zeros, ones, ]).transpose(0, 1).reshape(-1, 3, 3) points = torch.bmm(points, rot_mat_T) if len(input_dims) > 1: points = points.reshape(*input_dims, n_points, 3) return points def inverse_sigmoid(x, eps=1e-5): """Inverse function of sigmoid. Args: x (Tensor): The tensor to do the inverse. eps (float): EPS avoid numerical overflow. Defaults 1e-5. Returns: Tensor: The x has passed the inverse function of sigmoid, has same shape with input. """ x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1 / x2) def pad_multiple(inputs, img_metas, size_divisor=32): _, _, img_h, img_w = inputs.shape pad_h = 0 if img_h % size_divisor == 0 else size_divisor - (img_h % size_divisor) pad_w = 0 if img_w % size_divisor == 0 else size_divisor - (img_w % size_divisor) B = len(img_metas) N = len(img_metas[0]['ori_shape']) for b in range(B): img_metas[b]['img_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)] img_metas[b]['pad_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)] if pad_h == 0 and pad_w == 0: return inputs else: return F.pad(inputs, [0, pad_w, 0, pad_h], value=0) def rgb_to_hsv(image: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: r"""Convert an image from RGB to HSV. .. image:: _static/img/rgb_to_hsv.png The image data is assumed to be in the range of (0, 1). Args: image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`. eps: scalar to enforce numarical stability. Returns: HSV version of the image with shape of :math:`(*, 3, H, W)`. The H channel values are in the range 0..2pi. S and V are in the range 0..1. .. note:: See a working example `here `__. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = rgb_to_hsv(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") image = image / 255.0 max_rgb, argmax_rgb = image.max(-3) min_rgb, argmin_rgb = image.min(-3) deltac = max_rgb - min_rgb v = max_rgb s = deltac / (max_rgb + eps) deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac) rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3) h1 = bc - gc h2 = (rc - bc) + 2.0 * deltac h3 = (gc - rc) + 4.0 * deltac h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3) h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3) h = (h / 6.0) % 1.0 h = h * 360.0 v = v * 255.0 return torch.stack((h, s, v), dim=-3) def hsv_to_rgb(image: torch.Tensor) -> torch.Tensor: r"""Convert an image from HSV to RGB. The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1. Args: image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`. Returns: RGB version of the image with shape of :math:`(*, 3, H, W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = hsv_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") h: torch.Tensor = image[..., 0, :, :] / 360.0 s: torch.Tensor = image[..., 1, :, :] v: torch.Tensor = image[..., 2, :, :] / 255.0 hi: torch.Tensor = torch.floor(h * 6) % 6 f: torch.Tensor = ((h * 6) % 6) - hi one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype) p: torch.Tensor = v * (one - s) q: torch.Tensor = v * (one - f * s) t: torch.Tensor = v * (one - (one - f) * s) hi = hi.long() indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3) out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3) out = torch.gather(out, -3, indices) out = out * 255.0 return out class GpuPhotoMetricDistortion: """Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in second or second to last. 1. random brightness 2. random contrast (mode 0) 3. convert color from BGR to HSV 4. random saturation 5. random hue 6. convert color from HSV to BGR 7. random contrast (mode 1) 8. randomly swap channels Args: brightness_delta (int): delta of brightness. contrast_range (tuple): range of contrast. saturation_range (tuple): range of saturation. hue_delta (int): delta of hue. """ def __init__(self, brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18): self.brightness_delta = brightness_delta self.contrast_lower, self.contrast_upper = contrast_range self.saturation_lower, self.saturation_upper = saturation_range self.hue_delta = hue_delta def __call__(self, imgs): """Call function to perform photometric distortion on images. Args: results (dict): Result dict from loading pipeline. Returns: dict: Result dict with images distorted. """ imgs = imgs[:, [2, 1, 0], :, :] # BGR to RGB contrast_modes = [] for _ in range(imgs.shape[0]): # mode == 0 --> do random contrast first # mode == 1 --> do random contrast last contrast_modes.append(random.randint(2)) for idx in range(imgs.shape[0]): # random brightness if random.randint(2): delta = random.uniform(-self.brightness_delta, self.brightness_delta) imgs[idx] += delta if contrast_modes[idx] == 0: if random.randint(2): alpha = random.uniform(self.contrast_lower, self.contrast_upper) imgs[idx] *= alpha # convert color from BGR to HSV imgs = rgb_to_hsv(imgs) for idx in range(imgs.shape[0]): # random saturation if random.randint(2): imgs[idx, 1] *= random.uniform(self.saturation_lower, self.saturation_upper) # random hue if random.randint(2): imgs[idx, 0] += random.uniform(-self.hue_delta, self.hue_delta) imgs[:, 0][imgs[:, 0] > 360] -= 360 imgs[:, 0][imgs[:, 0] < 0] += 360 # convert color from HSV to BGR imgs = hsv_to_rgb(imgs) for idx in range(imgs.shape[0]): # random contrast if contrast_modes[idx] == 1: if random.randint(2): alpha = random.uniform(self.contrast_lower, self.contrast_upper) imgs[idx] *= alpha # randomly swap channels if random.randint(2): imgs[idx] = imgs[idx, random.permutation(3)] imgs = imgs[:, [2, 1, 0], :, :] # RGB to BGR return imgs class DumpConfig: def __init__(self): self.enabled = False self.out_dir = 'outputs' self.stage_count = 0 self.frame_count = 0 DUMP = DumpConfig() ================================================ FILE: old_metrics.py ================================================ import os import glob import torch import argparse import numpy as np from tqdm import tqdm from loaders.old_metrics import Metric_mIoU def main(args): pred_filepaths = sorted(glob.glob(os.path.join(args.pred_dir, '*.npz'))) gt_filepaths = sorted(glob.glob(os.path.join(args.data_root, 'occ3d', '*/*/*.npz'))) eval_metrics_miou = Metric_mIoU( num_classes=18, use_lidar_mask=False, use_image_mask=True) for pred_filepath in tqdm(pred_filepaths): sample_token = os.path.basename(pred_filepath).split('.')[0] for gt_filepath in gt_filepaths: if sample_token in gt_filepath: sem_pred = np.load(pred_filepath, allow_pickle=True)['pred'] sem_pred = np.reshape(sem_pred, [200, 200, 16]) occ_gt = np.load(gt_filepath, allow_pickle=True) gt_semantics = occ_gt['semantics'] mask_lidar = occ_gt['mask_lidar'].astype(bool) mask_camera = occ_gt['mask_camera'].astype(bool) eval_metrics_miou.add_batch(sem_pred, gt_semantics, mask_lidar, mask_camera) eval_metrics_miou.count_miou() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data-root", type=str, default='data/nuscenes') parser.add_argument("--pred-dir", type=str) args = parser.parse_args() torch.random.manual_seed(0) np.random.seed(0) main(args) ================================================ FILE: ray_metrics.py ================================================ import os import glob import mmcv import argparse import numpy as np import torch from torch.utils.data import DataLoader from loaders.ray_metrics import main_rayiou from loaders.ego_pose_dataset import EgoPoseDataset from configs.r50_nuimg_704x256_8f import occ_class_names as occ3d_class_names from configs.r50_nuimg_704x256_8f_openocc import occ_class_names as openocc_class_names def main(args): data_infos = mmcv.load(os.path.join(args.data_root, 'nuscenes_infos_val.pkl'))['infos'] gt_filepaths = sorted(glob.glob(os.path.join(args.data_root, args.data_type, '*/*/*.npz'))) # retrieve scene_name token2scene = {} for gt_path in gt_filepaths: token = gt_path.split('/')[-2] scene_name = gt_path.split('/')[-3] token2scene[token] = scene_name for i in range(len(data_infos)): scene_name = token2scene[data_infos[i]['token']] data_infos[i]['scene_name'] = scene_name lidar_origins = [] occ_gts = [] occ_preds = [] for idx, batch in enumerate(DataLoader(EgoPoseDataset(data_infos), num_workers=8)): output_origin = batch[1] info = data_infos[idx] occ_path = os.path.join(args.data_root, args.data_type, info['scene_name'], info['token'], 'labels.npz') occ_gt = np.load(occ_path, allow_pickle=True)['semantics'] occ_gt = np.reshape(occ_gt, [200, 200, 16]).astype(np.uint8) occ_path = os.path.join(args.pred_dir, info['token'] + '.npz') occ_pred = np.load(occ_path, allow_pickle=True)['pred'] occ_pred = np.reshape(occ_pred, [200, 200, 16]).astype(np.uint8) lidar_origins.append(output_origin) occ_gts.append(occ_gt) occ_preds.append(occ_pred) if args.data_type == 'occ3d': occ_class_names = occ3d_class_names elif args.data_type == 'openocc_v2': occ_class_names = openocc_class_names else: raise ValueError print(main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data-root", type=str, default='data/nuscenes') parser.add_argument("--pred-dir", type=str) parser.add_argument("--data-type", type=str, choices=['occ3d', 'openocc_v2'], default='occ3d') args = parser.parse_args() torch.random.manual_seed(0) np.random.seed(0) main(args) ================================================ FILE: timing.py ================================================ import time import utils import logging import argparse import importlib import torch import torch.distributed import torch.backends.cudnn as cudnn from mmcv import Config, DictAction from mmcv.parallel import MMDataParallel from mmcv.runner import load_checkpoint from mmdet.apis import set_random_seed from mmdet3d.datasets import build_dataset, build_dataloader from mmdet3d.models import build_model def main(): parser = argparse.ArgumentParser(description='Validate a detector') parser.add_argument('--config', required=True) parser.add_argument('--weights', required=True) parser.add_argument('--num_warmup', default=10) parser.add_argument('--samples', default=200) parser.add_argument('--log-interval', default=50, help='interval of logging') parser.add_argument('--override', nargs='+', action=DictAction) args = parser.parse_args() # parse configs cfgs = Config.fromfile(args.config) if args.override is not None: cfgs.merge_from_dict(args.override) # register custom module importlib.import_module('models') importlib.import_module('loaders') # MMCV, please shut up from mmcv.utils.logging import logger_initialized logger_initialized['root'] = logging.Logger(__name__, logging.WARNING) logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING) utils.init_logging(None, cfgs.debug) # you need GPUs assert torch.cuda.is_available() and torch.cuda.device_count() == 1 logging.info('Using GPU: %s' % torch.cuda.get_device_name(0)) torch.cuda.set_device(0) logging.info('Setting random seed: 0') set_random_seed(0, deterministic=True) cudnn.benchmark = True logging.info('Loading validation set from %s' % cfgs.data.val.data_root) val_dataset = build_dataset(cfgs.data.val) val_loader = build_dataloader( val_dataset, samples_per_gpu=1, workers_per_gpu=cfgs.data.workers_per_gpu, num_gpus=1, dist=False, shuffle=False, seed=0, ) logging.info('Creating model: %s' % cfgs.model.type) model = build_model(cfgs.model) model.cuda() assert torch.cuda.device_count() == 1 model = MMDataParallel(model, [0]) logging.info('Loading checkpoint from %s' % args.weights) load_checkpoint( model, args.weights, map_location='cuda', strict=False, logger=logging.Logger(__name__, logging.ERROR) ) model.eval() print('Timing w/ data loading:') pure_inf_time = 0 with torch.no_grad(): for i, data in enumerate(val_loader): torch.cuda.synchronize() start_time = time.perf_counter() model(return_loss=False, rescale=True, **data) torch.cuda.synchronize() elapsed = time.perf_counter() - start_time if i >= args.num_warmup: pure_inf_time += elapsed if (i + 1) % args.log_interval == 0: fps = (i + 1 - args.num_warmup) / pure_inf_time print(f'Done sample [{i + 1:<3}/ {args.samples}], ' f'fps: {fps:.1f} sample / s') if (i + 1) == args.samples: break if __name__ == '__main__': main() ================================================ FILE: train.py ================================================ import os import utils import shutil import logging import argparse import importlib import torch import torch.distributed as dist from datetime import datetime from mmcv import Config, DictAction from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import EpochBasedRunner, build_optimizer, load_checkpoint from mmdet.apis import set_random_seed from mmdet.core import DistEvalHook, EvalHook from mmdet3d.datasets import build_dataset from mmdet3d.models import build_model from loaders.builder import build_dataloader def main(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('--config', required=True) parser.add_argument('--run_name', required=False, default='') parser.add_argument('--override', nargs='+', action=DictAction) parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--world_size', type=int, default=1) args = parser.parse_args() # parse configs cfgs = Config.fromfile(args.config) if args.override is not None: cfgs.merge_from_dict(args.override) # register custom module importlib.import_module('models') importlib.import_module('loaders') # MMCV, please shut up from mmcv.utils.logging import logger_initialized logger_initialized['root'] = logging.Logger(__name__, logging.WARNING) logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING) logger_initialized['mmdet3d'] = logging.Logger(__name__, logging.WARNING) # you need GPUs assert torch.cuda.is_available() # determine local_rank and world_size if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) if 'WORLD_SIZE' not in os.environ: os.environ['WORLD_SIZE'] = str(args.world_size) local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) if local_rank == 0: # resume or start a new run if cfgs.resume_from is not None: assert os.path.isfile(cfgs.resume_from) work_dir = os.path.dirname(cfgs.resume_from) else: run_name = args.run_name if not cfgs.debug and run_name == '': run_name = input('Name your run (leave blank for default): ') if run_name == '': run_name = datetime.now().strftime("%Y-%m-%d/%H-%M-%S") work_dir = os.path.join('outputs', cfgs.model.type, run_name) if os.path.exists(work_dir): # must be an empty dir if input('Path "%s" already exists, overwrite it? [Y/n] ' % work_dir) == 'n': print('Bye.') exit(0) shutil.rmtree(work_dir) os.makedirs(work_dir, exist_ok=False) # init logging, backup code utils.init_logging(os.path.join(work_dir, 'train.log'), cfgs.debug) utils.backup_code(work_dir) logging.info('Logs will be saved to %s' % work_dir) else: # disable logging on other workers logging.root.disabled = True work_dir = '/tmp' logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank)) torch.cuda.set_device(local_rank) if world_size > 1: logging.info('Initializing DDP with %d GPUs...' % world_size) dist.init_process_group('nccl', init_method='env://') logging.info('Setting random seed: 0') set_random_seed(0, deterministic=True) logging.info('Loading training set from %s' % cfgs.dataset_root) train_dataset = build_dataset(cfgs.data.train) train_loader = build_dataloader( train_dataset, samples_per_gpu=cfgs.batch_size // world_size, workers_per_gpu=cfgs.data.workers_per_gpu, num_gpus=world_size, dist=world_size > 1, shuffle=True, seed=0, ) logging.info('Loading validation set from %s' % cfgs.dataset_root) val_dataset = build_dataset(cfgs.data.val) val_loader = build_dataloader( val_dataset, samples_per_gpu=1, workers_per_gpu=cfgs.data.workers_per_gpu, num_gpus=world_size, dist=world_size > 1, shuffle=False ) logging.info('Creating model: %s' % cfgs.model.type) model = build_model(cfgs.model) model.init_weights() model.cuda() model.train() n_params = sum([p.numel() for p in model.parameters() if p.requires_grad]) logging.info('Trainable parameters: %d (%.1fM)' % (n_params, n_params / 1e6)) logging.info('Batch size per GPU: %d' % (cfgs.batch_size // world_size)) if world_size > 1: model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False) else: model = MMDataParallel(model, [0]) logging.info('Creating optimizer: %s' % cfgs.optimizer.type) optimizer = build_optimizer(model, cfgs.optimizer) runner = EpochBasedRunner( model, optimizer=optimizer, work_dir=work_dir, logger=logging.root, max_epochs=cfgs.total_epochs, meta=dict(), ) runner.register_lr_hook(cfgs.lr_config) runner.register_optimizer_hook(cfgs.optimizer_config) runner.register_checkpoint_hook(cfgs.checkpoint_config) runner.register_logger_hooks(cfgs.log_config) runner.register_timer_hook(dict(type='IterTimerHook')) runner.register_custom_hooks(dict(type='DistSamplerSeedHook')) if cfgs.eval_config['interval'] > 0: if world_size > 1: runner.register_hook(DistEvalHook(val_loader, interval=cfgs.eval_config['interval'], gpu_collect=True)) else: runner.register_hook(EvalHook(val_loader, interval=cfgs.eval_config['interval'])) if cfgs.resume_from is not None: logging.info('Resuming from %s' % cfgs.resume_from) runner.resume(cfgs.resume_from) elif cfgs.load_from is not None: logging.info('Loading checkpoint from %s' % cfgs.load_from) if cfgs.revise_keys is not None: load_checkpoint( model, cfgs.load_from, map_location='cpu', revise_keys=cfgs.revise_keys ) else: load_checkpoint( model, cfgs.load_from, map_location='cpu', ) runner.run([train_loader], [('train', 1)]) if __name__ == '__main__': main() ================================================ FILE: utils.py ================================================ import os import sys import glob import torch import shutil import logging import datetime import socket import wandb from mmcv.runner.hooks import HOOKS from mmcv.runner.hooks.logger import LoggerHook, TextLoggerHook from mmcv.runner.dist_utils import master_only from torch.utils.tensorboard import SummaryWriter def init_logging(filename=None, debug=False): logging.root = logging.RootLogger('DEBUG' if debug else 'INFO') formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s') stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setFormatter(formatter) logging.root.addHandler(stream_handler) if filename is not None: file_handler = logging.FileHandler(filename) file_handler.setFormatter(formatter) logging.root.addHandler(file_handler) def backup_code(work_dir, verbose=False): base_dir = os.path.dirname(os.path.abspath(__file__)) for pattern in ['*.py', 'configs/*.py', 'models/*.py', 'loaders/*.py', 'loaders/pipelines/*.py']: for file in glob.glob(pattern): src = os.path.join(base_dir, file) dst = os.path.join(work_dir, 'backup', os.path.dirname(file)) if verbose: logging.info('Copying %s -> %s' % (os.path.relpath(src), os.path.relpath(dst))) os.makedirs(dst, exist_ok=True) shutil.copy2(src, dst) @HOOKS.register_module() class MyTextLoggerHook(TextLoggerHook): def _log_info(self, log_dict, runner): # print exp name for users to distinguish experiments # at every ``interval_exp_name`` iterations and the end of each epoch if runner.meta is not None and 'exp_name' in runner.meta: if (self.every_n_iters(runner, self.interval_exp_name)) or ( self.by_epoch and self.end_of_epoch(runner)): exp_info = f'Exp name: {runner.meta["exp_name"]}' runner.logger.info(exp_info) # by epoch: Epoch [4][100/1000] # by iter: Iter [100/100000] if self.by_epoch: log_str = f'Epoch [{log_dict["epoch"]}/{runner.max_epochs}]' \ f'[{log_dict["iter"]}/{len(runner.data_loader)}] ' else: log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}] ' log_str += 'loss: %.2f, ' % log_dict['loss'] if 'time' in log_dict.keys(): # MOD: skip the first iteration since it's not accurate if runner.iter == self.start_iter: time_sec_avg = log_dict['time'] else: self.time_sec_tot += (log_dict['time'] * self.interval) time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter) eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) eta_str = str(datetime.timedelta(seconds=int(eta_sec))) log_str += f'eta: {eta_str}, ' log_str += f'time: {log_dict["time"]:.2f}s, ' \ f'data: {log_dict["data_time"] * 1000:.0f}ms, ' # statistic memory if torch.cuda.is_available(): log_str += f'mem: {log_dict["memory"]}M' runner.logger.info(log_str) def log(self, runner): if 'eval_iter_num' in runner.log_buffer.output: # this doesn't modify runner.iter and is regardless of by_epoch cur_iter = runner.log_buffer.output.pop('eval_iter_num') else: cur_iter = self.get_iter(runner, inner_iter=True) log_dict = { 'mode': self.get_mode(runner), 'epoch': self.get_epoch(runner), 'iter': cur_iter } # only record lr of the first param group cur_lr = runner.current_lr() if isinstance(cur_lr, list): log_dict['lr'] = cur_lr[0] else: assert isinstance(cur_lr, dict) log_dict['lr'] = {} for k, lr_ in cur_lr.items(): assert isinstance(lr_, list) log_dict['lr'].update({k: lr_[0]}) if 'time' in runner.log_buffer.output: # statistic memory if torch.cuda.is_available(): log_dict['memory'] = self._get_max_memory(runner) log_dict = dict(log_dict, **runner.log_buffer.output) # MOD: disable writing to files # self._dump_log(log_dict, runner) self._log_info(log_dict, runner) return log_dict def after_train_epoch(self, runner): if 'eval_iter_num' in runner.log_buffer.output: runner.log_buffer.output.pop('eval_iter_num') if runner.log_buffer.ready: metrics = self.get_loggable_tags(runner) runner.logger.info('--- Evaluation Results ---') runner.logger.info('RayIoU: %.4f' % metrics['val/RayIoU']) @HOOKS.register_module() class MyTensorboardLoggerHook(LoggerHook): def __init__(self, log_dir=None, interval=10, ignore_last=True, reset_flag=False, by_epoch=True): super(MyTensorboardLoggerHook, self).__init__( interval, ignore_last, reset_flag, by_epoch) self.log_dir = log_dir @master_only def before_run(self, runner): super(MyTensorboardLoggerHook, self).before_run(runner) if self.log_dir is None: self.log_dir = runner.work_dir self.writer = SummaryWriter(self.log_dir) @master_only def log(self, runner): tags = self.get_loggable_tags(runner) for key, value in tags.items(): # MOD: merge into the 'train' group if key == 'learning_rate': key = 'train/learning_rate' # MOD: skip momentum ignore = False if key == 'momentum': ignore = True # MOD: skip intermediate losses for i in range(5): if key[:13] == 'train/d%d.loss' % i: ignore = True if self.get_mode(runner) == 'train' and key[:5] != 'train': ignore = True if self.get_mode(runner) != 'train' and key[:3] != 'val': ignore = True if ignore: continue if key[:5] == 'train': self.writer.add_scalar(key, value, self.get_iter(runner)) elif key[:3] == 'val': self.writer.add_scalar(key, value, self.get_epoch(runner)) @master_only def after_run(self, runner): self.writer.close() # modified from mmcv.runner.hooks.logger.wandb @HOOKS.register_module() class MyWandbLoggerHook(LoggerHook): """Class to log metrics with wandb. It requires `wandb`_ to be installed. Args: log_dir (str): directory for saving logs Default None. project_name (str): name for your project (mainly used to specify saving path on wandb server) Default None. team_name (str): name for your team (mainly used to specify saving path on wandb server) Default None. experiment_name (str): name for your run, if not specified, use the last part of log_dir Default None. interval (int): Logging interval (every k iterations). Default 10. ignore_last (bool): Ignore the log of last iterations in each epoch if less than `interval`. Default: True. reset_flag (bool): Whether to clear the output buffer after logging. Default: False. commit (bool): Save the metrics dict to the wandb server and increment the step. If false ``wandb.log`` just updates the current metrics dict with the row argument and metrics won't be saved until ``wandb.log`` is called with ``commit=True``. Default: True. by_epoch (bool): Whether EpochBasedRunner is used. Default: True. with_step (bool): If True, the step will be logged from ``self.get_iters``. Otherwise, step will not be logged. Default: True. out_suffix (str or tuple[str], optional): Those filenames ending with ``out_suffix`` will be uploaded to wandb. Default: ('.log.json', '.log', '.py'). `New in version 1.4.3.` .. _wandb: https://docs.wandb.ai """ def __init__(self, log_dir=None, project_name=None, team_name=None, experiment_name=None, interval=10, ignore_last=True, reset_flag=False, by_epoch=True, commit=True, with_step=True, out_suffix = ('.log.json', '.log', '.py')): super().__init__(interval, ignore_last, reset_flag, by_epoch) self.import_wandb() self.commit = commit self.with_step = with_step self.out_suffix = out_suffix self.log_dir = log_dir self.project_name = project_name self.team_name = team_name self.experiment_name = experiment_name if commit: os.system('wandb online') else: os.system('wandb offline') def import_wandb(self) -> None: try: import wandb except ImportError: raise ImportError( 'Please run "pip install wandb" to install wandb') self.wandb = wandb @master_only def before_run(self, runner) -> None: super().before_run(runner) if self.log_dir is None: self.log_dir = runner.work_dir if self.experiment_name is None: self.experiment_name = os.path.basename(self.log_dir) init_kwargs = dict( project=self.project_name, entity=self.team_name, notes=socket.gethostname(), name=self.experiment_name, dir=self.log_dir, reinit=True ) if self.wandb is None: self.import_wandb() if init_kwargs: self.wandb.init(**init_kwargs) # type: ignore else: self.wandb.init() # type: ignore @master_only def log(self, runner) -> None: tags = self.get_loggable_tags(runner) mode = self.get_mode(runner) if not tags: return if 'learning_rate' in tags.keys(): tags['train/learning_rate'] = tags['learning_rate'] del tags['learning_rate'] if 'momentum' in tags.keys(): del tags['momentum'] tags = {k: v for k, v in tags.items() if k.startswith(mode)} if self.with_step: self.wandb.log( tags, step=self.get_iter(runner), commit=self.commit) else: tags['global_step'] = self.get_iter(runner) self.wandb.log(tags, commit=self.commit) @master_only def after_run(self, runner) -> None: self.wandb.join() ================================================ FILE: val.py ================================================ import os import utils import logging import argparse import importlib import torch import torch.distributed import torch.distributed as dist import torch.backends.cudnn as cudnn from mmcv import Config from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import load_checkpoint from mmdet.apis import set_random_seed, multi_gpu_test, single_gpu_test from mmdet3d.datasets import build_dataset, build_dataloader from mmdet3d.models import build_model def evaluate(dataset, results): metrics = dataset.evaluate(results, jsonfile_prefix=None) logging.info('--- Evaluation Results ---') for k, v in metrics.items(): logging.info('%s: %.4f' % (k, v)) return metrics def main(): parser = argparse.ArgumentParser(description='Validate a detector') parser.add_argument('--config', required=True) parser.add_argument('--weights', required=True) parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--world_size', type=int, default=1) parser.add_argument('--batch_size', type=int, default=1) args = parser.parse_args() # parse configs cfgs = Config.fromfile(args.config) # register custom module importlib.import_module('models') importlib.import_module('loaders') # MMCV, please shut up from mmcv.utils.logging import logger_initialized logger_initialized['root'] = logging.Logger(__name__, logging.WARNING) logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING) # you need GPUs assert torch.cuda.is_available() # determine local_rank and world_size if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) if 'WORLD_SIZE' not in os.environ: os.environ['WORLD_SIZE'] = str(args.world_size) local_rank = int(os.environ['LOCAL_RANK']) world_size = int(os.environ['WORLD_SIZE']) if local_rank == 0: utils.init_logging(None, cfgs.debug) else: logging.root.disabled = True logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank)) torch.cuda.set_device(local_rank) if world_size > 1: logging.info('Initializing DDP with %d GPUs...' % world_size) dist.init_process_group('nccl', init_method='env://') logging.info('Setting random seed: 0') set_random_seed(0, deterministic=True) cudnn.benchmark = True logging.info('Loading validation set from %s' % cfgs.data.val.data_root) val_dataset = build_dataset(cfgs.data.val) val_loader = build_dataloader( val_dataset, samples_per_gpu=args.batch_size, workers_per_gpu=cfgs.data.workers_per_gpu, num_gpus=world_size, dist=world_size > 1, shuffle=False, seed=0, ) logging.info('Creating model: %s' % cfgs.model.type) model = build_model(cfgs.model) model.cuda() if world_size > 1: model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False) else: model = MMDataParallel(model, [0]) if os.path.isfile(args.weights): logging.info('Loading checkpoint from %s' % args.weights) load_checkpoint( model, args.weights, map_location='cuda', strict=True, logger=logging.Logger(__name__, logging.ERROR) ) if world_size > 1: results = multi_gpu_test(model, val_loader, gpu_collect=True) else: results = single_gpu_test(model, val_loader) if local_rank == 0: evaluate(val_dataset, results) if __name__ == '__main__': main() ================================================ FILE: viz_prediction.py ================================================ import os import cv2 import utils import logging import argparse import importlib import torch import numpy as np from tqdm import tqdm from mmcv import Config, DictAction from mmdet.apis import set_random_seed from mmdet3d.datasets import build_dataset, build_dataloader from configs.r50_nuimg_704x256_8f import point_cloud_range as pc_range from configs.r50_nuimg_704x256_8f import occ_size from configs.r50_nuimg_704x256_8f import occ_class_names from mmcv.parallel import MMDataParallel from mmcv.runner import load_checkpoint from mmdet3d.models import build_model color_map = np.array([ [0, 0, 0, 255], # others [255, 120, 50, 255], # barrier orangey [255, 192, 203, 255], # bicycle pink [255, 255, 0, 255], # bus yellow [0, 150, 245, 255], # car blue [0, 255, 255, 255], # construction_vehicle cyan [200, 180, 0, 255], # motorcycle dark orange [255, 0, 0, 255], # pedestrian red [255, 240, 150, 255], # traffic_cone light yellow [135, 60, 0, 255], # trailer brown [160, 32, 240, 255], # truck purple [255, 0, 255, 255], # driveable_surface dark pink [175, 0, 75, 255], # other_flat dark red [75, 0, 75, 255], # sidewalk dard purple [150, 240, 80, 255], # terrain light green [230, 230, 250, 255], # manmade white [0, 175, 0, 255], # vegetation green [255, 255, 255, 255], # free white ], dtype=np.uint8) def occ2img(semantics): H, W, D = semantics.shape free_id = len(occ_class_names) - 1 semantics_2d = np.ones([H, W], dtype=np.int32) * free_id for i in range(D): semantics_i = semantics[..., i] non_free_mask = (semantics_i != free_id) semantics_2d[non_free_mask] = semantics_i[non_free_mask] viz = color_map[semantics_2d] viz = viz[..., :3] viz = cv2.resize(viz, dsize=(800, 800)) return viz def main(): parser = argparse.ArgumentParser(description='Validate a detector') parser.add_argument('--config', required=True) parser.add_argument('--weights', required=True) parser.add_argument('--viz-dir', required=True) parser.add_argument('--override', nargs='+', action=DictAction) args = parser.parse_args() # parse configs cfgs = Config.fromfile(args.config) if args.override is not None: cfgs.merge_from_dict(args.override) # use val-mini for visualization #cfgs.data.val.ann_file = cfgs.data.val.ann_file.replace('val', 'val_mini') # register custom module importlib.import_module('models') importlib.import_module('loaders') # MMCV, please shut up from mmcv.utils.logging import logger_initialized logger_initialized['root'] = logging.Logger(__name__, logging.WARNING) logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING) # you need one GPU assert torch.cuda.is_available() assert torch.cuda.device_count() == 1 # logging utils.init_logging(None, cfgs.debug) logging.info('Using GPU: %s' % torch.cuda.get_device_name(0)) # random seed logging.info('Setting random seed: 0') set_random_seed(0, deterministic=True) logging.info('Loading validation set from %s' % cfgs.data.val.data_root) val_dataset = build_dataset(cfgs.data.val) val_loader = build_dataloader( val_dataset, samples_per_gpu=1, workers_per_gpu=cfgs.data.workers_per_gpu, num_gpus=1, dist=False, shuffle=False, seed=0, ) logging.info('Creating model: %s' % cfgs.model.type) model = build_model(cfgs.model) model.cuda() model = MMDataParallel(model, [0]) model.eval() logging.info('Loading checkpoint from %s' % args.weights) load_checkpoint( model, args.weights, map_location='cuda', strict=True, logger=logging.Logger(__name__, logging.ERROR) ) for i, data in tqdm(enumerate(val_loader)): #print(data['img_metas'].data[0][0]['filename'][:6]) with torch.no_grad(): occ_pred = model(return_loss=False, rescale=True, **data)[0] sem_pred = torch.from_numpy(occ_pred['sem_pred'])[0] # [N] occ_loc = torch.from_numpy(occ_pred['occ_loc'].astype(np.int64))[0] # [N, 3] # sparse to dense free_id = len(occ_class_names) - 1 dense_pred = torch.ones(occ_size, device=sem_pred.device, dtype=sem_pred.dtype) * free_id # [200, 200, 16] dense_pred[occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]] = sem_pred sem_pred = dense_pred.numpy() cv2.imwrite(os.path.join(args.viz_dir, 'sem_%04d.jpg' % i), occ2img(sem_pred)[..., ::-1]) if __name__ == '__main__': main()