Repository: MCG-NJU/SparseOcc
Branch: main
Commit: af4d9df83bfd
Files: 58
Total size: 409.0 KB
Directory structure:
gitextract_8j2gi4im/
├── .gitignore
├── LICENSE
├── README.md
├── configs/
│ ├── r50_nuimg_704x256_8f.py
│ ├── r50_nuimg_704x256_8f_60e.py
│ ├── r50_nuimg_704x256_8f_openocc.py
│ └── r50_nuimg_704x256_8f_pano.py
├── gen_instance_info.py
├── gen_sweep_info.py
├── lib/
│ └── dvr/
│ ├── dvr.cpp
│ └── dvr.cu
├── loaders/
│ ├── __init__.py
│ ├── builder.py
│ ├── ego_pose_dataset.py
│ ├── nuscenes_dataset.py
│ ├── nuscenes_occ_dataset.py
│ ├── old_metrics.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── loading.py
│ │ └── transforms.py
│ ├── ray_metrics.py
│ └── ray_pq.py
├── models/
│ ├── __init__.py
│ ├── backbones/
│ │ ├── __init__.py
│ │ └── vovnet.py
│ ├── bbox/
│ │ ├── __init__.py
│ │ ├── assigners/
│ │ │ ├── __init__.py
│ │ │ └── hungarian_assigner_3d.py
│ │ ├── coders/
│ │ │ ├── __init__.py
│ │ │ └── nms_free_coder.py
│ │ ├── match_costs/
│ │ │ ├── __init__.py
│ │ │ └── match_cost.py
│ │ └── utils.py
│ ├── checkpoint.py
│ ├── csrc/
│ │ ├── __init__.py
│ │ ├── msmv_sampling/
│ │ │ ├── msmv_sampling.cpp
│ │ │ ├── msmv_sampling.h
│ │ │ ├── msmv_sampling_backward.cu
│ │ │ └── msmv_sampling_forward.cu
│ │ ├── setup.py
│ │ └── wrapper.py
│ ├── loss_utils.py
│ ├── matcher.py
│ ├── sparse_voxel_decoder.py
│ ├── sparsebev_head.py
│ ├── sparsebev_sampling.py
│ ├── sparsebev_transformer.py
│ ├── sparseocc.py
│ ├── sparseocc_head.py
│ ├── sparseocc_transformer.py
│ └── utils.py
├── old_metrics.py
├── ray_metrics.py
├── timing.py
├── train.py
├── utils.py
├── val.py
└── viz_prediction.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Compiled source
build
debug
Debug
release
Release
x64
*.so
*.whl
# VS project files
*.sln
*.vcxproj
*.vcxproj.filters
*.vcxproj.user
*.rc
.vs
# Byte-compiled / optimized / DLL files
*__pycache__*
*.py[cod]
*$py.class
# Distribution / packaging
.Python
build
develop-eggs
dist
downloads
# IDE
.idea
.vscode
pyrightconfig.json
# Custom
data
outputs
prediction
submission
checkpoints
pretrain
ckpts
occ_result
wandb
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# SparseOcc
This is the official PyTorch implementation for our paper:
> [**Fully Sparse 3D Panoptic Occupancy Prediction**](https://arxiv.org/abs/2312.17118)
> :school: Presented by Nanjing University and Shanghai AI Lab
> :email: Primary contact: Haisong Liu (afterthat97@gmail.com)
> :trophy: [CVPR 2024 Autonomous Driving Challenge - Occupancy and Flow](https://opendrivelab.com/challenge2024/#occupancy_and_flow)
> :book: 中文解读(官方):https://zhuanlan.zhihu.com/p/709576252
> :book: 中文解读(第三方): [AIming](https://zhuanlan.zhihu.com/p/691549750), [自动驾驶之心](https://zhuanlan.zhihu.com/p/675811281)
## :warning: Important Notes
There is another concurrent project titled *''SparseOcc: Rethinking sparse latent representation for vision-based semantic occupancy prediction''* by Tang et al., which shares the same name SparseOcc with ours. However, this repository is **unrelated** to the aforementioned paper.
If you cite our research, please ensure that you reference the correct version (arXiv **2312.17118**, authored by **Liu et al.**):
```
@article{liu2023fully,
title={Fully sparse 3d panoptic occupancy prediction},
author={Liu, Haisong and Wang, Haiguang and Chen, Yang and Yang, Zetong and Zeng, Jia and Chen, Li and Wang, Limin},
journal={arXiv preprint arXiv:2312.17118},
year={2023}
}
```
> In arXiv 2312.17118v3, we removed the word "panoptic" from the title. However, Google Scholar's database has not been updated and still shows the old one. Therefore, we still recommend citing the old title - "Fully sparse 3d panoptic occupancy prediction" - so that Google Scholar can index it correctly. Thank you all.
## News
* **2024-07-19**: We released an updated version of SparseOcc on [arXiv](https://arxiv.org/abs/2312.17118). All charts and colors have been carefully adjusted. Delete the old version and download the new one!
* **2024-07-01**: SparseOcc is accepted to ECCV 2024.
* **2024-06-27**: SparseOcc v1.1 is released. In this change, we introduce BEV data augmentation (BDA) and Lovasz-Softmax loss to further enhance the performance. Compared with [v1.0](https://github.com/MCG-NJU/SparseOcc/tree/v1.0) (35.0 RayIoU with 48 epochs), SparseOcc v1.1 can achieve 36.8 RayIoU with 24 epochs!
* **2024-05-29**: We add support for [OpenOcc v2](configs/r50_nuimg_704x256_8f_openocc.py) dataset (without occupancy flow).
* **2024-04-11**: The panoptic version of SparseOcc ([configs/r50_nuimg_704x256_8f_pano.py](configs/r50_nuimg_704x256_8f_pano.py)) is released.
* **2024-04-09**: An updated arXiv version [https://arxiv.org/abs/2312.17118v3](https://arxiv.org/abs/2312.17118v3) has been released.
* **2024-03-31**: We release the code and pretrained weights.
* **2023-12-30**: We release the paper.
## Highlights
**New model**:1st_place_medal:: SparseOcc initially reconstructs a sparse 3D representation from visual inputs and subsequently predicts semantic/instance occupancy from the 3D sparse representation by sparse queries.

**New evaluation metric**:chart_with_upwards_trend:: We design a thoughtful ray-based evaluation metric, namely RayIoU, to solve the inconsistency penalty along depths raised in traditional voxel-level mIoU criteria.

Some FAQs from the community about the evaluation metrics:
1. **Why does training with visible masks result in significant improvements in the old mIoU metric, but not in the new RayIoU metric?** As mentioned in the paper, when using the visible mask during training, the area behind the surface won't be supervised, so the model tends to fill this area with duplicated predictions, leading to a thicker surface. The old metric inconsistently penalizes along the depth axis when the prediction has a thick surface. Thus, this ''imporovement'' is mainly due to the vulnerability of old metric.
2. **Why SparseOcc cannot exploit the vulnerability of the old metrics?** As SparseOcc employs a fully sparse architecture, it always predicts a thin surface. Thus, there are two ways for a fair comparison: (a) use the old metric, but all methods must predict a thin surface, which implies they cannot use the visible mask during training; (b) use RayIoU, as it is more reasonable and can fairly compare thick or thin surface. Our method achieves SOTA performance on both cases.
3. **Does RayIoU overlook interior reconstruction?** Firstly, we are unable to obtain the interior occupancy ground-truth. This is because the ground-truth is derived from voxelizing LiDAR point clouds, and LiDARs are only capable of scanning the thin surface of an object. Secondly, the query ray in RayIoU can originate from any position within the scene (see the figure above). This allows it to evaluate the overall reconstruction performance, unlike depth estimation. We would like to emphasize that the evaluation logic of RayIoU aligns with the process of ground-truth generation.
If you have other questions, feel free to contact me (Haisong Liu, afterthat97@gmail.com).
## Model Zoo
These results are from our latest version, v1.1, which outperforms the results reported in the paper. Additionally, our implementation differs slightly from the original paper. If you wish to reproduce the paper exactly, please refer to the [v1.0](https://github.com/MCG-NJU/SparseOcc/tree/v1.0) tag.
| Setting | Epochs | Training Cost | RayIoU | RayPQ | FPS | Weights |
|----------|:--------:|:-------------:|:------:|:-----:|:---:|:-------:|
| [r50_nuimg_704x256_8f](configs/r50_nuimg_704x256_8f.py) | 24 | 15h, ~12GB | 36.8 | - | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_24e_v1.1.pth) |
| [r50_nuimg_704x256_8f_60e](configs/r50_nuimg_704x256_8f_60e.py) | 60 | 37h, ~12GB | 37.7 | - | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_60e_v1.1.pth) |
| [r50_nuimg_704x256_8f_pano](configs/r50_nuimg_704x256_8f_pano.py) | 24 | 15h, ~12GB | 35.9 | 14.0 | 17.3 | [github](https://github.com/MCG-NJU/SparseOcc/releases/download/v1.1/sparseocc_r50_nuimg_704x256_8f_pano_24e_v1.1.pth) |
* The backbone is pretrained on [nuImages](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth). Download the weights to `pretrain/xxx.pth` before you start training.
* FPS is measured with Intel(R) Xeon(R) Platinum 8369B CPU and NVIDIA A100-SXM4-80GB GPU (PyTorch `fp32` backend, including data loading).
* We will release more settings in the future.
## Environment
> The requirements are the same as those of [SparseBEV](https://github.com/MCG-NJU/SparseBEV).
Install PyTorch 2.0 + CUDA 11.8:
```
conda create -n sparseocc python=3.8
conda activate sparseocc
conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.8 -c pytorch -c nvidia
```
Install other dependencies:
```
pip install openmim
mim install mmcv-full==1.6.0
mim install mmdet==2.28.2
mim install mmsegmentation==0.30.0
mim install mmdet3d==1.0.0rc6
pip install setuptools==59.5.0
pip install numpy==1.23.5
```
Install turbojpeg and pillow-simd to speed up data loading (optional but important):
```
sudo apt-get update
sudo apt-get install -y libturbojpeg
pip install pyturbojpeg
pip uninstall pillow
pip install pillow-simd==9.0.0.post1
```
Compile CUDA extensions:
```
cd models/csrc
python setup.py build_ext --inplace
```
## Prepare Dataset
> The first two steps are the same as those of [SparseBEV](https://github.com/MCG-NJU/SparseBEV).
1. Download nuScenes from [https://www.nuscenes.org/nuscenes](https://www.nuscenes.org/nuscenes), put it to `data/nuscenes` and preprocess it with [mmdetection3d](https://github.com/open-mmlab/mmdetection3d/tree/v1.0.0rc6).
2. Download the generated info file from [gdrive](https://drive.google.com/file/d/1uyoUuSRIVScrm_CUpge6V_UzwDT61ODO/view?usp=sharing) and unzip it. These `*.pkl` files can also be generated with our script: `gen_sweep_info.py`.
3. Download Occ3D-nuScenes occupancy GT from [gdrive](https://drive.google.com/file/d/1kiXVNSEi3UrNERPMz_CfiJXKkgts_5dY/view?usp=drive_link), unzip it, and save it to `data/nuscenes/occ3d`.
4. Folder structure:
```
data/nuscenes
├── maps
├── nuscenes_infos_test_sweep.pkl
├── nuscenes_infos_train_sweep.pkl
├── nuscenes_infos_val_sweep.pkl
├── samples
├── sweeps
├── v1.0-test
└── v1.0-trainval
└── occ3d
├── scene-0001
│ ├── 0037a705a2e04559b1bba6c01beca1cf
│ │ └── labels.npz
│ ├── 026155aa1c554e2f87914ec9ba80acae
│ │ └── labels.npz
...
```
5. (Optional) Generate the panoptic occupancy ground truth with `gen_instance_info.py`. The panoptic version of Occ3D will be saved to `data/nuscenes/occ3d_panoptic`.
## Training
Train SparseOcc with 8 GPUs:
```
torchrun --nproc_per_node 8 train.py --config configs/sparseocc_r50_nuimg_704x256_8f.py
```
Train SparseOcc with 4 GPUs (i.e the last four GPUs):
```
export CUDA_VISIBLE_DEVICES=4,5,6,7
torchrun --nproc_per_node 4 train.py --config configs/sparseocc_r50_nuimg_704x256_8f.py
```
The batch size for each GPU will be scaled automatically. So there is no need to modify the `batch_size` in config files.
## Evaluation
Single-GPU evaluation:
```
export CUDA_VISIBLE_DEVICES=0
python val.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth
```
Multi-GPU evaluation:
```
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
torchrun --nproc_per_node 8 val.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth
```
## Standalone Evaluation
If you want to evaluate your own model using RayIoU, please follow the steps below:
1. Save the predictions (shape=`[200x200x16]`, dtype=`np.uint8`) with the compressed `npz` format. For example:
```
save_path = os.path.join(save_dir, sample_token + '.npz')
np.savez_compressed(save_path, pred=sem_pred)
```
2. The filename for each sample is `sample_token.npz`, for example:
```
prediction/your_model
├── 000681a060c04755a1537cf83b53ba57.npz
├── 000868a72138448191b4092f75ed7776.npz
├── 0017c2623c914571a1ff2a37f034ffd7.npz
├── ...
```
3. Run `ray_metrics.py` to evaluate on the RayIoU:
```
python ray_metrics.py --pred-dir prediction/your_model
```
## Timing
FPS is measured with a single GPU:
```
export CUDA_VISIBLE_DEVICES=0
python timing.py --config configs/sparseocc_r50_nuimg_704x256_8f.py --weights checkpoints/sparseocc_r50_nuimg_704x256_8f.pth
```
## Acknowledgements
Many thanks to these excellent open-source projects:
* [MaskFormer](https://github.com/facebookresearch/MaskFormer)
* [NeuralRecon](https://github.com/zju3dv/NeuralRecon)
* [4D-Occ](https://github.com/tarashakhurana/4d-occ-forecasting)
* [MMDetection3D](https://github.com/open-mmlab/mmdetection3d)
================================================
FILE: configs/r50_nuimg_704x256_8f.py
================================================
dataset_type = 'NuSceneOcc'
dataset_root = 'data/nuscenes/'
occ_gt_root = 'data/nuscenes/occ3d'
# If point cloud range is changed, the models should also change their point
# cloud range accordingly
point_cloud_range = [-40, -40, -1.0, 40, 40, 5.4]
occ_size = [200, 200, 16]
img_norm_cfg = dict(
mean=[123.675, 116.280, 103.530],
std=[58.395, 57.120, 57.375],
to_rgb=True
)
# For nuScenes we usually do 10-class detection
det_class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
occ_class_names = [
'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation', 'free'
]
input_modality = dict(
use_lidar=False,
use_camera=True,
use_radar=False,
use_map=False,
use_external=False
)
_dim_ = 256
_num_points_ = 4
_num_groups_ = 4
_num_layers_ = 2
_num_frames_ = 8
_num_queries_ = 100
_topk_training_ = [4000, 16000, 64000]
_topk_testing_ = [2000, 8000, 32000]
model = dict(
type='SparseOcc',
data_aug=dict(
img_color_aug=True, # Move some augmentations to GPU
img_norm_cfg=img_norm_cfg,
img_pad_cfg=dict(size_divisor=32)),
use_mask_camera=False,
img_backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN2d', requires_grad=True),
norm_eval=True,
style='pytorch',
with_cp=True),
img_neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=_dim_,
num_outs=4),
pts_bbox_head=dict(
type='SparseOccHead',
class_names=occ_class_names,
embed_dims=_dim_,
occ_size=occ_size,
pc_range=point_cloud_range,
transformer=dict(
type='SparseOccTransformer',
embed_dims=_dim_,
num_layers=_num_layers_,
num_frames=_num_frames_,
num_points=_num_points_,
num_groups=_num_groups_,
num_queries=_num_queries_,
num_levels=4,
num_classes=len(occ_class_names),
pc_range=point_cloud_range,
occ_size=occ_size,
topk_training=_topk_training_,
topk_testing=_topk_testing_),
loss_cfgs=dict(
loss_mask2former=dict(
type='Mask2FormerLoss',
num_classes=len(occ_class_names),
no_class_weight=0.1,
loss_cls_weight=2.0,
loss_mask_weight=5.0,
loss_dice_weight=5.0,
),
loss_geo_scal=dict(
type='GeoScalLoss',
num_classes=len(occ_class_names),
loss_weight=1.0
),
loss_sem_scal=dict(
type='SemScalLoss',
num_classes=len(occ_class_names),
loss_weight=1.0
)
),
),
)
ida_aug_conf = {
'resize_lim': (0.38, 0.55),
'final_dim': (256, 704),
'bot_pct_lim': (0.0, 0.0),
'rot_lim': (0.0, 0.0),
'H': 900, 'W': 1600,
'rand_flip': True,
}
bda_aug_conf = dict(
rot_lim=(-22.5, 22.5),
scale_lim=(1., 1.),
flip_dx_ratio=0.5,
flip_dy_ratio=0.5
)
train_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1),
dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=True),
dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)),
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
dict(type='DefaultFormatBundle3D', class_names=det_class_names),
dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], # other keys: 'mask_camera'
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))
]
test_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True),
dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=False),
dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)),
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
dict(type='DefaultFormatBundle3D', class_names=det_class_names),
dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))
]
data = dict(
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_root=dataset_root,
occ_gt_root=occ_gt_root,
ann_file=dataset_root + 'nuscenes_infos_train_sweep.pkl',
pipeline=train_pipeline,
classes=det_class_names,
modality=input_modality,
test_mode=False
),
val=dict(
type=dataset_type,
data_root=dataset_root,
occ_gt_root=occ_gt_root,
ann_file=dataset_root + 'nuscenes_infos_val_sweep.pkl',
pipeline=test_pipeline,
classes=det_class_names,
modality=input_modality,
test_mode=True
),
test=dict(
type=dataset_type,
data_root=dataset_root,
occ_gt_root=occ_gt_root,
ann_file=dataset_root + 'nuscenes_infos_test_sweep.pkl',
pipeline=test_pipeline,
classes=det_class_names,
modality=input_modality,
test_mode=True
),
)
optimizer = dict(
type='AdamW',
lr=5e-4,
paramwise_cfg=dict(
custom_keys={
'img_backbone': dict(lr_mult=0.1),
'sampling_offset': dict(lr_mult=0.1),
}),
weight_decay=0.01
)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
by_epoch=True,
step=[22, 24],
gamma=0.2
)
total_epochs = 24
batch_size = 8
# load pretrained weights
load_from = 'pretrain/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth'
revise_keys = [('backbone', 'img_backbone')]
# resume the last training
resume_from = None
# checkpointing
checkpoint_config = dict(interval=1, max_keep_ckpts=1)
# logging
log_config = dict(
interval=1,
hooks=[
dict(type='MyTextLoggerHook', interval=1, reset_flag=True),
dict(type='MyTensorboardLoggerHook', interval=500, reset_flag=True)
]
)
# evaluation
eval_config = dict(interval=total_epochs)
# other flags
debug = False
================================================
FILE: configs/r50_nuimg_704x256_8f_60e.py
================================================
_base_ = ['./r50_nuimg_704x256_8f.py']
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
by_epoch=True,
step=[48, 60],
gamma=0.2
)
total_epochs = 60
# evaluation
eval_config = dict(interval=total_epochs)
================================================
FILE: configs/r50_nuimg_704x256_8f_openocc.py
================================================
_base_ = ['./r50_nuimg_704x256_8f.py']
occ_gt_root = 'data/nuscenes/openocc_v2'
det_class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
occ_class_names = [
'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation', 'free'
]
_num_frames_ = 8
model = dict(
pts_bbox_head=dict(
class_names=occ_class_names,
transformer=dict(
num_classes=len(occ_class_names)),
loss_cfgs=dict(
loss_mask2former=dict(
num_classes=len(occ_class_names)
),
),
),
)
ida_aug_conf = {
'resize_lim': (0.38, 0.55),
'final_dim': (256, 704),
'bot_pct_lim': (0.0, 0.0),
'rot_lim': (0.0, 0.0),
'H': 900, 'W': 1600,
'rand_flip': False,
}
train_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1),
dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)),
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
dict(type='DefaultFormatBundle3D', class_names=det_class_names),
dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], # other keys: 'mask_camera'
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))
]
test_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True),
dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names)),
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
dict(type='DefaultFormatBundle3D', class_names=det_class_names),
dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))
]
data = dict(
workers_per_gpu=8,
train=dict(
pipeline=train_pipeline,
occ_gt_root=occ_gt_root
),
val=dict(
pipeline=test_pipeline,
occ_gt_root=occ_gt_root
),
test=dict(
pipeline=test_pipeline,
occ_gt_root=occ_gt_root
),
)
================================================
FILE: configs/r50_nuimg_704x256_8f_pano.py
================================================
_base_ = ['./r50_nuimg_704x256_8f.py']
occ_gt_root = 'data/nuscenes/occ3d_panoptic'
# For nuScenes we usually do 10-class detection
det_class_names = [
'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
]
occ_class_names = [
'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation', 'free'
]
_num_frames_ = 8
model = dict(
pts_bbox_head=dict(
panoptic=True
)
)
ida_aug_conf = {
'resize_lim': (0.38, 0.55),
'final_dim': (256, 704),
'bot_pct_lim': (0.0, 0.0),
'rot_lim': (0.0, 0.0),
'H': 900, 'W': 1600,
'rand_flip': True,
}
bda_aug_conf = dict(
rot_lim=(-22.5, 22.5),
scale_lim=(1., 1.),
flip_dx_ratio=0.5,
flip_dy_ratio=0.5
)
train_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1),
dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=True),
dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names), inst_class_ids=[2, 3, 4, 5, 6, 7, 9, 10]),
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
dict(type='DefaultFormatBundle3D', class_names=det_class_names),
dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'], # other keys: 'mask_camera'
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))
]
test_pipeline = [
dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=_num_frames_ - 1, test_mode=True),
dict(type='BEVAug', bda_aug_conf=bda_aug_conf, classes=det_class_names, is_train=False),
dict(type='LoadOccGTFromFile', num_classes=len(occ_class_names), inst_class_ids=[2, 3, 4, 5, 6, 7, 9, 10]),
dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
dict(type='DefaultFormatBundle3D', class_names=det_class_names),
dict(type='Collect3D', keys=['img', 'voxel_semantics', 'voxel_instances', 'instance_class_ids'],
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp', 'ego2lidar'))
]
data = dict(
workers_per_gpu=8,
train=dict(
pipeline=train_pipeline,
occ_gt_root=occ_gt_root
),
val=dict(
pipeline=test_pipeline,
occ_gt_root=occ_gt_root
),
test=dict(
pipeline=test_pipeline,
occ_gt_root=occ_gt_root
),
)
================================================
FILE: gen_instance_info.py
================================================
import os
import tqdm
import glob
import pickle
import argparse
import numpy as np
import torch
import multiprocessing
from pyquaternion import Quaternion
from nuscenes.utils.data_classes import Box
from nuscenes.utils.geometry_utils import points_in_box
parser = argparse.ArgumentParser()
parser.add_argument('--nusc-root', default='data/nuscenes')
parser.add_argument('--occ3d-root', default='data/nuscenes/occ3d')
parser.add_argument('--output-dir', default='data/nuscenes/occ3d_panoptic')
parser.add_argument('--version', default='v1.0-trainval')
args = parser.parse_args()
token2path = {}
for gt_path in glob.glob(os.path.join(args.occ3d_root, '*/*/*.npz')):
token = gt_path.split('/')[-2]
token2path[token] = gt_path
occ_class_names = [
'others', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation', 'free'
]
det_class_names = [
'car', 'truck', 'trailer', 'bus', 'construction_vehicle',
'bicycle', 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier'
]
def convert_to_nusc_box(bboxes, lift_center=False, wlh_margin=0.0):
results = []
for q in range(bboxes.shape[0]):
bbox = bboxes[q].copy()
if lift_center:
bbox[2] += bbox[5] * 0.5
bbox_yaw = -bbox[6] - np.pi / 2
orientation = Quaternion(axis=[0, 0, 1], radians=bbox_yaw).inverse
box = Box(
center=[bbox[0], bbox[1], bbox[2]],
# 0.8 in pc range is roungly 2 voxels in occ grid
# enlarge bbox to include voxels on the edge
size=[bbox[3]+wlh_margin, bbox[4]+wlh_margin, bbox[5]+wlh_margin],
orientation=orientation,
)
results.append(box)
return results
def meshgrid3d(occ_size, pc_range): # points in ego coord
W, H, D = occ_size
xs = torch.linspace(0.5, W - 0.5, W).view(W, 1, 1).expand(W, H, D) / W
ys = torch.linspace(0.5, H - 0.5, H).view(1, H, 1).expand(W, H, D) / H
zs = torch.linspace(0.5, D - 0.5, D).view(1, 1, D).expand(W, H, D) / D
xs = xs * (pc_range[3] - pc_range[0]) + pc_range[0]
ys = ys * (pc_range[4] - pc_range[1]) + pc_range[1]
zs = zs * (pc_range[5] - pc_range[2]) + pc_range[2]
xyz = torch.stack((xs, ys, zs), -1)
return xyz
def process_add_instance_info(sample):
point_cloud_range = [-40, -40, -1.0, 40, 40, 5.4]
occ_size = [200, 200, 16]
num_classes = 18
occ_gt_path = token2path[sample['token']]
occ_labels = np.load(occ_gt_path)
occ_gt = occ_labels['semantics']
gt_boxes = sample['gt_boxes']
gt_names = sample['gt_names']
bboxes = convert_to_nusc_box(gt_boxes)
instance_gt = np.zeros(occ_gt.shape).astype(np.uint8)
instance_id = 1
pts = meshgrid3d(occ_size, point_cloud_range).numpy()
# filter out free voxels to accelerate
valid_idx = np.where(occ_gt < num_classes - 1)
flatten_occ_gt = occ_gt[valid_idx]
flatten_inst_gt = instance_gt[valid_idx]
flatten_pts = pts[valid_idx]
instance_boxes = []
instance_class_ids = []
for i in range(len(gt_names)):
if gt_names[i] not in occ_class_names:
continue
occ_tag_id = occ_class_names.index(gt_names[i])
# Move box to ego vehicle coord system
bbox = bboxes[i]
bbox.rotate(Quaternion(sample['lidar2ego_rotation']))
bbox.translate(np.array(sample['lidar2ego_translation']))
mask = points_in_box(bbox, flatten_pts.transpose(1, 0))
# ignore voxels not belonging to this class
mask[mask] = (flatten_occ_gt[mask] == occ_tag_id)
# ignore voxels already occupied
mask[mask] = (flatten_inst_gt[mask] == 0)
# only instance with at least 1 voxel will be recorded
if mask.sum() > 0:
flatten_inst_gt[mask] = instance_id
instance_id += 1
# enlarge boxes to include voxels on the edge
new_box = bbox.copy()
new_box.wlh = new_box.wlh + 0.8
instance_boxes.append(new_box)
instance_class_ids.append(occ_tag_id)
# classes that should be viewed as one instance
all_class_ids_unique = np.unique(occ_gt)
for i, class_name in enumerate(occ_class_names):
if class_name in det_class_names or class_name == 'free' or i not in all_class_ids_unique:
continue
flatten_inst_gt[flatten_occ_gt == i] = instance_id
instance_id += 1
# post process unconvered non-occupied voxels
uncover_idx = np.where(flatten_inst_gt == 0)
uncover_pts = flatten_pts[uncover_idx]
uncover_inst_gt = np.zeros_like(uncover_pts[..., 0]).astype(np.uint8)
unconver_occ_gt = flatten_occ_gt[uncover_idx]
# uncover_inst_dist records the dist between each voxel and its current nearest bbox's center
uncover_inst_dist = np.ones_like(uncover_pts[..., 0]) * 1e8
for i, box in enumerate(instance_boxes):
# important, non-background inst id starts from 1
inst_id = i + 1
class_id = instance_class_ids[i]
mask = points_in_box(box, uncover_pts.transpose(1, 0))
# mask voxels not belonging to this class
mask[unconver_occ_gt != class_id] = False
dist = np.sum((box.center - uncover_pts) ** 2, axis=-1)
# voxels that have already been assigned to a closer box's instance should be ignored
# voxels that not inside the box should be ignored
# `mask[(dist >= uncover_inst_dist)]=False` is right, as it only transforms True masks into False without converting False into True
# to give readers a more clear understanding, the most standard writing is `mask[mask & (dist >= uncover_inst_dist)]=False`
mask[dist >= uncover_inst_dist] = False
# mask[mask & (dist >= uncover_inst_dist)]=False
# important: only voxels inside the box (mask = True) and having no closer identical-class box need to update dist
uncover_inst_dist[mask] = dist[mask]
uncover_inst_gt[mask] = inst_id
flatten_inst_gt[uncover_idx] = uncover_inst_gt
instance_gt[valid_idx] = flatten_inst_gt
# not using this checking function yet
# assert (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum() < 100, "too many non-free voxels are not assigned to any instance in %s"%(occ_gt_path)
# global max_margin
# if max_margin < (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum():
# print("###### new max margin: ", max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum()))
# max_margin = max(max_margin, (instance_gt == 0).sum() - (occ_gt == num_classes-1).sum())
# save to original path
data_split = occ_gt_path.split(os.path.sep)[-3:]
data_path = os.path.sep.join(data_split)
##### Warning: Using args.xxx (global variable) here is strongly unrecommended
save_path = os.path.join(args.output_dir, data_path)
save_dir = os.path.split(save_path)[0]
if not os.path.exists(save_dir):
os.makedirs(save_dir)
if np.unique(instance_gt).shape[0] != instance_gt.max()+1:
print('warning: some instance masks are covered by following ones %s'%(save_dir))
# only semantic and mask information is needed to be reserved
retain_keys = ['semantics', 'mask_lidar', 'mask_camera']
new_occ_labels = {k: occ_labels[k] for k in retain_keys}
new_occ_labels['instances'] = instance_gt
np.savez_compressed(save_path, **new_occ_labels)
def add_instance_info(sample_infos):
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# all cpus participate in multi processing
pool = multiprocessing.Pool(multiprocessing.cpu_count())
with tqdm.tqdm(total=len(sample_infos['infos'])) as pbar:
for _ in pool.imap(process_add_instance_info, sample_infos['infos']):
pbar.update(1)
pool.close()
pool.join()
if __name__ == '__main__':
if args.version == 'v1.0-trainval':
sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_train_sweep.pkl'), 'rb'))
add_instance_info(sample_infos)
sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_val_sweep.pkl'), 'rb'))
add_instance_info(sample_infos)
elif args.version == 'v1.0-test':
sample_infos = pickle.load(open(os.path.join(args.nusc_root, 'nuscenes_infos_test_sweep.pkl'), 'rb'))
add_instance_info(sample_infos)
else:
raise ValueError
================================================
FILE: gen_sweep_info.py
================================================
# Generate info files manually
import os
import mmcv
import tqdm
import pickle
import argparse
import numpy as np
from nuscenes import NuScenes
from pyquaternion import Quaternion
parser = argparse.ArgumentParser()
parser.add_argument('--data-root', default='data/nuscenes')
parser.add_argument('--version', default='v1.0-trainval')
args = parser.parse_args()
def get_cam_info(nusc, sample_data):
pose_record = nusc.get('ego_pose', sample_data['ego_pose_token'])
cs_record = nusc.get('calibrated_sensor', sample_data['calibrated_sensor_token'])
sensor2ego_translation = cs_record['translation']
ego2global_translation = pose_record['translation']
sensor2ego_rotation = Quaternion(cs_record['rotation']).rotation_matrix
ego2global_rotation = Quaternion(pose_record['rotation']).rotation_matrix
cam_intrinsic = np.array(cs_record['camera_intrinsic'])
sensor2global_rotation = sensor2ego_rotation.T @ ego2global_rotation.T
sensor2global_translation = sensor2ego_translation @ ego2global_rotation.T + ego2global_translation
return {
'data_path': os.path.join(args.data_root, sample_data['filename']),
'sensor2global_rotation': sensor2global_rotation,
'sensor2global_translation': sensor2global_translation,
'cam_intrinsic': cam_intrinsic,
'timestamp': sample_data['timestamp'],
}
def add_sweep_info(nusc, sample_infos):
for curr_id in tqdm.tqdm(range(len(sample_infos['infos']))):
sample = nusc.get('sample', sample_infos['infos'][curr_id]['token'])
cam_types = [
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT',
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT'
]
curr_cams = dict()
for cam in cam_types:
curr_cams[cam] = nusc.get('sample_data', sample['data'][cam])
for cam in cam_types:
sample_data = nusc.get('sample_data', sample['data'][cam])
sweep_cam = get_cam_info(nusc, sample_data)
sample_infos['infos'][curr_id]['cams'][cam].update(sweep_cam)
# remove unnecessary
for cam in cam_types:
del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_translation']
del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_rotation']
del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_translation']
del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_rotation']
sweep_infos = []
if sample['prev'] != '': # add sweep frame between two key frame
for _ in range(5):
sweep_info = dict()
for cam in cam_types:
if curr_cams[cam]['prev'] == '':
sweep_info = sweep_infos[-1]
break
sample_data = nusc.get('sample_data', curr_cams[cam]['prev'])
sweep_cam = get_cam_info(nusc, sample_data)
curr_cams[cam] = sample_data
sweep_info[cam] = sweep_cam
sweep_infos.append(sweep_info)
sample_infos['infos'][curr_id]['sweeps'] = sweep_infos
return sample_infos
if __name__ == '__main__':
nusc = NuScenes(args.version, args.data_root)
if args.version == 'v1.0-trainval':
sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_train.pkl'), 'rb'))
sample_infos = add_sweep_info(nusc, sample_infos)
mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_train_sweep.pkl'))
sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_val.pkl'), 'rb'))
sample_infos = add_sweep_info(nusc, sample_infos)
mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_val_sweep.pkl'))
elif args.version == 'v1.0-test':
sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_test.pkl'), 'rb'))
sample_infos = add_sweep_info(nusc, sample_infos)
mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_test_sweep.pkl'))
else:
raise ValueError
================================================
FILE: lib/dvr/dvr.cpp
================================================
// Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting
// Modified by Haisong Liu
#include
#include
#include
/*
* CUDA forward declarations
*/
std::vector render_forward_cuda(torch::Tensor sigma,
torch::Tensor origin,
torch::Tensor points,
torch::Tensor tindex,
const std::vector grid,
std::string phase_name);
std::vector
render_cuda(torch::Tensor sigma, torch::Tensor origin, torch::Tensor points,
torch::Tensor tindex, std::string loss_name);
torch::Tensor init_cuda(torch::Tensor points, torch::Tensor tindex,
const std::vector grid);
/*
* C++ interface
*/
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std::vector
render_forward(torch::Tensor sigma, torch::Tensor origin, torch::Tensor points,
torch::Tensor tindex, const std::vector grid,
std::string phase_name) {
CHECK_INPUT(sigma);
CHECK_INPUT(origin);
CHECK_INPUT(points);
CHECK_INPUT(tindex);
return render_forward_cuda(sigma, origin, points, tindex, grid, phase_name);
}
std::vector render(torch::Tensor sigma, torch::Tensor origin,
torch::Tensor points, torch::Tensor tindex,
std::string loss_name) {
CHECK_INPUT(sigma);
CHECK_INPUT(origin);
CHECK_INPUT(points);
CHECK_INPUT(tindex);
return render_cuda(sigma, origin, points, tindex, loss_name);
}
torch::Tensor init(torch::Tensor points, torch::Tensor tindex,
const std::vector grid) {
CHECK_INPUT(points);
CHECK_INPUT(tindex);
return init_cuda(points, tindex, grid);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("init", &init, "Initialize");
m.def("render", &render, "Render");
m.def("render_forward", &render_forward, "Render (forward pass only)");
}
================================================
FILE: lib/dvr/dvr.cu
================================================
// Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting
// Modified by Haisong Liu
#include
#include
#include
#include
#include
#include
#include
#define MAX_D 1446 // 700 + 700 + 45 + 1
#define MAX_STEP 1000
enum LossType {L1, L2, ABSREL};
enum PhaseName {TEST, TRAIN};
template
__global__ void init_cuda_kernel(
const torch::PackedTensorAccessor32 points,
const torch::PackedTensorAccessor32 tindex,
torch::PackedTensorAccessor32 occupancy) {
// batch index
const auto n = blockIdx.y;
// ray index
const auto c = blockIdx.x * blockDim.x + threadIdx.x;
// num of rays
const auto M = points.size(1);
const auto T = occupancy.size(1);
// we allocated more threads than num_rays
if (c < M) {
// ray end point
const auto t = tindex[n][c];
// invalid points
assert(T == 1 || t < T);
// if t < 0, it is a padded point
if (t < 0) return;
// time index for sigma
// when T = 1, we have a static sigma
const auto ts = (T == 1) ? 0 : t;
// grid shape
const int vzsize = occupancy.size(2);
const int vysize = occupancy.size(3);
const int vxsize = occupancy.size(4);
// assert(vzsize + vysize + vxsize <= MAX_D);
// end point
const int vx = int(points[n][c][0]);
const int vy = int(points[n][c][1]);
const int vz = int(points[n][c][2]);
//
if (0 <= vx && vx < vxsize &&
0 <= vy && vy < vysize &&
0 <= vz && vz < vzsize) {
occupancy[n][ts][vz][vy][vx] = 1;
}
}
}
template
__global__ void render_forward_cuda_kernel(
const torch::PackedTensorAccessor32 sigma,
const torch::PackedTensorAccessor32 origin,
const torch::PackedTensorAccessor32 points,
const torch::PackedTensorAccessor32 tindex,
// torch::PackedTensorAccessor32 pog,
torch::PackedTensorAccessor32 pred_dist,
torch::PackedTensorAccessor32 gt_dist,
torch::PackedTensorAccessor32 coord_index,
PhaseName train_phase) {
// batch index
const auto n = blockIdx.y;
// ray index
const auto c = blockIdx.x * blockDim.x + threadIdx.x;
// num of rays
const auto M = points.size(1);
const auto T = sigma.size(1);
// we allocated more threads than num_rays
if (c < M) {
// ray end point
const auto t = tindex[n][c];
// invalid points
// assert(t < T);
assert(T == 1 || t < T);
// time index for sigma
// when T = 1, we have a static sigma
const auto ts = (T == 1) ? 0 : t;
// if t < 0, it is a padded point
if (t < 0) return;
// grid shape
const int vzsize = sigma.size(2);
const int vysize = sigma.size(3);
const int vxsize = sigma.size(4);
// assert(vzsize + vysize + vxsize <= MAX_D);
// origin
const double xo = origin[n][t][0];
const double yo = origin[n][t][1];
const double zo = origin[n][t][2];
// end point
const double xe = points[n][c][0];
const double ye = points[n][c][1];
const double ze = points[n][c][2];
// locate the voxel where the origin resides
const int vxo = int(xo);
const int vyo = int(yo);
const int vzo = int(zo);
const int vxe = int(xe);
const int vye = int(ye);
const int vze = int(ze);
// NOTE: new
int vx = vxo;
int vy = vyo;
int vz = vzo;
// origin to end
const double rx = xe - xo;
const double ry = ye - yo;
const double rz = ze - zo;
double gt_d = sqrt(rx * rx + ry * ry + rz * rz);
// directional vector
const double dx = rx / gt_d;
const double dy = ry / gt_d;
const double dz = rz / gt_d;
// In which direction the voxel ids are incremented.
const int stepX = (dx >= 0) ? 1 : -1;
const int stepY = (dy >= 0) ? 1 : -1;
const int stepZ = (dz >= 0) ? 1 : -1;
// Distance along the ray to the next voxel border from the current position (tMaxX, tMaxY, tMaxZ).
const double next_voxel_boundary_x = vx + (stepX < 0 ? 0 : 1);
const double next_voxel_boundary_y = vy + (stepY < 0 ? 0 : 1);
const double next_voxel_boundary_z = vz + (stepZ < 0 ? 0 : 1);
// tMaxX, tMaxY, tMaxZ -- distance until next intersection with voxel-border
// the value of t at which the ray crosses the first vertical voxel boundary
double tMaxX = (dx!=0) ? (next_voxel_boundary_x - xo)/dx : DBL_MAX; //
double tMaxY = (dy!=0) ? (next_voxel_boundary_y - yo)/dy : DBL_MAX; //
double tMaxZ = (dz!=0) ? (next_voxel_boundary_z - zo)/dz : DBL_MAX; //
// tDeltaX, tDeltaY, tDeltaZ --
// how far along the ray we must move for the horizontal component to equal the width of a voxel
// the direction in which we traverse the grid
// can only be FLT_MAX if we never go in that direction
const double tDeltaX = (dx!=0) ? stepX/dx : DBL_MAX;
const double tDeltaY = (dy!=0) ? stepY/dy : DBL_MAX;
const double tDeltaZ = (dz!=0) ? stepZ/dz : DBL_MAX;
int3 path[MAX_D];
double csd[MAX_D]; // cumulative sum of sigma times delta
double p[MAX_D]; // alpha
double d[MAX_D];
// forward raymarching with voxel traversal
int step = 0; // total number of voxels traversed
int count = 0; // number of voxels traversed inside the voxel grid
double last_d = 0.0; // correct initialization
// voxel traversal raycasting
bool was_inside = false;
while (true) {
bool inside = (0 <= vx && vx < vxsize) &&
(0 <= vy && vy < vysize) &&
(0 <= vz && vz < vzsize);
if (inside) {
was_inside = true;
path[count] = make_int3(vx, vy, vz);
} else if (was_inside) { // was but no longer inside
// we know we are not coming back so terminate
break;
} /*else if (last_d > gt_d) {
break;
} */
/*else { // has not gone inside yet
// assert(count == 0);
// (1) when we have hit the destination but haven't gone inside the voxel grid
// (2) when we have traveled MAX_D voxels but haven't found one valid voxel
// handle intersection corner cases in case of infinite loop
bool hit = (vx == vxe && vy == vye && vz == vze); // this test seems brittle with corner cases
if (hit || step >= MAX_D)
break;
//if (last_d >= gt_d || step >= MAX_D) break;
} */
// _d represents the ray distance has traveled before escaping the current voxel cell
double _d = 0.0;
// voxel traversal
if (tMaxX < tMaxY) {
if (tMaxX < tMaxZ) {
_d = tMaxX;
vx += stepX;
tMaxX += tDeltaX;
} else {
_d = tMaxZ;
vz += stepZ;
tMaxZ += tDeltaZ;
}
} else {
if (tMaxY < tMaxZ) {
_d = tMaxY;
vy += stepY;
tMaxY += tDeltaY;
} else {
_d = tMaxZ;
vz += stepZ;
tMaxZ += tDeltaZ;
}
}
if (inside) {
// get sigma at the current voxel
const int3 &v = path[count]; // use the recorded index
const double _sigma = sigma[n][ts][v.z][v.y][v.x];
const double _delta = max(0.0, _d - last_d); // THIS TURNS OUT IMPORTANT
const double sd = _sigma * _delta;
if (count == 0) { // the first voxel inside
csd[count] = sd;
p[count] = 1 - exp(-sd);
} else {
csd[count] = csd[count-1] + sd;
p[count] = exp(-csd[count-1]) - exp(-csd[count]);
}
// record the traveled distance
d[count] = _d;
// count the number of voxels we have escaped
count ++;
}
last_d = _d;
step ++;
if (step > MAX_STEP) {
break;
}
}
// the total number of voxels visited should not exceed this number
assert(count <= MAX_D);
if (count > 0) {
// compute the expected ray distance
//double exp_d = 0.0;
double exp_d = d[count-1];
const int3 &v_init = path[count-1];
int x = v_init.x;
int y = v_init.y;
int z = v_init.z;
for (int i = 0; i < count; i++) {
//printf("%f\t%f\n",p[i], d[i]);
//exp_d += p[i] * d[i];
const int3 &v = path[i];
const double occ = sigma[n][ts][v.z][v.y][v.x];
if (occ > 0.5) {
exp_d = d[i];
x = v.x;
y = v.y;
z = v.z;
break;
}
}
//printf("%f\n",exp_d);
// add an imaginary sample at the end point should gt_d exceeds max_d
double p_out = exp(-csd[count-1]);
double max_d = d[count-1];
// if (gt_d > max_d)
// exp_d += (p_out * gt_d);
// p_out is the probability the ray escapes the voxel grid
//exp_d += (p_out * max_d);
if (train_phase == 1) {
gt_d = min(gt_d, max_d);
}
// write the rendered ray distance (max_d)
pred_dist[n][c] = exp_d;
gt_dist[n][c] = gt_d;
coord_index[n][c][0] = double(x);
coord_index[n][c][1] = double(y);
coord_index[n][c][2] = double(z);
// // write occupancy
// for (int i = 0; i < count; i ++) {
// const int3 &v = path[i];
// auto & occ = pog[n][t][v.z][v.y][v.x];
// if (p[i] >= occ) {
// occ = p[i];
// }
// }
}
}
}
/*
* input shape
* sigma : N x T x H x L x W
* origin : N x T x 3
* points : N x M x 4
* output shape
* dist : N x M
*/
std::vector render_forward_cuda(
torch::Tensor sigma,
torch::Tensor origin,
torch::Tensor points,
torch::Tensor tindex,
const std::vector grid,
std::string phase_name) {
const auto N = points.size(0); // batch size
const auto M = points.size(1); // num of rays
const auto T = grid[0];
const auto H = grid[1];
const auto L = grid[2];
const auto W = grid[3];
const auto device = sigma.device();
const int threads = 1024;
const dim3 blocks((M + threads - 1) / threads, N);
//
// const auto dtype = points.dtype();
// const auto options = torch::TensorOptions().dtype(dtype).device(device).requires_grad(false);
// auto pog = torch::zeros({N, T, H, L, W}, options);
// perform rendering
auto gt_dist = -torch::ones({N, M}, device);
auto pred_dist = -torch::ones({N, M}, device);
auto coord_index = torch::zeros({N, M, 3}, device);
PhaseName train_phase;
if (phase_name.compare("test") == 0) {
train_phase = TEST;
} else if (phase_name.compare("train") == 0){
train_phase = TRAIN;
} else {
std::cout << "UNKNOWN PHASE NAME: " << phase_name << std::endl;
exit(1);
}
AT_DISPATCH_FLOATING_TYPES(sigma.type(), "render_forward_cuda", ([&] {
render_forward_cuda_kernel<<>>(
sigma.packed_accessor32(),
origin.packed_accessor32(),
points.packed_accessor32(),
tindex.packed_accessor32(),
// pog.packed_accessor32(),
pred_dist.packed_accessor32(),
gt_dist.packed_accessor32(),
coord_index.packed_accessor32(),
train_phase);
}));
cudaDeviceSynchronize();
// return {pog, pred_dist, gt_dist};
return {pred_dist, gt_dist, coord_index};
}
template
__global__ void render_cuda_kernel(
const torch::PackedTensorAccessor32 sigma,
const torch::PackedTensorAccessor32 origin,
const torch::PackedTensorAccessor32 points,
const torch::PackedTensorAccessor32 tindex,
// const torch::PackedTensorAccessor32 occupancy,
torch::PackedTensorAccessor32 pred_dist,
torch::PackedTensorAccessor32 gt_dist,
torch::PackedTensorAccessor32 grad_sigma,
// torch::PackedTensorAccessor32 grad_sigma_count,
LossType loss_type) {
// batch index
const auto n = blockIdx.y;
// ray index
const auto c = blockIdx.x * blockDim.x + threadIdx.x;
// num of rays
const auto M = points.size(1);
const auto T = sigma.size(1);
// we allocated more threads than num_rays
if (c < M) {
// ray end point
const auto t = tindex[n][c];
// invalid points
// assert(t < T);
assert(T == 1 || t < T);
// time index for sigma
// when T = 1, we have a static sigma
const auto ts = (T == 1) ? 0 : t;
// if t < 0, it is a padded point
if (t < 0) return;
// grid shape
const int vzsize = sigma.size(2);
const int vysize = sigma.size(3);
const int vxsize = sigma.size(4);
// assert(vzsize + vysize + vxsize <= MAX_D);
// origin
const double xo = origin[n][t][0];
const double yo = origin[n][t][1];
const double zo = origin[n][t][2];
// end point
const double xe = points[n][c][0];
const double ye = points[n][c][1];
const double ze = points[n][c][2];
// locate the voxel where the origin resides
const int vxo = int(xo);
const int vyo = int(yo);
const int vzo = int(zo);
//
const int vxe = int(xe);
const int vye = int(ye);
const int vze = int(ze);
// NOTE: new
int vx = vxo;
int vy = vyo;
int vz = vzo;
// origin to end
const double rx = xe - xo;
const double ry = ye - yo;
const double rz = ze - zo;
double gt_d = sqrt(rx * rx + ry * ry + rz * rz);
// directional vector
const double dx = rx / gt_d;
const double dy = ry / gt_d;
const double dz = rz / gt_d;
// In which direction the voxel ids are incremented.
const int stepX = (dx >= 0) ? 1 : -1;
const int stepY = (dy >= 0) ? 1 : -1;
const int stepZ = (dz >= 0) ? 1 : -1;
// Distance along the ray to the next voxel border from the current position (tMaxX, tMaxY, tMaxZ).
const double next_voxel_boundary_x = vx + (stepX < 0 ? 0 : 1);
const double next_voxel_boundary_y = vy + (stepY < 0 ? 0 : 1);
const double next_voxel_boundary_z = vz + (stepZ < 0 ? 0 : 1);
// tMaxX, tMaxY, tMaxZ -- distance until next intersection with voxel-border
// the value of t at which the ray crosses the first vertical voxel boundary
double tMaxX = (dx!=0) ? (next_voxel_boundary_x - xo)/dx : DBL_MAX; //
double tMaxY = (dy!=0) ? (next_voxel_boundary_y - yo)/dy : DBL_MAX; //
double tMaxZ = (dz!=0) ? (next_voxel_boundary_z - zo)/dz : DBL_MAX; //
// tDeltaX, tDeltaY, tDeltaZ --
// how far along the ray we must move for the horizontal component to equal the width of a voxel
// the direction in which we traverse the grid
// can only be FLT_MAX if we never go in that direction
const double tDeltaX = (dx!=0) ? stepX/dx : DBL_MAX;
const double tDeltaY = (dy!=0) ? stepY/dy : DBL_MAX;
const double tDeltaZ = (dz!=0) ? stepZ/dz : DBL_MAX;
int3 path[MAX_D];
double csd[MAX_D]; // cumulative sum of sigma times delta
double p[MAX_D]; // alpha
double d[MAX_D];
double dt[MAX_D];
// forward raymarching with voxel traversal
int step = 0; // total number of voxels traversed
int count = 0; // number of voxels traversed inside the voxel grid
double last_d = 0.0; // correct initialization
// voxel traversal raycasting
bool was_inside = false;
while (true) {
bool inside = (0 <= vx && vx < vxsize) &&
(0 <= vy && vy < vysize) &&
(0 <= vz && vz < vzsize);
if (inside) { // now inside
was_inside = true;
path[count] = make_int3(vx, vy, vz);
} else if (was_inside) { // was inside but no longer
// we know we are not coming back so terminate
break;
} else if (last_d > gt_d) {
break;
} /* else { // has not gone inside yet
// assert(count == 0);
// (1) when we have hit the destination but haven't gone inside the voxel grid
// (2) when we have traveled MAX_D voxels but haven't found one valid voxel
// handle intersection corner cases in case of infinite loop
// bool hit = (vx == vxe && vy == vye && vz == vze);
// if (hit || step >= MAX_D)
// break;
if (last_d >= gt_d || step >= MAX_D) break;
} */
// _d represents the ray distance has traveled before escaping the current voxel cell
double _d = 0.0;
// voxel traversal
if (tMaxX < tMaxY) {
if (tMaxX < tMaxZ) {
_d = tMaxX;
vx += stepX;
tMaxX += tDeltaX;
} else {
_d = tMaxZ;
vz += stepZ;
tMaxZ += tDeltaZ;
}
} else {
if (tMaxY < tMaxZ) {
_d = tMaxY;
vy += stepY;
tMaxY += tDeltaY;
} else {
_d = tMaxZ;
vz += stepZ;
tMaxZ += tDeltaZ;
}
}
if (inside) {
// get sigma at the current voxel
const int3 &v = path[count]; // use the recorded index
const double _sigma = sigma[n][ts][v.z][v.y][v.x];
const double _delta = max(0.0, _d - last_d); // THIS TURNS OUT IMPORTANT
const double sd = _sigma * _delta;
if (count == 0) { // the first voxel inside
csd[count] = sd;
p[count] = 1 - exp(-sd);
} else {
csd[count] = csd[count-1] + sd;
p[count] = exp(-csd[count-1]) - exp(-csd[count]);
}
// record the traveled distance
d[count] = _d;
dt[count] = _delta;
// count the number of voxels we have escaped
count ++;
}
last_d = _d;
step ++;
if (step > MAX_STEP) {
break;
}
}
// the total number of voxels visited should not exceed this number
assert(count <= MAX_D);
// WHEN THERE IS AN INTERSECTION BETWEEN THE RAY AND THE VOXEL GRID
if (count > 0) {
// compute the expected ray distance
double exp_d = 0.0;
for (int i = 0; i < count; i ++)
exp_d += p[i] * d[i];
// add an imaginary sample at the end point should gt_d exceeds max_d
double p_out = exp(-csd[count-1]);
double max_d = d[count-1];
exp_d += (p_out * max_d);
gt_d = min(gt_d, max_d);
// write the rendered ray distance (max_d)
pred_dist[n][c] = exp_d;
gt_dist[n][c] = gt_d;
/* backward raymarching */
double dd_dsigma[MAX_D];
for (int i = count - 1; i >= 0; i --) {
// NOTE: probably need to double check again
if (i == count - 1)
dd_dsigma[i] = p_out * max_d;
else
dd_dsigma[i] = dd_dsigma[i+1] - exp(-csd[i]) * (d[i+1] - d[i]);
}
for (int i = count - 1; i >= 0; i --)
dd_dsigma[i] *= dt[i];
// option 2: cap at the boundary
for (int i = count - 1; i >= 0; i --)
dd_dsigma[i] -= dt[i] * p_out * max_d;
double dl_dd = 1.0;
if (loss_type == L1)
dl_dd = (exp_d >= gt_d) ? 1 : -1;
else if (loss_type == L2)
dl_dd = (exp_d - gt_d);
else if (loss_type == ABSREL)
dl_dd = (exp_d >= gt_d) ? (1.0/gt_d) : -(1.0/gt_d);
// apply chain rule
for (int i = 0; i < count; i ++) {
const int3 &v = path[i];
// NOTE: potential race conditions when writing gradients
grad_sigma[n][ts][v.z][v.y][v.x] += dl_dd * dd_dsigma[i];
// grad_sigma_count[n][ts][v.z][v.y][v.x] += 1;
}
}
}
}
/*
* input shape
* sigma : N x T x H x L x W
* origin : N x T x 3
* points : N x M x 4
* output shape
* dist : N x M
* loss : N x M
* grad_sigma : N x T x H x L x W
*/
std::vector render_cuda(
torch::Tensor sigma,
torch::Tensor origin,
torch::Tensor points,
torch::Tensor tindex,
std::string loss_name) {
const auto N = points.size(0); // batch size
const auto M = points.size(1); // num of rays
const auto device = sigma.device();
const int threads = 1024;
const dim3 blocks((M + threads - 1) / threads, N);
// perform rendering
auto gt_dist = -torch::ones({N, M}, device);
auto pred_dist = -torch::ones({N, M}, device);
auto grad_sigma = torch::zeros_like(sigma);
// auto grad_sigma_count = torch::zeros_like(sigma);
LossType loss_type;
if (loss_name.compare("l1") == 0) {
loss_type = L1;
} else if (loss_name.compare("l2") == 0) {
loss_type = L2;
} else if (loss_name.compare("absrel") == 0) {
loss_type = ABSREL;
} else if (loss_name.compare("bce") == 0){
loss_type = L1;
} else {
std::cout << "UNKNOWN LOSS TYPE: " << loss_name << std::endl;
exit(1);
}
AT_DISPATCH_FLOATING_TYPES(sigma.type(), "render_cuda", ([&] {
render_cuda_kernel<<>>(
sigma.packed_accessor32(),
origin.packed_accessor32(),
points.packed_accessor32(),
tindex.packed_accessor32(),
// occupancy.packed_accessor32(),
pred_dist.packed_accessor32(),
gt_dist.packed_accessor32(),
grad_sigma.packed_accessor32(),
// grad_sigma_count.packed_accessor32(),
loss_type);
}));
cudaDeviceSynchronize();
// grad_sigma_count += (grad_sigma_count == 0);
// grad_sigma /= grad_sigma_count;
return {pred_dist, gt_dist, grad_sigma};
}
/*
* input shape
* origin : N x T x 3
* points : N x M x 3
* tindex : N x M
* output shape
* occupancy: N x T x H x L x W
*/
torch::Tensor init_cuda(
torch::Tensor points,
torch::Tensor tindex,
const std::vector grid) {
const auto N = points.size(0); // batch size
const auto M = points.size(1); // num of rays
const auto T = grid[0];
const auto H = grid[1];
const auto L = grid[2];
const auto W = grid[3];
const auto dtype = points.dtype();
const auto device = points.device();
const auto options = torch::TensorOptions().dtype(dtype).device(device).requires_grad(false);
auto occupancy = torch::zeros({N, T, H, L, W}, options);
const int threads = 1024;
const dim3 blocks((M + threads - 1) / threads, N);
// initialize occupancy such that every voxel with one or more points is occupied
AT_DISPATCH_FLOATING_TYPES(points.type(), "init_cuda", ([&] {
init_cuda_kernel<<>>(
points.packed_accessor32(),
tindex.packed_accessor32(),
occupancy.packed_accessor32());
}));
// synchronize
cudaDeviceSynchronize();
return occupancy;
}
================================================
FILE: loaders/__init__.py
================================================
from .pipelines import __all__
from .nuscenes_dataset import CustomNuScenesDataset
from .nuscenes_occ_dataset import NuSceneOcc
__all__ = [
'CustomNuScenesDataset', 'NuSceneOcc'
]
================================================
FILE: loaders/builder.py
================================================
from functools import partial
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from torch.utils.data import DataLoader
from mmdet.datasets.builder import worker_init_fn
from mmdet.datasets.samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
def build_dataloader(dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=1,
dist=True,
shuffle=True,
seed=None,
**kwargs):
rank, world_size = get_dist_info()
if dist:
# DistributedGroupSampler will definitely shuffle the data to satisfy
# that images on each GPU are in the same group
if shuffle:
sampler = DistributedGroupSampler(
dataset, samples_per_gpu, world_size, rank, seed=seed)
else:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=False, seed=seed)
batch_size = samples_per_gpu
num_workers = workers_per_gpu
else:
sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
batch_size = num_gpus * samples_per_gpu
num_workers = num_gpus * workers_per_gpu
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None
data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)
return data_loader
================================================
FILE: loaders/ego_pose_dataset.py
================================================
import torch
import numpy as np
from pyquaternion import Quaternion
from torch.utils.data import Dataset
np.set_printoptions(precision=3, suppress=True)
def trans_matrix(T, R):
tm = np.eye(4)
tm[:3, :3] = R.rotation_matrix
tm[:3, 3] = T
return tm
# A helper dataset for RayIoU. It is NOT used during training.
class EgoPoseDataset(Dataset):
def __init__(self, data_infos):
super(EgoPoseDataset, self).__init__()
self.data_infos = data_infos
self.scene_frames = {}
for info in data_infos:
scene_name = info['scene_name']
if scene_name not in self.scene_frames:
self.scene_frames[scene_name] = []
self.scene_frames[scene_name].append(info)
def __len__(self):
return len(self.data_infos)
def get_ego_from_lidar(self, info):
ego_from_lidar = trans_matrix(
np.array(info['lidar2ego_translation']),
Quaternion(info['lidar2ego_rotation']))
return ego_from_lidar
def get_global_pose(self, info, inverse=False):
global_from_ego = trans_matrix(
np.array(info['ego2global_translation']),
Quaternion(info['ego2global_rotation']))
ego_from_lidar = trans_matrix(
np.array(info['lidar2ego_translation']),
Quaternion(info['lidar2ego_rotation']))
pose = global_from_ego.dot(ego_from_lidar)
if inverse:
pose = np.linalg.inv(pose)
return pose
def __getitem__(self, idx):
info = self.data_infos[idx]
ref_sample_token = info['token']
ref_lidar_from_global = self.get_global_pose(info, inverse=True)
ref_ego_from_lidar = self.get_ego_from_lidar(info)
scene_frame = self.scene_frames[info['scene_name']]
ref_index = scene_frame.index(info)
# NOTE: getting output frames
output_origin_list = []
for curr_index in range(len(scene_frame)):
# if this exists a valid target
if curr_index == ref_index:
origin_tf = np.array([0.0, 0.0, 0.0], dtype=np.float32)
else:
# transform from the current lidar frame to global and then to the reference lidar frame
global_from_curr = self.get_global_pose(scene_frame[curr_index], inverse=False)
ref_from_curr = ref_lidar_from_global.dot(global_from_curr)
origin_tf = np.array(ref_from_curr[:3, 3], dtype=np.float32)
origin_tf_pad = np.ones([4])
origin_tf_pad[:3] = origin_tf # pad to [4]
origin_tf = np.dot(ref_ego_from_lidar[:3], origin_tf_pad.T).T # [3]
# origin
if np.abs(origin_tf[0]) < 39 and np.abs(origin_tf[1]) < 39:
output_origin_list.append(origin_tf)
# select 8 origins
if len(output_origin_list) > 8:
select_idx = np.round(np.linspace(0, len(output_origin_list) - 1, 8)).astype(np.int64)
output_origin_list = [output_origin_list[i] for i in select_idx]
output_origin_tensor = torch.from_numpy(np.stack(output_origin_list)) # [T, 3]
return (ref_sample_token, output_origin_tensor)
================================================
FILE: loaders/nuscenes_dataset.py
================================================
import os
import numpy as np
from mmdet.datasets import DATASETS
from mmdet3d.datasets import NuScenesDataset
from pyquaternion import Quaternion
@DATASETS.register_module()
class CustomNuScenesDataset(NuScenesDataset):
def collect_sweeps(self, index, into_past=60, into_future=0):
all_sweeps_prev = []
curr_index = index
while len(all_sweeps_prev) < into_past:
curr_sweeps = self.data_infos[curr_index]['sweeps']
if len(curr_sweeps) == 0:
break
all_sweeps_prev.extend(curr_sweeps)
all_sweeps_prev.append(self.data_infos[curr_index - 1]['cams'])
curr_index = curr_index - 1
all_sweeps_next = []
curr_index = index + 1
while len(all_sweeps_next) < into_future:
if curr_index >= len(self.data_infos):
break
curr_sweeps = self.data_infos[curr_index]['sweeps']
all_sweeps_next.extend(curr_sweeps[::-1])
all_sweeps_next.append(self.data_infos[curr_index]['cams'])
curr_index = curr_index + 1
return all_sweeps_prev, all_sweeps_next
def get_data_info(self, index):
info = self.data_infos[index]
sweeps_prev, sweeps_next = self.collect_sweeps(index)
ego2global_translation = info['ego2global_translation']
ego2global_rotation = info['ego2global_rotation']
lidar2ego_translation = info['lidar2ego_translation']
lidar2ego_rotation = info['lidar2ego_rotation']
ego2global_rotation = Quaternion(ego2global_rotation).rotation_matrix
lidar2ego_rotation = Quaternion(lidar2ego_rotation).rotation_matrix
input_dict = dict(
sample_idx=info['token'],
sweeps={'prev': sweeps_prev, 'next': sweeps_next},
timestamp=info['timestamp'] / 1e6,
ego2global_translation=ego2global_translation,
ego2global_rotation=ego2global_rotation,
lidar2ego_translation=lidar2ego_translation,
lidar2ego_rotation=lidar2ego_rotation,
)
if self.modality['use_camera']:
img_paths = []
img_timestamps = []
lidar2img_rts = []
for _, cam_info in info['cams'].items():
img_paths.append(os.path.relpath(cam_info['data_path']))
img_timestamps.append(cam_info['timestamp'] / 1e6)
# obtain lidar to image transformation matrix
lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation'])
lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T
lidar2cam_rt = np.eye(4)
lidar2cam_rt[:3, :3] = lidar2cam_r.T
lidar2cam_rt[3, :3] = -lidar2cam_t
intrinsic = cam_info['cam_intrinsic']
viewpad = np.eye(4)
viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic
lidar2img_rt = (viewpad @ lidar2cam_rt.T)
lidar2img_rts.append(lidar2img_rt)
input_dict.update(dict(
img_filename=img_paths,
img_timestamp=img_timestamps,
lidar2img=lidar2img_rts,
))
if not self.test_mode:
annos = self.get_ann_info(index)
input_dict['ann_info'] = annos
return input_dict
================================================
FILE: loaders/nuscenes_occ_dataset.py
================================================
import os
import mmcv
import glob
import torch
import numpy as np
from tqdm import tqdm
from mmdet.datasets import DATASETS
from mmdet3d.datasets import NuScenesDataset
from nuscenes.eval.common.utils import Quaternion
from nuscenes.utils.geometry_utils import transform_matrix
from torch.utils.data import DataLoader
from models.utils import sparse2dense
from .ray_metrics import main_rayiou, main_raypq
from .ego_pose_dataset import EgoPoseDataset
from configs.r50_nuimg_704x256_8f import occ_class_names as occ3d_class_names
from configs.r50_nuimg_704x256_8f_openocc import occ_class_names as openocc_class_names
@DATASETS.register_module()
class NuSceneOcc(NuScenesDataset):
def __init__(self, occ_gt_root, *args, **kwargs):
super().__init__(filter_empty_gt=False, *args, **kwargs)
self.occ_gt_root = occ_gt_root
self.data_infos = self.load_annotations(self.ann_file)
self.token2scene = {}
for gt_path in glob.glob(os.path.join(self.occ_gt_root, '*/*/*.npz')):
token = gt_path.split('/')[-2]
scene_name = gt_path.split('/')[-3]
self.token2scene[token] = scene_name
for i in range(len(self.data_infos)):
scene_name = self.token2scene[self.data_infos[i]['token']]
self.data_infos[i]['scene_name'] = scene_name
def collect_sweeps(self, index, into_past=150, into_future=0):
all_sweeps_prev = []
curr_index = index
while len(all_sweeps_prev) < into_past:
curr_sweeps = self.data_infos[curr_index]['sweeps']
if len(curr_sweeps) == 0:
break
all_sweeps_prev.extend(curr_sweeps)
all_sweeps_prev.append(self.data_infos[curr_index - 1]['cams'])
curr_index = curr_index - 1
all_sweeps_next = []
curr_index = index + 1
while len(all_sweeps_next) < into_future:
if curr_index >= len(self.data_infos):
break
curr_sweeps = self.data_infos[curr_index]['sweeps']
all_sweeps_next.extend(curr_sweeps[::-1])
all_sweeps_next.append(self.data_infos[curr_index]['cams'])
curr_index = curr_index + 1
return all_sweeps_prev, all_sweeps_next
def get_data_info(self, index):
info = self.data_infos[index]
sweeps_prev, sweeps_next = self.collect_sweeps(index)
ego2global_translation = info['ego2global_translation']
ego2global_rotation = info['ego2global_rotation']
lidar2ego_translation = info['lidar2ego_translation']
lidar2ego_rotation = info['lidar2ego_rotation']
ego2global_rotation_mat = Quaternion(ego2global_rotation).rotation_matrix
lidar2ego_rotation_mat = Quaternion(lidar2ego_rotation).rotation_matrix
input_dict = dict(
sample_idx=info['token'],
sweeps={'prev': sweeps_prev, 'next': sweeps_next},
timestamp=info['timestamp'] / 1e6,
ego2global_translation=ego2global_translation,
ego2global_rotation=ego2global_rotation_mat,
lidar2ego_translation=lidar2ego_translation,
lidar2ego_rotation=lidar2ego_rotation_mat,
)
ego2lidar = transform_matrix(lidar2ego_translation, Quaternion(lidar2ego_rotation), inverse=True)
input_dict['ego2lidar'] = [ego2lidar for _ in range(6)]
input_dict['occ_path'] = os.path.join(self.occ_gt_root, info['scene_name'], info['token'], 'labels.npz')
if self.modality['use_camera']:
img_paths = []
img_timestamps = []
lidar2img_rts = []
for _, cam_info in info['cams'].items():
img_paths.append(os.path.relpath(cam_info['data_path']))
img_timestamps.append(cam_info['timestamp'] / 1e6)
# obtain lidar to image transformation matrix
lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation'])
lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T
lidar2cam_rt = np.eye(4)
lidar2cam_rt[:3, :3] = lidar2cam_r.T
lidar2cam_rt[3, :3] = -lidar2cam_t
intrinsic = cam_info['cam_intrinsic']
viewpad = np.eye(4)
viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic
lidar2img_rt = (viewpad @ lidar2cam_rt.T)
lidar2img_rts.append(lidar2img_rt)
input_dict.update(dict(
img_filename=img_paths,
img_timestamp=img_timestamps,
lidar2img=lidar2img_rts,
))
if not self.test_mode:
annos = self.get_ann_info(index)
input_dict['ann_info'] = annos
return input_dict
def evaluate(self, occ_results, runner=None, show_dir=None, **eval_kwargs):
occ_gts, occ_preds, inst_gts, inst_preds, lidar_origins = [], [], [], [], []
print('\nStarting Evaluation...')
sample_tokens = [info['token'] for info in self.data_infos]
for batch in DataLoader(EgoPoseDataset(self.data_infos), num_workers=8):
token = batch[0][0]
output_origin = batch[1]
data_id = sample_tokens.index(token)
info = self.data_infos[data_id]
occ_path = os.path.join(self.occ_gt_root, info['scene_name'], info['token'], 'labels.npz')
occ_gt = np.load(occ_path, allow_pickle=True)
gt_semantics = occ_gt['semantics']
occ_pred = occ_results[data_id]
sem_pred = torch.from_numpy(occ_pred['sem_pred']) # [B, N]
occ_loc = torch.from_numpy(occ_pred['occ_loc'].astype(np.int64)) # [B, N, 3]
data_type = self.occ_gt_root.split('/')[-1]
if data_type == 'occ3d' or data_type == 'occ3d_panoptic':
occ_class_names = occ3d_class_names
elif data_type == 'openocc_v2':
occ_class_names = openocc_class_names
else:
raise ValueError
free_id = len(occ_class_names) - 1
occ_size = list(gt_semantics.shape)
sem_pred, _ = sparse2dense(occ_loc, sem_pred, dense_shape=occ_size, empty_value=free_id)
sem_pred = sem_pred.squeeze(0).numpy()
if 'pano_inst' in occ_pred.keys():
pano_inst = torch.from_numpy(occ_pred['pano_inst'])
pano_sem = torch.from_numpy(occ_pred['pano_sem'])
pano_inst, _ = sparse2dense(occ_loc, pano_inst, dense_shape=occ_size, empty_value=0)
pano_sem, _ = sparse2dense(occ_loc, pano_sem, dense_shape=occ_size, empty_value=free_id)
pano_inst = pano_inst.squeeze(0).numpy()
pano_sem = pano_sem.squeeze(0).numpy()
sem_pred = pano_sem
gt_instances = occ_gt['instances']
inst_gts.append(gt_instances)
inst_preds.append(pano_inst)
lidar_origins.append(output_origin)
occ_gts.append(gt_semantics)
occ_preds.append(sem_pred)
if len(inst_preds) > 0:
results = main_raypq(occ_preds, occ_gts, inst_preds, inst_gts, lidar_origins, occ_class_names=occ_class_names)
results.update(main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names))
return results
else:
return main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names)
def format_results(self, occ_results, submission_prefix, **kwargs):
if submission_prefix is not None:
mmcv.mkdir_or_exist(submission_prefix)
for index, occ_pred in enumerate(tqdm(occ_results)):
info = self.data_infos[index]
sample_token = info['token']
save_path = os.path.join(submission_prefix, '{}.npz'.format(sample_token))
np.savez_compressed(save_path, occ_pred.astype(np.uint8))
print('\nFinished.')
================================================
FILE: loaders/old_metrics.py
================================================
import os
import numpy as np
from sklearn.neighbors import KDTree
from termcolor import colored
from functools import reduce
from typing import Iterable
np.seterr(divide='ignore', invalid='ignore')
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
def pcolor(string, color, on_color=None, attrs=None):
"""
Produces a colored string for printing
Parameters
----------
string : str
String that will be colored
color : str
Color to use
on_color : str
Background color to use
attrs : list of str
Different attributes for the string
Returns
-------
string: str
Colored string
"""
return colored(string, color, on_color, attrs)
def getCellCoordinates(points, voxelSize):
return (points / voxelSize).astype(np.int)
def getNumUniqueCells(cells):
M = cells.max() + 1
return np.unique(cells[:, 0] + M * cells[:, 1] + M ** 2 * cells[:, 2]).shape[0]
class Metric_mIoU():
def __init__(self,
save_dir='.',
num_classes=18,
use_lidar_mask=False,
use_image_mask=False,
):
if num_classes == 18:
self.class_names = [
'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation','free'
]
elif num_classes == 2:
self.class_names = ['non-free', 'free']
self.save_dir = save_dir
self.use_lidar_mask = use_lidar_mask
self.use_image_mask = use_image_mask
self.num_classes = num_classes
self.point_cloud_range = [-40.0, -40.0, -1.0, 40.0, 40.0, 5.4]
self.occupancy_size = [0.4, 0.4, 0.4]
self.voxel_size = 0.4
self.occ_xdim = int((self.point_cloud_range[3] - self.point_cloud_range[0]) / self.occupancy_size[0])
self.occ_ydim = int((self.point_cloud_range[4] - self.point_cloud_range[1]) / self.occupancy_size[1])
self.occ_zdim = int((self.point_cloud_range[5] - self.point_cloud_range[2]) / self.occupancy_size[2])
self.voxel_num = self.occ_xdim * self.occ_ydim * self.occ_zdim
self.hist = np.zeros((self.num_classes, self.num_classes))
self.cnt = 0
def hist_info(self, n_cl, pred, gt):
"""
build confusion matrix
# empty classes:0
non-empty class: 0-16
free voxel class: 17
Args:
n_cl (int): num_classes_occupancy
pred (1-d array): pred_occupancy_label
gt (1-d array): gt_occupancu_label
Returns:
tuple:(hist, correctly number_predicted_labels, num_labelled_sample)
"""
assert pred.shape == gt.shape
k = (gt >= 0) & (gt < n_cl) # exclude 255
labeled = np.sum(k)
correct = np.sum((pred[k] == gt[k]))
return (
np.bincount(
n_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl ** 2
).reshape(n_cl, n_cl),
correct,
labeled,
)
def per_class_iu(self, hist):
#return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
result = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
result[hist.sum(1) == 0] = float('nan')
return result
def compute_mIoU(self, pred, label, n_classes):
hist = np.zeros((n_classes, n_classes))
new_hist, correct, labeled = self.hist_info(n_classes, pred.flatten(), label.flatten())
hist += new_hist
mIoUs = self.per_class_iu(hist)
# for ind_class in range(n_classes):
# print(str(round(mIoUs[ind_class] * 100, 2)))
# print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)))
return round(np.nanmean(mIoUs) * 100, 2), hist
def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera):
self.cnt += 1
if self.use_image_mask:
masked_semantics_gt = semantics_gt[mask_camera]
masked_semantics_pred = semantics_pred[mask_camera]
elif self.use_lidar_mask:
masked_semantics_gt = semantics_gt[mask_lidar]
masked_semantics_pred = semantics_pred[mask_lidar]
else:
masked_semantics_gt = semantics_gt
masked_semantics_pred = semantics_pred
if self.num_classes == 2:
masked_semantics_pred = np.copy(masked_semantics_pred)
masked_semantics_gt = np.copy(masked_semantics_gt)
masked_semantics_pred[masked_semantics_pred < 17] = 0
masked_semantics_pred[masked_semantics_pred == 17] = 1
masked_semantics_gt[masked_semantics_gt < 17] = 0
masked_semantics_gt[masked_semantics_gt == 17] = 1
_, _hist = self.compute_mIoU(masked_semantics_pred, masked_semantics_gt, self.num_classes)
self.hist += _hist
def count_miou(self):
mIoU = self.per_class_iu(self.hist)
# assert cnt == num_samples, 'some samples are not included in the miou calculation'
print(f'===> per class IoU of {self.cnt} samples:')
for ind_class in range(self.num_classes-1):
print(f'===> {self.class_names[ind_class]} - IoU = ' + str(round(mIoU[ind_class] * 100, 2)))
print(f'===> mIoU of {self.cnt} samples: ' + str(round(np.nanmean(mIoU[:self.num_classes-1]) * 100, 2)))
# print(f'===> sample-wise averaged mIoU of {cnt} samples: ' + str(round(np.nanmean(mIoU_avg), 2)))
return round(np.nanmean(mIoU[:self.num_classes-1]) * 100, 2)
class Metric_FScore():
def __init__(self,
leaf_size=10,
threshold_acc=0.6,
threshold_complete=0.6,
voxel_size=[0.4, 0.4, 0.4],
range=[-40, -40, -1, 40, 40, 5.4],
void=[17, 255],
use_lidar_mask=False,
use_image_mask=False, ) -> None:
self.leaf_size = leaf_size
self.threshold_acc = threshold_acc
self.threshold_complete = threshold_complete
self.voxel_size = voxel_size
self.range = range
self.void = void
self.use_lidar_mask = use_lidar_mask
self.use_image_mask = use_image_mask
self.cnt=0
self.tot_acc = 0.
self.tot_cmpl = 0.
self.tot_f1_mean = 0.
self.eps = 1e-8
def voxel2points(self, voxel):
# occIdx = torch.where(torch.logical_and(voxel != FREE, voxel != NOT_OBSERVED))
# if isinstance(voxel, np.ndarray): voxel = torch.from_numpy(voxel)
mask = np.logical_not(reduce(np.logical_or, [voxel == self.void[i] for i in range(len(self.void))]))
occIdx = np.where(mask)
points = np.concatenate((occIdx[0][:, None] * self.voxel_size[0] + self.voxel_size[0] / 2 + self.range[0], \
occIdx[1][:, None] * self.voxel_size[1] + self.voxel_size[1] / 2 + self.range[1], \
occIdx[2][:, None] * self.voxel_size[2] + self.voxel_size[2] / 2 + self.range[2]),
axis=1)
return points
def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera ):
# for scene_token in tqdm(preds_dict.keys()):
self.cnt += 1
if self.use_image_mask:
semantics_gt[mask_camera == False] = 255
semantics_pred[mask_camera == False] = 255
elif self.use_lidar_mask:
semantics_gt[mask_lidar == False] = 255
semantics_pred[mask_lidar == False] = 255
else:
pass
ground_truth = self.voxel2points(semantics_gt)
prediction = self.voxel2points(semantics_pred)
if prediction.shape[0] == 0:
accuracy=0
completeness=0
fmean=0
else:
prediction_tree = KDTree(prediction, leaf_size=self.leaf_size)
ground_truth_tree = KDTree(ground_truth, leaf_size=self.leaf_size)
complete_distance, _ = prediction_tree.query(ground_truth)
complete_distance = complete_distance.flatten()
accuracy_distance, _ = ground_truth_tree.query(prediction)
accuracy_distance = accuracy_distance.flatten()
# evaluate completeness
complete_mask = complete_distance < self.threshold_complete
completeness = complete_mask.mean()
# evalute accuracy
accuracy_mask = accuracy_distance < self.threshold_acc
accuracy = accuracy_mask.mean()
fmean = 2.0 / (1 / (accuracy+self.eps) + 1 / (completeness+self.eps))
self.tot_acc += accuracy
self.tot_cmpl += completeness
self.tot_f1_mean += fmean
def count_fscore(self,):
base_color, attrs = 'red', ['bold', 'dark']
print(pcolor('\n######## F score: {} #######'.format(self.tot_f1_mean / self.cnt), base_color, attrs=attrs))
return self.tot_f1_mean / self.cnt
class Metric_mRecall():
def __init__(self,
save_dir='.',
num_classes=18,
pred_classes=2,
use_lidar_mask=False,
use_image_mask=False,
):
if num_classes == 18:
self.class_names = [
'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation','free'
]
elif num_classes == 2:
self.class_names = ['non-free', 'free']
self.pred_classes = pred_classes
self.save_dir = save_dir
self.use_lidar_mask = use_lidar_mask
self.use_image_mask = use_image_mask
self.num_classes = num_classes
self.point_cloud_range = [-40.0, -40.0, -1.0, 40.0, 40.0, 5.4]
self.occupancy_size = [0.4, 0.4, 0.4]
self.voxel_size = 0.4
self.occ_xdim = int((self.point_cloud_range[3] - self.point_cloud_range[0]) / self.occupancy_size[0])
self.occ_ydim = int((self.point_cloud_range[4] - self.point_cloud_range[1]) / self.occupancy_size[1])
self.occ_zdim = int((self.point_cloud_range[5] - self.point_cloud_range[2]) / self.occupancy_size[2])
self.voxel_num = self.occ_xdim * self.occ_ydim * self.occ_zdim
self.hist = np.zeros((self.num_classes, self.pred_classes)) # n_cl, p_cl
self.cnt = 0
def hist_info(self, n_cl, p_cl, pred, gt):
"""
build confusion matrix
# empty classes:0
non-empty class: 0-16
free voxel class: 17
Args:
n_cl (int): num_classes_occupancy
pred (1-d array): pred_occupancy_label
gt (1-d array): gt_occupancu_label
Returns:
tuple:(hist, correctly number_predicted_labels, num_labelled_sample)
"""
assert pred.shape == gt.shape
k = (gt >= 0) & (gt < n_cl) # exclude 255
labeled = np.sum(k)
correct = np.sum((pred[k] == gt[k]))
return (
np.bincount(
p_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl * p_cl
).reshape(n_cl, p_cl), # 18, 2
correct,
labeled,
)
def per_class_recall(self, hist):
return hist[:, 1] / hist.sum(1) ## recall
def compute_mRecall(self, pred, label, n_classes, p_classes):
hist = np.zeros((n_classes, p_classes))
new_hist, correct, labeled = self.hist_info(n_classes, p_classes, pred.flatten(), label.flatten())
hist += new_hist
mRecalls = self.per_class_recall(hist)
# for ind_class in range(n_classes):
# print(str(round(mIoUs[ind_class] * 100, 2)))
# print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)))
return round(np.nanmean(mRecalls) * 100, 2), hist
def add_batch(self,semantics_pred,semantics_gt,mask_lidar,mask_camera):
self.cnt += 1
if self.use_image_mask:
masked_semantics_gt = semantics_gt[mask_camera]
masked_semantics_pred = semantics_pred[mask_camera]
elif self.use_lidar_mask:
masked_semantics_gt = semantics_gt[mask_lidar]
masked_semantics_pred = semantics_pred[mask_lidar]
else:
masked_semantics_gt = semantics_gt
masked_semantics_pred = semantics_pred
if self.pred_classes == 2:
masked_semantics_pred = np.copy(masked_semantics_pred)
masked_semantics_gt = np.copy(masked_semantics_gt)
masked_semantics_pred[masked_semantics_pred < 17] = 1
masked_semantics_pred[masked_semantics_pred == 17] = 0 # 0 is free
_, _hist = self.compute_mRecall(masked_semantics_pred, masked_semantics_gt, self.num_classes, self.pred_classes)
self.hist += _hist
def count_mrecall(self):
mRecall = self.per_class_recall(self.hist)
# assert cnt == num_samples, 'some samples are not included in the miou calculation'
print(f'===> per class Recall of {self.cnt} samples:')
for ind_class in range(self.num_classes-1):
print(f'===> {self.class_names[ind_class]} - Recall = ' + str(round(mRecall[ind_class] * 100, 2)))
print(f'===> mRecall of {self.cnt} samples: ' + str(round(np.nanmean(mRecall[:self.num_classes-1]) * 100, 2)))
return round(np.nanmean(mRecall[:self.num_classes-1]) * 100, 2)
# modified from https://github.com/open-mmlab/mmdetection3d/blob/main/mmdet3d/evaluation/functional/panoptic_seg_eval.py#L10
class Metric_Panoptic():
def __init__(self,
save_dir='.',
num_classes=18,
use_lidar_mask=False,
use_image_mask=False,
ignore_index: Iterable[int]=[],
):
"""
Args:
ignore_index (llist): Class ids that not be considered in pq counting.
"""
if num_classes == 18:
self.class_names = [
'others','barrier', 'bicycle', 'bus', 'car', 'construction_vehicle',
'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck',
'driveable_surface', 'other_flat', 'sidewalk',
'terrain', 'manmade', 'vegetation','free'
]
else:
raise ValueError
self.save_dir = save_dir
self.num_classes = num_classes
self.use_lidar_mask = use_lidar_mask
self.use_image_mask = use_image_mask
self.ignore_index = ignore_index
self.id_offset = 2 ** 16
self.eps = 1e-5
self.min_num_points = 20
self.include = np.array(
[n for n in range(self.num_classes - 1) if n not in self.ignore_index],
dtype=int)
self.cnt = 0
# panoptic stuff
self.pan_tp = np.zeros(self.num_classes, dtype=int)
self.pan_iou = np.zeros(self.num_classes, dtype=np.double)
self.pan_fp = np.zeros(self.num_classes, dtype=int)
self.pan_fn = np.zeros(self.num_classes, dtype=int)
def add_batch(self,semantics_pred,semantics_gt,instances_pred,instances_gt,mask_lidar,mask_camera):
self.cnt += 1
if self.use_image_mask:
masked_semantics_gt = semantics_gt[mask_camera]
masked_semantics_pred = semantics_pred[mask_camera]
masked_instances_gt = instances_gt[mask_camera]
masked_instances_pred = instances_pred[mask_camera]
elif self.use_lidar_mask:
masked_semantics_gt = semantics_gt[mask_lidar]
masked_semantics_pred = semantics_pred[mask_lidar]
masked_instances_gt = instances_gt[mask_lidar]
masked_instances_pred = instances_pred[mask_lidar]
else:
masked_semantics_gt = semantics_gt
masked_semantics_pred = semantics_pred
masked_instances_gt = instances_gt
masked_instances_pred = instances_pred
self.add_panoptic_sample(masked_semantics_pred, masked_semantics_gt, masked_instances_pred, masked_instances_gt)
def add_panoptic_sample(self, semantics_pred, semantics_gt, instances_pred, instances_gt):
"""Add one sample of panoptic predictions and ground truths for
evaluation.
Args:
semantics_pred (np.ndarray): Semantic predictions.
semantics_gt (np.ndarray): Semantic ground truths.
instances_pred (np.ndarray): Instance predictions.
instances_gt (np.ndarray): Instance ground truths.
"""
# get instance_class_id from instance_gt
instance_class_ids = [self.num_classes - 1]
for i in range(1, instances_gt.max() + 1):
class_id = np.unique(semantics_gt[instances_gt == i])
# assert class_id.shape[0] == 1, "each instance must belong to only one class"
if class_id.shape[0] == 1:
instance_class_ids.append(class_id[0])
else:
instance_class_ids.append(self.num_classes - 1)
instance_class_ids = np.array(instance_class_ids)
instance_count = 1
final_instance_class_ids = []
final_instances = np.zeros_like(instances_gt) # empty space has instance id "0"
for class_id in range(self.num_classes - 1):
if np.sum(semantics_gt == class_id) == 0:
continue
if self.class_names[class_id] in ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'motorcycle', 'bicycle', 'pedestrian']:
# treat as instances
for instance_id in range(len(instance_class_ids)):
if instance_class_ids[instance_id] != class_id:
continue
final_instances[instances_gt == instance_id] = instance_count
instance_count += 1
final_instance_class_ids.append(class_id)
else:
# treat as semantics
final_instances[semantics_gt == class_id] = instance_count
instance_count += 1
final_instance_class_ids.append(class_id)
instances_gt = final_instances
# avoid zero (ignored label)
instances_pred = instances_pred + 1
instances_gt = instances_gt + 1
for cl in self.ignore_index:
# make a mask for this class
gt_not_in_excl_mask = semantics_gt != cl
# remove all other points
semantics_pred = semantics_pred[gt_not_in_excl_mask]
semantics_gt = semantics_gt[gt_not_in_excl_mask]
instances_pred = instances_pred[gt_not_in_excl_mask]
instances_gt = instances_gt[gt_not_in_excl_mask]
# for each class (except the ignored ones)
for cl in self.include:
# get a class mask
pred_inst_in_cl_mask = semantics_pred == cl
gt_inst_in_cl_mask = semantics_gt == cl
# get instance points in class (makes outside stuff 0)
pred_inst_in_cl = instances_pred * pred_inst_in_cl_mask.astype(int)
gt_inst_in_cl = instances_gt * gt_inst_in_cl_mask.astype(int)
# generate the areas for each unique instance prediction
unique_pred, counts_pred = np.unique(
pred_inst_in_cl[pred_inst_in_cl > 0], return_counts=True)
id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)}
matched_pred = np.array([False] * unique_pred.shape[0])
# generate the areas for each unique instance gt_np
unique_gt, counts_gt = np.unique(
gt_inst_in_cl[gt_inst_in_cl > 0], return_counts=True)
id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)}
matched_gt = np.array([False] * unique_gt.shape[0])
# generate intersection using offset
valid_combos = np.logical_and(pred_inst_in_cl > 0,
gt_inst_in_cl > 0)
id_offset_combo = pred_inst_in_cl[
valid_combos] + self.id_offset * gt_inst_in_cl[valid_combos]
unique_combo, counts_combo = np.unique(
id_offset_combo, return_counts=True)
# generate an intersection map
# count the intersections with over 0.5 IoU as TP
gt_labels = unique_combo // self.id_offset
pred_labels = unique_combo % self.id_offset
gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels])
pred_areas = np.array(
[counts_pred[id2idx_pred[id]] for id in pred_labels])
intersections = counts_combo
unions = gt_areas + pred_areas - intersections
ious = intersections.astype(float) / unions.astype(float)
tp_indexes = ious > 0.5
self.pan_tp[cl] += np.sum(tp_indexes)
self.pan_iou[cl] += np.sum(ious[tp_indexes])
matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True
matched_pred[[id2idx_pred[id]
for id in pred_labels[tp_indexes]]] = True
# count the FN
if len(counts_gt) > 0:
self.pan_fn[cl] += np.sum(
np.logical_and(counts_gt >= self.min_num_points,
~matched_gt))
# count the FP
if len(matched_pred) > 0:
self.pan_fp[cl] += np.sum(
np.logical_and(counts_pred >= self.min_num_points,
~matched_pred))
def count_pq(self, ):
sq_all = self.pan_iou.astype(np.double) / np.maximum(
self.pan_tp.astype(np.double), self.eps)
rq_all = self.pan_tp.astype(np.double) / np.maximum(
self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double)
+ 0.5 * self.pan_fn.astype(np.double), self.eps)
pq_all = sq_all * rq_all
# mask classes not occurring in dataset
mask = (self.pan_tp + self.pan_fp + self.pan_fn) > 0
sq_all[~mask] = float('nan')
rq_all[~mask] = float('nan')
pq_all[~mask] = float('nan')
# then do the REAL mean (no ignored classes)
sq = round(np.nanmean(sq_all[self.include]) * 100, 2)
rq = round(np.nanmean(rq_all[self.include]) * 100, 2)
pq = round(np.nanmean(pq_all[self.include]) * 100, 2)
print(f'===> per class sq, rq, pq of {self.cnt} samples:')
for ind_class in self.include:
print(f'===> {self.class_names[ind_class]} -' + \
f' sq = {round(sq_all[ind_class] * 100, 2)},' + \
f' rq = {round(rq_all[ind_class] * 100, 2)},' + \
f' pq = {round(pq_all[ind_class] * 100, 2)}')
print(f'===> sq of {self.cnt} samples: ' + str(sq))
print(f'===> rq of {self.cnt} samples: ' + str(rq))
print(f'===> pq of {self.cnt} samples: ' + str(pq))
return (pq, sq, rq)
================================================
FILE: loaders/pipelines/__init__.py
================================================
from .loading import LoadMultiViewImageFromMultiSweeps, LoadOccGTFromFile
from .transforms import PadMultiViewImage, NormalizeMultiviewImage, PhotoMetricDistortionMultiViewImage
__all__ = [
'LoadMultiViewImageFromMultiSweeps', 'PadMultiViewImage', 'NormalizeMultiviewImage',
'PhotoMetricDistortionMultiViewImage', 'LoadOccGTFromFile'
]
================================================
FILE: loaders/pipelines/loading.py
================================================
import os
import mmcv
import torch
import numpy as np
from mmdet.datasets.builder import PIPELINES
from numpy.linalg import inv
from mmcv.runner import get_dist_info
from mmcv.parallel import DataContainer as DC
from mmdet.datasets.pipelines import to_tensor
from torchvision.transforms.functional import rotate
def compose_lidar2img(ego2global_translation_curr,
ego2global_rotation_curr,
lidar2ego_translation_curr,
lidar2ego_rotation_curr,
sensor2global_translation_past,
sensor2global_rotation_past,
cam_intrinsic_past):
R = sensor2global_rotation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T)
T = sensor2global_translation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T)
T -= ego2global_translation_curr @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T) + lidar2ego_translation_curr @ inv(lidar2ego_rotation_curr).T
lidar2cam_r = inv(R.T)
lidar2cam_t = T @ lidar2cam_r.T
lidar2cam_rt = np.eye(4)
lidar2cam_rt[:3, :3] = lidar2cam_r.T
lidar2cam_rt[3, :3] = -lidar2cam_t
viewpad = np.eye(4)
viewpad[:cam_intrinsic_past.shape[0], :cam_intrinsic_past.shape[1]] = cam_intrinsic_past
lidar2img = (viewpad @ lidar2cam_rt.T).astype(np.float32)
return lidar2img
@PIPELINES.register_module()
class LoadMultiViewImageFromMultiSweeps(object):
def __init__(self,
sweeps_num=5,
color_type='color',
test_mode=False):
self.sweeps_num = sweeps_num
self.color_type = color_type
self.test_mode = test_mode
self.train_interval = [4, 8]
self.test_interval = 6
try:
mmcv.use_backend('turbojpeg')
except ImportError:
mmcv.use_backend('cv2')
def load_offline(self, results):
cam_types = [
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
]
if len(results['sweeps']['prev']) == 0:
for _ in range(self.sweeps_num):
for j in range(len(cam_types)):
results['img'].append(results['img'][j])
results['img_timestamp'].append(results['img_timestamp'][j])
results['filename'].append(results['filename'][j])
results['lidar2img'].append(np.copy(results['lidar2img'][j]))
if 'ego2lidar' in results:
results['ego2lidar'].append(results['ego2lidar'][0])
else:
if self.test_mode:
interval = self.test_interval
choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
elif len(results['sweeps']['prev']) <= self.sweeps_num:
pad_len = self.sweeps_num - len(results['sweeps']['prev'])
choices = list(range(len(results['sweeps']['prev']))) + [len(results['sweeps']['prev']) - 1] * pad_len
else:
max_interval = len(results['sweeps']['prev']) // self.sweeps_num
max_interval = min(max_interval, self.train_interval[1])
min_interval = min(max_interval, self.train_interval[0])
interval = np.random.randint(min_interval, max_interval + 1)
choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
for idx in sorted(list(choices)):
sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)
sweep = results['sweeps']['prev'][sweep_idx]
if len(sweep.keys()) < len(cam_types):
sweep = results['sweeps']['prev'][sweep_idx - 1]
for sensor in cam_types:
results['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type))
results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
results['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
results['lidar2img'].append(compose_lidar2img(
results['ego2global_translation'],
results['ego2global_rotation'],
results['lidar2ego_translation'],
results['lidar2ego_rotation'],
sweep[sensor]['sensor2global_translation'],
sweep[sensor]['sensor2global_rotation'],
sweep[sensor]['cam_intrinsic'],
))
if 'ego2lidar' in results:
results['ego2lidar'].append(results['ego2lidar'][0])
return results
def load_online(self, results):
# only used when measuring FPS
assert self.test_mode
assert self.test_interval % 6 == 0
cam_types = [
'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
]
if len(results['sweeps']['prev']) == 0:
for _ in range(self.sweeps_num):
for j in range(len(cam_types)):
results['img_timestamp'].append(results['img_timestamp'][j])
results['filename'].append(results['filename'][j])
results['lidar2img'].append(np.copy(results['lidar2img'][j]))
if 'ego2lidar' in results:
results['ego2lidar'].append(results['ego2lidar'][0])
else:
interval = self.test_interval
choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
for idx in sorted(list(choices)):
sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)
sweep = results['sweeps']['prev'][sweep_idx]
if len(sweep.keys()) < len(cam_types):
sweep = results['sweeps']['prev'][sweep_idx - 1]
for sensor in cam_types:
# skip loading history frames
results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
results['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
results['lidar2img'].append(compose_lidar2img(
results['ego2global_translation'],
results['ego2global_rotation'],
results['lidar2ego_translation'],
results['lidar2ego_rotation'],
sweep[sensor]['sensor2global_translation'],
sweep[sensor]['sensor2global_rotation'],
sweep[sensor]['cam_intrinsic'],
))
if 'ego2lidar' in results:
results['ego2lidar'].append(results['ego2lidar'][0])
return results
def __call__(self, results):
if self.sweeps_num == 0:
return results
world_size = get_dist_info()[1]
if world_size == 1 and self.test_mode:
return self.load_online(results)
else:
return self.load_offline(results)
@PIPELINES.register_module()
class LoadOccGTFromFile(object):
def __init__(self, num_classes=18, inst_class_ids=[]):
self.num_classes = num_classes
self.inst_class_ids = inst_class_ids
def __call__(self, results):
occ_labels = np.load(results['occ_path'])
semantics = occ_labels['semantics'] # [200, 200, 16]
# mask_lidar = occ_labels['mask_lidar'].astype(np.bool_) # [200, 200, 16]
# mask_camera = occ_labels['mask_camera'].astype(np.bool_) # [200, 200, 16]
# results['mask_lidar'] = mask_lidar
# results['mask_camera'] = mask_camera
# instance GT
if 'instances' in occ_labels.keys():
instances = occ_labels['instances']
instance_class_ids = [self.num_classes - 1] # the 0-th class is always free class
for i in range(1, instances.max() + 1):
class_id = np.unique(semantics[instances == i])
assert class_id.shape[0] == 1, "each instance must belong to only one class"
instance_class_ids.append(class_id[0])
instance_class_ids = np.array(instance_class_ids)
else:
instances = None
instance_class_ids = None
instance_count = 0
final_instance_class_ids = []
final_instances = np.ones_like(semantics) * 255 # empty space has instance id "255"
for class_id in range(self.num_classes - 1):
if np.sum(semantics == class_id) == 0:
continue
if class_id in self.inst_class_ids:
assert instances is not None, 'instance annotation not found'
# treat as instances
for instance_id in range(len(instance_class_ids)):
if instance_class_ids[instance_id] != class_id:
continue
final_instances[instances == instance_id] = instance_count
instance_count += 1
final_instance_class_ids.append(class_id)
else:
# treat as semantics
final_instances[semantics == class_id] = instance_count
instance_count += 1
final_instance_class_ids.append(class_id)
results['voxel_semantics'] = semantics
results['voxel_instances'] = final_instances
results['instance_class_ids'] = DC(to_tensor(final_instance_class_ids))
if results.get('rotate_bda', False):
semantics = torch.from_numpy(semantics).permute(2, 0, 1) # [16, 200, 200]
semantics = rotate(semantics, results['rotate_bda'], fill=255).permute(1, 2, 0) # [200, 200, 16]
results['voxel_semantics'] = semantics.numpy()
final_instances = torch.from_numpy(final_instances).permute(2, 0, 1) # [16, 200, 200]
final_instances = rotate(final_instances, results['rotate_bda'], fill=255).permute(1, 2, 0) # [200, 200, 16]
results['voxel_instances'] = final_instances.numpy()
if results.get('flip_dx', False):
results['voxel_semantics'] = results['voxel_semantics'][::-1, ...].copy()
results['voxel_instances'] = results['voxel_instances'][::-1, ...].copy()
if results.get('flip_dy', False):
results['voxel_semantics'] = results['voxel_semantics'][:, ::-1, ...].copy()
results['voxel_instances'] = results['voxel_instances'][:, ::-1, ...].copy()
return results
# https://github.com/HuangJunJie2017/BEVDet/blob/58c2587a8f89a1927926f0bdb6cde2917c91a9a5/mmdet3d/datasets/pipelines/loading.py#L1177
@PIPELINES.register_module()
class BEVAug(object):
def __init__(self, bda_aug_conf, classes, is_train=True):
self.bda_aug_conf = bda_aug_conf
self.is_train = is_train
self.classes = classes
def sample_bda_augmentation(self):
"""Generate bda augmentation values based on bda_config."""
if self.is_train:
rotate_bda = np.random.uniform(*self.bda_aug_conf['rot_lim'])
scale_bda = np.random.uniform(*self.bda_aug_conf['scale_lim'])
flip_dx = np.random.uniform() < self.bda_aug_conf['flip_dx_ratio']
flip_dy = np.random.uniform() < self.bda_aug_conf['flip_dy_ratio']
else:
rotate_bda = 0
scale_bda = 1.0
flip_dx = False
flip_dy = False
return rotate_bda, scale_bda, flip_dx, flip_dy
def bev_transform(self, rotate_angle, scale_ratio, flip_dx, flip_dy):
"""
Returns:
rot_mat: (3, 3)
"""
rotate_angle = torch.tensor(rotate_angle / 180 * np.pi)
rot_sin = torch.sin(rotate_angle)
rot_cos = torch.cos(rotate_angle)
rot_mat = torch.Tensor([[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0],
[0, 0, 1]])
scale_mat = torch.Tensor([[scale_ratio, 0, 0], [0, scale_ratio, 0],
[0, 0, scale_ratio]])
flip_mat = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
if flip_dx:
flip_mat = flip_mat @ torch.Tensor([[-1, 0, 0], [0, 1, 0],
[0, 0, 1]])
if flip_dy:
flip_mat = flip_mat @ torch.Tensor([[1, 0, 0], [0, -1, 0],
[0, 0, 1]])
rot_mat = flip_mat @ (scale_mat @ rot_mat)
return rot_mat
def __call__(self, results):
rotate_bda, scale_bda, flip_dx, flip_dy = self.sample_bda_augmentation()
bda_mat = torch.zeros(4, 4)
bda_mat[3, 3] = 1
# bda_rot: (3, 3)
bda_rot = self.bev_transform(rotate_bda, scale_bda, flip_dx, flip_dy)
bda_mat[:3, :3] = bda_rot
results['bda_mat'] = bda_mat
results['flip_dx'] = flip_dx
results['flip_dy'] = flip_dy
results['rotate_bda'] = rotate_bda
results['scale_bda'] = scale_bda
for i in range(len(results['ego2lidar'])):
results['ego2lidar'][i] = results['ego2lidar'][i] @ torch.inverse(bda_mat).numpy() # [4, 4] @ [4, 4]
return results
================================================
FILE: loaders/pipelines/transforms.py
================================================
import mmcv
import torch
import numpy as np
from PIL import Image
from numpy import random
from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module()
class PadMultiViewImage(object):
"""Pad the multi-view image.
There are two padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number.
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
Args:
size (tuple, optional): Fixed padding size.
size_divisor (int, optional): The divisor of padded size.
pad_val (float, optional): Padding value, 0 by default.
"""
def __init__(self, size=None, size_divisor=None, pad_val=0):
self.size = size
self.size_divisor = size_divisor
self.pad_val = pad_val
# only one of size and size_divisor should be valid
assert size is not None or size_divisor is not None
assert size is None or size_divisor is None
def _pad_img(self, img):
if self.size_divisor is not None:
pad_h = int(np.ceil(img.shape[0] / self.size_divisor)) * self.size_divisor
pad_w = int(np.ceil(img.shape[1] / self.size_divisor)) * self.size_divisor
else:
pad_h, pad_w = self.size
pad_width = ((0, pad_h - img.shape[0]), (0, pad_w - img.shape[1]), (0, 0))
img = np.pad(img, pad_width, constant_values=self.pad_val)
return img
def _pad_imgs(self, results):
padded_img = [self._pad_img(img) for img in results['img']]
results['ori_shape'] = [img.shape for img in results['img']]
results['img'] = padded_img
results['img_shape'] = [img.shape for img in padded_img]
results['pad_shape'] = [img.shape for img in padded_img]
results['pad_fixed_size'] = self.size
results['pad_size_divisor'] = self.size_divisor
def __call__(self, results):
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self._pad_imgs(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(size={self.size}, '
repr_str += f'size_divisor={self.size_divisor}, '
repr_str += f'pad_val={self.pad_val})'
return repr_str
@PIPELINES.register_module()
class NormalizeMultiviewImage(object):
"""Normalize the image.
Added key is "img_norm_cfg".
Args:
mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels.
to_rgb (bool): Whether to convert the image from BGR to RGB,
default is true.
"""
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32).reshape(-1)
self.std = 1 / np.array(std, dtype=np.float32).reshape(-1)
self.to_rgb = to_rgb
def __call__(self, results):
"""Call function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, 'img_norm_cfg' key is added into
result dict.
"""
normalized_imgs = []
for img in results['img']:
img = img.astype(np.float32)
if self.to_rgb:
img = img[..., ::-1]
img = img - self.mean
img = img * self.std
normalized_imgs.append(img)
results['img'] = normalized_imgs
results['img_norm_cfg'] = dict(
mean=self.mean,
std=self.std,
to_rgb=self.to_rgb
)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
return repr_str
@PIPELINES.register_module()
class PhotoMetricDistortionMultiViewImage:
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, results):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
imgs = results['img']
new_imgs = []
for img in imgs:
ori_dtype = img.dtype
img = img.astype(np.float32)
# random brightness
if random.randint(2):
delta = random.uniform(-self.brightness_delta,
self.brightness_delta)
img += delta
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# convert color from BGR to HSV
img = mmcv.bgr2hsv(img)
# random saturation
if random.randint(2):
img[..., 1] *= random.uniform(self.saturation_lower,
self.saturation_upper)
# random hue
if random.randint(2):
img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
# convert color from HSV to BGR
img = mmcv.hsv2bgr(img)
# random contrast
if mode == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower,
self.contrast_upper)
img *= alpha
# randomly swap channels
if random.randint(2):
img = img[..., random.permutation(3)]
new_imgs.append(img.astype(ori_dtype))
results['img'] = new_imgs
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
repr_str += 'contrast_range='
repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
repr_str += 'saturation_range='
repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
repr_str += f'hue_delta={self.hue_delta})'
return repr_str
@PIPELINES.register_module()
class RandomTransformImage(object):
def __init__(self, ida_aug_conf=None, training=True):
self.ida_aug_conf = ida_aug_conf
self.training = training
def __call__(self, results):
resize, resize_dims, crop, flip, rotate = self.sample_augmentation()
if len(results['lidar2img']) == len(results['img']):
for i in range(len(results['img'])):
img = Image.fromarray(np.uint8(results['img'][i]))
# resize, resize_dims, crop, flip, rotate = self._sample_augmentation()
img, ida_mat = self.img_transform(
img,
resize=resize,
resize_dims=resize_dims,
crop=crop,
flip=flip,
rotate=rotate,
)
results['img'][i] = np.array(img).astype(np.uint8)
results['lidar2img'][i] = ida_mat @ results['lidar2img'][i]
elif len(results['img']) == 6:
for i in range(len(results['img'])):
img = Image.fromarray(np.uint8(results['img'][i]))
# resize, resize_dims, crop, flip, rotate = self._sample_augmentation()
img, ida_mat = self.img_transform(
img,
resize=resize,
resize_dims=resize_dims,
crop=crop,
flip=flip,
rotate=rotate,
)
results['img'][i] = np.array(img).astype(np.uint8)
for i in range(len(results['lidar2img'])):
results['lidar2img'][i] = ida_mat @ results['lidar2img'][i]
else:
raise ValueError()
results['ori_shape'] = [img.shape for img in results['img']]
results['img_shape'] = [img.shape for img in results['img']]
results['pad_shape'] = [img.shape for img in results['img']]
return results
def img_transform(self, img, resize, resize_dims, crop, flip, rotate):
"""
https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L48
"""
def get_rot(h):
return torch.Tensor([
[np.cos(h), np.sin(h)],
[-np.sin(h), np.cos(h)],
])
ida_rot = torch.eye(2)
ida_tran = torch.zeros(2)
# adjust image
img = img.resize(resize_dims)
img = img.crop(crop)
if flip:
img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
img = img.rotate(rotate)
# post-homography transformation
ida_rot *= resize
ida_tran -= torch.Tensor(crop[:2])
if flip:
A = torch.Tensor([[-1, 0], [0, 1]])
b = torch.Tensor([crop[2] - crop[0], 0])
ida_rot = A.matmul(ida_rot)
ida_tran = A.matmul(ida_tran) + b
A = get_rot(rotate / 180 * np.pi)
b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
b = A.matmul(-b) + b
ida_rot = A.matmul(ida_rot)
ida_tran = A.matmul(ida_tran) + b
ida_mat = torch.eye(4)
ida_mat[:2, :2] = ida_rot
ida_mat[:2, 2] = ida_tran
return img, ida_mat.numpy()
def sample_augmentation(self):
"""
https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L247
"""
H, W = self.ida_aug_conf['H'], self.ida_aug_conf['W']
fH, fW = self.ida_aug_conf['final_dim']
if self.training:
resize = np.random.uniform(*self.ida_aug_conf['resize_lim'])
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.random.uniform(*self.ida_aug_conf['bot_pct_lim'])) * newH) - fH
crop_w = int(np.random.uniform(0, max(0, newW - fW)))
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
if self.ida_aug_conf['rand_flip'] and np.random.choice([0, 1]):
flip = True
rotate = np.random.uniform(*self.ida_aug_conf['rot_lim'])
else:
resize = max(fH / H, fW / W)
resize_dims = (int(W * resize), int(H * resize))
newW, newH = resize_dims
crop_h = int((1 - np.mean(self.ida_aug_conf['bot_pct_lim'])) * newH) - fH
crop_w = int(max(0, newW - fW) / 2)
crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
flip = False
rotate = 0
return resize, resize_dims, crop, flip, rotate
@PIPELINES.register_module()
class GlobalRotScaleTransImage(object):
def __init__(self,
rot_range=[-0.3925, 0.3925],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0]):
self.rot_range = rot_range
self.scale_ratio_range = scale_ratio_range
self.translation_std = translation_std
def __call__(self, results):
# random rotate
rot_angle = np.random.uniform(*self.rot_range)
self.rotate_z(results, rot_angle)
results["gt_bboxes_3d"].rotate(np.array(rot_angle))
# random scale
scale_ratio = np.random.uniform(*self.scale_ratio_range)
self.scale_xyz(results, scale_ratio)
results["gt_bboxes_3d"].scale(scale_ratio)
# TODO: support translation
return results
def rotate_z(self, results, rot_angle):
rot_cos = torch.cos(torch.tensor(rot_angle))
rot_sin = torch.sin(torch.tensor(rot_angle))
rot_mat = torch.tensor([
[rot_cos, -rot_sin, 0, 0],
[rot_sin, rot_cos, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
])
rot_mat_inv = torch.inverse(rot_mat)
for view in range(len(results['lidar2img'])):
results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ rot_mat_inv).numpy()
def scale_xyz(self, results, scale_ratio):
scale_mat = torch.tensor([
[scale_ratio, 0, 0, 0],
[0, scale_ratio, 0, 0],
[0, 0, scale_ratio, 0],
[0, 0, 0, 1],
])
scale_mat_inv = torch.inverse(scale_mat)
for view in range(len(results['lidar2img'])):
results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ scale_mat_inv).numpy()
================================================
FILE: loaders/ray_metrics.py
================================================
# Acknowledgments: https://github.com/tarashakhurana/4d-occ-forecasting
# Modified by Haisong Liu
import math
import copy
import numpy as np
import torch
from torch.utils.cpp_extension import load
from tqdm import tqdm
from prettytable import PrettyTable
from .ray_pq import Metric_RayPQ
dvr = load("dvr", sources=["lib/dvr/dvr.cpp", "lib/dvr/dvr.cu"], verbose=True, extra_cuda_cflags=['-allow-unsupported-compiler'])
_pc_range = [-40, -40, -1.0, 40, 40, 5.4]
_voxel_size = 0.4
# https://github.com/tarashakhurana/4d-occ-forecasting/blob/ff986082cd6ea10e67ab7839bf0e654736b3f4e2/test_fgbg.py#L29C1-L46C16
def get_rendered_pcds(origin, points, tindex, pred_dist):
pcds = []
for t in range(len(origin)):
mask = (tindex == t)
# skip the ones with no data
if not mask.any():
continue
_pts = points[mask, :3]
# use ground truth lidar points for the raycasting direction
v = _pts - origin[t][None, :]
d = v / np.sqrt((v ** 2).sum(axis=1, keepdims=True))
pred_pts = origin[t][None, :] + d * pred_dist[mask][:, None]
pcds.append(torch.from_numpy(pred_pts))
return pcds
def meshgrid3d(occ_size, pc_range):
W, H, D = occ_size
xs = torch.linspace(0.5, W - 0.5, W).view(W, 1, 1).expand(W, H, D) / W
ys = torch.linspace(0.5, H - 0.5, H).view(1, H, 1).expand(W, H, D) / H
zs = torch.linspace(0.5, D - 0.5, D).view(1, 1, D).expand(W, H, D) / D
xs = xs * (pc_range[3] - pc_range[0]) + pc_range[0]
ys = ys * (pc_range[4] - pc_range[1]) + pc_range[1]
zs = zs * (pc_range[5] - pc_range[2]) + pc_range[2]
xyz = torch.stack((xs, ys, zs), -1)
return xyz
def generate_lidar_rays():
# prepare lidar ray angles
pitch_angles = []
for k in range(10):
angle = math.pi / 2 - math.atan(k + 1)
pitch_angles.append(-angle)
# nuscenes lidar fov: [0.2107773983152201, -0.5439104895672159] (rad)
while pitch_angles[-1] < 0.21:
delta = pitch_angles[-1] - pitch_angles[-2]
pitch_angles.append(pitch_angles[-1] + delta)
lidar_rays = []
for pitch_angle in pitch_angles:
for azimuth_angle in np.arange(0, 360, 1):
azimuth_angle = np.deg2rad(azimuth_angle)
x = np.cos(pitch_angle) * np.cos(azimuth_angle)
y = np.cos(pitch_angle) * np.sin(azimuth_angle)
z = np.sin(pitch_angle)
lidar_rays.append((x, y, z))
return np.array(lidar_rays, dtype=np.float32)
def process_one_sample(sem_pred, lidar_rays, output_origin, instance_pred=None, occ_class_names=None):
# lidar origin in ego coordinate
# lidar_origin = torch.tensor([[[0.9858, 0.0000, 1.8402]]])
T = output_origin.shape[1]
pred_pcds_t = []
free_id = len(occ_class_names) - 1
occ_pred = copy.deepcopy(sem_pred)
occ_pred[sem_pred < free_id] = 1
occ_pred[sem_pred == free_id] = 0
occ_pred = occ_pred.permute(2, 1, 0)
occ_pred = occ_pred[None, None, :].contiguous().float()
offset = torch.Tensor(_pc_range[:3])[None, None, :]
scaler = torch.Tensor([_voxel_size] * 3)[None, None, :]
lidar_tindex = torch.zeros([1, lidar_rays.shape[0]])
for t in range(T):
lidar_origin = output_origin[:, t:t+1, :] # [1, 1, 3]
lidar_endpts = lidar_rays[None] + lidar_origin # [1, 15840, 3]
output_origin_render = ((lidar_origin - offset) / scaler).float() # [1, 1, 3]
output_points_render = ((lidar_endpts - offset) / scaler).float() # [1, N, 3]
output_tindex_render = lidar_tindex # [1, N], all zeros
with torch.no_grad():
pred_dist, _, coord_index = dvr.render_forward(
occ_pred.cuda(),
output_origin_render.cuda(),
output_points_render.cuda(),
output_tindex_render.cuda(),
[1, 16, 200, 200],
"test"
)
pred_dist *= _voxel_size
pred_pcds = get_rendered_pcds(
lidar_origin[0].cpu().numpy(),
lidar_endpts[0].cpu().numpy(),
lidar_tindex[0].cpu().numpy(),
pred_dist[0].cpu().numpy()
)
coord_index = coord_index[0, :, :].int().cpu() # [N, 3]
pred_label = sem_pred[coord_index[:, 0], coord_index[:, 1], coord_index[:, 2]][:, None] # [N, 1]
pred_dist = pred_dist[0, :, None].cpu()
if instance_pred is not None:
pred_instance = instance_pred[coord_index[:, 0], coord_index[:, 1], coord_index[:, 2]][:, None] # [N, 1]
pred_pcds = torch.cat([pred_label.float(), pred_instance.float(), pred_dist], dim=-1)
else:
pred_pcds = torch.cat([pred_label.float(), pred_dist], dim=-1)
pred_pcds_t.append(pred_pcds)
pred_pcds_t = torch.cat(pred_pcds_t, dim=0)
return pred_pcds_t.numpy()
def calc_rayiou(pcd_pred_list, pcd_gt_list, occ_class_names):
thresholds = [1, 2, 4]
gt_cnt = np.zeros([len(occ_class_names)])
pred_cnt = np.zeros([len(occ_class_names)])
tp_cnt = np.zeros([len(thresholds), len(occ_class_names)])
for pcd_pred, pcd_gt in zip(pcd_pred_list, pcd_gt_list):
for j, threshold in enumerate(thresholds):
# L1
depth_pred = pcd_pred[:, 1]
depth_gt = pcd_gt[:, 1]
l1_error = np.abs(depth_pred - depth_gt)
tp_dist_mask = (l1_error < threshold)
for i, cls in enumerate(occ_class_names):
cls_id = occ_class_names.index(cls)
cls_mask_pred = (pcd_pred[:, 0] == cls_id)
cls_mask_gt = (pcd_gt[:, 0] == cls_id)
gt_cnt_i = cls_mask_gt.sum()
pred_cnt_i = cls_mask_pred.sum()
if j == 0:
gt_cnt[i] += gt_cnt_i
pred_cnt[i] += pred_cnt_i
tp_cls = cls_mask_gt & cls_mask_pred # [N]
tp_mask = np.logical_and(tp_cls, tp_dist_mask)
tp_cnt[j][i] += tp_mask.sum()
iou_list = []
for j, threshold in enumerate(thresholds):
iou_list.append((tp_cnt[j] / (gt_cnt + pred_cnt - tp_cnt[j]))[:-1])
return iou_list
def main_rayiou(sem_pred_list, sem_gt_list, lidar_origin_list, occ_class_names):
torch.cuda.empty_cache()
# generate lidar rays
lidar_rays = generate_lidar_rays()
lidar_rays = torch.from_numpy(lidar_rays)
pcd_pred_list, pcd_gt_list = [], []
for sem_pred, sem_gt, lidar_origins in tqdm(zip(sem_pred_list, sem_gt_list, lidar_origin_list), ncols=50):
sem_pred = torch.from_numpy(np.reshape(sem_pred, [200, 200, 16]))
sem_gt = torch.from_numpy(np.reshape(sem_gt, [200, 200, 16]))
pcd_pred = process_one_sample(sem_pred, lidar_rays, lidar_origins, occ_class_names=occ_class_names)
pcd_gt = process_one_sample(sem_gt, lidar_rays, lidar_origins, occ_class_names=occ_class_names)
# evalute on non-free rays
valid_mask = (pcd_gt[:, 0].astype(np.int32) != len(occ_class_names) - 1)
pcd_pred = pcd_pred[valid_mask]
pcd_gt = pcd_gt[valid_mask]
assert pcd_pred.shape == pcd_gt.shape
pcd_pred_list.append(pcd_pred)
pcd_gt_list.append(pcd_gt)
iou_list = calc_rayiou(pcd_pred_list, pcd_gt_list, occ_class_names)
rayiou = np.nanmean(iou_list)
rayiou_0 = np.nanmean(iou_list[0])
rayiou_1 = np.nanmean(iou_list[1])
rayiou_2 = np.nanmean(iou_list[2])
table = PrettyTable([
'Class Names',
'RayIoU@1', 'RayIoU@2', 'RayIoU@4'
])
table.float_format = '.3'
for i in range(len(occ_class_names) - 1):
table.add_row([
occ_class_names[i],
iou_list[0][i], iou_list[1][i], iou_list[2][i]
], divider=(i == len(occ_class_names) - 2))
table.add_row(['MEAN', rayiou_0, rayiou_1, rayiou_2])
print(table)
torch.cuda.empty_cache()
return {
'RayIoU': rayiou,
'RayIoU@1': rayiou_0,
'RayIoU@2': rayiou_1,
'RayIoU@4': rayiou_2,
}
def main_raypq(sem_pred_list, sem_gt_list, inst_pred_list, inst_gt_list, lidar_origin_list, occ_class_names):
torch.cuda.empty_cache()
eval_metrics_pq = Metric_RayPQ(
occ_class_names=occ_class_names,
num_classes=len(occ_class_names),
thresholds=[1, 2, 4]
)
# generate lidar rays
lidar_rays = generate_lidar_rays()
lidar_rays = torch.from_numpy(lidar_rays)
for sem_pred, sem_gt, inst_pred, inst_gt, lidar_origins in \
tqdm(zip(sem_pred_list, sem_gt_list, inst_pred_list, inst_gt_list, lidar_origin_list), ncols=50):
sem_pred = torch.from_numpy(np.reshape(sem_pred, [200, 200, 16]))
sem_gt = torch.from_numpy(np.reshape(sem_gt, [200, 200, 16]))
inst_pred = torch.from_numpy(np.reshape(inst_pred, [200, 200, 16]))
inst_gt = torch.from_numpy(np.reshape(inst_gt, [200, 200, 16]))
pcd_pred = process_one_sample(sem_pred, lidar_rays, lidar_origins, instance_pred=inst_pred, occ_class_names=occ_class_names)
pcd_gt = process_one_sample(sem_gt, lidar_rays, lidar_origins, instance_pred=inst_gt, occ_class_names=occ_class_names)
# evalute on non-free rays
valid_mask = (pcd_gt[:, 0].astype(np.int32) != len(occ_class_names) - 1)
pcd_pred = pcd_pred[valid_mask]
pcd_gt = pcd_gt[valid_mask]
assert pcd_pred.shape == pcd_gt.shape
sem_gt = pcd_gt[:, 0].astype(np.int32)
sem_pred = pcd_pred[:, 0].astype(np.int32)
instances_gt = pcd_gt[:, 1].astype(np.int32)
instances_pred = pcd_pred[:, 1].astype(np.int32)
# L1
depth_gt = pcd_gt[:, 2]
depth_pred = pcd_pred[:, 2]
l1_error = np.abs(depth_pred - depth_gt)
eval_metrics_pq.add_batch(sem_pred, sem_gt, instances_pred, instances_gt, l1_error)
torch.cuda.empty_cache()
return eval_metrics_pq.count_pq()
================================================
FILE: loaders/ray_pq.py
================================================
import numpy as np
from prettytable import PrettyTable
class Metric_RayPQ:
def __init__(self,
occ_class_names,
num_classes=18,
thresholds=[1, 2, 4]):
"""
Args:
ignore_index (llist): Class ids that not be considered in pq counting.
"""
if num_classes == 18 or num_classes == 17:
self.class_names = occ_class_names
else:
raise ValueError
self.num_classes = num_classes
self.id_offset = 2 ** 16
self.eps = 1e-5
self.thresholds = thresholds
self.min_num_points = 10
self.include = np.array(
[n for n in range(self.num_classes - 1)],
dtype=int)
self.cnt = 0
# panoptic stuff
self.pan_tp = np.zeros([len(self.thresholds), num_classes], dtype=int)
self.pan_iou = np.zeros([len(self.thresholds), num_classes], dtype=np.double)
self.pan_fp = np.zeros([len(self.thresholds), num_classes], dtype=int)
self.pan_fn = np.zeros([len(self.thresholds), num_classes], dtype=int)
def add_batch(self,semantics_pred,semantics_gt,instances_pred,instances_gt, l1_error):
self.cnt += 1
self.add_panoptic_sample(semantics_pred, semantics_gt, instances_pred, instances_gt, l1_error)
def add_panoptic_sample(self, semantics_pred, semantics_gt, instances_pred, instances_gt, l1_error):
"""Add one sample of panoptic predictions and ground truths for
evaluation.
Args:
semantics_pred (np.ndarray): Semantic predictions.
semantics_gt (np.ndarray): Semantic ground truths.
instances_pred (np.ndarray): Instance predictions.
instances_gt (np.ndarray): Instance ground truths.
"""
# get instance_class_id from instance_gt
instance_class_ids = [self.num_classes - 1]
for i in range(1, instances_gt.max() + 1):
class_id = np.unique(semantics_gt[instances_gt == i])
# assert class_id.shape[0] == 1, "each instance must belong to only one class"
if class_id.shape[0] == 1:
instance_class_ids.append(class_id[0])
else:
instance_class_ids.append(self.num_classes - 1)
instance_class_ids = np.array(instance_class_ids)
instance_count = 1
final_instance_class_ids = []
final_instances = np.zeros_like(instances_gt) # empty space has instance id "0"
for class_id in range(self.num_classes - 1):
if np.sum(semantics_gt == class_id) == 0:
continue
if self.class_names[class_id] in ['car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'motorcycle', 'bicycle', 'pedestrian']:
# treat as instances
for instance_id in range(len(instance_class_ids)):
if instance_class_ids[instance_id] != class_id:
continue
final_instances[instances_gt == instance_id] = instance_count
instance_count += 1
final_instance_class_ids.append(class_id)
else:
# treat as semantics
final_instances[semantics_gt == class_id] = instance_count
instance_count += 1
final_instance_class_ids.append(class_id)
instances_gt = final_instances
# avoid zero (ignored label)
instances_pred = instances_pred + 1
instances_gt = instances_gt + 1
for j, threshold in enumerate(self.thresholds):
tp_dist_mask = l1_error < threshold
# for each class (except the ignored ones)
for cl in self.include:
# get a class mask
pred_inst_in_cl_mask = semantics_pred == cl
gt_inst_in_cl_mask = semantics_gt == cl
# get instance points in class (makes outside stuff 0)
pred_inst_in_cl = instances_pred * pred_inst_in_cl_mask.astype(int)
gt_inst_in_cl = instances_gt * gt_inst_in_cl_mask.astype(int)
# generate the areas for each unique instance prediction
unique_pred, counts_pred = np.unique(
pred_inst_in_cl[pred_inst_in_cl > 0], return_counts=True)
id2idx_pred = {id: idx for idx, id in enumerate(unique_pred)}
matched_pred = np.array([False] * unique_pred.shape[0])
# generate the areas for each unique instance gt_np
unique_gt, counts_gt = np.unique(
gt_inst_in_cl[gt_inst_in_cl > 0], return_counts=True)
id2idx_gt = {id: idx for idx, id in enumerate(unique_gt)}
matched_gt = np.array([False] * unique_gt.shape[0])
# generate intersection using offset
valid_combos = np.logical_and(pred_inst_in_cl > 0,
gt_inst_in_cl > 0)
# add dist_mask
valid_combos = np.logical_and(valid_combos, tp_dist_mask)
id_offset_combo = pred_inst_in_cl[
valid_combos] + self.id_offset * gt_inst_in_cl[valid_combos]
unique_combo, counts_combo = np.unique(
id_offset_combo, return_counts=True)
# generate an intersection map
# count the intersections with over 0.5 IoU as TP
gt_labels = unique_combo // self.id_offset
pred_labels = unique_combo % self.id_offset
gt_areas = np.array([counts_gt[id2idx_gt[id]] for id in gt_labels])
pred_areas = np.array(
[counts_pred[id2idx_pred[id]] for id in pred_labels])
intersections = counts_combo
unions = gt_areas + pred_areas - intersections
ious = intersections.astype(float) / unions.astype(float)
tp_indexes = ious > 0.5
self.pan_tp[j][cl] += np.sum(tp_indexes)
self.pan_iou[j][cl] += np.sum(ious[tp_indexes])
matched_gt[[id2idx_gt[id] for id in gt_labels[tp_indexes]]] = True
matched_pred[[id2idx_pred[id]
for id in pred_labels[tp_indexes]]] = True
# count the FN
if len(counts_gt) > 0:
self.pan_fn[j][cl] += np.sum(
np.logical_and(counts_gt >= self.min_num_points,
~matched_gt))
# count the FP
if len(matched_pred) > 0:
self.pan_fp[j][cl] += np.sum(
np.logical_and(counts_pred >= self.min_num_points,
~matched_pred))
def count_pq(self):
sq_all = self.pan_iou.astype(np.double) / np.maximum(
self.pan_tp.astype(np.double), self.eps)
rq_all = self.pan_tp.astype(np.double) / np.maximum(
self.pan_tp.astype(np.double) + 0.5 * self.pan_fp.astype(np.double)
+ 0.5 * self.pan_fn.astype(np.double), self.eps)
pq_all = sq_all * rq_all
# mask classes not occurring in dataset
mask = (self.pan_tp + self.pan_fp + self.pan_fn) > 0
pq_all[~mask] = float('nan')
table = PrettyTable([
'Class Names',
'RayPQ@%d' % self.thresholds[0],
'RayPQ@%d' % self.thresholds[1],
'RayPQ@%d' % self.thresholds[2]
])
table.float_format = '.3'
for i in range(len(self.class_names) - 1):
table.add_row([
self.class_names[i],
pq_all[0][i], pq_all[1][i], pq_all[2][i],
], divider=(i == len(self.class_names) - 2))
table.add_row([
'MEAN',
np.nanmean(pq_all[0]), np.nanmean(pq_all[1]), np.nanmean(pq_all[2])
])
print(table)
return {
'RayPQ': np.nanmean(pq_all),
'RayPQ@1': np.nanmean(pq_all[0]),
'RayPQ@2': np.nanmean(pq_all[1]),
'RayPQ@4': np.nanmean(pq_all[2]),
}
================================================
FILE: models/__init__.py
================================================
from .backbones import __all__
from .bbox import __all__
from .sparseocc import SparseOcc
from .sparseocc_head import SparseOccHead
from .sparseocc_transformer import SparseOccTransformer
from .loss_utils import *
__all__ = []
================================================
FILE: models/backbones/__init__.py
================================================
from .vovnet import VoVNet
__all__ = ['VoVNet']
================================================
FILE: models/backbones/vovnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
import torch.utils.checkpoint as cp
from collections import OrderedDict
from mmcv.runner import BaseModule
from mmdet.models.builder import BACKBONES
from torch.nn.modules.batchnorm import _BatchNorm
VoVNet19_slim_dw_eSE = {
'stem': [64, 64, 64],
'stage_conv_ch': [64, 80, 96, 112],
'stage_out_ch': [112, 256, 384, 512],
"layer_per_block": 3,
"block_per_stage": [1, 1, 1, 1],
"eSE": True,
"dw": True
}
VoVNet19_dw_eSE = {
'stem': [64, 64, 64],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 3,
"block_per_stage": [1, 1, 1, 1],
"eSE": True,
"dw": True
}
VoVNet19_slim_eSE = {
'stem': [64, 64, 128],
'stage_conv_ch': [64, 80, 96, 112],
'stage_out_ch': [112, 256, 384, 512],
'layer_per_block': 3,
'block_per_stage': [1, 1, 1, 1],
'eSE': True,
"dw": False
}
VoVNet19_eSE = {
'stem': [64, 64, 128],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 3,
"block_per_stage": [1, 1, 1, 1],
"eSE": True,
"dw": False
}
VoVNet39_eSE = {
'stem': [64, 64, 128],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 5,
"block_per_stage": [1, 1, 2, 2],
"eSE": True,
"dw": False
}
VoVNet57_eSE = {
'stem': [64, 64, 128],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 5,
"block_per_stage": [1, 1, 4, 3],
"eSE": True,
"dw": False
}
VoVNet99_eSE = {
'stem': [64, 64, 128],
"stage_conv_ch": [128, 160, 192, 224],
"stage_out_ch": [256, 512, 768, 1024],
"layer_per_block": 5,
"block_per_stage": [1, 3, 9, 3],
"eSE": True,
"dw": False
}
_STAGE_SPECS = {
"V-19-slim-dw-eSE": VoVNet19_slim_dw_eSE,
"V-19-dw-eSE": VoVNet19_dw_eSE,
"V-19-slim-eSE": VoVNet19_slim_eSE,
"V-19-eSE": VoVNet19_eSE,
"V-39-eSE": VoVNet39_eSE,
"V-57-eSE": VoVNet57_eSE,
"V-99-eSE": VoVNet99_eSE,
}
def dw_conv3x3(in_channels, out_channels, module_name, postfix, stride=1, kernel_size=3, padding=1):
"""3x3 convolution with padding"""
return [
(
'{}_{}/dw_conv3x3'.format(module_name, postfix),
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=out_channels,
bias=False
)
),
(
'{}_{}/pw_conv1x1'.format(module_name, postfix),
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
),
('{}_{}/pw_norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)),
('{}_{}/pw_relu'.format(module_name, postfix), nn.ReLU(inplace=True)),
]
def conv3x3(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=3, padding=1):
"""3x3 convolution with padding"""
return [
(
f"{module_name}_{postfix}/conv",
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
),
),
(f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)),
(f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)),
]
def conv1x1(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=1, padding=0):
"""1x1 convolution with padding"""
return [
(
f"{module_name}_{postfix}/conv",
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
),
),
(f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)),
(f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)),
]
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3.0, inplace=self.inplace) / 6.0
class eSEModule(nn.Module):
def __init__(self, channel, reduction=4):
super(eSEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0)
self.hsigmoid = Hsigmoid()
def forward(self, x):
inputs = x
x = self.avg_pool(x)
x = self.fc(x)
x = self.hsigmoid(x)
return inputs * x
class _OSA_module(nn.Module):
def __init__(self, in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE=False, identity=False, depthwise=False, with_cp=False):
super(_OSA_module, self).__init__()
self.with_cp = with_cp
self.identity = identity
self.depthwise = depthwise
self.isReduced = False
self.layers = nn.ModuleList()
in_channel = in_ch
if self.depthwise and in_channel != stage_ch:
self.isReduced = True
self.conv_reduction = nn.Sequential(
OrderedDict(conv1x1(in_channel, stage_ch, "{}_reduction".format(module_name), "0"))
)
for i in range(layer_per_block):
if self.depthwise:
self.layers.append(nn.Sequential(OrderedDict(dw_conv3x3(stage_ch, stage_ch, module_name, i))))
else:
self.layers.append(nn.Sequential(OrderedDict(conv3x3(in_channel, stage_ch, module_name, i))))
in_channel = stage_ch
# feature aggregation
in_channel = in_ch + layer_per_block * stage_ch
self.concat = nn.Sequential(OrderedDict(conv1x1(in_channel, concat_ch, module_name, "concat")))
self.ese = eSEModule(concat_ch)
def _forward(self, x):
identity_feat = x
output = []
output.append(x)
if self.depthwise and self.isReduced:
x = self.conv_reduction(x)
for layer in self.layers:
x = layer(x)
output.append(x)
x = torch.cat(output, dim=1)
xt = self.concat(x)
xt = self.ese(xt)
if self.identity:
xt = xt + identity_feat
return xt
def forward(self, x):
if self.with_cp and self.training and x.requires_grad:
return cp.checkpoint(self._forward, x)
else:
return self._forward(x)
class _OSA_stage(nn.Sequential):
def __init__(self, in_ch, stage_ch, concat_ch, block_per_stage, layer_per_block, stage_num, SE=False, depthwise=False, with_cp=False):
super(_OSA_stage, self).__init__()
if not stage_num == 2:
self.add_module("Pooling", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True))
if block_per_stage != 1:
SE = False
module_name = f"OSA{stage_num}_1"
self.add_module(
module_name, _OSA_module(in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE, depthwise=depthwise, with_cp=with_cp)
)
for i in range(block_per_stage - 1):
if i != block_per_stage - 2: # last block
SE = False
module_name = f"OSA{stage_num}_{i + 2}"
self.add_module(
module_name,
_OSA_module(
concat_ch,
stage_ch,
concat_ch,
layer_per_block,
module_name,
SE,
identity=True,
depthwise=depthwise,
with_cp=with_cp
),
)
@BACKBONES.register_module()
class VoVNet(BaseModule):
def __init__(self, spec_name, input_ch=3, out_features=None, frozen_stages=-1, norm_eval=True, with_cp=False, pretrained=None, init_cfg=None):
"""
Args:
input_ch(int) : the number of input channel
out_features (list[str]): name of the layers whose outputs should
be returned in forward. Can be anything in "stem", "stage2" ...
"""
super(VoVNet, self).__init__(init_cfg)
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
stage_specs = _STAGE_SPECS[spec_name]
stem_ch = stage_specs["stem"]
config_stage_ch = stage_specs["stage_conv_ch"]
config_concat_ch = stage_specs["stage_out_ch"]
block_per_stage = stage_specs["block_per_stage"]
layer_per_block = stage_specs["layer_per_block"]
SE = stage_specs["eSE"]
depthwise = stage_specs["dw"]
self._out_features = out_features
# Stem module
conv_type = dw_conv3x3 if depthwise else conv3x3
stem = conv3x3(input_ch, stem_ch[0], "stem", "1", 2)
stem += conv_type(stem_ch[0], stem_ch[1], "stem", "2", 1)
stem += conv_type(stem_ch[1], stem_ch[2], "stem", "3", 2)
self.add_module("stem", nn.Sequential((OrderedDict(stem))))
current_stirde = 4
self._out_feature_strides = {"stem": current_stirde, "stage2": current_stirde}
self._out_feature_channels = {"stem": stem_ch[2]}
stem_out_ch = [stem_ch[2]]
in_ch_list = stem_out_ch + config_concat_ch[:-1]
# OSA stages
self.stage_names = []
for i in range(4): # num_stages
name = "stage%d" % (i + 2) # stage 2 ... stage 5
self.stage_names.append(name)
self.add_module(
name,
_OSA_stage(
in_ch_list[i],
config_stage_ch[i],
config_concat_ch[i],
block_per_stage[i],
layer_per_block,
i + 2,
SE,
depthwise,
with_cp=with_cp
),
)
self._out_feature_channels[name] = config_concat_ch[i]
if not i == 0:
self._out_feature_strides[name] = current_stirde = int(current_stirde * 2)
# initialize weights
# self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
outputs = {}
x = self.stem(x)
if "stem" in self._out_features:
outputs["stem"] = x
for name in self.stage_names:
x = getattr(self, name)(x)
if name in self._out_features:
outputs[name] = x
return outputs
def _freeze_stages(self):
if self.frozen_stages >= 0:
m = getattr(self, 'stem')
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'stage{i+1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
"""Convert the model into training mode while keep normalization layer
freezed."""
super(VoVNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
================================================
FILE: models/bbox/__init__.py
================================================
from .assigners import __all__
from .coders import __all__
from .match_costs import __all__
================================================
FILE: models/bbox/assigners/__init__.py
================================================
from .hungarian_assigner_3d import HungarianAssigner3D
__all__ = ['HungarianAssigner3D']
================================================
FILE: models/bbox/assigners/hungarian_assigner_3d.py
================================================
import torch
from mmdet.core.bbox.builder import BBOX_ASSIGNERS
from mmdet.core.bbox.assigners import AssignResult
from mmdet.core.bbox.assigners import BaseAssigner
from mmdet.core.bbox.match_costs import build_match_cost
from ..utils import normalize_bbox
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
linear_sum_assignment = None
@BBOX_ASSIGNERS.register_module()
class HungarianAssigner3D(BaseAssigner):
def __init__(self,
cls_cost=dict(type='ClassificationCost', weight=1.),
reg_cost=dict(type='BBoxL1Cost', weight=1.0),
iou_cost=dict(type='IoUCost', weight=0.0),
pc_range=None):
self.cls_cost = build_match_cost(cls_cost)
self.reg_cost = build_match_cost(reg_cost)
self.iou_cost = build_match_cost(iou_cost)
self.pc_range = pc_range
def assign(self,
bbox_pred,
cls_pred,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
code_weights=None,
with_velo=False):
assert gt_bboxes_ignore is None, \
'Only case when gt_bboxes_ignore is None is supported.'
num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
# 1. assign -1 by default
assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
-1,
dtype=torch.long)
assigned_labels = bbox_pred.new_full((num_bboxes, ),
-1,
dtype=torch.long)
if num_gts == 0 or num_bboxes == 0:
# No ground truth or boxes, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
# 2. compute the weighted costs
# classification and bboxcost.
cls_cost = self.cls_cost(cls_pred, gt_labels)
# regression L1 cost
normalized_gt_bboxes = normalize_bbox(gt_bboxes)
if code_weights is not None:
bbox_pred = bbox_pred * code_weights
normalized_gt_bboxes = normalized_gt_bboxes * code_weights
if with_velo:
reg_cost = self.reg_cost(bbox_pred, normalized_gt_bboxes)
else:
reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8])
# weighted sum of above two costs
cost = cls_cost + reg_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
cost = torch.nan_to_num(cost, nan=100.0, posinf=100.0, neginf=-100.0)
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" '
'to install scipy first.')
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = torch.from_numpy(matched_row_inds).to(
bbox_pred.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(
bbox_pred.device)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
return AssignResult(
num_gts, assigned_gt_inds, None, labels=assigned_labels)
================================================
FILE: models/bbox/coders/__init__.py
================================================
from .nms_free_coder import NMSFreeCoder
__all__ = ['NMSFreeCoder']
================================================
FILE: models/bbox/coders/nms_free_coder.py
================================================
import torch
from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS
from ..utils import denormalize_bbox
@BBOX_CODERS.register_module()
class NMSFreeCoder(BaseBBoxCoder):
"""Bbox coder for NMS-free detector.
Args:
pc_range (list[float]): Range of point cloud.
post_center_range (list[float]): Limit of the center.
Default: None.
max_num (int): Max number to be kept. Default: 100.
score_threshold (float): Threshold to filter boxes based on score.
Default: None.
code_size (int): Code size of bboxes. Default: 9
"""
def __init__(self,
pc_range,
voxel_size=None,
post_center_range=None,
max_num=100,
score_threshold=None,
num_classes=10):
self.pc_range = pc_range
self.voxel_size = voxel_size
self.post_center_range = post_center_range
self.max_num = max_num
self.score_threshold = score_threshold
self.num_classes = num_classes
def encode(self):
pass
def decode_single(self, cls_scores, bbox_preds):
"""Decode bboxes.
Args:
cls_scores (Tensor): Outputs from the classification head, \
shape [num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
bbox_preds (Tensor): Outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
Shape [num_query, 9].
Returns:
list[dict]: Decoded boxes.
"""
max_num = self.max_num
cls_scores = cls_scores.sigmoid()
scores, indexs = cls_scores.view(-1).topk(max_num)
labels = indexs % self.num_classes
bbox_index = torch.div(indexs, self.num_classes, rounding_mode='trunc')
bbox_preds = bbox_preds[bbox_index]
final_box_preds = denormalize_bbox(bbox_preds)
final_scores = scores
final_preds = labels
# use score threshold
if self.score_threshold is not None:
thresh_mask = final_scores > self.score_threshold
if self.post_center_range is not None:
limit = torch.tensor(self.post_center_range, device=scores.device)
mask = (final_box_preds[..., :3] >= limit[:3]).all(1)
mask &= (final_box_preds[..., :3] <= limit[3:]).all(1)
if self.score_threshold:
mask &= thresh_mask
boxes3d = final_box_preds[mask]
scores = final_scores[mask]
labels = final_preds[mask]
predictions_dict = {
'bboxes': boxes3d,
'scores': scores,
'labels': labels
}
else:
raise NotImplementedError(
'Need to reorganize output as a batch, only '
'support post_center_range is not None for now!'
)
return predictions_dict
def decode(self, preds_dicts):
"""Decode bboxes.
Args:
all_cls_scores (Tensor): Outputs from the classification head, \
shape [nb_dec, bs, num_query, cls_out_channels]. Note \
cls_out_channels should includes background.
all_bbox_preds (Tensor): Sigmoid outputs from the regression \
head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
Shape [nb_dec, bs, num_query, 9].
Returns:
list[dict]: Decoded boxes.
"""
all_cls_scores = preds_dicts['all_cls_scores'][-1]
all_bbox_preds = preds_dicts['all_bbox_preds'][-1]
batch_size = all_cls_scores.size()[0]
predictions_list = []
for i in range(batch_size):
predictions_list.append(self.decode_single(all_cls_scores[i], all_bbox_preds[i]))
return predictions_list
================================================
FILE: models/bbox/match_costs/__init__.py
================================================
from .match_cost import BBox3DL1Cost
__all__ = ['BBox3DL1Cost']
================================================
FILE: models/bbox/match_costs/match_cost.py
================================================
import torch
from mmdet.core.bbox.match_costs.builder import MATCH_COST
@MATCH_COST.register_module()
class BBox3DL1Cost(object):
"""BBox3DL1Cost.
Args:
weight (int | float, optional): loss_weight
"""
def __init__(self, weight=1.0):
self.weight = weight
def __call__(self, bbox_pred, gt_bboxes):
"""
Args:
bbox_pred (Tensor): Predicted boxes with normalized coordinates
(cx, cy, w, h), which are all in range [0, 1]. Shape
[num_query, 4].
gt_bboxes (Tensor): Ground truth boxes with normalized
coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
Returns:
torch.Tensor: bbox_cost value with weight
"""
bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
return bbox_cost * self.weight
@MATCH_COST.register_module()
class BBoxBEVL1Cost(object):
def __init__(self, weight, pc_range):
self.weight = weight
self.pc_range = pc_range
def __call__(self, bboxes, gt_bboxes):
pc_start = bboxes.new(self.pc_range[0:2])
pc_range = bboxes.new(self.pc_range[3:5]) - bboxes.new(self.pc_range[0:2])
# normalize the box center to [0, 1]
normalized_bboxes_xy = (bboxes[:, :2] - pc_start) / pc_range
normalized_gt_bboxes_xy = (gt_bboxes[:, :2] - pc_start) / pc_range
reg_cost = torch.cdist(normalized_bboxes_xy, normalized_gt_bboxes_xy, p=1)
return reg_cost * self.weight
@MATCH_COST.register_module()
class IoU3DCost(object):
def __init__(self, weight):
self.weight = weight
def __call__(self, iou):
iou_cost = - iou
return iou_cost * self.weight
================================================
FILE: models/bbox/utils.py
================================================
import torch
def normalize_bbox(bboxes):
cx = bboxes[..., 0:1]
cy = bboxes[..., 1:2]
cz = bboxes[..., 2:3]
w = bboxes[..., 3:4].log()
l = bboxes[..., 4:5].log()
h = bboxes[..., 5:6].log()
rot = bboxes[..., 6:7]
if bboxes.size(-1) > 7:
vx = bboxes[..., 7:8]
vy = bboxes[..., 8:9]
out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos(), vx, vy], dim=-1)
else:
out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos()], dim=-1)
return out
def denormalize_bbox(normalized_bboxes):
rot_sin = normalized_bboxes[..., 6:7]
rot_cos = normalized_bboxes[..., 7:8]
rot = torch.atan2(rot_sin, rot_cos)
cx = normalized_bboxes[..., 0:1]
cy = normalized_bboxes[..., 1:2]
cz = normalized_bboxes[..., 4:5]
w = normalized_bboxes[..., 2:3].exp()
l = normalized_bboxes[..., 3:4].exp()
h = normalized_bboxes[..., 5:6].exp()
if normalized_bboxes.size(-1) > 8:
vx = normalized_bboxes[..., 8:9]
vy = normalized_bboxes[..., 9:10]
out = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1)
else:
out = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1)
return out
def encode_bbox(bboxes, pc_range=None):
xyz = bboxes[..., 0:3].clone()
wlh = bboxes[..., 3:6].log()
rot = bboxes[..., 6:7]
if pc_range is not None:
xyz[..., 0] = (xyz[..., 0] - pc_range[0]) / (pc_range[3] - pc_range[0])
xyz[..., 1] = (xyz[..., 1] - pc_range[1]) / (pc_range[4] - pc_range[1])
xyz[..., 2] = (xyz[..., 2] - pc_range[2]) / (pc_range[5] - pc_range[2])
if bboxes.shape[-1] > 7:
vel = bboxes[..., 7:9].clone()
return torch.cat([xyz, wlh, rot.sin(), rot.cos(), vel], dim=-1)
else:
return torch.cat([xyz, wlh, rot.sin(), rot.cos()], dim=-1)
def decode_bbox(bboxes, pc_range=None):
xyz = bboxes[..., 0:3].clone()
wlh = bboxes[..., 3:6].exp()
rot = torch.atan2(bboxes[..., 6:7], bboxes[..., 7:8])
if pc_range is not None:
xyz[..., 0] = xyz[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0]
xyz[..., 1] = xyz[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1]
xyz[..., 2] = xyz[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2]
if bboxes.shape[-1] > 8:
vel = bboxes[..., 8:10].clone()
return torch.cat([xyz, wlh, rot, vel], dim=-1)
else:
return torch.cat([xyz, wlh, rot], dim=-1)
def bbox2occrange(bboxes, occ_size, query_cube_size=None):
"""
xyz in [0, 1]
wlh in [0, 1]
"""
xyz = bboxes[..., 0:3].clone()
if query_cube_size is not None:
wlh = torch.zeros_like(xyz)
wlh[..., 0] = query_cube_size[0]
wlh[..., 1] = query_cube_size[1]
wlh[..., 2] = query_cube_size[2]
else:
wlh = bboxes[..., 3:6]
wlh[..., 0] = wlh[..., 0] * occ_size[0]
wlh[..., 1] = wlh[..., 1] * occ_size[1]
wlh[..., 2] = wlh[..., 2] * occ_size[2]
xyz[..., 0] = xyz[..., 0] * occ_size[0]
xyz[..., 1] = xyz[..., 1] * occ_size[1]
xyz[..., 2] = xyz[..., 2] * occ_size[2]
xyz = torch.round(xyz)
low_bound = torch.round(xyz - wlh/2)
high_bound = torch.round(xyz + wlh/2)
return torch.cat((low_bound, high_bound), dim=-1).long()
def occrange2bbox(occ_range, occ_size, pc_range):
"""
Return: xyz in [0, 1], wlh in [0, pc_range_size)
"""
xyz = (occ_range[..., :3] + occ_range[..., 3:]).to(torch.float32) / 2
xyz[..., 0] /= occ_size[0]
xyz[..., 1] /= occ_size[1]
xyz[..., 2] /= occ_size[2]
wlh = (occ_range[..., 3:] - occ_range[..., :3]).to(torch.float32)
wlh[..., 0] *= (pc_range[3] - pc_range[0]) / occ_size[0]
wlh[..., 1] *= (pc_range[4] - pc_range[1]) / occ_size[1]
wlh[..., 2] *= (pc_range[5] - pc_range[2]) / occ_size[2]
return torch.cat((xyz, wlh), dim=-1)
================================================
FILE: models/checkpoint.py
================================================
# This page is completely copied from https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint
# If you are using torch 2.0 or higher, you can safely delete this page and import the related functions from official PyTorch
import torch
import warnings
import weakref
from typing import Any, Iterable, List, Tuple
__all__ = [
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
"check_backward_validity", "detach_variable", "get_device_states",
"set_device_states",
]
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
if isinstance(inputs, tuple):
out = []
for inp in inputs:
if not isinstance(inp, torch.Tensor):
out.append(inp)
continue
x = inp.detach()
x.requires_grad = inp.requires_grad
out.append(x)
return tuple(out)
else:
raise RuntimeError(
"Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
def check_backward_validity(inputs: Iterable[Any]) -> None:
if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
# We can't know if the run_fn will internally move some args to different devices,
# which would require logic to preserve rng states for those devices as well.
# We could paranoically stash and restore ALL the rng states for all visible devices,
# but that seems very wasteful for most cases. Compromise: Stash the RNG state for
# the device of all Tensor args.
#
# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
# the conditionals short-circuit.
fwd_gpu_devices = list({arg.get_device() for arg in args
if isinstance(arg, torch.Tensor) and arg.is_cuda})
fwd_gpu_states = []
for device in fwd_gpu_devices:
with torch.cuda.device(device):
fwd_gpu_states.append(torch.cuda.get_rng_state())
return fwd_gpu_devices, fwd_gpu_states
def set_device_states(devices, states) -> None:
for device, state in zip(devices, states):
with torch.cuda.device(device):
torch.cuda.set_rng_state(state)
def _get_autocast_kwargs():
gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
"dtype": torch.get_autocast_cpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled()}
return gpu_autocast_kwargs, cpu_autocast_kwargs
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.)
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for i, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
with torch.no_grad():
outputs = run_function(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors
# Fill in inputs with appropriate saved tensors.
for i, idx in enumerate(tensor_indices):
inputs[idx] = tensors[i]
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(outputs)):
if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
outputs_with_grad.append(outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"none of output has requires_grad=True,"
" this checkpoint() is not necessary")
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs)
return (None, None) + grads
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
r"""Checkpoint a model or part of the model
Checkpointing works by trading compute for memory. Rather than storing all
intermediate activations of the entire computation graph for computing
backward, the checkpointed part does **not** save intermediate activations,
and instead recomputes them in backward pass. It can be applied on any part
of a model.
Specifically, in the forward pass, :attr:`function` will run in
:func:`torch.no_grad` manner, i.e., not storing the intermediate
activations. Instead, the forward pass saves the inputs tuple and the
:attr:`function` parameter. In the backwards pass, the saved inputs and
:attr:`function` is retrieved, and the forward pass is computed on
:attr:`function` again, now tracking the intermediate activations, and then
the gradients are calculated using these activation values.
The output of :attr:`function` can contain non-Tensor values and gradient
recording is only performed for the Tensor values. Note that if the output
consists of nested structures (ex: custom objects, lists, dicts etc.)
consisting of Tensors, these Tensors nested in custom structures will not
be considered as part of autograd.
.. warning::
If :attr:`function` invocation during backward does anything different
than the one during forward, e.g., due to some global variable, the
checkpointed version won't be equivalent, and unfortunately it can't be
detected.
.. warning::
If ``use_reentrant=True`` is specified, then if the checkpointed segment
contains tensors detached from the computational graph by `detach()` or
`torch.no_grad()`, the backward pass will raise an error. This is
because `checkpoint` makes all the outputs require gradients which
causes issues when a tensor is defined to have no gradient in the model.
To circumvent this, detach the tensors outside of the `checkpoint`
function. Note that the checkpointed segment can contain tensors
detached from the computational graph if ``use_reentrant=False`` is
specified.
.. warning::
If ``use_reentrant=True`` is specified, at least one of the inputs needs
to have :code:`requires_grad=True` if grads are needed for model inputs,
otherwise the checkpointed part of the model won't have gradients. At
least one of the outputs needs to have :code:`requires_grad=True` as
well. Note that this does not apply if ``use_reentrant=False`` is
specified.
.. warning::
If ``use_reentrant=True`` is specified, checkpointing currently only
supports :func:`torch.autograd.backward` and only if its `inputs`
argument is not passed. :func:`torch.autograd.grad`
is not supported. If ``use_reentrant=False`` is specified, checkpointing
will work with :func:`torch.autograd.grad`.
Args:
function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint.
Default: ``True``
use_reentrant(bool, optional): Use checkpointing
implementation that requires re-entrant autograd.
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
implementation that does not require re-entrant autograd. This
allows ``checkpoint`` to support additional functionality, such as
working as expected with ``torch.autograd.grad`` and support for
keyword arguments input into the checkpointed function. Note that future
versions of PyTorch will default to ``use_reentrant=False``.
Default: ``True``
args: tuple containing inputs to the :attr:`function`
Returns:
Output of running :attr:`function` on :attr:`*args`
"""
# Hack to mix *args with **kwargs in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs and use_reentrant:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
if use_reentrant:
return CheckpointFunction.apply(function, preserve, *args)
else:
return _checkpoint_without_reentrant(
function,
preserve,
*args,
**kwargs,
)
def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a model in various segments
and checkpoint each segment. All segments except the last will run in
:func:`torch.no_grad` manner, i.e., not storing the intermediate
activations. The inputs of each checkpointed segment will be saved for
re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
.. warning:
Since PyTorch 1.4, it allows only one Tensor as the input and
intermediate outputs, just like :class:`torch.nn.Sequential`.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or
functions (comprising the model) to run sequentially.
segments: Number of chunks to create in the model
input: A Tensor that is input to :attr:`functions`
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint.
Default: ``True``
use_reentrant(bool, optional): Use checkpointing
implementation that requires re-entrant autograd.
If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
implementation that does not require re-entrant autograd. This
allows ``checkpoint`` to support additional functionality, such as
working as expected with ``torch.autograd.grad`` and support for
keyword arguments input into the checkpointed function.
Default: ``True``
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Example:
>>> # xdoctest: +SKIP("stub")
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
"""
# Hack for keyword-only parameter in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
def run_function(start, end, functions):
def forward(input):
for j in range(start, end + 1):
input = functions[j](input)
return input
return forward
if isinstance(functions, torch.nn.Sequential):
functions = list(functions.children())
segment_size = len(functions) // segments
# the last chunk has to be non-volatile
end = -1
for start in range(0, segment_size * (segments - 1), segment_size):
end = start + segment_size - 1
input = checkpoint(
run_function(start, end, functions),
input,
use_reentrant=use_reentrant,
preserve_rng_state=preserve
)
return run_function(end + 1, len(functions) - 1, functions)(input)
def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs):
"""Checkpointining without re-entrant autograd
Args:
function: describes what to run in the forward pass of the model or
part of the model. It should also know how to handle the inputs
passed as the tuple. For example, in LSTM, if user passes
``(activation, hidden)``, :attr:`function` should correctly use the
first input as ``activation`` and the second input as ``hidden``
preserve_rng_state(bool, optional): Omit stashing and restoring
the RNG state during each checkpoint.
Default: ``True``
*args: Arguments to pass in to the given ``function``.
**kwargs: Keyword arguments to pass into the given ``function``.
"""
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
if preserve_rng_state:
fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
# we have no way to anticipate this will happen before we run the function.
# If they do so, we raise an error.)
had_cuda_in_fwd = False
if torch.cuda._initialized:
had_cuda_in_fwd = True
fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
# Custom class to be able to take weak references
class Holder():
pass
# The Holder object for each of the saved object is saved directly on the
# SavedVariable and is cleared when reset_data() is called on it. We MUST make
# sure that this is the only object having an owning reference to ensure that
# the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
# data is cleared.
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
weak_holder_list = []
def pack(x):
# TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
# size, device, ...) to catch certain cases of undeterministic behavior of the forward
res = Holder()
weak_holder_list.append(weakref.ref(res))
return res
def unpack(x):
unpack_counter = 0
if len(storage) == 0:
def inner_pack(inner):
nonlocal unpack_counter
unpack_counter += 1
# If the holder went out of scope, the SavedVariable is dead and so
# the value will never be read from the storage. Skip filling it.
if weak_holder_list[unpack_counter - 1]() is None:
return
# Use detach here to ensure we don't keep the temporary autograd
# graph created during the second forward
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
return
def inner_unpack(packed):
raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
# Stash the surrounding rng state, and mimic the state that was
# present at this time during forward. Restore the surrounding state
# when we're done.
rng_devices = []
if preserve_rng_state and had_cuda_in_fwd:
rng_devices = fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
if preserve_rng_state:
torch.set_rng_state(fwd_cpu_state)
if had_cuda_in_fwd:
set_device_states(fwd_gpu_devices, fwd_gpu_states)
with torch.enable_grad(), \
torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
_unused = function(*args, **kwargs)
if x not in storage:
raise RuntimeError(
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
" recomputation being triggered in between, this is not currently supported. Please"
" open an issue with details on your use case so that we can prioritize adding this."
)
return storage[x]
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
output = function(*args, **kwargs)
if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
# Cuda was not initialized before running the forward, so we didn't
# stash the CUDA state.
raise RuntimeError(
"PyTorch's CUDA state was initialized in the forward pass "
"of a Checkpoint, which is not allowed. Please open an issue "
"if you need this feature.")
return output
================================================
FILE: models/csrc/__init__.py
================================================
================================================
FILE: models/csrc/msmv_sampling/msmv_sampling.cpp
================================================
#include "msmv_sampling.h"
#define MAX_POINT 32
void ms_deformable_im2col_cuda_c2345(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col
);
void ms_deformable_im2col_cuda_c23456(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const float* feat_c6,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const int h_c6, const int w_c6,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col
);
void ms_deformable_col2im_cuda_c2345(
const float* grad_col,
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* grad_value_c2,
float* grad_value_c3,
float* grad_value_c4,
float* grad_value_c5,
float* grad_sampling_loc,
float* grad_attn_weight
);
void ms_deformable_col2im_cuda_c23456(
const float *grad_col,
const float *feat_c2,
const float *feat_c3,
const float *feat_c4,
const float *feat_c5,
const float *feat_c6,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const int h_c6, const int w_c6,
const float *data_sampling_loc,
const float *data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float *grad_value_c2,
float *grad_value_c3,
float *grad_value_c4,
float *grad_value_c5,
float *grad_value_c6,
float *grad_sampling_loc,
float *grad_attn_weight
);
at::Tensor ms_deform_attn_cuda_c2345_forward(
const at::Tensor& feat_c2, // [B, N, H, W, C]
const at::Tensor& feat_c3, // [B, N, H, W, C]
const at::Tensor& feat_c4, // [B, N, H, W, C]
const at::Tensor& feat_c5, // [B, N, H, W, C]
const at::Tensor& sampling_loc, // [B, Q, P, 3]
const at::Tensor& attn_weight // [B, Q, P, 4]
) {
AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
const int batch_size = feat_c2.size(0);
const int num_views = feat_c2.size(1);
const int channels = feat_c2.size(4);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(2);
AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
const int h_c2 = feat_c2.size(2);
const int w_c2 = feat_c2.size(3);
const int h_c3 = feat_c3.size(2);
const int w_c3 = feat_c3.size(3);
const int h_c4 = feat_c4.size(2);
const int w_c4 = feat_c4.size(3);
const int h_c5 = feat_c5.size(2);
const int w_c5 = feat_c5.size(3);
auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options());
ms_deformable_im2col_cuda_c2345(
feat_c2.data_ptr(),
feat_c3.data_ptr(),
feat_c4.data_ptr(),
feat_c5.data_ptr(),
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
sampling_loc.data_ptr(),
attn_weight.data_ptr(),
batch_size, channels, num_views, num_query, num_point,
output.data_ptr()
);
return output;
}
at::Tensor ms_deform_attn_cuda_c23456_forward(
const at::Tensor& feat_c2, // [B, N, H, W, C]
const at::Tensor& feat_c3, // [B, N, H, W, C]
const at::Tensor& feat_c4, // [B, N, H, W, C]
const at::Tensor& feat_c5, // [B, N, H, W, C]
const at::Tensor& feat_c6, // [B, N, H, W, C]
const at::Tensor& sampling_loc, // [B, Q, P, 3]
const at::Tensor& attn_weight // [B, Q, P, 4]
) {
AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c6.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c6.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
const int batch_size = feat_c2.size(0);
const int num_views = feat_c2.size(1);
const int channels = feat_c2.size(4);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(2);
AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
const int h_c2 = feat_c2.size(2);
const int w_c2 = feat_c2.size(3);
const int h_c3 = feat_c3.size(2);
const int w_c3 = feat_c3.size(3);
const int h_c4 = feat_c4.size(2);
const int w_c4 = feat_c4.size(3);
const int h_c5 = feat_c5.size(2);
const int w_c5 = feat_c5.size(3);
const int h_c6 = feat_c6.size(2);
const int w_c6 = feat_c6.size(3);
auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options());
ms_deformable_im2col_cuda_c23456(
feat_c2.data_ptr(),
feat_c3.data_ptr(),
feat_c4.data_ptr(),
feat_c5.data_ptr(),
feat_c6.data_ptr(),
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
sampling_loc.data_ptr(),
attn_weight.data_ptr(),
batch_size, channels, num_views, num_query, num_point,
output.data_ptr()
);
return output;
}
std::vector ms_deform_attn_cuda_c2345_backward(
const at::Tensor& grad_output,
const at::Tensor& feat_c2, // [B, N, H, W, C]
const at::Tensor& feat_c3, // [B, N, H, W, C]
const at::Tensor& feat_c4, // [B, N, H, W, C]
const at::Tensor& feat_c5, // [B, N, H, W, C]
const at::Tensor& sampling_loc, // [B, Q, P, 3]
const at::Tensor& attn_weight // [B, Q, P, 4]
) {
AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
const int batch_size = feat_c2.size(0);
const int num_views = feat_c2.size(1);
const int channels = feat_c2.size(4);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(2);
AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
auto grad_value_c2 = at::zeros_like(feat_c2);
auto grad_value_c3 = at::zeros_like(feat_c3);
auto grad_value_c4 = at::zeros_like(feat_c4);
auto grad_value_c5 = at::zeros_like(feat_c5);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int h_c2 = feat_c2.size(2);
const int w_c2 = feat_c2.size(3);
const int h_c3 = feat_c3.size(2);
const int w_c3 = feat_c3.size(3);
const int h_c4 = feat_c4.size(2);
const int w_c4 = feat_c4.size(3);
const int h_c5 = feat_c5.size(2);
const int w_c5 = feat_c5.size(3);
ms_deformable_col2im_cuda_c2345(
grad_output.data_ptr(),
feat_c2.data_ptr(),
feat_c3.data_ptr(),
feat_c4.data_ptr(),
feat_c5.data_ptr(),
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
sampling_loc.data_ptr(),
attn_weight.data_ptr(),
batch_size, channels, num_views, num_query, num_point,
grad_value_c2.data_ptr(),
grad_value_c3.data_ptr(),
grad_value_c4.data_ptr(),
grad_value_c5.data_ptr(),
grad_sampling_loc.data_ptr(),
grad_attn_weight.data_ptr()
);
return {
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight
};
}
std::vector ms_deform_attn_cuda_c23456_backward(
const at::Tensor& grad_output,
const at::Tensor& feat_c2, // [B, N, H, W, C]
const at::Tensor& feat_c3, // [B, N, H, W, C]
const at::Tensor& feat_c4, // [B, N, H, W, C]
const at::Tensor& feat_c5, // [B, N, H, W, C]
const at::Tensor& feat_c6, // [B, N, H, W, C]
const at::Tensor& sampling_loc, // [B, Q, P, 3]
const at::Tensor& attn_weight // [B, Q, P, 4]
) {
AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(feat_c6.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(feat_c6.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
const int batch_size = feat_c2.size(0);
const int num_views = feat_c2.size(1);
const int channels = feat_c2.size(4);
const int num_query = sampling_loc.size(1);
const int num_point = sampling_loc.size(2);
AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
auto grad_value_c2 = at::zeros_like(feat_c2);
auto grad_value_c3 = at::zeros_like(feat_c3);
auto grad_value_c4 = at::zeros_like(feat_c4);
auto grad_value_c5 = at::zeros_like(feat_c5);
auto grad_value_c6 = at::zeros_like(feat_c6);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
auto grad_attn_weight = at::zeros_like(attn_weight);
const int h_c2 = feat_c2.size(2);
const int w_c2 = feat_c2.size(3);
const int h_c3 = feat_c3.size(2);
const int w_c3 = feat_c3.size(3);
const int h_c4 = feat_c4.size(2);
const int w_c4 = feat_c4.size(3);
const int h_c5 = feat_c5.size(2);
const int w_c5 = feat_c5.size(3);
const int h_c6 = feat_c6.size(2);
const int w_c6 = feat_c6.size(3);
ms_deformable_col2im_cuda_c23456(
grad_output.data_ptr(),
feat_c2.data_ptr(),
feat_c3.data_ptr(),
feat_c4.data_ptr(),
feat_c5.data_ptr(),
feat_c6.data_ptr(),
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
sampling_loc.data_ptr(),
attn_weight.data_ptr(),
batch_size, channels, num_views, num_query, num_point,
grad_value_c2.data_ptr(),
grad_value_c3.data_ptr(),
grad_value_c4.data_ptr(),
grad_value_c5.data_ptr(),
grad_value_c6.data_ptr(),
grad_sampling_loc.data_ptr(),
grad_attn_weight.data_ptr()
);
return {
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight
};
}
#ifdef TORCH_EXTENSION_NAME
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("_ms_deform_attn_cuda_c2345_forward", &ms_deform_attn_cuda_c2345_forward, "pass");
m.def("_ms_deform_attn_cuda_c2345_backward", &ms_deform_attn_cuda_c2345_backward, "pass");
m.def("_ms_deform_attn_cuda_c23456_forward", &ms_deform_attn_cuda_c23456_forward, "pass");
m.def("_ms_deform_attn_cuda_c23456_backward", &ms_deform_attn_cuda_c23456_backward, "pass");
}
#endif
================================================
FILE: models/csrc/msmv_sampling/msmv_sampling.h
================================================
#pragma once
#include
at::Tensor ms_deform_attn_cuda_c2345_forward(
const at::Tensor& feat_c2, // [B, N, H, W, C]
const at::Tensor& feat_c3, // [B, N, H, W, C]
const at::Tensor& feat_c4, // [B, N, H, W, C]
const at::Tensor& feat_c5, // [B, N, H, W, C]
const at::Tensor& sampling_loc, // [B, Q, P, 3]
const at::Tensor& attn_weight // [B, Q, P, 4]
);
std::vector ms_deform_attn_cuda_c2345_backward(
const at::Tensor& feat_c2, // [B, N, H, W, C]
const at::Tensor& feat_c3, // [B, N, H, W, C]
const at::Tensor& feat_c4, // [B, N, H, W, C]
const at::Tensor& feat_c5, // [B, N, H, W, C]
const at::Tensor& sampling_loc, // [B, Q, P, 3]
const at::Tensor& attn_weight, // [B, Q, P, 4]
const at::Tensor& grad_output
);
at::Tensor ms_deform_attn_cuda_c23456_forward(
const at::Tensor& feat_c2, // [B, N, H, W, C]
const at::Tensor& feat_c3, // [B, N, H, W, C]
const at::Tensor& feat_c4, // [B, N, H, W, C]
const at::Tensor& feat_c5, // [B, N, H, W, C]
const at::Tensor& feat_c6, // [B, N, H, W, C]
const at::Tensor& sampling_loc, // [B, Q, P, 3]
const at::Tensor& attn_weight // [B, Q, P, 4]
);
std::vector ms_deform_attn_cuda_c23456_backward(
const at::Tensor& grad_output,
const at::Tensor& feat_c2, // [B, N, H, W, C]
const at::Tensor& feat_c3, // [B, N, H, W, C]
const at::Tensor& feat_c4, // [B, N, H, W, C]
const at::Tensor& feat_c5, // [B, N, H, W, C]
const at::Tensor& feat_c6, // [B, N, H, W, C]
const at::Tensor& sampling_loc, // [B, Q, P, 3]
const at::Tensor& attn_weight // [B, Q, P, 4]
);
================================================
FILE: models/csrc/msmv_sampling/msmv_sampling_backward.cu
================================================
/*!
* Modified from Deformable DETR
*/
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
#define CUDA_NUM_THREADS 512
#define MAX_POINT 32
inline int GET_BLOCKS(const int N, const int num_threads)
{
return (N + num_threads - 1) / num_threads;
}
__device__ void ms_deform_attn_col2im_bilinear(const float *&bottom_data,
const int &height, const int &width, const int &channels,
const float &h, const float &w, const int &c,
const float &top_grad,
const float &attn_weight,
const float *&grad_value,
float *&grad_sampling_loc,
float *&grad_attn_weight)
{
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h - h_low;
const float lw = w - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float top_grad_value = top_grad * attn_weight;
float grad_h_weight = 0, grad_w_weight = 0;
float *grad_ptr;
float v1 = 0;
if (h_low >= 0 && w_low >= 0)
{
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;
grad_ptr = const_cast(grad_value + ptr1);
v1 = bottom_data[ptr1];
grad_h_weight -= hw * v1;
grad_w_weight -= hh * v1;
atomicAdd(grad_ptr, w1 * top_grad_value);
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
{
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;
grad_ptr = const_cast(grad_value + ptr2);
v2 = bottom_data[ptr2];
grad_h_weight -= lw * v2;
grad_w_weight += hh * v2;
atomicAdd(grad_ptr, w2 * top_grad_value);
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
{
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;
grad_ptr = const_cast(grad_value + ptr3);
v3 = bottom_data[ptr3];
grad_h_weight += hw * v3;
grad_w_weight -= lh * v3;
atomicAdd(grad_ptr, w3 * top_grad_value);
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
{
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;
grad_ptr = const_cast(grad_value + ptr4);
v4 = bottom_data[ptr4];
grad_h_weight += lw * v4;
grad_w_weight += lh * v4;
atomicAdd(grad_ptr, w4 * top_grad_value);
}
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
atomicAdd(grad_attn_weight, top_grad * val);
atomicAdd(grad_sampling_loc, (width - 1) * grad_w_weight * top_grad_value);
atomicAdd(grad_sampling_loc + 1, (height - 1) * grad_h_weight * top_grad_value);
}
// global_memory_way
__global__ void ms_deformable_col2im_gpu_kernel_gm_c2345(
const float *grad_col,
const float *feat_c2,
const float *feat_c3,
const float *feat_c4,
const float *feat_c5,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const float *data_sampling_loc,
const float *data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float *grad_value_c2,
float *grad_value_c3,
float *grad_value_c4,
float *grad_value_c5,
float *grad_sampling_loc,
float *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point)
{ // n: bs x query x channels
int _temp = index;
const int p_col = _temp % num_point;
_temp /= num_point;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
_temp /= num_query;
const int b_col = _temp;
const float top_grad = grad_col[index];
// Sampling location in range [0, 1]
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
const float loc_w = data_sampling_loc[data_loc_ptr];
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
// Attn weights
int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;
const float weight_c2 = data_attn_weight[data_weight_ptr];
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
// const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
// const float w_im = loc_w * spatial_w - 0.5;
// C2 Feature
float h_im = loc_h * (h_c2 - 1); // align_corners = True
float w_im = loc_w * (w_c2 - 1);
float *grad_location_ptr = grad_sampling_loc + data_loc_ptr;
float *grad_weights_ptr = grad_attn_weight + data_weight_ptr;
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2)
{
const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col,
top_grad, weight_c2,
grad_c2_ptr, grad_location_ptr, grad_weights_ptr);
}
grad_weights_ptr += 1;
// C3 Feature
h_im = loc_h * (h_c3 - 1); // align_corners = True
w_im = loc_w * (w_c3 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3)
{
const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col,
top_grad, weight_c3,
grad_c3_ptr, grad_location_ptr, grad_weights_ptr);
}
grad_weights_ptr += 1;
// C4 Feature
h_im = loc_h * (h_c4 - 1); // align_corners = True
w_im = loc_w * (w_c4 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4)
{
const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col,
top_grad, weight_c4,
grad_c4_ptr, grad_location_ptr, grad_weights_ptr);
}
grad_weights_ptr += 1;
// C5 Feature
h_im = loc_h * (h_c5 - 1); // align_corners = True
w_im = loc_w * (w_c5 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5)
{
const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col,
top_grad, weight_c5,
grad_c5_ptr, grad_location_ptr, grad_weights_ptr);
}
}
}
__global__ void ms_deformable_col2im_gpu_kernel_gm_c23456(
const float *grad_col,
const float *feat_c2,
const float *feat_c3,
const float *feat_c4,
const float *feat_c5,
const float *feat_c6,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const int h_c6, const int w_c6,
const float *data_sampling_loc,
const float *data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float *grad_value_c2,
float *grad_value_c3,
float *grad_value_c4,
float *grad_value_c5,
float *grad_value_c6,
float *grad_sampling_loc,
float *grad_attn_weight)
{
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point)
{ // n: bs x query x channels
int _temp = index;
const int p_col = _temp % num_point;
_temp /= num_point;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
_temp /= num_query;
const int b_col = _temp;
const float top_grad = grad_col[index];
// Sampling location in range [0, 1]
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
const float loc_w = data_sampling_loc[data_loc_ptr];
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
// Attn weights
int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;
const float weight_c2 = data_attn_weight[data_weight_ptr];
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
const float weight_c6 = data_attn_weight[data_weight_ptr + 4];
// const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
// const float w_im = loc_w * spatial_w - 0.5;
// C2 Feature
float h_im = loc_h * (h_c2 - 1); // align_corners = True
float w_im = loc_w * (w_c2 - 1);
float *grad_location_ptr = grad_sampling_loc + data_loc_ptr;
float *grad_weights_ptr = grad_attn_weight + data_weight_ptr;
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2)
{
const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col,
top_grad, weight_c2,
grad_c2_ptr, grad_location_ptr, grad_weights_ptr);
}
grad_weights_ptr += 1;
// C3 Feature
h_im = loc_h * (h_c3 - 1); // align_corners = True
w_im = loc_w * (w_c3 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3)
{
const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col,
top_grad, weight_c3,
grad_c3_ptr, grad_location_ptr, grad_weights_ptr);
}
grad_weights_ptr += 1;
// C4 Feature
h_im = loc_h * (h_c4 - 1); // align_corners = True
w_im = loc_w * (w_c4 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4)
{
const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col,
top_grad, weight_c4,
grad_c4_ptr, grad_location_ptr, grad_weights_ptr);
}
grad_weights_ptr += 1;
// C5 Feature
h_im = loc_h * (h_c5 - 1); // align_corners = True
w_im = loc_w * (w_c5 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5)
{
const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col,
top_grad, weight_c5,
grad_c5_ptr, grad_location_ptr, grad_weights_ptr);
}
grad_weights_ptr += 1;
// C6 Feature
h_im = loc_h * (h_c6 - 1); // align_corners = True
w_im = loc_w * (w_c6 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6)
{
const float *feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
const float *grad_c6_ptr = grad_value_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
ms_deform_attn_col2im_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col,
top_grad, weight_c6,
grad_c6_ptr, grad_location_ptr, grad_weights_ptr);
}
}
}
void ms_deformable_col2im_cuda_c2345(
const float *grad_col,
const float *feat_c2,
const float *feat_c3,
const float *feat_c4,
const float *feat_c5,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const float *data_sampling_loc,
const float *data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float *grad_value_c2,
float *grad_value_c3,
float *grad_value_c4,
float *grad_value_c5,
float *grad_sampling_loc,
float *grad_attn_weight)
{
const int num_kernels = batch_size * num_query * channels * num_point;
const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point;
ms_deformable_col2im_gpu_kernel_gm_c2345 <<>>(
grad_col, feat_c2, feat_c3, feat_c4, feat_c5,
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
data_sampling_loc, data_attn_weight,
batch_size, channels, num_views, num_query, num_point,
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5,
grad_sampling_loc, grad_attn_weight);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_col2im_cuda_c2345: %s\n", cudaGetErrorString(err));
}
}
void ms_deformable_col2im_cuda_c23456(
const float *grad_col,
const float *feat_c2,
const float *feat_c3,
const float *feat_c4,
const float *feat_c5,
const float *feat_c6,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const int h_c6, const int w_c6,
const float *data_sampling_loc,
const float *data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float *grad_value_c2,
float *grad_value_c3,
float *grad_value_c4,
float *grad_value_c5,
float *grad_value_c6,
float *grad_sampling_loc,
float *grad_attn_weight)
{
const int num_kernels = batch_size * num_query * channels * num_point;
const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point;
ms_deformable_col2im_gpu_kernel_gm_c23456 <<>>(
grad_col, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
data_sampling_loc, data_attn_weight,
batch_size, channels, num_views, num_query, num_point,
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6,
grad_sampling_loc, grad_attn_weight);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in ms_deformable_col2im_cuda_c23456: %s\n", cudaGetErrorString(err));
}
}
================================================
FILE: models/csrc/msmv_sampling/msmv_sampling_forward.cu
================================================
/*!
* Modified from Deformable DETR
*/
#include
#include
#include
#include
#include
#include
#include
#include
#include
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
#define CUDA_NUM_THREADS 512
#define MAX_POINT 32
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
__device__ float ms_deform_attn_im2col_bilinear(
const float*& bottom_data,
const int& height, const int& width, const int& channels,
const float& h, const float& w, const int& c) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h - h_low;
const float lw = w - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;
v1 = bottom_data[ptr1];
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;
v2 = bottom_data[ptr2];
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;
v3 = bottom_data[ptr3];
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;
v4 = bottom_data[ptr4];
}
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
__global__ void ms_deformable_im2col_gpu_kernel_c2345(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col) {
float res[MAX_POINT];
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
_temp /= num_query;
const int b_col = _temp;
for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }
for (int p_col = 0; p_col < num_point; ++p_col) {
// Sampling location in range [0, 1]
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
const float loc_w = data_sampling_loc[data_loc_ptr];
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
// Attn weights
int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;
const float weight_c2 = data_attn_weight[data_weight_ptr];
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
//const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
//const float w_im = loc_w * spatial_w - 0.5;
// C2 Feature
float h_im = loc_h * (h_c2 - 1); // align_corners = True
float w_im = loc_w * (w_c2 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {
const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;
}
// C3 Feature
h_im = loc_h * (h_c3 - 1); // align_corners = True
w_im = loc_w * (w_c3 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {
const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;
}
// C4 Feature
h_im = loc_h * (h_c4 - 1); // align_corners = True
w_im = loc_w * (w_c4 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {
const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;
}
// C5 Feature
h_im = loc_h * (h_c5 - 1); // align_corners = True
w_im = loc_w * (w_c5 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {
const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;
}
}
for (int p_col = 0; p_col < num_point; ++p_col) {
float* data_col_ptr = data_col + index * num_point + p_col;
*data_col_ptr = res[p_col];
}
}
}
__global__ void ms_deformable_im2col_gpu_kernel_c23456(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const float* feat_c6,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const int h_c6, const int w_c6,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col) {
float res[MAX_POINT];
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
_temp /= num_query;
const int b_col = _temp;
for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }
for (int p_col = 0; p_col < num_point; ++p_col) {
// Sampling location in range [0, 1]
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
const float loc_w = data_sampling_loc[data_loc_ptr];
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
// Attn weights
int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;
const float weight_c2 = data_attn_weight[data_weight_ptr];
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
const float weight_c6 = data_attn_weight[data_weight_ptr + 4];
//const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
//const float w_im = loc_w * spatial_w - 0.5;
// C2 Feature
float h_im = loc_h * (h_c2 - 1); // align_corners = True
float w_im = loc_w * (w_c2 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {
const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;
}
// C3 Feature
h_im = loc_h * (h_c3 - 1); // align_corners = True
w_im = loc_w * (w_c3 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {
const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;
}
// C4 Feature
h_im = loc_h * (h_c4 - 1); // align_corners = True
w_im = loc_w * (w_c4 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {
const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;
}
// C5 Feature
h_im = loc_h * (h_c5 - 1); // align_corners = True
w_im = loc_w * (w_c5 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {
const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;
}
// C6 Feature
h_im = loc_h * (h_c6 - 1); // align_corners = True
w_im = loc_w * (w_c6 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6) {
const float* feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col) * weight_c6;
}
}
for (int p_col = 0; p_col < num_point; ++p_col) {
float* data_col_ptr = data_col + index * num_point + p_col;
*data_col_ptr = res[p_col];
}
}
}
void ms_deformable_im2col_cuda_c2345(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col) {
const int num_kernels = batch_size * num_query * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel_c2345 <<>> (
feat_c2, feat_c3, feat_c4, feat_c5, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_im2col_cuda_c2345: %s\n", cudaGetErrorString(err));
}
}
void ms_deformable_im2col_cuda_c23456(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const float* feat_c6,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const int h_c6, const int w_c6,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col) {
const int num_kernels = batch_size * num_query * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel_c23456 <<>> (
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_im2col_cuda_c23456: %s\n", cudaGetErrorString(err));
}
}
================================================
FILE: models/csrc/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
def get_ext_modules():
return [
CUDAExtension(
name='_msmv_sampling_cuda',
sources=[
'msmv_sampling/msmv_sampling.cpp',
'msmv_sampling/msmv_sampling_forward.cu',
'msmv_sampling/msmv_sampling_backward.cu'
],
include_dirs=['msmv_sampling']
)
]
setup(
name='csrc',
ext_modules=get_ext_modules(),
cmdclass={'build_ext': BuildExtension}
)
================================================
FILE: models/csrc/wrapper.py
================================================
import torch
import torch.nn.functional as F
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward
from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward
def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
"""
value: [B, N, H1W1 + H2W2..., C]
sampling_locations: [B, Q, P, 3]
scale_weights: [B, Q, P, 4]
"""
assert scale_weights.shape[-1] == len(mlvl_feats)
B, _, _, _, C = mlvl_feats[0].shape
_, Q, P, _ = sampling_locations.shape
sampling_locations = sampling_locations * 2 - 1
sampling_locations = sampling_locations[:, :, :, None, :] # [B, Q, P, 1, 3]
final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
for lvl, feat in enumerate(mlvl_feats):
feat = feat.permute(0, 4, 1, 2, 3)
out = F.grid_sample(
feat, sampling_locations, mode='bilinear',
padding_mode='zeros', align_corners=True,
)[..., 0] # [B, C, Q, P]
out = out * scale_weights[..., lvl].reshape(B, 1, Q, P)
final += out
return final.permute(0, 2, 1, 3)
class MSMVSamplingC2345(torch.autograd.Function):
@staticmethod
def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights):
ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights)
assert callable(_ms_deform_attn_cuda_c2345_forward)
return _ms_deform_attn_cuda_c2345_forward(
feat_c2, feat_c3, feat_c4, feat_c5,
sampling_locations, scale_weights)
@staticmethod
def backward(ctx, grad_output):
feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights = ctx.saved_tensors
assert callable(_ms_deform_attn_cuda_c2345_backward)
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c2345_backward(grad_output.contiguous(),
feat_c2, feat_c3, feat_c4, feat_c5,
sampling_locations, scale_weights
)
return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight
class MSMVSamplingC23456(torch.autograd.Function):
@staticmethod
def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights):
ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights)
assert callable(_ms_deform_attn_cuda_c23456_forward)
return _ms_deform_attn_cuda_c23456_forward(
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
sampling_locations, scale_weights)
@staticmethod
def backward(ctx, grad_output):
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights = ctx.saved_tensors
assert callable(_ms_deform_attn_cuda_c23456_backward)
grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c23456_backward(grad_output.contiguous(),
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
sampling_locations, scale_weights
)
return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight
def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
sampling_locations = sampling_locations.contiguous()
scale_weights = scale_weights.contiguous()
if len(mlvl_feats) == 4:
return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
elif len(mlvl_feats) == 5:
return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
else:
return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
================================================
FILE: models/loss_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models.builder import LOSSES, build_loss
from mmdet.core import reduce_mean
from .utils import sparse2dense
from torch.cuda.amp import autocast
from torch.autograd import Variable
def get_voxel_decoder_loss_input(voxel_semantics, occ_loc_i, seg_pred_i, scale, num_classes=18):
assert voxel_semantics.shape[0] == 1 # bs = 1
voxel_semantics = voxel_semantics.long()
if seg_pred_i is not None: # semantic prediction
assert seg_pred_i.shape[-1] == num_classes
seg_pred_dense, sparse_mask = sparse2dense(
occ_loc_i, seg_pred_i,
dense_shape=[200 // scale, 200 // scale, 16 // scale, num_classes],
empty_value=torch.zeros((num_classes)).to(seg_pred_i)
)
sparse_mask = F.interpolate(sparse_mask[:, None].float(), scale_factor=scale)[:, 0].bool()
seg_pred_dense = seg_pred_dense.permute(0, 4, 1, 2, 3) # [B, CLS, W, H, D]
seg_pred_dense = F.interpolate(seg_pred_dense, scale_factor=scale)
seg_pred_dense = seg_pred_dense.permute(0, 2, 3, 4, 1) # [B, W, H, D, CLS]
seg_pred_i_sparse = seg_pred_dense[sparse_mask] # [K, CLS]
voxel_semantics_sparse = voxel_semantics[sparse_mask] # [K]
return seg_pred_i_sparse, voxel_semantics_sparse, sparse_mask
def compute_scal_loss(pred, gt, class_id, reverse=False, ignore_index=255):
p = pred[:, class_id, :]
completion_target = (gt == class_id).long()
loss = torch.zeros(pred.shape[0], device=pred.device)
if reverse:
p = 1 - p
completion_target = ((gt != class_id) & (gt != ignore_index)).long()
target_sum = completion_target.sum(dim=(1))
mask = (target_sum > 0)
p = p[torch.where(mask)]
completion_target = completion_target[torch.where(mask)]
nominator = torch.sum(p * completion_target, dim=(1))
p_mask = torch.where(torch.sum(p, dim=(1)) > 0)
if p_mask[0].shape[0] > 0:
precision = nominator[p_mask] / torch.sum(p[p_mask], dim=(1))
loss_precision = F.binary_cross_entropy(
precision, torch.ones_like(precision),
reduction='none'
)
loss[torch.where(mask)[0][p_mask]] += loss_precision
t_mask = torch.where(torch.sum(completion_target, dim=(1)) > 0)
if t_mask[0].shape[0] > 0:
recall = nominator[t_mask] / torch.sum(completion_target[t_mask], dim=(1))
loss_recall = F.binary_cross_entropy(
recall, torch.ones_like(recall),
reduction='none'
)
loss[torch.where(mask)[0][t_mask]] += loss_recall
ct_mask = torch.where(torch.sum(1 - completion_target, dim=(1)) > 0)
if ct_mask[0].shape[0] > 0:
specificity = torch.sum((1 - p[ct_mask]) * (1 - completion_target[ct_mask]), dim=(1)) / (
torch.sum(1 - completion_target[ct_mask], dim=(1))
)
loss_ct = F.binary_cross_entropy(
specificity, torch.ones_like(specificity),
reduction='none'
)
loss[torch.where(mask)[0][ct_mask]] += loss_ct
return loss, mask
@LOSSES.register_module()
class GeoScalLoss(nn.Module):
def __init__(self,
num_classes,
loss_weight=1.0):
super().__init__()
self.num_classes = num_classes
self.loss_weight = loss_weight
def forward(self, pred, gt):
loss = torch.tensor(0, device=pred.device, dtype=pred.dtype)
pred = F.softmax(pred, dim=1)
loss, _ = compute_scal_loss(pred, gt, self.num_classes - 1, reverse=True)
return self.loss_weight * torch.mean(loss)
@LOSSES.register_module()
class SemScalLoss(nn.Module):
def __init__(self,
num_classes,
class_weights=None,
loss_weight=1.0):
super().__init__()
self.num_classes = num_classes
self.class_weights = class_weights
if self.class_weights is not None:
assert len(self.class_weights) == self.num_classes, "number of class weights must equal to class number"
else:
self.class_weights = [1.0 for _ in range(self.num_classes)]
self.loss_weight = loss_weight
def forward(self, pred, gt):
pred = F.softmax(pred, dim=1)
batch_size = pred.shape[0]
loss = torch.zeros(batch_size, device=pred.device)
count = torch.zeros(batch_size, device=pred.device)
for i in range(self.num_classes):
loss_cls, mask_cls = compute_scal_loss(pred, gt, i)
count += mask_cls.long()
loss += loss_cls * self.class_weights[i]
return self.loss_weight * (loss / count).mean()
# borrowed from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L21
def dice_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
mask_camera: torch.Tensor
):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
if mask_camera is not None:
inputs = inputs[:, :, mask_camera]
targets = targets[:, :, mask_camera]
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
targets = targets.squeeze(1)
numerator = 2 * (inputs * targets).sum(-1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_masks
dice_loss_jit = torch.jit.script(
dice_loss
) # type: torch.jit.ScriptModule
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L48
def sigmoid_ce_loss(
inputs: torch.Tensor,
targets: torch.Tensor,
num_masks: float,
mask_camera: torch.Tensor
):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
Loss tensor
"""
# [M, 1, K]
if mask_camera is not None:
mask_camera = mask_camera.to(torch.int32)
mask_camera = mask_camera[None, None, ...].expand(targets.shape[0], 1, mask_camera.shape[-1])
loss = F.binary_cross_entropy_with_logits(inputs, targets, mask_camera, reduction="none")
else:
loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
return loss.mean(2).mean(1).sum() / num_masks
sigmoid_ce_loss_jit = torch.jit.script(
sigmoid_ce_loss
) # type: torch.jit.ScriptModule
def CE_ssc_loss(pred, target, class_weights=None, ignore_index=255):
"""
:param: prediction: the predicted tensor, must be [BS, C, ...]
"""
criterion = nn.CrossEntropyLoss(
weight=class_weights, ignore_index=ignore_index, reduction="mean"
)
with autocast(False):
loss = criterion(pred, target.long())
return loss
# https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L22
def lovasz_grad(gt_sorted):
"""
Computes gradient of the Lovasz extension w.r.t sorted errors
See Alg. 1 in paper
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard
# https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L157
def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
"""
Multi-class Lovasz-Softmax loss
probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
per_image: compute the loss per image instead of per batch
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
for prob, lab in zip(probas, labels))
else:
with autocast(False):
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
return loss
# https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L176
def lovasz_softmax_flat(probas, labels, classes='present'):
"""
Multi-class Lovasz-Softmax loss
probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
labels: [P] Tensor, ground truth labels (between 0 and C - 1)
classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
"""
if probas.numel() == 0:
# only void pixels, the gradients should be 0
return probas * 0.
C = probas.size(1)
losses = []
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
for c in class_to_sum:
fg = (labels == c).float() # foreground for class c
if (classes == 'present' and fg.sum() == 0):
continue
if C == 1:
if len(classes) > 1:
raise ValueError('Sigmoid output possible only with 1 class')
class_pred = probas[:, 0]
else:
class_pred = probas[:, c]
errors = (Variable(fg) - class_pred).abs()
errors_sorted, perm = torch.sort(errors, 0, descending=True)
perm = perm.data
fg_sorted = fg[perm]
losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
return mean(losses)
# https://github.com/NVlabs/FB-BEV/blob/832bd81866823a913a4c69552e1ca61ae34ac211/mmdet3d/models/fbbev/modules/occ_loss_utils/lovasz_softmax.py#L207
def flatten_probas(probas, labels, ignore=None):
"""
Flattens predictions in the batch
"""
if probas.dim() == 2:
if ignore is not None:
valid = (labels != ignore)
probas = probas[valid]
labels = labels[valid]
return probas, labels
elif probas.dim() == 3:
# assumes output of a sigmoid layer
B, H, W = probas.size()
probas = probas.view(B, 1, H, W)
elif probas.dim() == 5:
#3D segmentation
B, C, L, H, W = probas.size()
probas = probas.contiguous().view(B, C, L, H*W)
B, C, H, W = probas.size()
probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return probas, labels
valid = (labels != ignore)
vprobas = probas[valid.nonzero().squeeze()]
vlabels = labels[valid]
return vprobas, vlabels
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/criterion.py#L90
@LOSSES.register_module()
class Mask2FormerLoss(nn.Module):
def __init__(self,
num_classes,
loss_cls_weight=1.0,
loss_mask_weight=1.0,
loss_dice_weight=1.0,
no_class_weight=0.1):
super().__init__()
self.num_classes = num_classes
self.loss_cls_weight = loss_cls_weight
self.loss_mask_weight = loss_mask_weight
self.loss_dice_weight = loss_dice_weight
self.no_class_weight = no_class_weight
self.empty_weight = torch.ones(self.num_classes)
self.empty_weight[-1] = self.no_class_weight
self.loss_cls = build_loss(dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0
))
def forward(self, mask_pred, class_pred, mask_gt, class_gt, indices, mask_camera):
bs = mask_pred.shape[0]
loss_masks = torch.tensor(0).to(mask_pred)
loss_dices = torch.tensor(0).to(mask_pred)
loss_classes = torch.tensor(0).to(mask_pred)
num_total_pos = sum([tc.numel() for tc in class_gt])
avg_factor = torch.clamp(reduce_mean(class_pred.new_tensor([num_total_pos * 1.0])), min=1).item()
for b in range(bs):
mask_camera_b = mask_camera[b] if mask_camera is not None else None# N
tgt_mask = mask_gt[b]
num_instances = class_gt[b].shape[0]
tgt_class = class_gt[b]
tgt_mask = (tgt_mask.unsqueeze(-1) == torch.arange(num_instances).to(mask_gt.device))
tgt_mask = tgt_mask.permute(1, 0)
src_idx, tgt_idx = indices[b]
src_mask = mask_pred[b][src_idx] # [M, N], M is number of gt instances, N is number of remaining voxels
tgt_mask = tgt_mask[tgt_idx] # [M, N]
src_class = class_pred[b] # [Q, CLS]
# pad non-aligned queries' tgt classes with 'no class'
pad_tgt_class = torch.full(
(src_class.shape[0], ), self.num_classes - 1, dtype=torch.int64, device=class_pred.device
) # [Q]
pad_tgt_class[src_idx] = tgt_class[tgt_idx]
# only calculates loss mask for aligned pairs
loss_mask, loss_dice = self.loss_masks(src_mask, tgt_mask, avg_factor=avg_factor, mask_camera=mask_camera_b)
# calculates loss class for all queries
loss_class = self.loss_labels(src_class, pad_tgt_class, self.empty_weight.to(src_class.device), avg_factor=avg_factor)
loss_masks += loss_mask * self.loss_mask_weight
loss_dices += loss_dice * self.loss_dice_weight
loss_classes += loss_class * self.loss_cls_weight
return loss_masks, loss_dices, loss_classes
# mask2former use point sampling to calculate loss of fewer important points
# we omit point sampling as we have limited number of points
def loss_masks(self, src_mask, tgt_mask, avg_factor=None, mask_camera=None):
"""Compute the losses related to the masks: the focal loss and the dice loss.
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
"""
# No need to upsample predictions as we are using normalized coordinates :)
# N x 1 x H x W
num_masks = tgt_mask.shape[0]
src_mask = src_mask.view(num_masks, 1, -1)
tgt_mask = tgt_mask.view(num_masks, 1, -1)
if avg_factor is None:
avg_factor = num_masks
loss_dice = dice_loss(src_mask, tgt_mask, avg_factor, mask_camera)
loss_mask = sigmoid_ce_loss(src_mask, tgt_mask.float(), avg_factor, mask_camera)
return loss_mask, loss_dice
def loss_labels(self, src_class, tgt_class, empty_weight=None, avg_factor=None):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
return self.loss_cls(
src_class, tgt_class, torch.ones_like(tgt_class), avg_factor=avg_factor
).mean()
# --------------------------- HELPER FUNCTIONS ---------------------------
def mean(l, empty=0):
"""
nanmean compatible with generators.
"""
l = iter(l)
try:
n = 1
acc = next(l)
except StopIteration:
if empty == 'raise':
raise ValueError('Empty mean')
return empty
for n, v in enumerate(l, 2):
acc += v
if n == 1:
return acc
return acc / n
================================================
FILE: models/matcher.py
================================================
"""
Modified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py
"""
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from scipy.optimize import linear_sum_assignment
from mmcv.runner import BaseModule
from mmdet.core.bbox.match_costs import build_match_cost
def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor, mask_camera: torch.Tensor):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
if mask_camera is not None:
inputs = inputs[:, mask_camera]
targets = targets[:, mask_camera]
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
batch_dice_loss_jit = torch.jit.script(
batch_dice_loss
) # type: torch.jit.ScriptModule
def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor, mask_camera: torch.Tensor):
"""
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
Loss tensor
"""
hw = inputs.shape[1]
if mask_camera is not None:
mask_camera = mask_camera.to(torch.int32)
mask_camera = mask_camera[None].expand(inputs.shape[0], mask_camera.shape[-1])
pos = F.binary_cross_entropy_with_logits(
inputs, torch.ones_like(inputs), mask_camera, reduction="none"
)
neg = F.binary_cross_entropy_with_logits(
inputs, torch.zeros_like(inputs), mask_camera, reduction="none"
)
else:
pos = F.binary_cross_entropy_with_logits(
inputs, torch.ones_like(inputs), reduction="none"
)
neg = F.binary_cross_entropy_with_logits(
inputs, torch.zeros_like(inputs), reduction="none"
)
loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
"nc,mc->nm", neg, (1 - targets)
)
return loss / hw
batch_sigmoid_ce_loss_jit = torch.jit.script(
batch_sigmoid_ce_loss
) # type: torch.jit.ScriptModule
# modified from https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/modeling/matcher.py#L70
class HungarianMatcher(BaseModule):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
"""
super().__init__()
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
self.loss_focal = build_match_cost(dict(type='FocalLossCost', weight=2.0))
assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
@torch.no_grad()
def forward(self, mask_pred, class_pred, mask_gt, class_gt, mask_camera):
"""
Args:
mask_pred: [bs, num_query, num_voxel (65536)]
class_pred: [bs, num_query, 17]
mask_gt: [bs, num_voxel], value in range [0, num_obj - 1]
class_gt: [[bs0_num_obj], [bs1_num_obj], ...], value in range [0, num_cls - 1]
"""
bs, num_queries = class_pred.shape[:2]
indices = []
# Iterate through batch size
for b in range(bs):
mask_camera_b = mask_camera[b] if mask_camera is not None else None
tgt_ids = class_gt[b]
num_instances = tgt_ids.shape[0] # must be here, cause num of instances may change after masking
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
'''out_prob = class_pred[b].softmax(-1) # [num_queries, num_classes]
cost_class = -out_prob[:, tgt_ids.long()].squeeze(1)'''
# Compute the classification cost. We use focal loss provided by mmdet as sparsebev does
out_prob = class_pred[b] # TODO
cost_class = self.loss_focal(out_prob, tgt_ids.long())
out_mask = mask_pred[b] # [num_queries, H_pred, W_pred]
# gt masks are already padded when preparing target
tgt_mask = mask_gt[b]
tgt_mask = (tgt_mask.unsqueeze(-1) == torch.arange(num_instances).to(mask_gt.device))
tgt_mask = tgt_mask.permute(1, 0) # [Q, N]
# all masks share the same set of points for efficient matching!
tgt_mask = tgt_mask.view(tgt_mask.shape[0], -1)
out_mask = out_mask.view(out_mask.shape[0], -1)
with autocast(enabled=False):
out_mask = out_mask.float()
tgt_mask = tgt_mask.float()
# Compute the focal loss between masks
cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask, mask_camera_b)
# Compute the dice loss betwen masks
cost_dice = batch_dice_loss(out_mask, tgt_mask, mask_camera_b)
# Final cost matrix
C = (
self.cost_mask * cost_mask
+ self.cost_class * cost_class
+ self.cost_dice * cost_dice
)
C = C.reshape(num_queries, -1).cpu()
indices.append(linear_sum_assignment(C))
return [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
for i, j in indices
]
================================================
FILE: models/sparse_voxel_decoder.py
================================================
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mmcv.cnn.bricks.transformer import FFN
from .sparsebev_transformer import SparseBEVSelfAttention, SparseBEVSampling, AdaptiveMixing
from .utils import DUMP, generate_grid, batch_indexing
from .bbox.utils import encode_bbox
import torch.nn.functional as F
def index2point(coords, pc_range, voxel_size):
"""
coords: [B, N, 3], int
pc_range: [-40, -40, -1.0, 40, 40, 5.4]
voxel_size: float
"""
coords = coords * voxel_size
coords = coords + torch.tensor(pc_range[:3], device=coords.device)
return coords
def point2bbox(coords, box_size):
"""
coords: [B, N, 3], float
box_size: float
"""
wlh = torch.ones_like(coords.float()) * box_size
bboxes = torch.cat([coords, wlh], dim=-1) # [B, N, 6]
return bboxes
def upsample(pre_feat, pre_coords, interval):
'''
:param pre_feat: (Tensor), features from last level, (B, N, C)
:param pre_coords: (Tensor), coordinates from last level, (B, N, 3) (3: x, y, z)
:param interval: interval of voxels, interval = scale ** 2
:param num: 1 -> 8
:return: up_feat : upsampled features, (B, N*8, C//8)
:return: up_coords: upsampled coordinates, (B, N*8, 3)
'''
pos_list = [0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2]]
bs, num_query, num_channels = pre_feat.shape
up_feat = pre_feat.reshape(bs, num_query, 8, num_channels // 8) # [B, N, 8, C/8]
up_coords = pre_coords.unsqueeze(2).repeat(1, 1, 8, 1).contiguous() # [B, N, 8, 3]
for i in range(len(pos_list)):
up_coords[:, :, i + 1, pos_list[i]] += interval
up_feat = up_feat.reshape(bs, -1, num_channels // 8)
up_coords = up_coords.reshape(bs, -1, 3)
return up_feat, up_coords
class SparseVoxelDecoder(BaseModule):
def __init__(self,
embed_dims=None,
num_layers=None,
num_frames=None,
num_points=None,
num_groups=None,
num_levels=None,
num_classes=None,
semantic=False,
topk_training=None,
topk_testing=None,
pc_range=None):
super().__init__()
self.embed_dims = embed_dims
self.num_frames = num_frames
self.num_layers = num_layers
self.pc_range = pc_range
self.semantic = semantic
self.voxel_dim = [200, 200, 16]
self.topk_training = topk_training
self.topk_testing = topk_testing
self.decoder_layers = nn.ModuleList()
self.lift_feat_heads = nn.ModuleList()
#self.occ_pred_heads = nn.ModuleList()
if semantic:
self.seg_pred_heads = nn.ModuleList()
for i in range(num_layers):
self.decoder_layers.append(SparseVoxelDecoderLayer(
embed_dims=embed_dims,
num_frames=num_frames,
num_points=num_points // (2 ** i),
num_groups=num_groups,
num_levels=num_levels,
pc_range=pc_range,
self_attn=i in [0, 1]
))
self.lift_feat_heads.append(nn.Sequential(
nn.Linear(embed_dims, embed_dims * 8),
nn.ReLU(inplace=True)
))
#self.occ_pred_heads.append(nn.Linear(embed_dims, 1))
if semantic:
self.seg_pred_heads.append(nn.Linear(embed_dims, num_classes))
@torch.no_grad()
def init_weights(self):
for i in range(len(self.decoder_layers)):
self.decoder_layers[i].init_weights()
def forward(self, mlvl_feats, img_metas):
occ_preds = []
topk = self.topk_training if self.training else self.topk_testing
B = len(img_metas)
# init query coords
interval = 2 ** self.num_layers
query_coord = generate_grid(self.voxel_dim, interval).expand(B, -1, -1) # [B, N, 3]
query_feat = torch.zeros([B, query_coord.shape[1], self.embed_dims], device=query_coord.device) # [B, N, C]
for i, layer in enumerate(self.decoder_layers):
DUMP.stage_count = i
interval = 2 ** (self.num_layers - i) # 8 4 2 1
# bbox from coords
query_bbox = index2point(query_coord, self.pc_range, voxel_size=0.4) # [B, N, 3]
query_bbox = point2bbox(query_bbox, box_size=0.4 * interval) # [B, N, 6]
query_bbox = encode_bbox(query_bbox, pc_range=self.pc_range) # [B, N, 6]
# transformer layer
query_feat = layer(query_feat, query_bbox, mlvl_feats, img_metas) # [B, N, C]
# upsample 2x
query_feat = self.lift_feat_heads[i](query_feat) # [B, N, 8C]
query_feat_2x, query_coord_2x = upsample(query_feat, query_coord, interval // 2)
if self.semantic:
seg_pred_2x = self.seg_pred_heads[i](query_feat_2x) # [B, K, CLS]
else:
seg_pred_2x = None
# sparsify after seg_pred
non_free_prob = 1 - F.softmax(seg_pred_2x, dim=-1)[..., -1] # [B, K]
indices = torch.topk(non_free_prob, k=topk[i], dim=1)[1] # [B, K]
query_coord_2x = batch_indexing(query_coord_2x, indices, layout='channel_last') # [B, K, 3]
query_feat_2x = batch_indexing(query_feat_2x, indices, layout='channel_last') # [B, K, C]
seg_pred_2x = batch_indexing(seg_pred_2x, indices, layout='channel_last') # [B, K, CLS]
occ_preds.append((
torch.div(query_coord_2x, interval // 2, rounding_mode='trunc').long(),
None,
seg_pred_2x,
query_feat_2x,
interval // 2)
)
query_coord = query_coord_2x.detach()
query_feat = query_feat_2x.detach()
return occ_preds
class SparseVoxelDecoderLayer(BaseModule):
def __init__(self,
embed_dims=None,
num_frames=None,
num_points=None,
num_groups=None,
num_levels=None,
pc_range=None,
self_attn=True):
super().__init__()
self.position_encoder = nn.Sequential(
nn.Linear(3, embed_dims),
nn.LayerNorm(embed_dims),
nn.ReLU(inplace=True),
nn.Linear(embed_dims, embed_dims),
nn.LayerNorm(embed_dims),
nn.ReLU(inplace=True),
)
if self_attn:
self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range, scale_adaptive=True)
self.norm1 = nn.LayerNorm(embed_dims)
else:
self.self_attn = None
self.sampling = SparseBEVSampling(
embed_dims=embed_dims,
num_frames=num_frames,
num_groups=num_groups,
num_points=num_points,
num_levels=num_levels,
pc_range=pc_range
)
self.mixing = AdaptiveMixing(
in_dim=embed_dims,
in_points=num_points * num_frames,
n_groups=num_groups,
out_points=num_points * num_frames * num_groups
)
self.ffn = FFN(embed_dims, feedforward_channels=embed_dims * 2, ffn_drop=0.1)
self.norm2 = nn.LayerNorm(embed_dims)
self.norm3 = nn.LayerNorm(embed_dims)
@torch.no_grad()
def init_weights(self):
if self.self_attn is not None:
self.self_attn.init_weights()
self.sampling.init_weights()
self.mixing.init_weights()
self.ffn.init_weights()
def forward(self, query_feat, query_bbox, mlvl_feats, img_metas):
query_pos = self.position_encoder(query_bbox[..., :3])
query_feat = query_feat + query_pos
if self.self_attn is not None:
query_feat = self.norm1(self.self_attn(query_bbox, query_feat))
sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas)
query_feat = self.norm2(self.mixing(sampled_feat, query_feat))
query_feat = self.norm3(self.ffn(query_feat))
return query_feat
================================================
FILE: models/sparsebev_head.py
================================================
import math
import torch
import torch.nn as nn
from mmcv.runner import force_fp32
from mmdet.core import multi_apply, reduce_mean
from mmdet.models import HEADS
from mmdet.models.dense_heads import DETRHead
from mmdet3d.core.bbox.coders import build_bbox_coder
from mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes
from .bbox.utils import normalize_bbox, encode_bbox
@HEADS.register_module()
class SparseBEVHead(DETRHead):
def __init__(self,
*args,
num_classes,
in_channels,
query_denoising=True,
query_denoising_groups=10,
bbox_coder=None,
code_size=10,
code_weights=[1.0] * 10,
train_cfg=dict(),
test_cfg=dict(max_per_img=100),
**kwargs):
self.code_size = code_size
self.code_weights = code_weights
self.num_classes = num_classes
self.in_channels = in_channels
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.fp16_enabled = False
self.embed_dims = in_channels
super(SparseBEVHead, self).__init__(num_classes, in_channels, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs)
self.code_weights = nn.Parameter(torch.tensor(self.code_weights), requires_grad=False)
self.bbox_coder = build_bbox_coder(bbox_coder)
self.pc_range = self.bbox_coder.pc_range
self.dn_enabled = query_denoising
self.dn_group_num = query_denoising_groups
self.dn_weight = 1.0
self.dn_bbox_noise_scale = 0.5
self.dn_label_noise_scale = 0.5
def _init_layers(self):
self.init_query_bbox = nn.Embedding(self.num_query, 10) # (x, y, z, w, l, h, sin, cos, vx, vy)
self.label_enc = nn.Embedding(self.num_classes + 1, self.embed_dims - 1) # DAB-DETR
nn.init.zeros_(self.init_query_bbox.weight[:, 2:3])
nn.init.zeros_(self.init_query_bbox.weight[:, 8:10])
nn.init.constant_(self.init_query_bbox.weight[:, 5:6], 1.5)
grid_size = int(math.sqrt(self.num_query))
assert grid_size * grid_size == self.num_query
x = y = torch.arange(grid_size)
xx, yy = torch.meshgrid(x, y, indexing='ij') # [0, grid_size - 1]
xy = torch.cat([xx[..., None], yy[..., None]], dim=-1)
xy = (xy + 0.5) / grid_size # [0.5, grid_size - 0.5] / grid_size ~= (0, 1)
with torch.no_grad():
self.init_query_bbox.weight[:, :2] = xy.reshape(-1, 2) # [Q, 2]
def init_weights(self):
self.transformer.init_weights()
def forward(self, mlvl_feats, img_metas):
query_bbox = self.init_query_bbox.weight.clone() # [Q, 10]
#query_bbox[..., :3] = query_bbox[..., :3].sigmoid()
B = mlvl_feats[0].shape[0]
query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)
cls_scores, bbox_preds = self.transformer(
query_bbox,
query_feat,
mlvl_feats,
attn_mask=attn_mask,
img_metas=img_metas,
)
bbox_preds[..., 0] = bbox_preds[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
bbox_preds[..., 1] = bbox_preds[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
bbox_preds[..., 2] = bbox_preds[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]
bbox_preds = torch.cat([
bbox_preds[..., 0:2],
bbox_preds[..., 3:5],
bbox_preds[..., 2:3],
bbox_preds[..., 5:10],
], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy]
if mask_dict is not None and mask_dict['pad_size'] > 0:
output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]
output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]
output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]
output_bbox_preds = bbox_preds[:, :, mask_dict['pad_size']:, :]
mask_dict['output_known_lbs_bboxes'] = (output_known_cls_scores, output_known_bbox_preds)
outs = {
'all_cls_scores': output_cls_scores,
'all_bbox_preds': output_bbox_preds,
'enc_cls_scores': None,
'enc_bbox_preds': None,
'dn_mask_dict': mask_dict,
}
else:
outs = {
'all_cls_scores': cls_scores,
'all_bbox_preds': bbox_preds,
'enc_cls_scores': None,
'enc_bbox_preds': None,
}
return outs
def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):
device = init_query_bbox.device
indicator0 = torch.zeros([self.num_query, 1], device=device)
init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)
init_query_feat = torch.cat([init_query_feat, indicator0], dim=1)
if self.training and self.dn_enabled:
targets = [{
'bboxes': torch.cat([m['gt_bboxes_3d'].gravity_center,
m['gt_bboxes_3d'].tensor[:, 3:]], dim=1).cuda(),
'labels': m['gt_labels_3d'].cuda().long()
} for m in img_metas]
known = [torch.ones_like(t['labels'], device=device) for t in targets]
known_num = [sum(k) for k in known]
# can be modified to selectively denosie some label or boxes; also known label prediction
unmask_bbox = unmask_label = torch.cat(known)
labels = torch.cat([t['labels'] for t in targets]).clone()
bboxes = torch.cat([t['bboxes'] for t in targets]).clone()
batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
known_indice = torch.nonzero(unmask_label + unmask_bbox)
known_indice = known_indice.view(-1)
# add noise
known_indice = known_indice.repeat(self.dn_group_num, 1).view(-1)
known_labels = labels.repeat(self.dn_group_num, 1).view(-1)
known_bid = batch_idx.repeat(self.dn_group_num, 1).view(-1)
known_bboxs = bboxes.repeat(self.dn_group_num, 1) # 9
known_labels_expand = known_labels.clone()
known_bbox_expand = known_bboxs.clone()
# noise on the box
if self.dn_bbox_noise_scale > 0:
wlh = known_bbox_expand[..., 3:6].clone()
rand_prob = torch.rand_like(known_bbox_expand) * 2 - 1.0
known_bbox_expand[..., 0:3] += torch.mul(rand_prob[..., 0:3], wlh / 2) * self.dn_bbox_noise_scale
# known_bbox_expand[..., 3:6] += torch.mul(rand_prob[..., 3:6], wlh) * self.dn_bbox_noise_scale
# known_bbox_expand[..., 6:7] += torch.mul(rand_prob[..., 6:7], 3.14159) * self.dn_bbox_noise_scale
known_bbox_expand = encode_bbox(known_bbox_expand, self.pc_range)
known_bbox_expand[..., 0:3].clamp_(min=0.0, max=1.0)
# nn.init.constant(known_bbox_expand[..., 8:10], 0.0)
# noise on the label
if self.dn_label_noise_scale > 0:
p = torch.rand_like(known_labels_expand.float())
chosen_indice = torch.nonzero(p < self.dn_label_noise_scale).view(-1) # usually half of bbox noise
new_label = torch.randint_like(chosen_indice, 0, self.num_classes) # randomly put a new one here
known_labels_expand.scatter_(0, chosen_indice, new_label)
known_feat_expand = label_enc(known_labels_expand)
indicator1 = torch.ones([known_feat_expand.shape[0], 1], device=device) # add dn part indicator
known_feat_expand = torch.cat([known_feat_expand, indicator1], dim=1)
# construct final query
dn_single_pad = int(max(known_num))
dn_pad_size = int(dn_single_pad * self.dn_group_num)
dn_query_bbox = torch.zeros([dn_pad_size, init_query_bbox.shape[-1]], device=device)
dn_query_feat = torch.zeros([dn_pad_size, self.embed_dims], device=device)
input_query_bbox = torch.cat([dn_query_bbox, init_query_bbox], dim=0).repeat(batch_size, 1, 1)
input_query_feat = torch.cat([dn_query_feat, init_query_feat], dim=0).repeat(batch_size, 1, 1)
if len(known_num):
map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
map_known_indice = torch.cat([map_known_indice + dn_single_pad * i for i in range(self.dn_group_num)]).long()
if len(known_bid):
input_query_bbox[known_bid.long(), map_known_indice] = known_bbox_expand
input_query_feat[(known_bid.long(), map_known_indice)] = known_feat_expand
total_size = dn_pad_size + self.num_query
attn_mask = torch.ones([total_size, total_size], device=device) < 0
# match query cannot see the reconstruct
attn_mask[dn_pad_size:, :dn_pad_size] = True
for i in range(self.dn_group_num):
if i == 0:
attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True
if i == self.dn_group_num - 1:
attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True
else:
attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True
attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True
mask_dict = {
'known_indice': torch.as_tensor(known_indice).long(),
'batch_idx': torch.as_tensor(batch_idx).long(),
'map_known_indice': torch.as_tensor(map_known_indice).long(),
'known_lbs_bboxes': (known_labels, known_bboxs),
'pad_size': dn_pad_size
}
else:
input_query_bbox = init_query_bbox.repeat(batch_size, 1, 1)
input_query_feat = init_query_feat.repeat(batch_size, 1, 1)
attn_mask = None
mask_dict = None
return input_query_bbox, input_query_feat, attn_mask, mask_dict
def prepare_for_dn_loss(self, mask_dict):
cls_scores, bbox_preds = mask_dict['output_known_lbs_bboxes']
known_labels, known_bboxs = mask_dict['known_lbs_bboxes']
map_known_indice = mask_dict['map_known_indice'].long()
known_indice = mask_dict['known_indice'].long()
batch_idx = mask_dict['batch_idx'].long()
bid = batch_idx[known_indice]
num_tgt = known_indice.numel()
if len(cls_scores) > 0:
cls_scores = cls_scores.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
bbox_preds = bbox_preds.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
return known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt
def dn_loss_single(self,
cls_scores,
bbox_preds,
known_bboxs,
known_labels,
num_total_pos=None):
# Compute the average number of gt boxes accross all gpus
num_total_pos = cls_scores.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1.0).item()
# cls loss
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
bbox_weights = torch.ones_like(bbox_preds)
label_weights = torch.ones_like(known_labels)
loss_cls = self.loss_cls(
cls_scores,
known_labels.long(),
label_weights,
avg_factor=num_total_pos
)
# regression L1 loss
bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
normalized_bbox_targets = normalize_bbox(known_bboxs)
isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
bbox_weights = bbox_weights * self.code_weights
loss_bbox = self.loss_bbox(
bbox_preds[isnotnan, :10],
normalized_bbox_targets[isnotnan, :10],
bbox_weights[isnotnan, :10],
avg_factor=num_total_pos
)
loss_cls = self.dn_weight * torch.nan_to_num(loss_cls)
loss_bbox = self.dn_weight * torch.nan_to_num(loss_bbox)
return loss_cls, loss_bbox
@force_fp32(apply_to=('preds_dicts'))
def calc_dn_loss(self, loss_dict, preds_dicts, num_dec_layers):
known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt = \
self.prepare_for_dn_loss(preds_dicts['dn_mask_dict'])
all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)]
all_known_labels_list = [known_labels for _ in range(num_dec_layers)]
all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)]
dn_losses_cls, dn_losses_bbox = multi_apply(
self.dn_loss_single, cls_scores, bbox_preds,
all_known_bboxs_list, all_known_labels_list, all_num_tgts_list)
loss_dict['loss_cls_dn'] = dn_losses_cls[-1]
loss_dict['loss_bbox_dn'] = dn_losses_bbox[-1]
num_dec_layer = 0
for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls_dn'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox_dn'] = loss_bbox_i
num_dec_layer += 1
return loss_dict
def _get_target_single(self,
cls_score,
bbox_pred,
gt_labels,
gt_bboxes,
gt_bboxes_ignore=None):
num_bboxes = bbox_pred.size(0)
# assigner and sampler
assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, gt_labels, gt_bboxes_ignore, self.code_weights, True)
sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label targets
labels = gt_bboxes.new_full((num_bboxes, ), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_bboxes.new_ones(num_bboxes)
# bbox targets
bbox_targets = torch.zeros_like(bbox_pred)[..., :9]
bbox_weights = torch.zeros_like(bbox_pred)
bbox_weights[pos_inds] = 1.0
# DETR
bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds)
def get_targets(self,
cls_scores_list,
bbox_preds_list,
gt_bboxes_list,
gt_labels_list,
gt_bboxes_ignore_list=None):
assert gt_bboxes_ignore_list is None, \
'Only supports for gt_bboxes_ignore setting to None.'
num_imgs = len(cls_scores_list)
gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)]
(labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
self._get_target_single, cls_scores_list, bbox_preds_list,
gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, num_total_pos, num_total_neg)
def loss_single(self,
cls_scores,
bbox_preds,
gt_bboxes_list,
gt_labels_list,
gt_bboxes_ignore_list=None):
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list)
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
labels = torch.cat(labels_list, 0)
label_weights = torch.cat(label_weights_list, 0)
bbox_targets = torch.cat(bbox_targets_list, 0)
bbox_weights = torch.cat(bbox_weights_list, 0)
# classification loss
cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
# construct weighted avg_factor to match with the official DETR repo
cls_avg_factor = num_total_pos * 1.0 + \
num_total_neg * self.bg_cls_weight
if self.sync_cls_avg_factor:
cls_avg_factor = reduce_mean(
cls_scores.new_tensor([cls_avg_factor]))
cls_avg_factor = max(cls_avg_factor, 1)
loss_cls = self.loss_cls(
cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
# Compute the average number of gt boxes accross all gpus, for
# normalization purposes
num_total_pos = loss_cls.new_tensor([num_total_pos])
num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
# regression L1 loss
bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
normalized_bbox_targets = normalize_bbox(bbox_targets)
isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
bbox_weights = bbox_weights * self.code_weights
loss_bbox = self.loss_bbox(
bbox_preds[isnotnan, :10],
normalized_bbox_targets[isnotnan, :10],
bbox_weights[isnotnan, :10],
avg_factor=num_total_pos
)
loss_cls = torch.nan_to_num(loss_cls)
loss_bbox = torch.nan_to_num(loss_bbox)
return loss_cls, loss_bbox
@force_fp32(apply_to=('preds_dicts'))
def loss(self,
gt_bboxes_list,
gt_labels_list,
preds_dicts,
gt_bboxes_ignore=None):
assert gt_bboxes_ignore is None, \
f'{self.__class__.__name__} only supports ' \
f'for gt_bboxes_ignore setting to None.'
all_cls_scores = preds_dicts['all_cls_scores']
all_bbox_preds = preds_dicts['all_bbox_preds']
enc_cls_scores = preds_dicts['enc_cls_scores']
enc_bbox_preds = preds_dicts['enc_bbox_preds']
num_dec_layers = len(all_cls_scores)
device = gt_labels_list[0].device
gt_bboxes_list = [torch.cat(
(gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
dim=1).to(device) for gt_bboxes in gt_bboxes_list]
all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]
losses_cls, losses_bbox = multi_apply(
self.loss_single, all_cls_scores, all_bbox_preds,
all_gt_bboxes_list, all_gt_labels_list,
all_gt_bboxes_ignore_list)
loss_dict = dict()
# loss of proposal generated from encode feature map
if enc_cls_scores is not None:
binary_labels_list = [
torch.zeros_like(gt_labels_list[i])
for i in range(len(all_gt_labels_list))
]
enc_loss_cls, enc_losses_bbox = \
self.loss_single(enc_cls_scores, enc_bbox_preds,
gt_bboxes_list, binary_labels_list, gt_bboxes_ignore)
loss_dict['enc_loss_cls'] = enc_loss_cls
loss_dict['enc_loss_bbox'] = enc_losses_bbox
if 'dn_mask_dict' in preds_dicts and preds_dicts['dn_mask_dict'] is not None:
loss_dict = self.calc_dn_loss(loss_dict, preds_dicts, num_dec_layers)
# loss from the last decoder layer
loss_dict['loss_cls'] = losses_cls[-1]
loss_dict['loss_bbox'] = losses_bbox[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]):
loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
num_dec_layer += 1
return loss_dict
@force_fp32(apply_to=('preds_dicts'))
def get_bboxes(self, preds_dicts, img_metas, rescale=False):
preds_dicts = self.bbox_coder.decode(preds_dicts)
num_samples = len(preds_dicts)
ret_list = []
for i in range(num_samples):
preds = preds_dicts[i]
bboxes = preds['bboxes']
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
bboxes = LiDARInstance3DBoxes(bboxes, 9)
scores = preds['scores']
labels = preds['labels']
ret_list.append([bboxes, scores, labels])
return ret_list
================================================
FILE: models/sparsebev_sampling.py
================================================
import torch
from .bbox.utils import decode_bbox
from .utils import rotation_3d_in_axis, DUMP
from .csrc.wrapper import msmv_sampling
def make_sample_points_from_bbox(query_bbox, offset, pc_range):
'''
query_bbox: [B, Q, 10]
offset: [B, Q, num_points, 4], normalized by stride
'''
query_bbox = decode_bbox(query_bbox, pc_range) # [B, Q, 9]
xyz = query_bbox[..., 0:3] # [B, Q, 3]
wlh = query_bbox[..., 3:6] # [B, Q, 3]
# NOTE: different from SparseBEV
xyz += wlh / 2 # conver to center
delta_xyz = offset[..., 0:3] # [B, Q, P, 3]
delta_xyz = wlh[:, :, None, :] * delta_xyz # [B, Q, P, 3]
if query_bbox.shape[-1] > 6:
ang = query_bbox[..., 6:7] # [B, Q, 1]
delta_xyz = rotation_3d_in_axis(delta_xyz, ang) # [B, Q, P, 3]
sample_xyz = xyz[:, :, None, :] + delta_xyz # [B, Q, P, 3]
return sample_xyz # [B, Q, P, 3]
def make_sample_points_from_mask(valid_map, pc_range, occ_size, num_points, occ_loc=None, offset=None):
'''
valid_map: [B, Q, W, H, D] or [B, Q, N]
occ_loc: [B, N, 3] if valid map is sparse
Return: [B, Q, GP, 3] in pc_range
'''
B, Q = valid_map.shape[:2]
occ_size = torch.tensor(occ_size).to(valid_map.device)
sampling_pts = []
for b in range(B):
indices = torch.where(valid_map[b])
if indices[0].shape[0] == 0:
pts = torch.rand((Q, num_points, 3)).to(valid_map.device)
else:
if len(valid_map.shape) == 5:
bin_count = valid_map[b].sum(dim=(1,2,3))
else:
bin_count = valid_map[b].sum(dim=1)
sampling_rand = torch.rand((Q, num_points)).to(bin_count.device)
sampling_index = (sampling_rand * bin_count[:, None]).floor().long()
low_bound = torch.cumsum(bin_count, dim=0) - bin_count
sampling_index = sampling_index + low_bound[:, None]
sampling_index[sampling_index >= indices[0].shape[0]] = indices[0].shape[0] -1 # this can happen when zeros appear in the tail
sampling_index = sampling_index.to(valid_map.device)
if occ_loc is None: # dense occ points
pts = torch.stack((indices[1][sampling_index], indices[2][sampling_index], indices[3][sampling_index]))
pts = pts.permute(1, 2, 0)
else:
occ_idx = indices[1][sampling_index]
pts = occ_loc[b][occ_idx]
# pad queries with no valid occ
pts = pts.float()
rand_sampling_points = torch.rand(((bin_count==0).sum(), num_points, 3)).to(pts.device) * occ_size
pts[bin_count==0] = rand_sampling_points
sampling_pts.append(pts)
sampling_pts = torch.stack(sampling_pts)
if offset is not None:
sampling_pts = sampling_pts + offset
sampling_pts = sampling_pts / occ_size
sampling_pts[..., 0] = sampling_pts[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0]
sampling_pts[..., 1] = sampling_pts[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1]
sampling_pts[..., 2] = sampling_pts[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2]
return sampling_pts
def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
B, Q, T, G, P, _ = sample_points.shape # [B, Q, T, G, P, 4]
N = 6
sample_points = sample_points.reshape(B, Q, T, G * P, 3)
if DUMP.enabled:
torch.save(sample_points,
'{}/sample_points_3d_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
# get the projection matrix
lidar2img = lidar2img[:, :(T*N), None, None, :, :] # [B, TN, 1, 1, 4, 4]
lidar2img = lidar2img.expand(B, T*N, Q, G * P, 4, 4)
lidar2img = lidar2img.reshape(B, T, N, Q, G*P, 4, 4)
# expand the points
ones = torch.ones_like(sample_points[..., :1])
sample_points = torch.cat([sample_points, ones], dim=-1) # [B, Q, GP, 4]
sample_points = sample_points[:, :, None, ..., None] # [B, Q, T, GP, 4]
sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1]
# project 3d sampling points to image
sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4]
# homo coord -> pixel coord
homo = sample_points_cam[..., 2:3]
homo_nonzero = torch.maximum(homo, torch.zeros_like(homo) + eps)
sample_points_cam = sample_points_cam[..., 0:2] / homo_nonzero # [B, T, N, Q, GP, 2]
# normalize
sample_points_cam[..., 0] /= image_w
sample_points_cam[..., 1] /= image_h
# check if out of image
valid_mask = ((homo > eps) \
& (sample_points_cam[..., 1:2] > 0.0)
& (sample_points_cam[..., 1:2] < 1.0)
& (sample_points_cam[..., 0:1] > 0.0)
& (sample_points_cam[..., 0:1] < 1.0)
).squeeze(-1).float() # [B, T, N, Q, GP]
if DUMP.enabled:
torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1),
'{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
torch.save(valid_mask,
'{}/sample_points_cam_valid_mask_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device)
i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1)
i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1]
sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2]
valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1]
sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / 5], dim=-1)
sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3]
sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
C = final.shape[2] # [BTG, Q, C, P]
final = final.reshape(B, T, G, Q, C, P)
final = final.permute(0, 3, 2, 1, 5, 4)
final = final.flatten(3, 4) # [B, Q, G, FP, C]
return final
================================================
FILE: models/sparsebev_transformer.py
================================================
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from mmcv.runner import BaseModule
from mmcv.cnn import bias_init_with_prob
from mmcv.cnn.bricks.transformer import MultiheadAttention, FFN
from mmdet.models.utils.builder import TRANSFORMER
from .bbox.utils import decode_bbox
from .utils import inverse_sigmoid, DUMP
from .sparsebev_sampling import sampling_4d, make_sample_points_from_bbox
from .checkpoint import checkpoint as cp
@TRANSFORMER.register_module()
class SparseBEVTransformer(BaseModule):
def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(SparseBEVTransformer, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.pc_range = pc_range
self.decoder = SparseBEVTransformerDecoder(embed_dims, num_frames, num_points, num_layers, num_levels, num_classes, code_size, pc_range=pc_range)
@torch.no_grad()
def init_weights(self):
self.decoder.init_weights()
def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
cls_scores, bbox_preds = self.decoder(query_bbox, query_feat, mlvl_feats, attn_mask, img_metas)
cls_scores = torch.nan_to_num(cls_scores)
bbox_preds = torch.nan_to_num(bbox_preds)
return cls_scores, bbox_preds
class SparseBEVTransformerDecoder(BaseModule):
def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None):
super(SparseBEVTransformerDecoder, self).__init__(init_cfg)
self.num_layers = num_layers
self.pc_range = pc_range
self.decoder_layer = SparseBEVTransformerDecoderLayer(
embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range
)
@torch.no_grad()
def init_weights(self):
self.decoder_layer.init_weights()
def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
cls_scores, bbox_preds = [], []
timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
time_diff = timestamps[:, :1, :] - timestamps
time_diff = np.mean(time_diff, axis=-1).astype(np.float32) # [B, F]
time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
img_metas[0]['time_diff'] = time_diff
lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
img_metas[0]['lidar2img'] = lidar2img
for lvl, feat in enumerate(mlvl_feats):
B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
N, T, G, C = 6, TN // 6, 4, GC // 4
feat = feat.reshape(B, T, N, G, C, H, W)
feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C]
feat = feat.reshape(B*T*G, N, H, W, C) # [BTG, C, N, H, W]
mlvl_feats[lvl] = feat.contiguous()
for i in range(self.num_layers):
DUMP.stage_count = i
query_feat, cls_score, bbox_pred = self.decoder_layer(
query_bbox, query_feat, mlvl_feats, attn_mask, img_metas
)
query_bbox = bbox_pred.clone().detach()
cls_scores.append(cls_score)
bbox_preds.append(bbox_pred)
cls_scores = torch.stack(cls_scores)
bbox_preds = torch.stack(bbox_preds)
return cls_scores, bbox_preds
class SparseBEVTransformerDecoderLayer(BaseModule):
def __init__(self, embed_dims, num_frames=8, num_points=4, num_levels=4, num_classes=10, code_size=10, num_cls_fcs=2, num_reg_fcs=2, pc_range=[], init_cfg=None):
super(SparseBEVTransformerDecoderLayer, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.num_classes = num_classes
self.code_size = code_size
self.pc_range = pc_range
self.position_encoder = nn.Sequential(
nn.Linear(3, self.embed_dims),
nn.LayerNorm(self.embed_dims),
nn.ReLU(inplace=True),
nn.Linear(self.embed_dims, self.embed_dims),
nn.LayerNorm(self.embed_dims),
nn.ReLU(inplace=True),
)
self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range)
self.sampling = SparseBEVSampling(embed_dims, num_frames=num_frames, num_groups=4, num_points=num_points, num_levels=num_levels, pc_range=pc_range)
self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=4, out_points=128)
self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1)
self.norm1 = nn.LayerNorm(embed_dims)
self.norm2 = nn.LayerNorm(embed_dims)
self.norm3 = nn.LayerNorm(embed_dims)
cls_branch = []
for _ in range(num_cls_fcs):
cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
cls_branch.append(nn.LayerNorm(self.embed_dims))
cls_branch.append(nn.ReLU(inplace=True))
cls_branch.append(nn.Linear(self.embed_dims, self.num_classes))
self.cls_branch = nn.Sequential(*cls_branch)
reg_branch = []
for _ in range(num_reg_fcs):
reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
reg_branch.append(nn.ReLU(inplace=True))
reg_branch.append(nn.Linear(self.embed_dims, self.code_size))
self.reg_branch = nn.Sequential(*reg_branch)
@torch.no_grad()
def init_weights(self):
self.self_attn.init_weights()
self.sampling.init_weights()
self.mixing.init_weights()
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.cls_branch[-1].bias, bias_init)
def refine_bbox(self, bbox_proposal, bbox_delta):
xyz = inverse_sigmoid(bbox_proposal[..., 0:3])
xyz_delta = bbox_delta[..., 0:3]
xyz_new = torch.sigmoid(xyz_delta + xyz)
return torch.cat([xyz_new, bbox_delta[..., 3:]], dim=-1)
def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
"""
query_bbox: [B, Q, 10] [cx, cy, cz, w, h, d, rot.sin, rot.cos, vx, vy]
"""
query_pos = self.position_encoder(query_bbox[..., :3])
query_feat = query_feat + query_pos
query_feat = self.norm1(self.self_attn(query_bbox, query_feat, attn_mask))
sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas)
query_feat = self.norm2(self.mixing(sampled_feat, query_feat))
query_feat = self.norm3(self.ffn(query_feat))
cls_score = self.cls_branch(query_feat) # [B, Q, num_classes]
bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size]
bbox_pred = self.refine_bbox(query_bbox, bbox_pred)
time_diff = img_metas[0]['time_diff'] # [B, F]
if time_diff.shape[1] > 1:
time_diff = time_diff.clone()
time_diff[time_diff < 1e-5] = 1.0
bbox_pred[..., 8:] = bbox_pred[..., 8:] / time_diff[:, 1:2, None]
if DUMP.enabled:
query_bbox_dec = decode_bbox(query_bbox, self.pc_range)
bbox_pred_dec = decode_bbox(bbox_pred, self.pc_range)
cls_score_sig = torch.sigmoid(cls_score)
torch.save(query_bbox_dec, '{}/query_bbox_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
torch.save(bbox_pred_dec, '{}/bbox_pred_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
torch.save(cls_score_sig, '{}/cls_score_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
return query_feat, cls_score, bbox_pred
class SparseBEVSelfAttention(BaseModule):
def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], scale_adaptive=True):
super().__init__()
self.pc_range = pc_range
self.attention = MultiheadAttention(embed_dims, num_heads, dropout, batch_first=True)
if scale_adaptive:
self.gen_tau = nn.Linear(embed_dims, num_heads)
else:
self.gen_tau = None
@torch.no_grad()
def init_weights(self):
if self.gen_tau is not None:
nn.init.zeros_(self.gen_tau.weight)
nn.init.uniform_(self.gen_tau.bias, 0.0, 2.0)
def inner_forward(self, query_bbox, query_feat, pre_attn_mask=None):
"""
query_bbox: [B, Q, 10]
query_feat: [B, Q, C]
"""
if self.gen_tau is not None:
dist = self.calc_bbox_dists(query_bbox)
tau = self.gen_tau(query_feat) # [B, Q, 8]
if DUMP.enabled:
torch.save(tau, '{}/sasa_tau_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
tau = tau.permute(0, 2, 1) # [B, 8, Q]
attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q]
if pre_attn_mask is not None:
attn_mask[:, :, pre_attn_mask] = float('-inf')
attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q]
else:
attn_mask = None
return self.attention(query_feat, attn_mask=attn_mask)
def forward(self, query_bbox, query_feat, pre_attn_mask=None):
if self.training and query_feat.requires_grad:
return cp(self.inner_forward, query_bbox, query_feat, pre_attn_mask, use_reentrant=False)
else:
return self.inner_forward(query_bbox, query_feat, pre_attn_mask)
@torch.no_grad()
def calc_bbox_dists(self, bboxes):
centers = decode_bbox(bboxes, self.pc_range)[..., :2] # [B, Q, 2]
dist = []
for b in range(centers.shape[0]):
dist_b = torch.norm(centers[b].reshape(-1, 1, 2) - centers[b].reshape(1, -1, 2), dim=-1)
dist.append(dist_b[None, ...])
dist = torch.cat(dist, dim=0) # [B, Q, Q]
dist = -dist
return dist
class SparseBEVSampling(BaseModule):
def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None):
super().__init__(init_cfg)
self.num_frames = num_frames
self.num_points = num_points
self.num_groups = num_groups
self.num_levels = num_levels
self.pc_range = pc_range
self.sampling_offset = nn.Linear(embed_dims, num_groups * num_points * 3)
self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels)
def init_weights(self):
bias = self.sampling_offset.bias.data.view(self.num_groups * self.num_points, 3)
nn.init.zeros_(self.sampling_offset.weight)
nn.init.uniform_(bias[:, 0:3], -0.5, 0.5)
def inner_forward(self, query_bbox, query_feat, mlvl_feats, img_metas):
'''
query_bbox: [B, Q, 10]
query_feat: [B, Q, C]
'''
B, Q = query_bbox.shape[:2]
image_h, image_w, _ = img_metas[0]['img_shape'][0]
# sampling offset of all frames
sampling_offset = self.sampling_offset(query_feat)
sampling_offset = sampling_offset.view(B, Q, self.num_groups * self.num_points, 3)
sampling_points = make_sample_points_from_bbox(query_bbox, sampling_offset, self.pc_range) # [B, Q, GP, 3]
sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3)
sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3)
# warp sample points based on velocity
if query_bbox.shape[-1] > 8:
time_diff = img_metas[0]['time_diff'] # [B, F]
time_diff = time_diff[:, None, :, None] # [B, 1, F, 1]
vel = query_bbox[..., 8:].detach() # [B, Q, 2]
vel = vel[:, :, None, :] # [B, Q, 1, 2]
dist = vel * time_diff # [B, Q, F, 2]
dist = dist[:, :, :, None, None, :] # [B, Q, F, 1, 1, 2]
sampling_points = torch.cat([
sampling_points[..., 0:2] - dist,
sampling_points[..., 2:3]
], dim=-1)
# scale weights
scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels)
scale_weights = torch.softmax(scale_weights, dim=-1)
scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels)
# sampling
sampled_feats = sampling_4d(
sampling_points,
mlvl_feats,
scale_weights,
img_metas[0]['lidar2img'],
image_h, image_w
) # [B, Q, G, FP, C]
return sampled_feats
def forward(self, query_bbox, query_feat, mlvl_feats, img_metas):
if self.training and query_feat.requires_grad:
return cp(self.inner_forward, query_bbox, query_feat, mlvl_feats, img_metas, use_reentrant=False)
else:
return self.inner_forward(query_bbox, query_feat, mlvl_feats, img_metas)
class AdaptiveMixing(nn.Module):
def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
super(AdaptiveMixing, self).__init__()
out_dim = out_dim if out_dim is not None else in_dim
out_points = out_points if out_points is not None else in_points
query_dim = query_dim if query_dim is not None else in_dim
self.query_dim = query_dim
self.in_dim = in_dim
self.in_points = in_points
self.n_groups = n_groups
self.out_dim = out_dim
self.out_points = out_points
self.eff_in_dim = in_dim // n_groups
self.eff_out_dim = out_dim // n_groups
self.m_parameters = self.eff_in_dim * self.eff_out_dim
self.s_parameters = self.in_points * self.out_points
self.total_parameters = self.m_parameters + self.s_parameters
self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters)
self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim)
self.act = nn.ReLU(inplace=True)
@torch.no_grad()
def init_weights(self):
nn.init.zeros_(self.parameter_generator.weight)
def inner_forward(self, x, query):
B, Q, G, P, C = x.shape
assert G == self.n_groups
assert P == self.in_points
assert C == self.eff_in_dim
'''generate mixing parameters'''
params = self.parameter_generator(query)
params = params.reshape(B*Q, G, -1)
out = x.reshape(B*Q, G, P, C)
M, S = params.split([self.m_parameters, self.s_parameters], 2)
M = M.reshape(B*Q, G, self.eff_in_dim, self.eff_out_dim)
S = S.reshape(B*Q, G, self.out_points, self.in_points)
'''adaptive channel mixing'''
out = torch.matmul(out, M)
out = F.layer_norm(out, [out.size(-2), out.size(-1)])
out = self.act(out)
'''adaptive point mixing'''
out = torch.matmul(S, out) # implicitly transpose and matmul
out = F.layer_norm(out, [out.size(-2), out.size(-1)])
out = self.act(out)
'''linear transfomation to query dim'''
out = out.reshape(B, Q, -1)
out = self.out_proj(out)
out = query + out
return out
def forward(self, x, query):
if self.training and x.requires_grad:
return cp(self.inner_forward, x, query, use_reentrant=False)
else:
return self.inner_forward(x, query)
class AdaptiveMixingPointOnly(nn.Module):
def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
super(AdaptiveMixingPointOnly, self).__init__()
out_dim = out_dim if out_dim is not None else in_dim
out_points = out_points if out_points is not None else in_points
query_dim = query_dim if query_dim is not None else in_dim
self.query_dim = query_dim
self.in_dim = in_dim
self.in_points = in_points
self.n_groups = n_groups
self.out_dim = out_dim
self.out_points = out_points
self.eff_in_dim = in_dim // n_groups
self.eff_out_dim = out_dim // n_groups
self.s_parameters = self.in_points * self.out_points
self.total_parameters = self.s_parameters
self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters)
self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim)
self.act = nn.ReLU(inplace=True)
@torch.no_grad()
def init_weights(self):
nn.init.zeros_(self.parameter_generator.weight)
def inner_forward(self, x, query):
B, Q, G, P, C = x.shape
assert G == self.n_groups
assert P == self.in_points
assert C == self.eff_in_dim
'''generate mixing parameters'''
params = self.parameter_generator(query)
params = params.reshape(B*Q, G, -1)
out = x.reshape(B*Q, G, P, C)
S = params.reshape(B*Q, G, self.out_points, self.in_points)
'''adaptive spatial mixing'''
out = torch.matmul(S, out) # implicitly transpose and matmul
out = F.layer_norm(out, [out.size(-2), out.size(-1)])
out = self.act(out)
'''linear transfomation to query dim'''
out = out.reshape(B, Q, -1)
out = self.out_proj(out)
out = query + out
return out
def forward(self, x, query):
if self.training and x.requires_grad:
return cp(self.inner_forward, x, query, use_reentrant=False)
else:
return self.inner_forward(x, query)
class DeformAggregation(nn.Module):
def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
super(DeformAggregation, self).__init__()
out_dim = out_dim if out_dim is not None else in_dim
out_points = out_points if out_points is not None else in_points
query_dim = query_dim if query_dim is not None else in_dim
self.query_dim = query_dim
self.in_dim = in_dim
self.in_points = in_points
self.n_groups = n_groups
self.out_dim = out_dim
self.out_points = out_points
self.eff_in_dim = in_dim // n_groups
self.eff_out_dim = out_dim // n_groups
self.attn_weights = nn.Linear(query_dim, n_groups * in_points)
self.out_proj = nn.Linear(self.eff_in_dim * n_groups, self.query_dim)
self.act = nn.ReLU(inplace=True)
@torch.no_grad()
def init_weights(self):
pass
def inner_forward(self, x, query):
B, Q, G, P, C = x.shape
assert G == self.n_groups
assert P == self.in_points
assert C == self.eff_in_dim
out = x.reshape(B, Q, G, P, C)
attn_weights = self.attn_weights(query) # [B, Q, GP]
attn_weights = attn_weights.reshape(B, Q, self.n_groups, self.in_points, 1) # [B, Q, G, P, 1]
attn_weights = attn_weights.softmax(dim=-2)
out = torch.sum(out * attn_weights, dim=-2) # [B, Q, G, C]
out = out.reshape(B, Q, -1)
out = self.out_proj(out)
out = query + out
return out
def forward(self, x, query):
if self.training and x.requires_grad:
return cp(self.inner_forward, x, query, use_reentrant=False)
else:
return self.inner_forward(x, query)
================================================
FILE: models/sparseocc.py
================================================
import torch
import queue
import numpy as np
from mmcv.runner import get_dist_info
from mmcv.runner.fp16_utils import cast_tensor_type
from mmcv.runner import force_fp32, auto_fp16
from mmdet.models import DETECTORS
from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
from .utils import pad_multiple, GpuPhotoMetricDistortion
@DETECTORS.register_module()
class SparseOcc(MVXTwoStageDetector):
def __init__(self,
pts_voxel_layer=None,
pts_voxel_encoder=None,
pts_middle_encoder=None,
pts_fusion_layer=None,
img_backbone=None,
pts_backbone=None,
img_neck=None,
pts_neck=None,
pts_bbox_head=None,
img_roi_head=None,
img_rpn_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
data_aug=None,
use_mask_camera=False,
**kwargs):
super(SparseOcc, self).__init__(pts_voxel_layer, pts_voxel_encoder,
pts_middle_encoder, pts_fusion_layer,
img_backbone, pts_backbone, img_neck, pts_neck,
pts_bbox_head, img_roi_head, img_rpn_head,
train_cfg, test_cfg, pretrained)
self.use_mask_camera = use_mask_camera
self.fp16_enabled = False
self.data_aug = data_aug
self.color_aug = GpuPhotoMetricDistortion()
self.memory = {}
self.queue = queue.Queue()
@auto_fp16(apply_to=('img'), out_fp32=True)
def extract_img_feat(self, img):
img_feats = self.img_backbone(img)
if isinstance(img_feats, dict):
img_feats = list(img_feats.values())
if self.with_img_neck:
img_feats = self.img_neck(img_feats)
return img_feats
@auto_fp16(apply_to=('img'))
def extract_feat(self, img, img_metas=None):
"""Extract features from images and points."""
if len(img.shape) == 6:
img = img.flatten(1, 2) # [B, TN, C, H, W]
B, N, C, H, W = img.size()
img = img.view(B * N, C, H, W)
img = img.float()
if self.data_aug is not None:
if 'img_color_aug' in self.data_aug and self.data_aug['img_color_aug'] and self.training:
img = self.color_aug(img)
if 'img_norm_cfg' in self.data_aug:
img_norm_cfg = self.data_aug['img_norm_cfg']
norm_mean = torch.tensor(img_norm_cfg['mean'], device=img.device)
norm_std = torch.tensor(img_norm_cfg['std'], device=img.device)
if img_norm_cfg['to_rgb']:
img = img[:, [2, 1, 0], :, :] # BGR to RGB
img = img - norm_mean.reshape(1, 3, 1, 1)
img = img / norm_std.reshape(1, 3, 1, 1)
for b in range(B):
img_shape = (img.shape[2], img.shape[3], img.shape[1])
img_metas[b]['img_shape'] = [img_shape for _ in range(N)]
img_metas[b]['ori_shape'] = [img_shape for _ in range(N)]
if 'img_pad_cfg' in self.data_aug:
img_pad_cfg = self.data_aug['img_pad_cfg']
img = pad_multiple(img, img_metas, size_divisor=img_pad_cfg['size_divisor'])
H, W = img.shape[-2:]
input_shape = img.shape[-2:]
# update real input shape of each single img
for img_meta in img_metas:
img_meta.update(input_shape=input_shape)
img_feats = self.extract_img_feat(img)
img_feats_reshaped = []
for img_feat in img_feats:
BN, C, H, W = img_feat.size()
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
def forward_pts_train(self, mlvl_feats, voxel_semantics, voxel_instances, instance_class_ids, mask_camera, img_metas):
"""
voxel_semantics: [bs, 200, 200, 16], value in range [0, num_cls - 1]
voxel_instances: [bs, 200, 200, 16], value in range [0, num_obj - 1]
instance_class_ids: [[bs0_num_obj], [bs1_num_obj], ...], value in range [0, num_cls - 1]
"""
outs = self.pts_bbox_head(mlvl_feats, img_metas)
loss_inputs = [voxel_semantics, voxel_instances, instance_class_ids, outs]
return self.pts_bbox_head.loss(*loss_inputs)
def forward(self, return_loss=True, **kwargs):
if return_loss:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
@force_fp32(apply_to=('img'))
def forward_train(self, img_metas=None, img=None, voxel_semantics=None, voxel_instances=None, instance_class_ids=None, mask_camera=None, **kwargs):
img_feats = self.extract_feat(img=img, img_metas=img_metas)
return self.forward_pts_train(img_feats, voxel_semantics, voxel_instances, instance_class_ids, mask_camera, img_metas)
def forward_test(self, img_metas, img=None, **kwargs):
output = self.simple_test(img_metas, img)
sem_pred = output['sem_pred'].cpu().numpy().astype(np.uint8)
occ_loc = output['occ_loc'].cpu().numpy().astype(np.uint8)
batch_size = sem_pred.shape[0]
if 'pano_inst' and 'pano_sem' in output:
# important: uint8 is not enough for pano_pred
pano_inst = output['pano_inst'].cpu().numpy().astype(np.int16)
pano_sem = output['pano_sem'].cpu().numpy().astype(np.uint8)
return [{
'sem_pred': sem_pred[b:b+1],
'pano_inst': pano_inst[b:b+1],
'pano_sem': pano_sem[b:b+1],
'occ_loc': occ_loc[b:b+1]
} for b in range(batch_size)]
else:
return [{
'sem_pred': sem_pred[b:b+1],
'occ_loc': occ_loc[b:b+1]
} for b in range(batch_size)]
def simple_test_pts(self, x, img_metas, rescale=False):
outs = self.pts_bbox_head(x, img_metas)
outs = self.pts_bbox_head.merge_occ_pred(outs)
return outs
def simple_test(self, img_metas, img=None, rescale=False):
world_size = get_dist_info()[1]
if world_size == 1: # online
return self.simple_test_online(img_metas, img, rescale)
else: # offline
return self.simple_test_offline(img_metas, img, rescale)
def simple_test_offline(self, img_metas, img=None, rescale=False):
img_feats = self.extract_feat(img=img, img_metas=img_metas)
return self.simple_test_pts(img_feats, img_metas, rescale=rescale)
def simple_test_online(self, img_metas, img=None, rescale=False):
self.fp16_enabled = False
assert len(img_metas) == 1 # batch_size = 1
B, N, C, H, W = img.shape
img = img.reshape(B, N//6, 6, C, H, W)
img_filenames = img_metas[0]['filename']
num_frames = len(img_filenames) // 6
# assert num_frames == img.shape[1]
img_shape = (H, W, C)
img_metas[0]['img_shape'] = [img_shape for _ in range(len(img_filenames))]
img_metas[0]['ori_shape'] = [img_shape for _ in range(len(img_filenames))]
img_metas[0]['pad_shape'] = [img_shape for _ in range(len(img_filenames))]
img_feats_list, img_metas_list = [], []
# extract feature frame by frame
for i in range(num_frames):
img_indices = list(np.arange(i * 6, (i + 1) * 6))
img_metas_curr = [{}]
for k in img_metas[0].keys():
if isinstance(img_metas[0][k], list):
img_metas_curr[0][k] = [img_metas[0][k][i] for i in img_indices]
if img_filenames[img_indices[0]] in self.memory:
# found in memory
img_feats_curr = self.memory[img_filenames[img_indices[0]]]
else:
# extract feature and put into memory
img_feats_curr = self.extract_feat(img[:, i], img_metas_curr)
self.memory[img_filenames[img_indices[0]]] = img_feats_curr
self.queue.put(img_filenames[img_indices[0]])
while self.queue.qsize() > 16: # avoid OOM
pop_key = self.queue.get()
self.memory.pop(pop_key)
img_feats_list.append(img_feats_curr)
img_metas_list.append(img_metas_curr)
# reorganize
feat_levels = len(img_feats_list[0])
img_feats_reorganized = []
for j in range(feat_levels):
feat_l = torch.cat([img_feats_list[i][j] for i in range(len(img_feats_list))], dim=0)
feat_l = feat_l.flatten(0, 1)[None, ...]
img_feats_reorganized.append(feat_l)
img_metas_reorganized = img_metas_list[0]
for i in range(1, len(img_metas_list)):
for k, v in img_metas_list[i][0].items():
if isinstance(v, list):
img_metas_reorganized[0][k].extend(v)
img_feats = img_feats_reorganized
img_metas = img_metas_reorganized
img_feats = cast_tensor_type(img_feats, torch.half, torch.float32)
# run detector
return self.simple_test_pts(img_feats, img_metas, rescale=rescale)
================================================
FILE: models/sparseocc_head.py
================================================
import numpy as np
import torch
import torch.nn as nn
from mmdet.models import HEADS
from mmcv.runner import force_fp32, auto_fp16
from mmdet.models.builder import build_loss
from mmdet.models.utils import build_transformer
from .matcher import HungarianMatcher
from .loss_utils import CE_ssc_loss, lovasz_softmax, get_voxel_decoder_loss_input
NUSC_CLASS_FREQ = np.array([
944004, 1897170, 152386, 2391677, 16957802, 724139, 189027, 2074468, 413451, 2384460,
5916653, 175883646, 4275424, 51393615, 61411620, 105975596, 116424404, 1892500630
])
@HEADS.register_module()
class SparseOccHead(nn.Module):
def __init__(self,
transformer=None,
class_names=None,
embed_dims=None,
occ_size=None,
pc_range=None,
loss_cfgs=None,
panoptic=False,
**kwargs):
super(SparseOccHead, self).__init__()
self.num_classes = len(class_names)
self.class_names = class_names
self.pc_range = pc_range
self.occ_size = occ_size
self.embed_dims = embed_dims
self.score_threshold = 0.3
self.overlap_threshold = 0.8
self.panoptic = panoptic
self.transformer = build_transformer(transformer)
self.criterions = {k: build_loss(loss_cfg) for k, loss_cfg in loss_cfgs.items()}
self.matcher = HungarianMatcher(cost_class=2.0, cost_mask=5.0, cost_dice=5.0)
self.class_weights = torch.from_numpy(1 / np.log(NUSC_CLASS_FREQ + 0.001))
def init_weights(self):
self.transformer.init_weights()
@auto_fp16(apply_to=('mlvl_feats'))
def forward(self, mlvl_feats, img_metas):
occ_preds, mask_preds, class_preds = self.transformer(mlvl_feats, img_metas=img_metas)
return {
'occ_preds': occ_preds,
'mask_preds': mask_preds,
'class_preds': class_preds
}
@force_fp32(apply_to=('preds_dicts'))
def loss(self, voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera=None):
return self.loss_single(voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera)
def loss_single(self, voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera=None):
loss_dict = {}
B = voxel_instances.shape[0]
if mask_camera is not None:
assert mask_camera.shape == voxel_semantics.shape
assert mask_camera.dtype == torch.bool
for i, (occ_loc_i, _, seg_pred_i, _, scale) in enumerate(preds_dicts['occ_preds']):
loss_dict_i = {}
for b in range(B):
loss_dict_i_b = {}
seg_pred_i_sparse, voxel_semantics_sparse, sparse_mask = get_voxel_decoder_loss_input(
voxel_semantics[b:b + 1],
occ_loc_i[b:b + 1],
seg_pred_i[b:b + 1] if seg_pred_i is not None else None,
scale,
self.num_classes
)
loss_dict_i_b['loss_sem_lovasz'] = lovasz_softmax(torch.softmax(seg_pred_i_sparse, dim=1), voxel_semantics_sparse)
valid_mask = (voxel_semantics_sparse < 255)
seg_pred_i_sparse = seg_pred_i_sparse[valid_mask].transpose(0, 1).unsqueeze(0) # [K, CLS] -> [B, CLS, K]
voxel_semantics_sparse = voxel_semantics_sparse[valid_mask].unsqueeze(0) # [K] -> [B, K]
if 'loss_geo_scal' in self.criterions.keys():
loss_dict_i_b['loss_geo_scal'] = self.criterions['loss_geo_scal'](seg_pred_i_sparse, voxel_semantics_sparse)
if 'loss_sem_scal' in self.criterions.keys():
loss_dict_i_b['loss_sem_scal'] = self.criterions['loss_sem_scal'](seg_pred_i_sparse, voxel_semantics_sparse)
loss_dict_i_b['loss_sem_ce'] = CE_ssc_loss(seg_pred_i_sparse, voxel_semantics_sparse, self.class_weights.type_as(seg_pred_i_sparse))
for loss_key in loss_dict_i_b.keys():
loss_dict_i[loss_key] = loss_dict_i.get(loss_key, 0) + loss_dict_i_b[loss_key] / B
for k, v in loss_dict_i.items():
loss_dict['%s_%d' % (k, i)] = v
occ_loc = preds_dicts['occ_preds'][-1][0]
batch_idx = torch.arange(B)[:, None, None].expand(B, occ_loc.shape[1], 1).to(occ_loc.device)
occ_loc = occ_loc.reshape(-1, 3)
voxel_instances = voxel_instances[batch_idx.reshape(-1), occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]]
voxel_instances = voxel_instances.reshape(B, -1) # [B, N]
if mask_camera is not None:
mask_camera = mask_camera[batch_idx.reshape(-1), occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]]
mask_camera = mask_camera.reshape(B, -1) # [B, N]
# drop instances if it has no positive voxels
for b in range(B):
instance_count = instance_class_ids[b].shape[0]
instance_voxel_counts = torch.bincount(voxel_instances[b].long()) # [255]
id_map = torch.cumsum(instance_voxel_counts > 0, dim=0) - 1
id_map[255] = 255 # empty space still has an id of 255
voxel_instances[b] = id_map[voxel_instances[b].long()]
instance_class_ids[b] = instance_class_ids[b][instance_voxel_counts[:instance_count] > 0]
for i, pred in enumerate(preds_dicts['mask_preds']):
indices = self.matcher(pred, preds_dicts['class_preds'][i], voxel_instances, instance_class_ids, mask_camera)
loss_mask, loss_dice, loss_class = self.criterions['loss_mask2former'](
pred, preds_dicts['class_preds'][i], voxel_instances, instance_class_ids, indices, mask_camera)
loss_dict['loss_mask_{:d}'.format(i)] = loss_mask
loss_dict['loss_dice_mask_{:d}'.format(i)] = loss_dice
loss_dict['loss_class_{:d}'.format(i)] = loss_class
return loss_dict
def merge_occ_pred(self, outs):
mask_cls = outs['class_preds'][-1].sigmoid()
mask_pred = outs['mask_preds'][-1].sigmoid()
occ_indices = outs['occ_preds'][-1][0]
sem_pred = self.merge_semseg(mask_cls, mask_pred) # [B, C, N]
outs['sem_pred'] = sem_pred
outs['occ_loc'] = occ_indices
if self.panoptic:
pano_inst, pano_sem = self.merge_panoseg(mask_cls, mask_pred) # [B, C, N]
outs['pano_inst'] = pano_inst
outs['pano_sem'] = pano_sem
return outs
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/mask_former_model.py#L242
def merge_semseg(self, mask_cls, mask_pred):
valid_mask = mask_cls.max(dim=-1).values > self.score_threshold
mask_cls[~valid_mask] = 0.0
semseg = torch.einsum("bqc,bqn->bcn", mask_cls, mask_pred)
if semseg.shape[1] == self.num_classes:
semseg = semseg[:, :-1]
cls_score, cls_id = torch.max(semseg, dim=1)
cls_id[cls_score < 0.01] = self.num_classes - 1
return cls_id # [B, N]
def merge_panoseg(self, mask_cls, mask_pred):
pano_inst, pano_sem = [], []
for b in range(mask_cls.shape[0]):
pano_inst_b, pano_sem_b = self.merge_panoseg_single(
mask_cls[b:b + 1],
mask_pred[b:b + 1]
)
pano_inst.append(pano_inst_b)
pano_sem.append(pano_sem_b)
pano_inst = torch.cat(pano_inst, dim=0)
pano_sem = torch.cat(pano_sem, dim=0)
return pano_inst, pano_sem
# https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/maskformer_model.py#L286
def merge_panoseg_single(self, mask_cls, mask_pred):
assert mask_cls.shape[0] == 1, "bs != 1"
scores, labels = mask_cls.max(-1)
# filter out low score and background instances
keep = labels.ne(self.num_classes - 1) & (scores > self.score_threshold)
cur_scores = scores[keep]
cur_classes = labels[keep]
cur_masks = mask_pred[keep]
cur_prob_masks = cur_scores.view(-1, 1) * cur_masks
N = cur_masks.shape[-1]
instance_seg = torch.zeros((N), dtype=torch.int32, device=cur_masks.device)
semantic_seg = torch.ones((N), dtype=torch.int32, device=cur_masks.device) * (self.num_classes - 1)
current_segment_id = 0
stuff_memory_list = {self.num_classes - 1: 0}
# skip all process if no mask is detected
if cur_masks.shape[0] != 0:
# take argmax
cur_mask_ids = cur_prob_masks.argmax(0) # [N]
for k in range(cur_classes.shape[0]):
pred_class = cur_classes[k].item()
# moving objects are treated as instances
is_thing = self.class_names[pred_class] in [
'car', 'truck', 'construction_vehicle', 'bus',
'trailer', 'motorcycle', 'bicycle', 'pedestrian'
]
mask_area = (cur_mask_ids == k).sum().item()
original_area = (cur_masks[k] >= 0.5).sum().item()
mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
if mask_area / original_area < self.overlap_threshold:
continue
# merge stuff regions
if not is_thing:
if int(pred_class) in stuff_memory_list.keys():
instance_seg[mask] = stuff_memory_list[int(pred_class)]
continue
else:
stuff_memory_list[int(pred_class)] = current_segment_id + 1
current_segment_id += 1
instance_seg[mask] = current_segment_id
semantic_seg[mask] = pred_class
instance_seg = instance_seg.unsqueeze(0)
semantic_seg = semantic_seg.unsqueeze(0)
return instance_seg, semantic_seg # [B, N]
================================================
FILE: models/sparseocc_transformer.py
================================================
import copy
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from mmdet.models.utils.builder import TRANSFORMER
from mmcv.cnn.bricks.transformer import FFN
from .sparsebev_transformer import AdaptiveMixing
from .utils import DUMP
from .checkpoint import checkpoint as cp
from .sparsebev_sampling import sampling_4d, make_sample_points_from_mask
from .sparse_voxel_decoder import SparseVoxelDecoder
@TRANSFORMER.register_module()
class SparseOccTransformer(BaseModule):
def __init__(self,
embed_dims=None,
num_layers=None,
num_queries=None,
num_frames=None,
num_points=None,
num_groups=None,
num_levels=None,
num_classes=None,
pc_range=None,
occ_size=None,
topk_training=None,
topk_testing=None):
super().__init__()
self.num_frames = num_frames
self.voxel_decoder = SparseVoxelDecoder(
embed_dims=embed_dims,
num_layers=3,
num_frames=num_frames,
num_points=num_points,
num_groups=num_groups,
num_levels=num_levels,
num_classes=num_classes,
pc_range=pc_range,
semantic=True,
topk_training=topk_training,
topk_testing=topk_testing
)
self.decoder = MaskFormerOccDecoder(
embed_dims=embed_dims,
num_layers=num_layers,
num_frames=num_frames,
num_queries=num_queries,
num_points=num_points,
num_groups=num_groups,
num_levels=num_levels,
num_classes=num_classes,
pc_range=pc_range,
occ_size=occ_size,
)
@torch.no_grad()
def init_weights(self):
self.voxel_decoder.init_weights()
self.decoder.init_weights()
def forward(self, mlvl_feats, img_metas):
for lvl, feat in enumerate(mlvl_feats):
B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
N, T, G, C = 6, TN // 6, 4, GC // 4
feat = feat.reshape(B, T, N, G, C, H, W)
feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C]
feat = feat.reshape(B*T*G, N, H, W, C) # [BTG, N, H, W, C]
mlvl_feats[lvl] = feat.contiguous()
lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
lidar2img = torch.from_numpy(lidar2img).to(feat.device) # [B, N, 4, 4]
ego2lidar = np.asarray([m['ego2lidar'] for m in img_metas]).astype(np.float32)
ego2lidar = torch.from_numpy(ego2lidar).to(feat.device) # [B, N, 4, 4]
img_metas = copy.deepcopy(img_metas)
img_metas[0]['lidar2img'] = torch.matmul(lidar2img, ego2lidar)
occ_preds = self.voxel_decoder(mlvl_feats, img_metas=img_metas)
mask_preds, class_preds = self.decoder(occ_preds, mlvl_feats, img_metas)
return occ_preds, mask_preds, class_preds
class MaskFormerOccDecoder(BaseModule):
def __init__(self,
embed_dims=None,
num_layers=None,
num_frames=None,
num_queries=None,
num_points=None,
num_groups=None,
num_levels=None,
num_classes=None,
pc_range=None,
occ_size=None):
super().__init__()
self.num_layers = num_layers
self.num_queries = num_queries
self.num_frames = num_frames
self.decoder_layer = MaskFormerOccDecoderLayer(
embed_dims=embed_dims,
mask_dim=embed_dims,
num_frames=num_frames,
num_points=num_points,
num_groups=num_groups,
num_levels=num_levels,
num_classes=num_classes,
pc_range=pc_range,
occ_size=occ_size,
)
self.query_feat = nn.Embedding(num_queries, embed_dims)
self.query_pos = nn.Embedding(num_queries, embed_dims)
@torch.no_grad()
def init_weights(self):
self.decoder_layer.init_weights()
def forward(self, occ_preds, mlvl_feats, img_metas):
occ_loc, occ_pred, _, mask_feat, _ = occ_preds[-1]
bs = mask_feat.shape[0]
query_feat = self.query_feat.weight[None].repeat(bs, 1, 1)
query_pos = self.query_pos.weight[None].repeat(bs, 1, 1)
valid_map, mask_pred, class_pred = self.decoder_layer.pred_segmentation(query_feat, mask_feat)
class_preds = [class_pred]
mask_preds = [mask_pred]
for i in range(self.num_layers):
DUMP.stage_count = i
query_feat, valid_map, mask_pred, class_pred = self.decoder_layer(
query_feat, valid_map, mask_pred, occ_preds, mlvl_feats, query_pos, img_metas
)
mask_preds.append(mask_pred)
class_preds.append(class_pred)
return mask_preds, class_preds
class MaskFormerOccDecoderLayer(BaseModule):
def __init__(self,
embed_dims=None,
mask_dim=None,
num_frames=None,
num_queries=None,
num_points=None,
num_groups=None,
num_levels=None,
num_classes=None,
pc_range=None,
occ_size=None):
super().__init__()
self.pc_range = pc_range
self.occ_size = occ_size
self.self_attn = MaskFormerSelfAttention(embed_dims, num_heads=8)
self.sampling = MaskFormerSampling(embed_dims, num_frames, num_groups, num_points, num_levels, pc_range=pc_range, occ_size=occ_size)
self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=num_groups, out_points=128)
self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1)
self.mask_proj = nn.Linear(embed_dims, mask_dim)
self.classifier = nn.Linear(embed_dims, num_classes - 1)
self.norm1 = nn.LayerNorm(embed_dims)
self.norm2 = nn.LayerNorm(embed_dims)
self.norm3 = nn.LayerNorm(embed_dims)
@torch.no_grad()
def init_weights(self):
self.self_attn.init_weights()
self.sampling.init_weights()
self.mixing.init_weights()
self.ffn.init_weights()
def forward(self, query_feat, valid_map, mask_pred, occ_preds, mlvl_feats, query_pos, img_metas):
"""
query_feat: [bs, num_query, embed_dim]
valid_map: [bs, num_query, num_voxel]
mask_pred: [bs, num_query, num_voxel]
occ_preds: list(occ_loc, occ_pred, _, mask_feat, scale), all voxel decoder's outputs
mask_feat: [bs, num_voxel, embed_dim]
occ_pred: [bs, num_voxel]
occ_loc: [bs, num_voxel, 3]
"""
occ_loc, occ_pred, _, mask_feat, _ = occ_preds[-1]
query_feat = self.norm1(self.self_attn(query_feat, query_pos=query_pos))
sampled_feat = self.sampling(query_feat, valid_map, occ_loc, mlvl_feats, img_metas)
query_feat = self.norm2(self.mixing(sampled_feat, query_feat))
query_feat = self.norm3(self.ffn(query_feat))
valid_map, mask_pred, class_pred = self.pred_segmentation(query_feat, mask_feat)
return query_feat, valid_map, mask_pred, class_pred
def pred_segmentation(self, query_feat, mask_feat):
if self.training and query_feat.requires_grad:
return cp(self.inner_pred_segmentation, query_feat, mask_feat, use_reentrant=False)
else:
return self.inner_pred_segmentation(query_feat, mask_feat)
def inner_pred_segmentation(self, query_feat, mask_feat):
class_pred = self.classifier(query_feat)
feat_proj = self.mask_proj(query_feat)
mask_pred = torch.einsum("bqc,bnc->bqn", feat_proj, mask_feat)
valid_map = (mask_pred > 0.0)
return valid_map, mask_pred, class_pred
class MaskFormerSelfAttention(BaseModule):
def __init__(self, embed_dims, num_heads, dropout=0.0):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dims, num_heads, dropout=dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.activation = nn.ReLU(inplace=True)
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def with_pos_embed(self, tensor, pos=None):
return tensor if pos is None else tensor + pos
def inner_forward(self, query, mask = None, key_padding_mask = None,query_pos= None):
q = k = self.with_pos_embed(query, query_pos)
tgt = self.self_attn(q, k, value=query, attn_mask=mask, key_padding_mask=key_padding_mask)[0]
query = query + self.dropout(tgt)
return query
def forward(self, query, mask = None, key_padding_mask = None,query_pos= None):
if self.training and query.requires_grad:
return cp(self.inner_forward, query, mask, key_padding_mask, query_pos, use_reentrant=False)
else:
return self.inner_forward(query, mask, key_padding_mask, query_pos)
class MaskFormerSampling(BaseModule):
def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], occ_size=[], init_cfg=None):
super().__init__(init_cfg)
self.num_frames = num_frames
self.num_points = num_points
self.num_groups = num_groups
self.num_levels = num_levels
self.pc_range = pc_range
self.occ_size = occ_size
self.offset = nn.Linear(embed_dims, num_groups * num_points * 3)
self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels)
def init_weights(self, ):
nn.init.zeros_(self.offset.weight)
nn.init.zeros_(self.offset.bias)
def inner_forward(self, query_feat, valid_map, occ_loc, mlvl_feats, img_metas):
'''
valid_map: [B, Q, W, H, D]
query_feat: [B, Q, C]
'''
B, Q = query_feat.shape[:2]
image_h, image_w, _ = img_metas[0]['img_shape'][0]
# sampling offset of all frames
offset = self.offset(query_feat).view(B, Q, self.num_groups * self.num_points, 3) # [B, Q, GP, 3]
sampling_points = make_sample_points_from_mask(valid_map, self.pc_range, self.occ_size, self.num_groups*self.num_points, occ_loc, offset)
sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3)
sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3)
# scale weights
scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels)
scale_weights = torch.softmax(scale_weights, dim=-1)
scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels)
# sampling
sampled_feats = sampling_4d(
sampling_points,
mlvl_feats,
scale_weights,
img_metas[0]['lidar2img'],
image_h, image_w
) # [B, Q, G, FP, C]
return sampled_feats
def forward(self, query_feat, valid_map, occ_loc, mlvl_feats, img_metas):
if self.training and query_feat.requires_grad:
return cp(self.inner_forward, query_feat, valid_map, occ_loc, mlvl_feats, img_metas, use_reentrant=False)
else:
return self.inner_forward(query_feat, valid_map, occ_loc, mlvl_feats, img_metas)
================================================
FILE: models/utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy import random
from mmcv.cnn.bricks import ConvTranspose3d, Conv3d
def conv3d_gn_relu(in_channels, out_channels, kernel_size=1, stride=1):
return nn.Sequential(
Conv3d(in_channels, out_channels, kernel_size, stride, bias=False),
nn.GroupNorm(16, out_channels),
nn.ReLU(inplace=True),
)
def deconv3d_gn_relu(in_channels, out_channels, kernel_size=2, stride=2):
return nn.Sequential(
ConvTranspose3d(in_channels, out_channels, kernel_size, stride, bias=False),
nn.GroupNorm(16, out_channels),
nn.ReLU(inplace=True),
)
def sparse2dense(indices, value, dense_shape, empty_value=0):
B, N = indices.shape[:2] # [B, N, 3]
batch_index = torch.arange(B).unsqueeze(1).expand(B, N)
dense = torch.ones([B] + dense_shape, device=value.device, dtype=value.dtype) * empty_value
dense[batch_index, indices[..., 0], indices[..., 1], indices[..., 2]] = value
mask = torch.zeros([B] + dense_shape[:3], dtype=torch.bool, device=value.device)
mask[batch_index, indices[..., 0], indices[..., 1], indices[..., 2]] = 1
return dense, mask
@torch.no_grad()
def generate_grid(n_vox, interval):
# Create voxel grid
grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)]
grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2], indexing='ij')) # 3 dx dy dz
grid = grid.cuda().view(3, -1).permute(1, 0) # N, 3
return grid[None] # 1, N, 3
def batch_indexing(batched_data: torch.Tensor, batched_indices: torch.Tensor, layout='channel_first'):
def batch_indexing_channel_first(batched_data: torch.Tensor, batched_indices: torch.Tensor):
"""
:param batched_data: [batch_size, C, N]
:param batched_indices: [batch_size, I1, I2, ..., Im]
:return: indexed data: [batch_size, C, I1, I2, ..., Im]
"""
def product(arr):
p = 1
for i in arr:
p *= i
return p
assert batched_data.shape[0] == batched_indices.shape[0]
batch_size, n_channels = batched_data.shape[:2]
indices_shape = list(batched_indices.shape[1:])
batched_indices = batched_indices.reshape([batch_size, 1, -1])
batched_indices = batched_indices.expand([batch_size, n_channels, product(indices_shape)])
result = torch.gather(batched_data, dim=2, index=batched_indices.to(torch.int64))
result = result.view([batch_size, n_channels] + indices_shape)
return result
def batch_indexing_channel_last(batched_data: torch.Tensor, batched_indices: torch.Tensor):
"""
:param batched_data: [batch_size, N, C]
:param batched_indices: [batch_size, I1, I2, ..., Im]
:return: indexed data: [batch_size, I1, I2, ..., Im, C]
"""
assert batched_data.shape[0] == batched_indices.shape[0]
batch_size = batched_data.shape[0]
view_shape = [batch_size] + [1] * (len(batched_indices.shape) - 1)
expand_shape = [batch_size] + list(batched_indices.shape)[1:]
indices_of_batch = torch.arange(batch_size, dtype=torch.long, device=batched_data.device)
indices_of_batch = indices_of_batch.view(view_shape).expand(expand_shape) # [bs, I1, I2, ..., Im]
if len(batched_data.shape) == 2:
return batched_data[indices_of_batch, batched_indices.to(torch.long)]
else:
return batched_data[indices_of_batch, batched_indices.to(torch.long), :]
if layout == 'channel_first':
return batch_indexing_channel_first(batched_data, batched_indices)
elif layout == 'channel_last':
return batch_indexing_channel_last(batched_data, batched_indices)
else:
raise ValueError
def rotation_3d_in_axis(points, angles):
assert points.shape[-1] == 3
assert angles.shape[-1] == 1
angles = angles[..., 0]
n_points = points.shape[-2]
input_dims = angles.shape
if len(input_dims) > 1:
points = points.reshape(-1, n_points, 3)
angles = angles.reshape(-1)
rot_sin = torch.sin(angles)
rot_cos = torch.cos(angles)
ones = torch.ones_like(rot_cos)
zeros = torch.zeros_like(rot_cos)
rot_mat_T = torch.stack([
rot_cos, rot_sin, zeros,
-rot_sin, rot_cos, zeros,
zeros, zeros, ones,
]).transpose(0, 1).reshape(-1, 3, 3)
points = torch.bmm(points, rot_mat_T)
if len(input_dims) > 1:
points = points.reshape(*input_dims, n_points, 3)
return points
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
def pad_multiple(inputs, img_metas, size_divisor=32):
_, _, img_h, img_w = inputs.shape
pad_h = 0 if img_h % size_divisor == 0 else size_divisor - (img_h % size_divisor)
pad_w = 0 if img_w % size_divisor == 0 else size_divisor - (img_w % size_divisor)
B = len(img_metas)
N = len(img_metas[0]['ori_shape'])
for b in range(B):
img_metas[b]['img_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)]
img_metas[b]['pad_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)]
if pad_h == 0 and pad_w == 0:
return inputs
else:
return F.pad(inputs, [0, pad_w, 0, pad_h], value=0)
def rgb_to_hsv(image: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
r"""Convert an image from RGB to HSV.
.. image:: _static/img/rgb_to_hsv.png
The image data is assumed to be in the range of (0, 1).
Args:
image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
eps: scalar to enforce numarical stability.
Returns:
HSV version of the image with shape of :math:`(*, 3, H, W)`.
The H channel values are in the range 0..2pi. S and V are in the range 0..1.
.. note::
See a working example `here `__.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_hsv(input) # 2x3x4x5
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
image = image / 255.0
max_rgb, argmax_rgb = image.max(-3)
min_rgb, argmin_rgb = image.min(-3)
deltac = max_rgb - min_rgb
v = max_rgb
s = deltac / (max_rgb + eps)
deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
h1 = bc - gc
h2 = (rc - bc) + 2.0 * deltac
h3 = (gc - rc) + 4.0 * deltac
h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
h = (h / 6.0) % 1.0
h = h * 360.0
v = v * 255.0
return torch.stack((h, s, v), dim=-3)
def hsv_to_rgb(image: torch.Tensor) -> torch.Tensor:
r"""Convert an image from HSV to RGB.
The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1.
Args:
image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
Returns:
RGB version of the image with shape of :math:`(*, 3, H, W)`.
Example:
>>> input = torch.rand(2, 3, 4, 5)
>>> output = hsv_to_rgb(input) # 2x3x4x5
"""
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
h: torch.Tensor = image[..., 0, :, :] / 360.0
s: torch.Tensor = image[..., 1, :, :]
v: torch.Tensor = image[..., 2, :, :] / 255.0
hi: torch.Tensor = torch.floor(h * 6) % 6
f: torch.Tensor = ((h * 6) % 6) - hi
one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype)
p: torch.Tensor = v * (one - s)
q: torch.Tensor = v * (one - f * s)
t: torch.Tensor = v * (one - (one - f) * s)
hi = hi.long()
indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3)
out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
out = torch.gather(out, -3, indices)
out = out * 255.0
return out
class GpuPhotoMetricDistortion:
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def __call__(self, imgs):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
imgs = imgs[:, [2, 1, 0], :, :] # BGR to RGB
contrast_modes = []
for _ in range(imgs.shape[0]):
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
contrast_modes.append(random.randint(2))
for idx in range(imgs.shape[0]):
# random brightness
if random.randint(2):
delta = random.uniform(-self.brightness_delta, self.brightness_delta)
imgs[idx] += delta
if contrast_modes[idx] == 0:
if random.randint(2):
alpha = random.uniform(self.contrast_lower, self.contrast_upper)
imgs[idx] *= alpha
# convert color from BGR to HSV
imgs = rgb_to_hsv(imgs)
for idx in range(imgs.shape[0]):
# random saturation
if random.randint(2):
imgs[idx, 1] *= random.uniform(self.saturation_lower, self.saturation_upper)
# random hue
if random.randint(2):
imgs[idx, 0] += random.uniform(-self.hue_delta, self.hue_delta)
imgs[:, 0][imgs[:, 0] > 360] -= 360
imgs[:, 0][imgs[:, 0] < 0] += 360
# convert color from HSV to BGR
imgs = hsv_to_rgb(imgs)
for idx in range(imgs.shape[0]):
# random contrast
if contrast_modes[idx] == 1:
if random.randint(2):
alpha = random.uniform(self.contrast_lower, self.contrast_upper)
imgs[idx] *= alpha
# randomly swap channels
if random.randint(2):
imgs[idx] = imgs[idx, random.permutation(3)]
imgs = imgs[:, [2, 1, 0], :, :] # RGB to BGR
return imgs
class DumpConfig:
def __init__(self):
self.enabled = False
self.out_dir = 'outputs'
self.stage_count = 0
self.frame_count = 0
DUMP = DumpConfig()
================================================
FILE: old_metrics.py
================================================
import os
import glob
import torch
import argparse
import numpy as np
from tqdm import tqdm
from loaders.old_metrics import Metric_mIoU
def main(args):
pred_filepaths = sorted(glob.glob(os.path.join(args.pred_dir, '*.npz')))
gt_filepaths = sorted(glob.glob(os.path.join(args.data_root, 'occ3d', '*/*/*.npz')))
eval_metrics_miou = Metric_mIoU(
num_classes=18,
use_lidar_mask=False,
use_image_mask=True)
for pred_filepath in tqdm(pred_filepaths):
sample_token = os.path.basename(pred_filepath).split('.')[0]
for gt_filepath in gt_filepaths:
if sample_token in gt_filepath:
sem_pred = np.load(pred_filepath, allow_pickle=True)['pred']
sem_pred = np.reshape(sem_pred, [200, 200, 16])
occ_gt = np.load(gt_filepath, allow_pickle=True)
gt_semantics = occ_gt['semantics']
mask_lidar = occ_gt['mask_lidar'].astype(bool)
mask_camera = occ_gt['mask_camera'].astype(bool)
eval_metrics_miou.add_batch(sem_pred, gt_semantics, mask_lidar, mask_camera)
eval_metrics_miou.count_miou()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", type=str, default='data/nuscenes')
parser.add_argument("--pred-dir", type=str)
args = parser.parse_args()
torch.random.manual_seed(0)
np.random.seed(0)
main(args)
================================================
FILE: ray_metrics.py
================================================
import os
import glob
import mmcv
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from loaders.ray_metrics import main_rayiou
from loaders.ego_pose_dataset import EgoPoseDataset
from configs.r50_nuimg_704x256_8f import occ_class_names as occ3d_class_names
from configs.r50_nuimg_704x256_8f_openocc import occ_class_names as openocc_class_names
def main(args):
data_infos = mmcv.load(os.path.join(args.data_root, 'nuscenes_infos_val.pkl'))['infos']
gt_filepaths = sorted(glob.glob(os.path.join(args.data_root, args.data_type, '*/*/*.npz')))
# retrieve scene_name
token2scene = {}
for gt_path in gt_filepaths:
token = gt_path.split('/')[-2]
scene_name = gt_path.split('/')[-3]
token2scene[token] = scene_name
for i in range(len(data_infos)):
scene_name = token2scene[data_infos[i]['token']]
data_infos[i]['scene_name'] = scene_name
lidar_origins = []
occ_gts = []
occ_preds = []
for idx, batch in enumerate(DataLoader(EgoPoseDataset(data_infos), num_workers=8)):
output_origin = batch[1]
info = data_infos[idx]
occ_path = os.path.join(args.data_root, args.data_type, info['scene_name'], info['token'], 'labels.npz')
occ_gt = np.load(occ_path, allow_pickle=True)['semantics']
occ_gt = np.reshape(occ_gt, [200, 200, 16]).astype(np.uint8)
occ_path = os.path.join(args.pred_dir, info['token'] + '.npz')
occ_pred = np.load(occ_path, allow_pickle=True)['pred']
occ_pred = np.reshape(occ_pred, [200, 200, 16]).astype(np.uint8)
lidar_origins.append(output_origin)
occ_gts.append(occ_gt)
occ_preds.append(occ_pred)
if args.data_type == 'occ3d':
occ_class_names = occ3d_class_names
elif args.data_type == 'openocc_v2':
occ_class_names = openocc_class_names
else:
raise ValueError
print(main_rayiou(occ_preds, occ_gts, lidar_origins, occ_class_names=occ_class_names))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data-root", type=str, default='data/nuscenes')
parser.add_argument("--pred-dir", type=str)
parser.add_argument("--data-type", type=str, choices=['occ3d', 'openocc_v2'], default='occ3d')
args = parser.parse_args()
torch.random.manual_seed(0)
np.random.seed(0)
main(args)
================================================
FILE: timing.py
================================================
import time
import utils
import logging
import argparse
import importlib
import torch
import torch.distributed
import torch.backends.cudnn as cudnn
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from mmdet.apis import set_random_seed
from mmdet3d.datasets import build_dataset, build_dataloader
from mmdet3d.models import build_model
def main():
parser = argparse.ArgumentParser(description='Validate a detector')
parser.add_argument('--config', required=True)
parser.add_argument('--weights', required=True)
parser.add_argument('--num_warmup', default=10)
parser.add_argument('--samples', default=200)
parser.add_argument('--log-interval', default=50, help='interval of logging')
parser.add_argument('--override', nargs='+', action=DictAction)
args = parser.parse_args()
# parse configs
cfgs = Config.fromfile(args.config)
if args.override is not None:
cfgs.merge_from_dict(args.override)
# register custom module
importlib.import_module('models')
importlib.import_module('loaders')
# MMCV, please shut up
from mmcv.utils.logging import logger_initialized
logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
utils.init_logging(None, cfgs.debug)
# you need GPUs
assert torch.cuda.is_available() and torch.cuda.device_count() == 1
logging.info('Using GPU: %s' % torch.cuda.get_device_name(0))
torch.cuda.set_device(0)
logging.info('Setting random seed: 0')
set_random_seed(0, deterministic=True)
cudnn.benchmark = True
logging.info('Loading validation set from %s' % cfgs.data.val.data_root)
val_dataset = build_dataset(cfgs.data.val)
val_loader = build_dataloader(
val_dataset,
samples_per_gpu=1,
workers_per_gpu=cfgs.data.workers_per_gpu,
num_gpus=1,
dist=False,
shuffle=False,
seed=0,
)
logging.info('Creating model: %s' % cfgs.model.type)
model = build_model(cfgs.model)
model.cuda()
assert torch.cuda.device_count() == 1
model = MMDataParallel(model, [0])
logging.info('Loading checkpoint from %s' % args.weights)
load_checkpoint(
model, args.weights, map_location='cuda', strict=False,
logger=logging.Logger(__name__, logging.ERROR)
)
model.eval()
print('Timing w/ data loading:')
pure_inf_time = 0
with torch.no_grad():
for i, data in enumerate(val_loader):
torch.cuda.synchronize()
start_time = time.perf_counter()
model(return_loss=False, rescale=True, **data)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= args.num_warmup:
pure_inf_time += elapsed
if (i + 1) % args.log_interval == 0:
fps = (i + 1 - args.num_warmup) / pure_inf_time
print(f'Done sample [{i + 1:<3}/ {args.samples}], '
f'fps: {fps:.1f} sample / s')
if (i + 1) == args.samples:
break
if __name__ == '__main__':
main()
================================================
FILE: train.py
================================================
import os
import utils
import shutil
import logging
import argparse
import importlib
import torch
import torch.distributed as dist
from datetime import datetime
from mmcv import Config, DictAction
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import EpochBasedRunner, build_optimizer, load_checkpoint
from mmdet.apis import set_random_seed
from mmdet.core import DistEvalHook, EvalHook
from mmdet3d.datasets import build_dataset
from mmdet3d.models import build_model
from loaders.builder import build_dataloader
def main():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('--config', required=True)
parser.add_argument('--run_name', required=False, default='')
parser.add_argument('--override', nargs='+', action=DictAction)
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
args = parser.parse_args()
# parse configs
cfgs = Config.fromfile(args.config)
if args.override is not None:
cfgs.merge_from_dict(args.override)
# register custom module
importlib.import_module('models')
importlib.import_module('loaders')
# MMCV, please shut up
from mmcv.utils.logging import logger_initialized
logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
logger_initialized['mmdet3d'] = logging.Logger(__name__, logging.WARNING)
# you need GPUs
assert torch.cuda.is_available()
# determine local_rank and world_size
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = str(args.world_size)
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
if local_rank == 0:
# resume or start a new run
if cfgs.resume_from is not None:
assert os.path.isfile(cfgs.resume_from)
work_dir = os.path.dirname(cfgs.resume_from)
else:
run_name = args.run_name
if not cfgs.debug and run_name == '':
run_name = input('Name your run (leave blank for default): ')
if run_name == '':
run_name = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
work_dir = os.path.join('outputs', cfgs.model.type, run_name)
if os.path.exists(work_dir): # must be an empty dir
if input('Path "%s" already exists, overwrite it? [Y/n] ' % work_dir) == 'n':
print('Bye.')
exit(0)
shutil.rmtree(work_dir)
os.makedirs(work_dir, exist_ok=False)
# init logging, backup code
utils.init_logging(os.path.join(work_dir, 'train.log'), cfgs.debug)
utils.backup_code(work_dir)
logging.info('Logs will be saved to %s' % work_dir)
else:
# disable logging on other workers
logging.root.disabled = True
work_dir = '/tmp'
logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank))
torch.cuda.set_device(local_rank)
if world_size > 1:
logging.info('Initializing DDP with %d GPUs...' % world_size)
dist.init_process_group('nccl', init_method='env://')
logging.info('Setting random seed: 0')
set_random_seed(0, deterministic=True)
logging.info('Loading training set from %s' % cfgs.dataset_root)
train_dataset = build_dataset(cfgs.data.train)
train_loader = build_dataloader(
train_dataset,
samples_per_gpu=cfgs.batch_size // world_size,
workers_per_gpu=cfgs.data.workers_per_gpu,
num_gpus=world_size,
dist=world_size > 1,
shuffle=True,
seed=0,
)
logging.info('Loading validation set from %s' % cfgs.dataset_root)
val_dataset = build_dataset(cfgs.data.val)
val_loader = build_dataloader(
val_dataset,
samples_per_gpu=1,
workers_per_gpu=cfgs.data.workers_per_gpu,
num_gpus=world_size,
dist=world_size > 1,
shuffle=False
)
logging.info('Creating model: %s' % cfgs.model.type)
model = build_model(cfgs.model)
model.init_weights()
model.cuda()
model.train()
n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
logging.info('Trainable parameters: %d (%.1fM)' % (n_params, n_params / 1e6))
logging.info('Batch size per GPU: %d' % (cfgs.batch_size // world_size))
if world_size > 1:
model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
else:
model = MMDataParallel(model, [0])
logging.info('Creating optimizer: %s' % cfgs.optimizer.type)
optimizer = build_optimizer(model, cfgs.optimizer)
runner = EpochBasedRunner(
model,
optimizer=optimizer,
work_dir=work_dir,
logger=logging.root,
max_epochs=cfgs.total_epochs,
meta=dict(),
)
runner.register_lr_hook(cfgs.lr_config)
runner.register_optimizer_hook(cfgs.optimizer_config)
runner.register_checkpoint_hook(cfgs.checkpoint_config)
runner.register_logger_hooks(cfgs.log_config)
runner.register_timer_hook(dict(type='IterTimerHook'))
runner.register_custom_hooks(dict(type='DistSamplerSeedHook'))
if cfgs.eval_config['interval'] > 0:
if world_size > 1:
runner.register_hook(DistEvalHook(val_loader, interval=cfgs.eval_config['interval'], gpu_collect=True))
else:
runner.register_hook(EvalHook(val_loader, interval=cfgs.eval_config['interval']))
if cfgs.resume_from is not None:
logging.info('Resuming from %s' % cfgs.resume_from)
runner.resume(cfgs.resume_from)
elif cfgs.load_from is not None:
logging.info('Loading checkpoint from %s' % cfgs.load_from)
if cfgs.revise_keys is not None:
load_checkpoint(
model, cfgs.load_from, map_location='cpu',
revise_keys=cfgs.revise_keys
)
else:
load_checkpoint(
model, cfgs.load_from, map_location='cpu',
)
runner.run([train_loader], [('train', 1)])
if __name__ == '__main__':
main()
================================================
FILE: utils.py
================================================
import os
import sys
import glob
import torch
import shutil
import logging
import datetime
import socket
import wandb
from mmcv.runner.hooks import HOOKS
from mmcv.runner.hooks.logger import LoggerHook, TextLoggerHook
from mmcv.runner.dist_utils import master_only
from torch.utils.tensorboard import SummaryWriter
def init_logging(filename=None, debug=False):
logging.root = logging.RootLogger('DEBUG' if debug else 'INFO')
formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s')
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(formatter)
logging.root.addHandler(stream_handler)
if filename is not None:
file_handler = logging.FileHandler(filename)
file_handler.setFormatter(formatter)
logging.root.addHandler(file_handler)
def backup_code(work_dir, verbose=False):
base_dir = os.path.dirname(os.path.abspath(__file__))
for pattern in ['*.py', 'configs/*.py', 'models/*.py', 'loaders/*.py', 'loaders/pipelines/*.py']:
for file in glob.glob(pattern):
src = os.path.join(base_dir, file)
dst = os.path.join(work_dir, 'backup', os.path.dirname(file))
if verbose:
logging.info('Copying %s -> %s' % (os.path.relpath(src), os.path.relpath(dst)))
os.makedirs(dst, exist_ok=True)
shutil.copy2(src, dst)
@HOOKS.register_module()
class MyTextLoggerHook(TextLoggerHook):
def _log_info(self, log_dict, runner):
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{log_dict["epoch"]}/{runner.max_epochs}]' \
f'[{log_dict["iter"]}/{len(runner.data_loader)}] '
else:
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}] '
log_str += 'loss: %.2f, ' % log_dict['loss']
if 'time' in log_dict.keys():
# MOD: skip the first iteration since it's not accurate
if runner.iter == self.start_iter:
time_sec_avg = log_dict['time']
else:
self.time_sec_tot += (log_dict['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += f'eta: {eta_str}, '
log_str += f'time: {log_dict["time"]:.2f}s, ' \
f'data: {log_dict["data_time"] * 1000:.0f}ms, '
# statistic memory
if torch.cuda.is_available():
log_str += f'mem: {log_dict["memory"]}M'
runner.logger.info(log_str)
def log(self, runner):
if 'eval_iter_num' in runner.log_buffer.output:
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
else:
cur_iter = self.get_iter(runner, inner_iter=True)
log_dict = {
'mode': self.get_mode(runner),
'epoch': self.get_epoch(runner),
'iter': cur_iter
}
# only record lr of the first param group
cur_lr = runner.current_lr()
if isinstance(cur_lr, list):
log_dict['lr'] = cur_lr[0]
else:
assert isinstance(cur_lr, dict)
log_dict['lr'] = {}
for k, lr_ in cur_lr.items():
assert isinstance(lr_, list)
log_dict['lr'].update({k: lr_[0]})
if 'time' in runner.log_buffer.output:
# statistic memory
if torch.cuda.is_available():
log_dict['memory'] = self._get_max_memory(runner)
log_dict = dict(log_dict, **runner.log_buffer.output)
# MOD: disable writing to files
# self._dump_log(log_dict, runner)
self._log_info(log_dict, runner)
return log_dict
def after_train_epoch(self, runner):
if 'eval_iter_num' in runner.log_buffer.output:
runner.log_buffer.output.pop('eval_iter_num')
if runner.log_buffer.ready:
metrics = self.get_loggable_tags(runner)
runner.logger.info('--- Evaluation Results ---')
runner.logger.info('RayIoU: %.4f' % metrics['val/RayIoU'])
@HOOKS.register_module()
class MyTensorboardLoggerHook(LoggerHook):
def __init__(self, log_dir=None, interval=10, ignore_last=True, reset_flag=False, by_epoch=True):
super(MyTensorboardLoggerHook, self).__init__(
interval, ignore_last, reset_flag, by_epoch)
self.log_dir = log_dir
@master_only
def before_run(self, runner):
super(MyTensorboardLoggerHook, self).before_run(runner)
if self.log_dir is None:
self.log_dir = runner.work_dir
self.writer = SummaryWriter(self.log_dir)
@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
for key, value in tags.items():
# MOD: merge into the 'train' group
if key == 'learning_rate':
key = 'train/learning_rate'
# MOD: skip momentum
ignore = False
if key == 'momentum':
ignore = True
# MOD: skip intermediate losses
for i in range(5):
if key[:13] == 'train/d%d.loss' % i:
ignore = True
if self.get_mode(runner) == 'train' and key[:5] != 'train':
ignore = True
if self.get_mode(runner) != 'train' and key[:3] != 'val':
ignore = True
if ignore:
continue
if key[:5] == 'train':
self.writer.add_scalar(key, value, self.get_iter(runner))
elif key[:3] == 'val':
self.writer.add_scalar(key, value, self.get_epoch(runner))
@master_only
def after_run(self, runner):
self.writer.close()
# modified from mmcv.runner.hooks.logger.wandb
@HOOKS.register_module()
class MyWandbLoggerHook(LoggerHook):
"""Class to log metrics with wandb.
It requires `wandb`_ to be installed.
Args:
log_dir (str): directory for saving logs
Default None.
project_name (str): name for your project (mainly used to specify saving path on wandb server)
Default None.
team_name (str): name for your team (mainly used to specify saving path on wandb server)
Default None.
experiment_name (str): name for your run, if not specified, use the last part of log_dir
Default None.
interval (int): Logging interval (every k iterations).
Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
Default: True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default: False.
commit (bool): Save the metrics dict to the wandb server and increment
the step. If false ``wandb.log`` just updates the current metrics
dict with the row argument and metrics won't be saved until
``wandb.log`` is called with ``commit=True``.
Default: True.
by_epoch (bool): Whether EpochBasedRunner is used.
Default: True.
with_step (bool): If True, the step will be logged from
``self.get_iters``. Otherwise, step will not be logged.
Default: True.
out_suffix (str or tuple[str], optional): Those filenames ending with
``out_suffix`` will be uploaded to wandb.
Default: ('.log.json', '.log', '.py').
`New in version 1.4.3.`
.. _wandb:
https://docs.wandb.ai
"""
def __init__(self, log_dir=None, project_name=None, team_name=None, experiment_name=None,
interval=10, ignore_last=True, reset_flag=False, by_epoch=True, commit=True,
with_step=True, out_suffix = ('.log.json', '.log', '.py')):
super().__init__(interval, ignore_last, reset_flag, by_epoch)
self.import_wandb()
self.commit = commit
self.with_step = with_step
self.out_suffix = out_suffix
self.log_dir = log_dir
self.project_name = project_name
self.team_name = team_name
self.experiment_name = experiment_name
if commit:
os.system('wandb online')
else:
os.system('wandb offline')
def import_wandb(self) -> None:
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
self.wandb = wandb
@master_only
def before_run(self, runner) -> None:
super().before_run(runner)
if self.log_dir is None:
self.log_dir = runner.work_dir
if self.experiment_name is None:
self.experiment_name = os.path.basename(self.log_dir)
init_kwargs = dict(
project=self.project_name,
entity=self.team_name,
notes=socket.gethostname(),
name=self.experiment_name,
dir=self.log_dir,
reinit=True
)
if self.wandb is None:
self.import_wandb()
if init_kwargs:
self.wandb.init(**init_kwargs) # type: ignore
else:
self.wandb.init() # type: ignore
@master_only
def log(self, runner) -> None:
tags = self.get_loggable_tags(runner)
mode = self.get_mode(runner)
if not tags:
return
if 'learning_rate' in tags.keys():
tags['train/learning_rate'] = tags['learning_rate']
del tags['learning_rate']
if 'momentum' in tags.keys():
del tags['momentum']
tags = {k: v for k, v in tags.items() if k.startswith(mode)}
if self.with_step:
self.wandb.log(
tags, step=self.get_iter(runner), commit=self.commit)
else:
tags['global_step'] = self.get_iter(runner)
self.wandb.log(tags, commit=self.commit)
@master_only
def after_run(self, runner) -> None:
self.wandb.join()
================================================
FILE: val.py
================================================
import os
import utils
import logging
import argparse
import importlib
import torch
import torch.distributed
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from mmcv import Config
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import load_checkpoint
from mmdet.apis import set_random_seed, multi_gpu_test, single_gpu_test
from mmdet3d.datasets import build_dataset, build_dataloader
from mmdet3d.models import build_model
def evaluate(dataset, results):
metrics = dataset.evaluate(results, jsonfile_prefix=None)
logging.info('--- Evaluation Results ---')
for k, v in metrics.items():
logging.info('%s: %.4f' % (k, v))
return metrics
def main():
parser = argparse.ArgumentParser(description='Validate a detector')
parser.add_argument('--config', required=True)
parser.add_argument('--weights', required=True)
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=1)
args = parser.parse_args()
# parse configs
cfgs = Config.fromfile(args.config)
# register custom module
importlib.import_module('models')
importlib.import_module('loaders')
# MMCV, please shut up
from mmcv.utils.logging import logger_initialized
logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
# you need GPUs
assert torch.cuda.is_available()
# determine local_rank and world_size
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if 'WORLD_SIZE' not in os.environ:
os.environ['WORLD_SIZE'] = str(args.world_size)
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
if local_rank == 0:
utils.init_logging(None, cfgs.debug)
else:
logging.root.disabled = True
logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank))
torch.cuda.set_device(local_rank)
if world_size > 1:
logging.info('Initializing DDP with %d GPUs...' % world_size)
dist.init_process_group('nccl', init_method='env://')
logging.info('Setting random seed: 0')
set_random_seed(0, deterministic=True)
cudnn.benchmark = True
logging.info('Loading validation set from %s' % cfgs.data.val.data_root)
val_dataset = build_dataset(cfgs.data.val)
val_loader = build_dataloader(
val_dataset,
samples_per_gpu=args.batch_size,
workers_per_gpu=cfgs.data.workers_per_gpu,
num_gpus=world_size,
dist=world_size > 1,
shuffle=False,
seed=0,
)
logging.info('Creating model: %s' % cfgs.model.type)
model = build_model(cfgs.model)
model.cuda()
if world_size > 1:
model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
else:
model = MMDataParallel(model, [0])
if os.path.isfile(args.weights):
logging.info('Loading checkpoint from %s' % args.weights)
load_checkpoint(
model, args.weights, map_location='cuda', strict=True,
logger=logging.Logger(__name__, logging.ERROR)
)
if world_size > 1:
results = multi_gpu_test(model, val_loader, gpu_collect=True)
else:
results = single_gpu_test(model, val_loader)
if local_rank == 0:
evaluate(val_dataset, results)
if __name__ == '__main__':
main()
================================================
FILE: viz_prediction.py
================================================
import os
import cv2
import utils
import logging
import argparse
import importlib
import torch
import numpy as np
from tqdm import tqdm
from mmcv import Config, DictAction
from mmdet.apis import set_random_seed
from mmdet3d.datasets import build_dataset, build_dataloader
from configs.r50_nuimg_704x256_8f import point_cloud_range as pc_range
from configs.r50_nuimg_704x256_8f import occ_size
from configs.r50_nuimg_704x256_8f import occ_class_names
from mmcv.parallel import MMDataParallel
from mmcv.runner import load_checkpoint
from mmdet3d.models import build_model
color_map = np.array([
[0, 0, 0, 255], # others
[255, 120, 50, 255], # barrier orangey
[255, 192, 203, 255], # bicycle pink
[255, 255, 0, 255], # bus yellow
[0, 150, 245, 255], # car blue
[0, 255, 255, 255], # construction_vehicle cyan
[200, 180, 0, 255], # motorcycle dark orange
[255, 0, 0, 255], # pedestrian red
[255, 240, 150, 255], # traffic_cone light yellow
[135, 60, 0, 255], # trailer brown
[160, 32, 240, 255], # truck purple
[255, 0, 255, 255], # driveable_surface dark pink
[175, 0, 75, 255], # other_flat dark red
[75, 0, 75, 255], # sidewalk dard purple
[150, 240, 80, 255], # terrain light green
[230, 230, 250, 255], # manmade white
[0, 175, 0, 255], # vegetation green
[255, 255, 255, 255], # free white
], dtype=np.uint8)
def occ2img(semantics):
H, W, D = semantics.shape
free_id = len(occ_class_names) - 1
semantics_2d = np.ones([H, W], dtype=np.int32) * free_id
for i in range(D):
semantics_i = semantics[..., i]
non_free_mask = (semantics_i != free_id)
semantics_2d[non_free_mask] = semantics_i[non_free_mask]
viz = color_map[semantics_2d]
viz = viz[..., :3]
viz = cv2.resize(viz, dsize=(800, 800))
return viz
def main():
parser = argparse.ArgumentParser(description='Validate a detector')
parser.add_argument('--config', required=True)
parser.add_argument('--weights', required=True)
parser.add_argument('--viz-dir', required=True)
parser.add_argument('--override', nargs='+', action=DictAction)
args = parser.parse_args()
# parse configs
cfgs = Config.fromfile(args.config)
if args.override is not None:
cfgs.merge_from_dict(args.override)
# use val-mini for visualization
#cfgs.data.val.ann_file = cfgs.data.val.ann_file.replace('val', 'val_mini')
# register custom module
importlib.import_module('models')
importlib.import_module('loaders')
# MMCV, please shut up
from mmcv.utils.logging import logger_initialized
logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
# you need one GPU
assert torch.cuda.is_available()
assert torch.cuda.device_count() == 1
# logging
utils.init_logging(None, cfgs.debug)
logging.info('Using GPU: %s' % torch.cuda.get_device_name(0))
# random seed
logging.info('Setting random seed: 0')
set_random_seed(0, deterministic=True)
logging.info('Loading validation set from %s' % cfgs.data.val.data_root)
val_dataset = build_dataset(cfgs.data.val)
val_loader = build_dataloader(
val_dataset,
samples_per_gpu=1,
workers_per_gpu=cfgs.data.workers_per_gpu,
num_gpus=1,
dist=False,
shuffle=False,
seed=0,
)
logging.info('Creating model: %s' % cfgs.model.type)
model = build_model(cfgs.model)
model.cuda()
model = MMDataParallel(model, [0])
model.eval()
logging.info('Loading checkpoint from %s' % args.weights)
load_checkpoint(
model, args.weights, map_location='cuda', strict=True,
logger=logging.Logger(__name__, logging.ERROR)
)
for i, data in tqdm(enumerate(val_loader)):
#print(data['img_metas'].data[0][0]['filename'][:6])
with torch.no_grad():
occ_pred = model(return_loss=False, rescale=True, **data)[0]
sem_pred = torch.from_numpy(occ_pred['sem_pred'])[0] # [N]
occ_loc = torch.from_numpy(occ_pred['occ_loc'].astype(np.int64))[0] # [N, 3]
# sparse to dense
free_id = len(occ_class_names) - 1
dense_pred = torch.ones(occ_size, device=sem_pred.device, dtype=sem_pred.dtype) * free_id # [200, 200, 16]
dense_pred[occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]] = sem_pred
sem_pred = dense_pred.numpy()
cv2.imwrite(os.path.join(args.viz_dir, 'sem_%04d.jpg' % i), occ2img(sem_pred)[..., ::-1])
if __name__ == '__main__':
main()