[
  {
    "path": "DATASET.md",
    "content": "## Dataset\n\nThe overall directory structure should be:\n```\n│Point-MAE/\n├──cfgs/\n├──data/\n│   ├──ModelNet/\n│   ├──ModelNetFewshot/\n│   ├──ScanObjectNN/\n│   ├──ShapeNet55-34/\n│   ├──shapenetcore_partanno_segmentation_benchmark_v0_normal/\n├──datasets/\n├──.......\n```\n\n### ModelNet40 Dataset: \n\n```\n│ModelNet/\n├──modelnet40_normal_resampled/\n│  ├── modelnet40_shape_names.txt\n│  ├── modelnet40_train.txt\n│  ├── modelnet40_test.txt\n│  ├── modelnet40_train_8192pts_fps.dat\n│  ├── modelnet40_test_8192pts_fps.dat\n```\nDownload: You can download the processed data from [Point-BERT repo](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md), or download from the [official website](https://modelnet.cs.princeton.edu/#) and process it by yourself.\n\n### ModelNet Few-shot Dataset:\n```\n│ModelNetFewshot/\n├──5way10shot/\n│  ├── 0.pkl\n│  ├── ...\n│  ├── 9.pkl\n├──5way20shot/\n│  ├── ...\n├──10way10shot/\n│  ├── ...\n├──10way20shot/\n│  ├── ...\n```\n\nDownload: Please download the data from [Point-BERT repo](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md). We use the same data split as theirs.\n\n### ScanObjectNN Dataset:\n```\n│ScanObjectNN/\n├──main_split/\n│  ├── training_objectdataset_augmentedrot_scale75.h5\n│  ├── test_objectdataset_augmentedrot_scale75.h5\n│  ├── training_objectdataset.h5\n│  ├── test_objectdataset.h5\n├──main_split_nobg/\n│  ├── training_objectdataset.h5\n│  ├── test_objectdataset.h5\n```\nDownload: Please download the data from the [official website](https://hkust-vgd.github.io/scanobjectnn/).\n\n### ShapeNet55/34 Dataset:\n\n```\n│ShapeNet55-34/\n├──shapenet_pc/\n│  ├── 02691156-1a04e3eab45ca15dd86060f189eb133.npy\n│  ├── 02691156-1a6ad7a24bb89733f412783097373bdc.npy\n│  ├── .......\n├──ShapeNet-55/\n│  ├── train.txt\n│  └── test.txt\n```\n\nDownload: Please download the data from [Point-BERT repo](https://github.com/lulutang0608/Point-BERT/blob/49e2c7407d351ce8fe65764bbddd5d9c0e0a4c52/DATASET.md).\n\n### ShapeNetPart Dataset:\n\n```\n|shapenetcore_partanno_segmentation_benchmark_v0_normal/\n├──02691156/\n│  ├── 1a04e3eab45ca15dd86060f189eb133.txt\n│  ├── .......\n│── .......\n│──train_test_split/\n│──synsetoffset2category.txt\n```\n\nDownload: Please download the data from [here](https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip). \n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 PANG-Yatian, YUAN-Li\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# PointGPT\n\n## PointGPT: Auto-regressively Generative Pre-training from Point Clouds [ArXiv](https://arxiv.org/abs/2305.11487)\n\nIn this work, we present PointGPT, a novel approach that extends the concept of GPT to point clouds, utilizing a point cloud auto-regressive generation task for pre-training transformer models. In object classification tasks, our PointGPT achieves 94.9% accuracy on the ModelNet40 dataset and 93.4% accuracy on the ScanObjectNN dataset, outperforming all other transformer models. In few-shot learning tasks, our method also attains new SOTA performance on all four benchmarks.\n\n<div  align=\"center\">    \n <img src=\"./figures/net.png\" width = \"666\"  align=center />\n</div>\n\n## News\n\n[2023.09.22] PointGPT has been accepted by NeurIPS 2023!\n\n[2023.09.08] Unlabeled hybrid dataset and labeled hybrid dataset have been released!\n\n[2023.08.19] Code has been updated; PointGPT-B and PointGPT-L models have been released!\n\n[2023.06.20] Code and the PointGPT-S models have been released!\n\n\n## 1. Requirements\n\nPyTorch >= 1.7.0;\npython >= 3.7;\nCUDA >= 9.0;\nGCC >= 4.9;\ntorchvision;\n\n```\npip install -r requirements.txt\n```\n\n```\n# Chamfer Distance & emd\ncd ./extensions/chamfer_dist\npython setup.py install --user\ncd ./extensions/emd\npython setup.py install --user\n# PointNet++\npip install \"git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib\"\n# GPU kNN\npip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl\n```\n\n## 2. Datasets\n\nOur training data for the PointGPT-S model encompasses ShapeNet, ScanObjectNN, ModelNet40, and ShapeNetPart datasets. For detailed information, please refer to [DATASET.md](./DATASET.md).\n\nTo pretrain the PointGPT-B and PointGPT-L models, we employ both unlabeled hybrid dataset and labeled hybrid dataset, available for download [here](https://drive.google.com/file/d/1TWgd3eJX1HDruFfU9JrGnBfcVhzJIXqT/view?usp=sharing).\n\n\n\n## 3. PointGPT Models\n### PointGPT-S Models\n| Task              | Dataset        | Config                                                          | Acc.       | Download                                                                                      |\n| ----------------- | -------------- | --------------------------------------------------------------- | ---------- | --------------------------------------------------------------------------------------------- |\n| Pre-training      | ShapeNet       | [pretrain.yaml](./cfgs/PointGPT-S/pretrain.yaml)                           | N.A.       | [here](https://drive.google.com/file/d/1gTFI327kXVDFQ90JfYX0zIS4opM1EkqX/view?usp=drive_link) |\n| Classification    | ScanObjectNN   | [finetune_scan_hardest.yaml](./cfgs/PointGPT-S/finetune_scan_hardest.yaml) | 86.9%      | [here](https://drive.google.com/file/d/12Tj2OFKsEPT5zd5nQQ2VNEZlCKHncdGh/view?usp=drive_link) |\n| Classification    | ScanObjectNN   | [finetune_scan_objbg.yaml](./cfgs/PointGPT-S/finetune_scan_objbg.yaml)     | 91.6%      | [here](https://drive.google.com/file/d/1s4RrBkfwVr8r0H2FxwiHULcyMe_EAJ9D/view?usp=drive_link) |\n| Classification    | ScanObjectNN   | [finetune_scan_objonly.yaml](./cfgs/PointGPT-S/finetune_scan_objonly.yaml) | 90.0%      | [here](https://drive.google.com/file/d/173yfDAlqqed-oRHaogX6DC4Uj1b8Rvxt/view?usp=drive_link) |\n| Classification    | ModelNet40(1k) | [finetune_modelnet.yaml](./cfgs/PointGPT-S/finetune_modelnet.yaml)         | 94.0%      | [here](https://drive.google.com/file/d/17uoJchAzwapTNHVxOWNH4HLNZz9kbGoo/view?usp=drive_link) |\n| Classification    | ModelNet40(8k) | [finetune_modelnet_8k.yaml](./cfgs/PointGPT-S/finetune_modelnet_8k.yaml)   | 94.2%      | [here](https://drive.google.com/file/d/1XocTFSsKZgKHx2cLqZJi2rcF74hQ-1nx/view?usp=drive_link) |\n| Part segmentation | ShapeNetPart   | [segmentation](./segmentation)                                  | 86.2% mIoU | [here](https://drive.google.com/file/d/1WVMTtIq4vPQOOnlDsymVA5541lNL-hm3/view?usp=drive_link) |\n\n| Task              | Dataset    | Config                              | 5w10s Acc. (%) | 5w20s Acc. (%) | 10w10s Acc. (%) | 10w20s Acc. (%) |\n| ----------------- | ---------- | ----------------------------------- | -------------- | -------------- | --------------- | --------------- |\n| Few-shot learning | ModelNet40 | [fewshot.yaml](./cfgs/fewshot.yaml) | 96.8 ± 2.0     | 98.6 ± 1.1     | 92.6 ± 4.6      | 95.2 ± 3.4      |\n\n### PointGPT-B Models\n| Task              | Dataset        | Config                                                          | Acc.       | Download                                                                                      |\n| ----------------- | -------------- | --------------------------------------------------------------- | ---------- | --------------------------------------------------------------------------------------------- |\n| Pre-training      | UnlabeledHybrid       | [pretrain.yaml](./cfgs/PointGPT-B/pretrain.yaml)                           | N.A.       | [here](https://drive.google.com/file/d/1Gyf9ZR8MCPg1XOCALjJR9VJepV7iAi5S/view?usp=sharing) |\n| Post-pre-training | LabeledHybrid       | [post_pretrain.yaml](./cfgs/PointGPT-B/post_pretrain.yaml)                        | N.A.       | [here](https://drive.google.com/file/d/1Gc7thuU-D1Sq4NIMTV6-U1LhVN0E2z9l/view?usp=sharing) |\n| Classification    | ScanObjectNN   | [finetune_scan_hardest.yaml](./cfgs/PointGPT-B/finetune_scan_hardest.yaml) | 91.9%      | [here](https://drive.google.com/file/d/1tHi7W935DxVttXHG0Mgb0HSfYWUqXLwB/view?usp=sharing) |\n| Classification    | ScanObjectNN   | [finetune_scan_objbg.yaml](./cfgs/PointGPT-B/finetune_scan_objbg.yaml)     | 95.8%      | [here](https://drive.google.com/file/d/1te8DuC_-cOzt4JayyaNWvxHcRztjDlGF/view?usp=sharing) |\n| Classification    | ScanObjectNN   | [finetune_scan_objonly.yaml](./cfgs/PointGPT-B/finetune_scan_objonly.yaml) | 95.2%      | [here](https://drive.google.com/file/d/17c8KvDrAuY0GgcO7SGE-4zlMArjzkjLX/view?usp=sharing) |\n| Classification    | ModelNet40(1k) | [finetune_modelnet.yaml](./cfgs/PointGPT-B/finetune_modelnet.yaml)         | 94.4%      | [here](https://drive.google.com/file/d/1l5zhy52erSp5gigbhYaT0nyMrV_lbh-C/view?usp=sharing) |\n| Classification    | ModelNet40(8k) | [finetune_modelnet_8k.yaml](./cfgs/PointGPT-B/finetune_modelnet_8k.yaml)   | 94.6%      | [here](https://drive.google.com/file/d/1FzM7ULPUAOk_J0BRHFvv0nS_Xd65oWbV/view?usp=sharing) |\n| Part segmentation | ShapeNetPart   | [segmentation](./segmentation)                                  | 86.5% mIoU | [here](https://drive.google.com/file/d/1P6hELhX6Yr-rN04q6N71wZfvW2HnLhqD/view?usp=sharing) |\n\n| Task              | Dataset    | Config                              | 5w10s Acc. (%) | 5w20s Acc. (%) | 10w10s Acc. (%) | 10w20s Acc. (%) |\n| ----------------- | ---------- | ----------------------------------- | -------------- | -------------- | --------------- | --------------- |\n| Few-shot learning | ModelNet40 | [fewshot.yaml](./cfgs/PointGPT-B/fewshot.yaml) | 97.5 ± 2.0     | 98.8 ± 1.0     | 93.5 ± 4.0      | 95.8 ± 3.0      |\n\n### PointGPT-L Models\n| Task              | Dataset        | Config                                                          | Acc.       | Download                                                                                      |\n| ----------------- | -------------- | --------------------------------------------------------------- | ---------- | --------------------------------------------------------------------------------------------- |\n| Pre-training      | UnlabeledHybrid       | [pretrain.yaml](./cfgs/PointGPT-L/pretrain.yaml)                           | N.A.       | [here](https://drive.google.com/file/d/1nzCwriFbC2QoDbRpGhWvf_DbFIkFU6zV/view?usp=sharing) |\n| Post-pre-training | LabeledHybrid       | [post_pretrain.yaml](./cfgs/PointGPT-L/post_pretrain.yaml)                        | N.A.       | [here](https://drive.google.com/file/d/1Kh6f6gFR12Y86FAeBtMU9NbNpB5vZnpu/view?usp=sharing) |\n| Classification    | ScanObjectNN   | [finetune_scan_hardest.yaml](./cfgs/PointGPT-L/finetune_scan_hardest.yaml) | 93.4%      | [here](https://drive.google.com/file/d/1e_qIfZCqQmq0eRpYhf9xrIxl6TkzsaZ9/view?usp=sharing) |\n| Classification    | ScanObjectNN   | [finetune_scan_objbg.yaml](./cfgs/PointGPT-L/finetune_scan_objbg.yaml)     | 97.2%      | [here](https://drive.google.com/file/d/1gd8gn0ffK0zfWv7AAUbygzIPSeeRU8fD/view?usp=sharing) |\n| Classification    | ScanObjectNN   | [finetune_scan_objonly.yaml](./cfgs/PointGPT-L/finetune_scan_objonly.yaml) | 96.6%      | [here](https://drive.google.com/file/d/1F2MnPmQGKnYUgmS5uz3PNInU23jWsNj1/view?usp=sharing) |\n| Classification    | ModelNet40(1k) | [finetune_modelnet.yaml](./cfgs/PointGPT-L/finetune_modelnet.yaml)         | 94.7%      | [here](https://drive.google.com/file/d/1ntWwZCvD_Tqykq9F7QrDKXH7aL-dcCsQ/view?usp=sharing) |\n| Classification    | ModelNet40(8k) | [finetune_modelnet_8k.yaml](./cfgs/PointGPT-L/finetune_modelnet_8k.yaml)   | 94.9%      | [here](https://drive.google.com/file/d/1gKgdbtIuRinJY-NElSHwrKAL5OhBjrGD/view?usp=sharing) |\n| Part segmentation | ShapeNetPart   | [segmentation](./segmentation)                                  | 86.6% mIoU | [here](https://drive.google.com/file/d/1d3fXLBkXvzl9YjX5DDMdm7rUtCvfwgUL/view?usp=sharing) |\n\n| Task              | Dataset    | Config                              | 5w10s Acc. (%) | 5w20s Acc. (%) | 10w10s Acc. (%) | 10w20s Acc. (%) |\n| ----------------- | ---------- | ----------------------------------- | -------------- | -------------- | --------------- | --------------- |\n| Few-shot learning | ModelNet40 | [fewshot.yaml](./cfgs/PointGPT-L/fewshot.yaml) | 98.0 ± 1.9     | 99.0 ± 1.0     | 94.1 ± 3.3      | 96.1 ± 2.8      |\n\n## 4. PointGPT Pre-training\n\nTo pretrain PointGPT, run the following command. \n\n```\nCUDA_VISIBLE_DEVICES=<GPU> python main.py --config cfgs/<MODEL_NAME>/pretrain.yaml --exp_name <output_file_name>\n```\n\nTo post-pretrain PointGPT, run the following command. \n\n```\nCUDA_VISIBLE_DEVICES=<GPU> python main.py --config cfgs/<MODEL_NAME>/post_pretrain.yaml --exp_name <output_file_name> --finetune_model\n```\n\n## 5. PointGPT Fine-tuning\n\nFine-tuning on ScanObjectNN, run the following command:\n```\nCUDA_VISIBLE_DEVICES=<GPUs> python main.py --config cfgs/<MODEL_NAME>/finetune_scan_hardest.yaml \\\n--finetune_model --exp_name <output_file_name> --ckpts <path/to/pre-trained/model>\n```\nFine-tuning on ModelNet40, run the following command:\n```\nCUDA_VISIBLE_DEVICES=<GPUs> python main.py --config cfgs/<MODEL_NAME>/finetune_modelnet.yaml \\\n--finetune_model --exp_name <output_file_name> --ckpts <path/to/pre-trained/model>\n```\nVoting on ModelNet40, run the following command:\n```\nCUDA_VISIBLE_DEVICES=<GPUs> python main.py --test --config cfgs/<MODEL_NAME>/finetune_modelnet.yaml \\\n--exp_name <output_file_name> --ckpts <path/to/best/fine-tuned/model>\n```\nFew-shot learning, run the following command:\n```\nCUDA_VISIBLE_DEVICES=<GPUs> python main.py --config cfgs/<MODEL_NAME>/fewshot.yaml --finetune_model \\\n--ckpts <path/to/pre-trained/model> --exp_name <output_file_name> --way <5 or 10> --shot <10 or 20> --fold <0-9>\n```\nPart segmentation on ShapeNetPart, run the following command:\n```\ncd segmentation\npython main.py --ckpts <path/to/pre-trained/model> --root path/to/data --learning_rate 0.0002 --epoch 300 --model_name <MODEL_NAME>\n```\n\n## 6. Visualization\n\nVisulization of pre-trained model on validation set, run:\n\n```\npython main_vis.py --test --ckpts <path/to/pre-trained/model> --config cfgs/<MODEL_NAME>/pretrain.yaml --exp_name <name>\n```\n\n<div  align=\"center\">    \n <img src=\"./figures/vis.png\" width = \"900\"  align=center />\n</div>\n\n## 7. Ablation studies on post-pre-training stage \n<table>\n  <thead>\n    <tr>\n      <th rowspan=\"2\">Methods</th>\n      <th colspan=\"3\"><u>ScanObjectNN</u></th>\n      <th colspan=\"2\"><u>ModelNet40</u></th>\n      <th colspan=\"2\">ShapeNetPart</th>\n    </tr>\n    <tr>\n      <th>OBJ_BG</th>\n      <th>OBJ_ONLY</th>\n      <th>PB_T50_RS</th>\n      <th>1k P</th>\n      <th>8k P</th>\n      <th>Cls.mIoU</th>\n      <th>Inst.mIoU</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th colspan=\"8\">without post-pre-training</th>\n    </tr>\n    <tr>\n      <td><i>PointGPT-B</i></td>\n      <td>93.6</td>\n      <td>92.5</td>\n      <td>89.6</td>\n      <td>94.2</td>\n      <td>94.4</td>\n      <td>84.5</td>\n      <td>86.4</td>\n    </tr>\n    <tr>\n      <td><i>PointGPT-L</i></td>\n      <td>95.7</td>\n      <td>94.1</td>\n      <td>91.1</td>\n      <td>94.5</td>\n      <td>94.7</td>\n      <td>84.7</td>\n      <td>86.5</td>\n    </tr>\n    <tr>\n      <th colspan=\"8\">with post-pre-training</th>\n    </tr>\n    <tr>\n      <td><i>PointGPT-B</i></td>\n      <td>95.8 <span style=\"color:green\">(+2.2)</span></td>\n      <td>95.2 <span style=\"color:green\">(+2.7)</span></td>\n      <td>91.9 <span style=\"color:green\">(+2.3)</span></td>\n      <td>94.4 <span style=\"color:green\">(+0.2)</span></td>\n      <td>94.6 <span style=\"color:green\">(+0.2)</span></td>\n      <td>84.5 <span style=\"color:green\">(+0.0)</span></td>\n      <td>86.5 <span style=\"color:green\">(+0.1)</span></td>\n    </tr>\n    <tr>\n      <td><i>PointGPT-L</i></td>\n      <td>97.2 <span style=\"color:green\">(+1.5)</span></td>\n      <td>96.6 <span style=\"color:green\">(+2.5)</span></td>\n      <td>93.4 <span style=\"color:green\">(+2.3)</span></td>\n      <td>94.7 <span style=\"color:green\">(+0.2)</span></td>\n      <td>94.9 <span style=\"color:green\">(+0.2)</span></td>\n      <td>84.8 <span style=\"color:green\">(+0.1)</span></td>\n      <td>86.6 <span style=\"color:green\">(+0.1)</span></td>\n    </tr>\n  </tbody>\n</table>\n\n\n## Acknowledgements\n\nOur codes are built upon [Point-MAE](https://github.com/Pang-Yatian/Point-MAE), [Point-BERT](https://github.com/lulutang0608/Point-BERT), [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) and [Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch)\n\nThe unlabeled hybrid dataset and labeled hybrid dataset are built upon [ModelNet40](https://3dshapenets.cs.princeton.edu/), [PartNet](https://partnet.cs.stanford.edu/), [ShapeNet](http://www.shapenet.org), [S3DIS](http://buildingparser.stanford.edu/), [ScanObjectNN](https://hkust-vgd.github.io/scanobjectnn/), [SUN RGB-D](https://rgbd.cs.princeton.edu/), and [Semantic3D](http://semantic3d.net/)\n\n\n## Reference\n\n```\n@article{chen2024pointgpt,\n  title={Pointgpt: Auto-regressively generative pre-training from point clouds},\n  author={Chen, Guangyan and Wang, Meiling and Yang, Yi and Yu, Kai and Yuan, Li and Yue, Yufeng},\n  journal={Advances in Neural Information Processing Systems},\n  volume={36},\n  year={2024}\n}\n```\n\nFor unlabeled hybrid dataset or labeled hybrid dataset, please also cite the following work.\n\n```\n@inproceedings{wu20153d,\n  title={3d shapenets: A deep representation for volumetric shapes},\n  author={Wu, Zhirong and Song, Shuran and Khosla, Aditya and Yu, Fisher and Zhang, Linguang and Tang, Xiaoou and Xiao, Jianxiong},\n  booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},\n  pages={1912--1920},\n  year={2015}\n}\n\n@inproceedings{mo2019partnet,\n  title={Partnet: A large-scale benchmark for fine-grained and hierarchical part-level 3d object understanding},\n  author={Mo, Kaichun and Zhu, Shilin and Chang, Angel X and Yi, Li and Tripathi, Subarna and Guibas, Leonidas J and Su, Hao},\n  booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},\n  pages={909--918},\n  year={2019}\n}\n\n@article{chang2015shapenet,\n  title={Shapenet: An information-rich 3d model repository},\n  author={Chang, Angel X and Funkhouser, Thomas and Guibas, Leonidas and Hanrahan, Pat and Huang, Qixing and Li, Zimo and Savarese, Silvio and Savva, Manolis and Song, Shuran and Su, Hao and others},\n  journal={arXiv preprint arXiv:1512.03012},\n  year={2015}\n}\n\n@inproceedings{armeni20163d,\n  title={3d semantic parsing of large-scale indoor spaces},\n  author={Armeni, Iro and Sener, Ozan and Zamir, Amir R and Jiang, Helen and Brilakis, Ioannis and Fischer, Martin and Savarese, Silvio},\n  booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},\n  pages={1534--1543},\n  year={2016}\n}\n\n@inproceedings{uy-scanobjectnn-iccv19,\n  title = {Revisiting Point Cloud Classification: A New Benchmark Dataset and Classification Model on Real-World Data},\n  author = {Mikaela Angelina Uy and Quang-Hieu Pham and Binh-Son Hua and Duc Thanh Nguyen and Sai-Kit Yeung},\n  booktitle = {International Conference on Computer Vision (ICCV)},\n  year = {2019}\n}\n\n@inproceedings{song2015sun,\n  title={Sun rgb-d: A rgb-d scene understanding benchmark suite},\n  author={Song, Shuran and Lichtenberg, Samuel P and Xiao, Jianxiong},\n  booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},\n  pages={567--576},\n  year={2015}\n}\n\n@article{hackel2017semantic3d,\n  title={Semantic3d. net: A new large-scale point cloud classification benchmark},\n  author={Hackel, Timo and Savinov, Nikolay and Ladicky, Lubor and Wegner, Jan D and Schindler, Konrad and Pollefeys, Marc},\n  journal={arXiv preprint arXiv:1704.03847},\n  year={2017}\n}\n```\n"
  },
  {
    "path": "cfgs/PointGPT-B/fewshot.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0005, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 30 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 768,\n    depth: 12,\n    drop_path_rate: 0.1,\n    cls_dim: 40,\n    num_heads: 12,\n    group_size: 32,\n    num_group: 64,\n    encoder_dims: 768,\n    decoder_depth: 4,\n  }\n\nnpoints: 1024\ntotal_bs: 32\nstep_per_update: 1\nmax_epoch: 300\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-B/finetune_modelnet.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 768,\n    depth: 12,\n    drop_path_rate: 0.2,\n    cls_dim: 40,\n    num_heads: 12,\n    group_size: 32,\n    num_group: 64,\n    encoder_dims: 768,\n    decoder_depth: 4,\n    loss: cdl2,\n    weight_center: 1,\n  }\n\nnpoints: 1024\ntotal_bs: 128\nstep_per_update: 1\nmax_epoch: 50\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-B/finetune_modelnet_8k.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.00005, weight_decay: 0.005 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 768,\n    depth: 12,\n    drop_path_rate: 0.2,\n    cls_dim: 40,\n    num_heads: 12,\n    group_size: 32,\n    num_group: 512,\n    encoder_dims: 768,\n    decoder_depth: 4,\n    loss: cdl2,\n    weight_center: 1,\n  }\n\nnpoints: 8192\ntotal_bs: 32\nstep_per_update: 1\nmax_epoch: 50\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-B/finetune_scan_hardest.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 30, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 768,\n    depth: 12,\n    drop_path_rate: 0.2,\n    cls_dim: 15,\n    num_heads: 12,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 768,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 64\nstep_per_update: 1\nmax_epoch: 30\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-B/finetune_scan_objbg.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 30, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 768,\n    depth: 12,\n    drop_path_rate: 0.2,\n    cls_dim: 15,\n    num_heads: 12,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 768,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 64\nstep_per_update: 1\nmax_epoch: 30\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-B/finetune_scan_objonly.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 768,\n    depth: 12,\n    drop_path_rate: 0.2,\n    cls_dim: 15,\n    num_heads: 12,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 768,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 64\nstep_per_update: 1\nmax_epoch: 50\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-B/post_pretrain.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 100, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/LabeledHybrid.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/LabeledHybrid.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/LabeledHybrid.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 768,\n    depth: 12,\n    drop_path_rate: 0.2,\n    cls_dim: 87,\n    num_heads: 12,\n    group_size: 32,\n    num_group: 64,\n    encoder_dims: 768,\n    decoder_depth: 4,\n    loss: cdl2,\n    weight_center: 1,\n  }\n\nnpoints: 1024\ntotal_bs: 256\nstep_per_update: 1\nmax_epoch: 100\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-B/pretrain.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,\n        others: { subset: \"train\", npoints: 1024 },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,\n        others: { subset: \"test\", npoints: 1024 },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,\n        others: { subset: \"test\", npoints: 1024 },\n      },\n  }\n\nmodel:\n  {\n    NAME: PointGPT,\n    cls_dim: 40,\n    group_size: 32,\n    num_group: 64,\n    loss: cdl12,\n    weight_center: 1,\n    transformer_config:\n      {\n        mask_ratio: 0.7,\n        mask_type: \"rand\",\n        trans_dim: 768,\n        encoder_dims: 768,\n        depth: 12,\n        drop_path_rate: 0.1,\n        num_heads: 12,\n        decoder_depth: 4,\n        decoder_num_heads: 12,\n      },\n  }\n\nnpoints: 1024\ntotal_bs: 128\nstep_per_update: 1\nmax_epoch: 300\n"
  },
  {
    "path": "cfgs/PointGPT-L/fewshot.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0005, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 30 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 768,\n    depth: 12,\n    drop_path_rate: 0.1,\n    cls_dim: 40,\n    num_heads: 12,\n    group_size: 32,\n    num_group: 64,\n    encoder_dims: 768,\n    decoder_depth: 4,\n  }\n\nnpoints: 1024\ntotal_bs: 32\nstep_per_update: 1\nmax_epoch: 300\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-L/finetune_modelnet.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 1024,\n    depth: 24,\n    drop_path_rate: 0.2,\n    cls_dim: 40,\n    num_heads: 16,\n    group_size: 32,\n    num_group: 64,\n    encoder_dims: 1024,\n    decoder_depth: 4,\n    loss: cdl2,\n    weight_center: 1,\n  }\n\nnpoints: 1024\ntotal_bs: 128\nstep_per_update: 1\nmax_epoch: 50\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-L/finetune_modelnet_8k.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.00005, weight_decay: 0.005 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 1024,\n    depth: 24,\n    drop_path_rate: 0.2,\n    cls_dim: 40,\n    num_heads: 16,\n    group_size: 32,\n    num_group: 512,\n    encoder_dims: 1024,\n    decoder_depth: 4,\n    loss: cdl2,\n    weight_center: 1,\n  }\n\nnpoints: 8192\ntotal_bs: 32\nstep_per_update: 1\nmax_epoch: 50\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-L/finetune_scan_hardest.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 1024,\n    depth: 24,\n    drop_path_rate: 0.2,\n    cls_dim: 15,\n    num_heads: 16,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 1024,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 64\nstep_per_update: 1\nmax_epoch: 50\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-L/finetune_scan_objbg.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 1024,\n    depth: 24,\n    drop_path_rate: 0.2,\n    cls_dim: 15,\n    num_heads: 16,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 1024,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 64\nstep_per_update: 1\nmax_epoch: 50\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-L/finetune_scan_objonly.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 50, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 1024,\n    depth: 24,\n    drop_path_rate: 0.2,\n    cls_dim: 15,\n    num_heads: 16,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 1024,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 64\nstep_per_update: 1\nmax_epoch: 50\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-L/post_pretrain.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 100, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/LabeledHybrid.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/LabeledHybrid.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/LabeledHybrid.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 1024,\n    depth: 24,\n    drop_path_rate: 0.2,\n    cls_dim: 87,\n    num_heads: 16,\n    group_size: 32,\n    num_group: 64,\n    encoder_dims: 1024,\n    decoder_depth: 4,\n    loss: cdl2,\n    weight_center: 1,\n  }\n\nnpoints: 1024\ntotal_bs: 256\nstep_per_update: 1\nmax_epoch: 100\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-L/pretrain.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.00006, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 600, initial_epochs: 80 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,\n        others: { subset: \"train\", npoints: 1024 },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,\n        others: { subset: \"test\", npoints: 1024 },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/UnlabeledHybrid.yaml,\n        others: { subset: \"test\", npoints: 1024 },\n      },\n  }\n\nmodel:\n  {\n    NAME: PointGPT,\n    cls_dim: 40,\n    group_size: 32,\n    num_group: 64,\n    loss: cdl12,\n    weight_center: 1,\n    transformer_config:\n      {\n        mask_ratio: 0.7,\n        mask_type: \"rand\",\n        trans_dim: 1024,\n        encoder_dims: 1024,\n        depth: 24,\n        drop_path_rate: 0.1,\n        num_heads: 16,\n        decoder_depth: 4,\n        decoder_num_heads: 12,\n      },\n  }\n\nnpoints: 1024\ntotal_bs: 128\nstep_per_update: 1\nmax_epoch: 350\n"
  },
  {
    "path": "cfgs/PointGPT-S/fewshot.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0005, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 30 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40FewShot.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 384,\n    depth: 12,\n    drop_path_rate: 0.1,\n    cls_dim: 40,\n    num_heads: 6,\n    group_size: 32,\n    num_group: 64,\n    encoder_dims: 384,\n    decoder_depth: 4,\n  }\n\nnpoints: 1024\ntotal_bs: 32\nstep_per_update: 1\nmax_epoch: 300\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-S/finetune_modelnet.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 384,\n    depth: 12,\n    drop_path_rate: 0.1,\n    cls_dim: 40,\n    num_heads: 6,\n    group_size: 32,\n    num_group: 64,\n    encoder_dims: 384,\n    decoder_depth: 4,\n    loss: cdl2,\n    weight_center: 1,\n  }\n\nnpoints: 1024\ntotal_bs: 128\nstep_per_update: 1\nmax_epoch: 300\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-S/finetune_modelnet_8k.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.005 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ModelNet40.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 384,\n    depth: 12,\n    drop_path_rate: 0.1,\n    cls_dim: 40,\n    num_heads: 6,\n    group_size: 32,\n    num_group: 512,\n    encoder_dims: 384,\n    decoder_depth: 4,\n    loss: cdl2,\n    weight_center: 1,\n  }\n\nnpoints: 8192\ntotal_bs: 32\nstep_per_update: 1\nmax_epoch: 300\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-S/finetune_scan_hardest.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 30 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_hardest.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 384,\n    depth: 12,\n    drop_path_rate: 0.1,\n    cls_dim: 15,\n    num_heads: 6,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 384,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 64\nstep_per_update: 1\nmax_epoch: 300\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-S/finetune_scan_objbg.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectbg.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 384,\n    depth: 12,\n    drop_path_rate: 0.1,\n    cls_dim: 15,\n    num_heads: 6,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 384,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 32\nstep_per_update: 1\nmax_epoch: 300\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-S/finetune_scan_objonly.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"train\" },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"test\" },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ScanObjectNN_objectonly.yaml,\n        others: { subset: \"test\" },\n      },\n  }\nmodel:\n  {\n    NAME: PointTransformer,\n    trans_dim: 384,\n    depth: 12,\n    drop_path_rate: 0.1,\n    cls_dim: 15,\n    num_heads: 6,\n    group_size: 32,\n    num_group: 128,\n    encoder_dims: 384,\n    decoder_depth: 4,\n  }\n\nnpoints: 2048\ntotal_bs: 32\nstep_per_update: 1\nmax_epoch: 300\ngrad_norm_clip: 10\n"
  },
  {
    "path": "cfgs/PointGPT-S/pretrain.yaml",
    "content": "optimizer: { type: AdamW, kwargs: { lr: 0.0001, weight_decay: 0.05 } }\n\nscheduler: { type: CosLR, kwargs: { epochs: 300, initial_epochs: 10 } }\n\ndataset:\n  {\n    train:\n      {\n        _base_: cfgs/dataset_configs/ShapeNet-55.yaml,\n        others: { subset: \"train\", npoints: 1024 },\n      },\n    val:\n      {\n        _base_: cfgs/dataset_configs/ShapeNet-55.yaml,\n        others: { subset: \"test\", npoints: 1024 },\n      },\n    test:\n      {\n        _base_: cfgs/dataset_configs/ShapeNet-55.yaml,\n        others: { subset: \"test\", npoints: 1024 },\n      },\n  }\n\nmodel:\n  {\n    NAME: PointGPT,\n    cls_dim: 40,\n    group_size: 32,\n    num_group: 64,\n    loss: cdl12,\n    weight_center: 1,\n    transformer_config:\n      {\n        mask_ratio: 0.7,\n        mask_type: \"rand\",\n        trans_dim: 384,\n        encoder_dims: 384,\n        depth: 12,\n        drop_path_rate: 0.1,\n        num_heads: 6,\n        decoder_depth: 4,\n        decoder_num_heads: 6,\n      },\n  }\n\nnpoints: 1024\ntotal_bs: 64\nstep_per_update: 1\nmax_epoch: 300\n"
  },
  {
    "path": "cfgs/dataset_configs/LabeledHybrid.yaml",
    "content": "NAME: LabeledHybrid\nDATA_PATH: data/HybridDatasets/post_pretrain\nN_POINTS: 2048\nPC_PATH: data/HybridDatasets\nnpoints: 1024\nNUM_CATEGORY: 87\n"
  },
  {
    "path": "cfgs/dataset_configs/ModelNet40.yaml",
    "content": "NAME: ModelNet\nDATA_PATH: data/ModelNet/modelnet40_normal_resampled\nN_POINTS: 8192\nNUM_CATEGORY: 40\nUSE_NORMALS: FALSE"
  },
  {
    "path": "cfgs/dataset_configs/ModelNet40FewShot.yaml",
    "content": "NAME: ModelNetFewShot\nDATA_PATH: data/ModelNetFewshot\nN_POINTS: 8192\nNUM_CATEGORY: 40\nUSE_NORMALS: FALSE"
  },
  {
    "path": "cfgs/dataset_configs/ScanObjectNN_hardest.yaml",
    "content": "NAME: ScanObjectNN_hardest\nROOT: data/ScanObjectNN/h5_files/main_split"
  },
  {
    "path": "cfgs/dataset_configs/ScanObjectNN_objectbg.yaml",
    "content": "NAME: ScanObjectNN\nROOT: data/ScanObjectNN/h5_files/main_split"
  },
  {
    "path": "cfgs/dataset_configs/ScanObjectNN_objectonly.yaml",
    "content": "NAME: ScanObjectNN\nROOT: data/ScanObjectNN/h5_files/main_split_nobg"
  },
  {
    "path": "cfgs/dataset_configs/ShapeNet-55.yaml",
    "content": "NAME: ShapeNet\nDATA_PATH: data/ShapeNet55-34/ShapeNet-55\nN_POINTS: 8192\nPC_PATH: data/ShapeNet55-34/shapenet_pc\n"
  },
  {
    "path": "cfgs/dataset_configs/UnlabeledHybrid.yaml",
    "content": "NAME: UnlabeledHybrid\nDATA_PATH: data/HybridDatasets/pretrain\nN_POINTS: 2048\nPC_PATH: data/HybridDatasets\n"
  },
  {
    "path": "datasets/LabeledHybrid.py",
    "content": "import os\nimport torch\nimport numpy as np\nimport torch.utils.data as data\nfrom .io import IO\nfrom .build import DATASETS\nfrom utils.logger import *\n\n@DATASETS.register_module()\nclass LabeledHybrid(data.Dataset):\n    def __init__(self, config):\n        self.data_root = config.DATA_PATH\n        self.pc_path = config.PC_PATH\n        self.subset = config.subset\n        self.npoints = config.N_POINTS\n\n        self.data_list_file = os.path.join(self.data_root, f'{self.subset}.txt')\n        self.label_list_file = os.path.join(self.data_root, f'{self.subset}_num.txt')\n\n        self.sample_points_num = config.npoints\n\n        print_log(f'[DATASET] sample out {self.sample_points_num} points', logger = 'LabeledHybrid')\n        print_log(f'[DATASET] Open file {self.data_list_file}', logger = 'LabeledHybrid')\n        with open(self.data_list_file, 'r') as f:\n            lines = f.readlines()\n        print_log(f'[DATASET] Open file {self.label_list_file}', logger = 'LabeledHybrid')\n        with open(self.label_list_file, 'r') as f:\n            lines_label = f.readlines()\n\n        self.file_list = []\n        for line in lines:\n            self.file_list.append(line.strip())\n        print_log(f'[DATASET] {len(self.file_list)} instances were loaded', logger = 'LabeledHybrid')\n        self.label_list = []\n        for line_label in lines_label:\n            self.label_list.append(np.array(int(line_label.strip())))\n        print_log(f'[DATASET] {len(self.label_list)} labels were loaded', logger = 'LabeledHybrid')\n\n\n    def pc_norm(self, pc):\n        \"\"\" pc: NxC, return NxC \"\"\"\n        centroid = np.mean(pc, axis=0)\n        pc = pc - centroid\n        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n        pc = pc / m\n        return pc\n        \n\n    def random_sample(self, pc, num):\n        permutation = np.arange(pc.shape[0])\n        np.random.shuffle(permutation)\n        pc = pc[permutation[:num]]\n        return pc\n        \n    def __getitem__(self, idx):\n        sample = self.file_list[idx]\n        label = self.label_list[idx]\n\n        data = IO.get(os.path.join(self.pc_path, sample)).astype(np.float32)\n\n        data = self.random_sample(data, self.sample_points_num)\n        data = self.pc_norm(data)\n        data = torch.from_numpy(data).float()\n        return 'LabeledHybrid', 'sample', (data, label)\n\n    def __len__(self):\n        return len(self.file_list)"
  },
  {
    "path": "datasets/ModelNetDataset.py",
    "content": "'''\n@author: Xu Yan\n@file: ModelNet.py\n@time: 2021/3/19 15:51\n'''\nimport os\nimport numpy as np\nimport warnings\nimport pickle\n\nfrom tqdm import tqdm\nfrom torch.utils.data import Dataset\nfrom .build import DATASETS\nfrom utils.logger import *\nimport torch\n\nwarnings.filterwarnings('ignore')\n\n\ndef pc_normalize(pc):\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\n\n\ndef farthest_point_sample(point, npoint):\n    \"\"\"\n    Input:\n        xyz: pointcloud data, [N, D]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [npoint, D]\n    \"\"\"\n    N, D = point.shape\n    xyz = point[:,:3]\n    centroids = np.zeros((npoint,))\n    distance = np.ones((N,)) * 1e10\n    farthest = np.random.randint(0, N)\n    for i in range(npoint):\n        centroids[i] = farthest\n        centroid = xyz[farthest, :]\n        dist = np.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = np.argmax(distance, -1)\n    point = point[centroids.astype(np.int32)]\n    return point\n\n@DATASETS.register_module()\nclass ModelNet(Dataset):\n    def __init__(self, config):\n        self.root = config.DATA_PATH\n        self.npoints = config.N_POINTS\n        self.use_normals = config.USE_NORMALS\n        self.num_category = config.NUM_CATEGORY\n        self.process_data = True\n        self.uniform = True\n        split = config.subset\n        self.subset = config.subset\n\n        if self.num_category == 10:\n            self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')\n        else:\n            self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')\n\n        self.cat = [line.rstrip() for line in open(self.catfile)]\n        self.classes = dict(zip(self.cat, range(len(self.cat))))\n\n        shape_ids = {}\n        if self.num_category == 10:\n            shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]\n            shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]\n        else:\n            shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]\n            shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]\n\n        assert (split == 'train' or split == 'test')\n        shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]\n        self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i\n                         in range(len(shape_ids[split]))]\n        print_log('The size of %s data is %d' % (split, len(self.datapath)), logger = 'ModelNet')\n\n        if self.uniform:\n            self.save_path = os.path.join(self.root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))\n        else:\n            self.save_path = os.path.join(self.root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))\n\n        if self.process_data:\n            if not os.path.exists(self.save_path):\n                print_log('Processing data %s (only running in the first time)...' % self.save_path, logger = 'ModelNet')\n                self.list_of_points = [None] * len(self.datapath)\n                self.list_of_labels = [None] * len(self.datapath)\n\n                for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):\n                    fn = self.datapath[index]\n                    cls = self.classes[self.datapath[index][0]]\n                    cls = np.array([cls]).astype(np.int32)\n                    point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)\n\n                    if self.uniform:\n                        point_set = farthest_point_sample(point_set, self.npoints)\n                    else:\n                        point_set = point_set[0:self.npoints, :]\n\n                    self.list_of_points[index] = point_set\n                    self.list_of_labels[index] = cls\n\n                with open(self.save_path, 'wb') as f:\n                    pickle.dump([self.list_of_points, self.list_of_labels], f)\n            else:\n                print_log('Load processed data from %s...' % self.save_path, logger = 'ModelNet')\n                with open(self.save_path, 'rb') as f:\n                    self.list_of_points, self.list_of_labels = pickle.load(f)\n\n    def __len__(self):\n        return len(self.datapath)\n\n    def _get_item(self, index):\n        if self.process_data:\n            point_set, label = self.list_of_points[index], self.list_of_labels[index]\n        else:\n            fn = self.datapath[index]\n            cls = self.classes[self.datapath[index][0]]\n            label = np.array([cls]).astype(np.int32)\n            point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)\n\n            if self.uniform:\n                point_set = farthest_point_sample(point_set, self.npoints)\n            else:\n                point_set = point_set[0:self.npoints, :]\n                \n        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])\n        if not self.use_normals:\n            point_set = point_set[:, 0:3]\n\n        return point_set, label[0]\n\n\n    def __getitem__(self, index):\n        points, label = self._get_item(index)\n        pt_idxs = np.arange(0, points.shape[0])   # 2048\n        if self.subset == 'train':\n            np.random.shuffle(pt_idxs)\n        current_points = points[pt_idxs].copy()\n        current_points = torch.from_numpy(current_points).float()\n        return 'ModelNet', 'sample', (current_points, label)\n"
  },
  {
    "path": "datasets/ModelNetDatasetFewShot.py",
    "content": "'''\n@author: Xu Yan\n@file: ModelNet.py\n@time: 2021/3/19 15:51\n'''\nimport os\nimport numpy as np\nimport warnings\nimport pickle\n\nfrom tqdm import tqdm\nfrom torch.utils.data import Dataset\nfrom .build import DATASETS\nfrom utils.logger import *\nimport torch\nimport random\n\nwarnings.filterwarnings('ignore')\n\n\ndef pc_normalize(pc):\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\n@DATASETS.register_module()\nclass ModelNetFewShot(Dataset):\n    def __init__(self, config):\n        self.root = config.DATA_PATH\n        self.npoints = config.N_POINTS\n        self.use_normals = config.USE_NORMALS\n        self.num_category = config.NUM_CATEGORY\n        self.process_data = True\n        self.uniform = True\n        split = config.subset\n        self.subset = config.subset\n\n        self.way = config.way\n        self.shot = config.shot\n        self.fold = config.fold\n        if self.way == -1 or self.shot == -1 or self.fold == -1:\n            raise RuntimeError()\n\n        self.pickle_path = os.path.join(self.root, f'{self.way}way_{self.shot}shot', f'{self.fold}.pkl')\n\n\n        print_log('Load processed data from %s...' % self.pickle_path, logger = 'ModelNetFewShot')\n\n        with open(self.pickle_path, 'rb') as f:\n            self.dataset = pickle.load(f)[self.subset]\n\n        print_log('The size of %s data is %d' % (split, len(self.dataset)), logger = 'ModelNetFewShot')\n\n    def __len__(self):\n        return len(self.dataset)\n\n    def __getitem__(self, index):\n        points, label, _ = self.dataset[index]\n\n        points[:, 0:3] = pc_normalize(points[:, 0:3])\n        if not self.use_normals:\n            points = points[:, 0:3]\n\n        pt_idxs = np.arange(0, points.shape[0])   # 2048\n        if self.subset == 'train':\n            np.random.shuffle(pt_idxs)\n        current_points = points[pt_idxs].copy()\n        current_points = torch.from_numpy(current_points).float()\n        return 'ModelNet', 'sample', (current_points, label)"
  },
  {
    "path": "datasets/ScanObjectNNDataset.py",
    "content": "import numpy as np\nimport os, sys, h5py\nfrom torch.utils.data import Dataset\nimport torch\nfrom .build import DATASETS\nfrom utils.logger import *\n\nBASE_DIR = os.path.dirname(os.path.abspath(__file__))\nsys.path.append(BASE_DIR)\n\n@DATASETS.register_module()\nclass ScanObjectNN(Dataset):\n    def __init__(self, config, **kwargs):\n        super().__init__()\n        self.subset = config.subset\n        self.root = config.ROOT\n        \n        if self.subset == 'train':\n            h5 = h5py.File(os.path.join(self.root, 'training_objectdataset.h5'), 'r')\n            self.points = np.array(h5['data']).astype(np.float32)\n            self.labels = np.array(h5['label']).astype(int)\n            h5.close()\n        elif self.subset == 'test':\n            h5 = h5py.File(os.path.join(self.root, 'test_objectdataset.h5'), 'r')\n            self.points = np.array(h5['data']).astype(np.float32)\n            self.labels = np.array(h5['label']).astype(int)\n            h5.close()\n        else:\n            raise NotImplementedError()\n\n        print(f'Successfully load ScanObjectNN shape of {self.points.shape}')\n\n    def __getitem__(self, idx):\n        pt_idxs = np.arange(0, self.points.shape[1])   # 2048\n        if self.subset == 'train':\n            np.random.shuffle(pt_idxs)\n        \n        current_points = self.points[idx, pt_idxs].copy()\n        \n\n        current_points = torch.from_numpy(current_points).float()\n        label = self.labels[idx]\n        \n        return 'ScanObjectNN', 'sample', (current_points, label)\n\n    def __len__(self):\n        return self.points.shape[0]\n\n\n\n@DATASETS.register_module()\nclass ScanObjectNN_hardest(Dataset):\n    def __init__(self, config, **kwargs):\n        super().__init__()\n        self.subset = config.subset\n        self.root = config.ROOT\n        \n        if self.subset == 'train':\n            h5 = h5py.File(os.path.join(self.root, 'training_objectdataset_augmentedrot_scale75.h5'), 'r')\n            self.points = np.array(h5['data']).astype(np.float32)\n            self.labels = np.array(h5['label']).astype(int)\n            h5.close()\n        elif self.subset == 'test':\n            h5 = h5py.File(os.path.join(self.root, 'test_objectdataset_augmentedrot_scale75.h5'), 'r')\n            self.points = np.array(h5['data']).astype(np.float32)\n            self.labels = np.array(h5['label']).astype(int)\n            h5.close()\n        else:\n            raise NotImplementedError()\n\n        print(f'Successfully load ScanObjectNN shape of {self.points.shape}')\n\n    def __getitem__(self, idx):\n        pt_idxs = np.arange(0, self.points.shape[1])   # 2048\n        if self.subset == 'train':\n            np.random.shuffle(pt_idxs)\n        \n        current_points = self.points[idx, pt_idxs].copy()\n        \n\n        current_points = torch.from_numpy(current_points).float()\n        label = self.labels[idx]\n        \n        return 'ScanObjectNN', 'sample', (current_points, label)\n\n    def __len__(self):\n        return self.points.shape[0]"
  },
  {
    "path": "datasets/ShapeNet55Dataset.py",
    "content": "import os\nimport torch\nimport numpy as np\nimport torch.utils.data as data\nfrom .io import IO\nfrom .build import DATASETS\nfrom utils.logger import *\n\n\n@DATASETS.register_module()\nclass ShapeNet(data.Dataset):\n    def __init__(self, config):\n        self.data_root = config.DATA_PATH\n        self.pc_path = config.PC_PATH\n        self.subset = config.subset\n        self.npoints = config.N_POINTS\n\n        self.data_list_file = os.path.join(\n            self.data_root, f'{self.subset}.txt')\n        test_data_list_file = os.path.join(self.data_root, 'test.txt')\n\n        self.sample_points_num = config.npoints\n        self.whole = config.get('whole')\n\n        print_log(\n            f'[DATASET] sample out {self.sample_points_num} points', logger='ShapeNet-55')\n        print_log(\n            f'[DATASET] Open file {self.data_list_file}', logger='ShapeNet-55')\n        with open(self.data_list_file, 'r') as f:\n            lines = f.readlines()\n        if self.whole:\n            with open(test_data_list_file, 'r') as f:\n                test_lines = f.readlines()\n            print_log(\n                f'[DATASET] Open file {test_data_list_file}', logger='ShapeNet-55')\n            lines = test_lines + lines\n        self.file_list = []\n        for line in lines:\n            line = line.strip()\n            taxonomy_id = line.split('-')[0]\n            model_id = line.split('-')[1].split('.')[0]\n            self.file_list.append({\n                'taxonomy_id': taxonomy_id,\n                'model_id': model_id,\n                'file_path': line\n            })\n        print_log(\n            f'[DATASET] {len(self.file_list)} instances were loaded', logger='ShapeNet-55')\n\n        self.permutation = np.arange(self.npoints)\n\n    def pc_norm(self, pc):\n        \"\"\" pc: NxC, return NxC \"\"\"\n        centroid = np.mean(pc, axis=0)\n        pc = pc - centroid\n        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n        pc = pc / m\n        return pc\n\n    def random_sample(self, pc, num):\n        np.random.shuffle(self.permutation)\n        pc = pc[self.permutation[:num]]\n        return pc\n\n    def __getitem__(self, idx):\n        sample = self.file_list[idx]\n\n        data = IO.get(os.path.join(\n            self.pc_path, sample['file_path'])).astype(np.float32)\n\n        data = self.random_sample(data, self.sample_points_num)\n        data = self.pc_norm(data)\n        data = torch.from_numpy(data).float()\n        return sample['taxonomy_id'], sample['model_id'], data\n\n    def __len__(self):\n        return len(self.file_list)\n"
  },
  {
    "path": "datasets/UnlabeledHybrid.py",
    "content": "import os\nimport torch\nimport numpy as np\nimport torch.utils.data as data\nfrom .io import IO\nfrom .build import DATASETS\nfrom utils.logger import *\n\n\n@DATASETS.register_module()\nclass UnlabeledHybrid(data.Dataset):\n    def __init__(self, config):\n        self.data_root = config.DATA_PATH\n        self.pc_path = config.PC_PATH\n        self.subset = config.subset\n        self.npoints = config.N_POINTS\n\n        self.data_list_file = os.path.join(\n            self.data_root, f'{self.subset}.txt')\n        test_data_list_file = os.path.join(self.data_root, 'test.txt')\n\n        self.sample_points_num = config.npoints\n        self.whole = config.get('whole')\n\n        print_log(\n            f'[DATASET] sample out {self.sample_points_num} points', logger='UnlabeledHybrid')\n        print_log(\n            f'[DATASET] Open file {self.data_list_file}', logger='UnlabeledHybrid')\n        with open(self.data_list_file, 'r') as f:\n            lines = f.readlines()\n        if self.whole:\n            with open(test_data_list_file, 'r') as f:\n                test_lines = f.readlines()\n            print_log(\n                f'[DATASET] Open file {test_data_list_file}', logger='UnlabeledHybrid')\n            lines = test_lines + lines\n        self.file_list = []\n        for line in lines:\n            line = line.strip()\n            taxonomy_id = ''    \n            model_id = ''\n            self.file_list.append({\n                'taxonomy_id': taxonomy_id,\n                'model_id': model_id,\n                'file_path': line\n            })\n        print_log(\n            f'[DATASET] {len(self.file_list)} instances were loaded', logger='UnlabeledHybrid')\n\n        self.permutation = np.arange(self.npoints)\n\n    def pc_norm(self, pc):\n        \"\"\" pc: NxC, return NxC \"\"\"\n        centroid = np.mean(pc, axis=0)\n        pc = pc - centroid\n        m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n        pc = pc / m\n        return pc\n\n    def random_sample(self, pc, num):\n        permutation = np.arange(pc.shape[0])\n        np.random.shuffle(permutation)\n        pc = pc[permutation[:num]]\n        return pc\n\n    def __getitem__(self, idx):\n        sample = self.file_list[idx]\n\n        data = IO.get(os.path.join(\n            self.pc_path, sample['file_path'])).astype(np.float32)\n\n        data = self.random_sample(data, self.sample_points_num)\n        data = self.pc_norm(data)\n        data = torch.from_numpy(data).float()\n        # sample['taxonomy_id'] and sample['model_id'] are not utilized\n        return sample['taxonomy_id'], sample['model_id'], data\n\n    def __len__(self):\n        return len(self.file_list)\n"
  },
  {
    "path": "datasets/__init__.py",
    "content": "from .build import build_dataset_from_cfg\nimport datasets.ShapeNet55Dataset\nimport datasets.ModelNetDataset\nimport datasets.ModelNetDatasetFewShot\nimport datasets.ScanObjectNNDataset\nimport datasets.LabeledHybrid\nimport datasets.UnlabeledHybrid"
  },
  {
    "path": "datasets/build.py",
    "content": "from utils import registry\n\n\nDATASETS = registry.Registry('dataset')\n\n\ndef build_dataset_from_cfg(cfg, default_args = None):\n    \"\"\"\n    Build a dataset, defined by `dataset_name`.\n    Args:\n        cfg (eDICT): \n    Returns:\n        Dataset: a constructed dataset specified by dataset_name.\n    \"\"\"\n    return DATASETS.build(cfg, default_args = default_args)\n\n\n"
  },
  {
    "path": "datasets/data_transforms.py",
    "content": "import numpy as np\nimport torch\nimport random\n\n\nclass PointcloudRotate(object):\n    def __call__(self, pc):\n        bsize = pc.size()[0]\n        for i in range(bsize):\n            rotation_angle = np.random.uniform() * 2 * np.pi\n            cosval = np.cos(rotation_angle)\n            sinval = np.sin(rotation_angle)\n            rotation_matrix = np.array([[cosval, 0, sinval],\n                                        [0, 1, 0],\n                                        [-sinval, 0, cosval]])\n            R = torch.from_numpy(rotation_matrix.astype(np.float32)).to(pc.device)\n            pc[i, :, :] = torch.matmul(pc[i], R)\n        return pc\n\nclass PointcloudScaleAndTranslate(object):\n    def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2):\n        self.scale_low = scale_low\n        self.scale_high = scale_high\n        self.translate_range = translate_range\n\n    def __call__(self, pc):\n        bsize = pc.size()[0]\n        for i in range(bsize):\n            xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3])\n            xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3])\n            \n            pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy(xyz2).float().cuda()\n            \n        return pc\n\nclass PointcloudJitter(object):\n    def __init__(self, std=0.01, clip=0.05):\n        self.std, self.clip = std, clip\n\n    def __call__(self, pc):\n        bsize = pc.size()[0]\n        for i in range(bsize):\n            jittered_data = pc.new(pc.size(1), 3).normal_(\n                mean=0.0, std=self.std\n            ).clamp_(-self.clip, self.clip)\n            pc[i, :, 0:3] += jittered_data\n            \n        return pc\n\nclass PointcloudScale(object):\n    def __init__(self, scale_low=2. / 3., scale_high=3. / 2.):\n        self.scale_low = scale_low\n        self.scale_high = scale_high\n\n    def __call__(self, pc):\n        bsize = pc.size()[0]\n        for i in range(bsize):\n            xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3])\n            \n            pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda())\n            \n        return pc\n\nclass PointcloudTranslate(object):\n    def __init__(self, translate_range=0.2):\n        self.translate_range = translate_range\n\n    def __call__(self, pc):\n        bsize = pc.size()[0]\n        for i in range(bsize):\n            xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3])\n            \n            pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda()\n            \n        return pc\n\n\nclass PointcloudRandomInputDropout(object):\n    def __init__(self, max_dropout_ratio=0.5):\n        assert max_dropout_ratio >= 0 and max_dropout_ratio < 1\n        self.max_dropout_ratio = max_dropout_ratio\n\n    def __call__(self, pc):\n        bsize = pc.size()[0]\n        for i in range(bsize):\n            dropout_ratio = np.random.random() * self.max_dropout_ratio  # 0~0.875\n            drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0]\n            if len(drop_idx) > 0:\n                cur_pc = pc[i, :, :]\n                cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1)  # set to the first point\n                pc[i, :, :] = cur_pc\n\n        return pc\n\nclass RandomHorizontalFlip(object):\n\n\n  def __init__(self, upright_axis = 'z', is_temporal=False):\n    \"\"\"\n    upright_axis: axis index among x,y,z, i.e. 2 for z\n    \"\"\"\n    self.is_temporal = is_temporal\n    self.D = 4 if is_temporal else 3\n    self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()]\n    # Use the rest of axes for flipping.\n    self.horz_axes = set(range(self.D)) - set([self.upright_axis])\n\n\n  def __call__(self, coords):\n    bsize = coords.size()[0]\n    for i in range(bsize):\n        if random.random() < 0.95:\n            for curr_ax in self.horz_axes:\n                if random.random() < 0.5:\n                    coord_max = torch.max(coords[i, :, curr_ax])\n                    coords[i, :, curr_ax] = coord_max - coords[i, :, curr_ax]\n    return coords"
  },
  {
    "path": "datasets/generate_few_shot_data.py",
    "content": "import pickle\nimport numpy as np\nimport random\nimport os\n\nroot = '../data/ModelNet/modelnet40_normal_resampled'\ntarget = '../data/ModelNetFewshot'\n\ntrain_data_path = os.path.join(root, 'modelnet40_train_8192pts_fps.dat')\ntest_data_path = os.path.join(root, 'modelnet40_test_8192pts_fps.dat')\n# train\nwith open(train_data_path, 'rb') as f:\n    train_list_of_points, train_list_of_labels = pickle.load(f)\nwith open(test_data_path, 'rb') as f:\n    test_list_of_points, test_list_of_labels = pickle.load(f)\n\n# list_of_points = train_list_of_points + test_list_of_points  \n# list_of_labels = train_list_of_labels + test_list_of_labels\n\ndef generate_fewshot_data(way, shot, prefix_ind, eval_sample=20):\n    train_cls_dataset = {}\n    test_cls_dataset = {}\n    train_dataset = []\n    test_dataset = []\n    # build a dict containing different class\n    for point, label in zip(train_list_of_points, train_list_of_labels):\n        label = label[0]\n        if train_cls_dataset.get(label) is None:\n            train_cls_dataset[label] = []\n        train_cls_dataset[label].append(point)\n    # build a dict containing different class\n    for point, label in zip(test_list_of_points, test_list_of_labels):\n        label = label[0]\n        if test_cls_dataset.get(label) is None:\n            test_cls_dataset[label] = []\n        test_cls_dataset[label].append(point)\n    print(sum([train_cls_dataset[i].__len__() for i in range(40)]))\n    print(sum([test_cls_dataset[i].__len__() for i in range(40)]))\n    # import pdb; pdb.set_trace()\n    keys = list(train_cls_dataset.keys())\n    random.shuffle(keys)\n\n    for i, key in enumerate(keys[:way]):\n        train_data_list = train_cls_dataset[key]\n        random.shuffle(train_data_list)\n        assert len(train_data_list) > shot\n        for data in train_data_list[:shot]:\n            train_dataset.append((data, i, key))\n\n        test_data_list = test_cls_dataset[key]\n        random.shuffle(test_data_list)\n        # import pdb; pdb.set_trace()\n        assert len(test_data_list) >= eval_sample\n        for data in test_data_list[:eval_sample]:\n            test_dataset.append((data, i, key))\n\n    random.shuffle(train_dataset)\n    random.shuffle(test_dataset)\n    dataset = {\n        'train': train_dataset,\n        'test' : test_dataset\n    }\n    save_path = os.path.join(target, f'{way}way_{shot}shot')\n    if not os.path.exists(save_path):\n        os.makedirs(save_path)\n    with open(os.path.join(save_path, f'{prefix_ind}.pkl'), 'wb') as f:\n        pickle.dump(dataset, f)\n    \n\nif __name__ == '__main__':\n    ways = [5, 10]\n    shots = [10, 20]\n    for way in ways:\n        for shot in shots:\n            for i in range(10):\n                generate_fewshot_data(way = way, shot = shot, prefix_ind = i)"
  },
  {
    "path": "datasets/io.py",
    "content": "import h5py\nimport numpy as np\n# import open3d\nimport os\n\nclass IO:\n    @classmethod\n    def get(cls, file_path):\n        _, file_extension = os.path.splitext(file_path)\n\n        if file_extension in ['.npy']:\n            return cls._read_npy(file_path)\n        # elif file_extension in ['.pcd']:\n        #     return cls._read_pcd(file_path)\n        elif file_extension in ['.h5']:\n            return cls._read_h5(file_path)\n        elif file_extension in ['.txt']:\n            return cls._read_txt(file_path)\n        else:\n            raise Exception('Unsupported file extension: %s' % file_extension)\n\n    # References: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py\n    @classmethod\n    def _read_npy(cls, file_path):\n        return np.load(file_path)\n       \n    # References: https://github.com/dimatura/pypcd/blob/master/pypcd/pypcd.py#L275\n    # Support PCD files without compression ONLY!\n    # @classmethod\n    # def _read_pcd(cls, file_path):\n    #     pc = open3d.io.read_point_cloud(file_path)\n    #     ptcloud = np.array(pc.points)\n    #     return ptcloud\n\n    @classmethod\n    def _read_txt(cls, file_path):\n        return np.loadtxt(file_path)\n\n    @classmethod\n    def _read_h5(cls, file_path):\n        f = h5py.File(file_path, 'r')\n        return f['data'][()]"
  },
  {
    "path": "extensions/chamfer_dist/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Thibault GROUEIX\n# @Date:   2019-08-07 20:54:24\n# @Last Modified by:   Haozhe Xie\n# @Last Modified time: 2019-12-18 15:06:25\n# @Email:  cshzxie@gmail.com\n\nimport torch\n\nimport chamfer\n\n\nclass ChamferFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, xyz1, xyz2):\n        dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2)\n        ctx.save_for_backward(xyz1, xyz2, idx1, idx2)\n\n        return dist1, dist2\n\n    @staticmethod\n    def backward(ctx, grad_dist1, grad_dist2):\n        xyz1, xyz2, idx1, idx2 = ctx.saved_tensors\n        grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)\n        return grad_xyz1, grad_xyz2\n\n\nclass ChamferDistanceL2(torch.nn.Module):\n    f''' Chamder Distance L2\n    '''\n    def __init__(self, ignore_zeros=False):\n        super().__init__()\n        self.ignore_zeros = ignore_zeros\n\n    def forward(self, xyz1, xyz2):\n        batch_size = xyz1.size(0)\n        if batch_size == 1 and self.ignore_zeros:\n            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)\n            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)\n            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)\n            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)\n\n        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)\n        return torch.mean(dist1) + torch.mean(dist2)\n\nclass ChamferDistanceL2_split(torch.nn.Module):\n    f''' Chamder Distance L2\n    '''\n    def __init__(self, ignore_zeros=False):\n        super().__init__()\n        self.ignore_zeros = ignore_zeros\n\n    def forward(self, xyz1, xyz2):\n        batch_size = xyz1.size(0)\n        if batch_size == 1 and self.ignore_zeros:\n            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)\n            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)\n            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)\n            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)\n\n        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)\n        return torch.mean(dist1), torch.mean(dist2)\n\nclass ChamferDistanceL1(torch.nn.Module):\n    f''' Chamder Distance L1\n    '''\n    def __init__(self, ignore_zeros=False):\n        super().__init__()\n        self.ignore_zeros = ignore_zeros\n\n    def forward(self, xyz1, xyz2):\n        batch_size = xyz1.size(0)\n        if batch_size == 1 and self.ignore_zeros:\n            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)\n            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)\n            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)\n            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)\n\n        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)\n        # import pdb\n        # pdb.set_trace()\n        dist1 = torch.sqrt(dist1)\n        dist2 = torch.sqrt(dist2)\n        return (torch.mean(dist1) + torch.mean(dist2))/2\n\n"
  },
  {
    "path": "extensions/chamfer_dist/chamfer.cu",
    "content": "/*\n * @Author: Haozhe Xie\n * @Date:   2019-08-07 20:54:24\n * @Last Modified by:   Haozhe Xie\n * @Last Modified time: 2020-06-17 14:58:55\n * @Email:  cshzxie@gmail.com\n */\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include <vector>\n\n__global__ void chamfer_dist_kernel(int batch_size,\n                                    int n,\n                                    const float* xyz1,\n                                    int m,\n                                    const float* xyz2,\n                                    float* dist,\n                                    int* indexes) {\n  const int batch = 512;\n  __shared__ float buf[batch * 3];\n  for (int i = blockIdx.x; i < batch_size; i += gridDim.x) {\n    for (int k2 = 0; k2 < m; k2 += batch) {\n      int end_k = min(m, k2 + batch) - k2;\n      for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) {\n        buf[j] = xyz2[(i * m + k2) * 3 + j];\n      }\n      __syncthreads();\n      for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;\n           j += blockDim.x * gridDim.y) {\n        float x1            = xyz1[(i * n + j) * 3 + 0];\n        float y1            = xyz1[(i * n + j) * 3 + 1];\n        float z1            = xyz1[(i * n + j) * 3 + 2];\n        float best_dist     = 0;\n        int best_dist_index = 0;\n        int end_ka          = end_k - (end_k & 3);\n        if (end_ka == batch) {\n          for (int k = 0; k < batch; k += 4) {\n            {\n              float x2   = buf[k * 3 + 0] - x1;\n              float y2   = buf[k * 3 + 1] - y1;\n              float z2   = buf[k * 3 + 2] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n\n              if (k == 0 || dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 3] - x1;\n              float y2   = buf[k * 3 + 4] - y1;\n              float z2   = buf[k * 3 + 5] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 1;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 6] - x1;\n              float y2   = buf[k * 3 + 7] - y1;\n              float z2   = buf[k * 3 + 8] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 2;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 9] - x1;\n              float y2   = buf[k * 3 + 10] - y1;\n              float z2   = buf[k * 3 + 11] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 3;\n              }\n            }\n          }\n        } else {\n          for (int k = 0; k < end_ka; k += 4) {\n            {\n              float x2   = buf[k * 3 + 0] - x1;\n              float y2   = buf[k * 3 + 1] - y1;\n              float z2   = buf[k * 3 + 2] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (k == 0 || dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 3] - x1;\n              float y2   = buf[k * 3 + 4] - y1;\n              float z2   = buf[k * 3 + 5] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 1;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 6] - x1;\n              float y2   = buf[k * 3 + 7] - y1;\n              float z2   = buf[k * 3 + 8] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 2;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 9] - x1;\n              float y2   = buf[k * 3 + 10] - y1;\n              float z2   = buf[k * 3 + 11] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 3;\n              }\n            }\n          }\n        }\n        for (int k = end_ka; k < end_k; k++) {\n          float x2   = buf[k * 3 + 0] - x1;\n          float y2   = buf[k * 3 + 1] - y1;\n          float z2   = buf[k * 3 + 2] - z1;\n          float dist = x2 * x2 + y2 * y2 + z2 * z2;\n          if (k == 0 || dist < best_dist) {\n            best_dist       = dist;\n            best_dist_index = k + k2;\n          }\n        }\n        if (k2 == 0 || dist[(i * n + j)] > best_dist) {\n          dist[(i * n + j)]    = best_dist;\n          indexes[(i * n + j)] = best_dist_index;\n        }\n      }\n      __syncthreads();\n    }\n  }\n}\n\nstd::vector<torch::Tensor> chamfer_cuda_forward(torch::Tensor xyz1,\n                                                torch::Tensor xyz2) {\n  const int batch_size = xyz1.size(0);\n  const int n          = xyz1.size(1);  // num_points point cloud A\n  const int m          = xyz2.size(1);  // num_points point cloud B\n  torch::Tensor dist1 =\n    torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat));\n  torch::Tensor dist2 =\n    torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat));\n  torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));\n  torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));\n\n  chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(\n    batch_size, n, xyz1.data_ptr<float>(), m, xyz2.data_ptr<float>(),\n    dist1.data_ptr<float>(), idx1.data_ptr<int>());\n  chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(\n    batch_size, m, xyz2.data_ptr<float>(), n, xyz1.data_ptr<float>(),\n    dist2.data_ptr<float>(), idx2.data_ptr<int>());\n\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess) {\n    printf(\"Error in chamfer_cuda_forward: %s\\n\", cudaGetErrorString(err));\n  }\n  return {dist1, dist2, idx1, idx2};\n}\n\n__global__ void chamfer_dist_grad_kernel(int b,\n                                         int n,\n                                         const float* xyz1,\n                                         int m,\n                                         const float* xyz2,\n                                         const float* grad_dist1,\n                                         const int* idx1,\n                                         float* grad_xyz1,\n                                         float* grad_xyz2) {\n  for (int i = blockIdx.x; i < b; i += gridDim.x) {\n    for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;\n         j += blockDim.x * gridDim.y) {\n      float x1 = xyz1[(i * n + j) * 3 + 0];\n      float y1 = xyz1[(i * n + j) * 3 + 1];\n      float z1 = xyz1[(i * n + j) * 3 + 2];\n      int j2   = idx1[i * n + j];\n      float x2 = xyz2[(i * m + j2) * 3 + 0];\n      float y2 = xyz2[(i * m + j2) * 3 + 1];\n      float z2 = xyz2[(i * m + j2) * 3 + 2];\n      float g  = grad_dist1[i * n + j] * 2;\n      atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));\n      atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));\n      atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));\n      atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));\n      atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));\n      atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));\n    }\n  }\n}\n\nstd::vector<torch::Tensor> chamfer_cuda_backward(torch::Tensor xyz1,\n                                                 torch::Tensor xyz2,\n                                                 torch::Tensor idx1,\n                                                 torch::Tensor idx2,\n                                                 torch::Tensor grad_dist1,\n                                                 torch::Tensor grad_dist2) {\n  const int batch_size    = xyz1.size(0);\n  const int n             = xyz1.size(1);  // num_points point cloud A\n  const int m             = xyz2.size(1);  // num_points point cloud B\n  torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat));\n  torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat));\n\n  chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>(\n    batch_size, n, xyz1.data_ptr<float>(), m, xyz2.data_ptr<float>(),\n    grad_dist1.data_ptr<float>(), idx1.data_ptr<int>(),\n    grad_xyz1.data_ptr<float>(), grad_xyz2.data_ptr<float>());\n  chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>(\n    batch_size, m, xyz2.data_ptr<float>(), n, xyz1.data_ptr<float>(),\n    grad_dist2.data_ptr<float>(), idx2.data_ptr<int>(),\n    grad_xyz2.data_ptr<float>(), grad_xyz1.data_ptr<float>());\n\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess) {\n    printf(\"Error in chamfer_cuda_backward: %s\\n\", cudaGetErrorString(err));\n  }\n  return {grad_xyz1, grad_xyz2};\n}\n"
  },
  {
    "path": "extensions/chamfer_dist/chamfer_cuda.cpp",
    "content": "/*\n * @Author: Haozhe Xie\n * @Date:   2019-08-07 20:54:24\n * @Last Modified by:   Haozhe Xie\n * @Last Modified time: 2019-12-10 10:33:50\n * @Email:  cshzxie@gmail.com\n */\n\n#include <torch/extension.h>\n#include <vector>\n\nstd::vector<torch::Tensor> chamfer_cuda_forward(torch::Tensor xyz1,\n                                                torch::Tensor xyz2);\n\nstd::vector<torch::Tensor> chamfer_cuda_backward(torch::Tensor xyz1,\n                                                 torch::Tensor xyz2,\n                                                 torch::Tensor idx1,\n                                                 torch::Tensor idx2,\n                                                 torch::Tensor grad_dist1,\n                                                 torch::Tensor grad_dist2);\n\nstd::vector<torch::Tensor> chamfer_forward(torch::Tensor xyz1,\n                                           torch::Tensor xyz2) {\n  return chamfer_cuda_forward(xyz1, xyz2);\n}\n\nstd::vector<torch::Tensor> chamfer_backward(torch::Tensor xyz1,\n                                            torch::Tensor xyz2,\n                                            torch::Tensor idx1,\n                                            torch::Tensor idx2,\n                                            torch::Tensor grad_dist1,\n                                            torch::Tensor grad_dist2) {\n  return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &chamfer_forward, \"Chamfer forward (CUDA)\");\n  m.def(\"backward\", &chamfer_backward, \"Chamfer backward (CUDA)\");\n}\n"
  },
  {
    "path": "extensions/chamfer_dist/setup.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Haozhe Xie\n# @Date:   2019-08-07 20:54:24\n# @Last Modified by:   Haozhe Xie\n# @Last Modified time: 2019-12-10 10:04:25\n# @Email:  cshzxie@gmail.com\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(name='chamfer',\n      version='2.0.0',\n      ext_modules=[\n          CUDAExtension('chamfer', [\n              'chamfer_cuda.cpp',\n              'chamfer.cu',\n          ]),\n      ],\n      cmdclass={'build_ext': BuildExtension})\n"
  },
  {
    "path": "extensions/chamfer_dist/test.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Haozhe Xie\n# @Date:   2019-12-10 10:38:01\n# @Last Modified by:   Haozhe Xie\n# @Last Modified time: 2019-12-26 14:21:36\n# @Email:  cshzxie@gmail.com\n#\n# Note:\n# - Replace float -> double, kFloat -> kDouble in chamfer.cu\n\nimport os\nimport sys\nimport torch\nimport unittest\n\n\nfrom torch.autograd import gradcheck\n\nsys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))\nfrom extensions.chamfer_dist import ChamferFunction\n\n\nclass ChamferDistanceTestCase(unittest.TestCase):\n    def test_chamfer_dist(self):\n        x = torch.rand(4, 64, 3).double()\n        y = torch.rand(4, 128, 3).double()\n        x.requires_grad = True\n        y.requires_grad = True\n        print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()]))\n\n\n\nif __name__ == '__main__':\n    # unittest.main()\n    import pdb\n    x = torch.rand(32,128,3)\n    y = torch.rand(32,128,3)\n    pdb.set_trace()\n"
  },
  {
    "path": "extensions/emd/README.md",
    "content": "# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)\n\n## Dependency\n\nThe code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.\n\n## Usage\n\nFirst compile using\n        \n        python setup.py install\n\nThen, copy the lib file out to the main directory,\n\n        cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .\n\nThen, you can use it by simply\n\n        from emd import earth_mover_distance\n        d = earth_mover_distance(p1, p2, transpose=False)  # p1: B x N1 x 3, p2: B x N2 x 3\n\nCheck `test_emd_loss.py` for example.\n\n## Author\n\nThe cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.\n\n## License\n\nMIT\n\n"
  },
  {
    "path": "extensions/emd/__init__.py",
    "content": "from .emd import earth_mover_distance as emd\n\n__all__ = ['emd']"
  },
  {
    "path": "extensions/emd/cuda/emd.cpp",
    "content": "#ifndef _EMD\n#define _EMD\n\n#include <vector>\n#include <torch/extension.h>\n\n//CUDA declarations\nat::Tensor ApproxMatchForward(\n    const at::Tensor xyz1,\n    const at::Tensor xyz2);\n\nat::Tensor MatchCostForward(\n    const at::Tensor xyz1,\n    const at::Tensor xyz2,\n    const at::Tensor match);\n\nstd::vector<at::Tensor> MatchCostBackward(\n    const at::Tensor grad_cost,\n    const at::Tensor xyz1,\n    const at::Tensor xyz2,\n    const at::Tensor match);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"approxmatch_forward\", &ApproxMatchForward,\"ApproxMatch forward (CUDA)\");\n  m.def(\"matchcost_forward\", &MatchCostForward,\"MatchCost forward (CUDA)\");\n  m.def(\"matchcost_backward\", &MatchCostBackward,\"MatchCost backward (CUDA)\");\n}\n\n#endif\n"
  },
  {
    "path": "extensions/emd/cuda/emd_kernel.cu",
    "content": "/**********************************\n * Original Author: Haoqiang Fan\n * Modified by: Kaichun Mo\n *********************************/\n\n#ifndef _EMD_KERNEL\n#define _EMD_KERNEL\n\n#include <cmath>\n#include <vector>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAApplyUtils.cuh>  // at::cuda::getApplyGrid\n// #include <THC/THC.h>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\n\n/********************************\n* Forward kernel for approxmatch\n*********************************/\n\ntemplate<typename scalar_t>\n__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){\n\tscalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;\n\tscalar_t multiL,multiR;\n\tif (n>=m){\n\t\tmultiL=1;\n\t\tmultiR=n/m;\n\t}else{\n\t\tmultiL=m/n;\n\t\tmultiR=1;\n\t}\n\tconst int Block=1024;\n\t__shared__ scalar_t buf[Block*4];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfor (int j=threadIdx.x;j<n*m;j+=blockDim.x)\n\t\t\tmatch[i*n*m+j]=0;\n\t\tfor (int j=threadIdx.x;j<n;j+=blockDim.x)\n\t\t\tremainL[j]=multiL;\n\t\tfor (int j=threadIdx.x;j<m;j+=blockDim.x)\n\t\t\tremainR[j]=multiR;\n\t\t__syncthreads();\n\t\tfor (int j=7;j>=-2;j--){\n\t\t\tscalar_t level=-powf(4.0f,j);\n\t\t\tif (j==-2){\n\t\t\t\tlevel=0;\n\t\t\t}\n\t\t\tfor (int k0=0;k0<n;k0+=blockDim.x){\n\t\t\t\tint k=k0+threadIdx.x;\n\t\t\t\tscalar_t x1=0,y1=0,z1=0;\n\t\t\t\tif (k<n){\n\t\t\t\t\tx1=xyz1[i*n*3+k*3+0];\n\t\t\t\t\ty1=xyz1[i*n*3+k*3+1];\n\t\t\t\t\tz1=xyz1[i*n*3+k*3+2];\n\t\t\t\t}\n\t\t\t\tscalar_t suml=1e-9f;\n\t\t\t\tfor (int l0=0;l0<m;l0+=Block){\n\t\t\t\t\tint lend=min(m,l0+Block)-l0;\n\t\t\t\t\tfor (int l=threadIdx.x;l<lend;l+=blockDim.x){\n\t\t\t\t\t\tscalar_t x2=xyz2[i*m*3+l0*3+l*3+0];\n\t\t\t\t\t\tscalar_t y2=xyz2[i*m*3+l0*3+l*3+1];\n\t\t\t\t\t\tscalar_t z2=xyz2[i*m*3+l0*3+l*3+2];\n\t\t\t\t\t\tbuf[l*4+0]=x2;\n\t\t\t\t\t\tbuf[l*4+1]=y2;\n\t\t\t\t\t\tbuf[l*4+2]=z2;\n\t\t\t\t\t\tbuf[l*4+3]=remainR[l0+l];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tfor (int l=0;l<lend;l++){\n\t\t\t\t\t\tscalar_t x2=buf[l*4+0];\n\t\t\t\t\t\tscalar_t y2=buf[l*4+1];\n\t\t\t\t\t\tscalar_t z2=buf[l*4+2];\n\t\t\t\t\t\tscalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));\n\t\t\t\t\t\tscalar_t w=__expf(d)*buf[l*4+3];\n\t\t\t\t\t\tsuml+=w;\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (k<n)\n\t\t\t\t\tratioL[k]=remainL[k]/suml;\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t\tfor (int l0=0;l0<m;l0+=blockDim.x){\n\t\t\t\tint l=l0+threadIdx.x;\n\t\t\t\tscalar_t x2=0,y2=0,z2=0;\n\t\t\t\tif (l<m){\n\t\t\t\t\tx2=xyz2[i*m*3+l*3+0];\n\t\t\t\t\ty2=xyz2[i*m*3+l*3+1];\n\t\t\t\t\tz2=xyz2[i*m*3+l*3+2];\n\t\t\t\t}\n\t\t\t\tscalar_t sumr=0;\n\t\t\t\tfor (int k0=0;k0<n;k0+=Block){\n\t\t\t\t\tint kend=min(n,k0+Block)-k0;\n\t\t\t\t\tfor (int k=threadIdx.x;k<kend;k+=blockDim.x){\n\t\t\t\t\t\tbuf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];\n\t\t\t\t\t\tbuf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];\n\t\t\t\t\t\tbuf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];\n\t\t\t\t\t\tbuf[k*4+3]=ratioL[k0+k];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tfor (int k=0;k<kend;k++){\n\t\t\t\t\t\tscalar_t x1=buf[k*4+0];\n\t\t\t\t\t\tscalar_t y1=buf[k*4+1];\n\t\t\t\t\t\tscalar_t z1=buf[k*4+2];\n\t\t\t\t\t\tscalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];\n\t\t\t\t\t\tsumr+=w;\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (l<m){\n\t\t\t\t\tsumr*=remainR[l];\n\t\t\t\t\tscalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);\n\t\t\t\t\tratioR[l]=consumption*remainR[l];\n\t\t\t\t\tremainR[l]=fmaxf(0.0f,remainR[l]-sumr);\n\t\t\t\t}\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t\tfor (int k0=0;k0<n;k0+=blockDim.x){\n\t\t\t\tint k=k0+threadIdx.x;\n\t\t\t\tscalar_t x1=0,y1=0,z1=0;\n\t\t\t\tif (k<n){\n\t\t\t\t\tx1=xyz1[i*n*3+k*3+0];\n\t\t\t\t\ty1=xyz1[i*n*3+k*3+1];\n\t\t\t\t\tz1=xyz1[i*n*3+k*3+2];\n\t\t\t\t}\n\t\t\t\tscalar_t suml=0;\n\t\t\t\tfor (int l0=0;l0<m;l0+=Block){\n\t\t\t\t\tint lend=min(m,l0+Block)-l0;\n\t\t\t\t\tfor (int l=threadIdx.x;l<lend;l+=blockDim.x){\n\t\t\t\t\t\tbuf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];\n\t\t\t\t\t\tbuf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];\n\t\t\t\t\t\tbuf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];\n\t\t\t\t\t\tbuf[l*4+3]=ratioR[l0+l];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tscalar_t rl=ratioL[k];\n\t\t\t\t\tif (k<n){\n\t\t\t\t\t\tfor (int l=0;l<lend;l++){\n\t\t\t\t\t\t\tscalar_t x2=buf[l*4+0];\n\t\t\t\t\t\t\tscalar_t y2=buf[l*4+1];\n\t\t\t\t\t\t\tscalar_t z2=buf[l*4+2];\n\t\t\t\t\t\t\tscalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];\n\t\t\t\t\t\t\tmatch[i*n*m+(l0+l)*n+k]+=w;\n\t\t\t\t\t\t\tsuml+=w;\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (k<n)\n\t\t\t\t\tremainL[k]=fmaxf(0.0f,remainL[k]-suml);\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t}\n\t}\n}\n\n//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){\n//\tapproxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);\n//}\n\n/* ApproxMatch forward interface\nInput:\n  xyz1: (B, N1, 3)  # dataset_points\n  xyz2: (B, N2, 3)  # query_points\nOutput:\n  match: (B, N2, N1)\n*/\nat::Tensor ApproxMatchForward(\n    const at::Tensor xyz1,\n    const at::Tensor xyz2){\n  const auto b = xyz1.size(0);\n  const auto n = xyz1.size(1);\n  const auto m = xyz2.size(1);\n\n  CHECK_EQ(xyz2.size(0), b);\n  CHECK_EQ(xyz1.size(2), 3);\n  CHECK_EQ(xyz2.size(2), 3);\n  CHECK_INPUT(xyz1);\n  CHECK_INPUT(xyz2);\n\n  auto match = at::zeros({b, m, n}, xyz1.type());\n  auto temp = at::zeros({b, (n+m)*2}, xyz1.type());\n\n  AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), \"ApproxMatchForward\", ([&] {\n        approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());\n  }));\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  return match;\n}\n\n\n/********************************\n* Forward kernel for matchcost\n*********************************/\n\ntemplate<typename scalar_t>\n__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){\n\t__shared__ scalar_t allsum[512];\n\tconst int Block=1024;\n\t__shared__ scalar_t buf[Block*3];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tscalar_t subsum=0;\n\t\tfor (int k0=0;k0<n;k0+=blockDim.x){\n\t\t\tint k=k0+threadIdx.x;\n\t\t\tscalar_t x1=0,y1=0,z1=0;\n\t\t\tif (k<n){\n\t\t\t\tx1=xyz1[i*n*3+k*3+0];\n\t\t\t\ty1=xyz1[i*n*3+k*3+1];\n\t\t\t\tz1=xyz1[i*n*3+k*3+2];\n\t\t\t}\n\t\t\tfor (int l0=0;l0<m;l0+=Block){\n\t\t\t\tint lend=min(m,l0+Block)-l0;\n\t\t\t\tfor (int l=threadIdx.x;l<lend*3;l+=blockDim.x)\n\t\t\t\t\tbuf[l]=xyz2[i*m*3+l0*3+l];\n\t\t\t\t__syncthreads();\n\t\t\t\tif (k<n){\n\t\t\t\t\tfor (int l=0;l<lend;l++){\n\t\t\t\t\t\tscalar_t x2=buf[l*3+0];\n\t\t\t\t\t\tscalar_t y2=buf[l*3+1];\n\t\t\t\t\t\tscalar_t z2=buf[l*3+2];\n\t\t\t\t\t\tscalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);\n\t\t\t\t\t\tsubsum+=d*match[i*n*m+(l0+l)*n+k];\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\t__syncthreads();\n\t\t\t}\n\t\t}\n\t\tallsum[threadIdx.x]=subsum;\n\t\tfor (int j=1;j<blockDim.x;j<<=1){\n\t\t\t__syncthreads();\n\t\t\tif ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){\n\t\t\t\tallsum[threadIdx.x]+=allsum[threadIdx.x+j];\n\t\t\t}\n\t\t}\n\t\tif (threadIdx.x==0)\n\t\t\tout[i]=allsum[0];\n\t\t__syncthreads();\n\t}\n}\n\n//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){\n//\tmatchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);\n//}\n\n/* MatchCost forward interface\nInput:\n  xyz1: (B, N1, 3)  # dataset_points\n  xyz2: (B, N2, 3)  # query_points\n  match: (B, N2, N1)\nOutput:\n  cost: (B)\n*/\nat::Tensor MatchCostForward(\n    const at::Tensor xyz1,\n    const at::Tensor xyz2,\n    const at::Tensor match){\n  const auto b = xyz1.size(0);\n  const auto n = xyz1.size(1);\n  const auto m = xyz2.size(1);\n\n  CHECK_EQ(xyz2.size(0), b);\n  CHECK_EQ(xyz1.size(2), 3);\n  CHECK_EQ(xyz2.size(2), 3);\n  CHECK_INPUT(xyz1);\n  CHECK_INPUT(xyz2);\n\n  auto cost = at::zeros({b}, xyz1.type());\n\n  AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), \"MatchCostForward\", ([&] {\n        matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());\n  }));\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  return cost;\n}\n\n\n/********************************\n* matchcostgrad2 kernel\n*********************************/\n\ntemplate<typename scalar_t>\n__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){\n\t__shared__ scalar_t sum_grad[256*3];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tint kbeg=m*blockIdx.y/gridDim.y;\n\t\tint kend=m*(blockIdx.y+1)/gridDim.y;\n\t\tfor (int k=kbeg;k<kend;k++){\n\t\t\tscalar_t x2=xyz2[(i*m+k)*3+0];\n\t\t\tscalar_t y2=xyz2[(i*m+k)*3+1];\n\t\t\tscalar_t z2=xyz2[(i*m+k)*3+2];\n\t\t\tscalar_t subsumx=0,subsumy=0,subsumz=0;\n\t\t\tfor (int j=threadIdx.x;j<n;j+=blockDim.x){\n\t\t\t\tscalar_t x1=x2-xyz1[(i*n+j)*3+0];\n\t\t\t\tscalar_t y1=y2-xyz1[(i*n+j)*3+1];\n\t\t\t\tscalar_t z1=z2-xyz1[(i*n+j)*3+2];\n\t\t\t\tscalar_t d=match[i*n*m+k*n+j]*2;\n\t\t\t\tsubsumx+=x1*d;\n\t\t\t\tsubsumy+=y1*d;\n\t\t\t\tsubsumz+=z1*d;\n\t\t\t}\n\t\t\tsum_grad[threadIdx.x*3+0]=subsumx;\n\t\t\tsum_grad[threadIdx.x*3+1]=subsumy;\n\t\t\tsum_grad[threadIdx.x*3+2]=subsumz;\n\t\t\tfor (int j=1;j<blockDim.x;j<<=1){\n\t\t\t\t__syncthreads();\n\t\t\t\tint j1=threadIdx.x;\n\t\t\t\tint j2=threadIdx.x+j;\n\t\t\t\tif ((j1&j)==0 && j2<blockDim.x){\n\t\t\t\t\tsum_grad[j1*3+0]+=sum_grad[j2*3+0];\n\t\t\t\t\tsum_grad[j1*3+1]+=sum_grad[j2*3+1];\n\t\t\t\t\tsum_grad[j1*3+2]+=sum_grad[j2*3+2];\n\t\t\t\t}\n\t\t\t}\n\t\t\tif (threadIdx.x==0){\n\t\t\t\tgrad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];\n\t\t\t\tgrad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];\n\t\t\t\tgrad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t}\n\t}\n}\n\n/********************************\n* matchcostgrad1 kernel\n*********************************/\n\ntemplate<typename scalar_t>\n__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfor (int l=threadIdx.x;l<n;l+=blockDim.x){\n\t\t\tscalar_t x1=xyz1[i*n*3+l*3+0];\n\t\t\tscalar_t y1=xyz1[i*n*3+l*3+1];\n\t\t\tscalar_t z1=xyz1[i*n*3+l*3+2];\n\t\t\tscalar_t dx=0,dy=0,dz=0;\n\t\t\tfor (int k=0;k<m;k++){\n\t\t\t\tscalar_t x2=xyz2[i*m*3+k*3+0];\n\t\t\t\tscalar_t y2=xyz2[i*m*3+k*3+1];\n\t\t\t\tscalar_t z2=xyz2[i*m*3+k*3+2];\n\t\t\t\tscalar_t d=match[i*n*m+k*n+l]*2;\n\t\t\t\tdx+=(x1-x2)*d;\n\t\t\t\tdy+=(y1-y2)*d;\n\t\t\t\tdz+=(z1-z2)*d;\n\t\t\t}\n\t\t\tgrad1[i*n*3+l*3+0]=dx*grad_cost[i];\n\t\t\tgrad1[i*n*3+l*3+1]=dy*grad_cost[i];\n\t\t\tgrad1[i*n*3+l*3+2]=dz*grad_cost[i];\n\t\t}\n\t}\n}\n\n//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){\n//\tmatchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);\n//\tmatchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);\n//}\n\n\n/* MatchCost backward interface\nInput:\n  grad_cost: (B)    # gradients on cost\n  xyz1: (B, N1, 3)  # dataset_points\n  xyz2: (B, N2, 3)  # query_points\n  match: (B, N2, N1)\nOutput:\n  grad1: (B, N1, 3)\n  grad2: (B, N2, 3)\n*/\nstd::vector<at::Tensor> MatchCostBackward(\n    const at::Tensor grad_cost,\n    const at::Tensor xyz1,\n    const at::Tensor xyz2,\n    const at::Tensor match){\n  const auto b = xyz1.size(0);\n  const auto n = xyz1.size(1);\n  const auto m = xyz2.size(1);\n\n  CHECK_EQ(xyz2.size(0), b);\n  CHECK_EQ(xyz1.size(2), 3);\n  CHECK_EQ(xyz2.size(2), 3);\n  CHECK_INPUT(xyz1);\n  CHECK_INPUT(xyz2);\n\n  auto grad1 = at::zeros({b, n, 3}, xyz1.type());\n  auto grad2 = at::zeros({b, m, 3}, xyz1.type());\n\n  AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), \"MatchCostBackward\", ([&] {\n        matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());\n        matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());\n  }));\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  return std::vector<at::Tensor>({grad1, grad2});\n}\n\n#endif\n"
  },
  {
    "path": "extensions/emd/emd.py",
    "content": "import torch\nimport emd_cuda\n\n\nclass EarthMoverDistanceFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, xyz1, xyz2):\n        xyz1 = xyz1.contiguous()\n        xyz2 = xyz2.contiguous()\n        assert xyz1.is_cuda and xyz2.is_cuda, \"Only support cuda currently.\"\n        match = emd_cuda.approxmatch_forward(xyz1, xyz2)\n        cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)\n        ctx.save_for_backward(xyz1, xyz2, match)\n        return cost\n\n    @staticmethod\n    def backward(ctx, grad_cost):\n        xyz1, xyz2, match = ctx.saved_tensors\n        grad_cost = grad_cost.contiguous()\n        grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)\n        return grad_xyz1, grad_xyz2\n\n\n\n\nclass earth_mover_distance(torch.nn.Module):\n    f''' emd\n    '''\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, xyz1, xyz2, transpose=False):\n        \"\"\"Earth Mover Distance (Approx)\n\n        Args:\n            xyz1 (torch.Tensor): (b, n1, 3)\n            xyz2 (torch.Tensor): (b, n2, 3)\n            transpose (bool): whether to transpose inputs as it might be BCN format.\n                Extensions only support BNC format.\n\n        Returns:\n            cost (torch.Tensor): (b)\n\n        \"\"\"\n\n        cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)\n        cost = cost / xyz1.size(1)\n        \n        return cost.mean()\n# def earth_mover_distance(xyz1, xyz2, transpose=True):\n#     \"\"\"Earth Mover Distance (Approx)\n\n#     Args:\n#         xyz1 (torch.Tensor): (b, 3, n1)\n#         xyz2 (torch.Tensor): (b, 3, n1)\n#         transpose (bool): whether to transpose inputs as it might be BCN format.\n#             Extensions only support BNC format.\n\n#     Returns:\n#         cost (torch.Tensor): (b)\n\n#     \"\"\"\n#     if xyz1.dim() == 2:\n#         xyz1 = xyz1.unsqueeze(0)\n#     if xyz2.dim() == 2:\n#         xyz2 = xyz2.unsqueeze(0)\n#     if transpose:\n#         xyz1 = xyz1.transpose(1, 2)\n#         xyz2 = xyz2.transpose(1, 2)\n#     cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)\n#     return cost\n\n"
  },
  {
    "path": "extensions/emd/setup.py",
    "content": "\"\"\"Setup extension\n\nNotes:\n    If extra_compile_args is provided, you need to provide different instances for different extensions.\n    Refer to https://github.com/pytorch/pytorch/issues/20169\n\n\"\"\"\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\nsetup(\n    name='emd_ext',\n    ext_modules=[\n        CUDAExtension(\n            name='emd_cuda',\n            sources=[\n                'cuda/emd.cpp',\n                'cuda/emd_kernel.cu',\n            ],\n            extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}\n        ),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    })\n"
  },
  {
    "path": "extensions/emd/test_emd_loss.py",
    "content": "import torch\nimport numpy as np\nimport time\nfrom emd import earth_mover_distance\n\n# gt\np1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()\np1 = p1.repeat(3, 1, 1)\np2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()\np2 = p2.repeat(3, 1, 1)\nprint(p1)\nprint(p2)\nprint(p1.shape)\np1.requires_grad = True\np2.requires_grad = True\n\ngt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 +  \\\n         (((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \\\n         (((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3\nprint('gt_dist: ', gt_dist)\n\ngt_dist.backward()\nprint(p1.grad)\nprint(p2.grad)\n\n# emd\np1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()\np1 = p1.repeat(3, 1, 1)\np2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()\np2 = p2.repeat(3, 1, 1)\nprint(p1)\nprint(p2)\np1.requires_grad = True\np2.requires_grad = True\n\nd = earth_mover_distance(p1, p2, transpose=False)\nprint(d)\n\nloss = d[0] / 2 + d[1] * 2 + d[2] / 3\nprint(loss)\n\nloss.backward()\nprint(p1.grad)\nprint(p2.grad)\n\n"
  },
  {
    "path": "figures/a",
    "content": "\n"
  },
  {
    "path": "main.py",
    "content": "from tools import pretrain_run_net as pretrain\nfrom tools import finetune_run_net as finetune\nfrom tools import test_run_net as test_net\nfrom utils import parser, dist_utils, misc\nfrom utils.logger import *\nfrom utils.config import *\nimport time\nimport os\nimport torch\nfrom tensorboardX import SummaryWriter\nfrom torchstat import stat\n\n\ndef main():\n    # args\n    args = parser.get_args()\n    # CUDA\n    args.use_gpu = torch.cuda.is_available()\n    if args.use_gpu:\n        torch.backends.cudnn.benchmark = True\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        args.distributed = False\n    else:\n        args.distributed = True\n        dist_utils.init_dist(args.launcher)\n        # re-set gpu_ids with distributed training mode\n        _, world_size = dist_utils.get_dist_info()\n        args.world_size = world_size\n    # logger\n    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    log_file = os.path.join(args.experiment_path, f'{timestamp}.log')\n    logger = get_root_logger(log_file=log_file, name=args.log_name)\n    # define the tensorboard writer\n    if not args.test:\n        if args.local_rank == 0:\n            train_writer = SummaryWriter(\n                os.path.join(args.tfboard_path, 'train'))\n            val_writer = SummaryWriter(os.path.join(args.tfboard_path, 'test'))\n        else:\n            train_writer = None\n            val_writer = None\n    # config\n    config = get_config(args, logger=logger)\n    # batch size\n    if args.distributed:\n        assert config.total_bs % world_size == 0\n        config.dataset.train.others.bs = config.total_bs // world_size\n        if config.dataset.get('extra_train'):\n            config.dataset.extra_train.others.bs = config.total_bs // world_size * 2\n        config.dataset.val.others.bs = config.total_bs // world_size * 2\n        if config.dataset.get('test'):\n            config.dataset.test.others.bs = config.total_bs // world_size\n    else:\n        config.dataset.train.others.bs = config.total_bs\n        if config.dataset.get('extra_train'):\n            config.dataset.extra_train.others.bs = config.total_bs * 2\n        config.dataset.val.others.bs = config.total_bs * 2\n        if config.dataset.get('test'):\n            config.dataset.test.others.bs = config.total_bs\n    # log\n    log_args_to_file(args, 'args', logger=logger)\n    log_config_to_file(config, 'config', logger=logger)\n    # exit()\n    logger.info(f'Distributed training: {args.distributed}')\n    # set random seeds\n    if args.seed is not None:\n        logger.info(f'Set random seed to {args.seed}, '\n                    f'deterministic: {args.deterministic}')\n        # seed + rank, for augmentation\n        misc.set_random_seed(args.seed + args.local_rank,\n                             deterministic=args.deterministic)\n    if args.distributed:\n        assert args.local_rank == torch.distributed.get_rank()\n\n    if args.shot != -1:\n        config.dataset.train.others.shot = args.shot\n        config.dataset.train.others.way = args.way\n        config.dataset.train.others.fold = args.fold\n        config.dataset.val.others.shot = args.shot\n        config.dataset.val.others.way = args.way\n        config.dataset.val.others.fold = args.fold\n\n    # run\n    if args.test:\n        test_net(args, config)\n    else:\n        if args.finetune_model or args.scratch_model:\n            finetune(args, config, train_writer, val_writer)\n        else:\n            pretrain(args, config, train_writer, val_writer)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "main_vis.py",
    "content": "# from tools import run_net\nfrom tools import test_net\nfrom utils import parser, dist_utils, misc\nfrom utils.logger import *\nfrom utils.config import *\nimport time\nimport os\nimport torch\nfrom tensorboardX import SummaryWriter\n\ndef main():\n    # args\n    args = parser.get_args()\n    # CUDA\n    args.use_gpu = torch.cuda.is_available()\n    if args.use_gpu:\n        torch.backends.cudnn.benchmark = True\n    # init distributed env first, since logger depends on the dist info.\n    if args.launcher == 'none':\n        args.distributed = False\n    else:\n        args.distributed = True\n        dist_utils.init_dist(args.launcher)\n        # re-set gpu_ids with distributed training mode\n        _, world_size = dist_utils.get_dist_info()\n        args.world_size = world_size\n    # logger\n    timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())\n    log_file = os.path.join(args.experiment_path, f'{timestamp}.log')\n    logger = get_root_logger(log_file=log_file, name=args.log_name)\n    # define the tensorboard writer\n    if not args.test:\n        if args.local_rank == 0:\n            train_writer = SummaryWriter(os.path.join(args.tfboard_path, 'train'))\n            val_writer = SummaryWriter(os.path.join(args.tfboard_path, 'test'))\n        else:\n            train_writer = None\n            val_writer = None\n    # config\n    config = get_config(args, logger = logger)\n    # batch size\n    if args.distributed:\n        assert config.total_bs % world_size == 0\n        config.dataset.train.others.bs = config.total_bs // world_size\n        config.dataset.val.others.bs = 1\n        config.dataset.test.others.bs = 1\n    else:\n        config.dataset.train.others.bs = config.total_bs\n        config.dataset.val.others.bs = 1\n        config.dataset.test.others.bs = 1\n    # log \n    log_args_to_file(args, 'args', logger = logger)\n    log_config_to_file(config, 'config', logger = logger)\n    # exit()\n    logger.info(f'Distributed training: {args.distributed}')\n    # set random seeds\n    if args.seed is not None:\n        logger.info(f'Set random seed to {args.seed}, '\n                    f'deterministic: {args.deterministic}')\n        misc.set_random_seed(args.seed + args.local_rank, deterministic=args.deterministic) # seed + rank, for augmentation\n    if args.distributed:\n        assert args.local_rank == torch.distributed.get_rank() \n\n    # run\n    if args.test:\n        test_net(args, config)\n    else:\n        # run_net(args, config, train_writer, val_writer)\n        raise NotImplementedError\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "models/GPT.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport matplotlib.pyplot as plt\n\n\nclass Block(nn.Module):\n    def __init__(self, embed_dim, num_heads):\n        super(Block, self).__init__()\n        self.ln_1 = nn.LayerNorm(embed_dim)\n        self.ln_2 = nn.LayerNorm(embed_dim)\n        self.attn = nn.MultiheadAttention(embed_dim, num_heads)\n        self.mlp = nn.Sequential(\n            nn.Linear(embed_dim, embed_dim * 4),\n            nn.GELU(),\n            nn.Linear(embed_dim * 4, embed_dim),\n        )\n\n    def forward(self, x, attn_mask):\n\n        x = self.ln_1(x)\n        # a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)\n        a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)\n        x = x + a\n        m = self.mlp(self.ln_2(x))\n        x = x + m\n        return x\n\n\nclass GPT_extractor(nn.Module):\n    def __init__(\n        self, embed_dim, num_heads, num_layers, num_classes, trans_dim, group_size, pretrained=False\n    ):\n        super(GPT_extractor, self).__init__()\n\n        self.embed_dim = embed_dim\n        self.trans_dim = trans_dim\n        self.group_size = group_size\n\n        # start of sequence token\n        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))\n        nn.init.normal_(self.sos)\n\n        self.layers = nn.ModuleList()\n        for _ in range(num_layers):\n            self.layers.append(Block(embed_dim, num_heads))\n\n        self.ln_f = nn.LayerNorm(embed_dim)\n        # prediction head\n        self.increase_dim = nn.Sequential(\n            nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)\n        )\n\n        if pretrained == False:\n            self.cls_head_finetune = nn.Sequential(\n                nn.Linear(self.trans_dim * 2, 256),\n                nn.BatchNorm1d(256),\n                nn.ReLU(inplace=True),\n                nn.Dropout(0.5),\n                nn.Linear(256, 256),\n                nn.BatchNorm1d(256),\n                nn.ReLU(inplace=True),\n                nn.Dropout(0.5),\n                nn.Linear(256, num_classes)\n            )\n\n            self.cls_norm = nn.LayerNorm(self.trans_dim)\n\n    def forward(self, h, pos, attn_mask, classify=False):\n        \"\"\"\n        Expect input as shape [sequence len, batch]\n        If classify, return classification logits\n        \"\"\"\n        batch, length, C = h.shape\n\n        h = h.transpose(0, 1)\n        pos = pos.transpose(0, 1)\n\n        # prepend sos token\n        sos = torch.ones(1, batch, self.embed_dim, device=h.device) * self.sos\n        if not classify:\n            h = torch.cat([sos, h[:-1, :, :]], axis=0)\n        else:\n            h = torch.cat([sos, h], axis=0)\n\n        # transformer\n        for layer in self.layers:\n            h = layer(h + pos, attn_mask)\n\n        h = self.ln_f(h)\n\n        encoded_points = h.transpose(0, 1)\n        if not classify:\n            return encoded_points\n\n        h = h.transpose(0, 1)\n        h = self.cls_norm(h)\n        concat_f = torch.cat([h[:, 1], h[:, 2:].max(1)[0]], dim=-1)\n        ret = self.cls_head_finetune(concat_f)\n        return ret, encoded_points\n\n\nclass GPT_generator(nn.Module):\n    def __init__(\n        self, embed_dim, num_heads, num_layers, trans_dim, group_size\n    ):\n        super(GPT_generator, self).__init__()\n\n        self.embed_dim = embed_dim\n        self.trans_dim = trans_dim\n        self.group_size = group_size\n\n        # start of sequence token\n        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))\n        nn.init.normal_(self.sos)\n\n        self.layers = nn.ModuleList()\n        for _ in range(num_layers):\n            self.layers.append(Block(embed_dim, num_heads))\n\n        self.ln_f = nn.LayerNorm(embed_dim)\n        self.increase_dim = nn.Sequential(\n            nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)\n        )\n\n    def forward(self, h, pos, attn_mask):\n        \"\"\"\n        Expect input as shape [sequence len, batch]\n        If classify, return classification logits\n        \"\"\"\n        batch, length, C = h.shape\n\n        h = h.transpose(0, 1)\n        pos = pos.transpose(0, 1)\n\n        # transformer\n        for layer in self.layers:\n            h = layer(h + pos, attn_mask)\n\n        h = self.ln_f(h)\n\n        rebuild_points = self.increase_dim(h.transpose(1, 2)).transpose(\n            1, 2).transpose(0, 1).reshape(batch * length, -1, 3)\n\n        return rebuild_points\n"
  },
  {
    "path": "models/PointGPT.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport timm\nfrom timm.models.layers import DropPath, trunc_normal_\nimport numpy as np\nfrom .build import MODELS\nfrom utils import misc\nfrom utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message\nfrom utils.logger import *\nimport random\nfrom knn_cuda import KNN\nfrom extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2\nfrom models.GPT import GPT_extractor, GPT_generator\nimport math\nfrom models.z_order import *\n\nclass Encoder_large(nn.Module):  # Embedding module\n    def __init__(self, encoder_channel):\n        super().__init__()\n        self.encoder_channel = encoder_channel\n        self.first_conv = nn.Sequential(\n            nn.Conv1d(3, 256, 1),\n            nn.BatchNorm1d(256),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(256, 512, 1),\n            nn.BatchNorm1d(512),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(512, 1024, 1)\n        )\n        self.second_conv = nn.Sequential(\n            nn.Conv1d(2048, 2048, 1),\n            nn.BatchNorm1d(2048),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(2048, self.encoder_channel, 1)\n        )\n\n    def forward(self, point_groups):\n        '''\n            point_groups : B G N 3\n            -----------------\n            feature_global : B G C\n        '''\n        bs, g, n, _ = point_groups.shape\n        point_groups = point_groups.reshape(bs * g, n, 3)\n        # encoder\n        feature = self.first_conv(point_groups.transpose(2, 1))  # BG 256 n\n        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # BG 256 1\n        feature = torch.cat(\n            [feature_global.expand(-1, -1, n), feature], dim=1)  # BG 512 n\n        feature = self.second_conv(feature)  # BG 1024 n\n        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # BG 1024\n        return feature_global.reshape(bs, g, self.encoder_channel)\n\nclass Encoder_small(nn.Module):  # Embedding module\n    def __init__(self, encoder_channel):\n        super().__init__()\n        self.encoder_channel = encoder_channel\n        self.first_conv = nn.Sequential(\n            nn.Conv1d(3, 128, 1),\n            nn.BatchNorm1d(128),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(128, 256, 1)\n        )\n        self.second_conv = nn.Sequential(\n            nn.Conv1d(512, 512, 1),\n            nn.BatchNorm1d(512),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(512, self.encoder_channel, 1)\n        )\n\n    def forward(self, point_groups):\n        '''\n            point_groups : B G N 3\n            -----------------\n            feature_global : B G C\n        '''\n        bs, g, n, _ = point_groups.shape\n        point_groups = point_groups.reshape(bs * g, n, 3)\n        # encoder\n        feature = self.first_conv(point_groups.transpose(2, 1))\n        feature_global = torch.max(feature, dim=2, keepdim=True)[0]\n        feature = torch.cat(\n            [feature_global.expand(-1, -1, n), feature], dim=1)\n        feature = self.second_conv(feature)\n        feature_global = torch.max(feature, dim=2, keepdim=False)[0]\n        return feature_global.reshape(bs, g, self.encoder_channel)\n\n\nclass Group(nn.Module):\n    def __init__(self, num_group, group_size):\n        super().__init__()\n        self.num_group = num_group\n        self.group_size = group_size\n        self.knn = KNN(k=self.group_size, transpose_mode=True)\n        self.knn_2 = KNN(k=1, transpose_mode=True)\n\n    def simplied_morton_sorting(self, xyz, center):\n        '''\n        Simplifying the Morton code sorting to iterate and set the nearest patch to the last patch as the next patch, we found this to be more efficient.\n        '''\n        batch_size, num_points, _ = xyz.shape\n        distances_batch = torch.cdist(center, center)\n        distances_batch[:, torch.eye(self.num_group).bool()] = float(\"inf\")\n        idx_base = torch.arange(\n            0, batch_size, device=xyz.device) * self.num_group\n        sorted_indices_list = []\n        sorted_indices_list.append(idx_base)\n        distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(\n            1, 2).contiguous().view(batch_size * self.num_group, self.num_group)\n        distances_batch[idx_base] = float(\"inf\")\n        distances_batch = distances_batch.view(\n            batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()\n        for i in range(self.num_group - 1):\n            distances_batch = distances_batch.view(\n                batch_size * self.num_group, self.num_group)\n            distances_to_last_batch = distances_batch[sorted_indices_list[-1]]\n            closest_point_idx = torch.argmin(distances_to_last_batch, dim=-1)\n            closest_point_idx = closest_point_idx + idx_base\n            sorted_indices_list.append(closest_point_idx)\n            distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(\n                1, 2).contiguous().view(batch_size * self.num_group, self.num_group)\n            distances_batch[closest_point_idx] = float(\"inf\")\n            distances_batch = distances_batch.view(\n                batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()\n        sorted_indices = torch.stack(sorted_indices_list, dim=-1)\n        sorted_indices = sorted_indices.view(-1)\n        return sorted_indices\n\n    def morton_sorting(self, xyz, center):\n        batch_size, num_points, _ = xyz.shape\n        all_indices = []\n        for index in range(batch_size):\n            points = center[index]\n            z = get_z_values(points.cpu().numpy())\n            idxs = np.zeros((self.num_group), dtype=np.int32)\n            temp = np.arange(self.num_group)\n            z_ind = np.argsort(z[temp])\n            idxs = temp[z_ind]\n            all_indices.append(idxs)\n        all_indices = torch.tensor(all_indices, device=xyz.device)\n\n        idx_base = torch.arange(\n            0, batch_size, device=xyz.device).view(-1, 1) * self.num_group\n        sorted_indices = all_indices + idx_base\n        sorted_indices = sorted_indices.view(-1)\n\n    def forward(self, xyz):\n        '''\n            input: B N 3\n            ---------------------------\n            output: B G M 3\n            center : B G 3\n        '''\n        batch_size, num_points, _ = xyz.shape\n        # fps the centers out\n        center = misc.fps(xyz, self.num_group)  # B G 3\n        # knn to get the neighborhood\n        _, idx = self.knn(xyz, center)  # B G M\n        assert idx.size(1) == self.num_group\n        assert idx.size(2) == self.group_size\n        idx_base = torch.arange(\n            0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points\n        idx = idx + idx_base\n        idx = idx.view(-1)\n        neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]\n        neighborhood = neighborhood.view(\n            batch_size, self.num_group, self.group_size, 3).contiguous()\n        # normalize\n        neighborhood = neighborhood - center.unsqueeze(2)\n\n        # can utilize morton_sorting by choosing morton_sorting function\n        sorted_indices = self.simplied_morton_sorting(xyz, center)\n\n        neighborhood = neighborhood.view(\n            batch_size * self.num_group, self.group_size, 3)[sorted_indices, :, :]\n        neighborhood = neighborhood.view(\n            batch_size, self.num_group, self.group_size, 3).contiguous()\n        center = center.view(\n            batch_size * self.num_group, 3)[sorted_indices, :]\n        center = center.view(\n            batch_size, self.num_group, 3).contiguous()\n\n        return neighborhood, center\n\n\n# Transformers\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = qk_scale or head_dim ** -0.5\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //\n                                  self.num_heads).permute(2, 0, 3, 1, 4)\n        # make torchscript happy (cannot use tensor as tuple)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer, drop=drop)\n\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass PositionEmbeddingCoordsSine(nn.Module):\n    \"\"\"Similar to transformer's position encoding, but generalizes it to\n    arbitrary dimensions and continuous coordinates.\n\n    Args:\n        n_dim: Number of input dimensions, e.g. 2 for image coordinates.\n        d_model: Number of dimensions to encode into\n        temperature:\n        scale:\n    \"\"\"\n\n    def __init__(self, n_dim: int = 1, d_model: int = 256, temperature=10000, scale=None):\n        super().__init__()\n\n        self.n_dim = n_dim\n        self.num_pos_feats = d_model // n_dim // 2 * 2\n        self.temperature = temperature\n        self.padding = d_model - self.num_pos_feats * self.n_dim\n\n        if scale is None:\n            scale = 1.0\n        self.scale = scale * 2 * math.pi\n\n    def forward(self, xyz: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            xyz: Point positions (*, d_in)\n\n        Returns:\n            pos_emb (*, d_out)\n        \"\"\"\n        assert xyz.shape[-1] == self.n_dim\n\n        dim_t = torch.arange(self.num_pos_feats,\n                             dtype=torch.float32, device=xyz.device)\n        dim_t = self.temperature ** (2 * torch.div(dim_t,\n                                     2, rounding_mode='trunc') / self.num_pos_feats)\n\n        xyz = xyz * self.scale\n        pos_divided = xyz.unsqueeze(-1) / dim_t\n        pos_sin = pos_divided[..., 0::2].sin()\n        pos_cos = pos_divided[..., 1::2].cos()\n        pos_emb = torch.stack([pos_sin, pos_cos], dim=-\n                              1).reshape(*xyz.shape[:-1], -1)\n\n        # Pad unused dimensions with zeros\n        pos_emb = F.pad(pos_emb, (0, self.padding))\n        return pos_emb\n\n\nclass GPT_Transformer(nn.Module):\n    def __init__(self, config, **kwargs):\n        super().__init__()\n        self.config = config\n        # define the transformer argparse\n        self.mask_ratio = config.transformer_config.mask_ratio\n        self.trans_dim = config.transformer_config.trans_dim\n        self.depth = config.transformer_config.depth\n        self.decoder_depth = config.transformer_config.decoder_depth\n        self.drop_path_rate = config.transformer_config.drop_path_rate\n        self.num_heads = config.transformer_config.num_heads\n        self.group_size = config.group_size\n        print_log(f'[args] {config.transformer_config}', logger='Transformer')\n\n        self.encoder_dims = config.transformer_config.encoder_dims\n\n        assert self.encoder_dims in [384, 768, 1024]\n        if self.encoder_dims == 384:\n            self.encoder = Encoder_small(encoder_channel=self.encoder_dims)\n        else:\n            self.encoder = Encoder_large(encoder_channel=self.encoder_dims)\n\n        self.pos_embed = PositionEmbeddingCoordsSine(3, self.encoder_dims, 1.0)\n\n        self.blocks = GPT_extractor(\n            embed_dim=self.encoder_dims,\n            num_heads=self.num_heads,\n            num_layers=self.depth,\n            num_classes=config.cls_dim,\n            trans_dim=self.trans_dim,\n            group_size=self.group_size,\n            pretrained=True,\n        )\n\n        self.generator_blocks = GPT_generator(\n            embed_dim=self.encoder_dims,\n            num_heads=self.num_heads,\n            num_layers=self.decoder_depth,\n            trans_dim=self.trans_dim,\n            group_size=self.group_size\n        )\n\n        # do not perform additional mask on the first (self.keep_attend) tokens\n        self.keep_attend = 10\n        self.num_groups = config.num_group\n        self.num_mask = int(\n            (self.num_groups - self.keep_attend) * self.mask_ratio)\n\n        self.sos_pos = nn.Parameter(torch.zeros(1, 1, self.trans_dim))\n\n        self.norm = nn.LayerNorm(self.trans_dim)\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv1d):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, neighborhood, center, noaug=False, classify=False):\n        # generate mask\n\n        group_input_tokens = self.encoder(neighborhood)  # B G C\n\n        batch_size, seq_len, C = group_input_tokens.size()\n\n        relative_position = center[:, 1:, :] - center[:, :-1, :]\n        relative_norm = torch.norm(relative_position, dim=-1, keepdim=True)\n        relative_direction = relative_position / relative_norm\n        position = torch.cat(\n            [center[:, 0, :].unsqueeze(1), relative_direction], dim=1)\n        pos_relative = self.pos_embed(position)\n\n        sos_pos = self.sos_pos.expand(group_input_tokens.size(0), -1, -1)\n        pos_absolute = self.pos_embed(center[:, :-1, :])\n        pos_absolute = torch.cat([sos_pos, pos_absolute], dim=1)\n\n        attn_mask = torch.full(\n            (seq_len, seq_len), -float(\"Inf\"), device=group_input_tokens.device, dtype=group_input_tokens.dtype\n        ).to(torch.bool)\n\n        with torch.no_grad():\n            attn_mask = torch.triu(attn_mask, diagonal=1)\n\n            # point wise\n            # overall_mask = np.zeros([self.num_groups, self.num_groups])\n            # for i in range(self.num_groups):\n            #     mask = np.hstack([\n            #         np.zeros(self.num_groups-self.num_mask),\n            #         np.ones(self.num_mask),\n            #     ])\n            #     np.random.shuffle(mask)\n            #     overall_mask[i, :] = mask\n            # overall_mask = torch.from_numpy(\n            #     overall_mask).to(torch.bool).to('cuda')\n\n            # column wise\n            overall_mask = np.hstack([\n                np.zeros(self.num_groups-self.keep_attend-self.num_mask),\n                np.ones(self.num_mask),\n            ])\n            np.random.shuffle(overall_mask)\n            overall_mask = np.hstack([\n                np.zeros(self.keep_attend),\n                overall_mask,\n            ])\n            overall_mask = torch.from_numpy(\n                overall_mask).to(torch.bool).to('cuda')\n\n            eye_mask = torch.eye(self.num_groups).to(torch.bool).to('cuda')\n\n            attn_mask = attn_mask | overall_mask.unsqueeze(0) & ~eye_mask\n\n        # transformer\n        if classify == False:\n            encoded_features = self.blocks(\n                group_input_tokens, pos_absolute, attn_mask, classify=classify)\n            generated_points = self.generator_blocks(\n                encoded_features, pos_relative, attn_mask)\n            return generated_points\n        else:\n            print('----error---- This code is detached ----error----')\n            logits, generated_points = self.blocks(\n                group_input_tokens, pos_absolute, classify=classify)\n            return logits, generated_points\n\n\n@MODELS.register_module()\nclass PointGPT(nn.Module):\n    def __init__(self, config):\n        super().__init__()\n        print_log(f'[PointGPT] ', logger='PointGPT')\n        self.config = config\n        self.trans_dim = config.transformer_config.trans_dim\n        self.GPT_Transformer = GPT_Transformer(config)\n        self.group_size = config.group_size\n        self.num_group = config.num_group\n        self.drop_path_rate = config.transformer_config.drop_path_rate\n        self.weight_center = config.weight_center\n\n        print_log(\n            f'[PointGPT] divide point cloud into G{self.num_group} x S{self.group_size} points ...', logger='PointGPT')\n        self.group_divider = Group(\n            num_group=self.num_group, group_size=self.group_size)\n\n        self.loss = config.loss\n\n        self.build_loss_func(self.loss)\n\n    def build_loss_func(self, loss_type):\n        if loss_type == \"cdl1\":\n            self.loss_func_p = ChamferDistanceL1().cuda()\n        elif loss_type == 'cdl2':\n            self.loss_func_p = ChamferDistanceL2().cuda()\n        elif loss_type == 'cdl12':\n            self.loss_func_p1 = ChamferDistanceL1().cuda()\n            self.loss_func_p2 = ChamferDistanceL2().cuda()\n        else:\n            raise NotImplementedError\n        self.loss_func_c = nn.MSELoss().cuda()\n\n    def forward(self, pts, vis=False, **kwargs):\n        neighborhood, center = self.group_divider(pts)\n\n        B = neighborhood.shape[0]\n\n        generated_points = self.GPT_Transformer(\n            neighborhood, center)\n\n        gt_points = neighborhood.reshape(\n            B*(self.num_group), self.group_size, 3)\n        loss1 = self.loss_func_p1(generated_points, gt_points)\n        loss2 = self.loss_func_p2(generated_points, gt_points)\n\n        if vis:  # visualization\n            gt_points = gt_points.reshape(\n                B, self.num_group, self.group_size, 3)\n            gt_points = (gt_points + center.unsqueeze(-2)\n                         ).reshape(-1, 3).unsqueeze(0)\n            generated_points = generated_points.reshape(\n                B, self.num_group, self.group_size, 3) + center.unsqueeze(-2)\n            generated_points = generated_points.reshape(-1, 3).unsqueeze(0)\n\n            return generated_points, gt_points, center\n\n        return loss1 + loss2\n\n\n@MODELS.register_module()\nclass PointTransformer(nn.Module):\n    def __init__(self, config, **kwargs):\n        super().__init__()\n        self.config = config\n\n        self.trans_dim = config.trans_dim\n        self.depth = config.depth\n        self.decoder_depth = config.decoder_depth\n        self.drop_path_rate = config.drop_path_rate\n        self.cls_dim = config.cls_dim\n        self.num_heads = config.num_heads\n\n        self.group_size = config.group_size\n        self.num_group = config.num_group\n        self.encoder_dims = config.encoder_dims\n\n        self.group_divider = Group(\n            num_group=self.num_group, group_size=self.group_size)\n\n        assert self.encoder_dims in [384, 768, 1024]\n        if self.encoder_dims == 384:\n            self.encoder = Encoder_small(encoder_channel=self.encoder_dims)\n        else:\n            self.encoder = Encoder_large(encoder_channel=self.encoder_dims)\n\n        self.pos_embed = PositionEmbeddingCoordsSine(3, self.encoder_dims, 1.0)\n\n        self.blocks = GPT_extractor(\n            embed_dim=self.encoder_dims,\n            num_heads=self.num_heads,\n            num_layers=self.depth,\n            num_classes=config.cls_dim,\n            trans_dim=self.trans_dim,\n            group_size=self.group_size\n        )\n\n        self.generator_blocks = GPT_generator(\n            embed_dim=self.encoder_dims,\n            num_heads=self.num_heads,\n            num_layers=self.decoder_depth,\n            trans_dim=self.trans_dim,\n            group_size=self.group_size\n        )\n\n        self.norm = nn.LayerNorm(self.trans_dim)\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))\n        self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))\n\n        self.sos_pos = nn.Parameter(torch.zeros(1, 1, self.trans_dim))\n\n        self.norm = nn.LayerNorm(self.trans_dim)\n\n        self.build_loss_func()\n\n        trunc_normal_(self.cls_token, std=.02)\n        trunc_normal_(self.cls_pos, std=.02)\n\n    def build_loss_func(self, loss_type='cdl12'):\n        self.loss_ce = nn.CrossEntropyLoss()\n        if loss_type == \"cdl1\":\n            self.loss_func_p = ChamferDistanceL1().cuda()\n        elif loss_type == 'cdl2':\n            self.loss_func_p = ChamferDistanceL2().cuda()\n        elif loss_type == 'cdl12':\n            self.loss_func_p1 = ChamferDistanceL1().cuda()\n            self.loss_func_p2 = ChamferDistanceL2().cuda()\n        else:\n            raise NotImplementedError\n        self.loss_ce = nn.CrossEntropyLoss()\n\n    def get_loss_acc(self, ret, gt):\n        loss = self.loss_ce(ret, gt.long())\n        pred = ret.argmax(-1)\n        acc = (pred == gt).sum() / float(gt.size(0))\n        return loss, acc * 100\n\n    def load_model_from_ckpt(self, bert_ckpt_path):\n        if bert_ckpt_path is not None:\n            ckpt = torch.load(bert_ckpt_path)\n            base_ckpt = {k.replace(\"module.\", \"\"): v for k,\n                         v in ckpt['base_model'].items()}\n\n            for k in list(base_ckpt.keys()):\n                if k.startswith('GPT_Transformer'):\n                    base_ckpt[k[len('GPT_Transformer.'):]] = base_ckpt[k]\n                    del base_ckpt[k]\n                elif k.startswith('base_model'):\n                    base_ckpt[k[len('base_model.'):]] = base_ckpt[k]\n                    del base_ckpt[k]\n                if 'cls_head_finetune' in k:\n                    del base_ckpt[k]\n\n            incompatible = self.load_state_dict(base_ckpt, strict=False)\n\n            if incompatible.missing_keys:\n                print_log('missing_keys', logger='Transformer')\n                print_log(\n                    get_missing_parameters_message(incompatible.missing_keys),\n                    logger='Transformer'\n                )\n            if incompatible.unexpected_keys:\n                print_log('unexpected_keys', logger='Transformer')\n                print_log(\n                    get_unexpected_parameters_message(\n                        incompatible.unexpected_keys),\n                    logger='Transformer'\n                )\n\n            print_log(\n                f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}', logger='Transformer')\n        else:\n            print_log('Training from scratch!!!', logger='Transformer')\n            self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv1d):\n            trunc_normal_(m.weight, std=.02)\n            if m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, pts):\n\n        neighborhood, center = self.group_divider(pts)\n        group_input_tokens = self.encoder(neighborhood)  # B G N\n\n        B, L, _ = group_input_tokens.shape\n\n        cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)\n        cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)\n\n        pos = self.pos_embed(center)\n        sos_pos = self.sos_pos.expand(group_input_tokens.size(0), -1, -1)\n        pos = torch.cat([sos_pos, pos], dim=1)\n\n        relative_position = center[:, 1:, :] - center[:, :-1, :]\n        relative_norm = torch.norm(relative_position, dim=-1, keepdim=True)\n        relative_direction = relative_position / relative_norm\n        position = torch.cat(\n            [center[:, 0, :].unsqueeze(1), relative_direction], dim=1)\n        pos_relative = self.pos_embed(position)\n\n        x = torch.cat((cls_tokens, group_input_tokens), dim=1)\n        pos = torch.cat((cls_pos, pos), dim=1)\n\n        attn_mask = torch.full(\n            (L+2, L+2), -float(\"Inf\"), device=group_input_tokens.device, dtype=group_input_tokens.dtype\n        ).to(torch.bool)\n\n        attn_mask = torch.triu(attn_mask, diagonal=1)\n\n        # transformer\n        ret, encoded_features = self.blocks(x, pos, attn_mask, classify=True)\n\n        encoded_features = torch.cat(\n            [encoded_features[:, 0, :].unsqueeze(1), encoded_features[:, 2:-1, :]], dim=1)\n\n        attn_mask = torch.full(\n            (L, L), -float(\"Inf\"), device=group_input_tokens.device, dtype=group_input_tokens.dtype\n        ).to(torch.bool)\n\n        attn_mask = torch.triu(attn_mask, diagonal=1)\n\n        generated_points = self.generator_blocks(\n            encoded_features, pos_relative, attn_mask)\n\n        neighborhood = neighborhood + center.unsqueeze(2)\n\n        gt_points = neighborhood.reshape(\n            B*(self.num_group), self.group_size, 3)\n\n        loss1 = self.loss_func_p1(generated_points, gt_points)\n        loss2 = self.loss_func_p2(generated_points, gt_points)\n\n        return ret, loss1 + loss2\n"
  },
  {
    "path": "models/__init__.py",
    "content": "from .build import build_model_from_cfg\nimport models.PointGPT\n"
  },
  {
    "path": "models/build.py",
    "content": "from utils import registry\n\n\nMODELS = registry.Registry('models')\n\n\ndef build_model_from_cfg(cfg, **kwargs):\n    \"\"\"\n    Build a dataset, defined by `dataset_name`.\n    Args:\n        cfg (eDICT): \n    Returns:\n        Dataset: a constructed dataset specified by dataset_name.\n    \"\"\"\n    return MODELS.build(cfg, **kwargs)\n\n\n"
  },
  {
    "path": "models/z_order.py",
    "content": "import numpy as np\n\n\ndef round_to_int_32(data):\n    \"\"\"\n    Takes a Numpy array of float values between\n    -1 and 1, and rounds them to significant\n    32-bit integer values, to be used in the\n    morton code computation\n\n    :param data: multidimensional numpy array\n    :return: same as data but in 32-bit int format\n    \"\"\"\n    # first we rescale points to 0-512\n    min_data = np.abs(np.min(data)-0.5)\n    data = 256*(data + min_data)\n    # now convert to int\n    data = np.round(2 ** 21 - data).astype(dtype=np.int32)\n\n    return data\n\n\ndef split_by_3(x):\n    \"\"\"\n    Method to separate bits of a 32-bit integer\n    by 3 positions apart, using the magic bits\n    https://www.forceflow.be/2013/10/07/morton-encodingdecoding-through-bit-interleaving-implementations/\n\n    :param x: 32-bit integer\n    :return: x with bits separated\n    \"\"\"\n    # we only look at 21 bits, since we want to generate\n    # a 64-bit code eventually (3 x 21 bits = 63 bits, which\n    # is the maximum we can fit in a 64-bit code)\n    x &= 0x1fffff  # only take first 21 bits\n    # shift left 32 bits, OR with self, and 00011111000000000000000000000000000000001111111111111111\n    x = (x | (x << 32)) & 0x1f00000000ffff\n    # shift left 16 bits, OR with self, and 00011111000000000000000011111111000000000000000011111111\n    x = (x | (x << 16)) & 0x1f0000ff0000ff\n    # shift left 8 bits, OR with self, and 0001000000001111000000001111000000001111000000001111000000000000\n    x = (x | (x << 8)) & 0x100f00f00f00f00f\n    # shift left 4 bits, OR with self, and 0001000011000011000011000011000011000011000011000011000100000000\n    x = (x | (x << 4)) & 0x10c30c30c30c30c3\n    # shift left 2 bits, OR with self, and 0001001001001001001001001001001001001001001001001001001001001001\n    x = (x | (x << 2)) & 0x1249249249249249\n\n    return x\n\n\ndef get_z_order(x, y, z):\n    \"\"\"\n    Given 3 arrays of corresponding x, y, z\n    coordinates, compute the morton (or z) code for\n    each point and return an index array\n    We compute the Morton order as follows:\n        1- Split all coordinates by 3 (add 2 zeros between bits)\n        2- Shift bits left by 1 for y and 2 for z\n        3- Interleave x, shifted y, and shifted z\n    The mordon order is the final interleaved bit sequence\n\n    :param x: x coordinates\n    :param y: y coordinates\n    :param z: z coordinates\n    :return: index array with morton code\n    \"\"\"\n    res = 0\n    res |= split_by_3(x) | split_by_3(y) << 1 | split_by_3(z) << 2\n\n    return res\n\n\ndef get_z_values(data):\n    \"\"\"\n    Computes the z values for a point array\n    :param data: Nx3 array of x, y, and z location\n\n    :return: Nx1 array of z values\n    \"\"\"\n    points_round = round_to_int_32(data)  # convert to int\n    z = get_z_order(points_round[:, 0], points_round[:, 1], points_round[:, 2])\n\n    return z\n"
  },
  {
    "path": "requirements.txt",
    "content": "argparse\neasydict\nh5py\nmatplotlib\nnumpy\nopen3d==0.9\nopencv-python\npyyaml\nscipy\ntensorboardX\ntimm==0.4.5\ntqdm\ntransforms3d\ntermcolor"
  },
  {
    "path": "segmentation/__init__.py",
    "content": ""
  },
  {
    "path": "segmentation/dataset.py",
    "content": "import numpy as np\nimport os\nfrom torch.utils.data import Dataset\nimport torch\nfrom pointnet_util import farthest_point_sample, pc_normalize\nimport json\n\n\nclass ModelNetDataLoader(Dataset):\n    def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):\n        self.root = root\n        self.npoints = npoint\n        self.uniform = uniform\n        self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')\n\n        self.cat = [line.rstrip() for line in open(self.catfile)]\n        self.classes = dict(zip(self.cat, range(len(self.cat))))\n        self.normal_channel = normal_channel\n\n        shape_ids = {}\n        shape_ids['train'] = [line.rstrip() for line in open(\n            os.path.join(self.root, 'modelnet40_train.txt'))]\n        shape_ids['test'] = [line.rstrip() for line in open(\n            os.path.join(self.root, 'modelnet40_test.txt'))]\n\n        assert (split == 'train' or split == 'test')\n        shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]\n        # list of (shape_name, shape_txt_file_path) tuple\n        self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i\n                         in range(len(shape_ids[split]))]\n        print('The size of %s data is %d' % (split, len(self.datapath)))\n\n        self.cache_size = cache_size  # how many data points to cache in memory\n        self.cache = {}  # from index to (point_set, cls) tuple\n\n    def __len__(self):\n        return len(self.datapath)\n\n    def _get_item(self, index):\n        if index in self.cache:\n            point_set, cls = self.cache[index]\n        else:\n            fn = self.datapath[index]\n            cls = self.classes[self.datapath[index][0]]\n            cls = np.array([cls]).astype(np.int32)\n            point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)\n            if self.uniform:\n                point_set = farthest_point_sample(point_set, self.npoints)\n            else:\n                point_set = point_set[0:self.npoints, :]\n\n            point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])\n\n            if not self.normal_channel:\n                point_set = point_set[:, 0:3]\n\n            if len(self.cache) < self.cache_size:\n                self.cache[index] = (point_set, cls)\n\n        return point_set, cls\n\n    def __getitem__(self, index):\n        return self._get_item(index)\n\n\nclass PartNormalDataset(Dataset):\n    def __init__(self, root='/data/cgy/ShapenetPart/shapenetcore_partanno_segmentation_benchmark_v0_normal', npoints=2500, split='train', class_choice=None, normal_channel=False):\n        self.npoints = npoints\n        self.root = root\n        self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')\n        self.cat = {}\n        self.normal_channel = normal_channel\n\n        with open(self.catfile, 'r') as f:\n            for line in f:\n                ls = line.strip().split()\n                self.cat[ls[0]] = ls[1]\n        self.cat = {k: v for k, v in self.cat.items()}\n        self.classes_original = dict(zip(self.cat, range(len(self.cat))))\n\n        if not class_choice is None:\n            self.cat = {k: v for k, v in self.cat.items() if k in class_choice}\n        # print(self.cat)\n\n        self.meta = {}\n        with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:\n            train_ids = set([str(d.split('/')[2]) for d in json.load(f)])\n        with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:\n            val_ids = set([str(d.split('/')[2]) for d in json.load(f)])\n        with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:\n            test_ids = set([str(d.split('/')[2]) for d in json.load(f)])\n        for item in self.cat:\n            # print('category', item)\n            self.meta[item] = []\n            dir_point = os.path.join(self.root, self.cat[item])\n            fns = sorted(os.listdir(dir_point))\n            # print(fns[0][0:-4])\n            if split == 'trainval':\n                fns = [fn for fn in fns if (\n                    (fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]\n            elif split == 'train':\n                fns = [fn for fn in fns if fn[0:-4] in train_ids]\n            elif split == 'val':\n                fns = [fn for fn in fns if fn[0:-4] in val_ids]\n            elif split == 'test':\n                fns = [fn for fn in fns if fn[0:-4] in test_ids]\n            else:\n                print('Unknown split: %s. Exiting..' % (split))\n                exit(-1)\n\n            # print(os.path.basename(fns))\n            for fn in fns:\n                token = (os.path.splitext(os.path.basename(fn))[0])\n                self.meta[item].append(os.path.join(dir_point, token + '.txt'))\n\n        self.datapath = []\n        for item in self.cat:\n            for fn in self.meta[item]:\n                self.datapath.append((item, fn))\n\n        self.classes = {}\n        for i in self.cat.keys():\n            self.classes[i] = self.classes_original[i]\n\n        # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels\n        self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],\n                            'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46],\n                            'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27],\n                            'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40],\n                            'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}\n\n        # for cat in sorted(self.seg_classes.keys()):\n        #     print(cat, self.seg_classes[cat])\n\n        self.cache = {}  # from index to (point_set, cls, seg) tuple\n        self.cache_size = 20000\n\n    def __getitem__(self, index):\n        if index in self.cache:\n            point_set, cls, seg = self.cache[index]\n        else:\n            fn = self.datapath[index]\n            cat = self.datapath[index][0]\n            cls = self.classes[cat]\n            cls = np.array([cls]).astype(np.int32)\n            data = np.loadtxt(fn[1]).astype(np.float32)\n            if not self.normal_channel:\n                point_set = data[:, 0:3]\n            else:\n                point_set = data[:, 0:6]\n            seg = data[:, -1].astype(np.int32)\n            if len(self.cache) < self.cache_size:\n                self.cache[index] = (point_set, cls, seg)\n        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])\n\n        choice = np.random.choice(len(seg), self.npoints, replace=True)\n        # resample\n        point_set = point_set[choice, :]\n        seg = seg[choice]\n\n        return point_set, cls, seg\n\n    def __len__(self):\n        return len(self.datapath)\n\n\nif __name__ == '__main__':\n    data = ModelNetDataLoader('modelnet40_normal_resampled/',\n                              split='train', uniform=False, normal_channel=True)\n    DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)\n    for point, label in DataLoader:\n        print(point.shape)\n        print(label.shape)\n"
  },
  {
    "path": "segmentation/extensions/chamfer_dist/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Thibault GROUEIX\n# @Date:   2019-08-07 20:54:24\n# @Last Modified by:   Haozhe Xie\n# @Last Modified time: 2019-12-18 15:06:25\n# @Email:  cshzxie@gmail.com\n\nimport torch\n\nimport chamfer\n\n\nclass ChamferFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, xyz1, xyz2):\n        dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2)\n        ctx.save_for_backward(xyz1, xyz2, idx1, idx2)\n\n        return dist1, dist2\n\n    @staticmethod\n    def backward(ctx, grad_dist1, grad_dist2):\n        xyz1, xyz2, idx1, idx2 = ctx.saved_tensors\n        grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2)\n        return grad_xyz1, grad_xyz2\n\n\nclass ChamferDistanceL2(torch.nn.Module):\n    f''' Chamder Distance L2\n    '''\n    def __init__(self, ignore_zeros=False):\n        super().__init__()\n        self.ignore_zeros = ignore_zeros\n\n    def forward(self, xyz1, xyz2):\n        batch_size = xyz1.size(0)\n        if batch_size == 1 and self.ignore_zeros:\n            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)\n            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)\n            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)\n            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)\n\n        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)\n        return torch.mean(dist1) + torch.mean(dist2)\n\nclass ChamferDistanceL2_split(torch.nn.Module):\n    f''' Chamder Distance L2\n    '''\n    def __init__(self, ignore_zeros=False):\n        super().__init__()\n        self.ignore_zeros = ignore_zeros\n\n    def forward(self, xyz1, xyz2):\n        batch_size = xyz1.size(0)\n        if batch_size == 1 and self.ignore_zeros:\n            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)\n            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)\n            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)\n            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)\n\n        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)\n        return torch.mean(dist1), torch.mean(dist2)\n\nclass ChamferDistanceL1(torch.nn.Module):\n    f''' Chamder Distance L1\n    '''\n    def __init__(self, ignore_zeros=False):\n        super().__init__()\n        self.ignore_zeros = ignore_zeros\n\n    def forward(self, xyz1, xyz2):\n        batch_size = xyz1.size(0)\n        if batch_size == 1 and self.ignore_zeros:\n            non_zeros1 = torch.sum(xyz1, dim=2).ne(0)\n            non_zeros2 = torch.sum(xyz2, dim=2).ne(0)\n            xyz1 = xyz1[non_zeros1].unsqueeze(dim=0)\n            xyz2 = xyz2[non_zeros2].unsqueeze(dim=0)\n\n        dist1, dist2 = ChamferFunction.apply(xyz1, xyz2)\n        # import pdb\n        # pdb.set_trace()\n        dist1 = torch.sqrt(dist1)\n        dist2 = torch.sqrt(dist2)\n        return (torch.mean(dist1) + torch.mean(dist2))/2\n\n"
  },
  {
    "path": "segmentation/extensions/chamfer_dist/chamfer.cu",
    "content": "/*\n * @Author: Haozhe Xie\n * @Date:   2019-08-07 20:54:24\n * @Last Modified by:   Haozhe Xie\n * @Last Modified time: 2020-06-17 14:58:55\n * @Email:  cshzxie@gmail.com\n */\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <torch/extension.h>\n\n#include <vector>\n\n__global__ void chamfer_dist_kernel(int batch_size,\n                                    int n,\n                                    const float* xyz1,\n                                    int m,\n                                    const float* xyz2,\n                                    float* dist,\n                                    int* indexes) {\n  const int batch = 512;\n  __shared__ float buf[batch * 3];\n  for (int i = blockIdx.x; i < batch_size; i += gridDim.x) {\n    for (int k2 = 0; k2 < m; k2 += batch) {\n      int end_k = min(m, k2 + batch) - k2;\n      for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) {\n        buf[j] = xyz2[(i * m + k2) * 3 + j];\n      }\n      __syncthreads();\n      for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;\n           j += blockDim.x * gridDim.y) {\n        float x1            = xyz1[(i * n + j) * 3 + 0];\n        float y1            = xyz1[(i * n + j) * 3 + 1];\n        float z1            = xyz1[(i * n + j) * 3 + 2];\n        float best_dist     = 0;\n        int best_dist_index = 0;\n        int end_ka          = end_k - (end_k & 3);\n        if (end_ka == batch) {\n          for (int k = 0; k < batch; k += 4) {\n            {\n              float x2   = buf[k * 3 + 0] - x1;\n              float y2   = buf[k * 3 + 1] - y1;\n              float z2   = buf[k * 3 + 2] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n\n              if (k == 0 || dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 3] - x1;\n              float y2   = buf[k * 3 + 4] - y1;\n              float z2   = buf[k * 3 + 5] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 1;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 6] - x1;\n              float y2   = buf[k * 3 + 7] - y1;\n              float z2   = buf[k * 3 + 8] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 2;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 9] - x1;\n              float y2   = buf[k * 3 + 10] - y1;\n              float z2   = buf[k * 3 + 11] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 3;\n              }\n            }\n          }\n        } else {\n          for (int k = 0; k < end_ka; k += 4) {\n            {\n              float x2   = buf[k * 3 + 0] - x1;\n              float y2   = buf[k * 3 + 1] - y1;\n              float z2   = buf[k * 3 + 2] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (k == 0 || dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 3] - x1;\n              float y2   = buf[k * 3 + 4] - y1;\n              float z2   = buf[k * 3 + 5] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 1;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 6] - x1;\n              float y2   = buf[k * 3 + 7] - y1;\n              float z2   = buf[k * 3 + 8] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 2;\n              }\n            }\n            {\n              float x2   = buf[k * 3 + 9] - x1;\n              float y2   = buf[k * 3 + 10] - y1;\n              float z2   = buf[k * 3 + 11] - z1;\n              float dist = x2 * x2 + y2 * y2 + z2 * z2;\n              if (dist < best_dist) {\n                best_dist       = dist;\n                best_dist_index = k + k2 + 3;\n              }\n            }\n          }\n        }\n        for (int k = end_ka; k < end_k; k++) {\n          float x2   = buf[k * 3 + 0] - x1;\n          float y2   = buf[k * 3 + 1] - y1;\n          float z2   = buf[k * 3 + 2] - z1;\n          float dist = x2 * x2 + y2 * y2 + z2 * z2;\n          if (k == 0 || dist < best_dist) {\n            best_dist       = dist;\n            best_dist_index = k + k2;\n          }\n        }\n        if (k2 == 0 || dist[(i * n + j)] > best_dist) {\n          dist[(i * n + j)]    = best_dist;\n          indexes[(i * n + j)] = best_dist_index;\n        }\n      }\n      __syncthreads();\n    }\n  }\n}\n\nstd::vector<torch::Tensor> chamfer_cuda_forward(torch::Tensor xyz1,\n                                                torch::Tensor xyz2) {\n  const int batch_size = xyz1.size(0);\n  const int n          = xyz1.size(1);  // num_points point cloud A\n  const int m          = xyz2.size(1);  // num_points point cloud B\n  torch::Tensor dist1 =\n    torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat));\n  torch::Tensor dist2 =\n    torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat));\n  torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt));\n  torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt));\n\n  chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(\n    batch_size, n, xyz1.data_ptr<float>(), m, xyz2.data_ptr<float>(),\n    dist1.data_ptr<float>(), idx1.data_ptr<int>());\n  chamfer_dist_kernel<<<dim3(32, 16, 1), 512>>>(\n    batch_size, m, xyz2.data_ptr<float>(), n, xyz1.data_ptr<float>(),\n    dist2.data_ptr<float>(), idx2.data_ptr<int>());\n\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess) {\n    printf(\"Error in chamfer_cuda_forward: %s\\n\", cudaGetErrorString(err));\n  }\n  return {dist1, dist2, idx1, idx2};\n}\n\n__global__ void chamfer_dist_grad_kernel(int b,\n                                         int n,\n                                         const float* xyz1,\n                                         int m,\n                                         const float* xyz2,\n                                         const float* grad_dist1,\n                                         const int* idx1,\n                                         float* grad_xyz1,\n                                         float* grad_xyz2) {\n  for (int i = blockIdx.x; i < b; i += gridDim.x) {\n    for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;\n         j += blockDim.x * gridDim.y) {\n      float x1 = xyz1[(i * n + j) * 3 + 0];\n      float y1 = xyz1[(i * n + j) * 3 + 1];\n      float z1 = xyz1[(i * n + j) * 3 + 2];\n      int j2   = idx1[i * n + j];\n      float x2 = xyz2[(i * m + j2) * 3 + 0];\n      float y2 = xyz2[(i * m + j2) * 3 + 1];\n      float z2 = xyz2[(i * m + j2) * 3 + 2];\n      float g  = grad_dist1[i * n + j] * 2;\n      atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));\n      atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));\n      atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));\n      atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));\n      atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));\n      atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));\n    }\n  }\n}\n\nstd::vector<torch::Tensor> chamfer_cuda_backward(torch::Tensor xyz1,\n                                                 torch::Tensor xyz2,\n                                                 torch::Tensor idx1,\n                                                 torch::Tensor idx2,\n                                                 torch::Tensor grad_dist1,\n                                                 torch::Tensor grad_dist2) {\n  const int batch_size    = xyz1.size(0);\n  const int n             = xyz1.size(1);  // num_points point cloud A\n  const int m             = xyz2.size(1);  // num_points point cloud B\n  torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat));\n  torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat));\n\n  chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>(\n    batch_size, n, xyz1.data_ptr<float>(), m, xyz2.data_ptr<float>(),\n    grad_dist1.data_ptr<float>(), idx1.data_ptr<int>(),\n    grad_xyz1.data_ptr<float>(), grad_xyz2.data_ptr<float>());\n  chamfer_dist_grad_kernel<<<dim3(1, 16, 1), 256>>>(\n    batch_size, m, xyz2.data_ptr<float>(), n, xyz1.data_ptr<float>(),\n    grad_dist2.data_ptr<float>(), idx2.data_ptr<int>(),\n    grad_xyz2.data_ptr<float>(), grad_xyz1.data_ptr<float>());\n\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess) {\n    printf(\"Error in chamfer_cuda_backward: %s\\n\", cudaGetErrorString(err));\n  }\n  return {grad_xyz1, grad_xyz2};\n}\n"
  },
  {
    "path": "segmentation/extensions/chamfer_dist/chamfer_cuda.cpp",
    "content": "/*\n * @Author: Haozhe Xie\n * @Date:   2019-08-07 20:54:24\n * @Last Modified by:   Haozhe Xie\n * @Last Modified time: 2019-12-10 10:33:50\n * @Email:  cshzxie@gmail.com\n */\n\n#include <torch/extension.h>\n#include <vector>\n\nstd::vector<torch::Tensor> chamfer_cuda_forward(torch::Tensor xyz1,\n                                                torch::Tensor xyz2);\n\nstd::vector<torch::Tensor> chamfer_cuda_backward(torch::Tensor xyz1,\n                                                 torch::Tensor xyz2,\n                                                 torch::Tensor idx1,\n                                                 torch::Tensor idx2,\n                                                 torch::Tensor grad_dist1,\n                                                 torch::Tensor grad_dist2);\n\nstd::vector<torch::Tensor> chamfer_forward(torch::Tensor xyz1,\n                                           torch::Tensor xyz2) {\n  return chamfer_cuda_forward(xyz1, xyz2);\n}\n\nstd::vector<torch::Tensor> chamfer_backward(torch::Tensor xyz1,\n                                            torch::Tensor xyz2,\n                                            torch::Tensor idx1,\n                                            torch::Tensor idx2,\n                                            torch::Tensor grad_dist1,\n                                            torch::Tensor grad_dist2) {\n  return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"forward\", &chamfer_forward, \"Chamfer forward (CUDA)\");\n  m.def(\"backward\", &chamfer_backward, \"Chamfer backward (CUDA)\");\n}\n"
  },
  {
    "path": "segmentation/extensions/chamfer_dist/setup.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Haozhe Xie\n# @Date:   2019-08-07 20:54:24\n# @Last Modified by:   Haozhe Xie\n# @Last Modified time: 2019-12-10 10:04:25\n# @Email:  cshzxie@gmail.com\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(name='chamfer',\n      version='2.0.0',\n      ext_modules=[\n          CUDAExtension('chamfer', [\n              'chamfer_cuda.cpp',\n              'chamfer.cu',\n          ]),\n      ],\n      cmdclass={'build_ext': BuildExtension})\n"
  },
  {
    "path": "segmentation/extensions/chamfer_dist/test.py",
    "content": "# -*- coding: utf-8 -*-\n# @Author: Haozhe Xie\n# @Date:   2019-12-10 10:38:01\n# @Last Modified by:   Haozhe Xie\n# @Last Modified time: 2019-12-26 14:21:36\n# @Email:  cshzxie@gmail.com\n#\n# Note:\n# - Replace float -> double, kFloat -> kDouble in chamfer.cu\n\nimport os\nimport sys\nimport torch\nimport unittest\n\n\nfrom torch.autograd import gradcheck\n\nsys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)))\nfrom extensions.chamfer_dist import ChamferFunction\n\n\nclass ChamferDistanceTestCase(unittest.TestCase):\n    def test_chamfer_dist(self):\n        x = torch.rand(4, 64, 3).double()\n        y = torch.rand(4, 128, 3).double()\n        x.requires_grad = True\n        y.requires_grad = True\n        print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()]))\n\n\n\nif __name__ == '__main__':\n    # unittest.main()\n    import pdb\n    x = torch.rand(32,128,3)\n    y = torch.rand(32,128,3)\n    pdb.set_trace()\n"
  },
  {
    "path": "segmentation/extensions/emd/README.md",
    "content": "# PyTorch Wrapper for Point-cloud Earth-Mover-Distance (EMD)\n\n## Dependency\n\nThe code has been tested on Ubuntu 16.04, PyTorch 1.1.0, CUDA 9.0.\n\n## Usage\n\nFirst compile using\n        \n        python setup.py install\n\nThen, copy the lib file out to the main directory,\n\n        cp build/lib.linux-x86_64-3.6/emd_cuda.cpython-36m-x86_64-linux-gnu.so .\n\nThen, you can use it by simply\n\n        from emd import earth_mover_distance\n        d = earth_mover_distance(p1, p2, transpose=False)  # p1: B x N1 x 3, p2: B x N2 x 3\n\nCheck `test_emd_loss.py` for example.\n\n## Author\n\nThe cuda code is originally written by Haoqiang Fan. The PyTorch wrapper is written by Kaichun Mo. Also, Jiayuan Gu provided helps.\n\n## License\n\nMIT\n\n"
  },
  {
    "path": "segmentation/extensions/emd/__init__.py",
    "content": "from .emd import earth_mover_distance as emd\n\n__all__ = ['emd']"
  },
  {
    "path": "segmentation/extensions/emd/cuda/emd.cpp",
    "content": "#ifndef _EMD\n#define _EMD\n\n#include <vector>\n#include <torch/extension.h>\n\n//CUDA declarations\nat::Tensor ApproxMatchForward(\n    const at::Tensor xyz1,\n    const at::Tensor xyz2);\n\nat::Tensor MatchCostForward(\n    const at::Tensor xyz1,\n    const at::Tensor xyz2,\n    const at::Tensor match);\n\nstd::vector<at::Tensor> MatchCostBackward(\n    const at::Tensor grad_cost,\n    const at::Tensor xyz1,\n    const at::Tensor xyz2,\n    const at::Tensor match);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"approxmatch_forward\", &ApproxMatchForward,\"ApproxMatch forward (CUDA)\");\n  m.def(\"matchcost_forward\", &MatchCostForward,\"MatchCost forward (CUDA)\");\n  m.def(\"matchcost_backward\", &MatchCostBackward,\"MatchCost backward (CUDA)\");\n}\n\n#endif\n"
  },
  {
    "path": "segmentation/extensions/emd/cuda/emd_kernel.cu",
    "content": "/**********************************\n * Original Author: Haoqiang Fan\n * Modified by: Kaichun Mo\n *********************************/\n\n#ifndef _EMD_KERNEL\n#define _EMD_KERNEL\n\n#include <cmath>\n#include <vector>\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAApplyUtils.cuh>  // at::cuda::getApplyGrid\n// #include <THC/THC.h>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\n\n/********************************\n* Forward kernel for approxmatch\n*********************************/\n\ntemplate<typename scalar_t>\n__global__ void approxmatch(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,scalar_t * __restrict__ match,scalar_t * temp){\n\tscalar_t * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;\n\tscalar_t multiL,multiR;\n\tif (n>=m){\n\t\tmultiL=1;\n\t\tmultiR=n/m;\n\t}else{\n\t\tmultiL=m/n;\n\t\tmultiR=1;\n\t}\n\tconst int Block=1024;\n\t__shared__ scalar_t buf[Block*4];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfor (int j=threadIdx.x;j<n*m;j+=blockDim.x)\n\t\t\tmatch[i*n*m+j]=0;\n\t\tfor (int j=threadIdx.x;j<n;j+=blockDim.x)\n\t\t\tremainL[j]=multiL;\n\t\tfor (int j=threadIdx.x;j<m;j+=blockDim.x)\n\t\t\tremainR[j]=multiR;\n\t\t__syncthreads();\n\t\tfor (int j=7;j>=-2;j--){\n\t\t\tscalar_t level=-powf(4.0f,j);\n\t\t\tif (j==-2){\n\t\t\t\tlevel=0;\n\t\t\t}\n\t\t\tfor (int k0=0;k0<n;k0+=blockDim.x){\n\t\t\t\tint k=k0+threadIdx.x;\n\t\t\t\tscalar_t x1=0,y1=0,z1=0;\n\t\t\t\tif (k<n){\n\t\t\t\t\tx1=xyz1[i*n*3+k*3+0];\n\t\t\t\t\ty1=xyz1[i*n*3+k*3+1];\n\t\t\t\t\tz1=xyz1[i*n*3+k*3+2];\n\t\t\t\t}\n\t\t\t\tscalar_t suml=1e-9f;\n\t\t\t\tfor (int l0=0;l0<m;l0+=Block){\n\t\t\t\t\tint lend=min(m,l0+Block)-l0;\n\t\t\t\t\tfor (int l=threadIdx.x;l<lend;l+=blockDim.x){\n\t\t\t\t\t\tscalar_t x2=xyz2[i*m*3+l0*3+l*3+0];\n\t\t\t\t\t\tscalar_t y2=xyz2[i*m*3+l0*3+l*3+1];\n\t\t\t\t\t\tscalar_t z2=xyz2[i*m*3+l0*3+l*3+2];\n\t\t\t\t\t\tbuf[l*4+0]=x2;\n\t\t\t\t\t\tbuf[l*4+1]=y2;\n\t\t\t\t\t\tbuf[l*4+2]=z2;\n\t\t\t\t\t\tbuf[l*4+3]=remainR[l0+l];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tfor (int l=0;l<lend;l++){\n\t\t\t\t\t\tscalar_t x2=buf[l*4+0];\n\t\t\t\t\t\tscalar_t y2=buf[l*4+1];\n\t\t\t\t\t\tscalar_t z2=buf[l*4+2];\n\t\t\t\t\t\tscalar_t d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));\n\t\t\t\t\t\tscalar_t w=__expf(d)*buf[l*4+3];\n\t\t\t\t\t\tsuml+=w;\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (k<n)\n\t\t\t\t\tratioL[k]=remainL[k]/suml;\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t\tfor (int l0=0;l0<m;l0+=blockDim.x){\n\t\t\t\tint l=l0+threadIdx.x;\n\t\t\t\tscalar_t x2=0,y2=0,z2=0;\n\t\t\t\tif (l<m){\n\t\t\t\t\tx2=xyz2[i*m*3+l*3+0];\n\t\t\t\t\ty2=xyz2[i*m*3+l*3+1];\n\t\t\t\t\tz2=xyz2[i*m*3+l*3+2];\n\t\t\t\t}\n\t\t\t\tscalar_t sumr=0;\n\t\t\t\tfor (int k0=0;k0<n;k0+=Block){\n\t\t\t\t\tint kend=min(n,k0+Block)-k0;\n\t\t\t\t\tfor (int k=threadIdx.x;k<kend;k+=blockDim.x){\n\t\t\t\t\t\tbuf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];\n\t\t\t\t\t\tbuf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];\n\t\t\t\t\t\tbuf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];\n\t\t\t\t\t\tbuf[k*4+3]=ratioL[k0+k];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tfor (int k=0;k<kend;k++){\n\t\t\t\t\t\tscalar_t x1=buf[k*4+0];\n\t\t\t\t\t\tscalar_t y1=buf[k*4+1];\n\t\t\t\t\t\tscalar_t z1=buf[k*4+2];\n\t\t\t\t\t\tscalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];\n\t\t\t\t\t\tsumr+=w;\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (l<m){\n\t\t\t\t\tsumr*=remainR[l];\n\t\t\t\t\tscalar_t consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);\n\t\t\t\t\tratioR[l]=consumption*remainR[l];\n\t\t\t\t\tremainR[l]=fmaxf(0.0f,remainR[l]-sumr);\n\t\t\t\t}\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t\tfor (int k0=0;k0<n;k0+=blockDim.x){\n\t\t\t\tint k=k0+threadIdx.x;\n\t\t\t\tscalar_t x1=0,y1=0,z1=0;\n\t\t\t\tif (k<n){\n\t\t\t\t\tx1=xyz1[i*n*3+k*3+0];\n\t\t\t\t\ty1=xyz1[i*n*3+k*3+1];\n\t\t\t\t\tz1=xyz1[i*n*3+k*3+2];\n\t\t\t\t}\n\t\t\t\tscalar_t suml=0;\n\t\t\t\tfor (int l0=0;l0<m;l0+=Block){\n\t\t\t\t\tint lend=min(m,l0+Block)-l0;\n\t\t\t\t\tfor (int l=threadIdx.x;l<lend;l+=blockDim.x){\n\t\t\t\t\t\tbuf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];\n\t\t\t\t\t\tbuf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];\n\t\t\t\t\t\tbuf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];\n\t\t\t\t\t\tbuf[l*4+3]=ratioR[l0+l];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tscalar_t rl=ratioL[k];\n\t\t\t\t\tif (k<n){\n\t\t\t\t\t\tfor (int l=0;l<lend;l++){\n\t\t\t\t\t\t\tscalar_t x2=buf[l*4+0];\n\t\t\t\t\t\t\tscalar_t y2=buf[l*4+1];\n\t\t\t\t\t\t\tscalar_t z2=buf[l*4+2];\n\t\t\t\t\t\t\tscalar_t w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];\n\t\t\t\t\t\t\tmatch[i*n*m+(l0+l)*n+k]+=w;\n\t\t\t\t\t\t\tsuml+=w;\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (k<n)\n\t\t\t\t\tremainL[k]=fmaxf(0.0f,remainL[k]-suml);\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t}\n\t}\n}\n\n//void approxmatchLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,scalar_t * match,scalar_t * temp){\n//\tapproxmatch<<<32,512>>>(b,n,m,xyz1,xyz2,match,temp);\n//}\n\n/* ApproxMatch forward interface\nInput:\n  xyz1: (B, N1, 3)  # dataset_points\n  xyz2: (B, N2, 3)  # query_points\nOutput:\n  match: (B, N2, N1)\n*/\nat::Tensor ApproxMatchForward(\n    const at::Tensor xyz1,\n    const at::Tensor xyz2){\n  const auto b = xyz1.size(0);\n  const auto n = xyz1.size(1);\n  const auto m = xyz2.size(1);\n\n  CHECK_EQ(xyz2.size(0), b);\n  CHECK_EQ(xyz1.size(2), 3);\n  CHECK_EQ(xyz2.size(2), 3);\n  CHECK_INPUT(xyz1);\n  CHECK_INPUT(xyz2);\n\n  auto match = at::zeros({b, m, n}, xyz1.type());\n  auto temp = at::zeros({b, (n+m)*2}, xyz1.type());\n\n  AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), \"ApproxMatchForward\", ([&] {\n        approxmatch<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), temp.data<scalar_t>());\n  }));\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  return match;\n}\n\n\n/********************************\n* Forward kernel for matchcost\n*********************************/\n\ntemplate<typename scalar_t>\n__global__ void matchcost(int b,int n,int m,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ out){\n\t__shared__ scalar_t allsum[512];\n\tconst int Block=1024;\n\t__shared__ scalar_t buf[Block*3];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tscalar_t subsum=0;\n\t\tfor (int k0=0;k0<n;k0+=blockDim.x){\n\t\t\tint k=k0+threadIdx.x;\n\t\t\tscalar_t x1=0,y1=0,z1=0;\n\t\t\tif (k<n){\n\t\t\t\tx1=xyz1[i*n*3+k*3+0];\n\t\t\t\ty1=xyz1[i*n*3+k*3+1];\n\t\t\t\tz1=xyz1[i*n*3+k*3+2];\n\t\t\t}\n\t\t\tfor (int l0=0;l0<m;l0+=Block){\n\t\t\t\tint lend=min(m,l0+Block)-l0;\n\t\t\t\tfor (int l=threadIdx.x;l<lend*3;l+=blockDim.x)\n\t\t\t\t\tbuf[l]=xyz2[i*m*3+l0*3+l];\n\t\t\t\t__syncthreads();\n\t\t\t\tif (k<n){\n\t\t\t\t\tfor (int l=0;l<lend;l++){\n\t\t\t\t\t\tscalar_t x2=buf[l*3+0];\n\t\t\t\t\t\tscalar_t y2=buf[l*3+1];\n\t\t\t\t\t\tscalar_t z2=buf[l*3+2];\n\t\t\t\t\t\tscalar_t d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);\n\t\t\t\t\t\tsubsum+=d*match[i*n*m+(l0+l)*n+k];\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\t__syncthreads();\n\t\t\t}\n\t\t}\n\t\tallsum[threadIdx.x]=subsum;\n\t\tfor (int j=1;j<blockDim.x;j<<=1){\n\t\t\t__syncthreads();\n\t\t\tif ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){\n\t\t\t\tallsum[threadIdx.x]+=allsum[threadIdx.x+j];\n\t\t\t}\n\t\t}\n\t\tif (threadIdx.x==0)\n\t\t\tout[i]=allsum[0];\n\t\t__syncthreads();\n\t}\n}\n\n//void matchcostLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * out){\n//\tmatchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);\n//}\n\n/* MatchCost forward interface\nInput:\n  xyz1: (B, N1, 3)  # dataset_points\n  xyz2: (B, N2, 3)  # query_points\n  match: (B, N2, N1)\nOutput:\n  cost: (B)\n*/\nat::Tensor MatchCostForward(\n    const at::Tensor xyz1,\n    const at::Tensor xyz2,\n    const at::Tensor match){\n  const auto b = xyz1.size(0);\n  const auto n = xyz1.size(1);\n  const auto m = xyz2.size(1);\n\n  CHECK_EQ(xyz2.size(0), b);\n  CHECK_EQ(xyz1.size(2), 3);\n  CHECK_EQ(xyz2.size(2), 3);\n  CHECK_INPUT(xyz1);\n  CHECK_INPUT(xyz2);\n\n  auto cost = at::zeros({b}, xyz1.type());\n\n  AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), \"MatchCostForward\", ([&] {\n        matchcost<scalar_t><<<32,512>>>(b, n, m, xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), cost.data<scalar_t>());\n  }));\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  return cost;\n}\n\n\n/********************************\n* matchcostgrad2 kernel\n*********************************/\n\ntemplate<typename scalar_t>\n__global__ void matchcostgrad2(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad2){\n\t__shared__ scalar_t sum_grad[256*3];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tint kbeg=m*blockIdx.y/gridDim.y;\n\t\tint kend=m*(blockIdx.y+1)/gridDim.y;\n\t\tfor (int k=kbeg;k<kend;k++){\n\t\t\tscalar_t x2=xyz2[(i*m+k)*3+0];\n\t\t\tscalar_t y2=xyz2[(i*m+k)*3+1];\n\t\t\tscalar_t z2=xyz2[(i*m+k)*3+2];\n\t\t\tscalar_t subsumx=0,subsumy=0,subsumz=0;\n\t\t\tfor (int j=threadIdx.x;j<n;j+=blockDim.x){\n\t\t\t\tscalar_t x1=x2-xyz1[(i*n+j)*3+0];\n\t\t\t\tscalar_t y1=y2-xyz1[(i*n+j)*3+1];\n\t\t\t\tscalar_t z1=z2-xyz1[(i*n+j)*3+2];\n\t\t\t\tscalar_t d=match[i*n*m+k*n+j]*2;\n\t\t\t\tsubsumx+=x1*d;\n\t\t\t\tsubsumy+=y1*d;\n\t\t\t\tsubsumz+=z1*d;\n\t\t\t}\n\t\t\tsum_grad[threadIdx.x*3+0]=subsumx;\n\t\t\tsum_grad[threadIdx.x*3+1]=subsumy;\n\t\t\tsum_grad[threadIdx.x*3+2]=subsumz;\n\t\t\tfor (int j=1;j<blockDim.x;j<<=1){\n\t\t\t\t__syncthreads();\n\t\t\t\tint j1=threadIdx.x;\n\t\t\t\tint j2=threadIdx.x+j;\n\t\t\t\tif ((j1&j)==0 && j2<blockDim.x){\n\t\t\t\t\tsum_grad[j1*3+0]+=sum_grad[j2*3+0];\n\t\t\t\t\tsum_grad[j1*3+1]+=sum_grad[j2*3+1];\n\t\t\t\t\tsum_grad[j1*3+2]+=sum_grad[j2*3+2];\n\t\t\t\t}\n\t\t\t}\n\t\t\tif (threadIdx.x==0){\n\t\t\t\tgrad2[(i*m+k)*3+0]=sum_grad[0]*grad_cost[i];\n\t\t\t\tgrad2[(i*m+k)*3+1]=sum_grad[1]*grad_cost[i];\n\t\t\t\tgrad2[(i*m+k)*3+2]=sum_grad[2]*grad_cost[i];\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t}\n\t}\n}\n\n/********************************\n* matchcostgrad1 kernel\n*********************************/\n\ntemplate<typename scalar_t>\n__global__ void matchcostgrad1(int b,int n,int m,const scalar_t * __restrict__ grad_cost,const scalar_t * __restrict__ xyz1,const scalar_t * __restrict__ xyz2,const scalar_t * __restrict__ match,scalar_t * __restrict__ grad1){\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfor (int l=threadIdx.x;l<n;l+=blockDim.x){\n\t\t\tscalar_t x1=xyz1[i*n*3+l*3+0];\n\t\t\tscalar_t y1=xyz1[i*n*3+l*3+1];\n\t\t\tscalar_t z1=xyz1[i*n*3+l*3+2];\n\t\t\tscalar_t dx=0,dy=0,dz=0;\n\t\t\tfor (int k=0;k<m;k++){\n\t\t\t\tscalar_t x2=xyz2[i*m*3+k*3+0];\n\t\t\t\tscalar_t y2=xyz2[i*m*3+k*3+1];\n\t\t\t\tscalar_t z2=xyz2[i*m*3+k*3+2];\n\t\t\t\tscalar_t d=match[i*n*m+k*n+l]*2;\n\t\t\t\tdx+=(x1-x2)*d;\n\t\t\t\tdy+=(y1-y2)*d;\n\t\t\t\tdz+=(z1-z2)*d;\n\t\t\t}\n\t\t\tgrad1[i*n*3+l*3+0]=dx*grad_cost[i];\n\t\t\tgrad1[i*n*3+l*3+1]=dy*grad_cost[i];\n\t\t\tgrad1[i*n*3+l*3+2]=dz*grad_cost[i];\n\t\t}\n\t}\n}\n\n//void matchcostgradLauncher(int b,int n,int m,const scalar_t * xyz1,const scalar_t * xyz2,const scalar_t * match,scalar_t * grad1,scalar_t * grad2){\n//\tmatchcostgrad1<<<32,512>>>(b,n,m,xyz1,xyz2,match,grad1);\n//\tmatchcostgrad2<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);\n//}\n\n\n/* MatchCost backward interface\nInput:\n  grad_cost: (B)    # gradients on cost\n  xyz1: (B, N1, 3)  # dataset_points\n  xyz2: (B, N2, 3)  # query_points\n  match: (B, N2, N1)\nOutput:\n  grad1: (B, N1, 3)\n  grad2: (B, N2, 3)\n*/\nstd::vector<at::Tensor> MatchCostBackward(\n    const at::Tensor grad_cost,\n    const at::Tensor xyz1,\n    const at::Tensor xyz2,\n    const at::Tensor match){\n  const auto b = xyz1.size(0);\n  const auto n = xyz1.size(1);\n  const auto m = xyz2.size(1);\n\n  CHECK_EQ(xyz2.size(0), b);\n  CHECK_EQ(xyz1.size(2), 3);\n  CHECK_EQ(xyz2.size(2), 3);\n  CHECK_INPUT(xyz1);\n  CHECK_INPUT(xyz2);\n\n  auto grad1 = at::zeros({b, n, 3}, xyz1.type());\n  auto grad2 = at::zeros({b, m, 3}, xyz1.type());\n\n  AT_DISPATCH_FLOATING_TYPES(xyz1.scalar_type(), \"MatchCostBackward\", ([&] {\n        matchcostgrad1<scalar_t><<<32,512>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad1.data<scalar_t>());\n        matchcostgrad2<scalar_t><<<dim3(32,32),256>>>(b, n, m, grad_cost.data<scalar_t>(), xyz1.data<scalar_t>(), xyz2.data<scalar_t>(), match.data<scalar_t>(), grad2.data<scalar_t>());\n  }));\n  AT_CUDA_CHECK(cudaGetLastError());\n\n  return std::vector<at::Tensor>({grad1, grad2});\n}\n\n#endif\n"
  },
  {
    "path": "segmentation/extensions/emd/emd.py",
    "content": "import torch\nimport emd_cuda\n\n\nclass EarthMoverDistanceFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, xyz1, xyz2):\n        xyz1 = xyz1.contiguous()\n        xyz2 = xyz2.contiguous()\n        assert xyz1.is_cuda and xyz2.is_cuda, \"Only support cuda currently.\"\n        match = emd_cuda.approxmatch_forward(xyz1, xyz2)\n        cost = emd_cuda.matchcost_forward(xyz1, xyz2, match)\n        ctx.save_for_backward(xyz1, xyz2, match)\n        return cost\n\n    @staticmethod\n    def backward(ctx, grad_cost):\n        xyz1, xyz2, match = ctx.saved_tensors\n        grad_cost = grad_cost.contiguous()\n        grad_xyz1, grad_xyz2 = emd_cuda.matchcost_backward(grad_cost, xyz1, xyz2, match)\n        return grad_xyz1, grad_xyz2\n\n\n\n\nclass earth_mover_distance(torch.nn.Module):\n    f''' emd\n    '''\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, xyz1, xyz2, transpose=False):\n        \"\"\"Earth Mover Distance (Approx)\n\n        Args:\n            xyz1 (torch.Tensor): (b, n1, 3)\n            xyz2 (torch.Tensor): (b, n2, 3)\n            transpose (bool): whether to transpose inputs as it might be BCN format.\n                Extensions only support BNC format.\n\n        Returns:\n            cost (torch.Tensor): (b)\n\n        \"\"\"\n\n        cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)\n        cost = cost / xyz1.size(1)\n        \n        return cost.mean()\n# def earth_mover_distance(xyz1, xyz2, transpose=True):\n#     \"\"\"Earth Mover Distance (Approx)\n\n#     Args:\n#         xyz1 (torch.Tensor): (b, 3, n1)\n#         xyz2 (torch.Tensor): (b, 3, n1)\n#         transpose (bool): whether to transpose inputs as it might be BCN format.\n#             Extensions only support BNC format.\n\n#     Returns:\n#         cost (torch.Tensor): (b)\n\n#     \"\"\"\n#     if xyz1.dim() == 2:\n#         xyz1 = xyz1.unsqueeze(0)\n#     if xyz2.dim() == 2:\n#         xyz2 = xyz2.unsqueeze(0)\n#     if transpose:\n#         xyz1 = xyz1.transpose(1, 2)\n#         xyz2 = xyz2.transpose(1, 2)\n#     cost = EarthMoverDistanceFunction.apply(xyz1, xyz2)\n#     return cost\n\n"
  },
  {
    "path": "segmentation/extensions/emd/setup.py",
    "content": "\"\"\"Setup extension\n\nNotes:\n    If extra_compile_args is provided, you need to provide different instances for different extensions.\n    Refer to https://github.com/pytorch/pytorch/issues/20169\n\n\"\"\"\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\nsetup(\n    name='emd_ext',\n    ext_modules=[\n        CUDAExtension(\n            name='emd_cuda',\n            sources=[\n                'cuda/emd.cpp',\n                'cuda/emd_kernel.cu',\n            ],\n            extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}\n        ),\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    })\n"
  },
  {
    "path": "segmentation/extensions/emd/test_emd_loss.py",
    "content": "import torch\nimport numpy as np\nimport time\nfrom emd import earth_mover_distance\n\n# gt\np1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()\np1 = p1.repeat(3, 1, 1)\np2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()\np2 = p2.repeat(3, 1, 1)\nprint(p1)\nprint(p2)\nprint(p1.shape)\np1.requires_grad = True\np2.requires_grad = True\n\ngt_dist = (((p1[0, 0] - p2[0, 1])**2).sum() + ((p1[0, 1] - p2[0, 0])**2).sum()) / 2 +  \\\n         (((p1[1, 0] - p2[1, 1])**2).sum() + ((p1[1, 1] - p2[1, 0])**2).sum()) * 2 + \\\n         (((p1[2, 0] - p2[2, 1])**2).sum() + ((p1[2, 1] - p2[2, 0])**2).sum()) / 3\nprint('gt_dist: ', gt_dist)\n\ngt_dist.backward()\nprint(p1.grad)\nprint(p2.grad)\n\n# emd\np1 = torch.from_numpy(np.array([[[1.7, -0.1, 0.1], [0.1, 1.2, 0.3]]], dtype=np.float32)).cuda()\np1 = p1.repeat(3, 1, 1)\np2 = torch.from_numpy(np.array([[[0.3, 1.8, 0.2], [1.2, -0.2, 0.3]]], dtype=np.float32)).cuda()\np2 = p2.repeat(3, 1, 1)\nprint(p1)\nprint(p2)\np1.requires_grad = True\np2.requires_grad = True\n\nd = earth_mover_distance(p1, p2, transpose=False)\nprint(d)\n\nloss = d[0] / 2 + d[1] * 2 + d[2] / 3\nprint(loss)\n\nloss.backward()\nprint(p1.grad)\nprint(p2.grad)\n\n"
  },
  {
    "path": "segmentation/logger.py",
    "content": "import logging\nimport torch.distributed as dist\n\nimport copy\nimport logging\nimport os\nfrom collections import defaultdict\nimport torch\nimport torch.nn as nn\n\nfrom typing import Any\nfrom typing import Optional, List, Dict, NamedTuple, Tuple, Iterable\n\nfrom termcolor import colored\n\nlogger_initialized = {}\n\ndef get_root_logger(log_file=None, log_level=logging.INFO, name='main'):\n    \"\"\"Get root logger and add a keyword filter to it.\n    The logger will be initialized if it has not been initialized. By default a\n    StreamHandler will be added. If `log_file` is specified, a FileHandler will\n    also be added. The name of the root logger is the top-level package name,\n    e.g., \"mmdet3d\".\n    Args:\n        log_file (str, optional): File path of log. Defaults to None.\n        log_level (int, optional): The level of logger.\n            Defaults to logging.INFO.\n        name (str, optional): The name of the root logger, also used as a\n            filter keyword. Defaults to 'mmdet3d'.\n    Returns:\n        :obj:`logging.Logger`: The obtained logger\n    \"\"\"\n    logger = get_logger(name=name, log_file=log_file, log_level=log_level)\n    # add a logging filter\n    logging_filter = logging.Filter(name)\n    logging_filter.filter = lambda record: record.find(name) != -1\n\n    return logger\n\n\ndef get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):\n    \"\"\"Initialize and get a logger by name.\n    If the logger has not been initialized, this method will initialize the\n    logger by adding one or two handlers, otherwise the initialized logger will\n    be directly returned. During initialization, a StreamHandler will always be\n    added. If `log_file` is specified and the process rank is 0, a FileHandler\n    will also be added.\n    Args:\n        name (str): Logger name.\n        log_file (str | None): The log filename. If specified, a FileHandler\n            will be added to the logger.\n        log_level (int): The logger level. Note that only the process of\n            rank 0 is affected, and other processes will set the level to\n            \"Error\" thus be silent most of the time.\n        file_mode (str): The file mode used in opening log file.\n            Defaults to 'w'.\n    Returns:\n        logging.Logger: The expected logger.\n    \"\"\"\n    logger = logging.getLogger(name)\n    if name in logger_initialized:\n        return logger\n    # handle hierarchical names\n    # e.g., logger \"a\" is initialized, then logger \"a.b\" will skip the\n    # initialization since it is a child of \"a\".\n    for logger_name in logger_initialized:\n        if name.startswith(logger_name):\n            return logger\n\n    # handle duplicate logs to the console\n    # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)\n    # to the root logger. As logger.propagate is True by default, this root\n    # level handler causes logging messages from rank>0 processes to\n    # unexpectedly show up on the console, creating much unwanted clutter.\n    # To fix this issue, we set the root logger's StreamHandler, if any, to log\n    # at the ERROR level.\n    for handler in logger.root.handlers:\n        if type(handler) is logging.StreamHandler:\n            handler.setLevel(logging.ERROR)\n\n    stream_handler = logging.StreamHandler()\n    handlers = [stream_handler]\n\n    if dist.is_available() and dist.is_initialized():\n        rank = dist.get_rank()\n    else:\n        rank = 0\n\n    # only rank 0 will add a FileHandler\n    if rank == 0 and log_file is not None:\n        # Here, the default behaviour of the official logger is 'a'. Thus, we\n        # provide an interface to change the file mode to the default\n        # behaviour.\n        file_handler = logging.FileHandler(log_file, file_mode)\n        handlers.append(file_handler)\n\n    formatter = logging.Formatter(\n        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n    for handler in handlers:\n        handler.setFormatter(formatter)\n        handler.setLevel(log_level)\n        logger.addHandler(handler)\n\n    if rank == 0:\n        logger.setLevel(log_level)\n    else:\n        logger.setLevel(logging.ERROR)\n\n    logger_initialized[name] = True\n\n\n    return logger\n\n\ndef print_log(msg, logger=None, level=logging.INFO):\n    \"\"\"Print a log message.\n    Args:\n        msg (str): The message to be logged.\n        logger (logging.Logger | str | None): The logger to be used.\n            Some special loggers are:\n            - \"silent\": no message will be printed.\n            - other str: the logger obtained with `get_root_logger(logger)`.\n            - None: The `print()` method will be used to print log messages.\n        level (int): Logging level. Only available when `logger` is a Logger\n            object or \"root\".\n    \"\"\"\n    if logger is None:\n        print(msg)\n    elif isinstance(logger, logging.Logger):\n        logger.log(level, msg)\n    elif logger == 'silent':\n        pass\n    elif isinstance(logger, str):\n        _logger = get_logger(logger)\n        _logger.log(level, msg)\n    else:\n        raise TypeError(\n            'logger should be either a logging.Logger object, str, '\n            f'\"silent\" or None, but got {type(logger)}')\n\ndef get_missing_parameters_message(keys: List[str]) -> str:\n    \"\"\"\n    Get a logging-friendly message to report parameter names (keys) that are in\n    the model but not found in a checkpoint.\n    Args:\n        keys (list[str]): List of keys that were not found in the checkpoint.\n    Returns:\n        str: message.\n    \"\"\"\n    groups = _group_checkpoint_keys(keys)\n    msg = \"Some model parameters or buffers are not found in the checkpoint:\\n\"\n    msg += \"\\n\".join(\n        \"  \" + colored(k + _group_to_str(v), \"blue\") for k, v in groups.items()\n    )\n    return msg\n\n\ndef get_unexpected_parameters_message(keys: List[str]) -> str:\n    \"\"\"\n    Get a logging-friendly message to report parameter names (keys) that are in\n    the checkpoint but not found in the model.\n    Args:\n        keys (list[str]): List of keys that were not found in the model.\n    Returns:\n        str: message.\n    \"\"\"\n    groups = _group_checkpoint_keys(keys)\n    msg = \"The checkpoint state_dict contains keys that are not used by the model:\\n\"\n    msg += \"\\n\".join(\n        \"  \" + colored(k + _group_to_str(v), \"magenta\") for k, v in groups.items()\n    )\n    return msg\n\n\ndef _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:\n    \"\"\"\n    Strip the prefix in metadata, if any.\n    Args:\n        state_dict (OrderedDict): a state-dict to be loaded to the model.\n        prefix (str): prefix.\n    \"\"\"\n    keys = sorted(state_dict.keys())\n    if not all(len(key) == 0 or key.startswith(prefix) for key in keys):\n        return\n\n    for key in keys:\n        newkey = key[len(prefix):]\n        state_dict[newkey] = state_dict.pop(key)\n\n    # also strip the prefix in metadata, if any..\n    try:\n        metadata = state_dict._metadata  # pyre-ignore\n    except AttributeError:\n        pass\n    else:\n        for key in list(metadata.keys()):\n            # for the metadata dict, the key can be:\n            # '': for the DDP module, which we want to remove.\n            # 'module': for the actual model.\n            # 'module.xx.xx': for the rest.\n\n            if len(key) == 0:\n                continue\n            newkey = key[len(prefix):]\n            metadata[newkey] = metadata.pop(key)\n\n\ndef _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:\n    \"\"\"\n    Group keys based on common prefixes. A prefix is the string up to the final\n    \".\" in each key.\n    Args:\n        keys (list[str]): list of parameter names, i.e. keys in the model\n            checkpoint dict.\n    Returns:\n        dict[list]: keys with common prefixes are grouped into lists.\n    \"\"\"\n    groups = defaultdict(list)\n    for key in keys:\n        pos = key.rfind(\".\")\n        if pos >= 0:\n            head, tail = key[:pos], [key[pos + 1:]]\n        else:\n            head, tail = key, []\n        groups[head].extend(tail)\n    return groups\n\n\ndef _group_to_str(group: List[str]) -> str:\n    \"\"\"\n    Format a group of parameter name suffixes into a loggable string.\n    Args:\n        group (list[str]): list of parameter name suffixes.\n    Returns:\n        str: formated string.\n    \"\"\"\n    if len(group) == 0:\n        return \"\"\n\n    if len(group) == 1:\n        return \".\" + group[0]\n\n    return \".{\" + \", \".join(group) + \"}\"\n\n\ndef _named_modules_with_dup(\n        model: nn.Module, prefix: str = \"\"\n) -> Iterable[Tuple[str, nn.Module]]:\n    \"\"\"\n    The same as `model.named_modules()`, except that it includes\n    duplicated modules that have more than one name.\n    \"\"\"\n    yield prefix, model\n    for name, module in model._modules.items():  # pyre-ignore\n        if module is None:\n            continue\n        submodule_prefix = prefix + (\".\" if prefix else \"\") + name\n        yield from _named_modules_with_dup(module, submodule_prefix)"
  },
  {
    "path": "segmentation/main.py",
    "content": "\"\"\"\nAuthor: Benny\nDate: Nov 2019\n\"\"\"\nimport argparse\nimport os\nimport torch\nimport datetime\nimport logging\nimport sys\nimport importlib\nimport shutil\nimport provider\nimport numpy as np\nimport torch.optim as optim\nfrom timm.scheduler import CosineLRScheduler\nfrom pathlib import Path\nfrom tqdm import tqdm\nfrom dataset import PartNormalDataset\n\nBASE_DIR = os.path.dirname(os.path.abspath(__file__))\nROOT_DIR = BASE_DIR\nsys.path.append(os.path.join(ROOT_DIR, 'models'))\n\nseg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],\n               'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],\n               'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],\n               'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}\nseg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}\nfor cat in seg_classes.keys():\n    for label in seg_classes[cat]:\n        seg_label_to_cat[label] = cat\n\n\ndef inplace_relu(m):\n    classname = m.__class__.__name__\n    if classname.find('ReLU') != -1:\n        m.inplace = True\n\n\ndef to_categorical(y, num_classes):\n    \"\"\" 1-hot encodes a tensor \"\"\"\n    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]\n    if (y.is_cuda):\n        return new_y.cuda()\n    return new_y\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser('Model')\n    parser.add_argument('--model', type=str, default='pt')\n    parser.add_argument('--model_name', type=str, default='PointGPT_S', choices=['PointGPT_S', 'PointGPT_B', 'PointGPT_L'])\n    parser.add_argument('--batch_size', type=int, default=32, # 16, 32\n                        help='batch Size during training')\n    parser.add_argument('--epoch', default=300, type=int, help='epoch to run')\n    parser.add_argument('--warmup_epoch', default=30,\n                        type=int, help='warmup epoch')\n    parser.add_argument('--learning_rate', default=0.0002,\n                        type=float, help='initial learning rate')\n    parser.add_argument('--gpu', type=str, default='1',\n                        help='specify GPU devices')\n    # parser.add_argument('--optimizer', type=str, default='AdamW', help='Adam or SGD')\n    parser.add_argument('--log_dir', type=str,\n                        default='./exp', help='log path')\n    # parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay')\n    parser.add_argument('--npoint', type=int,\n                        default=2048, help='point Number')\n    parser.add_argument('--normal', action='store_true',\n                        default=False, help='use normals')\n\n    # parser.add_argument('--step_size', type=int, default=20, help='decay step for lr decay')\n    # parser.add_argument('--lr_decay', type=float, default=0.5, help='decay rate for lr decay')\n    parser.add_argument(\n        '--ckpts', type=str, default='../best/pretrain/m0.6R_1_pretrain300.pth', help='ckpts')\n    parser.add_argument(\n        '--root', type=str, default='data/ShapenetPart/shapenetcore_partanno_segmentation_benchmark_v0_normal/', help='data root')\n    return parser.parse_args()\n\ndef get_model_loss(MODEL, args, num_part):\n    if args.model_name == 'PointGPT_S':\n        classifier = MODEL.get_model(num_part, trans_dim=384, depth=12, drop_path_rate=0.1, num_heads=6, decoder_depth=4, group_size=32, num_group=128, prop_dim=1024, label_dim1=512, label_dim2=256, encoder_dims=384)\n        classifier = classifier.cuda()\n        criterion = MODEL.get_loss().cuda()\n        classifier.apply(inplace_relu)  \n    elif args.model_name == 'PointGPT_B':\n        classifier = MODEL.get_model(num_part, trans_dim=768, depth=12, drop_path_rate=0.1, num_heads=12, decoder_depth=4, group_size=32, num_group=128, prop_dim=2048, label_dim1=1024, label_dim2=512, encoder_dims=768)\n        classifier = classifier.cuda()\n        criterion = MODEL.get_loss().cuda()\n        classifier.apply(inplace_relu)  \n    elif args.model_name == 'PointGPT_L':\n        classifier = MODEL.get_model(num_part, trans_dim=1024, depth=24, drop_path_rate=0.1, num_heads=16, decoder_depth=4, group_size=32, num_group=128, prop_dim=2048, label_dim1=1024, label_dim2=512, encoder_dims=1024)\n        classifier = classifier.cuda()\n        criterion = MODEL.get_loss().cuda()\n        classifier.apply(inplace_relu)  \n    return classifier, criterion   \n\ndef main(args):\n    def log_string(str):\n        logger.info(str)\n        print(str)\n\n    '''HYPER PARAMETER'''\n    # os.environ[\"CUDA_VISIBLE_DEVICES\"] = args.gpu\n\n    '''CREATE DIR'''\n    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))\n    exp_dir = Path('./log/')\n    exp_dir.mkdir(exist_ok=True)\n    exp_dir = exp_dir.joinpath('part_seg')\n    exp_dir.mkdir(exist_ok=True)\n    if args.log_dir is None:\n        exp_dir = exp_dir.joinpath(timestr)\n    else:\n        exp_dir = exp_dir.joinpath(args.log_dir)\n    exp_dir.mkdir(exist_ok=True)\n    checkpoints_dir = exp_dir.joinpath('checkpoints/')\n    checkpoints_dir.mkdir(exist_ok=True)\n    log_dir = exp_dir.joinpath('logs/')\n    log_dir.mkdir(exist_ok=True)\n\n    '''LOG'''\n    args = parse_args()\n    logger = logging.getLogger(\"Model\")\n    logger.setLevel(logging.INFO)\n    formatter = logging.Formatter(\n        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))\n    file_handler.setLevel(logging.INFO)\n    file_handler.setFormatter(formatter)\n    logger.addHandler(file_handler)\n    log_string('PARAMETER ...')\n    log_string(args)\n\n    root = args.root\n\n    TRAIN_DATASET = PartNormalDataset(\n        root=root, npoints=args.npoint, split='trainval', normal_channel=args.normal)\n    trainDataLoader = torch.utils.data.DataLoader(\n        TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)\n    TEST_DATASET = PartNormalDataset(\n        root=root, npoints=args.npoint, split='test', normal_channel=args.normal)\n    testDataLoader = torch.utils.data.DataLoader(\n        TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)\n    log_string(\"The number of training data is: %d\" % len(TRAIN_DATASET))\n    log_string(\"The number of test data is: %d\" % len(TEST_DATASET))\n\n    num_classes = 16\n    num_part = 50\n\n    '''MODEL LOADING'''\n    MODEL = importlib.import_module(args.model)\n    shutil.copy('models/%s.py' % args.model, str(exp_dir))\n    # shutil.copy('models/pointnet2_utils.py', str(exp_dir))\n    \n    classifier, criterion = get_model_loss(MODEL, args, num_part)\n\n    print('# generator parameters:', sum(param.numel()\n          for param in classifier.parameters()))\n    start_epoch = 0\n\n    if args.ckpts is not None:\n        classifier.load_model_from_ckpt(args.ckpts)\n\n# we use adamw and cosine scheduler\n    def add_weight_decay(model, weight_decay=1e-5, skip_list=()):\n        decay = []\n        no_decay = []\n        for name, param in model.named_parameters():\n            if not param.requires_grad:\n                continue  # frozen weights\n            if len(param.shape) == 1 or name.endswith(\".bias\") or 'token' in name or name in skip_list:\n                # print(name)\n                no_decay.append(param)\n            else:\n                decay.append(param)\n        return [\n            {'params': no_decay, 'weight_decay': 0.},\n            {'params': decay, 'weight_decay': weight_decay}]\n\n    param_groups = add_weight_decay(classifier, weight_decay=0.05)\n    optimizer = optim.AdamW(\n        param_groups, lr=args.learning_rate, weight_decay=0.05)\n\n    scheduler = CosineLRScheduler(optimizer,\n                                  t_initial=args.epoch,\n                                #   t_mul=1,\n                                  lr_min=1e-6,\n                                  cycle_decay=0.1,\n                                  warmup_lr_init=1e-6,\n                                  warmup_t=args.warmup_epoch,\n                                  cycle_limit=1,\n                                  t_in_epochs=True)\n\n    best_acc = 0\n    global_epoch = 0\n    best_class_avg_iou = 0\n    best_inctance_avg_iou = 0\n\n    classifier.zero_grad()\n    for epoch in range(start_epoch, args.epoch):\n        mean_correct = []\n\n        log_string('Epoch %d (%d/%s):' %\n                   (global_epoch + 1, epoch + 1, args.epoch))\n        '''Adjust learning rate and BN momentum'''\n\n        classifier = classifier.train()\n        loss_batch = []\n        num_iter = 0\n        '''learning one epoch'''\n        for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):\n            num_iter += 1\n            points = points.data.numpy()\n            points[:, :, 0:3] = provider.random_scale_point_cloud(\n                points[:, :, 0:3])\n            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])\n            points = torch.Tensor(points)\n            points, label, target = points.float().cuda(\n            ), label.long().cuda(), target.long().cuda()\n            points = points.transpose(2, 1)\n\n            seg_pred = classifier(points, to_categorical(label, num_classes))\n            seg_pred = seg_pred.contiguous().view(-1, num_part)\n            target = target.view(-1, 1)[:, 0]\n            pred_choice = seg_pred.data.max(1)[1]\n\n            correct = pred_choice.eq(target.data).cpu().sum()\n            mean_correct.append(\n                correct.item() / (args.batch_size * args.npoint))\n            loss = criterion(seg_pred, target)\n            loss.backward()\n            optimizer.step()\n            loss_batch.append(loss.detach().cpu())\n\n            if num_iter == 1:\n\n                torch.nn.utils.clip_grad_norm_(\n                    classifier.parameters(), 10, norm_type=2)\n                num_iter = 0\n                optimizer.step()\n                classifier.zero_grad()\n\n        if isinstance(scheduler, list):\n            for item in scheduler:\n                item.step(epoch)\n        else:\n            scheduler.step(epoch)\n\n        train_instance_acc = np.mean(mean_correct)\n        loss1 = np.mean(loss_batch)\n        log_string('Train accuracy is: %.5f' % train_instance_acc)\n        log_string('Train loss: %.5f' % loss1)\n        log_string('lr: %.6f' % optimizer.param_groups[0]['lr'])\n\n        with torch.no_grad():\n            test_metrics = {}\n            total_correct = 0\n            total_seen = 0\n            total_seen_class = [0 for _ in range(num_part)]\n            total_correct_class = [0 for _ in range(num_part)]\n            shape_ious = {cat: [] for cat in seg_classes.keys()}\n            seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}\n\n            for cat in seg_classes.keys():\n                for label in seg_classes[cat]:\n                    seg_label_to_cat[label] = cat\n\n            classifier = classifier.eval()\n\n            for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):\n                cur_batch_size, NUM_POINT, _ = points.size()\n                points, label, target = points.float().cuda(\n                ), label.long().cuda(), target.long().cuda()\n                points = points.transpose(2, 1)\n                seg_pred = classifier(\n                    points, to_categorical(label, num_classes))\n                cur_pred_val = seg_pred.cpu().data.numpy()\n                cur_pred_val_logits = cur_pred_val\n                cur_pred_val = np.zeros(\n                    (cur_batch_size, NUM_POINT)).astype(np.int32)\n                target = target.cpu().data.numpy()\n\n                for i in range(cur_batch_size):\n                    cat = seg_label_to_cat[target[i, 0]]\n                    logits = cur_pred_val_logits[i, :, :]\n                    cur_pred_val[i, :] = np.argmax(\n                        logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]\n\n                correct = np.sum(cur_pred_val == target)\n                total_correct += correct\n                total_seen += (cur_batch_size * NUM_POINT)\n\n                for l in range(num_part):\n                    total_seen_class[l] += np.sum(target == l)\n                    total_correct_class[l] += (\n                        np.sum((cur_pred_val == l) & (target == l)))\n\n                for i in range(cur_batch_size):\n                    segp = cur_pred_val[i, :]\n                    segl = target[i, :]\n                    cat = seg_label_to_cat[segl[0]]\n                    part_ious = [0.0 for _ in range(len(seg_classes[cat]))]\n                    for l in seg_classes[cat]:\n                        if (np.sum(segl == l) == 0) and (\n                                np.sum(segp == l) == 0):  # part is not present, no prediction as well\n                            part_ious[l - seg_classes[cat][0]] = 1.0\n                        else:\n                            part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(\n                                np.sum((segl == l) | (segp == l)))\n                    shape_ious[cat].append(np.mean(part_ious))\n\n            all_shape_ious = []\n            for cat in shape_ious.keys():\n                for iou in shape_ious[cat]:\n                    all_shape_ious.append(iou)\n                shape_ious[cat] = np.mean(shape_ious[cat])\n            mean_shape_ious = np.mean(list(shape_ious.values()))\n            test_metrics['accuracy'] = total_correct / float(total_seen)\n            test_metrics['class_avg_accuracy'] = np.mean(\n                np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float))\n            for cat in sorted(shape_ious.keys()):\n                log_string('eval mIoU of %s %f' %\n                           (cat + ' ' * (14 - len(cat)), shape_ious[cat]))\n            test_metrics['class_avg_iou'] = mean_shape_ious\n            test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)\n\n        log_string('Epoch %d test Accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f' % (\n            epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))\n        if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):\n            logger.info('Save model...')\n            savepath = str(checkpoints_dir) + '/best_model.pth'\n            log_string('Saving at %s' % savepath)\n            state = {\n                'epoch': epoch,\n                'train_acc': train_instance_acc,\n                'test_acc': test_metrics['accuracy'],\n                'class_avg_iou': test_metrics['class_avg_iou'],\n                'inctance_avg_iou': test_metrics['inctance_avg_iou'],\n                'model_state_dict': classifier.state_dict(),\n                'optimizer_state_dict': optimizer.state_dict(),\n            }\n            torch.save(state, savepath)\n            log_string('Saving model....')\n\n        if test_metrics['accuracy'] > best_acc:\n            best_acc = test_metrics['accuracy']\n        if test_metrics['class_avg_iou'] > best_class_avg_iou:\n            best_class_avg_iou = test_metrics['class_avg_iou']\n        if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:\n            best_inctance_avg_iou = test_metrics['inctance_avg_iou']\n        log_string('Best accuracy is: %.5f' % best_acc)\n        log_string('Best class avg mIOU is: %.5f' % best_class_avg_iou)\n        log_string('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)\n        global_epoch += 1\n\n\nif __name__ == '__main__':\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "segmentation/misc.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nimport random\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport os\nfrom collections import abc\nfrom pointnet2_ops import pointnet2_utils\n\n\ndef fps(data, number):\n    '''\n        data B N 3\n        number int\n    '''\n    fps_idx = pointnet2_utils.furthest_point_sample(data, number)\n    fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()\n    return fps_data\n\n\ndef worker_init_fn(worker_id):\n    np.random.seed(np.random.get_state()[1][0] + worker_id)\n\n\ndef build_lambda_sche(opti, config):\n    if config.get('decay_step') is not None:\n        lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)\n        scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)\n    else:\n        raise NotImplementedError()\n    return scheduler\n\n\ndef build_lambda_bnsche(model, config):\n    if config.get('decay_step') is not None:\n        bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)\n        bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)\n    else:\n        raise NotImplementedError()\n    return bnm_scheduler\n\n\ndef set_random_seed(seed, deterministic=False):\n    \"\"\"Set random seed.\n    Args:\n        seed (int): Seed to be used.\n        deterministic (bool): Whether to set the deterministic option for\n            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`\n            to True and `torch.backends.cudnn.benchmark` to False.\n            Default: False.\n\n    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html\n    if cuda_deterministic:  # slower, more reproducible\n        cudnn.deterministic = True\n        cudnn.benchmark = False\n    else:  # faster, less reproducible\n        cudnn.deterministic = False\n        cudnn.benchmark = True\n\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    if deterministic:\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n\n\ndef is_seq_of(seq, expected_type, seq_type=None):\n    \"\"\"Check whether it is a sequence of some type.\n    Args:\n        seq (Sequence): The sequence to be checked.\n        expected_type (type): Expected type of sequence items.\n        seq_type (type, optional): Expected sequence type.\n    Returns:\n        bool: Whether the sequence is valid.\n    \"\"\"\n    if seq_type is None:\n        exp_seq_type = abc.Sequence\n    else:\n        assert isinstance(seq_type, type)\n        exp_seq_type = seq_type\n    if not isinstance(seq, exp_seq_type):\n        return False\n    for item in seq:\n        if not isinstance(item, expected_type):\n            return False\n    return True\n\n\ndef set_bn_momentum_default(bn_momentum):\n    def fn(m):\n        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):\n            m.momentum = bn_momentum\n\n    return fn\n\n\nclass BNMomentumScheduler(object):\n\n    def __init__(\n            self, model, bn_lambda, last_epoch=-1,\n            setter=set_bn_momentum_default\n    ):\n        if not isinstance(model, nn.Module):\n            raise RuntimeError(\n                \"Class '{}' is not a PyTorch nn Module\".format(\n                    type(model).__name__\n                )\n            )\n\n        self.model = model\n        self.setter = setter\n        self.lmbd = bn_lambda\n\n        self.step(last_epoch + 1)\n        self.last_epoch = last_epoch\n\n    def step(self, epoch=None):\n        if epoch is None:\n            epoch = self.last_epoch + 1\n\n        self.last_epoch = epoch\n        self.model.apply(self.setter(self.lmbd(epoch)))\n\n    def get_momentum(self, epoch=None):\n        if epoch is None:\n            epoch = self.last_epoch + 1\n        return self.lmbd(epoch)\n\n\ndef seprate_point_cloud(xyz, num_points, crop, fixed_points=None, padding_zeros=False):\n    '''\n     seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.\n    '''\n    _, n, c = xyz.shape\n\n    assert n == num_points\n    assert c == 3\n    if crop == num_points:\n        return xyz, None\n\n    INPUT = []\n    CROP = []\n    for points in xyz:\n        if isinstance(crop, list):\n            num_crop = random.randint(crop[0], crop[1])\n        else:\n            num_crop = crop\n\n        points = points.unsqueeze(0)\n\n        if fixed_points is None:\n            center = F.normalize(torch.randn(1, 1, 3), p=2, dim=-1).cuda()\n        else:\n            if isinstance(fixed_points, list):\n                fixed_point = random.sample(fixed_points, 1)[0]\n            else:\n                fixed_point = fixed_points\n            center = fixed_point.reshape(1, 1, 3).cuda()\n\n        distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p=2, dim=-1)  # 1 1 2048\n\n        idx = torch.argsort(distance_matrix, dim=-1, descending=False)[0, 0]  # 2048\n\n        if padding_zeros:\n            input_data = points.clone()\n            input_data[0, idx[:num_crop]] = input_data[0, idx[:num_crop]] * 0\n\n        else:\n            input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0)  # 1 N 3\n\n        crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)\n\n        if isinstance(crop, list):\n            INPUT.append(fps(input_data, 2048))\n            CROP.append(fps(crop_data, 2048))\n        else:\n            INPUT.append(input_data)\n            CROP.append(crop_data)\n\n    input_data = torch.cat(INPUT, dim=0)  # B N 3\n    crop_data = torch.cat(CROP, dim=0)  # B M 3\n\n    return input_data.contiguous(), crop_data.contiguous()\n\n\ndef get_ptcloud_img(ptcloud):\n    fig = plt.figure(figsize=(8, 8))\n\n    x, z, y = ptcloud.transpose(1, 0)\n    ax = fig.gca(projection=Axes3D.name, adjustable='box')\n    ax.axis('off')\n    # ax.axis('scaled')\n    ax.view_init(90, 45)\n    max, min = np.max(ptcloud), np.min(ptcloud)\n    ax.set_xbound(min, max)\n    ax.set_ybound(min, max)\n    ax.set_zbound(min, max)\n    ax.scatter(x, y, z, zdir='z', c=y, cmap='jet')\n\n    fig.canvas.draw()\n    img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')\n    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n    return img\n\n\ndef visualize_KITTI(path, data_list, titles=['input', 'pred'], cmap=['bwr', 'autumn'], zdir='y',\n                    xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1)):\n    fig = plt.figure(figsize=(6 * len(data_list), 6))\n    cmax = data_list[-1][:, 0].max()\n\n    for i in range(len(data_list)):\n        data = data_list[i][:-2048] if i == 1 else data_list[i]\n        color = data[:, 0] / cmax\n        ax = fig.add_subplot(1, len(data_list), i + 1, projection='3d')\n        ax.view_init(30, -120)\n        b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color, vmin=-1, vmax=1, cmap=cmap[0], s=4,\n                       linewidth=0.05, edgecolors='black')\n        ax.set_title(titles[i])\n\n        ax.set_axis_off()\n        ax.set_xlim(xlim)\n        ax.set_ylim(ylim)\n        ax.set_zlim(zlim)\n    plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n    pic_path = path + '.png'\n    fig.savefig(pic_path)\n\n    np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())\n    np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())\n    plt.close(fig)\n\n\ndef random_dropping(pc, e):\n    up_num = max(64, 768 // (e // 50 + 1))\n    pc = pc\n    random_num = torch.randint(1, up_num, (1, 1))[0, 0]\n    pc = fps(pc, random_num)\n    padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)\n    pc = torch.cat([pc, padding], dim=1)\n    return pc\n\n\ndef random_scale(partial, scale_range=[0.8, 1.2]):\n    scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]\n    return partial * scale\n"
  },
  {
    "path": "segmentation/models/gpt2_seg.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import DropPath, trunc_normal_\n\n\nclass Block(nn.Module):\n    def __init__(self, embed_dim, num_heads, drop_path):\n        super(Block, self).__init__()\n        self.ln_1 = nn.LayerNorm(embed_dim)\n        self.ln_2 = nn.LayerNorm(embed_dim)\n        self.attn = nn.MultiheadAttention(embed_dim, num_heads)\n        self.mlp = nn.Sequential(\n            nn.Linear(embed_dim, embed_dim * 4),\n            nn.GELU(),\n            nn.Linear(embed_dim * 4, embed_dim),\n        )\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n\n    def forward(self, x):\n        attn_mask = torch.full(\n            (len(x), len(x)), -float(\"Inf\"), device=x.device, dtype=x.dtype\n        )\n        attn_mask = torch.triu(attn_mask, diagonal=1)\n\n        x = self.ln_1(x)\n        a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)\n        x = x + self.drop_path(a)\n        m = self.drop_path(self.mlp(self.ln_2(x)))\n        x = x + m\n        return x\n\n\nclass GPT_extractor(nn.Module):\n    def __init__(\n        self, embed_dim, num_heads, num_layers, trans_dim, group_size, drop_path_rate\n    ):\n        super(GPT_extractor, self).__init__()\n\n        self.embed_dim = embed_dim\n        self.trans_dim = trans_dim\n        self.group_size = group_size\n        self.drop_path_rate = drop_path_rate\n\n        # start of sequence token\n        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))\n        nn.init.normal_(self.sos)\n\n        dpr = [x.item() for x in torch.linspace(\n            0, self.drop_path_rate, num_layers)]\n        self.layers = nn.ModuleList()\n        for i in range(num_layers):\n            self.layers.append(Block(embed_dim, num_heads, dpr[i]))\n\n        self.ln_f = nn.LayerNorm(embed_dim)\n        # prediction head\n        self.increase_dim = nn.Sequential(\n            nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)\n        )\n\n    def forward(self, h, pos, classify=False):\n        \"\"\"\n        Expect input as shape [sequence len, batch]\n        If classify, return classification logits\n        \"\"\"\n        batch, length, C = h.shape\n\n        h = h.transpose(0, 1)\n        pos = pos.transpose(0, 1)\n\n        # prepend sos token\n        sos = torch.ones(1, batch, self.embed_dim, device=h.device) * self.sos\n        if not classify:\n            h = torch.cat([sos, h[:-1, :, :]], axis=0)\n        else:\n            h = torch.cat([sos, h], axis=0)\n\n        feature_list = []\n        fetch_idx = [3, 7, 11]\n\n        # transformer\n        for i, layer in enumerate(self.layers):\n            h = layer(h + pos)\n            if i in fetch_idx:\n                feature_list.append(h.transpose(0, 1)[:, 2:])\n\n        h = self.ln_f(h)\n\n        encoded_points = h.transpose(0, 1)\n\n        return encoded_points, feature_list\n\n\nclass GPT_generator(nn.Module):\n    def __init__(\n        self, embed_dim, num_heads, num_layers, trans_dim, group_size, drop_path_rate\n    ):\n        super(GPT_generator, self).__init__()\n\n        self.embed_dim = embed_dim\n        self.trans_dim = trans_dim\n        self.group_size = group_size\n\n        # start of sequence token\n        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))\n        nn.init.normal_(self.sos)\n\n        self.drop_path_rate = drop_path_rate\n\n        dpr = [x.item() for x in torch.linspace(\n            0, self.drop_path_rate, num_layers)]\n        self.layers = nn.ModuleList()\n        for i in range(num_layers):\n            self.layers.append(Block(embed_dim, num_heads, dpr[i]))\n\n        self.ln_f = nn.LayerNorm(embed_dim)\n        # prediction head\n        self.increase_dim = nn.Sequential(\n            nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)\n        )\n\n    def forward(self, h, pos):\n        \"\"\"\n        Expect input as shape [sequence len, batch]\n        If classify, return classification logits\n        \"\"\"\n        batch, length, C = h.shape\n\n        h = h.transpose(0, 1)\n        pos = pos.transpose(0, 1)\n\n        # transformer\n        for layer in self.layers:\n            h = layer(h + pos)\n\n        h = self.ln_f(h)\n\n        rebuild_points = self.increase_dim(h.transpose(1, 2)).transpose(\n            1, 2).transpose(0, 1).reshape(batch * length, -1, 3)\n\n        return rebuild_points\n"
  },
  {
    "path": "segmentation/models/pointnet2_utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom time import time\nimport numpy as np\n\ndef timeit(tag, t):\n    print(\"{}: {}s\".format(tag, time() - t))\n    return time()\n\ndef pc_normalize(pc):\n    l = pc.shape[0]\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\ndef square_distance(src, dst):\n    \"\"\"\n    Calculate Euclid distance between each two points.\n    src^T * dst = xn * xm + yn * ym + zn * zm；\n    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;\n    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;\n    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2\n         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst\n    Input:\n        src: source points, [B, N, C]\n        dst: target points, [B, M, C]\n    Output:\n        dist: per-point square distance, [B, N, M]\n    \"\"\"\n    B, N, _ = src.shape\n    _, M, _ = dst.shape\n    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))\n    dist += torch.sum(src ** 2, -1).view(B, N, 1)\n    dist += torch.sum(dst ** 2, -1).view(B, 1, M)\n    return dist\n\n\ndef index_points(points, idx):\n    \"\"\"\n    Input:\n        points: input points data, [B, N, C]\n        idx: sample index data, [B, S]\n    Return:\n        new_points:, indexed points data, [B, S, C]\n    \"\"\"\n    device = points.device\n    B = points.shape[0]\n    view_shape = list(idx.shape)\n    view_shape[1:] = [1] * (len(view_shape) - 1)\n    repeat_shape = list(idx.shape)\n    repeat_shape[0] = 1\n    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)\n    new_points = points[batch_indices, idx, :]\n    return new_points\n\n\ndef farthest_point_sample(xyz, npoint):\n    \"\"\"\n    Input:\n        xyz: pointcloud data, [B, N, 3]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [B, npoint]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)\n    distance = torch.ones(B, N).to(device) * 1e10\n    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)\n    batch_indices = torch.arange(B, dtype=torch.long).to(device)\n    for i in range(npoint):\n        centroids[:, i] = farthest\n        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)\n        dist = torch.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = torch.max(distance, -1)[1]\n    return centroids\n\n\ndef query_ball_point(radius, nsample, xyz, new_xyz):\n    \"\"\"\n    Input:\n        radius: local region radius\n        nsample: max sample number in local region\n        xyz: all points, [B, N, 3]\n        new_xyz: query points, [B, S, 3]\n    Return:\n        group_idx: grouped points index, [B, S, nsample]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    _, S, _ = new_xyz.shape\n    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])\n    sqrdists = square_distance(new_xyz, xyz)\n    group_idx[sqrdists > radius ** 2] = N\n    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]\n    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])\n    mask = group_idx == N\n    group_idx[mask] = group_first[mask]\n    return group_idx\n\n\ndef sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):\n    \"\"\"\n    Input:\n        npoint:\n        radius:\n        nsample:\n        xyz: input points position data, [B, N, 3]\n        points: input points data, [B, N, D]\n    Return:\n        new_xyz: sampled points position data, [B, npoint, nsample, 3]\n        new_points: sampled points data, [B, npoint, nsample, 3+D]\n    \"\"\"\n    B, N, C = xyz.shape\n    S = npoint\n    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]\n    new_xyz = index_points(xyz, fps_idx)\n    idx = query_ball_point(radius, nsample, xyz, new_xyz)\n    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]\n    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)\n\n    if points is not None:\n        grouped_points = index_points(points, idx)\n        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]\n    else:\n        new_points = grouped_xyz_norm\n    if returnfps:\n        return new_xyz, new_points, grouped_xyz, fps_idx\n    else:\n        return new_xyz, new_points\n\n\ndef sample_and_group_all(xyz, points):\n    \"\"\"\n    Input:\n        xyz: input points position data, [B, N, 3]\n        points: input points data, [B, N, D]\n    Return:\n        new_xyz: sampled points position data, [B, 1, 3]\n        new_points: sampled points data, [B, 1, N, 3+D]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    new_xyz = torch.zeros(B, 1, C).to(device)\n    grouped_xyz = xyz.view(B, 1, N, C)\n    if points is not None:\n        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)\n    else:\n        new_points = grouped_xyz\n    return new_xyz, new_points\n\n\nclass PointNetSetAbstraction(nn.Module):\n    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):\n        super(PointNetSetAbstraction, self).__init__()\n        self.npoint = npoint\n        self.radius = radius\n        self.nsample = nsample\n        self.mlp_convs = nn.ModuleList()\n        self.mlp_bns = nn.ModuleList()\n        last_channel = in_channel\n        for out_channel in mlp:\n            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))\n            self.mlp_bns.append(nn.BatchNorm2d(out_channel))\n            last_channel = out_channel\n        self.group_all = group_all\n\n    def forward(self, xyz, points):\n        \"\"\"\n        Input:\n            xyz: input points position data, [B, C, N]\n            points: input points data, [B, D, N]\n        Return:\n            new_xyz: sampled points position data, [B, C, S]\n            new_points_concat: sample points feature data, [B, D', S]\n        \"\"\"\n        xyz = xyz.permute(0, 2, 1)\n        if points is not None:\n            points = points.permute(0, 2, 1)\n\n        if self.group_all:\n            new_xyz, new_points = sample_and_group_all(xyz, points)\n        else:\n            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)\n        # new_xyz: sampled points position data, [B, npoint, C]\n        # new_points: sampled points data, [B, npoint, nsample, C+D]\n        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]\n        for i, conv in enumerate(self.mlp_convs):\n            bn = self.mlp_bns[i]\n            new_points =  F.relu(bn(conv(new_points)))\n\n        new_points = torch.max(new_points, 2)[0]\n        new_xyz = new_xyz.permute(0, 2, 1)\n        return new_xyz, new_points\n\n\nclass PointNetSetAbstractionMsg(nn.Module):\n    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):\n        super(PointNetSetAbstractionMsg, self).__init__()\n        self.npoint = npoint\n        self.radius_list = radius_list\n        self.nsample_list = nsample_list\n        self.conv_blocks = nn.ModuleList()\n        self.bn_blocks = nn.ModuleList()\n        for i in range(len(mlp_list)):\n            convs = nn.ModuleList()\n            bns = nn.ModuleList()\n            last_channel = in_channel + 3\n            for out_channel in mlp_list[i]:\n                convs.append(nn.Conv2d(last_channel, out_channel, 1))\n                bns.append(nn.BatchNorm2d(out_channel))\n                last_channel = out_channel\n            self.conv_blocks.append(convs)\n            self.bn_blocks.append(bns)\n\n    def forward(self, xyz, points):\n        \"\"\"\n        Input:\n            xyz: input points position data, [B, C, N]\n            points: input points data, [B, D, N]\n        Return:\n            new_xyz: sampled points position data, [B, C, S]\n            new_points_concat: sample points feature data, [B, D', S]\n        \"\"\"\n        xyz = xyz.permute(0, 2, 1)\n        if points is not None:\n            points = points.permute(0, 2, 1)\n\n        B, N, C = xyz.shape\n        S = self.npoint\n        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))\n        new_points_list = []\n        for i, radius in enumerate(self.radius_list):\n            K = self.nsample_list[i]\n            group_idx = query_ball_point(radius, K, xyz, new_xyz)\n            grouped_xyz = index_points(xyz, group_idx)\n            grouped_xyz -= new_xyz.view(B, S, 1, C)\n            if points is not None:\n                grouped_points = index_points(points, group_idx)\n                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)\n            else:\n                grouped_points = grouped_xyz\n\n            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]\n            for j in range(len(self.conv_blocks[i])):\n                conv = self.conv_blocks[i][j]\n                bn = self.bn_blocks[i][j]\n                grouped_points =  F.relu(bn(conv(grouped_points)))\n            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]\n            new_points_list.append(new_points)\n\n        new_xyz = new_xyz.permute(0, 2, 1)\n        new_points_concat = torch.cat(new_points_list, dim=1)\n        return new_xyz, new_points_concat\n\n\nclass PointNetFeaturePropagation(nn.Module):\n    def __init__(self, in_channel, mlp):\n        super(PointNetFeaturePropagation, self).__init__()\n        self.mlp_convs = nn.ModuleList()\n        self.mlp_bns = nn.ModuleList()\n        last_channel = in_channel\n        for out_channel in mlp:\n            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))\n            self.mlp_bns.append(nn.BatchNorm1d(out_channel))\n            last_channel = out_channel\n\n    def forward(self, xyz1, xyz2, points1, points2):\n        \"\"\"\n        Input:\n            xyz1: input points position data, [B, C, N]\n            xyz2: sampled input points position data, [B, C, S]\n            points1: input points data, [B, D, N]\n            points2: input points data, [B, D, S]\n        Return:\n            new_points: upsampled points data, [B, D', N]\n        \"\"\"\n        xyz1 = xyz1.permute(0, 2, 1)\n        xyz2 = xyz2.permute(0, 2, 1)\n\n        points2 = points2.permute(0, 2, 1)\n        B, N, C = xyz1.shape\n        _, S, _ = xyz2.shape\n\n        if S == 1:\n            interpolated_points = points2.repeat(1, N, 1)\n        else:\n            dists = square_distance(xyz1, xyz2)\n            dists, idx = dists.sort(dim=-1)\n            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]\n\n            dist_recip = 1.0 / (dists + 1e-8)\n            norm = torch.sum(dist_recip, dim=2, keepdim=True)\n            weight = dist_recip / norm\n            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)\n\n        if points1 is not None:\n            points1 = points1.permute(0, 2, 1)\n            new_points = torch.cat([points1, interpolated_points], dim=-1)\n        else:\n            new_points = interpolated_points\n\n        new_points = new_points.permute(0, 2, 1)\n        for i, conv in enumerate(self.mlp_convs):\n            bn = self.mlp_bns[i]\n            new_points = F.relu(bn(conv(new_points)))\n        return new_points"
  },
  {
    "path": "segmentation/models/pt.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom timm.models.layers import DropPath, trunc_normal_\nfrom logger import get_missing_parameters_message, get_unexpected_parameters_message\n\nfrom pointnet2_ops import pointnet2_utils\nfrom knn_cuda import KNN\nfrom pointnet2_utils import PointNetFeaturePropagation\nfrom gpt2_seg import GPT_extractor, GPT_generator\nimport math\nfrom extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2\nimport numpy as np\nfrom z_order import *\n\n\ndef fps(data, number):\n    '''\n        data B N 3\n        number int\n    '''\n    fps_idx = pointnet2_utils.furthest_point_sample(data, number)\n    fps_data = pointnet2_utils.gather_operation(data.transpose(\n        1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()\n    return fps_data\n\n\nclass Group(nn.Module):\n    def __init__(self, num_group, group_size):\n        super().__init__()\n        self.num_group = num_group\n        self.group_size = group_size\n        self.knn = KNN(k=self.group_size, transpose_mode=True)\n        self.knn_2 = KNN(k=1, transpose_mode=True)\n\n    def simplied_morton_sorting(self, xyz, center):\n        batch_size, num_points, _ = xyz.shape\n        distances_batch = torch.cdist(center, center)\n        distances_batch[:, torch.eye(self.num_group).bool()] = float(\"inf\")\n        idx_base = torch.arange(\n            0, batch_size, device=xyz.device) * self.num_group\n        sorted_indices_list = []\n        sorted_indices_list.append(idx_base)\n        distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(\n            1, 2).contiguous().view(batch_size * self.num_group, self.num_group)\n        distances_batch[idx_base] = float(\"inf\")\n        distances_batch = distances_batch.view(\n            batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()\n        for i in range(self.num_group - 1):\n            distances_batch = distances_batch.view(\n                batch_size * self.num_group, self.num_group)\n            distances_to_last_batch = distances_batch[sorted_indices_list[-1]]\n            closest_point_idx = torch.argmin(distances_to_last_batch, dim=-1)\n            closest_point_idx = closest_point_idx + idx_base\n            sorted_indices_list.append(closest_point_idx)\n            distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(\n                1, 2).contiguous().view(batch_size * self.num_group, self.num_group)\n            distances_batch[closest_point_idx] = float(\"inf\")\n            distances_batch = distances_batch.view(\n                batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()\n        sorted_indices = torch.stack(sorted_indices_list, dim=-1)\n        sorted_indices = sorted_indices.view(-1)\n        return sorted_indices\n\n    def morton_sorting(self, xyz, center):\n        batch_size, num_points, _ = xyz.shape\n        all_indices = []\n        for index in range(batch_size):\n            points = center[index]\n            z = get_z_values(points.cpu().numpy())\n            idxs = np.zeros((self.num_group), dtype=np.int32)\n            temp = np.arange(self.num_group)\n            z_ind = np.argsort(z[temp])\n            idxs = temp[z_ind]\n            all_indices.append(idxs)\n        all_indices = torch.tensor(all_indices, device=xyz.device)\n\n        idx_base = torch.arange(\n            0, batch_size, device=xyz.device).view(-1, 1) * self.num_group\n        sorted_indices = all_indices + idx_base\n        sorted_indices = sorted_indices.view(-1)\n\n    def forward(self, xyz):\n        '''\n            input: B N 3\n            ---------------------------\n            output: B G M 3\n            center : B G 3\n        '''\n        batch_size, num_points, _ = xyz.shape\n        # fps the centers out\n        center = fps(xyz, self.num_group)  # B G 3\n        # knn to get the neighborhood\n        _, idx = self.knn(xyz, center)  # B G M\n        assert idx.size(1) == self.num_group\n        assert idx.size(2) == self.group_size\n        idx_base = torch.arange(\n            0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points\n        idx = idx + idx_base\n        idx = idx.view(-1)\n        neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]\n        neighborhood = neighborhood.view(\n            batch_size, self.num_group, self.group_size, 3).contiguous()\n        # normalize\n        neighborhood = neighborhood - center.unsqueeze(2)\n\n        # can utilize morton_sorting by choosing morton_sorting function\n        sorted_indices = self.simplied_morton_sorting(xyz, center)\n\n        neighborhood = neighborhood.view(\n            batch_size * self.num_group, self.group_size, 3)[sorted_indices, :, :]\n        neighborhood = neighborhood.view(\n            batch_size, self.num_group, self.group_size, 3).contiguous()\n        center = center.view(\n            batch_size * self.num_group, 3)[sorted_indices, :]\n        center = center.view(\n            batch_size, self.num_group, 3).contiguous()\n\n        return neighborhood, center\n\nclass Encoder_small(nn.Module):\n    def __init__(self, encoder_channel):\n        super().__init__()\n        self.encoder_channel = encoder_channel\n        self.first_conv = nn.Sequential(\n            nn.Conv1d(3, 128, 1),\n            nn.BatchNorm1d(128),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(128, 256, 1)\n        )\n        self.second_conv = nn.Sequential(\n            nn.Conv1d(512, 512, 1),\n            nn.BatchNorm1d(512),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(512, self.encoder_channel, 1)\n        )\n\n    def forward(self, point_groups):\n        '''\n            point_groups : B G N 3\n            -----------------\n            feature_global : B G C\n        '''\n        bs, g, n, _ = point_groups.shape\n        point_groups = point_groups.reshape(bs * g, n, 3)\n        # encoder\n        feature = self.first_conv(point_groups.transpose(2, 1))\n        feature_global = torch.max(feature, dim=2, keepdim=True)[0]\n        feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)\n        feature = self.second_conv(feature)\n        feature_global = torch.max(feature, dim=2, keepdim=False)[0]\n        return feature_global.reshape(bs, g, self.encoder_channel)\n\nclass Encoder_large(nn.Module):  # Embedding module\n    def __init__(self, encoder_channel):\n        super().__init__()\n        self.encoder_channel = encoder_channel\n        self.first_conv = nn.Sequential(\n            nn.Conv1d(3, 256, 1),\n            nn.BatchNorm1d(256),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(256, 512, 1),\n            nn.BatchNorm1d(512),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(512, 1024, 1)\n        )\n        self.second_conv = nn.Sequential(\n            nn.Conv1d(2048, 2048, 1),\n            nn.BatchNorm1d(2048),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(2048, self.encoder_channel, 1)\n        )\n\n    def forward(self, point_groups):\n        '''\n            point_groups : B G N 3\n            -----------------\n            feature_global : B G C\n        '''\n        bs, g, n, _ = point_groups.shape\n        point_groups = point_groups.reshape(bs * g, n, 3)\n        # encoder\n        feature = self.first_conv(point_groups.transpose(2, 1))  # BG 256 n\n        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # BG 256 1\n        feature = torch.cat(\n            [feature_global.expand(-1, -1, n), feature], dim=1)  # BG 512 n\n        feature = self.second_conv(feature)  # BG 1024 n\n        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # BG 1024\n        return feature_global.reshape(bs, g, self.encoder_channel)\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //\n                                  self.num_heads).permute(2, 0, 3, 1, 4)\n        # make torchscript happy (cannot use tensor as tuple)\n        q, k, v = qkv[0], qkv[1], qkv[2]\n\n        attn = (q * self.scale) @ k.transpose(-2, -1)\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.drop_path = DropPath(\n            drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,\n                       act_layer=act_layer, drop=drop)\n\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass TransformerEncoder(nn.Module):\n    \"\"\" Transformer Encoder without hierarchical structure\n    \"\"\"\n\n    def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):\n        super().__init__()\n\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate,\n                drop_path=drop_path_rate[i] if isinstance(\n                    drop_path_rate, list) else drop_path_rate\n            )\n            for i in range(depth)])\n\n    def forward(self, x, pos):\n        feature_list = []\n        fetch_idx = [7, 15, 23]\n        for i, block in enumerate(self.blocks):\n            x = block(x + pos)\n            if i in fetch_idx:\n                feature_list.append(x)\n        return feature_list\n\n\nclass PositionEmbeddingCoordsSine(nn.Module):\n    \"\"\"Similar to transformer's position encoding, but generalizes it to\n    arbitrary dimensions and continuous coordinates.\n\n    Args:\n        n_dim: Number of input dimensions, e.g. 2 for image coordinates.\n        d_model: Number of dimensions to encode into\n        temperature:\n        scale:\n    \"\"\"\n\n    def __init__(self, n_dim: int = 1, d_model: int = 256, temperature=10000, scale=None):\n        super().__init__()\n\n        self.n_dim = n_dim\n        self.num_pos_feats = d_model // n_dim // 2 * 2\n        self.temperature = temperature\n        self.padding = d_model - self.num_pos_feats * self.n_dim\n\n        if scale is None:\n            scale = 1.0\n        self.scale = scale * 2 * math.pi\n\n    def forward(self, xyz: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Args:\n            xyz: Point positions (*, d_in)\n\n        Returns:\n            pos_emb (*, d_out)\n        \"\"\"\n        assert xyz.shape[-1] == self.n_dim\n\n        dim_t = torch.arange(self.num_pos_feats,\n                             dtype=torch.float32, device=xyz.device)\n        dim_t = self.temperature ** (2 * torch.div(dim_t,\n                                     2, rounding_mode='trunc') / self.num_pos_feats)\n\n        xyz = xyz * self.scale\n        pos_divided = xyz.unsqueeze(-1) / dim_t\n        pos_sin = pos_divided[..., 0::2].sin()\n        pos_cos = pos_divided[..., 1::2].cos()\n        pos_emb = torch.stack([pos_sin, pos_cos], dim=-\n                              1).reshape(*xyz.shape[:-1], -1)\n\n        # Pad unused dimensions with zeros\n        pos_emb = F.pad(pos_emb, (0, self.padding))\n        return pos_emb\n\n\nclass get_model(nn.Module):\n    def __init__(self, cls_dim, trans_dim=384, depth=12, drop_path_rate=0.1, num_heads=6, decoder_depth=4, group_size=32, num_group=128, prop_dim=1024, label_dim1=512, label_dim2=256, encoder_dims=384):\n        super().__init__()\n\n        self.trans_dim = trans_dim\n        self.depth = depth\n        self.drop_path_rate = drop_path_rate\n        self.cls_dim = cls_dim\n        self.num_heads = num_heads\n\n        self.decoder_depth = decoder_depth\n\n        self.group_size = group_size\n        self.num_group = num_group\n\n        self.prop_dim = prop_dim\n\n        self.label_dim1 = label_dim1\n        self.label_dim2 = label_dim2\n        # grouper\n        self.group_divider = Group(\n            num_group=self.num_group, group_size=self.group_size)\n        # define the encoder\n        self.encoder_dims = encoder_dims\n        assert encoder_dims in [384, 768, 1024]\n        if encoder_dims == 384:\n            self.encoder = Encoder_small(encoder_channel=self.encoder_dims)\n        else:\n            self.encoder = Encoder_large(encoder_channel=self.encoder_dims)\n        # bridge encoder and transformer\n\n        self.pos_embed = PositionEmbeddingCoordsSine(3, self.encoder_dims, 1.0)\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))\n        self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))\n\n        self.sos_pos = nn.Parameter(torch.zeros(1, 1, self.trans_dim))\n\n        self.blocks = GPT_extractor(\n            embed_dim=self.encoder_dims,\n            num_heads=self.num_heads,\n            num_layers=self.depth,\n            trans_dim=self.trans_dim,\n            group_size=self.group_size,\n            drop_path_rate=self.drop_path_rate\n        )\n\n        self.generator_blocks = GPT_generator(\n            embed_dim=self.encoder_dims,\n            num_heads=self.num_heads,\n            num_layers=self.decoder_depth,\n            trans_dim=self.trans_dim,\n            group_size=self.group_size,\n            drop_path_rate=self.drop_path_rate\n        )\n\n        self.norm = nn.LayerNorm(self.trans_dim)\n\n        self.label_conv = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),\n                                        nn.BatchNorm1d(64),\n                                        nn.LeakyReLU(0.2))\n\n        self.propagation_0 = PointNetFeaturePropagation(in_channel=3 * self.encoder_dims + 3,\n                                                        mlp=[self.trans_dim * 4, self.prop_dim])\n\n        self.convs1 = nn.Conv1d(6*self.encoder_dims +\n                                64 + self.prop_dim, self.label_dim1, 1)\n        self.dp1 = nn.Dropout(0.5)\n        self.convs2 = nn.Conv1d(self.label_dim1, self.label_dim2, 1)\n        self.convs3 = nn.Conv1d(self.label_dim2, self.cls_dim, 1)\n        self.bns1 = nn.BatchNorm1d(self.label_dim1)\n        self.bns2 = nn.BatchNorm1d(self.label_dim2)\n\n        self.relu = nn.ReLU()\n\n        self.loss_func_p1 = ChamferDistanceL1().cuda()\n        self.loss_func_p2 = ChamferDistanceL2().cuda()\n\n    def get_loss_acc(self, ret, gt):\n        loss = self.loss_ce(ret, gt.long())\n        pred = ret.argmax(-1)\n        acc = (pred == gt).sum() / float(gt.size(0))\n        return loss, acc * 100\n\n    def load_model_from_ckpt(self, bert_ckpt_path):\n        if bert_ckpt_path is not None:\n            ckpt = torch.load(bert_ckpt_path)\n            base_ckpt = {k.replace(\"module.\", \"\"): v for k,\n                         v in ckpt['base_model'].items()}\n\n            for k in list(base_ckpt.keys()):\n                if k.startswith('GPT_Transformer'):\n                    base_ckpt[k[len('GPT_Transformer.'):]] = base_ckpt[k]\n                    del base_ckpt[k]\n                elif k.startswith('base_model'):\n                    base_ckpt[k[len('base_model.'):]] = base_ckpt[k]\n                    del base_ckpt[k]\n\n            incompatible = self.load_state_dict(base_ckpt, strict=False)\n\n            if incompatible.missing_keys:\n                print('missing_keys')\n                print(\n                    get_missing_parameters_message(incompatible.missing_keys)\n                )\n            if incompatible.unexpected_keys:\n                print('unexpected_keys')\n                print(\n                    get_unexpected_parameters_message(\n                        incompatible.unexpected_keys)\n\n                )\n\n            print(\n                f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')\n\n    def forward(self, pts, cls_label):\n        B, C, N = pts.shape\n        pts = pts.transpose(-1, -2)  # B N 3\n\n        neighborhood, center = self.group_divider(pts)\n        group_input_tokens = self.encoder(neighborhood)  # B G N\n\n        B = group_input_tokens.shape[0]\n\n        cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)\n        cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)\n\n        pos = self.pos_embed(center)\n        sos_pos = self.sos_pos.expand(group_input_tokens.size(0), -1, -1)\n        pos = torch.cat([sos_pos, pos], dim=1)\n\n        relative_position = center[:, 1:, :] - center[:, :-1, :]\n        relative_norm = torch.norm(relative_position, dim=-1, keepdim=True)\n        relative_direction = relative_position / relative_norm\n        position = torch.cat(\n            [center[:, 0, :].unsqueeze(1), relative_direction], dim=1)\n        pos_relative = self.pos_embed(position)\n\n        x = torch.cat((cls_tokens, group_input_tokens), dim=1)\n        pos = torch.cat((cls_pos, pos), dim=1)\n        # transformer\n        encoded_features, feature_list = self.blocks(x, pos, classify=True)\n\n        encoded_features = torch.cat(\n            [encoded_features[:, 0, :].unsqueeze(1), encoded_features[:, 2:-1, :]], dim=1)\n\n        rebuild_points = self.generator_blocks(\n            encoded_features, pos_relative)\n\n        neighborhood = neighborhood + center.unsqueeze(2)\n\n        gt_points = neighborhood.reshape(\n            B*(self.num_group), self.group_size, 3)\n\n        loss1 = self.loss_func_p1(rebuild_points, gt_points)\n        loss2 = self.loss_func_p2(rebuild_points, gt_points)\n\n        feature_list = [self.norm(x).transpose(-1, -2).contiguous()\n                        for x in feature_list]\n        x = torch.cat(\n            (feature_list[0], feature_list[1], feature_list[2]), dim=1)  # 1152\n        x_max = torch.max(x, 2)[0]\n        x_avg = torch.mean(x, 2)\n        x_max_feature = x_max.view(B, -1).unsqueeze(-1).repeat(1, 1, N)\n        x_avg_feature = x_avg.view(B, -1).unsqueeze(-1).repeat(1, 1, N)\n        cls_label_one_hot = cls_label.view(B, 16, 1)\n        cls_label_feature = self.label_conv(cls_label_one_hot).repeat(1, 1, N)\n        x_global_feature = torch.cat(\n            (x_max_feature, x_avg_feature, cls_label_feature), 1)  # 1152*2 + 64\n\n        f_level_0 = self.propagation_0(\n            pts.transpose(-1, -2), center.transpose(-1, -2), pts.transpose(-1, -2), x)\n\n        x = torch.cat((f_level_0, x_global_feature), 1)\n        x = self.relu(self.bns1(self.convs1(x)))\n        x = self.dp1(x)\n        x = self.relu(self.bns2(self.convs2(x)))\n        x = self.convs3(x)\n        x = F.log_softmax(x, dim=1)\n        x = x.permute(0, 2, 1)\n        return x\n\n\nclass get_loss(nn.Module):\n    def __init__(self):\n        super(get_loss, self).__init__()\n\n    def forward(self, pred, target):\n        total_loss = F.nll_loss(pred, target)\n        return total_loss\n"
  },
  {
    "path": "segmentation/models/z_order.py",
    "content": "import numpy as np\n\n\ndef round_to_int_32(data):\n    \"\"\"\n    Takes a Numpy array of float values between\n    -1 and 1, and rounds them to significant\n    32-bit integer values, to be used in the\n    morton code computation\n\n    :param data: multidimensional numpy array\n    :return: same as data but in 32-bit int format\n    \"\"\"\n    # first we rescale points to 0-512\n    min_data = np.abs(np.min(data)-0.5)\n    data = 256*(data + min_data)\n    # now convert to int\n    data = np.round(2 ** 21 - data).astype(dtype=np.int32)\n\n    return data\n\n\ndef split_by_3(x):\n    \"\"\"\n    Method to separate bits of a 32-bit integer\n    by 3 positions apart, using the magic bits\n    https://www.forceflow.be/2013/10/07/morton-encodingdecoding-through-bit-interleaving-implementations/\n\n    :param x: 32-bit integer\n    :return: x with bits separated\n    \"\"\"\n    # we only look at 21 bits, since we want to generate\n    # a 64-bit code eventually (3 x 21 bits = 63 bits, which\n    # is the maximum we can fit in a 64-bit code)\n    x &= 0x1fffff  # only take first 21 bits\n    # shift left 32 bits, OR with self, and 00011111000000000000000000000000000000001111111111111111\n    x = (x | (x << 32)) & 0x1f00000000ffff\n    # shift left 16 bits, OR with self, and 00011111000000000000000011111111000000000000000011111111\n    x = (x | (x << 16)) & 0x1f0000ff0000ff\n    # shift left 8 bits, OR with self, and 0001000000001111000000001111000000001111000000001111000000000000\n    x = (x | (x << 8)) & 0x100f00f00f00f00f\n    # shift left 4 bits, OR with self, and 0001000011000011000011000011000011000011000011000011000100000000\n    x = (x | (x << 4)) & 0x10c30c30c30c30c3\n    # shift left 2 bits, OR with self, and 0001001001001001001001001001001001001001001001001001001001001001\n    x = (x | (x << 2)) & 0x1249249249249249\n\n    return x\n\n\ndef get_z_order(x, y, z):\n    \"\"\"\n    Given 3 arrays of corresponding x, y, z\n    coordinates, compute the morton (or z) code for\n    each point and return an index array\n    We compute the Morton order as follows:\n        1- Split all coordinates by 3 (add 2 zeros between bits)\n        2- Shift bits left by 1 for y and 2 for z\n        3- Interleave x, shifted y, and shifted z\n    The mordon order is the final interleaved bit sequence\n\n    :param x: x coordinates\n    :param y: y coordinates\n    :param z: z coordinates\n    :return: index array with morton code\n    \"\"\"\n    res = 0\n    res |= split_by_3(x) | split_by_3(y) << 1 | split_by_3(z) << 2\n\n    return res\n\n\ndef get_z_values(data):\n    \"\"\"\n    Computes the z values for a point array\n    :param data: Nx3 array of x, y, and z location\n\n    :return: Nx1 array of z values\n    \"\"\"\n    points_round = round_to_int_32(data)  # convert to int\n    z = get_z_order(points_round[:, 0], points_round[:, 1], points_round[:, 2])\n\n    return z\n"
  },
  {
    "path": "segmentation/pointnet_util.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom time import time\nimport numpy as np\n\n\n# reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You\n\n\ndef timeit(tag, t):\n    print(\"{}: {}s\".format(tag, time() - t))\n    return time()\n\ndef pc_normalize(pc):\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\ndef square_distance(src, dst):\n    \"\"\"\n    Calculate Euclid distance between each two points.\n    src^T * dst = xn * xm + yn * ym + zn * zm；\n    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;\n    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;\n    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2\n         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst\n    Input:\n        src: source points, [B, N, C]\n        dst: target points, [B, M, C]\n    Output:\n        dist: per-point square distance, [B, N, M]\n    \"\"\"\n    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)\n\n\ndef index_points(points, idx):\n    \"\"\"\n    Input:\n        points: input points data, [B, N, C]\n        idx: sample index data, [B, S, [K]]\n    Return:\n        new_points:, indexed points data, [B, S, [K], C]\n    \"\"\"\n    raw_size = idx.size()\n    idx = idx.reshape(raw_size[0], -1)\n    res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))\n    return res.reshape(*raw_size, -1)\n\n\ndef farthest_point_sample(xyz, npoint):\n    \"\"\"\n    Input:\n        xyz: pointcloud data, [B, N, 3]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [B, npoint]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)\n    distance = torch.ones(B, N).to(device) * 1e10\n    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)\n    batch_indices = torch.arange(B, dtype=torch.long).to(device)\n    for i in range(npoint):\n        centroids[:, i] = farthest\n        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)\n        dist = torch.sum((xyz - centroid) ** 2, -1)\n        distance = torch.min(distance, dist)\n        farthest = torch.max(distance, -1)[1]\n    return centroids\n\n\ndef query_ball_point(radius, nsample, xyz, new_xyz):\n    \"\"\"\n    Input:\n        radius: local region radius\n        nsample: max sample number in local region\n        xyz: all points, [B, N, 3]\n        new_xyz: query points, [B, S, 3]\n    Return:\n        group_idx: grouped points index, [B, S, nsample]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    _, S, _ = new_xyz.shape\n    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])\n    sqrdists = square_distance(new_xyz, xyz)\n    group_idx[sqrdists > radius ** 2] = N\n    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]\n    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])\n    mask = group_idx == N\n    group_idx[mask] = group_first[mask]\n    return group_idx\n\n\ndef sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False):\n    \"\"\"\n    Input:\n        npoint:\n        radius:\n        nsample:\n        xyz: input points position data, [B, N, 3]\n        points: input points data, [B, N, D]\n    Return:\n        new_xyz: sampled points position data, [B, npoint, nsample, 3]\n        new_points: sampled points data, [B, npoint, nsample, 3+D]\n    \"\"\"\n    B, N, C = xyz.shape\n    S = npoint\n    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]\n    torch.cuda.empty_cache()\n    new_xyz = index_points(xyz, fps_idx)\n    torch.cuda.empty_cache()\n    if knn:\n        dists = square_distance(new_xyz, xyz)  # B x npoint x N\n        idx = dists.argsort()[:, :, :nsample]  # B x npoint x K\n    else:\n        idx = query_ball_point(radius, nsample, xyz, new_xyz)\n    torch.cuda.empty_cache()\n    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]\n    torch.cuda.empty_cache()\n    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)\n    torch.cuda.empty_cache()\n\n    if points is not None:\n        grouped_points = index_points(points, idx)\n        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]\n    else:\n        new_points = grouped_xyz_norm\n    if returnfps:\n        return new_xyz, new_points, grouped_xyz, fps_idx\n    else:\n        return new_xyz, new_points\n\n\ndef sample_and_group_all(xyz, points):\n    \"\"\"\n    Input:\n        xyz: input points position data, [B, N, 3]\n        points: input points data, [B, N, D]\n    Return:\n        new_xyz: sampled points position data, [B, 1, 3]\n        new_points: sampled points data, [B, 1, N, 3+D]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    new_xyz = torch.zeros(B, 1, C).to(device)\n    grouped_xyz = xyz.view(B, 1, N, C)\n    if points is not None:\n        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)\n    else:\n        new_points = grouped_xyz\n    return new_xyz, new_points\n\n\nclass PointNetSetAbstraction(nn.Module):\n    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False):\n        super(PointNetSetAbstraction, self).__init__()\n        self.npoint = npoint\n        self.radius = radius\n        self.nsample = nsample\n        self.knn = knn\n        self.mlp_convs = nn.ModuleList()\n        self.mlp_bns = nn.ModuleList()\n        last_channel = in_channel\n        for out_channel in mlp:\n            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))\n            self.mlp_bns.append(nn.BatchNorm2d(out_channel))\n            last_channel = out_channel\n        self.group_all = group_all\n\n    def forward(self, xyz, points):\n        \"\"\"\n        Input:\n            xyz: input points position data, [B, N, C]\n            points: input points data, [B, N, C]\n        Return:\n            new_xyz: sampled points position data, [B, S, C]\n            new_points_concat: sample points feature data, [B, S, D']\n        \"\"\"\n        if self.group_all:\n            new_xyz, new_points = sample_and_group_all(xyz, points)\n        else:\n            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn)\n        # new_xyz: sampled points position data, [B, npoint, C]\n        # new_points: sampled points data, [B, npoint, nsample, C+D]\n        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]\n        for i, conv in enumerate(self.mlp_convs):\n            bn = self.mlp_bns[i]\n            new_points =  F.relu(bn(conv(new_points)))\n\n        new_points = torch.max(new_points, 2)[0].transpose(1, 2)\n        return new_xyz, new_points\n\n\nclass PointNetSetAbstractionMsg(nn.Module):\n    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False):\n        super(PointNetSetAbstractionMsg, self).__init__()\n        self.npoint = npoint\n        self.radius_list = radius_list\n        self.nsample_list = nsample_list\n        self.knn = knn\n        self.conv_blocks = nn.ModuleList()\n        self.bn_blocks = nn.ModuleList()\n        for i in range(len(mlp_list)):\n            convs = nn.ModuleList()\n            bns = nn.ModuleList()\n            last_channel = in_channel + 3\n            for out_channel in mlp_list[i]:\n                convs.append(nn.Conv2d(last_channel, out_channel, 1))\n                bns.append(nn.BatchNorm2d(out_channel))\n                last_channel = out_channel\n            self.conv_blocks.append(convs)\n            self.bn_blocks.append(bns)\n\n    def forward(self, xyz, points, seed_idx=None):\n        \"\"\"\n        Input:\n            xyz: input points position data, [B, C, N]\n            points: input points data, [B, D, N]\n        Return:\n            new_xyz: sampled points position data, [B, C, S]\n            new_points_concat: sample points feature data, [B, D', S]\n        \"\"\"\n\n        B, N, C = xyz.shape\n        S = self.npoint\n        new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx)\n        new_points_list = []\n        for i, radius in enumerate(self.radius_list):\n            K = self.nsample_list[i]\n            if self.knn:\n                dists = square_distance(new_xyz, xyz)  # B x npoint x N\n                group_idx = dists.argsort()[:, :, :K]  # B x npoint x K\n            else:\n                group_idx = query_ball_point(radius, K, xyz, new_xyz)\n            grouped_xyz = index_points(xyz, group_idx)\n            grouped_xyz -= new_xyz.view(B, S, 1, C)\n            if points is not None:\n                grouped_points = index_points(points, group_idx)\n                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)\n            else:\n                grouped_points = grouped_xyz\n\n            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]\n            for j in range(len(self.conv_blocks[i])):\n                conv = self.conv_blocks[i][j]\n                bn = self.bn_blocks[i][j]\n                grouped_points =  F.relu(bn(conv(grouped_points)))\n            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]\n            new_points_list.append(new_points)\n\n        new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2)\n        return new_xyz, new_points_concat\n\n\n# NoteL this function swaps N and C\nclass PointNetFeaturePropagation(nn.Module):\n    def __init__(self, in_channel, mlp):\n        super(PointNetFeaturePropagation, self).__init__()\n        self.mlp_convs = nn.ModuleList()\n        self.mlp_bns = nn.ModuleList()\n        last_channel = in_channel\n        for out_channel in mlp:\n            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))\n            self.mlp_bns.append(nn.BatchNorm1d(out_channel))\n            last_channel = out_channel\n\n    def forward(self, xyz1, xyz2, points1, points2):\n        \"\"\"\n        Input:\n            xyz1: input points position data, [B, C, N]\n            xyz2: sampled input points position data, [B, C, S]\n            points1: input points data, [B, D, N]\n            points2: input points data, [B, D, S]\n        Return:\n            new_points: upsampled points data, [B, D', N]\n        \"\"\"\n        xyz1 = xyz1.permute(0, 2, 1)\n        xyz2 = xyz2.permute(0, 2, 1)\n\n        points2 = points2.permute(0, 2, 1)\n        B, N, C = xyz1.shape\n        _, S, _ = xyz2.shape\n\n        if S == 1:\n            interpolated_points = points2.repeat(1, N, 1)\n        else:\n            dists = square_distance(xyz1, xyz2)\n            dists, idx = dists.sort(dim=-1)\n            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]\n\n            dist_recip = 1.0 / (dists + 1e-8)\n            norm = torch.sum(dist_recip, dim=2, keepdim=True)\n            weight = dist_recip / norm\n            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)\n\n        if points1 is not None:\n            points1 = points1.permute(0, 2, 1)\n            new_points = torch.cat([points1, interpolated_points], dim=-1)\n        else:\n            new_points = interpolated_points\n\n        new_points = new_points.permute(0, 2, 1)\n        for i, conv in enumerate(self.mlp_convs):\n            bn = self.mlp_bns[i]\n            new_points = F.relu(bn(conv(new_points)))\n        return new_points\n\n"
  },
  {
    "path": "segmentation/provider.py",
    "content": "import numpy as np\n\ndef normalize_data(batch_data):\n    \"\"\" Normalize the batch data, use coordinates of the block centered at origin,\n        Input:\n            BxNxC array\n        Output:\n            BxNxC array\n    \"\"\"\n    B, N, C = batch_data.shape\n    normal_data = np.zeros((B, N, C))\n    for b in range(B):\n        pc = batch_data[b]\n        centroid = np.mean(pc, axis=0)\n        pc = pc - centroid\n        m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))\n        pc = pc / m\n        normal_data[b] = pc\n    return normal_data\n\n\ndef shuffle_data(data, labels):\n    \"\"\" Shuffle data and labels.\n        Input:\n          data: B,N,... numpy array\n          label: B,... numpy array\n        Return:\n          shuffled data, label and shuffle indices\n    \"\"\"\n    idx = np.arange(len(labels))\n    np.random.shuffle(idx)\n    return data[idx, ...], labels[idx], idx\n\ndef shuffle_points(batch_data):\n    \"\"\" Shuffle orders of points in each point cloud -- changes FPS behavior.\n        Use the same shuffling idx for the entire batch.\n        Input:\n            BxNxC array\n        Output:\n            BxNxC array\n    \"\"\"\n    idx = np.arange(batch_data.shape[1])\n    np.random.shuffle(idx)\n    return batch_data[:,idx,:]\n\ndef rotate_point_cloud(batch_data):\n    \"\"\" Randomly rotate the point clouds to augument the dataset\n        rotation is per shape based along up direction\n        Input:\n          BxNx3 array, original batch of point clouds\n        Return:\n          BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array([[cosval, 0, sinval],\n                                    [0, 1, 0],\n                                    [-sinval, 0, cosval]])\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)\n    return rotated_data\n\ndef rotate_point_cloud_z(batch_data):\n    \"\"\" Randomly rotate the point clouds to augument the dataset\n        rotation is per shape based along up direction\n        Input:\n          BxNx3 array, original batch of point clouds\n        Return:\n          BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array([[cosval, sinval, 0],\n                                    [-sinval, cosval, 0],\n                                    [0, 0, 1]])\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)\n    return rotated_data\n\ndef rotate_point_cloud_with_normal(batch_xyz_normal):\n    ''' Randomly rotate XYZ, normal point cloud.\n        Input:\n            batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal\n        Output:\n            B,N,6, rotated XYZ, normal point cloud\n    '''\n    for k in range(batch_xyz_normal.shape[0]):\n        rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array([[cosval, 0, sinval],\n                                    [0, 1, 0],\n                                    [-sinval, 0, cosval]])\n        shape_pc = batch_xyz_normal[k,:,0:3]\n        shape_normal = batch_xyz_normal[k,:,3:6]\n        batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)\n        batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix)\n    return batch_xyz_normal\n\ndef rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18):\n    \"\"\" Randomly perturb the point clouds by small rotations\n        Input:\n          BxNx6 array, original batch of point clouds and point normals\n        Return:\n          BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)\n        Rx = np.array([[1,0,0],\n                       [0,np.cos(angles[0]),-np.sin(angles[0])],\n                       [0,np.sin(angles[0]),np.cos(angles[0])]])\n        Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],\n                       [0,1,0],\n                       [-np.sin(angles[1]),0,np.cos(angles[1])]])\n        Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],\n                       [np.sin(angles[2]),np.cos(angles[2]),0],\n                       [0,0,1]])\n        R = np.dot(Rz, np.dot(Ry,Rx))\n        shape_pc = batch_data[k,:,0:3]\n        shape_normal = batch_data[k,:,3:6]\n        rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R)\n        rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R)\n    return rotated_data\n\n\ndef rotate_point_cloud_by_angle(batch_data, rotation_angle):\n    \"\"\" Rotate the point cloud along up direction with certain angle.\n        Input:\n          BxNx3 array, original batch of point clouds\n        Return:\n          BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        #rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array([[cosval, 0, sinval],\n                                    [0, 1, 0],\n                                    [-sinval, 0, cosval]])\n        shape_pc = batch_data[k,:,0:3]\n        rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)\n    return rotated_data\n\ndef rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle):\n    \"\"\" Rotate the point cloud along up direction with certain angle.\n        Input:\n          BxNx6 array, original batch of point clouds with normal\n          scalar, angle of rotation\n        Return:\n          BxNx6 array, rotated batch of point clouds iwth normal\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        #rotation_angle = np.random.uniform() * 2 * np.pi\n        cosval = np.cos(rotation_angle)\n        sinval = np.sin(rotation_angle)\n        rotation_matrix = np.array([[cosval, 0, sinval],\n                                    [0, 1, 0],\n                                    [-sinval, 0, cosval]])\n        shape_pc = batch_data[k,:,0:3]\n        shape_normal = batch_data[k,:,3:6]\n        rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)\n        rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix)\n    return rotated_data\n\n\n\ndef rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):\n    \"\"\" Randomly perturb the point clouds by small rotations\n        Input:\n          BxNx3 array, original batch of point clouds\n        Return:\n          BxNx3 array, rotated batch of point clouds\n    \"\"\"\n    rotated_data = np.zeros(batch_data.shape, dtype=np.float32)\n    for k in range(batch_data.shape[0]):\n        angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)\n        Rx = np.array([[1,0,0],\n                       [0,np.cos(angles[0]),-np.sin(angles[0])],\n                       [0,np.sin(angles[0]),np.cos(angles[0])]])\n        Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],\n                       [0,1,0],\n                       [-np.sin(angles[1]),0,np.cos(angles[1])]])\n        Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],\n                       [np.sin(angles[2]),np.cos(angles[2]),0],\n                       [0,0,1]])\n        R = np.dot(Rz, np.dot(Ry,Rx))\n        shape_pc = batch_data[k, ...]\n        rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)\n    return rotated_data\n\n\ndef jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):\n    \"\"\" Randomly jitter points. jittering is per point.\n        Input:\n          BxNx3 array, original batch of point clouds\n        Return:\n          BxNx3 array, jittered batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    assert(clip > 0)\n    jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)\n    jittered_data += batch_data\n    return jittered_data\n\ndef shift_point_cloud(batch_data, shift_range=0.1):\n    \"\"\" Randomly shift point cloud. Shift is per point cloud.\n        Input:\n          BxNx3 array, original batch of point clouds\n        Return:\n          BxNx3 array, shifted batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    shifts = np.random.uniform(-shift_range, shift_range, (B,3))\n    for batch_index in range(B):\n        batch_data[batch_index,:,:] += shifts[batch_index,:]\n    return batch_data\n\n\ndef random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):\n    \"\"\" Randomly scale the point cloud. Scale is per point cloud.\n        Input:\n            BxNx3 array, original batch of point clouds\n        Return:\n            BxNx3 array, scaled batch of point clouds\n    \"\"\"\n    B, N, C = batch_data.shape\n    scales = np.random.uniform(scale_low, scale_high, B)\n    for batch_index in range(B):\n        batch_data[batch_index,:,:] *= scales[batch_index]\n    return batch_data\n\ndef random_point_dropout(batch_pc, max_dropout_ratio=0.875):\n    ''' batch_pc: BxNx3 '''\n    for b in range(batch_pc.shape[0]):\n        dropout_ratio =  np.random.random()*max_dropout_ratio # 0~0.875\n        drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0]\n        if len(drop_idx)>0:\n            batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point\n    return batch_pc\n"
  },
  {
    "path": "tools/__init__.py",
    "content": "# from .runner import run_net\nfrom .runner import test_net\nfrom .runner_pretrain import run_net as pretrain_run_net\nfrom .runner_finetune import run_net as finetune_run_net\nfrom .runner_finetune import test_net as test_run_net"
  },
  {
    "path": "tools/builder.py",
    "content": "import os\nimport sys\n# online package\nimport torch\n# optimizer\nimport torch.optim as optim\n# dataloader\nfrom datasets import build_dataset_from_cfg\nfrom models import build_model_from_cfg\n# utils\nfrom utils.logger import *\nfrom utils.misc import *\nfrom timm.scheduler import CosineLRScheduler\n\n\ndef dataset_builder(args, config):\n    dataset = build_dataset_from_cfg(config._base_, config.others)\n    shuffle = config.others.subset == 'train'\n    if args.distributed:\n        sampler = torch.utils.data.distributed.DistributedSampler(\n            dataset, shuffle=shuffle)\n        dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.others.bs,\n                                                 num_workers=int(\n                                                     args.num_workers),\n                                                 drop_last=config.others.subset == 'train',\n                                                 worker_init_fn=worker_init_fn,\n                                                 sampler=sampler)\n    else:\n        sampler = None\n        dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.others.bs,\n                                                 shuffle=shuffle,\n                                                 drop_last=config.others.subset == 'train',\n                                                 num_workers=int(\n                                                     args.num_workers),\n                                                 worker_init_fn=worker_init_fn)\n    return sampler, dataloader\n\n\ndef model_builder(config):\n    model = build_model_from_cfg(config)\n    return model\n\n\ndef build_opti_sche(base_model, config):\n    opti_config = config.optimizer\n    if opti_config.type == 'AdamW':\n        def add_weight_decay(model, weight_decay=1e-5, skip_list=()):\n            decay = []\n            no_decay = []\n            for name, param in model.module.named_parameters():\n                if not param.requires_grad:\n                    continue  # frozen weights\n                if len(param.shape) == 1 or name.endswith(\".bias\") or 'token' in name or name in skip_list:\n                    # print(name)\n                    no_decay.append(param)\n                else:\n                    decay.append(param)\n            return [\n                {'params': no_decay, 'weight_decay': 0.},\n                {'params': decay, 'weight_decay': weight_decay}]\n        param_groups = add_weight_decay(\n            base_model, weight_decay=opti_config.kwargs.weight_decay)\n        optimizer = optim.AdamW(param_groups, **opti_config.kwargs)\n    elif opti_config.type == 'Adam':\n        optimizer = optim.Adam(base_model.parameters(), **opti_config.kwargs)\n    elif opti_config.type == 'SGD':\n        optimizer = optim.SGD(base_model.parameters(),\n                              nesterov=True, **opti_config.kwargs)\n    else:\n        raise NotImplementedError()\n\n    sche_config = config.scheduler\n    if sche_config.type == 'LambdaLR':\n        scheduler = build_lambda_sche(optimizer, sche_config.kwargs)  # misc.py\n    elif sche_config.type == 'CosLR':\n        scheduler = CosineLRScheduler(optimizer,\n                                      t_initial=sche_config.kwargs.epochs,\n                                      # t_mul=1,\n                                      lr_min=1e-6,\n                                      cycle_decay=0.1,  # decay_rate\n                                      warmup_lr_init=1e-6,\n                                      warmup_t=sche_config.kwargs.initial_epochs,\n                                      cycle_limit=1,\n                                      t_in_epochs=True)\n    elif sche_config.type == 'StepLR':\n        scheduler = torch.optim.lr_scheduler.StepLR(\n            optimizer, **sche_config.kwargs)\n    elif sche_config.type == 'function':\n        scheduler = None\n    else:\n        raise NotImplementedError()\n\n    if config.get('bnmscheduler') is not None:\n        bnsche_config = config.bnmscheduler\n        if bnsche_config.type == 'Lambda':\n            bnscheduler = build_lambda_bnsche(\n                base_model, bnsche_config.kwargs)  # misc.py\n        scheduler = [scheduler, bnscheduler]\n\n    return optimizer, scheduler\n\n\ndef resume_model(base_model, args, logger=None):\n    ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')\n    if not os.path.exists(ckpt_path):\n        print_log(\n            f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger=logger)\n        return 0, 0\n    print_log(\n        f'[RESUME INFO] Loading model weights from {ckpt_path}...', logger=logger)\n\n    # load state dict\n    map_location = {'cuda:%d' % 0: 'cuda:%d' % args.local_rank}\n    state_dict = torch.load(ckpt_path, map_location=map_location)\n    # parameter resume of base model\n    # if args.local_rank == 0:\n    base_ckpt = {k.replace(\"module.\", \"\"): v for k,\n                 v in state_dict['base_model'].items()}\n    base_model.load_state_dict(base_ckpt, strict=True)\n\n    # parameter\n    start_epoch = state_dict['epoch'] + 1\n    best_metrics = state_dict['best_metrics']\n    if not isinstance(best_metrics, dict):\n        best_metrics = best_metrics.state_dict()\n    # print(best_metrics)\n\n    print_log(\n        f'[RESUME INFO] resume ckpts @ {start_epoch - 1} epoch( best_metrics = {str(best_metrics):s})', logger=logger)\n    return start_epoch, best_metrics\n\n\ndef resume_optimizer(optimizer, args, logger=None):\n    ckpt_path = os.path.join(args.experiment_path, 'ckpt-last.pth')\n    if not os.path.exists(ckpt_path):\n        print_log(\n            f'[RESUME INFO] no checkpoint file from path {ckpt_path}...', logger=logger)\n        return 0, 0, 0\n    print_log(\n        f'[RESUME INFO] Loading optimizer from {ckpt_path}...', logger=logger)\n    # load state dict\n    state_dict = torch.load(ckpt_path, map_location='cpu')\n    # optimizer\n    optimizer.load_state_dict(state_dict['optimizer'])\n\n\ndef save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, prefix, args, logger=None):\n    if args.local_rank == 0:\n        torch.save({\n            'base_model': base_model.module.state_dict() if args.distributed else base_model.state_dict(),\n            'optimizer': optimizer.state_dict(),\n            'epoch': epoch,\n            'metrics': metrics.state_dict() if metrics is not None else dict(),\n            'best_metrics': best_metrics.state_dict() if best_metrics is not None else dict(),\n        }, os.path.join(args.experiment_path, prefix + '.pth'))\n        print_log(\n            f\"Save checkpoint at {os.path.join(args.experiment_path, prefix + '.pth')}\", logger=logger)\n\n\ndef load_model(base_model, ckpt_path, logger=None):\n    if not os.path.exists(ckpt_path):\n        raise NotImplementedError(\n            'no checkpoint file from path %s...' % ckpt_path)\n    print_log(f'Loading weights from {ckpt_path}...', logger=logger)\n\n    # load state dict\n    state_dict = torch.load(ckpt_path, map_location='cpu')\n    # parameter resume of base model\n    if state_dict.get('model') is not None:\n        base_ckpt = {k.replace(\"module.\", \"\"): v for k,\n                     v in state_dict['model'].items()}\n    elif state_dict.get('base_model') is not None:\n        base_ckpt = {k.replace(\"module.\", \"\"): v for k,\n                     v in state_dict['base_model'].items()}\n    else:\n        raise RuntimeError('mismatch of ckpt weight')\n    base_model.load_state_dict(base_ckpt, strict=True)\n\n    epoch = -1\n    if state_dict.get('epoch') is not None:\n        epoch = state_dict['epoch']\n    if state_dict.get('metrics') is not None:\n        metrics = state_dict['metrics']\n        if not isinstance(metrics, dict):\n            metrics = metrics.state_dict()\n    else:\n        metrics = 'No Metrics'\n    print_log(\n        f'ckpts @ {epoch} epoch( performance = {str(metrics):s})', logger=logger)\n    return\n"
  },
  {
    "path": "tools/runner.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nimport json\nfrom tools import builder\nfrom utils import misc, dist_utils\nimport time\nfrom utils.logger import *\n\nimport cv2\nimport numpy as np\n\n\ndef test_net(args, config):\n    logger = get_logger(args.log_name)\n    print_log('Tester start ... ', logger=logger)\n    _, test_dataloader = builder.dataset_builder(args, config.dataset.test)\n\n    base_model = builder.model_builder(config.model)\n    # base_model.load_model_from_ckpt(args.ckpts)\n    builder.load_model(base_model, args.ckpts, logger=logger)\n\n    if args.use_gpu:\n        base_model.to(args.local_rank)\n\n    #  DDP\n    if args.distributed:\n        raise NotImplementedError()\n\n    test(base_model, test_dataloader, args, config, logger=logger)\n\n\n# visualization\ndef test(base_model, test_dataloader, args, config, logger=None):\n\n    base_model.eval()  # set model to eval mode\n    target = './vis'\n    useful_cate = [\n        \"02691156\",  # plane\n        \"04379243\",  # table\n        \"03790512\",  # motorbike\n        \"03948459\",  # pistol\n        \"03642806\",  # laptop\n        \"03467517\",  # guitar\n        \"03261776\",  # earphone\n        \"03001627\",  # chair\n        \"02958343\",  # car\n        \"04090263\",  # rifle\n        \"03759954\",  # microphone\n    ]\n    with torch.no_grad():\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):\n            # import pdb; pdb.set_trace()\n            if taxonomy_ids[0] not in useful_cate:\n                continue\n            if taxonomy_ids[0] == \"02691156\":\n                a, b = 90, 135\n            elif taxonomy_ids[0] == \"04379243\":\n                a, b = 30, 30\n            elif taxonomy_ids[0] == \"03642806\":\n                a, b = 30, -45\n            elif taxonomy_ids[0] == \"03467517\":\n                a, b = 0, 90\n            elif taxonomy_ids[0] == \"03261776\":\n                a, b = 0, 75\n            elif taxonomy_ids[0] == \"03001627\":\n                a, b = 30, -45\n            else:\n                a, b = 0, 0\n\n            dataset_name = config.dataset.test._base_.NAME\n            if dataset_name == 'ShapeNet':\n                points = data.cuda()\n            else:\n                raise NotImplementedError(\n                    f'Train phase do not support {dataset_name}')\n\n            # dense_points, vis_points = base_model(points, vis=True)\n            dense_points, vis_points, centers = base_model(points, vis=True)\n            final_image = []\n            data_path = f'./vis/{taxonomy_ids[0]}_{idx}'\n            if not os.path.exists(data_path):\n                os.makedirs(data_path)\n\n            points = points.squeeze().detach().cpu().numpy()\n            np.savetxt(os.path.join(data_path, 'gt.txt'),\n                       points, delimiter=';')\n            points = misc.get_ptcloud_img(points, a, b)\n            final_image.append(points[150:650, 150:675, :])\n\n            # centers = centers.squeeze().detach().cpu().numpy()\n            # np.savetxt(os.path.join(data_path,'center.txt'), centers, delimiter=';')\n            # centers = misc.get_ptcloud_img(centers)\n            # final_image.append(centers)\n\n            vis_points = vis_points.squeeze().detach().cpu().numpy()\n            np.savetxt(os.path.join(data_path, 'vis.txt'),\n                       vis_points, delimiter=';')\n            vis_points = misc.get_ptcloud_img(vis_points, a, b)\n\n            final_image.append(vis_points[150:650, 150:675, :])\n\n            dense_points = dense_points.squeeze().detach().cpu().numpy()\n            np.savetxt(os.path.join(data_path, 'dense_points.txt'),\n                       dense_points, delimiter=';')\n            dense_points = misc.get_ptcloud_img(dense_points, a, b)\n            final_image.append(dense_points[150:650, 150:675, :])\n\n            img = np.concatenate(final_image, axis=1)\n            img_path = os.path.join(data_path, f'plot.jpg')\n            cv2.imwrite(img_path, img)\n\n            if idx > 1500:\n                break\n\n        return\n"
  },
  {
    "path": "tools/runner_finetune.py",
    "content": "import torch\nimport torch.nn as nn\nfrom tools import builder\nfrom utils import misc, dist_utils\nimport time\nfrom utils.logger import *\nfrom utils.AverageMeter import AverageMeter\n\nimport numpy as np\nfrom datasets import data_transforms\nfrom pointnet2_ops import pointnet2_utils\nfrom torchvision import transforms\n\n\ntrain_transforms = transforms.Compose(\n    [\n        # data_transforms.PointcloudScale(),\n        # data_transforms.PointcloudRotate(),\n        # data_transforms.PointcloudTranslate(),\n        # data_transforms.PointcloudJitter(),\n        # data_transforms.PointcloudRandomInputDropout(),\n        # data_transforms.RandomHorizontalFlip(),\n        data_transforms.PointcloudScaleAndTranslate(),\n    ]\n)\n\ntest_transforms = transforms.Compose(\n    [\n        # data_transforms.PointcloudScale(),\n        # data_transforms.PointcloudRotate(),\n        # data_transforms.PointcloudTranslate(),\n        data_transforms.PointcloudScaleAndTranslate(),\n    ]\n)\n\n\nclass Acc_Metric:\n    def __init__(self, acc=0.):\n        if type(acc).__name__ == 'dict':\n            self.acc = acc['acc']\n        elif type(acc).__name__ == 'Acc_Metric':\n            self.acc = acc.acc\n        else:\n            self.acc = acc\n\n    def better_than(self, other):\n        if self.acc > other.acc:\n            return True\n        else:\n            return False\n\n    def state_dict(self):\n        _dict = dict()\n        _dict['acc'] = self.acc\n        return _dict\n\n\ndef run_net(args, config, train_writer=None, val_writer=None):\n    logger = get_logger(args.log_name)\n    # build dataset\n    (train_sampler, train_dataloader), (_, test_dataloader), = builder.dataset_builder(args, config.dataset.train), \\\n        builder.dataset_builder(args, config.dataset.val)\n    # build model\n    base_model = builder.model_builder(config.model)\n\n    # parameter setting\n    start_epoch = 0\n    best_metrics = Acc_Metric(0.)\n    best_metrics_vote = Acc_Metric(0.)\n    metrics = Acc_Metric(0.)\n\n    # resume ckpts\n    if args.resume:\n        start_epoch, best_metric = builder.resume_model(\n            base_model, args, logger=logger)\n        best_metrics = Acc_Metric(best_metrics)\n    else:\n        if args.ckpts is not None:\n            base_model.load_model_from_ckpt(args.ckpts)\n        else:\n            print_log('Training from scratch', logger=logger)\n\n    if args.use_gpu:\n        base_model.to(args.local_rank)\n    # DDP\n    if args.distributed:\n        # Sync BN\n        if args.sync_bn:\n            base_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(\n                base_model)\n            print_log('Using Synchronized BatchNorm ...', logger=logger)\n        base_model = nn.parallel.DistributedDataParallel(\n            base_model, device_ids=[args.local_rank % torch.cuda.device_count()])\n        print_log('Using Distributed Data parallel ...', logger=logger)\n    else:\n        print_log('Using Data parallel ...', logger=logger)\n        base_model = nn.DataParallel(base_model).cuda()\n    # optimizer & scheduler\n    optimizer, scheduler = builder.build_opti_sche(base_model, config)\n\n    if args.resume:\n        builder.resume_optimizer(optimizer, args, logger=logger)\n\n    # trainval\n    # training\n    base_model.zero_grad()\n    for epoch in range(start_epoch, config.max_epoch + 1):\n        if args.distributed:\n            train_sampler.set_epoch(epoch)\n        base_model.train()\n\n        epoch_start_time = time.time()\n        batch_start_time = time.time()\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        losses = AverageMeter(['loss', 'loss_r', 'acc'])\n        num_iter = 0\n        base_model.train()  # set model to training mode\n        n_batches = len(train_dataloader)\n\n        npoints = config.npoints\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(train_dataloader):\n            num_iter += 1\n            n_itr = epoch * n_batches + idx\n\n            data_time.update(time.time() - batch_start_time)\n\n            points = data[0].cuda()\n            label = data[1].cuda()\n\n            if npoints == 1024:\n                point_all = 1200\n            elif npoints == 2048:\n                point_all = 2400\n            elif npoints == 4096:\n                point_all = 4800\n            elif npoints == 8192:\n                point_all = 8192\n            else:\n                raise NotImplementedError()\n\n            if points.size(1) < point_all:\n                point_all = points.size(1)\n\n            fps_idx = pointnet2_utils.furthest_point_sample(\n                points, point_all)  # (B, npoint)\n            fps_idx = fps_idx[:, np.random.choice(point_all, npoints, False)]\n            points = pointnet2_utils.gather_operation(points.transpose(\n                1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)\n            # import pdb; pdb.set_trace()\n            points = train_transforms(points)\n\n            ret, loss1 = base_model(points)\n\n            loss, acc = base_model.module.get_loss_acc(ret, label)\n\n            _loss = loss + 3 * loss1\n\n            try:\n                _loss.backward()\n            except:\n                _loss = _loss.mean()\n                _loss.backward()\n\n            # forward\n            if num_iter == config.step_per_update:\n                if config.get('grad_norm_clip') is not None:\n                    torch.nn.utils.clip_grad_norm_(\n                        base_model.parameters(), config.grad_norm_clip, norm_type=2)\n                num_iter = 0\n                optimizer.step()\n                base_model.zero_grad()\n\n            if args.distributed:\n                loss = dist_utils.reduce_tensor(loss, args)\n                acc = dist_utils.reduce_tensor(acc, args)\n                losses.update([loss.item(), loss1.item(), acc.item()])\n            else:\n                try:\n                    losses.update([loss.item(), loss1.item(), acc.item()])\n                except:\n                    losses.update([loss.mean().item(), loss1.mean().item(), acc.mean().item()])\n\n            if args.distributed:\n                torch.cuda.synchronize()\n\n            if train_writer is not None:\n                train_writer.add_scalar('Loss/Batch/Loss', loss.item(), n_itr)\n                train_writer.add_scalar(\n                    'Loss/Batch/TrainAcc', acc.item(), n_itr)\n                train_writer.add_scalar(\n                    'Loss/Batch/LR', optimizer.param_groups[0]['lr'], n_itr)\n\n            batch_time.update(time.time() - batch_start_time)\n            batch_start_time = time.time()\n\n            # if idx % 10 == 0:\n            #     print_log('[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Loss+Acc = %s lr = %.6f' %\n            #                 (epoch, config.max_epoch, idx + 1, n_batches, batch_time.val(), data_time.val(),\n            #                 ['%.4f' % l for l in losses.val()], optimizer.param_groups[0]['lr']), logger = logger)\n        if isinstance(scheduler, list):\n            for item in scheduler:\n                item.step(epoch)\n        else:\n            scheduler.step(epoch)\n        epoch_end_time = time.time()\n\n        if train_writer is not None:\n            train_writer.add_scalar('Loss/Epoch/Loss', losses.avg(0), epoch)\n\n        print_log('[Training] EPOCH: %d EpochTime = %.3f (s) Losses = %s lr = %.6f' %\n                  (epoch,  epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()], optimizer.param_groups[0]['lr']), logger=logger)\n\n        if epoch % args.val_freq == 0 and epoch != 0:\n            # Validate the current model\n            metrics = validate(base_model, test_dataloader,\n                               epoch, val_writer, args, config, logger=logger)\n\n            better = metrics.better_than(best_metrics)\n            # Save ckeckpoints\n            if better:\n                best_metrics = metrics\n                builder.save_checkpoint(\n                    base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-best', args, logger=logger)\n                print_log(\n                    \"--------------------------------------------------------------------------------------------\", logger=logger)\n            if args.vote:\n                if metrics.acc > 92.1 or (better and metrics.acc > 91):\n                    metrics_vote = validate_vote(\n                        base_model, test_dataloader, epoch, val_writer, args, config, logger=logger)\n                    if metrics_vote.better_than(best_metrics_vote):\n                        best_metrics_vote = metrics_vote\n                        print_log(\n                            \"****************************************************************************************\",\n                            logger=logger)\n                        builder.save_checkpoint(\n                            base_model, optimizer, epoch, metrics, best_metrics_vote, 'ckpt-best_vote', args, logger=logger)\n\n        builder.save_checkpoint(base_model, optimizer, epoch,\n                                metrics, best_metrics, 'ckpt-last', args, logger=logger)\n        # if (config.max_epoch - epoch) < 10:\n        #     builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, f'ckpt-epoch-{epoch:03d}', args, logger = logger)\n    if train_writer is not None:\n        train_writer.close()\n    if val_writer is not None:\n        val_writer.close()\n\n\ndef validate(base_model, test_dataloader, epoch, val_writer, args, config, logger=None):\n    # print_log(f\"[VALIDATION] Start validating epoch {epoch}\", logger = logger)\n    base_model.eval()  # set model to eval mode\n\n    test_pred = []\n    test_label = []\n    npoints = config.npoints\n    with torch.no_grad():\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):\n            points = data[0].cuda()\n            label = data[1].cuda()\n\n            points = misc.fps(points, npoints)\n\n            logits, loss1 = base_model(points)\n            target = label.view(-1)\n\n            pred = logits.argmax(-1).view(-1)\n\n            test_pred.append(pred.detach())\n            test_label.append(target.detach())\n\n        test_pred = torch.cat(test_pred, dim=0)\n        test_label = torch.cat(test_label, dim=0)\n\n        if args.distributed:\n            test_pred = dist_utils.gather_tensor(test_pred, args)\n            test_label = dist_utils.gather_tensor(test_label, args)\n\n        acc = (test_pred == test_label).sum() / \\\n            float(test_label.size(0)) * 100.\n        print_log('[Validation] EPOCH: %d  acc = %.4f' %\n                  (epoch, acc), logger=logger)\n\n        if args.distributed:\n            torch.cuda.synchronize()\n\n    # Add testing results to TensorBoard\n    if val_writer is not None:\n        val_writer.add_scalar('Metric/ACC', acc, epoch)\n\n    return Acc_Metric(acc)\n\n\ndef validate_vote(base_model, test_dataloader, epoch, val_writer, args, config, logger=None, times=10):\n    print_log(f\"[VALIDATION_VOTE] epoch {epoch}\", logger=logger)\n    base_model.eval()  # set model to eval mode\n\n    test_pred = []\n    test_label = []\n    npoints = config.npoints\n    with torch.no_grad():\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):\n            points_raw = data[0].cuda()\n            label = data[1].cuda()\n            if npoints == 1024:\n                point_all = 1200\n            elif npoints == 4096:\n                point_all = 4800\n            elif npoints == 8192:\n                point_all = 8192\n            else:\n                raise NotImplementedError()\n\n            if points_raw.size(1) < point_all:\n                point_all = points_raw.size(1)\n\n            fps_idx_raw = pointnet2_utils.furthest_point_sample(\n                points_raw, point_all)  # (B, npoint)\n            local_pred = []\n\n            for kk in range(times):\n                fps_idx = fps_idx_raw[:, np.random.choice(\n                    point_all, npoints, False)]\n                points = pointnet2_utils.gather_operation(points_raw.transpose(1, 2).contiguous(),\n                                                          fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)\n\n                points = test_transforms(points)\n\n                logits, loss1 = base_model(points)\n                target = label.view(-1)\n\n                local_pred.append(logits.detach().unsqueeze(0))\n\n            pred = torch.cat(local_pred, dim=0).mean(0)\n            _, pred_choice = torch.max(pred, -1)\n\n            test_pred.append(pred_choice)\n            test_label.append(target.detach())\n\n        test_pred = torch.cat(test_pred, dim=0)\n        test_label = torch.cat(test_label, dim=0)\n\n        if args.distributed:\n            test_pred = dist_utils.gather_tensor(test_pred, args)\n            test_label = dist_utils.gather_tensor(test_label, args)\n\n        acc = (test_pred == test_label).sum() / \\\n            float(test_label.size(0)) * 100.\n        print_log('[Validation_vote] EPOCH: %d  acc_vote = %.4f' %\n                  (epoch, acc), logger=logger)\n\n        if args.distributed:\n            torch.cuda.synchronize()\n\n    # Add testing results to TensorBoard\n    if val_writer is not None:\n        val_writer.add_scalar('Metric/ACC_vote', acc, epoch)\n\n    return Acc_Metric(acc)\n\n\ndef test_net(args, config):\n    logger = get_logger(args.log_name)\n    print_log('Tester start ... ', logger=logger)\n    _, test_dataloader = builder.dataset_builder(args, config.dataset.test)\n    base_model = builder.model_builder(config.model)\n    # load checkpoints\n    # for finetuned transformer\n    builder.load_model(base_model, args.ckpts, logger=logger)\n    # base_model.load_model_from_ckpt(args.ckpts) # for BERT\n    if args.use_gpu:\n        base_model.to(args.local_rank)\n\n    #  DDP\n    if args.distributed:\n        raise NotImplementedError()\n\n    test(base_model, test_dataloader, args, config, logger=logger)\n\n\ndef test(base_model, test_dataloader, args, config, logger=None):\n\n    base_model.eval()  # set model to eval mode\n\n    test_pred = []\n    test_label = []\n    npoints = config.npoints\n\n    with torch.no_grad():\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):\n            points = data[0].cuda()\n            label = data[1].cuda()\n\n            points = misc.fps(points, npoints)\n\n            logits, loss1 = base_model(points)\n            target = label.view(-1)\n\n            pred = logits.argmax(-1).view(-1)\n\n            test_pred.append(pred.detach())\n            test_label.append(target.detach())\n\n        test_pred = torch.cat(test_pred, dim=0)\n        test_label = torch.cat(test_label, dim=0)\n\n        if args.distributed:\n            test_pred = dist_utils.gather_tensor(test_pred, args)\n            test_label = dist_utils.gather_tensor(test_label, args)\n\n        acc = (test_pred == test_label).sum() / \\\n            float(test_label.size(0)) * 100.\n        print_log('[TEST] acc = %.4f' % acc, logger=logger)\n\n        if args.distributed:\n            torch.cuda.synchronize()\n\n        print_log(f\"[TEST_VOTE]\", logger=logger)\n        acc = 0.\n        for time in range(1, 300):\n            this_acc = test_vote(base_model, test_dataloader,\n                                 1, None, args, config, logger=logger, times=5)\n            if acc < this_acc:\n                acc = this_acc\n            print_log('[TEST_VOTE_time %d]  acc = %.4f, best acc = %.4f' %\n                      (time, this_acc, acc), logger=logger)\n        print_log('[TEST_VOTE] acc = %.4f' % acc, logger=logger)\n\n\ndef test_vote(base_model, test_dataloader, epoch, val_writer, args, config, logger=None, times=10):\n\n    base_model.eval()  # set model to eval mode\n\n    test_pred = []\n    test_label = []\n    npoints = config.npoints\n    with torch.no_grad():\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):\n            points_raw = data[0].cuda()\n            label = data[1].cuda()\n            if npoints == 1024:\n                point_all = 1024\n            elif npoints == 2048:\n                point_all = 2048\n            elif npoints == 4096:\n                point_all = 4096\n            elif npoints == 8192:\n                point_all = 8192\n            else:\n                raise NotImplementedError()\n\n            if points_raw.size(1) < point_all:\n                point_all = points_raw.size(1)\n\n            fps_idx_raw = pointnet2_utils.furthest_point_sample(\n                points_raw, point_all)  # (B, npoint)\n            local_pred = []\n\n            for kk in range(times):\n                fps_idx = fps_idx_raw[:, np.random.choice(\n                    point_all, npoints, False)]\n                points = pointnet2_utils.gather_operation(points_raw.transpose(1, 2).contiguous(),\n                                                          fps_idx).transpose(1, 2).contiguous()  # (B, N, 3)\n\n                points = test_transforms(points)\n\n                logits, loss1 = base_model(points)\n                target = label.view(-1)\n\n                local_pred.append(logits.detach().unsqueeze(0))\n\n            # softmax = torch.softmax\n\n            pred = torch.cat(local_pred, dim=0).mean(0)\n            # print('pred', pred.shape)\n            # pred = softmax(1000*pred, dim=-1).mean(0)\n            _, pred_choice = torch.max(pred, -1)\n\n            test_pred.append(pred_choice)\n            test_label.append(target.detach())\n\n        test_pred = torch.cat(test_pred, dim=0)\n        test_label = torch.cat(test_label, dim=0)\n\n        if args.distributed:\n            test_pred = dist_utils.gather_tensor(test_pred, args)\n            test_label = dist_utils.gather_tensor(test_label, args)\n\n        acc = (test_pred == test_label).sum() / \\\n            float(test_label.size(0)) * 100.\n\n        if args.distributed:\n            torch.cuda.synchronize()\n\n    # Add testing results to TensorBoard\n    if val_writer is not None:\n        val_writer.add_scalar('Metric/ACC_vote', acc, epoch)\n    # print_log('[TEST] acc = %.4f' % acc, logger=logger)\n\n    return acc\n"
  },
  {
    "path": "tools/runner_pretrain.py",
    "content": "import torch\nimport torch.nn as nn\nimport os\nimport json\nfrom tools import builder\nfrom utils import misc, dist_utils\nimport time\nfrom utils.logger import *\nfrom utils.AverageMeter import AverageMeter\n\nfrom sklearn.svm import LinearSVC\nimport numpy as np\nfrom torchvision import transforms\nfrom datasets import data_transforms\nfrom pointnet2_ops import pointnet2_utils\nfrom torchstat import stat\n\ntrain_transforms = transforms.Compose(\n    [\n        # data_transforms.PointcloudScale(),\n        # data_transforms.PointcloudRotate(),\n        # data_transforms.PointcloudRotatePerturbation(),\n        # data_transforms.PointcloudTranslate(),\n        # data_transforms.PointcloudJitter(),\n        # data_transforms.PointcloudRandomInputDropout(),\n        data_transforms.PointcloudScaleAndTranslate(),\n    ]\n)\n\n\nclass Acc_Metric:\n    def __init__(self, acc=0.):\n        if type(acc).__name__ == 'dict':\n            self.acc = acc['acc']\n        else:\n            self.acc = acc\n\n    def better_than(self, other):\n        if self.acc > other.acc:\n            return True\n        else:\n            return False\n\n    def state_dict(self):\n        _dict = dict()\n        _dict['acc'] = self.acc\n        return _dict\n\n\ndef evaluate_svm(train_features, train_labels, test_features, test_labels):\n    clf = LinearSVC()\n    clf.fit(train_features, train_labels)\n    pred = clf.predict(test_features)\n    return np.sum(test_labels == pred) * 1. / pred.shape[0]\n\n\ndef run_net(args, config, train_writer=None, val_writer=None):\n    logger = get_logger(args.log_name)\n    # build dataset\n    (train_sampler, train_dataloader), (_, test_dataloader), = builder.dataset_builder(args, config.dataset.train), \\\n        builder.dataset_builder(args, config.dataset.val)\n    (_, extra_train_dataloader) = builder.dataset_builder(\n        args, config.dataset.extra_train) if config.dataset.get('extra_train') else (None, None)\n    # build model\n    base_model = builder.model_builder(config.model)\n    if args.use_gpu:\n        base_model.to(args.local_rank)\n\n    # from IPython import embed; embed()\n\n    # parameter setting\n    start_epoch = 0\n    best_metrics = Acc_Metric(0.)\n    metrics = Acc_Metric(0.)\n\n    # resume ckpts\n    if args.resume:\n        start_epoch, best_metric = builder.resume_model(\n            base_model, args, logger=logger)\n        best_metrics = Acc_Metric(best_metric)\n    elif args.start_ckpts is not None:\n        builder.load_model(base_model, args.start_ckpts, logger=logger)\n\n    # DDP\n    if args.distributed:\n        # Sync BN\n        if args.sync_bn:\n            base_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(\n                base_model)\n            print_log('Using Synchronized BatchNorm ...', logger=logger)\n        base_model = nn.parallel.DistributedDataParallel(base_model, device_ids=[\n                                                         args.local_rank % torch.cuda.device_count()], find_unused_parameters=True)\n        print_log('Using Distributed Data parallel ...', logger=logger)\n    else:\n        print_log('Using Data parallel ...', logger=logger)\n        base_model = nn.DataParallel(base_model).cuda()\n    # optimizer & scheduler\n    optimizer, scheduler = builder.build_opti_sche(base_model, config)\n\n    if args.resume:\n        builder.resume_optimizer(optimizer, args, logger=logger)\n\n    # trainval\n    # training\n    base_model.zero_grad()\n    for epoch in range(start_epoch, config.max_epoch + 1):\n        if args.distributed:\n            train_sampler.set_epoch(epoch)\n        base_model.train()\n\n        epoch_start_time = time.time()\n        batch_start_time = time.time()\n        batch_time = AverageMeter()\n        data_time = AverageMeter()\n        losses = AverageMeter(['Loss'])\n\n        num_iter = 0\n\n        base_model.train()  # set model to training mode\n        n_batches = len(train_dataloader)\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(train_dataloader):\n            num_iter += 1\n            n_itr = epoch * n_batches + idx\n\n            data_time.update(time.time() - batch_start_time)\n            npoints = config.dataset.train.others.npoints\n            dataset_name = config.dataset.train._base_.NAME\n            if dataset_name == 'ShapeNet' or dataset_name == 'UnlabeledHybrid':\n                points = data.cuda()\n            elif dataset_name == 'ModelNet':\n                points = data[0].cuda()\n                points = misc.fps(points, npoints)\n            else:\n                raise NotImplementedError(\n                    f'Train phase do not support {dataset_name}')\n\n            assert points.size(1) == npoints\n            points = train_transforms(points)\n            loss = base_model(points)\n            try:\n                loss.backward()\n                # print(\"Using one GPU\")\n            except:\n                loss = loss.mean()\n                loss.backward()\n                # print(\"Using multi GPUs\")\n\n            # forward\n            if num_iter == config.step_per_update:\n                num_iter = 0\n                optimizer.step()\n                base_model.zero_grad()\n\n            if args.distributed:\n                loss = dist_utils.reduce_tensor(loss, args)\n                losses.update([loss.item()*1000])\n            else:\n                losses.update([loss.item()*1000])\n\n            if args.distributed:\n                torch.cuda.synchronize()\n\n            if train_writer is not None:\n                train_writer.add_scalar('Loss/Batch/Loss', loss.item(), n_itr)\n                train_writer.add_scalar(\n                    'Loss/Batch/LR', optimizer.param_groups[0]['lr'], n_itr)\n\n            batch_time.update(time.time() - batch_start_time)\n            batch_start_time = time.time()\n\n            if idx % 20 == 0:\n                print_log('[Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) Losses = %s lr = %.6f' %\n                          (epoch, config.max_epoch, idx + 1, n_batches, batch_time.val(), data_time.val(),\n                           ['%.4f' % l for l in losses.val()], optimizer.param_groups[0]['lr']), logger=logger)\n        if isinstance(scheduler, list):\n            for item in scheduler:\n                item.step(epoch)\n        else:\n            scheduler.step(epoch)\n        epoch_end_time = time.time()\n\n        if train_writer is not None:\n            train_writer.add_scalar('Loss/Epoch/Loss_1', losses.avg(0), epoch)\n        print_log('[Training] EPOCH: %d EpochTime = %.3f (s) Losses = %s lr = %.6f' %\n                  (epoch,  epoch_end_time - epoch_start_time, ['%.4f' % l for l in losses.avg()],\n                   optimizer.param_groups[0]['lr']), logger=logger)\n\n        # if epoch % args.val_freq == 0 and epoch != 0:\n        #     # Validate the current model\n        #     metrics = validate(base_model, extra_train_dataloader, test_dataloader, epoch, val_writer, args, config, logger=logger)\n        #\n        #     # Save ckeckpoints\n        #     if metrics.better_than(best_metrics):\n        #         best_metrics = metrics\n        #         builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, 'ckpt-best', args, logger = logger)\n        builder.save_checkpoint(base_model, optimizer, epoch,\n                                metrics, best_metrics, 'ckpt-last', args, logger=logger)\n        if epoch % 25 == 0 and epoch >= 250:\n            builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, f'ckpt-epoch-{epoch:03d}', args,\n                                    logger=logger)\n        # if (config.max_epoch - epoch) < 10:\n        #     builder.save_checkpoint(base_model, optimizer, epoch, metrics, best_metrics, f'ckpt-epoch-{epoch:03d}', args, logger = logger)\n    if train_writer is not None:\n        train_writer.close()\n    if val_writer is not None:\n        val_writer.close()\n\n\ndef validate(base_model, extra_train_dataloader, test_dataloader, epoch, val_writer, args, config, logger=None):\n    print_log(f\"[VALIDATION] Start validating epoch {epoch}\", logger=logger)\n    base_model.eval()  # set model to eval mode\n\n    test_features = []\n    test_label = []\n\n    train_features = []\n    train_label = []\n    npoints = config.dataset.train.others.npoints\n    with torch.no_grad():\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(extra_train_dataloader):\n            points = data[0].cuda()\n            label = data[1].cuda()\n\n            points = misc.fps(points, npoints)\n\n            assert points.size(1) == npoints\n            feature = base_model(points, noaug=True)\n            target = label.view(-1)\n\n            train_features.append(feature.detach())\n            train_label.append(target.detach())\n\n        for idx, (taxonomy_ids, model_ids, data) in enumerate(test_dataloader):\n            points = data[0].cuda()\n            label = data[1].cuda()\n\n            points = misc.fps(points, npoints)\n            assert points.size(1) == npoints\n            feature = base_model(points, noaug=True)\n            target = label.view(-1)\n\n            test_features.append(feature.detach())\n            test_label.append(target.detach())\n\n        train_features = torch.cat(train_features, dim=0)\n        train_label = torch.cat(train_label, dim=0)\n        test_features = torch.cat(test_features, dim=0)\n        test_label = torch.cat(test_label, dim=0)\n\n        if args.distributed:\n            train_features = dist_utils.gather_tensor(train_features, args)\n            train_label = dist_utils.gather_tensor(train_label, args)\n            test_features = dist_utils.gather_tensor(test_features, args)\n            test_label = dist_utils.gather_tensor(test_label, args)\n\n        svm_acc = evaluate_svm(train_features.data.cpu().numpy(), train_label.data.cpu(\n        ).numpy(), test_features.data.cpu().numpy(), test_label.data.cpu().numpy())\n\n        print_log('[Validation] EPOCH: %d  acc = %.4f' %\n                  (epoch, svm_acc), logger=logger)\n\n        if args.distributed:\n            torch.cuda.synchronize()\n\n    # Add testing results to TensorBoard\n    if val_writer is not None:\n        val_writer.add_scalar('Metric/ACC', svm_acc, epoch)\n\n    return Acc_Metric(svm_acc)\n\n\ndef test_net():\n    pass\n"
  },
  {
    "path": "utils/AverageMeter.py",
    "content": "\nclass AverageMeter(object):\n    def __init__(self, items=None):\n        self.items = items\n        self.n_items = 1 if items is None else len(items)\n        self.reset()\n\n    def reset(self):\n        self._val = [0] * self.n_items\n        self._sum = [0] * self.n_items\n        self._count = [0] * self.n_items\n\n    def update(self, values):\n        if type(values).__name__ == 'list':\n            for idx, v in enumerate(values):\n                self._val[idx] = v\n                self._sum[idx] += v\n                self._count[idx] += 1\n        else:\n            self._val[0] = values\n            self._sum[0] += values\n            self._count[0] += 1\n\n    def val(self, idx=None):\n        if idx is None:\n            return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)]\n        else:\n            return self._val[idx]\n\n    def count(self, idx=None):\n        if idx is None:\n            return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)]\n        else:\n            return self._count[idx]\n\n    def avg(self, idx=None):\n        if idx is None:\n            return self._sum[0] / self._count[0] if self.items is None else [\n                self._sum[i] / self._count[i] for i in range(self.n_items)\n            ]\n        else:\n            return self._sum[idx] / self._count[idx]"
  },
  {
    "path": "utils/checkpoint.py",
    "content": "#!/usr/bin/env python3\n# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.\n\nimport copy\nimport logging\nimport os\nfrom collections import defaultdict\nimport torch\nimport torch.nn as nn\n\nfrom typing import Any\nfrom typing import Optional, List, Dict, NamedTuple, Tuple, Iterable\n\nfrom termcolor import colored\n\ndef get_missing_parameters_message(keys: List[str]) -> str:\n    \"\"\"\n    Get a logging-friendly message to report parameter names (keys) that are in\n    the model but not found in a checkpoint.\n    Args:\n        keys (list[str]): List of keys that were not found in the checkpoint.\n    Returns:\n        str: message.\n    \"\"\"\n    groups = _group_checkpoint_keys(keys)\n    msg = \"Some model parameters or buffers are not found in the checkpoint:\\n\"\n    msg += \"\\n\".join(\n        \"  \" + colored(k + _group_to_str(v), \"blue\") for k, v in groups.items()\n    )\n    return msg\n\n\ndef get_unexpected_parameters_message(keys: List[str]) -> str:\n    \"\"\"\n    Get a logging-friendly message to report parameter names (keys) that are in\n    the checkpoint but not found in the model.\n    Args:\n        keys (list[str]): List of keys that were not found in the model.\n    Returns:\n        str: message.\n    \"\"\"\n    groups = _group_checkpoint_keys(keys)\n    msg = \"The checkpoint state_dict contains keys that are not used by the model:\\n\"\n    msg += \"\\n\".join(\n        \"  \" + colored(k + _group_to_str(v), \"magenta\") for k, v in groups.items()\n    )\n    return msg\n\n\ndef _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:\n    \"\"\"\n    Strip the prefix in metadata, if any.\n    Args:\n        state_dict (OrderedDict): a state-dict to be loaded to the model.\n        prefix (str): prefix.\n    \"\"\"\n    keys = sorted(state_dict.keys())\n    if not all(len(key) == 0 or key.startswith(prefix) for key in keys):\n        return\n\n    for key in keys:\n        newkey = key[len(prefix):]\n        state_dict[newkey] = state_dict.pop(key)\n\n    # also strip the prefix in metadata, if any..\n    try:\n        metadata = state_dict._metadata  # pyre-ignore\n    except AttributeError:\n        pass\n    else:\n        for key in list(metadata.keys()):\n            # for the metadata dict, the key can be:\n            # '': for the DDP module, which we want to remove.\n            # 'module': for the actual model.\n            # 'module.xx.xx': for the rest.\n\n            if len(key) == 0:\n                continue\n            newkey = key[len(prefix):]\n            metadata[newkey] = metadata.pop(key)\n\n\ndef _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:\n    \"\"\"\n    Group keys based on common prefixes. A prefix is the string up to the final\n    \".\" in each key.\n    Args:\n        keys (list[str]): list of parameter names, i.e. keys in the model\n            checkpoint dict.\n    Returns:\n        dict[list]: keys with common prefixes are grouped into lists.\n    \"\"\"\n    groups = defaultdict(list)\n    for key in keys:\n        pos = key.rfind(\".\")\n        if pos >= 0:\n            head, tail = key[:pos], [key[pos + 1:]]\n        else:\n            head, tail = key, []\n        groups[head].extend(tail)\n    return groups\n\n\ndef _group_to_str(group: List[str]) -> str:\n    \"\"\"\n    Format a group of parameter name suffixes into a loggable string.\n    Args:\n        group (list[str]): list of parameter name suffixes.\n    Returns:\n        str: formated string.\n    \"\"\"\n    if len(group) == 0:\n        return \"\"\n\n    if len(group) == 1:\n        return \".\" + group[0]\n\n    return \".{\" + \", \".join(group) + \"}\"\n\n\ndef _named_modules_with_dup(\n        model: nn.Module, prefix: str = \"\"\n) -> Iterable[Tuple[str, nn.Module]]:\n    \"\"\"\n    The same as `model.named_modules()`, except that it includes\n    duplicated modules that have more than one name.\n    \"\"\"\n    yield prefix, model\n    for name, module in model._modules.items():  # pyre-ignore\n        if module is None:\n            continue\n        submodule_prefix = prefix + (\".\" if prefix else \"\") + name\n        yield from _named_modules_with_dup(module, submodule_prefix)"
  },
  {
    "path": "utils/config.py",
    "content": "import yaml\nfrom easydict import EasyDict\nimport os\nfrom .logger import print_log\n\n\ndef log_args_to_file(args, pre='args', logger=None):\n    for key, val in args.__dict__.items():\n        print_log(f'{pre}.{key} : {val}', logger=logger)\n\n\ndef log_config_to_file(cfg, pre='cfg', logger=None):\n    for key, val in cfg.items():\n        if isinstance(cfg[key], EasyDict):\n            print_log(f'{pre}.{key} = edict()', logger=logger)\n            log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger)\n            continue\n        print_log(f'{pre}.{key} : {val}', logger=logger)\n\n\ndef merge_new_config(config, new_config):\n    for key, val in new_config.items():\n        if not isinstance(val, dict):\n            if key == '_base_':\n                with open(new_config['_base_'], 'r') as f:\n                    try:\n                        val = yaml.load(f, Loader=yaml.FullLoader)\n                    except:\n                        val = yaml.load(f)\n                config[key] = EasyDict()\n                merge_new_config(config[key], val)\n            else:\n                config[key] = val\n                continue\n        if key not in config:\n            config[key] = EasyDict()\n        merge_new_config(config[key], val)\n    return config\n\n\ndef cfg_from_yaml_file(cfg_file):\n    config = EasyDict()\n    with open(cfg_file, 'r') as f:\n        try:\n            new_config = yaml.load(f, Loader=yaml.FullLoader)\n        except:\n            new_config = yaml.load(f)\n    merge_new_config(config=config, new_config=new_config)\n    return config\n\n\ndef get_config(args, logger=None):\n    if args.resume:\n        cfg_path = os.path.join(args.experiment_path, 'config.yaml')\n        if not os.path.exists(cfg_path):\n            print_log(\"Failed to resume\", logger=logger)\n            raise FileNotFoundError()\n        print_log(f'Resume yaml from {cfg_path}', logger=logger)\n        args.config = cfg_path\n    config = cfg_from_yaml_file(args.config)\n    if not args.resume and args.local_rank == 0:\n        save_experiment_config(args, config, logger)\n    return config\n\n\ndef save_experiment_config(args, config, logger=None):\n    config_path = os.path.join(args.experiment_path, 'config.yaml')\n    os.system('cp %s %s' % (args.config, config_path))\n    print_log(\n        f'Copy the Config file from {args.config} to {config_path}', logger=logger)\n"
  },
  {
    "path": "utils/dist_utils.py",
    "content": "import os\n\nimport torch\nimport torch.multiprocessing as mp\nfrom torch import distributed as dist\n\n\n\ndef init_dist(launcher, backend='nccl', **kwargs):\n    if mp.get_start_method(allow_none=True) is None:\n        mp.set_start_method('spawn')\n    if launcher == 'pytorch':\n        _init_dist_pytorch(backend, **kwargs)\n    else:\n        raise ValueError(f'Invalid launcher type: {launcher}')\n\n\ndef _init_dist_pytorch(backend, **kwargs):\n    # TODO: use local_rank instead of rank % num_gpus\n    rank = int(os.environ['RANK'])\n    num_gpus = torch.cuda.device_count()\n    torch.cuda.set_device(rank % num_gpus)\n    dist.init_process_group(backend=backend, **kwargs)\n    print(f'init distributed in rank {torch.distributed.get_rank()}')\n\n\ndef get_dist_info():\n    if dist.is_available():\n        initialized = dist.is_initialized()\n    else:\n        initialized = False\n    if initialized:\n        rank = dist.get_rank()\n        world_size = dist.get_world_size()\n    else:\n        rank = 0\n        world_size = 1\n    return rank, world_size\n\n\ndef reduce_tensor(tensor, args):\n    '''\n        for acc kind, get the mean in each gpu\n    '''\n    rt = tensor.clone()\n    torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)\n    rt /= args.world_size\n    return rt\n\ndef gather_tensor(tensor, args):\n    output_tensors = [tensor.clone() for _ in range(args.world_size)]\n    torch.distributed.all_gather(output_tensors, tensor)\n    concat = torch.cat(output_tensors, dim=0)\n    return concat\n"
  },
  {
    "path": "utils/logger.py",
    "content": "import logging\nimport torch.distributed as dist\n\nlogger_initialized = {}\n\ndef get_root_logger(log_file=None, log_level=logging.INFO, name='main'):\n    \"\"\"Get root logger and add a keyword filter to it.\n    The logger will be initialized if it has not been initialized. By default a\n    StreamHandler will be added. If `log_file` is specified, a FileHandler will\n    also be added. The name of the root logger is the top-level package name,\n    e.g., \"mmdet3d\".\n    Args:\n        log_file (str, optional): File path of log. Defaults to None.\n        log_level (int, optional): The level of logger.\n            Defaults to logging.INFO.\n        name (str, optional): The name of the root logger, also used as a\n            filter keyword. Defaults to 'mmdet3d'.\n    Returns:\n        :obj:`logging.Logger`: The obtained logger\n    \"\"\"\n    logger = get_logger(name=name, log_file=log_file, log_level=log_level)\n    # add a logging filter\n    logging_filter = logging.Filter(name)\n    logging_filter.filter = lambda record: record.find(name) != -1\n\n    return logger\n\n\ndef get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):\n    \"\"\"Initialize and get a logger by name.\n    If the logger has not been initialized, this method will initialize the\n    logger by adding one or two handlers, otherwise the initialized logger will\n    be directly returned. During initialization, a StreamHandler will always be\n    added. If `log_file` is specified and the process rank is 0, a FileHandler\n    will also be added.\n    Args:\n        name (str): Logger name.\n        log_file (str | None): The log filename. If specified, a FileHandler\n            will be added to the logger.\n        log_level (int): The logger level. Note that only the process of\n            rank 0 is affected, and other processes will set the level to\n            \"Error\" thus be silent most of the time.\n        file_mode (str): The file mode used in opening log file.\n            Defaults to 'w'.\n    Returns:\n        logging.Logger: The expected logger.\n    \"\"\"\n    logger = logging.getLogger(name)\n    if name in logger_initialized:\n        return logger\n    # handle hierarchical names\n    # e.g., logger \"a\" is initialized, then logger \"a.b\" will skip the\n    # initialization since it is a child of \"a\".\n    for logger_name in logger_initialized:\n        if name.startswith(logger_name):\n            return logger\n\n    # handle duplicate logs to the console\n    # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)\n    # to the root logger. As logger.propagate is True by default, this root\n    # level handler causes logging messages from rank>0 processes to\n    # unexpectedly show up on the console, creating much unwanted clutter.\n    # To fix this issue, we set the root logger's StreamHandler, if any, to log\n    # at the ERROR level.\n    for handler in logger.root.handlers:\n        if type(handler) is logging.StreamHandler:\n            handler.setLevel(logging.ERROR)\n\n    stream_handler = logging.StreamHandler()\n    handlers = [stream_handler]\n\n    if dist.is_available() and dist.is_initialized():\n        rank = dist.get_rank()\n    else:\n        rank = 0\n\n    # only rank 0 will add a FileHandler\n    if rank == 0 and log_file is not None:\n        # Here, the default behaviour of the official logger is 'a'. Thus, we\n        # provide an interface to change the file mode to the default\n        # behaviour.\n        file_handler = logging.FileHandler(log_file, file_mode)\n        handlers.append(file_handler)\n\n    formatter = logging.Formatter(\n        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n    for handler in handlers:\n        handler.setFormatter(formatter)\n        handler.setLevel(log_level)\n        logger.addHandler(handler)\n\n    if rank == 0:\n        logger.setLevel(log_level)\n    else:\n        logger.setLevel(logging.ERROR)\n\n    logger_initialized[name] = True\n\n\n    return logger\n\n\ndef print_log(msg, logger=None, level=logging.INFO):\n    \"\"\"Print a log message.\n    Args:\n        msg (str): The message to be logged.\n        logger (logging.Logger | str | None): The logger to be used.\n            Some special loggers are:\n            - \"silent\": no message will be printed.\n            - other str: the logger obtained with `get_root_logger(logger)`.\n            - None: The `print()` method will be used to print log messages.\n        level (int): Logging level. Only available when `logger` is a Logger\n            object or \"root\".\n    \"\"\"\n    if logger is None:\n        print(msg)\n    elif isinstance(logger, logging.Logger):\n        logger.log(level, msg)\n    elif logger == 'silent':\n        pass\n    elif isinstance(logger, str):\n        _logger = get_logger(logger)\n        _logger.log(level, msg)\n    else:\n        raise TypeError(\n            'logger should be either a logging.Logger object, str, '\n            f'\"silent\" or None, but got {type(logger)}')"
  },
  {
    "path": "utils/misc.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nimport random\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport os\nfrom collections import abc\nfrom pointnet2_ops import pointnet2_utils\n\n\ndef fps(data, number):\n    '''\n        data B N 3\n        number int\n    '''\n    fps_idx = pointnet2_utils.furthest_point_sample(data, number) \n    fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()\n    return fps_data\n\n\ndef worker_init_fn(worker_id):\n    np.random.seed(np.random.get_state()[1][0] + worker_id)\n\ndef build_lambda_sche(opti, config):\n    if config.get('decay_step') is not None:\n        lr_lbmd = lambda e: max(config.lr_decay ** (e / config.decay_step), config.lowest_decay)\n        scheduler = torch.optim.lr_scheduler.LambdaLR(opti, lr_lbmd)\n    else:\n        raise NotImplementedError()\n    return scheduler\n\ndef build_lambda_bnsche(model, config):\n    if config.get('decay_step') is not None:\n        bnm_lmbd = lambda e: max(config.bn_momentum * config.bn_decay ** (e / config.decay_step), config.lowest_decay)\n        bnm_scheduler = BNMomentumScheduler(model, bnm_lmbd)\n    else:\n        raise NotImplementedError()\n    return bnm_scheduler\n    \ndef set_random_seed(seed, deterministic=False):\n    \"\"\"Set random seed.\n    Args:\n        seed (int): Seed to be used.\n        deterministic (bool): Whether to set the deterministic option for\n            CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`\n            to True and `torch.backends.cudnn.benchmark` to False.\n            Default: False.\n\n    # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html\n    if cuda_deterministic:  # slower, more reproducible\n        cudnn.deterministic = True\n        cudnn.benchmark = False\n    else:  # faster, less reproducible\n        cudnn.deterministic = False\n        cudnn.benchmark = True\n\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    if deterministic:\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = False\n\n\ndef is_seq_of(seq, expected_type, seq_type=None):\n    \"\"\"Check whether it is a sequence of some type.\n    Args:\n        seq (Sequence): The sequence to be checked.\n        expected_type (type): Expected type of sequence items.\n        seq_type (type, optional): Expected sequence type.\n    Returns:\n        bool: Whether the sequence is valid.\n    \"\"\"\n    if seq_type is None:\n        exp_seq_type = abc.Sequence\n    else:\n        assert isinstance(seq_type, type)\n        exp_seq_type = seq_type\n    if not isinstance(seq, exp_seq_type):\n        return False\n    for item in seq:\n        if not isinstance(item, expected_type):\n            return False\n    return True\n\n\ndef set_bn_momentum_default(bn_momentum):\n    def fn(m):\n        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):\n            m.momentum = bn_momentum\n    return fn\n\nclass BNMomentumScheduler(object):\n\n    def __init__(\n            self, model, bn_lambda, last_epoch=-1,\n            setter=set_bn_momentum_default\n    ):\n        if not isinstance(model, nn.Module):\n            raise RuntimeError(\n                \"Class '{}' is not a PyTorch nn Module\".format(\n                    type(model).__name__\n                )\n            )\n\n        self.model = model\n        self.setter = setter\n        self.lmbd = bn_lambda\n\n        self.step(last_epoch + 1)\n        self.last_epoch = last_epoch\n\n    def step(self, epoch=None):\n        if epoch is None:\n            epoch = self.last_epoch + 1\n\n        self.last_epoch = epoch\n        self.model.apply(self.setter(self.lmbd(epoch)))\n\n    def get_momentum(self, epoch=None):\n        if epoch is None:\n            epoch = self.last_epoch + 1\n        return self.lmbd(epoch)\n\n\n\ndef seprate_point_cloud(xyz, num_points, crop, fixed_points = None, padding_zeros = False):\n    '''\n     seprate point cloud: usage : using to generate the incomplete point cloud with a setted number.\n    '''\n    _,n,c = xyz.shape\n\n    assert n == num_points\n    assert c == 3\n    if crop == num_points:\n        return xyz, None\n        \n    INPUT = []\n    CROP = []\n    for points in xyz:\n        if isinstance(crop,list):\n            num_crop = random.randint(crop[0],crop[1])\n        else:\n            num_crop = crop\n\n        points = points.unsqueeze(0)\n\n        if fixed_points is None:       \n            center = F.normalize(torch.randn(1,1,3),p=2,dim=-1).cuda()\n        else:\n            if isinstance(fixed_points,list):\n                fixed_point = random.sample(fixed_points,1)[0]\n            else:\n                fixed_point = fixed_points\n            center = fixed_point.reshape(1,1,3).cuda()\n\n        distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1)  # 1 1 2048\n\n        idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048\n\n        if padding_zeros:\n            input_data = points.clone()\n            input_data[0, idx[:num_crop]] =  input_data[0,idx[:num_crop]] * 0\n\n        else:\n            input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3\n\n        crop_data =  points.clone()[0, idx[:num_crop]].unsqueeze(0)\n\n        if isinstance(crop,list):\n            INPUT.append(fps(input_data,2048))\n            CROP.append(fps(crop_data,2048))\n        else:\n            INPUT.append(input_data)\n            CROP.append(crop_data)\n\n    input_data = torch.cat(INPUT,dim=0)# B N 3\n    crop_data = torch.cat(CROP,dim=0)# B M 3\n\n    return input_data.contiguous(), crop_data.contiguous()\n\ndef get_ptcloud_img(ptcloud,roll,pitch):\n    fig = plt.figure(figsize=(8, 8))\n\n    x, z, y = ptcloud.transpose(1, 0)\n    ax = fig.gca(projection=Axes3D.name, adjustable='box')\n    ax.axis('off')\n    # ax.axis('scaled')\n    ax.view_init(roll,pitch)\n    max, min = np.max(ptcloud), np.min(ptcloud)\n    ax.set_xbound(min, max)\n    ax.set_ybound(min, max)\n    ax.set_zbound(min, max)\n    ax.scatter(x, y, z, zdir='z', c=y, cmap='jet')\n\n    fig.canvas.draw()\n    img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')\n    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))\n    return img\n\n\n\ndef visualize_KITTI(path, data_list, titles = ['input','pred'], cmap=['bwr','autumn'], zdir='y', \n                         xlim=(-1, 1), ylim=(-1, 1), zlim=(-1, 1) ):\n    fig = plt.figure(figsize=(6*len(data_list),6))\n    cmax = data_list[-1][:,0].max()\n\n    for i in range(len(data_list)):\n        data = data_list[i][:-2048] if i == 1 else data_list[i]\n        color = data[:,0] /cmax\n        ax = fig.add_subplot(1, len(data_list) , i + 1, projection='3d')\n        ax.view_init(30, -120)\n        b = ax.scatter(data[:, 0], data[:, 1], data[:, 2], zdir=zdir, c=color,vmin=-1,vmax=1 ,cmap = cmap[0],s=4,linewidth=0.05, edgecolors = 'black')\n        ax.set_title(titles[i])\n\n        ax.set_axis_off()\n        ax.set_xlim(xlim)\n        ax.set_ylim(ylim)\n        ax.set_zlim(zlim)\n    plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0.2, hspace=0)\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n    pic_path = path + '.png'\n    fig.savefig(pic_path)\n\n    np.save(os.path.join(path, 'input.npy'), data_list[0].numpy())\n    np.save(os.path.join(path, 'pred.npy'), data_list[1].numpy())\n    plt.close(fig)\n\n\ndef random_dropping(pc, e):\n    up_num = max(64, 768 // (e//50 + 1))\n    pc = pc\n    random_num = torch.randint(1, up_num, (1,1))[0,0]\n    pc = fps(pc, random_num)\n    padding = torch.zeros(pc.size(0), 2048 - pc.size(1), 3).to(pc.device)\n    pc = torch.cat([pc, padding], dim = 1)\n    return pc\n    \n\ndef random_scale(partial, scale_range=[0.8, 1.2]):\n    scale = torch.rand(1).cuda() * (scale_range[1] - scale_range[0]) + scale_range[0]\n    return partial * scale\n"
  },
  {
    "path": "utils/parser.py",
    "content": "import os\nimport argparse\nfrom pathlib import Path\n\ndef get_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        '--config', \n        type = str, \n        help = 'yaml config file')\n    parser.add_argument(\n        '--launcher',\n        choices=['none', 'pytorch'],\n        default='none',\n        help='job launcher')     \n    parser.add_argument('--local_rank', type=int, default=0)\n    parser.add_argument('--num_workers', type=int, default=8)\n    # seed \n    parser.add_argument('--seed', type=int, default=0, help='random seed')\n    parser.add_argument(\n        '--deterministic',\n        action='store_true',\n        help='whether to set deterministic options for CUDNN backend.')      \n    # bn\n    parser.add_argument(\n        '--sync_bn', \n        action='store_true', \n        default=False, \n        help='whether to use sync bn')\n    # some args\n    parser.add_argument('--exp_name', type = str, default='default', help = 'experiment name')\n    parser.add_argument('--loss', type=str, default='cd1', help='loss name')\n    parser.add_argument('--start_ckpts', type = str, default=None, help = 'reload used ckpt path')\n    parser.add_argument('--ckpts', type = str, default=None, help = 'test used ckpt path')\n    parser.add_argument('--val_freq', type = int, default=1, help = 'test freq')\n    parser.add_argument(\n        '--vote',\n        action='store_true',\n        default=False,\n        help = 'vote acc')\n    parser.add_argument(\n        '--resume', \n        action='store_true', \n        default=False, \n        help = 'autoresume training (interrupted by accident)')\n    parser.add_argument(\n        '--test', \n        action='store_true', \n        default=False, \n        help = 'test mode for certain ckpt')\n    parser.add_argument(\n        '--finetune_model', \n        action='store_true', \n        default=False, \n        help = 'finetune modelnet with pretrained weight')\n    parser.add_argument(\n        '--scratch_model', \n        action='store_true', \n        default=False, \n        help = 'training modelnet from scratch')\n    parser.add_argument(\n        '--mode', \n        choices=['easy', 'median', 'hard', None],\n        default=None,\n        help = 'difficulty mode for shapenet')        \n    parser.add_argument(\n        '--way', type=int, default=-1)\n    parser.add_argument(\n        '--shot', type=int, default=-1)\n    parser.add_argument(\n        '--fold', type=int, default=-1)\n    \n    args = parser.parse_args()\n\n    if args.test and args.resume:\n        raise ValueError(\n            '--test and --resume cannot be both activate')\n\n    if args.resume and args.start_ckpts is not None:\n        raise ValueError(\n            '--resume and --start_ckpts cannot be both activate')\n\n    if args.test and args.ckpts is None:\n        raise ValueError(\n            'ckpts shouldnt be None while test mode')\n\n    if args.finetune_model and args.ckpts is None:\n        print(\n            'training from scratch')\n\n    if 'LOCAL_RANK' not in os.environ:\n        os.environ['LOCAL_RANK'] = str(args.local_rank)\n\n    if args.test:\n        args.exp_name = 'test_' + args.exp_name\n    if args.mode is not None:\n        args.exp_name = args.exp_name + '_' +args.mode\n    args.experiment_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem, args.exp_name)\n    args.tfboard_path = os.path.join('./experiments', Path(args.config).stem, Path(args.config).parent.stem,'TFBoard' ,args.exp_name)\n    args.log_name = Path(args.config).stem\n    create_experiment_dir(args)\n    return args\n\ndef create_experiment_dir(args):\n    if not os.path.exists(args.experiment_path):\n        os.makedirs(args.experiment_path)\n        print('Create experiment path successfully at %s' % args.experiment_path)\n    if not os.path.exists(args.tfboard_path):\n        os.makedirs(args.tfboard_path)\n        print('Create TFBoard path successfully at %s' % args.tfboard_path)\n\n"
  },
  {
    "path": "utils/registry.py",
    "content": "import inspect\nimport warnings\nfrom functools import partial\nfrom utils import config\n\nclass Registry:\n    \"\"\"A registry to map strings to classes.\n    Registered object could be built from registry.\n    Example:\n        >>> MODELS = Registry('models')\n        >>> @MODELS.register_module()\n        >>> class ResNet:\n        >>>     pass\n        >>> resnet = MODELS.build(dict(NAME='ResNet'))\n    Please refer to https://mmcv.readthedocs.io/en/latest/registry.html for\n    advanced useage.\n    Args:\n        name (str): Registry name.\n        build_func(func, optional): Build function to construct instance from\n            Registry, func:`build_from_cfg` is used if neither ``parent`` or\n            ``build_func`` is specified. If ``parent`` is specified and\n            ``build_func`` is not given,  ``build_func`` will be inherited\n            from ``parent``. Default: None.\n        parent (Registry, optional): Parent registry. The class registered in\n            children registry could be built from parent. Default: None.\n        scope (str, optional): The scope of registry. It is the key to search\n            for children registry. If not specified, scope will be the name of\n            the package where class is defined, e.g. mmdet, mmcls, mmseg.\n            Default: None.\n    \"\"\"\n\n    def __init__(self, name, build_func=None, parent=None, scope=None):\n        self._name = name\n        self._module_dict = dict()\n        self._children = dict()\n        self._scope = self.infer_scope() if scope is None else scope\n\n        # self.build_func will be set with the following priority:\n        # 1. build_func\n        # 2. parent.build_func\n        # 3. build_from_cfg\n        if build_func is None:\n            if parent is not None:\n                self.build_func = parent.build_func\n            else:\n                self.build_func = build_from_cfg\n        else:\n            self.build_func = build_func\n        if parent is not None:\n            assert isinstance(parent, Registry)\n            parent._add_children(self)\n            self.parent = parent\n        else:\n            self.parent = None\n\n    def __len__(self):\n        return len(self._module_dict)\n\n    def __contains__(self, key):\n        return self.get(key) is not None\n\n    def __repr__(self):\n        format_str = self.__class__.__name__ + \\\n                     f'(name={self._name}, ' \\\n                     f'items={self._module_dict})'\n        return format_str\n\n    @staticmethod\n    def infer_scope():\n        \"\"\"Infer the scope of registry.\n        The name of the package where registry is defined will be returned.\n        Example:\n            # in mmdet/models/backbone/resnet.py\n            >>> MODELS = Registry('models')\n            >>> @MODELS.register_module()\n            >>> class ResNet:\n            >>>     pass\n            The scope of ``ResNet`` will be ``mmdet``.\n        Returns:\n            scope (str): The inferred scope name.\n        \"\"\"\n        # inspect.stack() trace where this function is called, the index-2\n        # indicates the frame where `infer_scope()` is called\n        filename = inspect.getmodule(inspect.stack()[2][0]).__name__\n        split_filename = filename.split('.')\n        return split_filename[0]\n\n    @staticmethod\n    def split_scope_key(key):\n        \"\"\"Split scope and key.\n        The first scope will be split from key.\n        Examples:\n            >>> Registry.split_scope_key('mmdet.ResNet')\n            'mmdet', 'ResNet'\n            >>> Registry.split_scope_key('ResNet')\n            None, 'ResNet'\n        Return:\n            scope (str, None): The first scope.\n            key (str): The remaining key.\n        \"\"\"\n        split_index = key.find('.')\n        if split_index != -1:\n            return key[:split_index], key[split_index + 1:]\n        else:\n            return None, key\n\n    @property\n    def name(self):\n        return self._name\n\n    @property\n    def scope(self):\n        return self._scope\n\n    @property\n    def module_dict(self):\n        return self._module_dict\n\n    @property\n    def children(self):\n        return self._children\n\n    def get(self, key):\n        \"\"\"Get the registry record.\n        Args:\n            key (str): The class name in string format.\n        Returns:\n            class: The corresponding class.\n        \"\"\"\n        scope, real_key = self.split_scope_key(key)\n        if scope is None or scope == self._scope:\n            # get from self\n            if real_key in self._module_dict:\n                return self._module_dict[real_key]\n        else:\n            # get from self._children\n            if scope in self._children:\n                return self._children[scope].get(real_key)\n            else:\n                # goto root\n                parent = self.parent\n                while parent.parent is not None:\n                    parent = parent.parent\n                return parent.get(key)\n\n    def build(self, *args, **kwargs):\n        return self.build_func(*args, **kwargs, registry=self)\n\n    def _add_children(self, registry):\n        \"\"\"Add children for a registry.\n        The ``registry`` will be added as children based on its scope.\n        The parent registry could build objects from children registry.\n        Example:\n            >>> models = Registry('models')\n            >>> mmdet_models = Registry('models', parent=models)\n            >>> @mmdet_models.register_module()\n            >>> class ResNet:\n            >>>     pass\n            >>> resnet = models.build(dict(NAME='mmdet.ResNet'))\n        \"\"\"\n\n        assert isinstance(registry, Registry)\n        assert registry.scope is not None\n        assert registry.scope not in self.children, \\\n            f'scope {registry.scope} exists in {self.name} registry'\n        self.children[registry.scope] = registry\n\n    def _register_module(self, module_class, module_name=None, force=False):\n        if not inspect.isclass(module_class):\n            raise TypeError('module must be a class, '\n                            f'but got {type(module_class)}')\n\n        if module_name is None:\n            module_name = module_class.__name__\n        if isinstance(module_name, str):\n            module_name = [module_name]\n        for name in module_name:\n            if not force and name in self._module_dict:\n                raise KeyError(f'{name} is already registered '\n                               f'in {self.name}')\n            self._module_dict[name] = module_class\n\n    def deprecated_register_module(self, cls=None, force=False):\n        warnings.warn(\n            'The old API of register_module(module, force=False) '\n            'is deprecated and will be removed, please use the new API '\n            'register_module(name=None, force=False, module=None) instead.')\n        if cls is None:\n            return partial(self.deprecated_register_module, force=force)\n        self._register_module(cls, force=force)\n        return cls\n\n    def register_module(self, name=None, force=False, module=None):\n        \"\"\"Register a module.\n        A record will be added to `self._module_dict`, whose key is the class\n        name or the specified name, and value is the class itself.\n        It can be used as a decorator or a normal function.\n        Example:\n            >>> backbones = Registry('backbone')\n            >>> @backbones.register_module()\n            >>> class ResNet:\n            >>>     pass\n            >>> backbones = Registry('backbone')\n            >>> @backbones.register_module(name='mnet')\n            >>> class MobileNet:\n            >>>     pass\n            >>> backbones = Registry('backbone')\n            >>> class ResNet:\n            >>>     pass\n            >>> backbones.register_module(ResNet)\n        Args:\n            name (str | None): The module name to be registered. If not\n                specified, the class name will be used.\n            force (bool, optional): Whether to override an existing class with\n                the same name. Default: False.\n            module (type): Module class to be registered.\n        \"\"\"\n        if not isinstance(force, bool):\n            raise TypeError(f'force must be a boolean, but got {type(force)}')\n        # NOTE: This is a walkaround to be compatible with the old api,\n        # while it may introduce unexpected bugs.\n        if isinstance(name, type):\n            return self.deprecated_register_module(name, force=force)\n\n        # raise the error ahead of time\n        if not (name is None or isinstance(name, str) or misc.is_seq_of(name, str)):\n            raise TypeError(\n                'name must be either of None, an instance of str or a sequence'\n                f'  of str, but got {type(name)}')\n\n        # use it as a normal method: x.register_module(module=SomeClass)\n        if module is not None:\n            self._register_module(\n                module_class=module, module_name=name, force=force)\n            return module\n\n        # use it as a decorator: @x.register_module()\n        def _register(cls):\n            self._register_module(\n                module_class=cls, module_name=name, force=force)\n            return cls\n\n        return _register\n\n\ndef build_from_cfg(cfg, registry, default_args=None):\n    \"\"\"Build a module from config dict.\n    Args:\n        cfg (edict): Config dict. It should at least contain the key \"NAME\".\n        registry (:obj:`Registry`): The registry to search the type from.\n    Returns:\n        object: The constructed object.\n    \"\"\"\n    if not isinstance(cfg, dict):\n        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')\n    if 'NAME' not in cfg:\n        if default_args is None or 'NAME' not in default_args:\n            raise KeyError(\n                '`cfg` or `default_args` must contain the key \"NAME\", '\n                f'but got {cfg}\\n{default_args}')\n    if not isinstance(registry, Registry):\n        raise TypeError('registry must be an mmcv.Registry object, '\n                        f'but got {type(registry)}')\n\n    if not (isinstance(default_args, dict) or default_args is None):\n        raise TypeError('default_args must be a dict or None, '\n                        f'but got {type(default_args)}')\n\n    if default_args is not None:\n        cfg = config.merge_new_config(cfg, default_args)\n\n    obj_type = cfg.get('NAME')\n\n    if isinstance(obj_type, str):\n        obj_cls = registry.get(obj_type)\n        if obj_cls is None:\n            raise KeyError(\n                f'{obj_type} is not in the {registry.name} registry')\n    elif inspect.isclass(obj_type):\n        obj_cls = obj_type\n    else:\n        raise TypeError(\n            f'type must be a str or valid type, but got {type(obj_type)}')\n    try:\n        return obj_cls(cfg)\n    except Exception as e:\n        # Normal TypeError does not print class name.\n        raise type(e)(f'{obj_cls.__name__}: {e}')\n"
  }
]