Repository: CGuangyan-BIT/PointGPT
Branch: V1.2
Commit: 9ba3d40e3aaa
Files: 102
Total size: 341.7 KB
Directory structure:
gitextract_eruklvu8/
├── DATASET.md
├── LICENSE
├── README.md
├── cfgs/
│ ├── PointGPT-B/
│ │ ├── fewshot.yaml
│ │ ├── finetune_modelnet.yaml
│ │ ├── finetune_modelnet_8k.yaml
│ │ ├── finetune_scan_hardest.yaml
│ │ ├── finetune_scan_objbg.yaml
│ │ ├── finetune_scan_objonly.yaml
│ │ ├── post_pretrain.yaml
│ │ └── pretrain.yaml
│ ├── PointGPT-L/
│ │ ├── fewshot.yaml
│ │ ├── finetune_modelnet.yaml
│ │ ├── finetune_modelnet_8k.yaml
│ │ ├── finetune_scan_hardest.yaml
│ │ ├── finetune_scan_objbg.yaml
│ │ ├── finetune_scan_objonly.yaml
│ │ ├── post_pretrain.yaml
│ │ └── pretrain.yaml
│ ├── PointGPT-S/
│ │ ├── fewshot.yaml
│ │ ├── finetune_modelnet.yaml
│ │ ├── finetune_modelnet_8k.yaml
│ │ ├── finetune_scan_hardest.yaml
│ │ ├── finetune_scan_objbg.yaml
│ │ ├── finetune_scan_objonly.yaml
│ │ └── pretrain.yaml
│ └── dataset_configs/
│ ├── LabeledHybrid.yaml
│ ├── ModelNet40.yaml
│ ├── ModelNet40FewShot.yaml
│ ├── ScanObjectNN_hardest.yaml
│ ├── ScanObjectNN_objectbg.yaml
│ ├── ScanObjectNN_objectonly.yaml
│ ├── ShapeNet-55.yaml
│ └── UnlabeledHybrid.yaml
├── datasets/
│ ├── LabeledHybrid.py
│ ├── ModelNetDataset.py
│ ├── ModelNetDatasetFewShot.py
│ ├── ScanObjectNNDataset.py
│ ├── ShapeNet55Dataset.py
│ ├── UnlabeledHybrid.py
│ ├── __init__.py
│ ├── build.py
│ ├── data_transforms.py
│ ├── generate_few_shot_data.py
│ └── io.py
├── extensions/
│ ├── chamfer_dist/
│ │ ├── __init__.py
│ │ ├── chamfer.cu
│ │ ├── chamfer_cuda.cpp
│ │ ├── setup.py
│ │ └── test.py
│ └── emd/
│ ├── README.md
│ ├── __init__.py
│ ├── cuda/
│ │ ├── emd.cpp
│ │ └── emd_kernel.cu
│ ├── emd.py
│ ├── setup.py
│ └── test_emd_loss.py
├── figures/
│ └── a
├── main.py
├── main_vis.py
├── models/
│ ├── GPT.py
│ ├── PointGPT.py
│ ├── __init__.py
│ ├── build.py
│ └── z_order.py
├── requirements.txt
├── segmentation/
│ ├── __init__.py
│ ├── dataset.py
│ ├── extensions/
│ │ ├── chamfer_dist/
│ │ │ ├── __init__.py
│ │ │ ├── chamfer.cu
│ │ │ ├── chamfer_cuda.cpp
│ │ │ ├── setup.py
│ │ │ └── test.py
│ │ └── emd/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── cuda/
│ │ │ ├── emd.cpp
│ │ │ └── emd_kernel.cu
│ │ ├── emd.py
│ │ ├── setup.py
│ │ └── test_emd_loss.py
│ ├── logger.py
│ ├── main.py
│ ├── misc.py
│ ├── models/
│ │ ├── gpt2_seg.py
│ │ ├── pointnet2_utils.py
│ │ ├── pt.py
│ │ └── z_order.py
│ ├── pointnet_util.py
│ └── provider.py
├── tools/
│ ├── __init__.py
│ ├── builder.py
│ ├── runner.py
│ ├── runner_finetune.py
│ └── runner_pretrain.py
└── utils/
├── AverageMeter.py
├── checkpoint.py
├── config.py
├── dist_utils.py
├── logger.py
├── misc.py
├── parser.py
└── registry.py
================================================
FILE CONTENTS
================================================
================================================
FILE: DATASET.md
================================================
## Dataset
The overall directory structure should be:
```
│Point-MAE/
├──cfgs/
├──data/
│ ├──ModelNet/
│ ├──ModelNetFewshot/
│ ├──ScanObjectNN/
│ ├──ShapeNet55-34/
│ ├──shapenetcore_partanno_segmentation_benchmark_v0_normal/
├──datasets/
├──.......
```
### ModelNet40 Dataset:
```
│ModelNet/
├──modelnet40_normal_resampled/
│ ├── modelnet40_shape_names.txt
│ ├── modelnet40_train.txt
│ ├── modelnet40_test.txt
│ ├── modelnet40_train_8192pts_fps.dat
│ ├── modelnet40_test_8192pts_fps.dat
```
Download: You can download the processed data from [Point-BERT repo](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md), or download from the [official website](https://modelnet.cs.princeton.edu/#) and process it by yourself.
### ModelNet Few-shot Dataset:
```
│ModelNetFewshot/
├──5way10shot/
│ ├── 0.pkl
│ ├── ...
│ ├── 9.pkl
├──5way20shot/
│ ├── ...
├──10way10shot/
│ ├── ...
├──10way20shot/
│ ├── ...
```
Download: Please download the data from [Point-BERT repo](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md). We use the same data split as theirs.
### ScanObjectNN Dataset:
```
│ScanObjectNN/
├──main_split/
│ ├── training_objectdataset_augmentedrot_scale75.h5
│ ├── test_objectdataset_augmentedrot_scale75.h5
│ ├── training_objectdataset.h5
│ ├── test_objectdataset.h5
├──main_split_nobg/
│ ├── training_objectdataset.h5
│ ├── test_objectdataset.h5
```
Download: Please download the data from the [official website](https://hkust-vgd.github.io/scanobjectnn/).
### ShapeNet55/34 Dataset:
```
│ShapeNet55-34/
├──shapenet_pc/
│ ├── 02691156-1a04e3eab45ca15dd86060f189eb133.npy
│ ├── 02691156-1a6ad7a24bb89733f412783097373bdc.npy
│ ├── .......
├──ShapeNet-55/
│ ├── train.txt
│ └── test.txt
```
Download: Please download the data from [Point-BERT repo](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md).
### ShapeNetPart Dataset:
```
|shapenetcore_partanno_segmentation_benchmark_v0_normal/
├──02691156/
│ ├── 1a04e3eab45ca15dd86060f189eb133.txt
│ ├── .......
│── .......
│──train_test_split/
│──synsetoffset2category.txt
```
Download: Please download the data from [here](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip).
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 PANG-Yatian, YUAN-Li
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# PointGPT
## PointGPT: Auto-regressively Generative Pre-training from Point Clouds [ArXiv](https://arxiv.org/abs/2305.11487)
In this work, we present PointGPT, a novel approach that extends the concept of GPT to point clouds, utilizing a point cloud auto-regressive generation task for pre-training transformer models. In object classification tasks, our PointGPT achieves 94.9% accuracy on the ModelNet40 dataset and 93.4% accuracy on the ScanObjectNN dataset, outperforming all other transformer models. In few-shot learning tasks, our method also attains new SOTA performance on all four benchmarks.
## News
[2023.09.22] PointGPT has been accepted by NeurIPS 2023!
[2023.09.08] Unlabeled hybrid dataset and labeled hybrid dataset have been released!
[2023.08.19] Code has been updated; PointGPT-B and PointGPT-L models have been released!
[2023.06.20] Code and the PointGPT-S models have been released!
## 1. Requirements
PyTorch >= 1.7.0;
python >= 3.7;
CUDA >= 9.0;
GCC >= 4.9;
torchvision;
```
pip install -r requirements.txt
```
```
# Chamfer Distance & emd
cd ./extensions/chamfer_dist
python setup.py install --user
cd ./extensions/emd
python setup.py install --user
# PointNet++
pip install "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
# GPU kNN
pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl
```
## 2. Datasets
Our training data for the PointGPT-S model encompasses ShapeNet, ScanObjectNN, ModelNet40, and ShapeNetPart datasets. For detailed information, please refer to [DATASET.md](./DATASET.md).
To pretrain the PointGPT-B and PointGPT-L models, we employ both unlabeled hybrid dataset and labeled hybrid dataset, available for download [here](https://drive.google.com/file/d/1TWgd3eJX1HDruFfU9JrGnBfcVhzJIXqT/view?usp=sharing).
## 3. PointGPT Models
### PointGPT-S Models
| Task | Dataset | Config | Acc. | Download |
| ----------------- | -------------- | --------------------------------------------------------------- | ---------- | --------------------------------------------------------------------------------------------- |
| Pre-training | ShapeNet | [pretrain.yaml](./cfgs/PointGPT-S/pretrain.yaml) | N.A. | [here](https://drive.google.com/file/d/1gTFI327kXVDFQ90JfYX0zIS4opM1EkqX/view?usp=drive_link) |
| Classification | ScanObjectNN | [finetune_scan_hardest.yaml](./cfgs/PointGPT-S/finetune_scan_hardest.yaml) | 86.9% | [here](https://drive.google.com/file/d/12Tj2OFKsEPT5zd5nQQ2VNEZlCKHncdGh/view?usp=drive_link) |
| Classification | ScanObjectNN | [finetune_scan_objbg.yaml](./cfgs/PointGPT-S/finetune_scan_objbg.yaml) | 91.6% | [here](https://drive.google.com/file/d/1s4RrBkfwVr8r0H2FxwiHULcyMe_EAJ9D/view?usp=drive_link) |
| Classification | ScanObjectNN | [finetune_scan_objonly.yaml](./cfgs/PointGPT-S/finetune_scan_objonly.yaml) | 90.0% | [here](https://drive.google.com/file/d/173yfDAlqqed-oRHaogX6DC4Uj1b8Rvxt/view?usp=drive_link) |
| Classification | ModelNet40(1k) | [finetune_modelnet.yaml](./cfgs/PointGPT-S/finetune_modelnet.yaml) | 94.0% | [here](https://drive.google.com/file/d/17uoJchAzwapTNHVxOWNH4HLNZz9kbGoo/view?usp=drive_link) |
| Classification | ModelNet40(8k) | [finetune_modelnet_8k.yaml](./cfgs/PointGPT-S/finetune_modelnet_8k.yaml) | 94.2% | [here](https://drive.google.com/file/d/1XocTFSsKZgKHx2cLqZJi2rcF74hQ-1nx/view?usp=drive_link) |
| Part segmentation | ShapeNetPart | [segmentation](./segmentation) | 86.2% mIoU | [here](https://drive.google.com/file/d/1WVMTtIq4vPQOOnlDsymVA5541lNL-hm3/view?usp=drive_link) |
| Task | Dataset | Config | 5w10s Acc. (%) | 5w20s Acc. (%) | 10w10s Acc. (%) | 10w20s Acc. (%) |
| ----------------- | ---------- | ----------------------------------- | -------------- | -------------- | --------------- | --------------- |
| Few-shot learning | ModelNet40 | [fewshot.yaml](./cfgs/fewshot.yaml) | 96.8 ± 2.0 | 98.6 ± 1.1 | 92.6 ± 4.6 | 95.2 ± 3.4 |
### PointGPT-B Models
| Task | Dataset | Config | Acc. | Download |
| ----------------- | -------------- | --------------------------------------------------------------- | ---------- | --------------------------------------------------------------------------------------------- |
| Pre-training | UnlabeledHybrid | [pretrain.yaml](./cfgs/PointGPT-B/pretrain.yaml) | N.A. | [here](https://drive.google.com/file/d/1Gyf9ZR8MCPg1XOCALjJR9VJepV7iAi5S/view?usp=sharing) |
| Post-pre-training | LabeledHybrid | [post_pretrain.yaml](./cfgs/PointGPT-B/post_pretrain.yaml) | N.A. | [here](https://drive.google.com/file/d/1Gc7thuU-D1Sq4NIMTV6-U1LhVN0E2z9l/view?usp=sharing) |
| Classification | ScanObjectNN | [finetune_scan_hardest.yaml](./cfgs/PointGPT-B/finetune_scan_hardest.yaml) | 91.9% | [here](https://drive.google.com/file/d/1tHi7W935DxVttXHG0Mgb0HSfYWUqXLwB/view?usp=sharing) |
| Classification | ScanObjectNN | [finetune_scan_objbg.yaml](./cfgs/PointGPT-B/finetune_scan_objbg.yaml) | 95.8% | [here](https://drive.google.com/file/d/1te8DuC_-cOzt4JayyaNWvxHcRztjDlGF/view?usp=sharing) |
| Classification | ScanObjectNN | [finetune_scan_objonly.yaml](./cfgs/PointGPT-B/finetune_scan_objonly.yaml) | 95.2% | [here](https://drive.google.com/file/d/17c8KvDrAuY0GgcO7SGE-4zlMArjzkjLX/view?usp=sharing) |
| Classification | ModelNet40(1k) | [finetune_modelnet.yaml](./cfgs/PointGPT-B/finetune_modelnet.yaml) | 94.4% | [here](https://drive.google.com/file/d/1l5zhy52erSp5gigbhYaT0nyMrV_lbh-C/view?usp=sharing) |
| Classification | ModelNet40(8k) | [finetune_modelnet_8k.yaml](./cfgs/PointGPT-B/finetune_modelnet_8k.yaml) | 94.6% | [here](https://drive.google.com/file/d/1FzM7ULPUAOk_J0BRHFvv0nS_Xd65oWbV/view?usp=sharing) |
| Part segmentation | ShapeNetPart | [segmentation](./segmentation) | 86.5% mIoU | [here](https://drive.google.com/file/d/1P6hELhX6Yr-rN04q6N71wZfvW2HnLhqD/view?usp=sharing) |
| Task | Dataset | Config | 5w10s Acc. (%) | 5w20s Acc. (%) | 10w10s Acc. (%) | 10w20s Acc. (%) |
| ----------------- | ---------- | ----------------------------------- | -------------- | -------------- | --------------- | --------------- |
| Few-shot learning | ModelNet40 | [fewshot.yaml](./cfgs/PointGPT-B/fewshot.yaml) | 97.5 ± 2.0 | 98.8 ± 1.0 | 93.5 ± 4.0 | 95.8 ± 3.0 |
### PointGPT-L Models
| Task | Dataset | Config | Acc. | Download |
| ----------------- | -------------- | --------------------------------------------------------------- | ---------- | --------------------------------------------------------------------------------------------- |
| Pre-training | UnlabeledHybrid | [pretrain.yaml](./cfgs/PointGPT-L/pretrain.yaml) | N.A. | [here](https://drive.google.com/file/d/1nzCwriFbC2QoDbRpGhWvf_DbFIkFU6zV/view?usp=sharing) |
| Post-pre-training | LabeledHybrid | [post_pretrain.yaml](./cfgs/PointGPT-L/post_pretrain.yaml) | N.A. | [here](https://drive.google.com/file/d/1Kh6f6gFR12Y86FAeBtMU9NbNpB5vZnpu/view?usp=sharing) |
| Classification | ScanObjectNN | [finetune_scan_hardest.yaml](./cfgs/PointGPT-L/finetune_scan_hardest.yaml) | 93.4% | [here](https://drive.google.com/file/d/1e_qIfZCqQmq0eRpYhf9xrIxl6TkzsaZ9/view?usp=sharing) |
| Classification | ScanObjectNN | [finetune_scan_objbg.yaml](./cfgs/PointGPT-L/finetune_scan_objbg.yaml) | 97.2% | [here](https://drive.google.com/file/d/1gd8gn0ffK0zfWv7AAUbygzIPSeeRU8fD/view?usp=sharing) |
| Classification | ScanObjectNN | [finetune_scan_objonly.yaml](./cfgs/PointGPT-L/finetune_scan_objonly.yaml) | 96.6% | [here](https://drive.google.com/file/d/1F2MnPmQGKnYUgmS5uz3PNInU23jWsNj1/view?usp=sharing) |
| Classification | ModelNet40(1k) | [finetune_modelnet.yaml](./cfgs/PointGPT-L/finetune_modelnet.yaml) | 94.7% | [here](https://drive.google.com/file/d/1ntWwZCvD_Tqykq9F7QrDKXH7aL-dcCsQ/view?usp=sharing) |
| Classification | ModelNet40(8k) | [finetune_modelnet_8k.yaml](./cfgs/PointGPT-L/finetune_modelnet_8k.yaml) | 94.9% | [here](https://drive.google.com/file/d/1gKgdbtIuRinJY-NElSHwrKAL5OhBjrGD/view?usp=sharing) |
| Part segmentation | ShapeNetPart | [segmentation](./segmentation) | 86.6% mIoU | [here](https://drive.google.com/file/d/1d3fXLBkXvzl9YjX5DDMdm7rUtCvfwgUL/view?usp=sharing) |
| Task | Dataset | Config | 5w10s Acc. (%) | 5w20s Acc. (%) | 10w10s Acc. (%) | 10w20s Acc. (%) |
| ----------------- | ---------- | ----------------------------------- | -------------- | -------------- | --------------- | --------------- |
| Few-shot learning | ModelNet40 | [fewshot.yaml](./cfgs/PointGPT-L/fewshot.yaml) | 98.0 ± 1.9 | 99.0 ± 1.0 | 94.1 ± 3.3 | 96.1 ± 2.8 |
## 4. PointGPT Pre-training
To pretrain PointGPT, run the following command.
```
CUDA_VISIBLE_DEVICES= python main.py --config cfgs//pretrain.yaml --exp_name
```
To post-pretrain PointGPT, run the following command.
```
CUDA_VISIBLE_DEVICES= python main.py --config cfgs//post_pretrain.yaml --exp_name --finetune_model
```
## 5. PointGPT Fine-tuning
Fine-tuning on ScanObjectNN, run the following command:
```
CUDA_VISIBLE_DEVICES= python main.py --config cfgs//finetune_scan_hardest.yaml \
--finetune_model --exp_name --ckpts
```
Fine-tuning on ModelNet40, run the following command:
```
CUDA_VISIBLE_DEVICES= python main.py --config cfgs//finetune_modelnet.yaml \
--finetune_model --exp_name --ckpts
```
Voting on ModelNet40, run the following command:
```
CUDA_VISIBLE_DEVICES= python main.py --test --config cfgs//finetune_modelnet.yaml \
--exp_name --ckpts
```
Few-shot learning, run the following command:
```
CUDA_VISIBLE_DEVICES= python main.py --config cfgs//fewshot.yaml --finetune_model \
--ckpts --exp_name --way <5 or 10> --shot <10 or 20> --fold <0-9>
```
Part segmentation on ShapeNetPart, run the following command:
```
cd segmentation
python main.py --ckpts --root path/to/data --learning_rate 0.0002 --epoch 300 --model_name
```
## 6. Visualization
Visulization of pre-trained model on validation set, run:
```
python main_vis.py --test --ckpts --config cfgs//pretrain.yaml --exp_name
```
## 7. Ablation studies on post-pre-training stage
| Methods |
ScanObjectNN |
ModelNet40 |
ShapeNetPart |
| OBJ_BG |
OBJ_ONLY |
PB_T50_RS |
1k P |
8k P |
Cls.mIoU |
Inst.mIoU |
| without post-pre-training |
| PointGPT-B |
93.6 |
92.5 |
89.6 |
94.2 |
94.4 |
84.5 |
86.4 |
| PointGPT-L |
95.7 |
94.1 |
91.1 |
94.5 |
94.7 |
84.7 |
86.5 |
| with post-pre-training |
| PointGPT-B |
95.8 (+2.2) |
95.2 (+2.7) |
91.9 (+2.3) |
94.4 (+0.2) |
94.6 (+0.2) |
84.5 (+0.0) |
86.5 (+0.1) |
| PointGPT-L |
97.2 (+1.5) |
96.6 (+2.5) |
93.4 (+2.3) |
94.7 (+0.2) |
94.9 (+0.2) |
84.8 (+0.1) |
86.6 (+0.1) |
## Acknowledgements
Our codes are built upon [Point-MAE](https://github.com/Pang-Yatian/Point-MAE), [Point-BERT](https://github.com/lulutang0608/Point-BERT), [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) and [Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch)
The unlabeled hybrid dataset and labeled hybrid dataset are built upon [ModelNet40](https://3dshapenets.cs.princeton.edu/), [PartNet](https://partnet.cs.stanford.edu/), [ShapeNet](http://www.shapenet.org), [S3DIS](http://buildingparser.stanford.edu/), [ScanObjectNN](https://hkust-vgd.github.io/scanobjectnn/), [SUN RGB-D](https://rgbd.cs.princeton.edu/), and [Semantic3D](http://semantic3d.net/)
## Reference
```
@article{chen2024pointgpt,
title={Pointgpt: Auto-regressively generative pre-training from point clouds},
author={Chen, Guangyan and Wang, Meiling and Yang, Yi and Yu, Kai and Yuan, Li and Yue, Yufeng},
journal={Advances in Neural Information Processing Systems},
volume={36},
year={2024}
}
```
For unlabeled hybrid dataset or labeled hybrid dataset, please also cite the following work.
```
@inproceedings{wu20153d,
title={3d shapenets: A deep representation for volumetric shapes},
author={Wu, Zhirong and Song, Shuran and Khosla, Aditya and Yu, Fisher and Zhang, Linguang and Tang, Xiaoou and Xiao, Jianxiong},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={1912--1920},
year={2015}
}
@inproceedings{mo2019partnet,
title={Partnet: A large-scale benchmark for fine-grained and hierarchical part-level 3d object understanding},
author={Mo, Kaichun and Zhu, Shilin and Chang, Angel X and Yi, Li and Tripathi, Subarna and Guibas, Leonidas J and Su, Hao},
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
pages={909--918},
year={2019}
}
@article{chang2015shapenet,
title={Shapenet: An information-rich 3d model repository},
author={Chang, Angel X and Funkhouser, Thomas and Guibas, Leonidas and Hanrahan, Pat and Huang, Qixing and Li, Zimo and Savarese, Silvio and Savva, Manolis and Song, Shuran and Su, Hao and others},
journal={arXiv preprint arXiv:1512.03012},
year={2015}
}
@inproceedings{armeni20163d,
title={3d semantic parsing of large-scale indoor spaces},
author={Armeni, Iro and Sener, Ozan and Zamir, Amir R and Jiang, Helen and Brilakis, Ioannis and Fischer, Martin and Savarese, Silvio},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={1534--1543},
year={2016}
}
@inproceedings{uy-scanobjectnn-iccv19,
title = {Revisiting Point Cloud Classification: A New Benchmark Dataset and Classification Model on Real-World Data},
author = {Mikaela Angelina Uy and Quang-Hieu Pham and Binh-Son Hua and Duc Thanh Nguyen and Sai-Kit Yeung},
booktitle = {International Conference on Computer Vision (ICCV)},
year = {2019}
}
@inproceedings{song2015sun,
title={Sun rgb-d: A rgb-d scene understanding benchmark suite},
author={Song, Shuran and Lichtenberg, Samuel P and Xiao, Jianxiong},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={567--576},
year={2015}
}
@article{hackel2017semantic3d,
title={Semantic3d. net: A new large-scale point cloud classification benchmark},
author={Hackel, Timo and Savinov, Nikolay and Ladicky, Lubor and Wegner, Jan D and Schindler, Konrad and Pollefeys, Marc},
journal={arXiv preprint arXiv:1704.03847},
year={2017}
}
```
================================================
FILE: cfgs/PointGPT-B/fewshot.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0005, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 30 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 768,
depth: 12,
drop_path_rate: 0.1,
cls_dim: 40,
num_heads: 12,
group_size: 32,
num_group: 64,
encoder_dims: 768,
decoder_depth: 4,
}
npoints: 1024
total_bs: 32
step_per_update: 1
max_epoch: 300
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-B/finetune_modelnet.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 768,
depth: 12,
drop_path_rate: 0.2,
cls_dim: 40,
num_heads: 12,
group_size: 32,
num_group: 64,
encoder_dims: 768,
decoder_depth: 4,
loss: cdl2,
weight_center: 1,
}
npoints: 1024
total_bs: 128
step_per_update: 1
max_epoch: 50
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-B/finetune_modelnet_8k.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.00005, weight_decay: 0.005 } }
scheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 768,
depth: 12,
drop_path_rate: 0.2,
cls_dim: 40,
num_heads: 12,
group_size: 32,
num_group: 512,
encoder_dims: 768,
decoder_depth: 4,
loss: cdl2,
weight_center: 1,
}
npoints: 8192
total_bs: 32
step_per_update: 1
max_epoch: 50
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-B/finetune_scan_hardest.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 30, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 768,
depth: 12,
drop_path_rate: 0.2,
cls_dim: 15,
num_heads: 12,
group_size: 32,
num_group: 128,
encoder_dims: 768,
decoder_depth: 4,
}
npoints: 2048
total_bs: 64
step_per_update: 1
max_epoch: 30
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-B/finetune_scan_objbg.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 30, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 768,
depth: 12,
drop_path_rate: 0.2,
cls_dim: 15,
num_heads: 12,
group_size: 32,
num_group: 128,
encoder_dims: 768,
decoder_depth: 4,
}
npoints: 2048
total_bs: 64
step_per_update: 1
max_epoch: 30
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-B/finetune_scan_objonly.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 768,
depth: 12,
drop_path_rate: 0.2,
cls_dim: 15,
num_heads: 12,
group_size: 32,
num_group: 128,
encoder_dims: 768,
decoder_depth: 4,
}
npoints: 2048
total_bs: 64
step_per_update: 1
max_epoch: 50
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-B/post_pretrain.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 100, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/LabeledHybrid.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/LabeledHybrid.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/LabeledHybrid.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 768,
depth: 12,
drop_path_rate: 0.2,
cls_dim: 87,
num_heads: 12,
group_size: 32,
num_group: 64,
encoder_dims: 768,
decoder_depth: 4,
loss: cdl2,
weight_center: 1,
}
npoints: 1024
total_bs: 256
step_per_update: 1
max_epoch: 100
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-B/pretrain.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,
others: { subset: "train", npoints: 1024 },
},
val:
{
_base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,
others: { subset: "test", npoints: 1024 },
},
test:
{
_base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,
others: { subset: "test", npoints: 1024 },
},
}
model:
{
NAME: PointGPT,
cls_dim: 40,
group_size: 32,
num_group: 64,
loss: cdl12,
weight_center: 1,
transformer_config:
{
mask_ratio: 0.7,
mask_type: "rand",
trans_dim: 768,
encoder_dims: 768,
depth: 12,
drop_path_rate: 0.1,
num_heads: 12,
decoder_depth: 4,
decoder_num_heads: 12,
},
}
npoints: 1024
total_bs: 128
step_per_update: 1
max_epoch: 300
================================================
FILE: cfgs/PointGPT-L/fewshot.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0005, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 30 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 768,
depth: 12,
drop_path_rate: 0.1,
cls_dim: 40,
num_heads: 12,
group_size: 32,
num_group: 64,
encoder_dims: 768,
decoder_depth: 4,
}
npoints: 1024
total_bs: 32
step_per_update: 1
max_epoch: 300
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-L/finetune_modelnet.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 1024,
depth: 24,
drop_path_rate: 0.2,
cls_dim: 40,
num_heads: 16,
group_size: 32,
num_group: 64,
encoder_dims: 1024,
decoder_depth: 4,
loss: cdl2,
weight_center: 1,
}
npoints: 1024
total_bs: 128
step_per_update: 1
max_epoch: 50
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-L/finetune_modelnet_8k.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.00005, weight_decay: 0.005 } }
scheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 1024,
depth: 24,
drop_path_rate: 0.2,
cls_dim: 40,
num_heads: 16,
group_size: 32,
num_group: 512,
encoder_dims: 1024,
decoder_depth: 4,
loss: cdl2,
weight_center: 1,
}
npoints: 8192
total_bs: 32
step_per_update: 1
max_epoch: 50
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-L/finetune_scan_hardest.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 1024,
depth: 24,
drop_path_rate: 0.2,
cls_dim: 15,
num_heads: 16,
group_size: 32,
num_group: 128,
encoder_dims: 1024,
decoder_depth: 4,
}
npoints: 2048
total_bs: 64
step_per_update: 1
max_epoch: 50
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-L/finetune_scan_objbg.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 1024,
depth: 24,
drop_path_rate: 0.2,
cls_dim: 15,
num_heads: 16,
group_size: 32,
num_group: 128,
encoder_dims: 1024,
decoder_depth: 4,
}
npoints: 2048
total_bs: 64
step_per_update: 1
max_epoch: 50
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-L/finetune_scan_objonly.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 1024,
depth: 24,
drop_path_rate: 0.2,
cls_dim: 15,
num_heads: 16,
group_size: 32,
num_group: 128,
encoder_dims: 1024,
decoder_depth: 4,
}
npoints: 2048
total_bs: 64
step_per_update: 1
max_epoch: 50
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-L/post_pretrain.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 100, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/LabeledHybrid.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/LabeledHybrid.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/LabeledHybrid.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 1024,
depth: 24,
drop_path_rate: 0.2,
cls_dim: 87,
num_heads: 16,
group_size: 32,
num_group: 64,
encoder_dims: 1024,
decoder_depth: 4,
loss: cdl2,
weight_center: 1,
}
npoints: 1024
total_bs: 256
step_per_update: 1
max_epoch: 100
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-L/pretrain.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.00006, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 600, initial_epochs: 80 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,
others: { subset: "train", npoints: 1024 },
},
val:
{
_base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,
others: { subset: "test", npoints: 1024 },
},
test:
{
_base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,
others: { subset: "test", npoints: 1024 },
},
}
model:
{
NAME: PointGPT,
cls_dim: 40,
group_size: 32,
num_group: 64,
loss: cdl12,
weight_center: 1,
transformer_config:
{
mask_ratio: 0.7,
mask_type: "rand",
trans_dim: 1024,
encoder_dims: 1024,
depth: 24,
drop_path_rate: 0.1,
num_heads: 16,
decoder_depth: 4,
decoder_num_heads: 12,
},
}
npoints: 1024
total_bs: 128
step_per_update: 1
max_epoch: 350
================================================
FILE: cfgs/PointGPT-S/fewshot.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0005, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 30 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 384,
depth: 12,
drop_path_rate: 0.1,
cls_dim: 40,
num_heads: 6,
group_size: 32,
num_group: 64,
encoder_dims: 384,
decoder_depth: 4,
}
npoints: 1024
total_bs: 32
step_per_update: 1
max_epoch: 300
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-S/finetune_modelnet.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 384,
depth: 12,
drop_path_rate: 0.1,
cls_dim: 40,
num_heads: 6,
group_size: 32,
num_group: 64,
encoder_dims: 384,
decoder_depth: 4,
loss: cdl2,
weight_center: 1,
}
npoints: 1024
total_bs: 128
step_per_update: 1
max_epoch: 300
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-S/finetune_modelnet_8k.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.005 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ModelNet40.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 384,
depth: 12,
drop_path_rate: 0.1,
cls_dim: 40,
num_heads: 6,
group_size: 32,
num_group: 512,
encoder_dims: 384,
decoder_depth: 4,
loss: cdl2,
weight_center: 1,
}
npoints: 8192
total_bs: 32
step_per_update: 1
max_epoch: 300
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-S/finetune_scan_hardest.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 30 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 384,
depth: 12,
drop_path_rate: 0.1,
cls_dim: 15,
num_heads: 6,
group_size: 32,
num_group: 128,
encoder_dims: 384,
decoder_depth: 4,
}
npoints: 2048
total_bs: 64
step_per_update: 1
max_epoch: 300
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-S/finetune_scan_objbg.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 384,
depth: 12,
drop_path_rate: 0.1,
cls_dim: 15,
num_heads: 6,
group_size: 32,
num_group: 128,
encoder_dims: 384,
decoder_depth: 4,
}
npoints: 2048
total_bs: 32
step_per_update: 1
max_epoch: 300
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-S/finetune_scan_objonly.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "train" },
},
val:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "test" },
},
test:
{
_base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,
others: { subset: "test" },
},
}
model:
{
NAME: PointTransformer,
trans_dim: 384,
depth: 12,
drop_path_rate: 0.1,
cls_dim: 15,
num_heads: 6,
group_size: 32,
num_group: 128,
encoder_dims: 384,
decoder_depth: 4,
}
npoints: 2048
total_bs: 32
step_per_update: 1
max_epoch: 300
grad_norm_clip: 10
================================================
FILE: cfgs/PointGPT-S/pretrain.yaml
================================================
optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }
scheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }
dataset:
{
train:
{
_base_: cfgs/dataset_configs/ShapeNet-55.yaml,
others: { subset: "train", npoints: 1024 },
},
val:
{
_base_: cfgs/dataset_configs/ShapeNet-55.yaml,
others: { subset: "test", npoints: 1024 },
},
test:
{
_base_: cfgs/dataset_configs/ShapeNet-55.yaml,
others: { subset: "test", npoints: 1024 },
},
}
model:
{
NAME: PointGPT,
cls_dim: 40,
group_size: 32,
num_group: 64,
loss: cdl12,
weight_center: 1,
transformer_config:
{
mask_ratio: 0.7,
mask_type: "rand",
trans_dim: 384,
encoder_dims: 384,
depth: 12,
drop_path_rate: 0.1,
num_heads: 6,
decoder_depth: 4,
decoder_num_heads: 6,
},
}
npoints: 1024
total_bs: 64
step_per_update: 1
max_epoch: 300
================================================
FILE: cfgs/dataset_configs/LabeledHybrid.yaml
================================================
NAME: LabeledHybrid
DATA_PATH: data/HybridDatasets/post_pretrain
N_POINTS: 2048
PC_PATH: data/HybridDatasets
npoints: 1024
NUM_CATEGORY: 87
================================================
FILE: cfgs/dataset_configs/ModelNet40.yaml
================================================
NAME: ModelNet
DATA_PATH: data/ModelNet/modelnet40_normal_resampled
N_POINTS: 8192
NUM_CATEGORY: 40
USE_NORMALS: FALSE
================================================
FILE: cfgs/dataset_configs/ModelNet40FewShot.yaml
================================================
NAME: ModelNetFewShot
DATA_PATH: data/ModelNetFewshot
N_POINTS: 8192
NUM_CATEGORY: 40
USE_NORMALS: FALSE
================================================
FILE: cfgs/dataset_configs/ScanObjectNN_hardest.yaml
================================================
NAME: ScanObjectNN_hardest
ROOT: data/ScanObjectNN/h5_files/main_split
================================================
FILE: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml
================================================
NAME: ScanObjectNN
ROOT: data/ScanObjectNN/h5_files/main_split
================================================
FILE: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml
================================================
NAME: ScanObjectNN
ROOT: data/ScanObjectNN/h5_files/main_split_nobg
================================================
FILE: cfgs/dataset_configs/ShapeNet-55.yaml
================================================
NAME: ShapeNet
DATA_PATH: data/ShapeNet55-34/ShapeNet-55
N_POINTS: 8192
PC_PATH: data/ShapeNet55-34/shapenet_pc
================================================
FILE: cfgs/dataset_configs/UnlabeledHybrid.yaml
================================================
NAME: UnlabeledHybrid
DATA_PATH: data/HybridDatasets/pretrain
N_POINTS: 2048
PC_PATH: data/HybridDatasets
================================================
FILE: datasets/LabeledHybrid.py
================================================
import os
import torch
import numpy as np
import torch.utils.data as data
from .io import IO
from .build import DATASETS
from utils.logger import *
@DATASETS.register_module()
class LabeledHybrid(data.Dataset):
def __init__(self, config):
self.data_root = config.DATA_PATH
self.pc_path = config.PC_PATH
self.subset = config.subset
self.npoints = config.N_POINTS
self.data_list_file = os.path.join(self.data_root, f'{self.subset}.txt')
self.label_list_file = os.path.join(self.data_root, f'{self.subset}_num.txt')
self.sample_points_num = config.npoints
print_log(f'[DATASET] sample out {self.sample_points_num} points', logger = 'LabeledHybrid')
print_log(f'[DATASET] Open file {self.data_list_file}', logger = 'LabeledHybrid')
with open(self.data_list_file, 'r') as f:
lines = f.readlines()
print_log(f'[DATASET] Open file {self.label_list_file}', logger = 'LabeledHybrid')
with open(self.label_list_file, 'r') as f:
lines_label = f.readlines()
self.file_list = []
for line in lines:
self.file_list.append(line.strip())
print_log(f'[DATASET] {len(self.file_list)} instances were loaded', logger = 'LabeledHybrid')
self.label_list = []
for line_label in lines_label:
self.label_list.append(np.array(int(line_label.strip())))
print_log(f'[DATASET] {len(self.label_list)} labels were loaded', logger = 'LabeledHybrid')
def pc_norm(self, pc):
""" pc: NxC, return NxC """
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def random_sample(self, pc, num):
permutation = np.arange(pc.shape[0])
np.random.shuffle(permutation)
pc = pc[permutation[:num]]
return pc
def __getitem__(self, idx):
sample = self.file_list[idx]
label = self.label_list[idx]
data = IO.get(os.path.join(self.pc_path, sample)).astype(np.float32)
data = self.random_sample(data, self.sample_points_num)
data = self.pc_norm(data)
data = torch.from_numpy(data).float()
return 'LabeledHybrid', 'sample', (data, label)
def __len__(self):
return len(self.file_list)
================================================
FILE: datasets/ModelNetDataset.py
================================================
'''
@author: Xu Yan
@file: ModelNet.py
@time: 2021/3/19 15:51
'''
import os
import numpy as np
import warnings
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset
from .build import DATASETS
from utils.logger import *
import torch
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:,:3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
@DATASETS.register_module()
class ModelNet(Dataset):
def __init__(self, config):
self.root = config.DATA_PATH
self.npoints = config.N_POINTS
self.use_normals = config.USE_NORMALS
self.num_category = config.NUM_CATEGORY
self.process_data = True
self.uniform = True
split = config.subset
self.subset = config.subset
if self.num_category == 10:
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
else:
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
shape_ids = {}
if self.num_category == 10:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
else:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print_log('The size of %s data is %d' % (split, len(self.datapath)), logger = 'ModelNet')
if self.uniform:
self.save_path = os.path.join(self.root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
else:
self.save_path = os.path.join(self.root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))
if self.process_data:
if not os.path.exists(self.save_path):
print_log('Processing data %s (only running in the first time)...' % self.save_path, logger = 'ModelNet')
self.list_of_points = [None] * len(self.datapath)
self.list_of_labels = [None] * len(self.datapath)
for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
self.list_of_points[index] = point_set
self.list_of_labels[index] = cls
with open(self.save_path, 'wb') as f:
pickle.dump([self.list_of_points, self.list_of_labels], f)
else:
print_log('Load processed data from %s...' % self.save_path, logger = 'ModelNet')
with open(self.save_path, 'rb') as f:
self.list_of_points, self.list_of_labels = pickle.load(f)
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if self.process_data:
point_set, label = self.list_of_points[index], self.list_of_labels[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
label = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.use_normals:
point_set = point_set[:, 0:3]
return point_set, label[0]
def __getitem__(self, index):
points, label = self._get_item(index)
pt_idxs = np.arange(0, points.shape[0]) # 2048
if self.subset == 'train':
np.random.shuffle(pt_idxs)
current_points = points[pt_idxs].copy()
current_points = torch.from_numpy(current_points).float()
return 'ModelNet', 'sample', (current_points, label)
================================================
FILE: datasets/ModelNetDatasetFewShot.py
================================================
'''
@author: Xu Yan
@file: ModelNet.py
@time: 2021/3/19 15:51
'''
import os
import numpy as np
import warnings
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset
from .build import DATASETS
from utils.logger import *
import torch
import random
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
@DATASETS.register_module()
class ModelNetFewShot(Dataset):
def __init__(self, config):
self.root = config.DATA_PATH
self.npoints = config.N_POINTS
self.use_normals = config.USE_NORMALS
self.num_category = config.NUM_CATEGORY
self.process_data = True
self.uniform = True
split = config.subset
self.subset = config.subset
self.way = config.way
self.shot = config.shot
self.fold = config.fold
if self.way == -1 or self.shot == -1 or self.fold == -1:
raise RuntimeError()
self.pickle_path = os.path.join(self.root, f'{self.way}way_{self.shot}shot', f'{self.fold}.pkl')
print_log('Load processed data from %s...' % self.pickle_path, logger = 'ModelNetFewShot')
with open(self.pickle_path, 'rb') as f:
self.dataset = pickle.load(f)[self.subset]
print_log('The size of %s data is %d' % (split, len(self.dataset)), logger = 'ModelNetFewShot')
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
points, label, _ = self.dataset[index]
points[:, 0:3] = pc_normalize(points[:, 0:3])
if not self.use_normals:
points = points[:, 0:3]
pt_idxs = np.arange(0, points.shape[0]) # 2048
if self.subset == 'train':
np.random.shuffle(pt_idxs)
current_points = points[pt_idxs].copy()
current_points = torch.from_numpy(current_points).float()
return 'ModelNet', 'sample', (current_points, label)
================================================
FILE: datasets/ScanObjectNNDataset.py
================================================
import numpy as np
import os, sys, h5py
from torch.utils.data import Dataset
import torch
from .build import DATASETS
from utils.logger import *
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
@DATASETS.register_module()
class ScanObjectNN(Dataset):
def __init__(self, config, **kwargs):
super().__init__()
self.subset = config.subset
self.root = config.ROOT
if self.subset == 'train':
h5 = h5py.File(os.path.join(self.root, 'training_objectdataset.h5'), 'r')
self.points = np.array(h5['data']).astype(np.float32)
self.labels = np.array(h5['label']).astype(int)
h5.close()
elif self.subset == 'test':
h5 = h5py.File(os.path.join(self.root, 'test_objectdataset.h5'), 'r')
self.points = np.array(h5['data']).astype(np.float32)
self.labels = np.array(h5['label']).astype(int)
h5.close()
else:
raise NotImplementedError()
print(f'Successfully load ScanObjectNN shape of {self.points.shape}')
def __getitem__(self, idx):
pt_idxs = np.arange(0, self.points.shape[1]) # 2048
if self.subset == 'train':
np.random.shuffle(pt_idxs)
current_points = self.points[idx, pt_idxs].copy()
current_points = torch.from_numpy(current_points).float()
label = self.labels[idx]
return 'ScanObjectNN', 'sample', (current_points, label)
def __len__(self):
return self.points.shape[0]
@DATASETS.register_module()
class ScanObjectNN_hardest(Dataset):
def __init__(self, config, **kwargs):
super().__init__()
self.subset = config.subset
self.root = config.ROOT
if self.subset == 'train':
h5 = h5py.File(os.path.join(self.root, 'training_objectdataset_augmentedrot_scale75.h5'), 'r')
self.points = np.array(h5['data']).astype(np.float32)
self.labels = np.array(h5['label']).astype(int)
h5.close()
elif self.subset == 'test':
h5 = h5py.File(os.path.join(self.root, 'test_objectdataset_augmentedrot_scale75.h5'), 'r')
self.points = np.array(h5['data']).astype(np.float32)
self.labels = np.array(h5['label']).astype(int)
h5.close()
else:
raise NotImplementedError()
print(f'Successfully load ScanObjectNN shape of {self.points.shape}')
def __getitem__(self, idx):
pt_idxs = np.arange(0, self.points.shape[1]) # 2048
if self.subset == 'train':
np.random.shuffle(pt_idxs)
current_points = self.points[idx, pt_idxs].copy()
current_points = torch.from_numpy(current_points).float()
label = self.labels[idx]
return 'ScanObjectNN', 'sample', (current_points, label)
def __len__(self):
return self.points.shape[0]
================================================
FILE: datasets/ShapeNet55Dataset.py
================================================
import os
import torch
import numpy as np
import torch.utils.data as data
from .io import IO
from .build import DATASETS
from utils.logger import *
@DATASETS.register_module()
class ShapeNet(data.Dataset):
def __init__(self, config):
self.data_root = config.DATA_PATH
self.pc_path = config.PC_PATH
self.subset = config.subset
self.npoints = config.N_POINTS
self.data_list_file = os.path.join(
self.data_root, f'{self.subset}.txt')
test_data_list_file = os.path.join(self.data_root, 'test.txt')
self.sample_points_num = config.npoints
self.whole = config.get('whole')
print_log(
f'[DATASET] sample out {self.sample_points_num} points', logger='ShapeNet-55')
print_log(
f'[DATASET] Open file {self.data_list_file}', logger='ShapeNet-55')
with open(self.data_list_file, 'r') as f:
lines = f.readlines()
if self.whole:
with open(test_data_list_file, 'r') as f:
test_lines = f.readlines()
print_log(
f'[DATASET] Open file {test_data_list_file}', logger='ShapeNet-55')
lines = test_lines + lines
self.file_list = []
for line in lines:
line = line.strip()
taxonomy_id = line.split('-')[0]
model_id = line.split('-')[1].split('.')[0]
self.file_list.append({
'taxonomy_id': taxonomy_id,
'model_id': model_id,
'file_path': line
})
print_log(
f'[DATASET] {len(self.file_list)} instances were loaded', logger='ShapeNet-55')
self.permutation = np.arange(self.npoints)
def pc_norm(self, pc):
""" pc: NxC, return NxC """
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def random_sample(self, pc, num):
np.random.shuffle(self.permutation)
pc = pc[self.permutation[:num]]
return pc
def __getitem__(self, idx):
sample = self.file_list[idx]
data = IO.get(os.path.join(
self.pc_path, sample['file_path'])).astype(np.float32)
data = self.random_sample(data, self.sample_points_num)
data = self.pc_norm(data)
data = torch.from_numpy(data).float()
return sample['taxonomy_id'], sample['model_id'], data
def __len__(self):
return len(self.file_list)
================================================
FILE: datasets/UnlabeledHybrid.py
================================================
import os
import torch
import numpy as np
import torch.utils.data as data
from .io import IO
from .build import DATASETS
from utils.logger import *
@DATASETS.register_module()
class UnlabeledHybrid(data.Dataset):
def __init__(self, config):
self.data_root = config.DATA_PATH
self.pc_path = config.PC_PATH
self.subset = config.subset
self.npoints = config.N_POINTS
self.data_list_file = os.path.join(
self.data_root, f'{self.subset}.txt')
test_data_list_file = os.path.join(self.data_root, 'test.txt')
self.sample_points_num = config.npoints
self.whole = config.get('whole')
print_log(
f'[DATASET] sample out {self.sample_points_num} points', logger='UnlabeledHybrid')
print_log(
f'[DATASET] Open file {self.data_list_file}', logger='UnlabeledHybrid')
with open(self.data_list_file, 'r') as f:
lines = f.readlines()
if self.whole:
with open(test_data_list_file, 'r') as f:
test_lines = f.readlines()
print_log(
f'[DATASET] Open file {test_data_list_file}', logger='UnlabeledHybrid')
lines = test_lines + lines
self.file_list = []
for line in lines:
line = line.strip()
taxonomy_id = ''
model_id = ''
self.file_list.append({
'taxonomy_id': taxonomy_id,
'model_id': model_id,
'file_path': line
})
print_log(
f'[DATASET] {len(self.file_list)} instances were loaded', logger='UnlabeledHybrid')
self.permutation = np.arange(self.npoints)
def pc_norm(self, pc):
""" pc: NxC, return NxC """
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def random_sample(self, pc, num):
permutation = np.arange(pc.shape[0])
np.random.shuffle(permutation)
pc = pc[permutation[:num]]
return pc
def __getitem__(self, idx):
sample = self.file_list[idx]
data = IO.get(os.path.join(
self.pc_path, sample['file_path'])).astype(np.float32)
data = self.random_sample(data, self.sample_points_num)
data = self.pc_norm(data)
data = torch.from_numpy(data).float()
# sample['taxonomy_id'] and sample['model_id'] are not utilized
return sample['taxonomy_id'], sample['model_id'], data
def __len__(self):
return len(self.file_list)
================================================
FILE: datasets/__init__.py
================================================
from .build import build_dataset_from_cfg
import datasets.ShapeNet55Dataset
import datasets.ModelNetDataset
import datasets.ModelNetDatasetFewShot
import datasets.ScanObjectNNDataset
import datasets.LabeledHybrid
import datasets.UnlabeledHybrid
================================================
FILE: datasets/build.py
================================================
from utils import registry
DATASETS = registry.Registry('dataset')
def build_dataset_from_cfg(cfg, default_args = None):
"""
Build a dataset, defined by `dataset_name`.
Args:
cfg (eDICT):
Returns:
Dataset: a constructed dataset specified by dataset_name.
"""
return DATASETS.build(cfg, default_args = default_args)
================================================
FILE: datasets/data_transforms.py
================================================
import numpy as np
import torch
import random
class PointcloudRotate(object):
def __call__(self, pc):
bsize = pc.size()[0]
for i in range(bsize):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
R = torch.from_numpy(rotation_matrix.astype(np.float32)).to(pc.device)
pc[i, :, :] = torch.matmul(pc[i], R)
return pc
class PointcloudScaleAndTranslate(object):
def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2):
self.scale_low = scale_low
self.scale_high = scale_high
self.translate_range = translate_range
def __call__(self, pc):
bsize = pc.size()[0]
for i in range(bsize):
xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3])
xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3])
pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy(xyz2).float().cuda()
return pc
class PointcloudJitter(object):
def __init__(self, std=0.01, clip=0.05):
self.std, self.clip = std, clip
def __call__(self, pc):
bsize = pc.size()[0]
for i in range(bsize):
jittered_data = pc.new(pc.size(1), 3).normal_(
mean=0.0, std=self.std
).clamp_(-self.clip, self.clip)
pc[i, :, 0:3] += jittered_data
return pc
class PointcloudScale(object):
def __init__(self, scale_low=2. / 3., scale_high=3. / 2.):
self.scale_low = scale_low
self.scale_high = scale_high
def __call__(self, pc):
bsize = pc.size()[0]
for i in range(bsize):
xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3])
pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda())
return pc
class PointcloudTranslate(object):
def __init__(self, translate_range=0.2):
self.translate_range = translate_range
def __call__(self, pc):
bsize = pc.size()[0]
for i in range(bsize):
xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3])
pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda()
return pc
class PointcloudRandomInputDropout(object):
def __init__(self, max_dropout_ratio=0.5):
assert max_dropout_ratio >= 0 and max_dropout_ratio < 1
self.max_dropout_ratio = max_dropout_ratio
def __call__(self, pc):
bsize = pc.size()[0]
for i in range(bsize):
dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0]
if len(drop_idx) > 0:
cur_pc = pc[i, :, :]
cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1) # set to the first point
pc[i, :, :] = cur_pc
return pc
class RandomHorizontalFlip(object):
def __init__(self, upright_axis = 'z', is_temporal=False):
"""
upright_axis: axis index among x,y,z, i.e. 2 for z
"""
self.is_temporal = is_temporal
self.D = 4 if is_temporal else 3
self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()]
# Use the rest of axes for flipping.
self.horz_axes = set(range(self.D)) - set([self.upright_axis])
def __call__(self, coords):
bsize = coords.size()[0]
for i in range(bsize):
if random.random() < 0.95:
for curr_ax in self.horz_axes:
if random.random() < 0.5:
coord_max = torch.max(coords[i, :, curr_ax])
coords[i, :, curr_ax] = coord_max - coords[i, :, curr_ax]
return coords
================================================
FILE: datasets/generate_few_shot_data.py
================================================
import pickle
import numpy as np
import random
import os
root = '../data/ModelNet/modelnet40_normal_resampled'
target = '../data/ModelNetFewshot'
train_data_path = os.path.join(root, 'modelnet40_train_8192pts_fps.dat')
test_data_path = os.path.join(root, 'modelnet40_test_8192pts_fps.dat')
# train
with open(train_data_path, 'rb') as f:
train_list_of_points, train_list_of_labels = pickle.load(f)
with open(test_data_path, 'rb') as f:
test_list_of_points, test_list_of_labels = pickle.load(f)
# list_of_points = train_list_of_points + test_list_of_points
# list_of_labels = train_list_of_labels + test_list_of_labels
def generate_fewshot_data(way, shot, prefix_ind, eval_sample=20):
train_cls_dataset = {}
test_cls_dataset = {}
train_dataset = []
test_dataset = []
# build a dict containing different class
for point, label in zip(train_list_of_points, train_list_of_labels):
label = label[0]
if train_cls_dataset.get(label) is None:
train_cls_dataset[label] = []
train_cls_dataset[label].append(point)
# build a dict containing different class
for point, label in zip(test_list_of_points, test_list_of_labels):
label = label[0]
if test_cls_dataset.get(label) is None:
test_cls_dataset[label] = []
test_cls_dataset[label].append(point)
print(sum([train_cls_dataset[i].__len__() for i in range(40)]))
print(sum([test_cls_dataset[i].__len__() for i in range(40)]))
# import pdb; pdb.set_trace()
keys = list(train_cls_dataset.keys())
random.shuffle(keys)
for i, key in enumerate(keys[:way]):
train_data_list = train_cls_dataset[key]
random.shuffle(train_data_list)
assert len(train_data_list) > shot
for data in train_data_list[:shot]:
train_dataset.append((data, i, key))
test_data_list = test_cls_dataset[key]
random.shuffle(test_data_list)
# import pdb; pdb.set_trace()
assert len(test_data_list) >= eval_sample
for data in test_data_list[:eval_sample]:
test_dataset.append((data, i, key))
random.shuffle(train_dataset)
random.shuffle(test_dataset)
dataset = {
'train': train_dataset,
'test' : test_dataset
}
save_path = os.path.join(target, f'{way}way_{shot}shot')
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(os.path.join(save_path, f'{prefix_ind}.pkl'), 'wb') as f:
pickle.dump(dataset, f)
if __name__ == '__main__':
ways = [5, 10]
shots = [10, 20]
for way in ways:
for shot in shots:
for i in range(10):
generate_fewshot_data(way = way, shot = shot, prefix_ind = i)
================================================
FILE: datasets/io.py
================================================
import h5py
import numpy as np
# import open3d
import os
class IO:
@classmethod
def get(cls, file_path):
_, file_extension = os.path.splitext(file_path)
if file_extension in ['.npy']:
return cls._read_npy(file_path)
# elif file_extension in ['.pcd']:
# return cls._read_pcd(file_path)
elif file_extension in ['.h5']:
return cls._read_h5(file_path)
elif file_extension in ['.txt']:
return cls._read_txt(file_path)
else:
raise Exception('Unsupported file extension: %s' % file_extension)
# References: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py
@classmethod
def _read_npy(cls, file_path):
return np.load(file_path)
# References: https://github.com/dimatura/pypcd/blob/master/pypcd/pypcd.py#L275
# Support PCD files without compression ONLY!
# @classmethod
# def _read_pcd(cls, file_path):
# pc = open3d.io.read_point_cloud(file_path)
# ptcloud = np.array(pc.points)
# return ptcloud
@classmethod
def _read_txt(cls, file_path):
return np.loadtxt(file_path)
@classmethod
def _read_h5(cls, file_path):
f = h5py.File(file_path, 'r')
return f['data'][()]
================================================
FILE: extensions/chamfer_dist/__init__.py
================================================
# -*- coding: utf-8 -*-
# @Author: Thibault GROUEIX
# @Date: 2019-08-07 20:54:24
# @Last Modified by: Haozhe Xie
# @Last Modified time: 2019-12-18 15:06:25
# @Email: cshzxie@gmail.com
import torch
import chamfer
class ChamferFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2
@staticmethod
def backward(ctx, grad_dist1, grad_dist2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
return grad_xyz1, grad_xyz2
class ChamferDistanceL2(torch.nn.Module):
f''' Chamder Distance L2
'''
def __init__(self, ignore_zeros=False):
super().__init__()
self.ignore_zeros = ignore_zeros
def forward(self, xyz1, xyz2):
batch_size = xyz1.size(0)
if batch_size == 1 and self.ignore_zeros:
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
return torch.mean(dist1) + torch.mean(dist2)
class ChamferDistanceL2_split(torch.nn.Module):
f''' Chamder Distance L2
'''
def __init__(self, ignore_zeros=False):
super().__init__()
self.ignore_zeros = ignore_zeros
def forward(self, xyz1, xyz2):
batch_size = xyz1.size(0)
if batch_size == 1 and self.ignore_zeros:
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
return torch.mean(dist1), torch.mean(dist2)
class ChamferDistanceL1(torch.nn.Module):
f''' Chamder Distance L1
'''
def __init__(self, ignore_zeros=False):
super().__init__()
self.ignore_zeros = ignore_zeros
def forward(self, xyz1, xyz2):
batch_size = xyz1.size(0)
if batch_size == 1 and self.ignore_zeros:
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
# import pdb
# pdb.set_trace()
dist1 = torch.sqrt(dist1)
dist2 = torch.sqrt(dist2)
return (torch.mean(dist1) + torch.mean(dist2))/2
================================================
FILE: extensions/chamfer_dist/chamfer.cu
================================================
/*
* @Author: Haozhe Xie
* @Date: 2019-08-07 20:54:24
* @Last Modified by: Haozhe Xie
* @Last Modified time: 2020-06-17 14:58:55
* @Email: cshzxie@gmail.com
*/
#include
#include
#include
#include
__global__ void chamfer_dist_kernel(int batch_size,
int n,
const float* xyz1,
int m,
const float* xyz2,
float* dist,
int* indexes) {
const int batch = 512;
__shared__ float buf[batch * 3];
for (int i = blockIdx.x; i < batch_size; i += gridDim.x) {
for (int k2 = 0; k2 < m; k2 += batch) {
int end_k = min(m, k2 + batch) - k2;
for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) {
buf[j] = xyz2[(i * m + k2) * 3 + j];
}
__syncthreads();
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;
j += blockDim.x * gridDim.y) {
float x1 = xyz1[(i * n + j) * 3 + 0];
float y1 = xyz1[(i * n + j) * 3 + 1];
float z1 = xyz1[(i * n + j) * 3 + 2];
float best_dist = 0;
int best_dist_index = 0;
int end_ka = end_k - (end_k & 3);
if (end_ka == batch) {
for (int k = 0; k < batch; k += 4) {
{
float x2 = buf[k * 3 + 0] - x1;
float y2 = buf[k * 3 + 1] - y1;
float z2 = buf[k * 3 + 2] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (k == 0 || dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2;
}
}
{
float x2 = buf[k * 3 + 3] - x1;
float y2 = buf[k * 3 + 4] - y1;
float z2 = buf[k * 3 + 5] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 1;
}
}
{
float x2 = buf[k * 3 + 6] - x1;
float y2 = buf[k * 3 + 7] - y1;
float z2 = buf[k * 3 + 8] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 2;
}
}
{
float x2 = buf[k * 3 + 9] - x1;
float y2 = buf[k * 3 + 10] - y1;
float z2 = buf[k * 3 + 11] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 3;
}
}
}
} else {
for (int k = 0; k < end_ka; k += 4) {
{
float x2 = buf[k * 3 + 0] - x1;
float y2 = buf[k * 3 + 1] - y1;
float z2 = buf[k * 3 + 2] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (k == 0 || dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2;
}
}
{
float x2 = buf[k * 3 + 3] - x1;
float y2 = buf[k * 3 + 4] - y1;
float z2 = buf[k * 3 + 5] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 1;
}
}
{
float x2 = buf[k * 3 + 6] - x1;
float y2 = buf[k * 3 + 7] - y1;
float z2 = buf[k * 3 + 8] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 2;
}
}
{
float x2 = buf[k * 3 + 9] - x1;
float y2 = buf[k * 3 + 10] - y1;
float z2 = buf[k * 3 + 11] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 3;
}
}
}
}
for (int k = end_ka; k < end_k; k++) {
float x2 = buf[k * 3 + 0] - x1;
float y2 = buf[k * 3 + 1] - y1;
float z2 = buf[k * 3 + 2] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (k == 0 || dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2;
}
}
if (k2 == 0 || dist[(i * n + j)] > best_dist) {
dist[(i * n + j)] = best_dist;
indexes[(i * n + j)] = best_dist_index;
}
}
__syncthreads();
}
}
}
std::vector chamfer_cuda_forward(torch::Tensor xyz1,
torch::Tensor xyz2) {
const int batch_size = xyz1.size(0);
const int n = xyz1.size(1); // num_points point cloud A
const int m = xyz2.size(1); // num_points point cloud B
torch::Tensor dist1 =
torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat));
torch::Tensor dist2 =
torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat));
torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));
chamfer_dist_kernel<<>>(
batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(),
dist1.data_ptr(), idx1.data_ptr());
chamfer_dist_kernel<<>>(
batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(),
dist2.data_ptr(), idx2.data_ptr());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("Error in chamfer_cuda_forward: %s\n", cudaGetErrorString(err));
}
return {dist1, dist2, idx1, idx2};
}
__global__ void chamfer_dist_grad_kernel(int b,
int n,
const float* xyz1,
int m,
const float* xyz2,
const float* grad_dist1,
const int* idx1,
float* grad_xyz1,
float* grad_xyz2) {
for (int i = blockIdx.x; i < b; i += gridDim.x) {
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;
j += blockDim.x * gridDim.y) {
float x1 = xyz1[(i * n + j) * 3 + 0];
float y1 = xyz1[(i * n + j) * 3 + 1];
float z1 = xyz1[(i * n + j) * 3 + 2];
int j2 = idx1[i * n + j];
float x2 = xyz2[(i * m + j2) * 3 + 0];
float y2 = xyz2[(i * m + j2) * 3 + 1];
float z2 = xyz2[(i * m + j2) * 3 + 2];
float g = grad_dist1[i * n + j] * 2;
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));
}
}
}
std::vector chamfer_cuda_backward(torch::Tensor xyz1,
torch::Tensor xyz2,
torch::Tensor idx1,
torch::Tensor idx2,
torch::Tensor grad_dist1,
torch::Tensor grad_dist2) {
const int batch_size = xyz1.size(0);
const int n = xyz1.size(1); // num_points point cloud A
const int m = xyz2.size(1); // num_points point cloud B
torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat));
torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat));
chamfer_dist_grad_kernel<<>>(
batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(),
grad_dist1.data_ptr(), idx1.data_ptr(),
grad_xyz1.data_ptr(), grad_xyz2.data_ptr());
chamfer_dist_grad_kernel<<>>(
batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(),
grad_dist2.data_ptr(), idx2.data_ptr(),
grad_xyz2.data_ptr(), grad_xyz1.data_ptr());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("Error in chamfer_cuda_backward: %s\n", cudaGetErrorString(err));
}
return {grad_xyz1, grad_xyz2};
}
================================================
FILE: extensions/chamfer_dist/chamfer_cuda.cpp
================================================
/*
* @Author: Haozhe Xie
* @Date: 2019-08-07 20:54:24
* @Last Modified by: Haozhe Xie
* @Last Modified time: 2019-12-10 10:33:50
* @Email: cshzxie@gmail.com
*/
#include
#include
std::vector chamfer_cuda_forward(torch::Tensor xyz1,
torch::Tensor xyz2);
std::vector chamfer_cuda_backward(torch::Tensor xyz1,
torch::Tensor xyz2,
torch::Tensor idx1,
torch::Tensor idx2,
torch::Tensor grad_dist1,
torch::Tensor grad_dist2);
std::vector chamfer_forward(torch::Tensor xyz1,
torch::Tensor xyz2) {
return chamfer_cuda_forward(xyz1, xyz2);
}
std::vector chamfer_backward(torch::Tensor xyz1,
torch::Tensor xyz2,
torch::Tensor idx1,
torch::Tensor idx2,
torch::Tensor grad_dist1,
torch::Tensor grad_dist2) {
return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)");
}
================================================
FILE: extensions/chamfer_dist/setup.py
================================================
# -*- coding: utf-8 -*-
# @Author: Haozhe Xie
# @Date: 2019-08-07 20:54:24
# @Last Modified by: Haozhe Xie
# @Last Modified time: 2019-12-10 10:04:25
# @Email: cshzxie@gmail.com
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(name='chamfer',
version='2.0.0',
ext_modules=[
CUDAExtension('chamfer', [
'chamfer_cuda.cpp',
'chamfer.cu',
]),
],
cmdclass={'build_ext': BuildExtension})
================================================
FILE: extensions/chamfer_dist/test.py
================================================
# -*- coding: utf-8 -*-
# @Author: Haozhe Xie
# @Date: 2019-12-10 10:38:01
# @Last Modified by: Haozhe Xie
# @Last Modified time: 2019-12-26 14:21:36
# @Email: cshzxie@gmail.com
#
# Note:
# - Replace float -> double, kFloat -> kDouble in chamfer.cu
import os
import sys
import torch
import unittest
from torch.autograd import gradcheck
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
from extensions.chamfer_dist import ChamferFunction
class ChamferDistanceTestCase(unittest.TestCase):
def test_chamfer_dist(self):
x = torch.rand(4, 64, 3).double()
y = torch.rand(4, 128, 3).double()
x.requires_grad = True
y.requires_grad = True
print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()]))
if __name__ == '__main__':
# unittest.main()
import pdb
x = torch.rand(32,128,3)
y = torch.rand(32,128,3)
pdb.set_trace()
================================================
FILE: extensions/emd/README.md
================================================
# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)
## Dependency
The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.
## Usage
First compile using
python setup.py install
Then, copy the lib file out to the main directory,
cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
Then, you can use it by simply
from emd import earth_mover_distance
d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3
Check `test_emd_loss.py` for example.
## Author
The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
## License
MIT
================================================
FILE: extensions/emd/__init__.py
================================================
from .emd import earth_mover_distance as emd
__all__ = ['emd']
================================================
FILE: extensions/emd/cuda/emd.cpp
================================================
#ifndef _EMD
#define _EMD
#include
#include
//CUDA declarations
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2);
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match);
std::vector MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)");
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)");
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)");
}
#endif
================================================
FILE: extensions/emd/cuda/emd_kernel.cu
================================================
/**********************************
* Original Author: Haoqiang Fan
* Modified by: Kaichun Mo
*********************************/
#ifndef _EMD_KERNEL
#define _EMD_KERNEL
#include
#include
#include
#include // at::cuda::getApplyGrid
// #include
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/********************************
* Forward kernel for approxmatch
*********************************/
template
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
scalar_t multiL,multiR;
if (n>=m){
multiL=1;
multiR=n/m;
}else{
multiL=m/n;
multiR=1;
}
const int Block=1024;
__shared__ scalar_t buf[Block*4];
for (int i=blockIdx.x;i=-2;j--){
scalar_t level=-powf(4.0f,j);
if (j==-2){
level=0;
}
for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp);
//}
/* ApproxMatch forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
Output:
match: (B, N2, N1)
*/
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto match = at::zeros({b, m, n}, xyz1.type());
auto temp = at::zeros({b, (n+m)*2}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data());
}));
AT_CUDA_CHECK(cudaGetLastError());
return match;
}
/********************************
* Forward kernel for matchcost
*********************************/
template
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
__shared__ scalar_t allsum[512];
const int Block=1024;
__shared__ scalar_t buf[Block*3];
for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out);
//}
/* MatchCost forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
cost: (B)
*/
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto cost = at::zeros({b}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data());
}));
AT_CUDA_CHECK(cudaGetLastError());
return cost;
}
/********************************
* matchcostgrad2 kernel
*********************************/
template
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
__shared__ scalar_t sum_grad[256*3];
for (int i=blockIdx.x;i
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1);
// matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2);
//}
/* MatchCost backward interface
Input:
grad_cost: (B) # gradients on cost
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
grad1: (B, N1, 3)
grad2: (B, N2, 3)
*/
std::vector MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data());
matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data());
}));
AT_CUDA_CHECK(cudaGetLastError());
return std::vector({grad1, grad2});
}
#endif
================================================
FILE: extensions/emd/emd.py
================================================
import torch
import emd_cuda
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost):
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
class earth_mover_distance(torch.nn.Module):
f''' emd
'''
def __init__(self):
super().__init__()
def forward(self, xyz1, xyz2, transpose=False):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, n1, 3)
xyz2 (torch.Tensor): (b, n2, 3)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
cost = cost / xyz1.size(1)
return cost.mean()
# def earth_mover_distance(xyz1, xyz2, transpose=True):
# """Earth Mover Distance (Approx)
# Args:
# xyz1 (torch.Tensor): (b, 3, n1)
# xyz2 (torch.Tensor): (b, 3, n1)
# transpose (bool): whether to transpose inputs as it might be BCN format.
# Extensions only support BNC format.
# Returns:
# cost (torch.Tensor): (b)
# """
# if xyz1.dim() == 2:
# xyz1 = xyz1.unsqueeze(0)
# if xyz2.dim() == 2:
# xyz2 = xyz2.unsqueeze(0)
# if transpose:
# xyz1 = xyz1.transpose(1, 2)
# xyz2 = xyz2.transpose(1, 2)
# cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
# return cost
================================================
FILE: extensions/emd/setup.py
================================================
"""Setup extension
Notes:
If extra_compile_args is provided, you need to provide different instances for different extensions.
Refer to https://github.com/pytorch/pytorch/issues/20169
"""
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='emd_ext',
ext_modules=[
CUDAExtension(
name='emd_cuda',
sources=[
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
),
],
cmdclass={
'build_ext': BuildExtension
})
================================================
FILE: extensions/emd/test_emd_loss.py
================================================
import torch
import numpy as np
import time
from emd import earth_mover_distance
# gt
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
print(p1.shape)
p1.requires_grad = True
p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
print('gt_dist: ', gt_dist)
gt_dist.backward()
print(p1.grad)
print(p2.grad)
# emd
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
d = earth_mover_distance(p1, p2, transpose=False)
print(d)
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
print(loss)
loss.backward()
print(p1.grad)
print(p2.grad)
================================================
FILE: figures/a
================================================
================================================
FILE: main.py
================================================
from tools import pretrain_run_net as pretrain
from tools import finetune_run_net as finetune
from tools import test_run_net as test_net
from utils import parser, dist_utils, misc
from utils.logger import *
from utils.config import *
import time
import os
import torch
from tensorboardX import SummaryWriter
from torchstat import stat
def main():
# args
args = parser.get_args()
# CUDA
args.use_gpu = torch.cuda.is_available()
if args.use_gpu:
torch.backends.cudnn.benchmark = True
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
args.distributed = False
else:
args.distributed = True
dist_utils.init_dist(args.launcher)
# re-set gpu_ids with distributed training mode
_, world_size = dist_utils.get_dist_info()
args.world_size = world_size
# logger
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = os.path.join(args.experiment_path, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, name=args.log_name)
# define the tensorboard writer
if not args.test:
if args.local_rank == 0:
train_writer = SummaryWriter(
os.path.join(args.tfboard_path, 'train'))
val_writer = SummaryWriter(os.path.join(args.tfboard_path, 'test'))
else:
train_writer = None
val_writer = None
# config
config = get_config(args, logger=logger)
# batch size
if args.distributed:
assert config.total_bs % world_size == 0
config.dataset.train.others.bs = config.total_bs // world_size
if config.dataset.get('extra_train'):
config.dataset.extra_train.others.bs = config.total_bs // world_size * 2
config.dataset.val.others.bs = config.total_bs // world_size * 2
if config.dataset.get('test'):
config.dataset.test.others.bs = config.total_bs // world_size
else:
config.dataset.train.others.bs = config.total_bs
if config.dataset.get('extra_train'):
config.dataset.extra_train.others.bs = config.total_bs * 2
config.dataset.val.others.bs = config.total_bs * 2
if config.dataset.get('test'):
config.dataset.test.others.bs = config.total_bs
# log
log_args_to_file(args, 'args', logger=logger)
log_config_to_file(config, 'config', logger=logger)
# exit()
logger.info(f'Distributed training: {args.distributed}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
# seed + rank, for augmentation
misc.set_random_seed(args.seed + args.local_rank,
deterministic=args.deterministic)
if args.distributed:
assert args.local_rank == torch.distributed.get_rank()
if args.shot != -1:
config.dataset.train.others.shot = args.shot
config.dataset.train.others.way = args.way
config.dataset.train.others.fold = args.fold
config.dataset.val.others.shot = args.shot
config.dataset.val.others.way = args.way
config.dataset.val.others.fold = args.fold
# run
if args.test:
test_net(args, config)
else:
if args.finetune_model or args.scratch_model:
finetune(args, config, train_writer, val_writer)
else:
pretrain(args, config, train_writer, val_writer)
if __name__ == '__main__':
main()
================================================
FILE: main_vis.py
================================================
# from tools import run_net
from tools import test_net
from utils import parser, dist_utils, misc
from utils.logger import *
from utils.config import *
import time
import os
import torch
from tensorboardX import SummaryWriter
def main():
# args
args = parser.get_args()
# CUDA
args.use_gpu = torch.cuda.is_available()
if args.use_gpu:
torch.backends.cudnn.benchmark = True
# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
args.distributed = False
else:
args.distributed = True
dist_utils.init_dist(args.launcher)
# re-set gpu_ids with distributed training mode
_, world_size = dist_utils.get_dist_info()
args.world_size = world_size
# logger
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = os.path.join(args.experiment_path, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, name=args.log_name)
# define the tensorboard writer
if not args.test:
if args.local_rank == 0:
train_writer = SummaryWriter(os.path.join(args.tfboard_path, 'train'))
val_writer = SummaryWriter(os.path.join(args.tfboard_path, 'test'))
else:
train_writer = None
val_writer = None
# config
config = get_config(args, logger = logger)
# batch size
if args.distributed:
assert config.total_bs % world_size == 0
config.dataset.train.others.bs = config.total_bs // world_size
config.dataset.val.others.bs = 1
config.dataset.test.others.bs = 1
else:
config.dataset.train.others.bs = config.total_bs
config.dataset.val.others.bs = 1
config.dataset.test.others.bs = 1
# log
log_args_to_file(args, 'args', logger = logger)
log_config_to_file(config, 'config', logger = logger)
# exit()
logger.info(f'Distributed training: {args.distributed}')
# set random seeds
if args.seed is not None:
logger.info(f'Set random seed to {args.seed}, '
f'deterministic: {args.deterministic}')
misc.set_random_seed(args.seed + args.local_rank, deterministic=args.deterministic) # seed + rank, for augmentation
if args.distributed:
assert args.local_rank == torch.distributed.get_rank()
# run
if args.test:
test_net(args, config)
else:
# run_net(args, config, train_writer, val_writer)
raise NotImplementedError
if __name__ == '__main__':
main()
================================================
FILE: models/GPT.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class Block(nn.Module):
def __init__(self, embed_dim, num_heads):
super(Block, self).__init__()
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)
def forward(self, x, attn_mask):
x = self.ln_1(x)
# a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
return x
class GPT_extractor(nn.Module):
def __init__(
self, embed_dim, num_heads, num_layers, num_classes, trans_dim, group_size, pretrained=False
):
super(GPT_extractor, self).__init__()
self.embed_dim = embed_dim
self.trans_dim = trans_dim
self.group_size = group_size
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(Block(embed_dim, num_heads))
self.ln_f = nn.LayerNorm(embed_dim)
# prediction head
self.increase_dim = nn.Sequential(
nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)
)
if pretrained == False:
self.cls_head_finetune = nn.Sequential(
nn.Linear(self.trans_dim * 2, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, 256),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
self.cls_norm = nn.LayerNorm(self.trans_dim)
def forward(self, h, pos, attn_mask, classify=False):
"""
Expect input as shape [sequence len, batch]
If classify, return classification logits
"""
batch, length, C = h.shape
h = h.transpose(0, 1)
pos = pos.transpose(0, 1)
# prepend sos token
sos = torch.ones(1, batch, self.embed_dim, device=h.device) * self.sos
if not classify:
h = torch.cat([sos, h[:-1, :, :]], axis=0)
else:
h = torch.cat([sos, h], axis=0)
# transformer
for layer in self.layers:
h = layer(h + pos, attn_mask)
h = self.ln_f(h)
encoded_points = h.transpose(0, 1)
if not classify:
return encoded_points
h = h.transpose(0, 1)
h = self.cls_norm(h)
concat_f = torch.cat([h[:, 1], h[:, 2:].max(1)[0]], dim=-1)
ret = self.cls_head_finetune(concat_f)
return ret, encoded_points
class GPT_generator(nn.Module):
def __init__(
self, embed_dim, num_heads, num_layers, trans_dim, group_size
):
super(GPT_generator, self).__init__()
self.embed_dim = embed_dim
self.trans_dim = trans_dim
self.group_size = group_size
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(Block(embed_dim, num_heads))
self.ln_f = nn.LayerNorm(embed_dim)
self.increase_dim = nn.Sequential(
nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)
)
def forward(self, h, pos, attn_mask):
"""
Expect input as shape [sequence len, batch]
If classify, return classification logits
"""
batch, length, C = h.shape
h = h.transpose(0, 1)
pos = pos.transpose(0, 1)
# transformer
for layer in self.layers:
h = layer(h + pos, attn_mask)
h = self.ln_f(h)
rebuild_points = self.increase_dim(h.transpose(1, 2)).transpose(
1, 2).transpose(0, 1).reshape(batch * length, -1, 3)
return rebuild_points
================================================
FILE: models/PointGPT.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.models.layers import DropPath, trunc_normal_
import numpy as np
from .build import MODELS
from utils import misc
from utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from utils.logger import *
import random
from knn_cuda import KNN
from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2
from models.GPT import GPT_extractor, GPT_generator
import math
from models.z_order import *
class Encoder_large(nn.Module): # Embedding module
def __init__(self, encoder_channel):
super().__init__()
self.encoder_channel = encoder_channel
self.first_conv = nn.Sequential(
nn.Conv1d(3, 256, 1),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Conv1d(256, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, 1024, 1)
)
self.second_conv = nn.Sequential(
nn.Conv1d(2048, 2048, 1),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
nn.Conv1d(2048, self.encoder_channel, 1)
)
def forward(self, point_groups):
'''
point_groups : B G N 3
-----------------
feature_global : B G C
'''
bs, g, n, _ = point_groups.shape
point_groups = point_groups.reshape(bs * g, n, 3)
# encoder
feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n
feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1
feature = torch.cat(
[feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n
feature = self.second_conv(feature) # BG 1024 n
feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
return feature_global.reshape(bs, g, self.encoder_channel)
class Encoder_small(nn.Module): # Embedding module
def __init__(self, encoder_channel):
super().__init__()
self.encoder_channel = encoder_channel
self.first_conv = nn.Sequential(
nn.Conv1d(3, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 256, 1)
)
self.second_conv = nn.Sequential(
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, self.encoder_channel, 1)
)
def forward(self, point_groups):
'''
point_groups : B G N 3
-----------------
feature_global : B G C
'''
bs, g, n, _ = point_groups.shape
point_groups = point_groups.reshape(bs * g, n, 3)
# encoder
feature = self.first_conv(point_groups.transpose(2, 1))
feature_global = torch.max(feature, dim=2, keepdim=True)[0]
feature = torch.cat(
[feature_global.expand(-1, -1, n), feature], dim=1)
feature = self.second_conv(feature)
feature_global = torch.max(feature, dim=2, keepdim=False)[0]
return feature_global.reshape(bs, g, self.encoder_channel)
class Group(nn.Module):
def __init__(self, num_group, group_size):
super().__init__()
self.num_group = num_group
self.group_size = group_size
self.knn = KNN(k=self.group_size, transpose_mode=True)
self.knn_2 = KNN(k=1, transpose_mode=True)
def simplied_morton_sorting(self, xyz, center):
'''
Simplifying the Morton code sorting to iterate and set the nearest patch to the last patch as the next patch, we found this to be more efficient.
'''
batch_size, num_points, _ = xyz.shape
distances_batch = torch.cdist(center, center)
distances_batch[:, torch.eye(self.num_group).bool()] = float("inf")
idx_base = torch.arange(
0, batch_size, device=xyz.device) * self.num_group
sorted_indices_list = []
sorted_indices_list.append(idx_base)
distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
distances_batch[idx_base] = float("inf")
distances_batch = distances_batch.view(
batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
for i in range(self.num_group - 1):
distances_batch = distances_batch.view(
batch_size * self.num_group, self.num_group)
distances_to_last_batch = distances_batch[sorted_indices_list[-1]]
closest_point_idx = torch.argmin(distances_to_last_batch, dim=-1)
closest_point_idx = closest_point_idx + idx_base
sorted_indices_list.append(closest_point_idx)
distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
distances_batch[closest_point_idx] = float("inf")
distances_batch = distances_batch.view(
batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
sorted_indices = torch.stack(sorted_indices_list, dim=-1)
sorted_indices = sorted_indices.view(-1)
return sorted_indices
def morton_sorting(self, xyz, center):
batch_size, num_points, _ = xyz.shape
all_indices = []
for index in range(batch_size):
points = center[index]
z = get_z_values(points.cpu().numpy())
idxs = np.zeros((self.num_group), dtype=np.int32)
temp = np.arange(self.num_group)
z_ind = np.argsort(z[temp])
idxs = temp[z_ind]
all_indices.append(idxs)
all_indices = torch.tensor(all_indices, device=xyz.device)
idx_base = torch.arange(
0, batch_size, device=xyz.device).view(-1, 1) * self.num_group
sorted_indices = all_indices + idx_base
sorted_indices = sorted_indices.view(-1)
def forward(self, xyz):
'''
input: B N 3
---------------------------
output: B G M 3
center : B G 3
'''
batch_size, num_points, _ = xyz.shape
# fps the centers out
center = misc.fps(xyz, self.num_group) # B G 3
# knn to get the neighborhood
_, idx = self.knn(xyz, center) # B G M
assert idx.size(1) == self.num_group
assert idx.size(2) == self.group_size
idx_base = torch.arange(
0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
neighborhood = neighborhood.view(
batch_size, self.num_group, self.group_size, 3).contiguous()
# normalize
neighborhood = neighborhood - center.unsqueeze(2)
# can utilize morton_sorting by choosing morton_sorting function
sorted_indices = self.simplied_morton_sorting(xyz, center)
neighborhood = neighborhood.view(
batch_size * self.num_group, self.group_size, 3)[sorted_indices, :, :]
neighborhood = neighborhood.view(
batch_size, self.num_group, self.group_size, 3).contiguous()
center = center.view(
batch_size * self.num_group, 3)[sorted_indices, :]
center = center.view(
batch_size, self.num_group, 3).contiguous()
return neighborhood, center
# Transformers
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PositionEmbeddingCoordsSine(nn.Module):
"""Similar to transformer's position encoding, but generalizes it to
arbitrary dimensions and continuous coordinates.
Args:
n_dim: Number of input dimensions, e.g. 2 for image coordinates.
d_model: Number of dimensions to encode into
temperature:
scale:
"""
def __init__(self, n_dim: int = 1, d_model: int = 256, temperature=10000, scale=None):
super().__init__()
self.n_dim = n_dim
self.num_pos_feats = d_model // n_dim // 2 * 2
self.temperature = temperature
self.padding = d_model - self.num_pos_feats * self.n_dim
if scale is None:
scale = 1.0
self.scale = scale * 2 * math.pi
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Args:
xyz: Point positions (*, d_in)
Returns:
pos_emb (*, d_out)
"""
assert xyz.shape[-1] == self.n_dim
dim_t = torch.arange(self.num_pos_feats,
dtype=torch.float32, device=xyz.device)
dim_t = self.temperature ** (2 * torch.div(dim_t,
2, rounding_mode='trunc') / self.num_pos_feats)
xyz = xyz * self.scale
pos_divided = xyz.unsqueeze(-1) / dim_t
pos_sin = pos_divided[..., 0::2].sin()
pos_cos = pos_divided[..., 1::2].cos()
pos_emb = torch.stack([pos_sin, pos_cos], dim=-
1).reshape(*xyz.shape[:-1], -1)
# Pad unused dimensions with zeros
pos_emb = F.pad(pos_emb, (0, self.padding))
return pos_emb
class GPT_Transformer(nn.Module):
def __init__(self, config, **kwargs):
super().__init__()
self.config = config
# define the transformer argparse
self.mask_ratio = config.transformer_config.mask_ratio
self.trans_dim = config.transformer_config.trans_dim
self.depth = config.transformer_config.depth
self.decoder_depth = config.transformer_config.decoder_depth
self.drop_path_rate = config.transformer_config.drop_path_rate
self.num_heads = config.transformer_config.num_heads
self.group_size = config.group_size
print_log(f'[args] {config.transformer_config}', logger='Transformer')
self.encoder_dims = config.transformer_config.encoder_dims
assert self.encoder_dims in [384, 768, 1024]
if self.encoder_dims == 384:
self.encoder = Encoder_small(encoder_channel=self.encoder_dims)
else:
self.encoder = Encoder_large(encoder_channel=self.encoder_dims)
self.pos_embed = PositionEmbeddingCoordsSine(3, self.encoder_dims, 1.0)
self.blocks = GPT_extractor(
embed_dim=self.encoder_dims,
num_heads=self.num_heads,
num_layers=self.depth,
num_classes=config.cls_dim,
trans_dim=self.trans_dim,
group_size=self.group_size,
pretrained=True,
)
self.generator_blocks = GPT_generator(
embed_dim=self.encoder_dims,
num_heads=self.num_heads,
num_layers=self.decoder_depth,
trans_dim=self.trans_dim,
group_size=self.group_size
)
# do not perform additional mask on the first (self.keep_attend) tokens
self.keep_attend = 10
self.num_groups = config.num_group
self.num_mask = int(
(self.num_groups - self.keep_attend) * self.mask_ratio)
self.sos_pos = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
self.norm = nn.LayerNorm(self.trans_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv1d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, neighborhood, center, noaug=False, classify=False):
# generate mask
group_input_tokens = self.encoder(neighborhood) # B G C
batch_size, seq_len, C = group_input_tokens.size()
relative_position = center[:, 1:, :] - center[:, :-1, :]
relative_norm = torch.norm(relative_position, dim=-1, keepdim=True)
relative_direction = relative_position / relative_norm
position = torch.cat(
[center[:, 0, :].unsqueeze(1), relative_direction], dim=1)
pos_relative = self.pos_embed(position)
sos_pos = self.sos_pos.expand(group_input_tokens.size(0), -1, -1)
pos_absolute = self.pos_embed(center[:, :-1, :])
pos_absolute = torch.cat([sos_pos, pos_absolute], dim=1)
attn_mask = torch.full(
(seq_len, seq_len), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
).to(torch.bool)
with torch.no_grad():
attn_mask = torch.triu(attn_mask, diagonal=1)
# point wise
# overall_mask = np.zeros([self.num_groups, self.num_groups])
# for i in range(self.num_groups):
# mask = np.hstack([
# np.zeros(self.num_groups-self.num_mask),
# np.ones(self.num_mask),
# ])
# np.random.shuffle(mask)
# overall_mask[i, :] = mask
# overall_mask = torch.from_numpy(
# overall_mask).to(torch.bool).to('cuda')
# column wise
overall_mask = np.hstack([
np.zeros(self.num_groups-self.keep_attend-self.num_mask),
np.ones(self.num_mask),
])
np.random.shuffle(overall_mask)
overall_mask = np.hstack([
np.zeros(self.keep_attend),
overall_mask,
])
overall_mask = torch.from_numpy(
overall_mask).to(torch.bool).to('cuda')
eye_mask = torch.eye(self.num_groups).to(torch.bool).to('cuda')
attn_mask = attn_mask | overall_mask.unsqueeze(0) & ~eye_mask
# transformer
if classify == False:
encoded_features = self.blocks(
group_input_tokens, pos_absolute, attn_mask, classify=classify)
generated_points = self.generator_blocks(
encoded_features, pos_relative, attn_mask)
return generated_points
else:
print('----error---- This code is detached ----error----')
logits, generated_points = self.blocks(
group_input_tokens, pos_absolute, classify=classify)
return logits, generated_points
@MODELS.register_module()
class PointGPT(nn.Module):
def __init__(self, config):
super().__init__()
print_log(f'[PointGPT] ', logger='PointGPT')
self.config = config
self.trans_dim = config.transformer_config.trans_dim
self.GPT_Transformer = GPT_Transformer(config)
self.group_size = config.group_size
self.num_group = config.num_group
self.drop_path_rate = config.transformer_config.drop_path_rate
self.weight_center = config.weight_center
print_log(
f'[PointGPT] divide point cloud into G{self.num_group} x S{self.group_size} points ...', logger='PointGPT')
self.group_divider = Group(
num_group=self.num_group, group_size=self.group_size)
self.loss = config.loss
self.build_loss_func(self.loss)
def build_loss_func(self, loss_type):
if loss_type == "cdl1":
self.loss_func_p = ChamferDistanceL1().cuda()
elif loss_type == 'cdl2':
self.loss_func_p = ChamferDistanceL2().cuda()
elif loss_type == 'cdl12':
self.loss_func_p1 = ChamferDistanceL1().cuda()
self.loss_func_p2 = ChamferDistanceL2().cuda()
else:
raise NotImplementedError
self.loss_func_c = nn.MSELoss().cuda()
def forward(self, pts, vis=False, **kwargs):
neighborhood, center = self.group_divider(pts)
B = neighborhood.shape[0]
generated_points = self.GPT_Transformer(
neighborhood, center)
gt_points = neighborhood.reshape(
B*(self.num_group), self.group_size, 3)
loss1 = self.loss_func_p1(generated_points, gt_points)
loss2 = self.loss_func_p2(generated_points, gt_points)
if vis: # visualization
gt_points = gt_points.reshape(
B, self.num_group, self.group_size, 3)
gt_points = (gt_points + center.unsqueeze(-2)
).reshape(-1, 3).unsqueeze(0)
generated_points = generated_points.reshape(
B, self.num_group, self.group_size, 3) + center.unsqueeze(-2)
generated_points = generated_points.reshape(-1, 3).unsqueeze(0)
return generated_points, gt_points, center
return loss1 + loss2
@MODELS.register_module()
class PointTransformer(nn.Module):
def __init__(self, config, **kwargs):
super().__init__()
self.config = config
self.trans_dim = config.trans_dim
self.depth = config.depth
self.decoder_depth = config.decoder_depth
self.drop_path_rate = config.drop_path_rate
self.cls_dim = config.cls_dim
self.num_heads = config.num_heads
self.group_size = config.group_size
self.num_group = config.num_group
self.encoder_dims = config.encoder_dims
self.group_divider = Group(
num_group=self.num_group, group_size=self.group_size)
assert self.encoder_dims in [384, 768, 1024]
if self.encoder_dims == 384:
self.encoder = Encoder_small(encoder_channel=self.encoder_dims)
else:
self.encoder = Encoder_large(encoder_channel=self.encoder_dims)
self.pos_embed = PositionEmbeddingCoordsSine(3, self.encoder_dims, 1.0)
self.blocks = GPT_extractor(
embed_dim=self.encoder_dims,
num_heads=self.num_heads,
num_layers=self.depth,
num_classes=config.cls_dim,
trans_dim=self.trans_dim,
group_size=self.group_size
)
self.generator_blocks = GPT_generator(
embed_dim=self.encoder_dims,
num_heads=self.num_heads,
num_layers=self.decoder_depth,
trans_dim=self.trans_dim,
group_size=self.group_size
)
self.norm = nn.LayerNorm(self.trans_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
self.sos_pos = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
self.norm = nn.LayerNorm(self.trans_dim)
self.build_loss_func()
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.cls_pos, std=.02)
def build_loss_func(self, loss_type='cdl12'):
self.loss_ce = nn.CrossEntropyLoss()
if loss_type == "cdl1":
self.loss_func_p = ChamferDistanceL1().cuda()
elif loss_type == 'cdl2':
self.loss_func_p = ChamferDistanceL2().cuda()
elif loss_type == 'cdl12':
self.loss_func_p1 = ChamferDistanceL1().cuda()
self.loss_func_p2 = ChamferDistanceL2().cuda()
else:
raise NotImplementedError
self.loss_ce = nn.CrossEntropyLoss()
def get_loss_acc(self, ret, gt):
loss = self.loss_ce(ret, gt.long())
pred = ret.argmax(-1)
acc = (pred == gt).sum() / float(gt.size(0))
return loss, acc * 100
def load_model_from_ckpt(self, bert_ckpt_path):
if bert_ckpt_path is not None:
ckpt = torch.load(bert_ckpt_path)
base_ckpt = {k.replace("module.", ""): v for k,
v in ckpt['base_model'].items()}
for k in list(base_ckpt.keys()):
if k.startswith('GPT_Transformer'):
base_ckpt[k[len('GPT_Transformer.'):]] = base_ckpt[k]
del base_ckpt[k]
elif k.startswith('base_model'):
base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
del base_ckpt[k]
if 'cls_head_finetune' in k:
del base_ckpt[k]
incompatible = self.load_state_dict(base_ckpt, strict=False)
if incompatible.missing_keys:
print_log('missing_keys', logger='Transformer')
print_log(
get_missing_parameters_message(incompatible.missing_keys),
logger='Transformer'
)
if incompatible.unexpected_keys:
print_log('unexpected_keys', logger='Transformer')
print_log(
get_unexpected_parameters_message(
incompatible.unexpected_keys),
logger='Transformer'
)
print_log(
f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}', logger='Transformer')
else:
print_log('Training from scratch!!!', logger='Transformer')
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv1d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, pts):
neighborhood, center = self.group_divider(pts)
group_input_tokens = self.encoder(neighborhood) # B G N
B, L, _ = group_input_tokens.shape
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
pos = self.pos_embed(center)
sos_pos = self.sos_pos.expand(group_input_tokens.size(0), -1, -1)
pos = torch.cat([sos_pos, pos], dim=1)
relative_position = center[:, 1:, :] - center[:, :-1, :]
relative_norm = torch.norm(relative_position, dim=-1, keepdim=True)
relative_direction = relative_position / relative_norm
position = torch.cat(
[center[:, 0, :].unsqueeze(1), relative_direction], dim=1)
pos_relative = self.pos_embed(position)
x = torch.cat((cls_tokens, group_input_tokens), dim=1)
pos = torch.cat((cls_pos, pos), dim=1)
attn_mask = torch.full(
(L+2, L+2), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
).to(torch.bool)
attn_mask = torch.triu(attn_mask, diagonal=1)
# transformer
ret, encoded_features = self.blocks(x, pos, attn_mask, classify=True)
encoded_features = torch.cat(
[encoded_features[:, 0, :].unsqueeze(1), encoded_features[:, 2:-1, :]], dim=1)
attn_mask = torch.full(
(L, L), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
).to(torch.bool)
attn_mask = torch.triu(attn_mask, diagonal=1)
generated_points = self.generator_blocks(
encoded_features, pos_relative, attn_mask)
neighborhood = neighborhood + center.unsqueeze(2)
gt_points = neighborhood.reshape(
B*(self.num_group), self.group_size, 3)
loss1 = self.loss_func_p1(generated_points, gt_points)
loss2 = self.loss_func_p2(generated_points, gt_points)
return ret, loss1 + loss2
================================================
FILE: models/__init__.py
================================================
from .build import build_model_from_cfg
import models.PointGPT
================================================
FILE: models/build.py
================================================
from utils import registry
MODELS = registry.Registry('models')
def build_model_from_cfg(cfg, **kwargs):
"""
Build a dataset, defined by `dataset_name`.
Args:
cfg (eDICT):
Returns:
Dataset: a constructed dataset specified by dataset_name.
"""
return MODELS.build(cfg, **kwargs)
================================================
FILE: models/z_order.py
================================================
import numpy as np
def round_to_int_32(data):
"""
Takes a Numpy array of float values between
-1 and 1, and rounds them to significant
32-bit integer values, to be used in the
morton code computation
:param data: multidimensional numpy array
:return: same as data but in 32-bit int format
"""
# first we rescale points to 0-512
min_data = np.abs(np.min(data)-0.5)
data = 256*(data + min_data)
# now convert to int
data = np.round(2 ** 21 - data).astype(dtype=np.int32)
return data
def split_by_3(x):
"""
Method to separate bits of a 32-bit integer
by 3 positions apart, using the magic bits
https://www.forceflow.be/2013/10/07/morton-encodingdecoding-through-bit-interleaving-implementations/
:param x: 32-bit integer
:return: x with bits separated
"""
# we only look at 21 bits, since we want to generate
# a 64-bit code eventually (3 x 21 bits = 63 bits, which
# is the maximum we can fit in a 64-bit code)
x &= 0x1fffff # only take first 21 bits
# shift left 32 bits, OR with self, and 00011111000000000000000000000000000000001111111111111111
x = (x | (x << 32)) & 0x1f00000000ffff
# shift left 16 bits, OR with self, and 00011111000000000000000011111111000000000000000011111111
x = (x | (x << 16)) & 0x1f0000ff0000ff
# shift left 8 bits, OR with self, and 0001000000001111000000001111000000001111000000001111000000000000
x = (x | (x << 8)) & 0x100f00f00f00f00f
# shift left 4 bits, OR with self, and 0001000011000011000011000011000011000011000011000011000100000000
x = (x | (x << 4)) & 0x10c30c30c30c30c3
# shift left 2 bits, OR with self, and 0001001001001001001001001001001001001001001001001001001001001001
x = (x | (x << 2)) & 0x1249249249249249
return x
def get_z_order(x, y, z):
"""
Given 3 arrays of corresponding x, y, z
coordinates, compute the morton (or z) code for
each point and return an index array
We compute the Morton order as follows:
1- Split all coordinates by 3 (add 2 zeros between bits)
2- Shift bits left by 1 for y and 2 for z
3- Interleave x, shifted y, and shifted z
The mordon order is the final interleaved bit sequence
:param x: x coordinates
:param y: y coordinates
:param z: z coordinates
:return: index array with morton code
"""
res = 0
res |= split_by_3(x) | split_by_3(y) << 1 | split_by_3(z) << 2
return res
def get_z_values(data):
"""
Computes the z values for a point array
:param data: Nx3 array of x, y, and z location
:return: Nx1 array of z values
"""
points_round = round_to_int_32(data) # convert to int
z = get_z_order(points_round[:, 0], points_round[:, 1], points_round[:, 2])
return z
================================================
FILE: requirements.txt
================================================
argparse
easydict
h5py
matplotlib
numpy
open3d==0.9
opencv-python
pyyaml
scipy
tensorboardX
timm==0.4.5
tqdm
transforms3d
termcolor
================================================
FILE: segmentation/__init__.py
================================================
================================================
FILE: segmentation/dataset.py
================================================
import numpy as np
import os
from torch.utils.data import Dataset
import torch
from pointnet_util import farthest_point_sample, pc_normalize
import json
class ModelNetDataLoader(Dataset):
def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):
self.root = root
self.npoints = npoint
self.uniform = uniform
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
self.normal_channel = normal_channel
shape_ids = {}
shape_ids['train'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(
os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
# list of (shape_name, shape_txt_file_path) tuple
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
self.cache_size = cache_size # how many data points to cache in memory
self.cache = {} # from index to (point_set, cls) tuple
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if index in self.cache:
point_set, cls = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.normal_channel:
point_set = point_set[:, 0:3]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls)
return point_set, cls
def __getitem__(self, index):
return self._get_item(index)
class PartNormalDataset(Dataset):
def __init__(self, root='/data/cgy/ShapenetPart/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.normal_channel = normal_channel
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
self.cat = {k: v for k, v in self.cat.items()}
self.classes_original = dict(zip(self.cat, range(len(self.cat))))
if not class_choice is None:
self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
# print(self.cat)
self.meta = {}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
for item in self.cat:
# print('category', item)
self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item])
fns = sorted(os.listdir(dir_point))
# print(fns[0][0:-4])
if split == 'trainval':
fns = [fn for fn in fns if (
(fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
elif split == 'train':
fns = [fn for fn in fns if fn[0:-4] in train_ids]
elif split == 'val':
fns = [fn for fn in fns if fn[0:-4] in val_ids]
elif split == 'test':
fns = [fn for fn in fns if fn[0:-4] in test_ids]
else:
print('Unknown split: %s. Exiting..' % (split))
exit(-1)
# print(os.path.basename(fns))
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0])
self.meta[item].append(os.path.join(dir_point, token + '.txt'))
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn))
self.classes = {}
for i in self.cat.keys():
self.classes[i] = self.classes_original[i]
# Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],
'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],
'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],
'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
# for cat in sorted(self.seg_classes.keys()):
# print(cat, self.seg_classes[cat])
self.cache = {} # from index to (point_set, cls, seg) tuple
self.cache_size = 20000
def __getitem__(self, index):
if index in self.cache:
point_set, cls, seg = self.cache[index]
else:
fn = self.datapath[index]
cat = self.datapath[index][0]
cls = self.classes[cat]
cls = np.array([cls]).astype(np.int32)
data = np.loadtxt(fn[1]).astype(np.float32)
if not self.normal_channel:
point_set = data[:, 0:3]
else:
point_set = data[:, 0:6]
seg = data[:, -1].astype(np.int32)
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, cls, seg)
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
choice = np.random.choice(len(seg), self.npoints, replace=True)
# resample
point_set = point_set[choice, :]
seg = seg[choice]
return point_set, cls, seg
def __len__(self):
return len(self.datapath)
if __name__ == '__main__':
data = ModelNetDataLoader('modelnet40_normal_resampled/',
split='train', uniform=False, normal_channel=True)
DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
for point, label in DataLoader:
print(point.shape)
print(label.shape)
================================================
FILE: segmentation/extensions/chamfer_dist/__init__.py
================================================
# -*- coding: utf-8 -*-
# @Author: Thibault GROUEIX
# @Date: 2019-08-07 20:54:24
# @Last Modified by: Haozhe Xie
# @Last Modified time: 2019-12-18 15:06:25
# @Email: cshzxie@gmail.com
import torch
import chamfer
class ChamferFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2)
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
return dist1, dist2
@staticmethod
def backward(ctx, grad_dist1, grad_dist2):
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)
return grad_xyz1, grad_xyz2
class ChamferDistanceL2(torch.nn.Module):
f''' Chamder Distance L2
'''
def __init__(self, ignore_zeros=False):
super().__init__()
self.ignore_zeros = ignore_zeros
def forward(self, xyz1, xyz2):
batch_size = xyz1.size(0)
if batch_size == 1 and self.ignore_zeros:
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
return torch.mean(dist1) + torch.mean(dist2)
class ChamferDistanceL2_split(torch.nn.Module):
f''' Chamder Distance L2
'''
def __init__(self, ignore_zeros=False):
super().__init__()
self.ignore_zeros = ignore_zeros
def forward(self, xyz1, xyz2):
batch_size = xyz1.size(0)
if batch_size == 1 and self.ignore_zeros:
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
return torch.mean(dist1), torch.mean(dist2)
class ChamferDistanceL1(torch.nn.Module):
f''' Chamder Distance L1
'''
def __init__(self, ignore_zeros=False):
super().__init__()
self.ignore_zeros = ignore_zeros
def forward(self, xyz1, xyz2):
batch_size = xyz1.size(0)
if batch_size == 1 and self.ignore_zeros:
non_zeros1 = torch.sum(xyz1, dim=2).ne(0)
non_zeros2 = torch.sum(xyz2, dim=2).ne(0)
xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)
xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)
dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)
# import pdb
# pdb.set_trace()
dist1 = torch.sqrt(dist1)
dist2 = torch.sqrt(dist2)
return (torch.mean(dist1) + torch.mean(dist2))/2
================================================
FILE: segmentation/extensions/chamfer_dist/chamfer.cu
================================================
/*
* @Author: Haozhe Xie
* @Date: 2019-08-07 20:54:24
* @Last Modified by: Haozhe Xie
* @Last Modified time: 2020-06-17 14:58:55
* @Email: cshzxie@gmail.com
*/
#include
#include
#include
#include
__global__ void chamfer_dist_kernel(int batch_size,
int n,
const float* xyz1,
int m,
const float* xyz2,
float* dist,
int* indexes) {
const int batch = 512;
__shared__ float buf[batch * 3];
for (int i = blockIdx.x; i < batch_size; i += gridDim.x) {
for (int k2 = 0; k2 < m; k2 += batch) {
int end_k = min(m, k2 + batch) - k2;
for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) {
buf[j] = xyz2[(i * m + k2) * 3 + j];
}
__syncthreads();
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;
j += blockDim.x * gridDim.y) {
float x1 = xyz1[(i * n + j) * 3 + 0];
float y1 = xyz1[(i * n + j) * 3 + 1];
float z1 = xyz1[(i * n + j) * 3 + 2];
float best_dist = 0;
int best_dist_index = 0;
int end_ka = end_k - (end_k & 3);
if (end_ka == batch) {
for (int k = 0; k < batch; k += 4) {
{
float x2 = buf[k * 3 + 0] - x1;
float y2 = buf[k * 3 + 1] - y1;
float z2 = buf[k * 3 + 2] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (k == 0 || dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2;
}
}
{
float x2 = buf[k * 3 + 3] - x1;
float y2 = buf[k * 3 + 4] - y1;
float z2 = buf[k * 3 + 5] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 1;
}
}
{
float x2 = buf[k * 3 + 6] - x1;
float y2 = buf[k * 3 + 7] - y1;
float z2 = buf[k * 3 + 8] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 2;
}
}
{
float x2 = buf[k * 3 + 9] - x1;
float y2 = buf[k * 3 + 10] - y1;
float z2 = buf[k * 3 + 11] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 3;
}
}
}
} else {
for (int k = 0; k < end_ka; k += 4) {
{
float x2 = buf[k * 3 + 0] - x1;
float y2 = buf[k * 3 + 1] - y1;
float z2 = buf[k * 3 + 2] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (k == 0 || dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2;
}
}
{
float x2 = buf[k * 3 + 3] - x1;
float y2 = buf[k * 3 + 4] - y1;
float z2 = buf[k * 3 + 5] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 1;
}
}
{
float x2 = buf[k * 3 + 6] - x1;
float y2 = buf[k * 3 + 7] - y1;
float z2 = buf[k * 3 + 8] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 2;
}
}
{
float x2 = buf[k * 3 + 9] - x1;
float y2 = buf[k * 3 + 10] - y1;
float z2 = buf[k * 3 + 11] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2 + 3;
}
}
}
}
for (int k = end_ka; k < end_k; k++) {
float x2 = buf[k * 3 + 0] - x1;
float y2 = buf[k * 3 + 1] - y1;
float z2 = buf[k * 3 + 2] - z1;
float dist = x2 * x2 + y2 * y2 + z2 * z2;
if (k == 0 || dist < best_dist) {
best_dist = dist;
best_dist_index = k + k2;
}
}
if (k2 == 0 || dist[(i * n + j)] > best_dist) {
dist[(i * n + j)] = best_dist;
indexes[(i * n + j)] = best_dist_index;
}
}
__syncthreads();
}
}
}
std::vector chamfer_cuda_forward(torch::Tensor xyz1,
torch::Tensor xyz2) {
const int batch_size = xyz1.size(0);
const int n = xyz1.size(1); // num_points point cloud A
const int m = xyz2.size(1); // num_points point cloud B
torch::Tensor dist1 =
torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat));
torch::Tensor dist2 =
torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat));
torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));
torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));
chamfer_dist_kernel<<>>(
batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(),
dist1.data_ptr(), idx1.data_ptr());
chamfer_dist_kernel<<>>(
batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(),
dist2.data_ptr(), idx2.data_ptr());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("Error in chamfer_cuda_forward: %s\n", cudaGetErrorString(err));
}
return {dist1, dist2, idx1, idx2};
}
__global__ void chamfer_dist_grad_kernel(int b,
int n,
const float* xyz1,
int m,
const float* xyz2,
const float* grad_dist1,
const int* idx1,
float* grad_xyz1,
float* grad_xyz2) {
for (int i = blockIdx.x; i < b; i += gridDim.x) {
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;
j += blockDim.x * gridDim.y) {
float x1 = xyz1[(i * n + j) * 3 + 0];
float y1 = xyz1[(i * n + j) * 3 + 1];
float z1 = xyz1[(i * n + j) * 3 + 2];
int j2 = idx1[i * n + j];
float x2 = xyz2[(i * m + j2) * 3 + 0];
float y2 = xyz2[(i * m + j2) * 3 + 1];
float z2 = xyz2[(i * m + j2) * 3 + 2];
float g = grad_dist1[i * n + j] * 2;
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));
atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));
atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));
}
}
}
std::vector chamfer_cuda_backward(torch::Tensor xyz1,
torch::Tensor xyz2,
torch::Tensor idx1,
torch::Tensor idx2,
torch::Tensor grad_dist1,
torch::Tensor grad_dist2) {
const int batch_size = xyz1.size(0);
const int n = xyz1.size(1); // num_points point cloud A
const int m = xyz2.size(1); // num_points point cloud B
torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat));
torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat));
chamfer_dist_grad_kernel<<>>(
batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(),
grad_dist1.data_ptr(), idx1.data_ptr(),
grad_xyz1.data_ptr(), grad_xyz2.data_ptr());
chamfer_dist_grad_kernel<<>>(
batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(),
grad_dist2.data_ptr(), idx2.data_ptr(),
grad_xyz2.data_ptr(), grad_xyz1.data_ptr());
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("Error in chamfer_cuda_backward: %s\n", cudaGetErrorString(err));
}
return {grad_xyz1, grad_xyz2};
}
================================================
FILE: segmentation/extensions/chamfer_dist/chamfer_cuda.cpp
================================================
/*
* @Author: Haozhe Xie
* @Date: 2019-08-07 20:54:24
* @Last Modified by: Haozhe Xie
* @Last Modified time: 2019-12-10 10:33:50
* @Email: cshzxie@gmail.com
*/
#include
#include
std::vector chamfer_cuda_forward(torch::Tensor xyz1,
torch::Tensor xyz2);
std::vector chamfer_cuda_backward(torch::Tensor xyz1,
torch::Tensor xyz2,
torch::Tensor idx1,
torch::Tensor idx2,
torch::Tensor grad_dist1,
torch::Tensor grad_dist2);
std::vector chamfer_forward(torch::Tensor xyz1,
torch::Tensor xyz2) {
return chamfer_cuda_forward(xyz1, xyz2);
}
std::vector chamfer_backward(torch::Tensor xyz1,
torch::Tensor xyz2,
torch::Tensor idx1,
torch::Tensor idx2,
torch::Tensor grad_dist1,
torch::Tensor grad_dist2) {
return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)");
m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)");
}
================================================
FILE: segmentation/extensions/chamfer_dist/setup.py
================================================
# -*- coding: utf-8 -*-
# @Author: Haozhe Xie
# @Date: 2019-08-07 20:54:24
# @Last Modified by: Haozhe Xie
# @Last Modified time: 2019-12-10 10:04:25
# @Email: cshzxie@gmail.com
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(name='chamfer',
version='2.0.0',
ext_modules=[
CUDAExtension('chamfer', [
'chamfer_cuda.cpp',
'chamfer.cu',
]),
],
cmdclass={'build_ext': BuildExtension})
================================================
FILE: segmentation/extensions/chamfer_dist/test.py
================================================
# -*- coding: utf-8 -*-
# @Author: Haozhe Xie
# @Date: 2019-12-10 10:38:01
# @Last Modified by: Haozhe Xie
# @Last Modified time: 2019-12-26 14:21:36
# @Email: cshzxie@gmail.com
#
# Note:
# - Replace float -> double, kFloat -> kDouble in chamfer.cu
import os
import sys
import torch
import unittest
from torch.autograd import gradcheck
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))
from extensions.chamfer_dist import ChamferFunction
class ChamferDistanceTestCase(unittest.TestCase):
def test_chamfer_dist(self):
x = torch.rand(4, 64, 3).double()
y = torch.rand(4, 128, 3).double()
x.requires_grad = True
y.requires_grad = True
print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()]))
if __name__ == '__main__':
# unittest.main()
import pdb
x = torch.rand(32,128,3)
y = torch.rand(32,128,3)
pdb.set_trace()
================================================
FILE: segmentation/extensions/emd/README.md
================================================
# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)
## Dependency
The code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.
## Usage
First compile using
python setup.py install
Then, copy the lib file out to the main directory,
cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .
Then, you can use it by simply
from emd import earth_mover_distance
d = earth_mover_distance(p1, p2, transpose=False) # p1: B x N1 x 3, p2: B x N2 x 3
Check `test_emd_loss.py` for example.
## Author
The cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.
## License
MIT
================================================
FILE: segmentation/extensions/emd/__init__.py
================================================
from .emd import earth_mover_distance as emd
__all__ = ['emd']
================================================
FILE: segmentation/extensions/emd/cuda/emd.cpp
================================================
#ifndef _EMD
#define _EMD
#include
#include
//CUDA declarations
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2);
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match);
std::vector MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("approxmatch_forward", &ApproxMatchForward,"ApproxMatch forward (CUDA)");
m.def("matchcost_forward", &MatchCostForward,"MatchCost forward (CUDA)");
m.def("matchcost_backward", &MatchCostBackward,"MatchCost backward (CUDA)");
}
#endif
================================================
FILE: segmentation/extensions/emd/cuda/emd_kernel.cu
================================================
/**********************************
* Original Author: Haoqiang Fan
* Modified by: Kaichun Mo
*********************************/
#ifndef _EMD_KERNEL
#define _EMD_KERNEL
#include
#include
#include
#include // at::cuda::getApplyGrid
// #include
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/********************************
* Forward kernel for approxmatch
*********************************/
template
__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){
scalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
scalar_t multiL,multiR;
if (n>=m){
multiL=1;
multiR=n/m;
}else{
multiL=m/n;
multiR=1;
}
const int Block=1024;
__shared__ scalar_t buf[Block*4];
for (int i=blockIdx.x;i=-2;j--){
scalar_t level=-powf(4.0f,j);
if (j==-2){
level=0;
}
for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp);
//}
/* ApproxMatch forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
Output:
match: (B, N2, N1)
*/
at::Tensor ApproxMatchForward(
const at::Tensor xyz1,
const at::Tensor xyz2){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto match = at::zeros({b, m, n}, xyz1.type());
auto temp = at::zeros({b, (n+m)*2}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "ApproxMatchForward", ([&] {
approxmatch<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), temp.data());
}));
AT_CUDA_CHECK(cudaGetLastError());
return match;
}
/********************************
* Forward kernel for matchcost
*********************************/
template
__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){
__shared__ scalar_t allsum[512];
const int Block=1024;
__shared__ scalar_t buf[Block*3];
for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out);
//}
/* MatchCost forward interface
Input:
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
cost: (B)
*/
at::Tensor MatchCostForward(
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto cost = at::zeros({b}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostForward", ([&] {
matchcost<<<32,512>>>(b, n, m, xyz1.data(), xyz2.data(), match.data(), cost.data());
}));
AT_CUDA_CHECK(cudaGetLastError());
return cost;
}
/********************************
* matchcostgrad2 kernel
*********************************/
template
__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){
__shared__ scalar_t sum_grad[256*3];
for (int i=blockIdx.x;i
__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){
for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1);
// matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2);
//}
/* MatchCost backward interface
Input:
grad_cost: (B) # gradients on cost
xyz1: (B, N1, 3) # dataset_points
xyz2: (B, N2, 3) # query_points
match: (B, N2, N1)
Output:
grad1: (B, N1, 3)
grad2: (B, N2, 3)
*/
std::vector MatchCostBackward(
const at::Tensor grad_cost,
const at::Tensor xyz1,
const at::Tensor xyz2,
const at::Tensor match){
const auto b = xyz1.size(0);
const auto n = xyz1.size(1);
const auto m = xyz2.size(1);
CHECK_EQ(xyz2.size(0), b);
CHECK_EQ(xyz1.size(2), 3);
CHECK_EQ(xyz2.size(2), 3);
CHECK_INPUT(xyz1);
CHECK_INPUT(xyz2);
auto grad1 = at::zeros({b, n, 3}, xyz1.type());
auto grad2 = at::zeros({b, m, 3}, xyz1.type());
AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), "MatchCostBackward", ([&] {
matchcostgrad1<<<32,512>>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad1.data());
matchcostgrad2<<>>(b, n, m, grad_cost.data(), xyz1.data(), xyz2.data(), match.data(), grad2.data());
}));
AT_CUDA_CHECK(cudaGetLastError());
return std::vector({grad1, grad2});
}
#endif
================================================
FILE: segmentation/extensions/emd/emd.py
================================================
import torch
import emd_cuda
class EarthMoverDistanceFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz1, xyz2):
xyz1 = xyz1.contiguous()
xyz2 = xyz2.contiguous()
assert xyz1.is_cuda and xyz2.is_cuda, "Only support cuda currently."
match = emd_cuda.approxmatch_forward(xyz1, xyz2)
cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)
ctx.save_for_backward(xyz1, xyz2, match)
return cost
@staticmethod
def backward(ctx, grad_cost):
xyz1, xyz2, match = ctx.saved_tensors
grad_cost = grad_cost.contiguous()
grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)
return grad_xyz1, grad_xyz2
class earth_mover_distance(torch.nn.Module):
f''' emd
'''
def __init__(self):
super().__init__()
def forward(self, xyz1, xyz2, transpose=False):
"""Earth Mover Distance (Approx)
Args:
xyz1 (torch.Tensor): (b, n1, 3)
xyz2 (torch.Tensor): (b, n2, 3)
transpose (bool): whether to transpose inputs as it might be BCN format.
Extensions only support BNC format.
Returns:
cost (torch.Tensor): (b)
"""
cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
cost = cost / xyz1.size(1)
return cost.mean()
# def earth_mover_distance(xyz1, xyz2, transpose=True):
# """Earth Mover Distance (Approx)
# Args:
# xyz1 (torch.Tensor): (b, 3, n1)
# xyz2 (torch.Tensor): (b, 3, n1)
# transpose (bool): whether to transpose inputs as it might be BCN format.
# Extensions only support BNC format.
# Returns:
# cost (torch.Tensor): (b)
# """
# if xyz1.dim() == 2:
# xyz1 = xyz1.unsqueeze(0)
# if xyz2.dim() == 2:
# xyz2 = xyz2.unsqueeze(0)
# if transpose:
# xyz1 = xyz1.transpose(1, 2)
# xyz2 = xyz2.transpose(1, 2)
# cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)
# return cost
================================================
FILE: segmentation/extensions/emd/setup.py
================================================
"""Setup extension
Notes:
If extra_compile_args is provided, you need to provide different instances for different extensions.
Refer to https://github.com/pytorch/pytorch/issues/20169
"""
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='emd_ext',
ext_modules=[
CUDAExtension(
name='emd_cuda',
sources=[
'cuda/emd.cpp',
'cuda/emd_kernel.cu',
],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
),
],
cmdclass={
'build_ext': BuildExtension
})
================================================
FILE: segmentation/extensions/emd/test_emd_loss.py
================================================
import torch
import numpy as np
import time
from emd import earth_mover_distance
# gt
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
print(p1.shape)
p1.requires_grad = True
p2.requires_grad = True
gt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 + \
(((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \
(((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3
print('gt_dist: ', gt_dist)
gt_dist.backward()
print(p1.grad)
print(p2.grad)
# emd
p1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()
p1 = p1.repeat(3, 1, 1)
p2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()
p2 = p2.repeat(3, 1, 1)
print(p1)
print(p2)
p1.requires_grad = True
p2.requires_grad = True
d = earth_mover_distance(p1, p2, transpose=False)
print(d)
loss = d[0] / 2 + d[1] * 2 + d[2] / 3
print(loss)
loss.backward()
print(p1.grad)
print(p2.grad)
================================================
FILE: segmentation/logger.py
================================================
import logging
import torch.distributed as dist
import copy
import logging
import os
from collections import defaultdict
import torch
import torch.nn as nn
from typing import Any
from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
from termcolor import colored
logger_initialized = {}
def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
"""Get root logger and add a keyword filter to it.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmdet3d".
Args:
log_file (str, optional): File path of log. Defaults to None.
log_level (int, optional): The level of logger.
Defaults to logging.INFO.
name (str, optional): The name of the root logger, also used as a
filter keyword. Defaults to 'mmdet3d'.
Returns:
:obj:`logging.Logger`: The obtained logger
"""
logger = get_logger(name=name, log_file=log_file, log_level=log_level)
# add a logging filter
logging_filter = logging.Filter(name)
logging_filter.filter = lambda record: record.find(name) != -1
return logger
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# only rank 0 will add a FileHandler
if rank == 0 and log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
def print_log(msg, logger=None, level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used.
Some special loggers are:
- "silent": no message will be printed.
- other str: the logger obtained with `get_root_logger(logger)`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger == 'silent':
pass
elif isinstance(logger, str):
_logger = get_logger(logger)
_logger.log(level, msg)
else:
raise TypeError(
'logger should be either a logging.Logger object, str, '
f'"silent" or None, but got {type(logger)}')
def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
)
return msg
def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
)
return msg
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix):]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata # pyre-ignore
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix):]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1:]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(group) + "}"
def _named_modules_with_dup(
model: nn.Module, prefix: str = ""
) -> Iterable[Tuple[str, nn.Module]]:
"""
The same as `model.named_modules()`, except that it includes
duplicated modules that have more than one name.
"""
yield prefix, model
for name, module in model._modules.items(): # pyre-ignore
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _named_modules_with_dup(module, submodule_prefix)
================================================
FILE: segmentation/main.py
================================================
"""
Author: Benny
Date: Nov 2019
"""
import argparse
import os
import torch
import datetime
import logging
import sys
import importlib
import shutil
import provider
import numpy as np
import torch.optim as optim
from timm.scheduler import CosineLRScheduler
from pathlib import Path
from tqdm import tqdm
from dataset import PartNormalDataset
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))
seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
for label in seg_classes[cat]:
seg_label_to_cat[label] = cat
def inplace_relu(m):
classname = m.__class__.__name__
if classname.find('ReLU') != -1:
m.inplace = True
def to_categorical(y, num_classes):
""" 1-hot encodes a tensor """
new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
if (y.is_cuda):
return new_y.cuda()
return new_y
def parse_args():
parser = argparse.ArgumentParser('Model')
parser.add_argument('--model', type=str, default='pt')
parser.add_argument('--model_name', type=str, default='PointGPT_S', choices=['PointGPT_S', 'PointGPT_B', 'PointGPT_L'])
parser.add_argument('--batch_size', type=int, default=32, # 16, 32
help='batch Size during training')
parser.add_argument('--epoch', default=300, type=int, help='epoch to run')
parser.add_argument('--warmup_epoch', default=30,
type=int, help='warmup epoch')
parser.add_argument('--learning_rate', default=0.0002,
type=float, help='initial learning rate')
parser.add_argument('--gpu', type=str, default='1',
help='specify GPU devices')
# parser.add_argument('--optimizer', type=str, default='AdamW', help='Adam or SGD')
parser.add_argument('--log_dir', type=str,
default='./exp', help='log path')
# parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay')
parser.add_argument('--npoint', type=int,
default=2048, help='point Number')
parser.add_argument('--normal', action='store_true',
default=False, help='use normals')
# parser.add_argument('--step_size', type=int, default=20, help='decay step for lr decay')
# parser.add_argument('--lr_decay', type=float, default=0.5, help='decay rate for lr decay')
parser.add_argument(
'--ckpts', type=str, default='../best/pretrain/m0.6R_1_pretrain300.pth', help='ckpts')
parser.add_argument(
'--root', type=str, default='data/ShapenetPart/shapenetcore_partanno_segmentation_benchmark_v0_normal/', help='data root')
return parser.parse_args()
def get_model_loss(MODEL, args, num_part):
if args.model_name == 'PointGPT_S':
classifier = MODEL.get_model(num_part, trans_dim=384, depth=12, drop_path_rate=0.1, num_heads=6, decoder_depth=4, group_size=32, num_group=128, prop_dim=1024, label_dim1=512, label_dim2=256, encoder_dims=384)
classifier = classifier.cuda()
criterion = MODEL.get_loss().cuda()
classifier.apply(inplace_relu)
elif args.model_name == 'PointGPT_B':
classifier = MODEL.get_model(num_part, trans_dim=768, depth=12, drop_path_rate=0.1, num_heads=12, decoder_depth=4, group_size=32, num_group=128, prop_dim=2048, label_dim1=1024, label_dim2=512, encoder_dims=768)
classifier = classifier.cuda()
criterion = MODEL.get_loss().cuda()
classifier.apply(inplace_relu)
elif args.model_name == 'PointGPT_L':
classifier = MODEL.get_model(num_part, trans_dim=1024, depth=24, drop_path_rate=0.1, num_heads=16, decoder_depth=4, group_size=32, num_group=128, prop_dim=2048, label_dim1=1024, label_dim2=512, encoder_dims=1024)
classifier = classifier.cuda()
criterion = MODEL.get_loss().cuda()
classifier.apply(inplace_relu)
return classifier, criterion
def main(args):
def log_string(str):
logger.info(str)
print(str)
'''HYPER PARAMETER'''
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
'''CREATE DIR'''
timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
exp_dir = Path('./log/')
exp_dir.mkdir(exist_ok=True)
exp_dir = exp_dir.joinpath('part_seg')
exp_dir.mkdir(exist_ok=True)
if args.log_dir is None:
exp_dir = exp_dir.joinpath(timestr)
else:
exp_dir = exp_dir.joinpath(args.log_dir)
exp_dir.mkdir(exist_ok=True)
checkpoints_dir = exp_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = exp_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)
'''LOG'''
args = parse_args()
logger = logging.getLogger("Model")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
log_string('PARAMETER ...')
log_string(args)
root = args.root
TRAIN_DATASET = PartNormalDataset(
root=root, npoints=args.npoint, split='trainval', normal_channel=args.normal)
trainDataLoader = torch.utils.data.DataLoader(
TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
TEST_DATASET = PartNormalDataset(
root=root, npoints=args.npoint, split='test', normal_channel=args.normal)
testDataLoader = torch.utils.data.DataLoader(
TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)
log_string("The number of training data is: %d" % len(TRAIN_DATASET))
log_string("The number of test data is: %d" % len(TEST_DATASET))
num_classes = 16
num_part = 50
'''MODEL LOADING'''
MODEL = importlib.import_module(args.model)
shutil.copy('models/%s.py' % args.model, str(exp_dir))
# shutil.copy('models/pointnet2_utils.py', str(exp_dir))
classifier, criterion = get_model_loss(MODEL, args, num_part)
print('# generator parameters:', sum(param.numel()
for param in classifier.parameters()))
start_epoch = 0
if args.ckpts is not None:
classifier.load_model_from_ckpt(args.ckpts)
# we use adamw and cosine scheduler
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or 'token' in name or name in skip_list:
# print(name)
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
param_groups = add_weight_decay(classifier, weight_decay=0.05)
optimizer = optim.AdamW(
param_groups, lr=args.learning_rate, weight_decay=0.05)
scheduler = CosineLRScheduler(optimizer,
t_initial=args.epoch,
# t_mul=1,
lr_min=1e-6,
cycle_decay=0.1,
warmup_lr_init=1e-6,
warmup_t=args.warmup_epoch,
cycle_limit=1,
t_in_epochs=True)
best_acc = 0
global_epoch = 0
best_class_avg_iou = 0
best_inctance_avg_iou = 0
classifier.zero_grad()
for epoch in range(start_epoch, args.epoch):
mean_correct = []
log_string('Epoch %d (%d/%s):' %
(global_epoch + 1, epoch + 1, args.epoch))
'''Adjust learning rate and BN momentum'''
classifier = classifier.train()
loss_batch = []
num_iter = 0
'''learning one epoch'''
for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
num_iter += 1
points = points.data.numpy()
points[:, :, 0:3] = provider.random_scale_point_cloud(
points[:, :, 0:3])
points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])
points = torch.Tensor(points)
points, label, target = points.float().cuda(
), label.long().cuda(), target.long().cuda()
points = points.transpose(2, 1)
seg_pred = classifier(points, to_categorical(label, num_classes))
seg_pred = seg_pred.contiguous().view(-1, num_part)
target = target.view(-1, 1)[:, 0]
pred_choice = seg_pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
mean_correct.append(
correct.item() / (args.batch_size * args.npoint))
loss = criterion(seg_pred, target)
loss.backward()
optimizer.step()
loss_batch.append(loss.detach().cpu())
if num_iter == 1:
torch.nn.utils.clip_grad_norm_(
classifier.parameters(), 10, norm_type=2)
num_iter = 0
optimizer.step()
classifier.zero_grad()
if isinstance(scheduler, list):
for item in scheduler:
item.step(epoch)
else:
scheduler.step(epoch)
train_instance_acc = np.mean(mean_correct)
loss1 = np.mean(loss_batch)
log_string('Train accuracy is: %.5f' % train_instance_acc)
log_string('Train loss: %.5f' % loss1)
log_string('lr: %.6f' % optimizer.param_groups[0]['lr'])
with torch.no_grad():
test_metrics = {}
total_correct = 0
total_seen = 0
total_seen_class = [0 for _ in range(num_part)]
total_correct_class = [0 for _ in range(num_part)]
shape_ious = {cat: [] for cat in seg_classes.keys()}
seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table}
for cat in seg_classes.keys():
for label in seg_classes[cat]:
seg_label_to_cat[label] = cat
classifier = classifier.eval()
for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
cur_batch_size, NUM_POINT, _ = points.size()
points, label, target = points.float().cuda(
), label.long().cuda(), target.long().cuda()
points = points.transpose(2, 1)
seg_pred = classifier(
points, to_categorical(label, num_classes))
cur_pred_val = seg_pred.cpu().data.numpy()
cur_pred_val_logits = cur_pred_val
cur_pred_val = np.zeros(
(cur_batch_size, NUM_POINT)).astype(np.int32)
target = target.cpu().data.numpy()
for i in range(cur_batch_size):
cat = seg_label_to_cat[target[i, 0]]
logits = cur_pred_val_logits[i, :, :]
cur_pred_val[i, :] = np.argmax(
logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]
correct = np.sum(cur_pred_val == target)
total_correct += correct
total_seen += (cur_batch_size * NUM_POINT)
for l in range(num_part):
total_seen_class[l] += np.sum(target == l)
total_correct_class[l] += (
np.sum((cur_pred_val == l) & (target == l)))
for i in range(cur_batch_size):
segp = cur_pred_val[i, :]
segl = target[i, :]
cat = seg_label_to_cat[segl[0]]
part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
for l in seg_classes[cat]:
if (np.sum(segl == l) == 0) and (
np.sum(segp == l) == 0): # part is not present, no prediction as well
part_ious[l - seg_classes[cat][0]] = 1.0
else:
part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
np.sum((segl == l) | (segp == l)))
shape_ious[cat].append(np.mean(part_ious))
all_shape_ious = []
for cat in shape_ious.keys():
for iou in shape_ious[cat]:
all_shape_ious.append(iou)
shape_ious[cat] = np.mean(shape_ious[cat])
mean_shape_ious = np.mean(list(shape_ious.values()))
test_metrics['accuracy'] = total_correct / float(total_seen)
test_metrics['class_avg_accuracy'] = np.mean(
np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float))
for cat in sorted(shape_ious.keys()):
log_string('eval mIoU of %s %f' %
(cat + ' ' * (14 - len(cat)), shape_ious[cat]))
test_metrics['class_avg_iou'] = mean_shape_ious
test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)
log_string('Epoch %d test Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % (
epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):
logger.info('Save model...')
savepath = str(checkpoints_dir) + '/best_model.pth'
log_string('Saving at %s' % savepath)
state = {
'epoch': epoch,
'train_acc': train_instance_acc,
'test_acc': test_metrics['accuracy'],
'class_avg_iou': test_metrics['class_avg_iou'],
'inctance_avg_iou': test_metrics['inctance_avg_iou'],
'model_state_dict': classifier.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(state, savepath)
log_string('Saving model....')
if test_metrics['accuracy'] > best_acc:
best_acc = test_metrics['accuracy']
if test_metrics['class_avg_iou'] > best_class_avg_iou:
best_class_avg_iou = test_metrics['class_avg_iou']
if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
best_inctance_avg_iou = test_metrics['inctance_avg_iou']
log_string('Best accuracy is: %.5f' % best_acc)
log_string('Best class avg mIOU is: %.5f' % best_class_avg_iou)
log_string('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
global_epoch += 1
if __name__ == '__main__':
args = parse_args()
main(args)
================================================
FILE: segmentation/misc.py
================================================
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from collections import abc
from pointnet2_ops import pointnet2_utils
def fps(data, number):
'''
data B N 3
number int
'''
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
return fps_data
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
def build_lambda_sche(opti, config):
if config.get('decay_step') is not None:
lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
else:
raise NotImplementedError()
return scheduler
def build_lambda_bnsche(model, config):
if config.get('decay_step') is not None:
bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
else:
raise NotImplementedError()
return bnm_scheduler
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def set_bn_momentum_default(bn_momentum):
def fn(m):
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.momentum = bn_momentum
return fn
class BNMomentumScheduler(object):
def __init__(
self, model, bn_lambda, last_epoch=-1,
setter=set_bn_momentum_default
):
if not isinstance(model, nn.Module):
raise RuntimeError(
"Class '{}' is not a PyTorch nn Module".format(
type(model).__name__
)
)
self.model = model
self.setter = setter
self.lmbd = bn_lambda
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
self.model.apply(self.setter(self.lmbd(epoch)))
def get_momentum(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
return self.lmbd(epoch)
def seprate_point_cloud(xyz, num_points, crop, fixed_points=None, padding_zeros=False):
'''
seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
'''
_, n, c = xyz.shape
assert n == num_points
assert c == 3
if crop == num_points:
return xyz, None
INPUT = []
CROP = []
for points in xyz:
if isinstance(crop, list):
num_crop = random.randint(crop[0], crop[1])
else:
num_crop = crop
points = points.unsqueeze(0)
if fixed_points is None:
center = F.normalize(torch.randn(1, 1, 3), p=2, dim=-1).cuda()
else:
if isinstance(fixed_points, list):
fixed_point = random.sample(fixed_points, 1)[0]
else:
fixed_point = fixed_points
center = fixed_point.reshape(1, 1, 3).cuda()
distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p=2, dim=-1) # 1 1 2048
idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0, 0] # 2048
if padding_zeros:
input_data = points.clone()
input_data[0, idx[:num_crop]] = input_data[0, idx[:num_crop]] * 0
else:
input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
if isinstance(crop, list):
INPUT.append(fps(input_data, 2048))
CROP.append(fps(crop_data, 2048))
else:
INPUT.append(input_data)
CROP.append(crop_data)
input_data = torch.cat(INPUT, dim=0) # B N 3
crop_data = torch.cat(CROP, dim=0) # B M 3
return input_data.contiguous(), crop_data.contiguous()
def get_ptcloud_img(ptcloud):
fig = plt.figure(figsize=(8, 8))
x, z, y = ptcloud.transpose(1, 0)
ax = fig.gca(projection=Axes3D.name, adjustable='box')
ax.axis('off')
# ax.axis('scaled')
ax.view_init(90, 45)
max, min = np.max(ptcloud), np.min(ptcloud)
ax.set_xbound(min, max)
ax.set_ybound(min, max)
ax.set_zbound(min, max)
ax.scatter(x, y, z, zdir='z', c=y, cmap='jet')
fig.canvas.draw()
img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return img
def visualize_KITTI(path, data_list, titles=['input', 'pred'], cmap=['bwr', 'autumn'], zdir='y',
xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1)):
fig = plt.figure(figsize=(6 * len(data_list), 6))
cmax = data_list[-1][:, 0].max()
for i in range(len(data_list)):
data = data_list[i][:-2048] if i == 1 else data_list[i]
color = data[:, 0] / cmax
ax = fig.add_subplot(1, len(data_list), i + 1, projection='3d')
ax.view_init(30, -120)
b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color, vmin=-1, vmax=1, cmap=cmap[0], s=4,
linewidth=0.05, edgecolors='black')
ax.set_title(titles[i])
ax.set_axis_off()
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_zlim(zlim)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
if not os.path.exists(path):
os.makedirs(path)
pic_path = path + '.png'
fig.savefig(pic_path)
np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
plt.close(fig)
def random_dropping(pc, e):
up_num = max(64, 768 // (e // 50 + 1))
pc = pc
random_num = torch.randint(1, up_num, (1, 1))[0, 0]
pc = fps(pc, random_num)
padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
pc = torch.cat([pc, padding], dim=1)
return pc
def random_scale(partial, scale_range=[0.8, 1.2]):
scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
return partial * scale
================================================
FILE: segmentation/models/gpt2_seg.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, trunc_normal_
class Block(nn.Module):
def __init__(self, embed_dim, num_heads, drop_path):
super(Block, self).__init__()
self.ln_1 = nn.LayerNorm(embed_dim)
self.ln_2 = nn.LayerNorm(embed_dim)
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
attn_mask = torch.full(
(len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype
)
attn_mask = torch.triu(attn_mask, diagonal=1)
x = self.ln_1(x)
a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
x = x + self.drop_path(a)
m = self.drop_path(self.mlp(self.ln_2(x)))
x = x + m
return x
class GPT_extractor(nn.Module):
def __init__(
self, embed_dim, num_heads, num_layers, trans_dim, group_size, drop_path_rate
):
super(GPT_extractor, self).__init__()
self.embed_dim = embed_dim
self.trans_dim = trans_dim
self.group_size = group_size
self.drop_path_rate = drop_path_rate
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
dpr = [x.item() for x in torch.linspace(
0, self.drop_path_rate, num_layers)]
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(Block(embed_dim, num_heads, dpr[i]))
self.ln_f = nn.LayerNorm(embed_dim)
# prediction head
self.increase_dim = nn.Sequential(
nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)
)
def forward(self, h, pos, classify=False):
"""
Expect input as shape [sequence len, batch]
If classify, return classification logits
"""
batch, length, C = h.shape
h = h.transpose(0, 1)
pos = pos.transpose(0, 1)
# prepend sos token
sos = torch.ones(1, batch, self.embed_dim, device=h.device) * self.sos
if not classify:
h = torch.cat([sos, h[:-1, :, :]], axis=0)
else:
h = torch.cat([sos, h], axis=0)
feature_list = []
fetch_idx = [3, 7, 11]
# transformer
for i, layer in enumerate(self.layers):
h = layer(h + pos)
if i in fetch_idx:
feature_list.append(h.transpose(0, 1)[:, 2:])
h = self.ln_f(h)
encoded_points = h.transpose(0, 1)
return encoded_points, feature_list
class GPT_generator(nn.Module):
def __init__(
self, embed_dim, num_heads, num_layers, trans_dim, group_size, drop_path_rate
):
super(GPT_generator, self).__init__()
self.embed_dim = embed_dim
self.trans_dim = trans_dim
self.group_size = group_size
# start of sequence token
self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
nn.init.normal_(self.sos)
self.drop_path_rate = drop_path_rate
dpr = [x.item() for x in torch.linspace(
0, self.drop_path_rate, num_layers)]
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(Block(embed_dim, num_heads, dpr[i]))
self.ln_f = nn.LayerNorm(embed_dim)
# prediction head
self.increase_dim = nn.Sequential(
nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)
)
def forward(self, h, pos):
"""
Expect input as shape [sequence len, batch]
If classify, return classification logits
"""
batch, length, C = h.shape
h = h.transpose(0, 1)
pos = pos.transpose(0, 1)
# transformer
for layer in self.layers:
h = layer(h + pos)
h = self.ln_f(h)
rebuild_points = self.increase_dim(h.transpose(1, 2)).transpose(
1, 2).transpose(0, 1).reshape(batch * length, -1, 3)
return rebuild_points
================================================
FILE: segmentation/models/pointnet2_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
new_xyz = index_points(xyz, fps_idx)
idx = query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0]
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, mlp):
super(PointNetFeaturePropagation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
"""
Input:
xyz1: input points position data, [B, C, N]
xyz2: sampled input points position data, [B, C, S]
points1: input points data, [B, D, N]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
xyz1 = xyz1.permute(0, 2, 1)
xyz2 = xyz2.permute(0, 2, 1)
points2 = points2.permute(0, 2, 1)
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points
================================================
FILE: segmentation/models/pt.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, trunc_normal_
from logger import get_missing_parameters_message, get_unexpected_parameters_message
from pointnet2_ops import pointnet2_utils
from knn_cuda import KNN
from pointnet2_utils import PointNetFeaturePropagation
from gpt2_seg import GPT_extractor, GPT_generator
import math
from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2
import numpy as np
from z_order import *
def fps(data, number):
'''
data B N 3
number int
'''
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
fps_data = pointnet2_utils.gather_operation(data.transpose(
1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
return fps_data
class Group(nn.Module):
def __init__(self, num_group, group_size):
super().__init__()
self.num_group = num_group
self.group_size = group_size
self.knn = KNN(k=self.group_size, transpose_mode=True)
self.knn_2 = KNN(k=1, transpose_mode=True)
def simplied_morton_sorting(self, xyz, center):
batch_size, num_points, _ = xyz.shape
distances_batch = torch.cdist(center, center)
distances_batch[:, torch.eye(self.num_group).bool()] = float("inf")
idx_base = torch.arange(
0, batch_size, device=xyz.device) * self.num_group
sorted_indices_list = []
sorted_indices_list.append(idx_base)
distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
distances_batch[idx_base] = float("inf")
distances_batch = distances_batch.view(
batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
for i in range(self.num_group - 1):
distances_batch = distances_batch.view(
batch_size * self.num_group, self.num_group)
distances_to_last_batch = distances_batch[sorted_indices_list[-1]]
closest_point_idx = torch.argmin(distances_to_last_batch, dim=-1)
closest_point_idx = closest_point_idx + idx_base
sorted_indices_list.append(closest_point_idx)
distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
distances_batch[closest_point_idx] = float("inf")
distances_batch = distances_batch.view(
batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
sorted_indices = torch.stack(sorted_indices_list, dim=-1)
sorted_indices = sorted_indices.view(-1)
return sorted_indices
def morton_sorting(self, xyz, center):
batch_size, num_points, _ = xyz.shape
all_indices = []
for index in range(batch_size):
points = center[index]
z = get_z_values(points.cpu().numpy())
idxs = np.zeros((self.num_group), dtype=np.int32)
temp = np.arange(self.num_group)
z_ind = np.argsort(z[temp])
idxs = temp[z_ind]
all_indices.append(idxs)
all_indices = torch.tensor(all_indices, device=xyz.device)
idx_base = torch.arange(
0, batch_size, device=xyz.device).view(-1, 1) * self.num_group
sorted_indices = all_indices + idx_base
sorted_indices = sorted_indices.view(-1)
def forward(self, xyz):
'''
input: B N 3
---------------------------
output: B G M 3
center : B G 3
'''
batch_size, num_points, _ = xyz.shape
# fps the centers out
center = fps(xyz, self.num_group) # B G 3
# knn to get the neighborhood
_, idx = self.knn(xyz, center) # B G M
assert idx.size(1) == self.num_group
assert idx.size(2) == self.group_size
idx_base = torch.arange(
0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
neighborhood = neighborhood.view(
batch_size, self.num_group, self.group_size, 3).contiguous()
# normalize
neighborhood = neighborhood - center.unsqueeze(2)
# can utilize morton_sorting by choosing morton_sorting function
sorted_indices = self.simplied_morton_sorting(xyz, center)
neighborhood = neighborhood.view(
batch_size * self.num_group, self.group_size, 3)[sorted_indices, :, :]
neighborhood = neighborhood.view(
batch_size, self.num_group, self.group_size, 3).contiguous()
center = center.view(
batch_size * self.num_group, 3)[sorted_indices, :]
center = center.view(
batch_size, self.num_group, 3).contiguous()
return neighborhood, center
class Encoder_small(nn.Module):
def __init__(self, encoder_channel):
super().__init__()
self.encoder_channel = encoder_channel
self.first_conv = nn.Sequential(
nn.Conv1d(3, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 256, 1)
)
self.second_conv = nn.Sequential(
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, self.encoder_channel, 1)
)
def forward(self, point_groups):
'''
point_groups : B G N 3
-----------------
feature_global : B G C
'''
bs, g, n, _ = point_groups.shape
point_groups = point_groups.reshape(bs * g, n, 3)
# encoder
feature = self.first_conv(point_groups.transpose(2, 1))
feature_global = torch.max(feature, dim=2, keepdim=True)[0]
feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)
feature = self.second_conv(feature)
feature_global = torch.max(feature, dim=2, keepdim=False)[0]
return feature_global.reshape(bs, g, self.encoder_channel)
class Encoder_large(nn.Module): # Embedding module
def __init__(self, encoder_channel):
super().__init__()
self.encoder_channel = encoder_channel
self.first_conv = nn.Sequential(
nn.Conv1d(3, 256, 1),
nn.BatchNorm1d(256),
nn.ReLU(inplace=True),
nn.Conv1d(256, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, 1024, 1)
)
self.second_conv = nn.Sequential(
nn.Conv1d(2048, 2048, 1),
nn.BatchNorm1d(2048),
nn.ReLU(inplace=True),
nn.Conv1d(2048, self.encoder_channel, 1)
)
def forward(self, point_groups):
'''
point_groups : B G N 3
-----------------
feature_global : B G C
'''
bs, g, n, _ = point_groups.shape
point_groups = point_groups.reshape(bs * g, n, 3)
# encoder
feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n
feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1
feature = torch.cat(
[feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n
feature = self.second_conv(feature) # BG 1024 n
feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
return feature_global.reshape(bs, g, self.encoder_channel)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class TransformerEncoder(nn.Module):
""" Transformer Encoder without hierarchical structure
"""
def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
super().__init__()
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=drop_path_rate[i] if isinstance(
drop_path_rate, list) else drop_path_rate
)
for i in range(depth)])
def forward(self, x, pos):
feature_list = []
fetch_idx = [7, 15, 23]
for i, block in enumerate(self.blocks):
x = block(x + pos)
if i in fetch_idx:
feature_list.append(x)
return feature_list
class PositionEmbeddingCoordsSine(nn.Module):
"""Similar to transformer's position encoding, but generalizes it to
arbitrary dimensions and continuous coordinates.
Args:
n_dim: Number of input dimensions, e.g. 2 for image coordinates.
d_model: Number of dimensions to encode into
temperature:
scale:
"""
def __init__(self, n_dim: int = 1, d_model: int = 256, temperature=10000, scale=None):
super().__init__()
self.n_dim = n_dim
self.num_pos_feats = d_model // n_dim // 2 * 2
self.temperature = temperature
self.padding = d_model - self.num_pos_feats * self.n_dim
if scale is None:
scale = 1.0
self.scale = scale * 2 * math.pi
def forward(self, xyz: torch.Tensor) -> torch.Tensor:
"""
Args:
xyz: Point positions (*, d_in)
Returns:
pos_emb (*, d_out)
"""
assert xyz.shape[-1] == self.n_dim
dim_t = torch.arange(self.num_pos_feats,
dtype=torch.float32, device=xyz.device)
dim_t = self.temperature ** (2 * torch.div(dim_t,
2, rounding_mode='trunc') / self.num_pos_feats)
xyz = xyz * self.scale
pos_divided = xyz.unsqueeze(-1) / dim_t
pos_sin = pos_divided[..., 0::2].sin()
pos_cos = pos_divided[..., 1::2].cos()
pos_emb = torch.stack([pos_sin, pos_cos], dim=-
1).reshape(*xyz.shape[:-1], -1)
# Pad unused dimensions with zeros
pos_emb = F.pad(pos_emb, (0, self.padding))
return pos_emb
class get_model(nn.Module):
def __init__(self, cls_dim, trans_dim=384, depth=12, drop_path_rate=0.1, num_heads=6, decoder_depth=4, group_size=32, num_group=128, prop_dim=1024, label_dim1=512, label_dim2=256, encoder_dims=384):
super().__init__()
self.trans_dim = trans_dim
self.depth = depth
self.drop_path_rate = drop_path_rate
self.cls_dim = cls_dim
self.num_heads = num_heads
self.decoder_depth = decoder_depth
self.group_size = group_size
self.num_group = num_group
self.prop_dim = prop_dim
self.label_dim1 = label_dim1
self.label_dim2 = label_dim2
# grouper
self.group_divider = Group(
num_group=self.num_group, group_size=self.group_size)
# define the encoder
self.encoder_dims = encoder_dims
assert encoder_dims in [384, 768, 1024]
if encoder_dims == 384:
self.encoder = Encoder_small(encoder_channel=self.encoder_dims)
else:
self.encoder = Encoder_large(encoder_channel=self.encoder_dims)
# bridge encoder and transformer
self.pos_embed = PositionEmbeddingCoordsSine(3, self.encoder_dims, 1.0)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
self.sos_pos = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
self.blocks = GPT_extractor(
embed_dim=self.encoder_dims,
num_heads=self.num_heads,
num_layers=self.depth,
trans_dim=self.trans_dim,
group_size=self.group_size,
drop_path_rate=self.drop_path_rate
)
self.generator_blocks = GPT_generator(
embed_dim=self.encoder_dims,
num_heads=self.num_heads,
num_layers=self.decoder_depth,
trans_dim=self.trans_dim,
group_size=self.group_size,
drop_path_rate=self.drop_path_rate
)
self.norm = nn.LayerNorm(self.trans_dim)
self.label_conv = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2))
self.propagation_0 = PointNetFeaturePropagation(in_channel=3 * self.encoder_dims + 3,
mlp=[self.trans_dim * 4, self.prop_dim])
self.convs1 = nn.Conv1d(6*self.encoder_dims +
64 + self.prop_dim, self.label_dim1, 1)
self.dp1 = nn.Dropout(0.5)
self.convs2 = nn.Conv1d(self.label_dim1, self.label_dim2, 1)
self.convs3 = nn.Conv1d(self.label_dim2, self.cls_dim, 1)
self.bns1 = nn.BatchNorm1d(self.label_dim1)
self.bns2 = nn.BatchNorm1d(self.label_dim2)
self.relu = nn.ReLU()
self.loss_func_p1 = ChamferDistanceL1().cuda()
self.loss_func_p2 = ChamferDistanceL2().cuda()
def get_loss_acc(self, ret, gt):
loss = self.loss_ce(ret, gt.long())
pred = ret.argmax(-1)
acc = (pred == gt).sum() / float(gt.size(0))
return loss, acc * 100
def load_model_from_ckpt(self, bert_ckpt_path):
if bert_ckpt_path is not None:
ckpt = torch.load(bert_ckpt_path)
base_ckpt = {k.replace("module.", ""): v for k,
v in ckpt['base_model'].items()}
for k in list(base_ckpt.keys()):
if k.startswith('GPT_Transformer'):
base_ckpt[k[len('GPT_Transformer.'):]] = base_ckpt[k]
del base_ckpt[k]
elif k.startswith('base_model'):
base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
del base_ckpt[k]
incompatible = self.load_state_dict(base_ckpt, strict=False)
if incompatible.missing_keys:
print('missing_keys')
print(
get_missing_parameters_message(incompatible.missing_keys)
)
if incompatible.unexpected_keys:
print('unexpected_keys')
print(
get_unexpected_parameters_message(
incompatible.unexpected_keys)
)
print(
f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')
def forward(self, pts, cls_label):
B, C, N = pts.shape
pts = pts.transpose(-1, -2) # B N 3
neighborhood, center = self.group_divider(pts)
group_input_tokens = self.encoder(neighborhood) # B G N
B = group_input_tokens.shape[0]
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
pos = self.pos_embed(center)
sos_pos = self.sos_pos.expand(group_input_tokens.size(0), -1, -1)
pos = torch.cat([sos_pos, pos], dim=1)
relative_position = center[:, 1:, :] - center[:, :-1, :]
relative_norm = torch.norm(relative_position, dim=-1, keepdim=True)
relative_direction = relative_position / relative_norm
position = torch.cat(
[center[:, 0, :].unsqueeze(1), relative_direction], dim=1)
pos_relative = self.pos_embed(position)
x = torch.cat((cls_tokens, group_input_tokens), dim=1)
pos = torch.cat((cls_pos, pos), dim=1)
# transformer
encoded_features, feature_list = self.blocks(x, pos, classify=True)
encoded_features = torch.cat(
[encoded_features[:, 0, :].unsqueeze(1), encoded_features[:, 2:-1, :]], dim=1)
rebuild_points = self.generator_blocks(
encoded_features, pos_relative)
neighborhood = neighborhood + center.unsqueeze(2)
gt_points = neighborhood.reshape(
B*(self.num_group), self.group_size, 3)
loss1 = self.loss_func_p1(rebuild_points, gt_points)
loss2 = self.loss_func_p2(rebuild_points, gt_points)
feature_list = [self.norm(x).transpose(-1, -2).contiguous()
for x in feature_list]
x = torch.cat(
(feature_list[0], feature_list[1], feature_list[2]), dim=1) # 1152
x_max = torch.max(x, 2)[0]
x_avg = torch.mean(x, 2)
x_max_feature = x_max.view(B, -1).unsqueeze(-1).repeat(1, 1, N)
x_avg_feature = x_avg.view(B, -1).unsqueeze(-1).repeat(1, 1, N)
cls_label_one_hot = cls_label.view(B, 16, 1)
cls_label_feature = self.label_conv(cls_label_one_hot).repeat(1, 1, N)
x_global_feature = torch.cat(
(x_max_feature, x_avg_feature, cls_label_feature), 1) # 1152*2 + 64
f_level_0 = self.propagation_0(
pts.transpose(-1, -2), center.transpose(-1, -2), pts.transpose(-1, -2), x)
x = torch.cat((f_level_0, x_global_feature), 1)
x = self.relu(self.bns1(self.convs1(x)))
x = self.dp1(x)
x = self.relu(self.bns2(self.convs2(x)))
x = self.convs3(x)
x = F.log_softmax(x, dim=1)
x = x.permute(0, 2, 1)
return x
class get_loss(nn.Module):
def __init__(self):
super(get_loss, self).__init__()
def forward(self, pred, target):
total_loss = F.nll_loss(pred, target)
return total_loss
================================================
FILE: segmentation/models/z_order.py
================================================
import numpy as np
def round_to_int_32(data):
"""
Takes a Numpy array of float values between
-1 and 1, and rounds them to significant
32-bit integer values, to be used in the
morton code computation
:param data: multidimensional numpy array
:return: same as data but in 32-bit int format
"""
# first we rescale points to 0-512
min_data = np.abs(np.min(data)-0.5)
data = 256*(data + min_data)
# now convert to int
data = np.round(2 ** 21 - data).astype(dtype=np.int32)
return data
def split_by_3(x):
"""
Method to separate bits of a 32-bit integer
by 3 positions apart, using the magic bits
https://www.forceflow.be/2013/10/07/morton-encodingdecoding-through-bit-interleaving-implementations/
:param x: 32-bit integer
:return: x with bits separated
"""
# we only look at 21 bits, since we want to generate
# a 64-bit code eventually (3 x 21 bits = 63 bits, which
# is the maximum we can fit in a 64-bit code)
x &= 0x1fffff # only take first 21 bits
# shift left 32 bits, OR with self, and 00011111000000000000000000000000000000001111111111111111
x = (x | (x << 32)) & 0x1f00000000ffff
# shift left 16 bits, OR with self, and 00011111000000000000000011111111000000000000000011111111
x = (x | (x << 16)) & 0x1f0000ff0000ff
# shift left 8 bits, OR with self, and 0001000000001111000000001111000000001111000000001111000000000000
x = (x | (x << 8)) & 0x100f00f00f00f00f
# shift left 4 bits, OR with self, and 0001000011000011000011000011000011000011000011000011000100000000
x = (x | (x << 4)) & 0x10c30c30c30c30c3
# shift left 2 bits, OR with self, and 0001001001001001001001001001001001001001001001001001001001001001
x = (x | (x << 2)) & 0x1249249249249249
return x
def get_z_order(x, y, z):
"""
Given 3 arrays of corresponding x, y, z
coordinates, compute the morton (or z) code for
each point and return an index array
We compute the Morton order as follows:
1- Split all coordinates by 3 (add 2 zeros between bits)
2- Shift bits left by 1 for y and 2 for z
3- Interleave x, shifted y, and shifted z
The mordon order is the final interleaved bit sequence
:param x: x coordinates
:param y: y coordinates
:param z: z coordinates
:return: index array with morton code
"""
res = 0
res |= split_by_3(x) | split_by_3(y) << 1 | split_by_3(z) << 2
return res
def get_z_values(data):
"""
Computes the z values for a point array
:param data: Nx3 array of x, y, and z location
:return: Nx1 array of z values
"""
points_round = round_to_int_32(data) # convert to int
z = get_z_order(points_round[:, 0], points_round[:, 1], points_round[:, 2])
return z
================================================
FILE: segmentation/pointnet_util.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
# reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S, [K]]
Return:
new_points:, indexed points data, [B, S, [K], C]
"""
raw_size = idx.size()
idx = idx.reshape(raw_size[0], -1)
res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
return res.reshape(*raw_size, -1)
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
distance = torch.min(distance, dist)
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
torch.cuda.empty_cache()
new_xyz = index_points(xyz, fps_idx)
torch.cuda.empty_cache()
if knn:
dists = square_distance(new_xyz, xyz) # B x npoint x N
idx = dists.argsort()[:, :, :nsample] # B x npoint x K
else:
idx = query_ball_point(radius, nsample, xyz, new_xyz)
torch.cuda.empty_cache()
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
torch.cuda.empty_cache()
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
torch.cuda.empty_cache()
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.knn = knn
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, N, C]
points: input points data, [B, N, C]
Return:
new_xyz: sampled points position data, [B, S, C]
new_points_concat: sample points feature data, [B, S, D']
"""
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0].transpose(1, 2)
return new_xyz, new_points
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.knn = knn
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points, seed_idx=None):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx)
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
if self.knn:
dists = square_distance(new_xyz, xyz) # B x npoint x N
group_idx = dists.argsort()[:, :, :K] # B x npoint x K
else:
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2)
return new_xyz, new_points_concat
# NoteL this function swaps N and C
class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, mlp):
super(PointNetFeaturePropagation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
"""
Input:
xyz1: input points position data, [B, C, N]
xyz2: sampled input points position data, [B, C, S]
points1: input points data, [B, D, N]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
xyz1 = xyz1.permute(0, 2, 1)
xyz2 = xyz2.permute(0, 2, 1)
points2 = points2.permute(0, 2, 1)
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points
================================================
FILE: segmentation/provider.py
================================================
import numpy as np
def normalize_data(batch_data):
""" Normalize the batch data, use coordinates of the block centered at origin,
Input:
BxNxC array
Output:
BxNxC array
"""
B, N, C = batch_data.shape
normal_data = np.zeros((B, N, C))
for b in range(B):
pc = batch_data[b]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
normal_data[b] = pc
return normal_data
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx
def shuffle_points(batch_data):
""" Shuffle orders of points in each point cloud -- changes FPS behavior.
Use the same shuffling idx for the entire batch.
Input:
BxNxC array
Output:
BxNxC array
"""
idx = np.arange(batch_data.shape[1])
np.random.shuffle(idx)
return batch_data[:,idx,:]
def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_z(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, sinval, 0],
[-sinval, cosval, 0],
[0, 0, 1]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_with_normal(batch_xyz_normal):
''' Randomly rotate XYZ, normal point cloud.
Input:
batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal
Output:
B,N,6, rotated XYZ, normal point cloud
'''
for k in range(batch_xyz_normal.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_xyz_normal[k,:,0:3]
shape_normal = batch_xyz_normal[k,:,3:6]
batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)
return batch_xyz_normal
def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx6 array, original batch of point clouds and point normals
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)
return rotated_data
def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data
def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx6 array, original batch of point clouds with normal
scalar, angle of rotation
Return:
BxNx6 array, rotated batch of point clouds iwth normal
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k,:,0:3]
shape_normal = batch_data[k,:,3:6]
rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)
return rotated_data
def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in range(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data
def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data
def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
return batch_data
def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data
def random_point_dropout(batch_pc, max_dropout_ratio=0.875):
''' batch_pc: BxNx3 '''
for b in range(batch_pc.shape[0]):
dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875
drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]
if len(drop_idx)>0:
batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point
return batch_pc
================================================
FILE: tools/__init__.py
================================================
# from .runner import run_net
from .runner import test_net
from .runner_pretrain import run_net as pretrain_run_net
from .runner_finetune import run_net as finetune_run_net
from .runner_finetune import test_net as test_run_net
================================================
FILE: tools/builder.py
================================================
import os
import sys
# online package
import torch
# optimizer
import torch.optim as optim
# dataloader
from datasets import build_dataset_from_cfg
from models import build_model_from_cfg
# utils
from utils.logger import *
from utils.misc import *
from timm.scheduler import CosineLRScheduler
def dataset_builder(args, config):
dataset = build_dataset_from_cfg(config._base_, config.others)
shuffle = config.others.subset == 'train'
if args.distributed:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, shuffle=shuffle)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.others.bs,
num_workers=int(
args.num_workers),
drop_last=config.others.subset == 'train',
worker_init_fn=worker_init_fn,
sampler=sampler)
else:
sampler = None
dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.others.bs,
shuffle=shuffle,
drop_last=config.others.subset == 'train',
num_workers=int(
args.num_workers),
worker_init_fn=worker_init_fn)
return sampler, dataloader
def model_builder(config):
model = build_model_from_cfg(config)
return model
def build_opti_sche(base_model, config):
opti_config = config.optimizer
if opti_config.type == 'AdamW':
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
decay = []
no_decay = []
for name, param in model.module.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or 'token' in name or name in skip_list:
# print(name)
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
param_groups = add_weight_decay(
base_model, weight_decay=opti_config.kwargs.weight_decay)
optimizer = optim.AdamW(param_groups, **opti_config.kwargs)
elif opti_config.type == 'Adam':
optimizer = optim.Adam(base_model.parameters(), **opti_config.kwargs)
elif opti_config.type == 'SGD':
optimizer = optim.SGD(base_model.parameters(),
nesterov=True, **opti_config.kwargs)
else:
raise NotImplementedError()
sche_config = config.scheduler
if sche_config.type == 'LambdaLR':
scheduler = build_lambda_sche(optimizer, sche_config.kwargs) # misc.py
elif sche_config.type == 'CosLR':
scheduler = CosineLRScheduler(optimizer,
t_initial=sche_config.kwargs.epochs,
# t_mul=1,
lr_min=1e-6,
cycle_decay=0.1, # decay_rate
warmup_lr_init=1e-6,
warmup_t=sche_config.kwargs.initial_epochs,
cycle_limit=1,
t_in_epochs=True)
elif sche_config.type == 'StepLR':
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, **sche_config.kwargs)
elif sche_config.type == 'function':
scheduler = None
else:
raise NotImplementedError()
if config.get('bnmscheduler') is not None:
bnsche_config = config.bnmscheduler
if bnsche_config.type == 'Lambda':
bnscheduler = build_lambda_bnsche(
base_model, bnsche_config.kwargs) # misc.py
scheduler = [scheduler, bnscheduler]
return optimizer, scheduler
def resume_model(base_model, args, logger=None):
ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')
if not os.path.exists(ckpt_path):
print_log(
f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger=logger)
return 0, 0
print_log(
f'[RESUME INFO] Loading model weights from {ckpt_path}...', logger=logger)
# load state dict
map_location = {'cuda:%d' % 0: 'cuda:%d' % args.local_rank}
state_dict = torch.load(ckpt_path, map_location=map_location)
# parameter resume of base model
# if args.local_rank == 0:
base_ckpt = {k.replace("module.", ""): v for k,
v in state_dict['base_model'].items()}
base_model.load_state_dict(base_ckpt, strict=True)
# parameter
start_epoch = state_dict['epoch'] + 1
best_metrics = state_dict['best_metrics']
if not isinstance(best_metrics, dict):
best_metrics = best_metrics.state_dict()
# print(best_metrics)
print_log(
f'[RESUME INFO] resume ckpts @ {start_epoch - 1} epoch( best_metrics = {str(best_metrics):s})', logger=logger)
return start_epoch, best_metrics
def resume_optimizer(optimizer, args, logger=None):
ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')
if not os.path.exists(ckpt_path):
print_log(
f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger=logger)
return 0, 0, 0
print_log(
f'[RESUME INFO] Loading optimizer from {ckpt_path}...', logger=logger)
# load state dict
state_dict = torch.load(ckpt_path, map_location='cpu')
# optimizer
optimizer.load_state_dict(state_dict['optimizer'])
def save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, prefix, args, logger=None):
if args.local_rank == 0:
torch.save({
'base_model': base_model.module.state_dict() if args.distributed else base_model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'metrics': metrics.state_dict() if metrics is not None else dict(),
'best_metrics': best_metrics.state_dict() if best_metrics is not None else dict(),
}, os.path.join(args.experiment_path, prefix + '.pth'))
print_log(
f"Save checkpoint at {os.path.join(args.experiment_path, prefix + '.pth')}", logger=logger)
def load_model(base_model, ckpt_path, logger=None):
if not os.path.exists(ckpt_path):
raise NotImplementedError(
'no checkpoint file from path %s...' % ckpt_path)
print_log(f'Loading weights from {ckpt_path}...', logger=logger)
# load state dict
state_dict = torch.load(ckpt_path, map_location='cpu')
# parameter resume of base model
if state_dict.get('model') is not None:
base_ckpt = {k.replace("module.", ""): v for k,
v in state_dict['model'].items()}
elif state_dict.get('base_model') is not None:
base_ckpt = {k.replace("module.", ""): v for k,
v in state_dict['base_model'].items()}
else:
raise RuntimeError('mismatch of ckpt weight')
base_model.load_state_dict(base_ckpt, strict=True)
epoch = -1
if state_dict.get('epoch') is not None:
epoch = state_dict['epoch']
if state_dict.get('metrics') is not None:
metrics = state_dict['metrics']
if not isinstance(metrics, dict):
metrics = metrics.state_dict()
else:
metrics = 'No Metrics'
print_log(
f'ckpts @ {epoch} epoch( performance = {str(metrics):s})', logger=logger)
return
================================================
FILE: tools/runner.py
================================================
import torch
import torch.nn as nn
import os
import json
from tools import builder
from utils import misc, dist_utils
import time
from utils.logger import *
import cv2
import numpy as np
def test_net(args, config):
logger = get_logger(args.log_name)
print_log('Tester start ... ', logger=logger)
_, test_dataloader = builder.dataset_builder(args, config.dataset.test)
base_model = builder.model_builder(config.model)
# base_model.load_model_from_ckpt(args.ckpts)
builder.load_model(base_model, args.ckpts, logger=logger)
if args.use_gpu:
base_model.to(args.local_rank)
# DDP
if args.distributed:
raise NotImplementedError()
test(base_model, test_dataloader, args, config, logger=logger)
# visualization
def test(base_model, test_dataloader, args, config, logger=None):
base_model.eval() # set model to eval mode
target = './vis'
useful_cate = [
"02691156", # plane
"04379243", # table
"03790512", # motorbike
"03948459", # pistol
"03642806", # laptop
"03467517", # guitar
"03261776", # earphone
"03001627", # chair
"02958343", # car
"04090263", # rifle
"03759954", # microphone
]
with torch.no_grad():
for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
# import pdb; pdb.set_trace()
if taxonomy_ids[0] not in useful_cate:
continue
if taxonomy_ids[0] == "02691156":
a, b = 90, 135
elif taxonomy_ids[0] == "04379243":
a, b = 30, 30
elif taxonomy_ids[0] == "03642806":
a, b = 30, -45
elif taxonomy_ids[0] == "03467517":
a, b = 0, 90
elif taxonomy_ids[0] == "03261776":
a, b = 0, 75
elif taxonomy_ids[0] == "03001627":
a, b = 30, -45
else:
a, b = 0, 0
dataset_name = config.dataset.test._base_.NAME
if dataset_name == 'ShapeNet':
points = data.cuda()
else:
raise NotImplementedError(
f'Train phase do not support {dataset_name}')
# dense_points, vis_points = base_model(points, vis=True)
dense_points, vis_points, centers = base_model(points, vis=True)
final_image = []
data_path = f'./vis/{taxonomy_ids[0]}_{idx}'
if not os.path.exists(data_path):
os.makedirs(data_path)
points = points.squeeze().detach().cpu().numpy()
np.savetxt(os.path.join(data_path, 'gt.txt'),
points, delimiter=';')
points = misc.get_ptcloud_img(points, a, b)
final_image.append(points[150:650, 150:675, :])
# centers = centers.squeeze().detach().cpu().numpy()
# np.savetxt(os.path.join(data_path,'center.txt'), centers, delimiter=';')
# centers = misc.get_ptcloud_img(centers)
# final_image.append(centers)
vis_points = vis_points.squeeze().detach().cpu().numpy()
np.savetxt(os.path.join(data_path, 'vis.txt'),
vis_points, delimiter=';')
vis_points = misc.get_ptcloud_img(vis_points, a, b)
final_image.append(vis_points[150:650, 150:675, :])
dense_points = dense_points.squeeze().detach().cpu().numpy()
np.savetxt(os.path.join(data_path, 'dense_points.txt'),
dense_points, delimiter=';')
dense_points = misc.get_ptcloud_img(dense_points, a, b)
final_image.append(dense_points[150:650, 150:675, :])
img = np.concatenate(final_image, axis=1)
img_path = os.path.join(data_path, f'plot.jpg')
cv2.imwrite(img_path, img)
if idx > 1500:
break
return
================================================
FILE: tools/runner_finetune.py
================================================
import torch
import torch.nn as nn
from tools import builder
from utils import misc, dist_utils
import time
from utils.logger import *
from utils.AverageMeter import AverageMeter
import numpy as np
from datasets import data_transforms
from pointnet2_ops import pointnet2_utils
from torchvision import transforms
train_transforms = transforms.Compose(
[
# data_transforms.PointcloudScale(),
# data_transforms.PointcloudRotate(),
# data_transforms.PointcloudTranslate(),
# data_transforms.PointcloudJitter(),
# data_transforms.PointcloudRandomInputDropout(),
# data_transforms.RandomHorizontalFlip(),
data_transforms.PointcloudScaleAndTranslate(),
]
)
test_transforms = transforms.Compose(
[
# data_transforms.PointcloudScale(),
# data_transforms.PointcloudRotate(),
# data_transforms.PointcloudTranslate(),
data_transforms.PointcloudScaleAndTranslate(),
]
)
class Acc_Metric:
def __init__(self, acc=0.):
if type(acc).__name__ == 'dict':
self.acc = acc['acc']
elif type(acc).__name__ == 'Acc_Metric':
self.acc = acc.acc
else:
self.acc = acc
def better_than(self, other):
if self.acc > other.acc:
return True
else:
return False
def state_dict(self):
_dict = dict()
_dict['acc'] = self.acc
return _dict
def run_net(args, config, train_writer=None, val_writer=None):
logger = get_logger(args.log_name)
# build dataset
(train_sampler, train_dataloader), (_, test_dataloader), = builder.dataset_builder(args, config.dataset.train), \
builder.dataset_builder(args, config.dataset.val)
# build model
base_model = builder.model_builder(config.model)
# parameter setting
start_epoch = 0
best_metrics = Acc_Metric(0.)
best_metrics_vote = Acc_Metric(0.)
metrics = Acc_Metric(0.)
# resume ckpts
if args.resume:
start_epoch, best_metric = builder.resume_model(
base_model, args, logger=logger)
best_metrics = Acc_Metric(best_metrics)
else:
if args.ckpts is not None:
base_model.load_model_from_ckpt(args.ckpts)
else:
print_log('Training from scratch', logger=logger)
if args.use_gpu:
base_model.to(args.local_rank)
# DDP
if args.distributed:
# Sync BN
if args.sync_bn:
base_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
base_model)
print_log('Using Synchronized BatchNorm ...', logger=logger)
base_model = nn.parallel.DistributedDataParallel(
base_model, device_ids=[args.local_rank % torch.cuda.device_count()])
print_log('Using Distributed Data parallel ...', logger=logger)
else:
print_log('Using Data parallel ...', logger=logger)
base_model = nn.DataParallel(base_model).cuda()
# optimizer & scheduler
optimizer, scheduler = builder.build_opti_sche(base_model, config)
if args.resume:
builder.resume_optimizer(optimizer, args, logger=logger)
# trainval
# training
base_model.zero_grad()
for epoch in range(start_epoch, config.max_epoch + 1):
if args.distributed:
train_sampler.set_epoch(epoch)
base_model.train()
epoch_start_time = time.time()
batch_start_time = time.time()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter(['loss', 'loss_r', 'acc'])
num_iter = 0
base_model.train() # set model to training mode
n_batches = len(train_dataloader)
npoints = config.npoints
for idx, (taxonomy_ids, model_ids, data) in enumerate(train_dataloader):
num_iter += 1
n_itr = epoch * n_batches + idx
data_time.update(time.time() - batch_start_time)
points = data[0].cuda()
label = data[1].cuda()
if npoints == 1024:
point_all = 1200
elif npoints == 2048:
point_all = 2400
elif npoints == 4096:
point_all = 4800
elif npoints == 8192:
point_all = 8192
else:
raise NotImplementedError()
if points.size(1) < point_all:
point_all = points.size(1)
fps_idx = pointnet2_utils.furthest_point_sample(
points, point_all) # (B, npoint)
fps_idx = fps_idx[:, np.random.choice(point_all, npoints, False)]
points = pointnet2_utils.gather_operation(points.transpose(
1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() # (B, N, 3)
# import pdb; pdb.set_trace()
points = train_transforms(points)
ret, loss1 = base_model(points)
loss, acc = base_model.module.get_loss_acc(ret, label)
_loss = loss + 3 * loss1
try:
_loss.backward()
except:
_loss = _loss.mean()
_loss.backward()
# forward
if num_iter == config.step_per_update:
if config.get('grad_norm_clip') is not None:
torch.nn.utils.clip_grad_norm_(
base_model.parameters(), config.grad_norm_clip, norm_type=2)
num_iter = 0
optimizer.step()
base_model.zero_grad()
if args.distributed:
loss = dist_utils.reduce_tensor(loss, args)
acc = dist_utils.reduce_tensor(acc, args)
losses.update([loss.item(), loss1.item(), acc.item()])
else:
try:
losses.update([loss.item(), loss1.item(), acc.item()])
except:
losses.update([loss.mean().item(), loss1.mean().item(), acc.mean().item()])
if args.distributed:
torch.cuda.synchronize()
if train_writer is not None:
train_writer.add_scalar('Loss/Batch/Loss', loss.item(), n_itr)
train_writer.add_scalar(
'Loss/Batch/TrainAcc', acc.item(), n_itr)
train_writer.add_scalar(
'Loss/Batch/LR', optimizer.param_groups[0]['lr'], n_itr)
batch_time.update(time.time() - batch_start_time)
batch_start_time = time.time()
# if idx % 10 == 0:
# print_log('[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Loss+Acc = %s lr = %.6f' %
# (epoch, config.max_epoch, idx + 1, n_batches, batch_time.val(), data_time.val(),
# ['%.4f' % l for l in losses.val()], optimizer.param_groups[0]['lr']), logger = logger)
if isinstance(scheduler, list):
for item in scheduler:
item.step(epoch)
else:
scheduler.step(epoch)
epoch_end_time = time.time()
if train_writer is not None:
train_writer.add_scalar('Loss/Epoch/Loss', losses.avg(0), epoch)
print_log('[Training] EPOCH: %d EpochTime = %.3f (s) Losses = %s lr = %.6f' %
(epoch, epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()], optimizer.param_groups[0]['lr']), logger=logger)
if epoch % args.val_freq == 0 and epoch != 0:
# Validate the current model
metrics = validate(base_model, test_dataloader,
epoch, val_writer, args, config, logger=logger)
better = metrics.better_than(best_metrics)
# Save ckeckpoints
if better:
best_metrics = metrics
builder.save_checkpoint(
base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-best', args, logger=logger)
print_log(
"--------------------------------------------------------------------------------------------", logger=logger)
if args.vote:
if metrics.acc > 92.1 or (better and metrics.acc > 91):
metrics_vote = validate_vote(
base_model, test_dataloader, epoch, val_writer, args, config, logger=logger)
if metrics_vote.better_than(best_metrics_vote):
best_metrics_vote = metrics_vote
print_log(
"****************************************************************************************",
logger=logger)
builder.save_checkpoint(
base_model, optimizer, epoch, metrics, best_metrics_vote, 'ckpt-best_vote', args, logger=logger)
builder.save_checkpoint(base_model, optimizer, epoch,
metrics, best_metrics, 'ckpt-last', args, logger=logger)
# if (config.max_epoch - epoch) < 10:
# builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, f'ckpt-epoch-{epoch:03d}', args, logger = logger)
if train_writer is not None:
train_writer.close()
if val_writer is not None:
val_writer.close()
def validate(base_model, test_dataloader, epoch, val_writer, args, config, logger=None):
# print_log(f"[VALIDATION] Start validating epoch {epoch}", logger = logger)
base_model.eval() # set model to eval mode
test_pred = []
test_label = []
npoints = config.npoints
with torch.no_grad():
for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
points = data[0].cuda()
label = data[1].cuda()
points = misc.fps(points, npoints)
logits, loss1 = base_model(points)
target = label.view(-1)
pred = logits.argmax(-1).view(-1)
test_pred.append(pred.detach())
test_label.append(target.detach())
test_pred = torch.cat(test_pred, dim=0)
test_label = torch.cat(test_label, dim=0)
if args.distributed:
test_pred = dist_utils.gather_tensor(test_pred, args)
test_label = dist_utils.gather_tensor(test_label, args)
acc = (test_pred == test_label).sum() / \
float(test_label.size(0)) * 100.
print_log('[Validation] EPOCH: %d acc = %.4f' %
(epoch, acc), logger=logger)
if args.distributed:
torch.cuda.synchronize()
# Add testing results to TensorBoard
if val_writer is not None:
val_writer.add_scalar('Metric/ACC', acc, epoch)
return Acc_Metric(acc)
def validate_vote(base_model, test_dataloader, epoch, val_writer, args, config, logger=None, times=10):
print_log(f"[VALIDATION_VOTE] epoch {epoch}", logger=logger)
base_model.eval() # set model to eval mode
test_pred = []
test_label = []
npoints = config.npoints
with torch.no_grad():
for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
points_raw = data[0].cuda()
label = data[1].cuda()
if npoints == 1024:
point_all = 1200
elif npoints == 4096:
point_all = 4800
elif npoints == 8192:
point_all = 8192
else:
raise NotImplementedError()
if points_raw.size(1) < point_all:
point_all = points_raw.size(1)
fps_idx_raw = pointnet2_utils.furthest_point_sample(
points_raw, point_all) # (B, npoint)
local_pred = []
for kk in range(times):
fps_idx = fps_idx_raw[:, np.random.choice(
point_all, npoints, False)]
points = pointnet2_utils.gather_operation(points_raw.transpose(1, 2).contiguous(),
fps_idx).transpose(1, 2).contiguous() # (B, N, 3)
points = test_transforms(points)
logits, loss1 = base_model(points)
target = label.view(-1)
local_pred.append(logits.detach().unsqueeze(0))
pred = torch.cat(local_pred, dim=0).mean(0)
_, pred_choice = torch.max(pred, -1)
test_pred.append(pred_choice)
test_label.append(target.detach())
test_pred = torch.cat(test_pred, dim=0)
test_label = torch.cat(test_label, dim=0)
if args.distributed:
test_pred = dist_utils.gather_tensor(test_pred, args)
test_label = dist_utils.gather_tensor(test_label, args)
acc = (test_pred == test_label).sum() / \
float(test_label.size(0)) * 100.
print_log('[Validation_vote] EPOCH: %d acc_vote = %.4f' %
(epoch, acc), logger=logger)
if args.distributed:
torch.cuda.synchronize()
# Add testing results to TensorBoard
if val_writer is not None:
val_writer.add_scalar('Metric/ACC_vote', acc, epoch)
return Acc_Metric(acc)
def test_net(args, config):
logger = get_logger(args.log_name)
print_log('Tester start ... ', logger=logger)
_, test_dataloader = builder.dataset_builder(args, config.dataset.test)
base_model = builder.model_builder(config.model)
# load checkpoints
# for finetuned transformer
builder.load_model(base_model, args.ckpts, logger=logger)
# base_model.load_model_from_ckpt(args.ckpts) # for BERT
if args.use_gpu:
base_model.to(args.local_rank)
# DDP
if args.distributed:
raise NotImplementedError()
test(base_model, test_dataloader, args, config, logger=logger)
def test(base_model, test_dataloader, args, config, logger=None):
base_model.eval() # set model to eval mode
test_pred = []
test_label = []
npoints = config.npoints
with torch.no_grad():
for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
points = data[0].cuda()
label = data[1].cuda()
points = misc.fps(points, npoints)
logits, loss1 = base_model(points)
target = label.view(-1)
pred = logits.argmax(-1).view(-1)
test_pred.append(pred.detach())
test_label.append(target.detach())
test_pred = torch.cat(test_pred, dim=0)
test_label = torch.cat(test_label, dim=0)
if args.distributed:
test_pred = dist_utils.gather_tensor(test_pred, args)
test_label = dist_utils.gather_tensor(test_label, args)
acc = (test_pred == test_label).sum() / \
float(test_label.size(0)) * 100.
print_log('[TEST] acc = %.4f' % acc, logger=logger)
if args.distributed:
torch.cuda.synchronize()
print_log(f"[TEST_VOTE]", logger=logger)
acc = 0.
for time in range(1, 300):
this_acc = test_vote(base_model, test_dataloader,
1, None, args, config, logger=logger, times=5)
if acc < this_acc:
acc = this_acc
print_log('[TEST_VOTE_time %d] acc = %.4f, best acc = %.4f' %
(time, this_acc, acc), logger=logger)
print_log('[TEST_VOTE] acc = %.4f' % acc, logger=logger)
def test_vote(base_model, test_dataloader, epoch, val_writer, args, config, logger=None, times=10):
base_model.eval() # set model to eval mode
test_pred = []
test_label = []
npoints = config.npoints
with torch.no_grad():
for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
points_raw = data[0].cuda()
label = data[1].cuda()
if npoints == 1024:
point_all = 1024
elif npoints == 2048:
point_all = 2048
elif npoints == 4096:
point_all = 4096
elif npoints == 8192:
point_all = 8192
else:
raise NotImplementedError()
if points_raw.size(1) < point_all:
point_all = points_raw.size(1)
fps_idx_raw = pointnet2_utils.furthest_point_sample(
points_raw, point_all) # (B, npoint)
local_pred = []
for kk in range(times):
fps_idx = fps_idx_raw[:, np.random.choice(
point_all, npoints, False)]
points = pointnet2_utils.gather_operation(points_raw.transpose(1, 2).contiguous(),
fps_idx).transpose(1, 2).contiguous() # (B, N, 3)
points = test_transforms(points)
logits, loss1 = base_model(points)
target = label.view(-1)
local_pred.append(logits.detach().unsqueeze(0))
# softmax = torch.softmax
pred = torch.cat(local_pred, dim=0).mean(0)
# print('pred', pred.shape)
# pred = softmax(1000*pred, dim=-1).mean(0)
_, pred_choice = torch.max(pred, -1)
test_pred.append(pred_choice)
test_label.append(target.detach())
test_pred = torch.cat(test_pred, dim=0)
test_label = torch.cat(test_label, dim=0)
if args.distributed:
test_pred = dist_utils.gather_tensor(test_pred, args)
test_label = dist_utils.gather_tensor(test_label, args)
acc = (test_pred == test_label).sum() / \
float(test_label.size(0)) * 100.
if args.distributed:
torch.cuda.synchronize()
# Add testing results to TensorBoard
if val_writer is not None:
val_writer.add_scalar('Metric/ACC_vote', acc, epoch)
# print_log('[TEST] acc = %.4f' % acc, logger=logger)
return acc
================================================
FILE: tools/runner_pretrain.py
================================================
import torch
import torch.nn as nn
import os
import json
from tools import builder
from utils import misc, dist_utils
import time
from utils.logger import *
from utils.AverageMeter import AverageMeter
from sklearn.svm import LinearSVC
import numpy as np
from torchvision import transforms
from datasets import data_transforms
from pointnet2_ops import pointnet2_utils
from torchstat import stat
train_transforms = transforms.Compose(
[
# data_transforms.PointcloudScale(),
# data_transforms.PointcloudRotate(),
# data_transforms.PointcloudRotatePerturbation(),
# data_transforms.PointcloudTranslate(),
# data_transforms.PointcloudJitter(),
# data_transforms.PointcloudRandomInputDropout(),
data_transforms.PointcloudScaleAndTranslate(),
]
)
class Acc_Metric:
def __init__(self, acc=0.):
if type(acc).__name__ == 'dict':
self.acc = acc['acc']
else:
self.acc = acc
def better_than(self, other):
if self.acc > other.acc:
return True
else:
return False
def state_dict(self):
_dict = dict()
_dict['acc'] = self.acc
return _dict
def evaluate_svm(train_features, train_labels, test_features, test_labels):
clf = LinearSVC()
clf.fit(train_features, train_labels)
pred = clf.predict(test_features)
return np.sum(test_labels == pred) * 1. / pred.shape[0]
def run_net(args, config, train_writer=None, val_writer=None):
logger = get_logger(args.log_name)
# build dataset
(train_sampler, train_dataloader), (_, test_dataloader), = builder.dataset_builder(args, config.dataset.train), \
builder.dataset_builder(args, config.dataset.val)
(_, extra_train_dataloader) = builder.dataset_builder(
args, config.dataset.extra_train) if config.dataset.get('extra_train') else (None, None)
# build model
base_model = builder.model_builder(config.model)
if args.use_gpu:
base_model.to(args.local_rank)
# from IPython import embed; embed()
# parameter setting
start_epoch = 0
best_metrics = Acc_Metric(0.)
metrics = Acc_Metric(0.)
# resume ckpts
if args.resume:
start_epoch, best_metric = builder.resume_model(
base_model, args, logger=logger)
best_metrics = Acc_Metric(best_metric)
elif args.start_ckpts is not None:
builder.load_model(base_model, args.start_ckpts, logger=logger)
# DDP
if args.distributed:
# Sync BN
if args.sync_bn:
base_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
base_model)
print_log('Using Synchronized BatchNorm ...', logger=logger)
base_model = nn.parallel.DistributedDataParallel(base_model, device_ids=[
args.local_rank % torch.cuda.device_count()], find_unused_parameters=True)
print_log('Using Distributed Data parallel ...', logger=logger)
else:
print_log('Using Data parallel ...', logger=logger)
base_model = nn.DataParallel(base_model).cuda()
# optimizer & scheduler
optimizer, scheduler = builder.build_opti_sche(base_model, config)
if args.resume:
builder.resume_optimizer(optimizer, args, logger=logger)
# trainval
# training
base_model.zero_grad()
for epoch in range(start_epoch, config.max_epoch + 1):
if args.distributed:
train_sampler.set_epoch(epoch)
base_model.train()
epoch_start_time = time.time()
batch_start_time = time.time()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter(['Loss'])
num_iter = 0
base_model.train() # set model to training mode
n_batches = len(train_dataloader)
for idx, (taxonomy_ids, model_ids, data) in enumerate(train_dataloader):
num_iter += 1
n_itr = epoch * n_batches + idx
data_time.update(time.time() - batch_start_time)
npoints = config.dataset.train.others.npoints
dataset_name = config.dataset.train._base_.NAME
if dataset_name == 'ShapeNet' or dataset_name == 'UnlabeledHybrid':
points = data.cuda()
elif dataset_name == 'ModelNet':
points = data[0].cuda()
points = misc.fps(points, npoints)
else:
raise NotImplementedError(
f'Train phase do not support {dataset_name}')
assert points.size(1) == npoints
points = train_transforms(points)
loss = base_model(points)
try:
loss.backward()
# print("Using one GPU")
except:
loss = loss.mean()
loss.backward()
# print("Using multi GPUs")
# forward
if num_iter == config.step_per_update:
num_iter = 0
optimizer.step()
base_model.zero_grad()
if args.distributed:
loss = dist_utils.reduce_tensor(loss, args)
losses.update([loss.item()*1000])
else:
losses.update([loss.item()*1000])
if args.distributed:
torch.cuda.synchronize()
if train_writer is not None:
train_writer.add_scalar('Loss/Batch/Loss', loss.item(), n_itr)
train_writer.add_scalar(
'Loss/Batch/LR', optimizer.param_groups[0]['lr'], n_itr)
batch_time.update(time.time() - batch_start_time)
batch_start_time = time.time()
if idx % 20 == 0:
print_log('[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s lr = %.6f' %
(epoch, config.max_epoch, idx + 1, n_batches, batch_time.val(), data_time.val(),
['%.4f' % l for l in losses.val()], optimizer.param_groups[0]['lr']), logger=logger)
if isinstance(scheduler, list):
for item in scheduler:
item.step(epoch)
else:
scheduler.step(epoch)
epoch_end_time = time.time()
if train_writer is not None:
train_writer.add_scalar('Loss/Epoch/Loss_1', losses.avg(0), epoch)
print_log('[Training] EPOCH: %d EpochTime = %.3f (s) Losses = %s lr = %.6f' %
(epoch, epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()],
optimizer.param_groups[0]['lr']), logger=logger)
# if epoch % args.val_freq == 0 and epoch != 0:
# # Validate the current model
# metrics = validate(base_model, extra_train_dataloader, test_dataloader, epoch, val_writer, args, config, logger=logger)
#
# # Save ckeckpoints
# if metrics.better_than(best_metrics):
# best_metrics = metrics
# builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-best', args, logger = logger)
builder.save_checkpoint(base_model, optimizer, epoch,
metrics, best_metrics, 'ckpt-last', args, logger=logger)
if epoch % 25 == 0 and epoch >= 250:
builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, f'ckpt-epoch-{epoch:03d}', args,
logger=logger)
# if (config.max_epoch - epoch) < 10:
# builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, f'ckpt-epoch-{epoch:03d}', args, logger = logger)
if train_writer is not None:
train_writer.close()
if val_writer is not None:
val_writer.close()
def validate(base_model, extra_train_dataloader, test_dataloader, epoch, val_writer, args, config, logger=None):
print_log(f"[VALIDATION] Start validating epoch {epoch}", logger=logger)
base_model.eval() # set model to eval mode
test_features = []
test_label = []
train_features = []
train_label = []
npoints = config.dataset.train.others.npoints
with torch.no_grad():
for idx, (taxonomy_ids, model_ids, data) in enumerate(extra_train_dataloader):
points = data[0].cuda()
label = data[1].cuda()
points = misc.fps(points, npoints)
assert points.size(1) == npoints
feature = base_model(points, noaug=True)
target = label.view(-1)
train_features.append(feature.detach())
train_label.append(target.detach())
for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):
points = data[0].cuda()
label = data[1].cuda()
points = misc.fps(points, npoints)
assert points.size(1) == npoints
feature = base_model(points, noaug=True)
target = label.view(-1)
test_features.append(feature.detach())
test_label.append(target.detach())
train_features = torch.cat(train_features, dim=0)
train_label = torch.cat(train_label, dim=0)
test_features = torch.cat(test_features, dim=0)
test_label = torch.cat(test_label, dim=0)
if args.distributed:
train_features = dist_utils.gather_tensor(train_features, args)
train_label = dist_utils.gather_tensor(train_label, args)
test_features = dist_utils.gather_tensor(test_features, args)
test_label = dist_utils.gather_tensor(test_label, args)
svm_acc = evaluate_svm(train_features.data.cpu().numpy(), train_label.data.cpu(
).numpy(), test_features.data.cpu().numpy(), test_label.data.cpu().numpy())
print_log('[Validation] EPOCH: %d acc = %.4f' %
(epoch, svm_acc), logger=logger)
if args.distributed:
torch.cuda.synchronize()
# Add testing results to TensorBoard
if val_writer is not None:
val_writer.add_scalar('Metric/ACC', svm_acc, epoch)
return Acc_Metric(svm_acc)
def test_net():
pass
================================================
FILE: utils/AverageMeter.py
================================================
class AverageMeter(object):
def __init__(self, items=None):
self.items = items
self.n_items = 1 if items is None else len(items)
self.reset()
def reset(self):
self._val = [0] * self.n_items
self._sum = [0] * self.n_items
self._count = [0] * self.n_items
def update(self, values):
if type(values).__name__ == 'list':
for idx, v in enumerate(values):
self._val[idx] = v
self._sum[idx] += v
self._count[idx] += 1
else:
self._val[0] = values
self._sum[0] += values
self._count[0] += 1
def val(self, idx=None):
if idx is None:
return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)]
else:
return self._val[idx]
def count(self, idx=None):
if idx is None:
return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)]
else:
return self._count[idx]
def avg(self, idx=None):
if idx is None:
return self._sum[0] / self._count[0] if self.items is None else [
self._sum[i] / self._count[i] for i in range(self.n_items)
]
else:
return self._sum[idx] / self._count[idx]
================================================
FILE: utils/checkpoint.py
================================================
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import copy
import logging
import os
from collections import defaultdict
import torch
import torch.nn as nn
from typing import Any
from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
from termcolor import colored
def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
)
return msg
def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
)
return msg
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix):]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata # pyre-ignore
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix):]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1:]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(group) + "}"
def _named_modules_with_dup(
model: nn.Module, prefix: str = ""
) -> Iterable[Tuple[str, nn.Module]]:
"""
The same as `model.named_modules()`, except that it includes
duplicated modules that have more than one name.
"""
yield prefix, model
for name, module in model._modules.items(): # pyre-ignore
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _named_modules_with_dup(module, submodule_prefix)
================================================
FILE: utils/config.py
================================================
import yaml
from easydict import EasyDict
import os
from .logger import print_log
def log_args_to_file(args, pre='args', logger=None):
for key, val in args.__dict__.items():
print_log(f'{pre}.{key} : {val}', logger=logger)
def log_config_to_file(cfg, pre='cfg', logger=None):
for key, val in cfg.items():
if isinstance(cfg[key], EasyDict):
print_log(f'{pre}.{key} = edict()', logger=logger)
log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger)
continue
print_log(f'{pre}.{key} : {val}', logger=logger)
def merge_new_config(config, new_config):
for key, val in new_config.items():
if not isinstance(val, dict):
if key == '_base_':
with open(new_config['_base_'], 'r') as f:
try:
val = yaml.load(f, Loader=yaml.FullLoader)
except:
val = yaml.load(f)
config[key] = EasyDict()
merge_new_config(config[key], val)
else:
config[key] = val
continue
if key not in config:
config[key] = EasyDict()
merge_new_config(config[key], val)
return config
def cfg_from_yaml_file(cfg_file):
config = EasyDict()
with open(cfg_file, 'r') as f:
try:
new_config = yaml.load(f, Loader=yaml.FullLoader)
except:
new_config = yaml.load(f)
merge_new_config(config=config, new_config=new_config)
return config
def get_config(args, logger=None):
if args.resume:
cfg_path = os.path.join(args.experiment_path, 'config.yaml')
if not os.path.exists(cfg_path):
print_log("Failed to resume", logger=logger)
raise FileNotFoundError()
print_log(f'Resume yaml from {cfg_path}', logger=logger)
args.config = cfg_path
config = cfg_from_yaml_file(args.config)
if not args.resume and args.local_rank == 0:
save_experiment_config(args, config, logger)
return config
def save_experiment_config(args, config, logger=None):
config_path = os.path.join(args.experiment_path, 'config.yaml')
os.system('cp %s %s' % (args.config, config_path))
print_log(
f'Copy the Config file from {args.config} to {config_path}', logger=logger)
================================================
FILE: utils/dist_utils.py
================================================
import os
import torch
import torch.multiprocessing as mp
from torch import distributed as dist
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
print(f'init distributed in rank {torch.distributed.get_rank()}')
def get_dist_info():
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def reduce_tensor(tensor, args):
'''
for acc kind, get the mean in each gpu
'''
rt = tensor.clone()
torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
rt /= args.world_size
return rt
def gather_tensor(tensor, args):
output_tensors = [tensor.clone() for _ in range(args.world_size)]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
return concat
================================================
FILE: utils/logger.py
================================================
import logging
import torch.distributed as dist
logger_initialized = {}
def get_root_logger(log_file=None, log_level=logging.INFO, name='main'):
"""Get root logger and add a keyword filter to it.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added. The name of the root logger is the top-level package name,
e.g., "mmdet3d".
Args:
log_file (str, optional): File path of log. Defaults to None.
log_level (int, optional): The level of logger.
Defaults to logging.INFO.
name (str, optional): The name of the root logger, also used as a
filter keyword. Defaults to 'mmdet3d'.
Returns:
:obj:`logging.Logger`: The obtained logger
"""
logger = get_logger(name=name, log_file=log_file, log_level=log_level)
# add a logging filter
logging_filter = logging.Filter(name)
logging_filter.filter = lambda record: record.find(name) != -1
return logger
def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified and the process rank is 0, a FileHandler
will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# only rank 0 will add a FileHandler
if rank == 0 and log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
if rank == 0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
return logger
def print_log(msg, logger=None, level=logging.INFO):
"""Print a log message.
Args:
msg (str): The message to be logged.
logger (logging.Logger | str | None): The logger to be used.
Some special loggers are:
- "silent": no message will be printed.
- other str: the logger obtained with `get_root_logger(logger)`.
- None: The `print()` method will be used to print log messages.
level (int): Logging level. Only available when `logger` is a Logger
object or "root".
"""
if logger is None:
print(msg)
elif isinstance(logger, logging.Logger):
logger.log(level, msg)
elif logger == 'silent':
pass
elif isinstance(logger, str):
_logger = get_logger(logger)
_logger.log(level, msg)
else:
raise TypeError(
'logger should be either a logging.Logger object, str, '
f'"silent" or None, but got {type(logger)}')
================================================
FILE: utils/misc.py
================================================
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from collections import abc
from pointnet2_ops import pointnet2_utils
def fps(data, number):
'''
data B N 3
number int
'''
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
return fps_data
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
def build_lambda_sche(opti, config):
if config.get('decay_step') is not None:
lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)
scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)
else:
raise NotImplementedError()
return scheduler
def build_lambda_bnsche(model, config):
if config.get('decay_step') is not None:
bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)
bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)
else:
raise NotImplementedError()
return bnm_scheduler
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def set_bn_momentum_default(bn_momentum):
def fn(m):
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
m.momentum = bn_momentum
return fn
class BNMomentumScheduler(object):
def __init__(
self, model, bn_lambda, last_epoch=-1,
setter=set_bn_momentum_default
):
if not isinstance(model, nn.Module):
raise RuntimeError(
"Class '{}' is not a PyTorch nn Module".format(
type(model).__name__
)
)
self.model = model
self.setter = setter
self.lmbd = bn_lambda
self.step(last_epoch + 1)
self.last_epoch = last_epoch
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
self.model.apply(self.setter(self.lmbd(epoch)))
def get_momentum(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
return self.lmbd(epoch)
def seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False):
'''
seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.
'''
_,n,c = xyz.shape
assert n == num_points
assert c == 3
if crop == num_points:
return xyz, None
INPUT = []
CROP = []
for points in xyz:
if isinstance(crop,list):
num_crop = random.randint(crop[0],crop[1])
else:
num_crop = crop
points = points.unsqueeze(0)
if fixed_points is None:
center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda()
else:
if isinstance(fixed_points,list):
fixed_point = random.sample(fixed_points,1)[0]
else:
fixed_point = fixed_points
center = fixed_point.reshape(1,1,3).cuda()
distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048
idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048
if padding_zeros:
input_data = points.clone()
input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0
else:
input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
if isinstance(crop,list):
INPUT.append(fps(input_data,2048))
CROP.append(fps(crop_data,2048))
else:
INPUT.append(input_data)
CROP.append(crop_data)
input_data = torch.cat(INPUT,dim=0)# B N 3
crop_data = torch.cat(CROP,dim=0)# B M 3
return input_data.contiguous(), crop_data.contiguous()
def get_ptcloud_img(ptcloud,roll,pitch):
fig = plt.figure(figsize=(8, 8))
x, z, y = ptcloud.transpose(1, 0)
ax = fig.gca(projection=Axes3D.name, adjustable='box')
ax.axis('off')
# ax.axis('scaled')
ax.view_init(roll,pitch)
max, min = np.max(ptcloud), np.min(ptcloud)
ax.set_xbound(min, max)
ax.set_ybound(min, max)
ax.set_zbound(min, max)
ax.scatter(x, y, z, zdir='z', c=y, cmap='jet')
fig.canvas.draw()
img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
return img
def visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y',
xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ):
fig = plt.figure(figsize=(6*len(data_list),6))
cmax = data_list[-1][:,0].max()
for i in range(len(data_list)):
data = data_list[i][:-2048] if i == 1 else data_list[i]
color = data[:,0] /cmax
ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d')
ax.view_init(30, -120)
b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black')
ax.set_title(titles[i])
ax.set_axis_off()
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_zlim(zlim)
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)
if not os.path.exists(path):
os.makedirs(path)
pic_path = path + '.png'
fig.savefig(pic_path)
np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())
np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())
plt.close(fig)
def random_dropping(pc, e):
up_num = max(64, 768 // (e//50 + 1))
pc = pc
random_num = torch.randint(1, up_num, (1,1))[0,0]
pc = fps(pc, random_num)
padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)
pc = torch.cat([pc, padding], dim = 1)
return pc
def random_scale(partial, scale_range=[0.8, 1.2]):
scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]
return partial * scale
================================================
FILE: utils/parser.py
================================================
import os
import argparse
from pathlib import Path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--config',
type = str,
help = 'yaml config file')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--num_workers', type=int, default=8)
# seed
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
# bn
parser.add_argument(
'--sync_bn',
action='store_true',
default=False,
help='whether to use sync bn')
# some args
parser.add_argument('--exp_name', type = str, default='default', help = 'experiment name')
parser.add_argument('--loss', type=str, default='cd1', help='loss name')
parser.add_argument('--start_ckpts', type = str, default=None, help = 'reload used ckpt path')
parser.add_argument('--ckpts', type = str, default=None, help = 'test used ckpt path')
parser.add_argument('--val_freq', type = int, default=1, help = 'test freq')
parser.add_argument(
'--vote',
action='store_true',
default=False,
help = 'vote acc')
parser.add_argument(
'--resume',
action='store_true',
default=False,
help = 'autoresume training (interrupted by accident)')
parser.add_argument(
'--test',
action='store_true',
default=False,
help = 'test mode for certain ckpt')
parser.add_argument(
'--finetune_model',
action='store_true',
default=False,
help = 'finetune modelnet with pretrained weight')
parser.add_argument(
'--scratch_model',
action='store_true',
default=False,
help = 'training modelnet from scratch')
parser.add_argument(
'--mode',
choices=['easy', 'median', 'hard', None],
default=None,
help = 'difficulty mode for shapenet')
parser.add_argument(
'--way', type=int, default=-1)
parser.add_argument(
'--shot', type=int, default=-1)
parser.add_argument(
'--fold', type=int, default=-1)
args = parser.parse_args()
if args.test and args.resume:
raise ValueError(
'--test and --resume cannot be both activate')
if args.resume and args.start_ckpts is not None:
raise ValueError(
'--resume and --start_ckpts cannot be both activate')
if args.test and args.ckpts is None:
raise ValueError(
'ckpts shouldnt be None while test mode')
if args.finetune_model and args.ckpts is None:
print(
'training from scratch')
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
if args.test:
args.exp_name = 'test_' + args.exp_name
if args.mode is not None:
args.exp_name = args.exp_name + '_' +args.mode
args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, args.exp_name)
args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,'TFBoard' ,args.exp_name)
args.log_name = Path(args.config).stem
create_experiment_dir(args)
return args
def create_experiment_dir(args):
if not os.path.exists(args.experiment_path):
os.makedirs(args.experiment_path)
print('Create experiment path successfully at %s' % args.experiment_path)
if not os.path.exists(args.tfboard_path):
os.makedirs(args.tfboard_path)
print('Create TFBoard path successfully at %s' % args.tfboard_path)
================================================
FILE: utils/registry.py
================================================
import inspect
import warnings
from functools import partial
from utils import config
class Registry:
"""A registry to map strings to classes.
Registered object could be built from registry.
Example:
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
>>> resnet = MODELS.build(dict(NAME='ResNet'))
Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for
advanced useage.
Args:
name (str): Registry name.
build_func(func, optional): Build function to construct instance from
Registry, func:`build_from_cfg` is used if neither ``parent`` or
``build_func`` is specified. If ``parent`` is specified and
``build_func`` is not given, ``build_func`` will be inherited
from ``parent``. Default: None.
parent (Registry, optional): Parent registry. The class registered in
children registry could be built from parent. Default: None.
scope (str, optional): The scope of registry. It is the key to search
for children registry. If not specified, scope will be the name of
the package where class is defined, e.g. mmdet, mmcls, mmseg.
Default: None.
"""
def __init__(self, name, build_func=None, parent=None, scope=None):
self._name = name
self._module_dict = dict()
self._children = dict()
self._scope = self.infer_scope() if scope is None else scope
# self.build_func will be set with the following priority:
# 1. build_func
# 2. parent.build_func
# 3. build_from_cfg
if build_func is None:
if parent is not None:
self.build_func = parent.build_func
else:
self.build_func = build_from_cfg
else:
self.build_func = build_func
if parent is not None:
assert isinstance(parent, Registry)
parent._add_children(self)
self.parent = parent
else:
self.parent = None
def __len__(self):
return len(self._module_dict)
def __contains__(self, key):
return self.get(key) is not None
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={self._module_dict})'
return format_str
@staticmethod
def infer_scope():
"""Infer the scope of registry.
The name of the package where registry is defined will be returned.
Example:
# in mmdet/models/backbone/resnet.py
>>> MODELS = Registry('models')
>>> @MODELS.register_module()
>>> class ResNet:
>>> pass
The scope of ``ResNet`` will be ``mmdet``.
Returns:
scope (str): The inferred scope name.
"""
# inspect.stack() trace where this function is called, the index-2
# indicates the frame where `infer_scope()` is called
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
split_filename = filename.split('.')
return split_filename[0]
@staticmethod
def split_scope_key(key):
"""Split scope and key.
The first scope will be split from key.
Examples:
>>> Registry.split_scope_key('mmdet.ResNet')
'mmdet', 'ResNet'
>>> Registry.split_scope_key('ResNet')
None, 'ResNet'
Return:
scope (str, None): The first scope.
key (str): The remaining key.
"""
split_index = key.find('.')
if split_index != -1:
return key[:split_index], key[split_index + 1:]
else:
return None, key
@property
def name(self):
return self._name
@property
def scope(self):
return self._scope
@property
def module_dict(self):
return self._module_dict
@property
def children(self):
return self._children
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
scope, real_key = self.split_scope_key(key)
if scope is None or scope == self._scope:
# get from self
if real_key in self._module_dict:
return self._module_dict[real_key]
else:
# get from self._children
if scope in self._children:
return self._children[scope].get(real_key)
else:
# goto root
parent = self.parent
while parent.parent is not None:
parent = parent.parent
return parent.get(key)
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
def _add_children(self, registry):
"""Add children for a registry.
The ``registry`` will be added as children based on its scope.
The parent registry could build objects from children registry.
Example:
>>> models = Registry('models')
>>> mmdet_models = Registry('models', parent=models)
>>> @mmdet_models.register_module()
>>> class ResNet:
>>> pass
>>> resnet = models.build(dict(NAME='mmdet.ResNet'))
"""
assert isinstance(registry, Registry)
assert registry.scope is not None
assert registry.scope not in self.children, \
f'scope {registry.scope} exists in {self.name} registry'
self.children[registry.scope] = registry
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.')
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return cls
return _register
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (edict): Config dict. It should at least contain the key "NAME".
registry (:obj:`Registry`): The registry to search the type from.
Returns:
object: The constructed object.
"""
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'NAME' not in cfg:
if default_args is None or 'NAME' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "NAME", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
if default_args is not None:
cfg = config.merge_new_config(cfg, default_args)
obj_type = cfg.get('NAME')
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
return obj_cls(cfg)
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}')