Repository: princeton-computational-imaging/Diffusion-SDF
Branch: main
Commit: 7eb2d6786b88
Files: 81
Total size: 38.4 MB
Directory structure:
gitextract_hv9jwie8/
├── .gitignore
├── LICENSE
├── README.md
├── config/
│ ├── stage1_sdf/
│ │ └── specs.json
│ ├── stage2_diff_cond/
│ │ └── specs.json
│ ├── stage2_diff_uncond/
│ │ └── specs.json
│ ├── stage3_cond/
│ │ └── specs.json
│ └── stage3_uncond/
│ └── specs.json
├── data/
│ ├── acronym/
│ │ └── Couch/
│ │ └── 37cfcafe606611d81246538126da07a8/
│ │ └── sdf_data.csv
│ ├── grid_data/
│ │ └── acronym/
│ │ └── Couch/
│ │ └── 37cfcafe606611d81246538126da07a8/
│ │ └── grid_gt.csv
│ └── splits/
│ └── couch_all.json
├── dataloader/
│ ├── .base.py.swp
│ ├── __init__.py
│ ├── base.py
│ ├── modulation_loader.py
│ ├── pc_loader.py
│ └── sdf_loader.py
├── diff_utils/
│ ├── helpers.py
│ ├── model_utils.py
│ ├── pointnet/
│ │ ├── __init__.py
│ │ ├── conv_pointnet.py
│ │ ├── dgcnn.py
│ │ ├── dgcnn_bn32.pt
│ │ ├── pointnet_base.py
│ │ ├── pointnet_classifier.py
│ │ └── transformer.py
│ └── sdf_utils.py
├── environment.yml
├── metrics/
│ ├── StructuralLosses/
│ │ ├── __init__.py
│ │ ├── match_cost.py
│ │ └── nn_distance.py
│ ├── __init__.py
│ ├── evaluation_metrics.py
│ └── pytorch_structural_losses/
│ ├── .gitignore
│ ├── Makefile
│ ├── StructuralLosses/
│ │ ├── __init__.py
│ │ ├── match_cost.py
│ │ └── nn_distance.py
│ ├── __init__.py
│ ├── match_cost.py
│ ├── nn_distance.py
│ ├── pybind/
│ │ ├── bind.cpp
│ │ └── extern.hpp
│ ├── setup.py
│ └── src/
│ ├── approxmatch.cu
│ ├── approxmatch.cuh
│ ├── nndistance.cu
│ ├── nndistance.cuh
│ ├── structural_loss.cpp
│ └── utils.hpp
├── models/
│ ├── __init__.py
│ ├── archs/
│ │ ├── __init__.py
│ │ ├── diffusion_arch.py
│ │ ├── encoders/
│ │ │ ├── __init__.py
│ │ │ ├── auto_decoder.py
│ │ │ ├── conv_pointnet.py
│ │ │ ├── dgcnn.py
│ │ │ ├── rbf.py
│ │ │ ├── sal_pointnet.py
│ │ │ └── vanilla_pointnet.py
│ │ ├── modulated_sdf.py
│ │ ├── resnet_block.py
│ │ ├── sdf_decoder.py
│ │ └── unet.py
│ ├── autoencoder.py
│ ├── combined_model.py
│ ├── diff_np_if_torch_error.py
│ ├── diffusion.py
│ └── sdf_model.py
├── test.py
├── train.py
└── utils/
├── __init__.py
├── chamfer.py
├── evaluate.py
├── mesh.py
├── mmd.py
├── reconstruct.py
├── renderer.py
├── tmd.py
├── uhd.py
└── visualize.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.DS_Store
__pycache__/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Gene Chou
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions
[**Paper**](https://arxiv.org/abs/2211.13757) | [**Supplement**](https://light.princeton.edu/wp-content/uploads/2023/03/diffusionsdf_supp.pdf) | [**Project Page**](https://light.princeton.edu/publication/diffusion-sdf/)
This repository contains the official implementation of
**[ICCV 2023] Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions**
[Gene Chou](https://genechou.com), [Yuval Bahat](https://sites.google.com/view/yuval-bahat/home), [Felix Heide](https://www.cs.princeton.edu/~fheide/)
If you find our code or paper useful, please consider citing
```bibtex
@inproceedings{chou2022diffusionsdf,
title={Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions},
author={Gene Chou and Yuval Bahat and Felix Heide},
journal={The IEEE International Conference on Computer Vision (ICCV)},
year={2023}
}
```
```cpp
root directory
├── config
│ └── // folders for checkpoints and training configs
├── data
│ └── // folders for data (in csv format) and train test splits (json)
├── models
│ ├── // models and lightning modules; main model is 'combined_model.py'
│ └── archs
│ └── // architectures such as PointNets, SDF MLPs, diffusion network..etc
├── dataloader
│ └── // dataloaders for different stages of training and generation
├── utils
│ └── // reconstruction and evaluation
├── metrics
│ └── // reconstruction and evaluation
├── diff_utils
│ └── // helper functions for diffusion
├── environment.yml // package requirements
├── train.py // script for training, specify the stage of training in the config files
├── test.py // script for testing, specify the stage of testing in the config files
└── tensorboard_logs // created when running any training script
```
## Installation
We recommend creating an [anaconda](https://www.anaconda.com/) environment using our provided `environment.yml`:
```
conda env create -f environment.yml
conda activate diffusionsdf
```
## Dataset
For training, we preprocess all meshes and store query coordinates and signed distance values in csv files. Each csv file corresponds to one object, and each line represents a coordinate followed by its signed distance value. See `data/acronym` for examples. Modify the dataloader according to your file format.
When sampling query points, make sure to also **sample uniformly within the 3D grid space** (i.e. from (-1,-1,-1) to (1,1,1)) rather than only sampling near the surface to avoid artifacts. For each training batch, we take 70% of query points sampled near the object surface and 30% sampled uniformly in the grid. `grid_source` in our dataloader and config file refers to the latter.
## Training
As described in our [paper](https://arxiv.org/abs/2211.13757), there are three stages of training. All corresponding config files can be found in the `config` folders. Logs are created in a `tensorboard_logs` folder in the root directory. We recommend tuning the `"kld_weight"` when training the joint SDF-VAE model as it enforces the continuity of the latent space. A higher value (e.g. 0.1) will result in better interpolation and generalization but sometimes more artifacts. A lower value (e.g. 0.00001) will result in worse interpolation but higher quality of generations.
1. Training SDF modulations
```
python train.py -e config/stage1_sdf/ -b 32 -w 8 # -b for batch size, -w for workers, -r to resume training
```
Training notes: For Acronym / ShapeNet datasets, the loss should go down to $6 \sim 8 \times 10^{-4}$. Run testing to visualize whether the quality of reconstructed shapes is sufficient. The quality of reconstructions will carry over to the quality of generations. Note that the dimension of the VAE latent vectors will be 3 times `"latent_dim"` in `"SdfModelSpecs"` listed in the config file.
2. Training the diffusion model using the modulations extracted from the first stage
```
# extract the modulations / latent vectors, which will be saved in a "modulations" folder in the config directory
# the folder needs to correspond to "data_path" in the diffusion config files
python test.py -e config/stage1_sdf/ -r last
# unconditional
python train.py -e config/stage2_diff_uncond/ -b 32 -w 8
# conditional
python train.py -e config/stage2_diff_cond/ -b 32 -w 8
```
Training notes: When extracting modulations, we recommend filtering based on the chamfer distance. See `test_modulations()` in `test.py` for details. Some notes on the conditional config file: `"perturb_pc":"partial"`, `"crop_percent":0.5`, and `"sample_pc_size":128` refers to cropping 50% of a point cloud with 128 points to use as condition. `dim` in `diffusion_model_specs` needs to be the dimension of the latent vector, which is 3 times `"latent_dim"` in `"SdfModelSpecs"`.
3. End-to-end training using the saved models from above
```
# unconditional
python train.py -e config/stage3_uncond/ -b 32 -w 8 -r finetune # training from the saved models of first two stages
python train.py -e config/stage3_uncond/ -b 32 -w 8 -r last # resuming training if third stage has been trained
# conditional
python train.py -e config/stage3_cond/ -b 32 -w 8 -r finetune # training from the saved models of first two stages
python train.py -e config/stage3_cond/ -b 32 -w 8 -r last # resuming training if third stage has been trained
```
Training notes: The config file needs to contain the saved checkpoints for the previous two stages of training. The sdf loss (not generated sdf loss) should approach $6 \sim 8 \times 10^{-4}$.
## Testing
1. Testing SDF reconstructions and saving modulations
After the first stage of training, visualize / test reconstructions and save modulations:
```
# extract the modulations / latent vectors, which will be saved in a "modulations" folder in the config directory
# the folder needs to correspond to "data_path" in the diffusion config files
python test.py -e config/stage1_sdf/ -r last
```
A `recon` folder in the config directory will contain the `.ply` reconstructions and a `cd.csv` file that logs Chamfer Distance (CD). A `modulation` folder will contain `latent.txt` files for each SDF. The `modulation` folder will be the data path to the second stage of training.
2. Generations
Meshes can be generated after the second or third stage of training.
```
python test.py -e config/stage3_uncond/ -r finetune # generation after second stage
python test.py -e config/stage3_uncond/ -r last # after third stage
```
A `recon` folder in the config directory will contain the `.ply` reconstructions. `max_batch` arguments in `test.py` are used for running marching cubes; change it to the max value your GPU memory can hold.
## References
We adapt code from
GenSDF https://github.com/princeton-computational-imaging/gensdf
DALLE2-pytorch https://github.com/lucidrains/DALLE2-pytorch
Convolutional Occupancy Networks https://github.com/autonomousvision/convolutional_occupancy_networks (for PointNet encoder)
Multimodal Shape Completion via cGANs https://github.com/ChrisWu1997/Multimodal-Shape-Completion (for conditional metrics)
PointFlow https://github.com/stevenygd/PointFlow (for unconditional metrics)
================================================
FILE: config/stage1_sdf/specs.json
================================================
{
"Description" : "training joint SDF-VAE model for modulating SDFs on the couch dataset",
"DataSource" : "data",
"GridSource" : "data/grid_data",
"TrainSplit" : "data/splits/couch_all.json",
"TestSplit" : "data/splits/couch_all.json",
"training_task": "modulation",
"SdfModelSpecs" : {
"hidden_dim" : 512,
"latent_dim" : 256,
"pn_hidden_dim" : 128,
"num_layers" : 9
},
"SampPerMesh" : 16000,
"PCsize" : 1024,
"num_epochs" : 100001,
"log_freq" : 5000,
"kld_weight" : 1e-5,
"latent_std" : 0.25,
"sdf_lr" : 1e-4
}
================================================
FILE: config/stage2_diff_cond/specs.json
================================================
{
"Description" : "diffusion training (conditional) on couch dataset",
"pc_path" : "data",
"total_pc_size" : 10000,
"TrainSplit" : "data/splits/couch_all.json",
"TestSplit" : "data/splits/couch_all.json",
"data_path" : "config/stage1_sdf/modulations",
"training_task": "diffusion",
"num_epochs" : 50001,
"log_freq" : 5000,
"diff_lr" : 1e-5,
"diffusion_specs" : {
"timesteps" : 1000,
"objective" : "pred_x0",
"loss_type" : "l2",
"perturb_pc" : "partial",
"crop_percent": 0.5,
"sample_pc_size" : 128
},
"diffusion_model_specs": {
"dim" : 768,
"depth" : 4,
"ff_dropout" : 0.3,
"cond" : true,
"cross_attn" : true,
"cond_dropout":true,
"point_feature_dim" : 128
}
}
================================================
FILE: config/stage2_diff_uncond/specs.json
================================================
{
"Description" : "diffusion training (unconditional) on couch dataset",
"TrainSplit" : "data/splits/couch_all.json",
"TestSplit" : "data/splits/couch_all.json",
"data_path" : "config/stage1_sdf/modulations",
"training_task": "diffusion",
"num_epochs" : 50001,
"log_freq" : 5000,
"diff_lr" : 1e-5,
"diffusion_specs" : {
"timesteps" : 1000,
"objective" : "pred_x0",
"loss_type" : "l2"
},
"diffusion_model_specs": {
"dim" : 768,
"dim_in_out" : 768,
"depth" : 4,
"ff_dropout" : 0.3,
"cond" : false
}
}
================================================
FILE: config/stage3_cond/specs.json
================================================
{
"Description" : "end-to-end training (conditional) on couch dataset",
"DataSource" : "data",
"GridSource" : "grid_data",
"TrainSplit" : "data/splits/couch_all.json",
"TestSplit" : "data/splits/couch_all.json",
"modulation_path" : "config/stage1_sdf/modulations",
"modulation_ckpt_path" : "config/stage1_sdf/last.ckpt",
"diffusion_ckpt_path" : "config/stage2_cond/last.ckpt",
"training_task": "combined",
"num_epochs" : 100001,
"log_freq" : 5000,
"kld_weight" : 1e-5,
"latent_std" : 0.25,
"sdf_lr" : 1e-4,
"diff_lr" : 1e-5,
"SdfModelSpecs" : {
"hidden_dim" : 512,
"latent_dim" : 256,
"pn_hidden_dim" : 128,
"num_layers" : 9
},
"SampPerMesh" : 16000,
"PCsize" : 1024,
"diffusion_specs" : {
"timesteps" : 1000,
"objective" : "pred_x0",
"loss_type" : "l2",
"perturb_pc" : "partial",
"crop_percent": 0.5,
"sample_pc_size" : 128
},
"diffusion_model_specs": {
"dim" : 768,
"depth" : 4,
"ff_dropout" : 0.3,
"cond" : true,
"cross_attn" : true,
"cond_dropout":true,
"point_feature_dim" : 128
}
}
================================================
FILE: config/stage3_uncond/specs.json
================================================
{
"Description" : "end-to-end training (unconditional) on couch dataset",
"DataSource" : "data",
"GridSource" : "grid_data",
"TrainSplit" : "data/splits/couch_all.json",
"TestSplit" : "data/splits/couch_all.json",
"modulation_path" : "config/stage1_sdf/modulations",
"modulation_ckpt_path" : "config/stage1_sdf/last.ckpt",
"diffusion_ckpt_path" : "config/stage2_uncond/last.ckpt",
"training_task": "combined",
"num_epochs" : 100001,
"log_freq" : 5000,
"kld_weight" : 1e-5,
"latent_std" : 0.25,
"sdf_lr" : 1e-4,
"diff_lr" : 1e-5,
"SdfModelSpecs" : {
"hidden_dim" : 512,
"latent_dim" : 256,
"pn_hidden_dim" : 128,
"num_layers" : 9
},
"SampPerMesh" : 16000,
"PCsize" : 1024,
"diffusion_specs" : {
"timesteps" : 1000,
"objective" : "pred_x0",
"loss_type" : "l2"
},
"diffusion_model_specs": {
"dim" : 768,
"dim_in_out" : 768,
"depth" : 4,
"ff_dropout" : 0.3,
"cond" : false
}
}
================================================
FILE: data/acronym/Couch/37cfcafe606611d81246538126da07a8/sdf_data.csv
================================================
[File too large to display: 21.4 MB]
================================================
FILE: data/grid_data/acronym/Couch/37cfcafe606611d81246538126da07a8/grid_gt.csv
================================================
[File too large to display: 16.7 MB]
================================================
FILE: data/splits/couch_all.json
================================================
{
"acronym": {
"Couch": [
"37cfcafe606611d81246538126da07a8",
"fd124209f0bc1222f34e56d1d9ad8c65",
"9bbf6977b3f0f9cb11e76965808086c8",
"b6ac23d327248d72627bb9f102840372",
"133f2d360f9b8691f5ee22e800bb9145",
"fa1e1a91e66faf411de55fee5ac2c5c2",
"a10a157254f74b0835836c728d324152",
"9152b02103bdc98633b8015b1af14a5f",
"9cddb828b936db93c341afa383659322",
"69c1e004f9e42b71e7e684d25d4dcaf0",
"694af585378fef692d0fc9d5a4650a3b",
"9156988a1a8645e727eb00c151c6f711",
"e549ab0c75d7fd08cbf65787367134d6",
"354c37c168778a0bd4830313df3656b",
"c32dc92e660bf12b7b6fd5468f603b31",
"a613610e1b9eda601e20309da4bdcbd0",
"8701046c07327e3063f9008a349ae40b",
"95277cbbe4e4a566c5ee765adc4c8ef3",
"c29c56fe616c0986e7e684d25d4dcaf0",
"252be483777007c22e7955415f58545",
"8030e04cf9c19768623a7f06c75e0539",
"21f76612b56d67edf54efb4962ed3879",
"cfa489d103b4f937411053a770408f0d",
"83fd038152fca4ba1c5fa3821fb8a849",
"a38ab75ff03f4befe3f7a74e12a274ef",
"ac376ac6f62d2bd535836c728d324152",
"af796bbeb2990931a1ce49849c98d31c",
"3f90f35a391d5e273b04279a45e22431",
"cfb40e7b9990da99c2f927df125f5ce4",
"26205c6a0a25d5f884099151cc96de84",
"199da405357516a37540a07f0d62945",
"f261a5d15d8dd1cee16219238f4e984c",
"4c5d2b857e9b60a884b236d07141c1a6",
"9b818d54d601b026b161f36d4e309050",
"40060ca118b6dd904d3df84a5eca1a73",
"a17f6d35f289433b5112295517941cf7",
"6475924415a3534111e85af13fa8bb9b",
"87f103e24f91af8d4343db7d677fae7b",
"656925921cfaf668639c9ca4528de8f2",
"b38978ed5bb13266f2b67ae827b02632",
"ce0a19a131121305d2bc6fb367627d3e",
"f645a78ce2bd743ed0c05eb40b42c942",
"2b8d1c67d17d3911d9cff7df5347abca",
"4ef1c8ad9c72263cdc6d207f01b16dd0",
"537438b5f127b55e851f4ba6aaedaaa8",
"2b158fc055b09892e3f7a74e12a274ef",
"abb7fb6cc7c5b2632829384b8c505bb2",
"e68758f3775f700df1783a44a88d6274",
"4dd15f5b09c1494e8058cf23f6382c1",
"6c4eece6ed9f26f1a1741c84cc821e55",
"ae7e1fe11d9243bc5c7a30510dbe4e9f",
"63d0dd6566dfec32ea65c47b660136e7",
"7c6491c7fe65449cd9bdef3310c0e5a6",
"de14f6786ddc7e11bd83c80a89cb8f94",
"a1b0f606bee9dd8e2e7ed79f4f48ff79",
"8efa91e2f3e2eaf7bdc82a7932cd806",
"fa1e84f906d8b61973b1be0dc36d8eee",
"3f6cbfc955e3a82e2f3bfb1446a70fbe",
"a627a11941892ada3707a503164f0d91",
"436a96f58ef9a6fdb039d8689a74349",
"c644689138389daa749ddac355b8e63d",
"ee5f19266a3ed535a6491c91cd8d7770",
"62852049a9f049cca67a896eea47bc81",
"2efd479c13f4cf72af1a85b857ec9afc",
"32b6c232154f2f35550644042dd119da",
"5fffafa504c7ee2a8c4f202fffc87396",
"b714755fb21264c73f89a402a2f2d0aa",
"7e68c12da163627dd2856b4dc0e7a482",
"ce273dabcdd9943b663191fd557d3a61",
"a207c52399b289be9657063f91d78d19",
"f0a02b63e1c84a3fbf7df791578d3fb5",
"34156abf2ce71a60e7e684d25d4dcaf0",
"80982ab523bd3ad35c7200259c323750",
"82460a80f6e9cd8bb0851ce87f32a267",
"823e0ffaf2b47fb27f45370489ca3156",
"e6d92067a01231cb3f7e27638e63d848",
"35b4c449b5ae845013aebd62ca4dcb58",
"8a1a39223639f16e833c6c72c4b62a4d",
"f87682301c61288c25a12159dbc477b5",
"76bef187c092b6335ff61a3a2a0e2484",
"9473a8f3e2182d90d810b14a81e12eca",
"ec9d8a450d5d30935951e132a7bb8604",
"e94c453523b1dabda4669f677ccd56a9",
"d8e1f6781276eb1af34e56d1d9ad8c65",
"1fd45c57ab27cb6cea65c47b660136e7",
"b55d24ed786a8279ad2d3ca7f660ddd",
"30497cce015f143c1191e01810565e0c",
"9b73921863beef532d1fcd5297483645",
"e49c0df0a42bdbecc4b4c7225ff8487e",
"bac90cda9aeb6f5d348e0a0702972359",
"264d40f99914b97e577df49fb73cc67c",
"b1b5d123613e75c97d1a9cfc3f714b6d",
"3b109c1cd75648b71a9a361595e863ab",
"44708e57c3c0140f9cf1b86170c5d320",
"66f193b8f45cd9525582003645829560",
"9c4c4fbaa6e9284fabd9c31d4e727bfe",
"f20e7a860fca9179d57c8a5f8e280cfb",
"8872b51e43d2839b2bbb2dbe46302370",
"323e278c6d7942691836774006aed7e2",
"4106b3d95b6dc95caacb7aee27fab780",
"601bf25b4512502145c6cb69e0968783",
"cb57531ee7104145b0892e337a99f42d",
"b81749d11da67cef57a847db7547c1f3",
"3aab3cf29ab92ce7c29432ec481a60b1",
"e0b897af3eee5ec8d8ac5d7ad0953104",
"eda881a6ea96da8a46874ce99dea28d5",
"5bf5096583e15c0080741efeb2454ffb",
"b765b2c997c459fa83fb3a64ac774b17",
"ca3e6e011ed7ecb32dc86554a873c4fe",
"eb74c7d047809a8c859e01b847403b9e",
"3cce48090f5fabe4e356f23093e95c53",
"dc079a42bd90dcd9593ebeeedbff73b",
"2c12a9ba6d4d25df8af30108ea9ccb6c",
"97f506b27434613d51c4deb11af7079e",
"98e6845fd0c59edef1a8499ee563cb43",
"addd6a0ef4f55a66d810b14a81e12eca",
"21c61d015c6a9591e3f7a74e12a274ef",
"9de72746c0f00317dd73da65dc0e6a7",
"ac8a2b54ab0d2f94493a4a2a112261d3",
"f31648796796cff297c8d78b9aede742",
"c69ce34e38f6218b2f809039658ca52",
"31b5cb5dfaa253b3df85db41e3677c28",
"d8fa31c19a952efb293968bf1f72ae90",
"ab347f12652bda8eab7f9d2da6fc61cf",
"ccd0c5e06744ad9a5ca7e476d2d4a26e",
"daf0e2193c5d1096540b442e651b7ad2",
"3aa613c06675d2a4dd94d3cfe79da065",
"6c3cee36477c444e2e64cef1881421b",
"9a32ae6c8a6afa1e7cfc6b949bbd80bf",
"a85046f089363914e500b815ce2d83a5",
"be2fe14cb518d6346c115ab7398115c6",
"7961d0f612add0cee08bb071746122b9",
"b95a9b552791e21035836c728d324152",
"d2014ddbb4e91d7ecb1a776b5576b46b",
"9c0c2110a58e15febc48810968735439",
"950b680f7b13e72956433d91451f5f9e",
"d01ce0f02e25fc2b42e1bb4fe264125f",
"5fabccb7eaa2adb19a037b4abf810691",
"ad5ef1b493028c0bd810b14a81e12eca",
"a2bdc3a6cb13bf197a969536c4ba7f8",
"3c56ceef171fa142126c0b0ea3ea0a2c",
"a8ae390fc6de647e1191e01810565e0c",
"10507ae95f984daccd8f3fe9ca2145e1",
"ebf1982ccb77cf7d4c37b9ce3a3de242",
"fbd0055862daa31a2d8ad3188383fcc8",
"1d6250cafc410fdfe8058cf23f6382c1",
"7833d94635b755793adc3470b30138f3",
"ee4ec19df3a6b81317fd40ad196436c",
"82d25519070e3d5d6f1ad7def14e2855",
"1ed2e9056a9e94a693e60794f9200b7",
"980d28e46333db492878c1c13b7f1ba6",
"a84ff0a6802b0661f345fb470303964a",
"b58291a2e96691189d1eb836604648db",
"f653f39d443ac6af15c0ed29be4328d5",
"4f20a1fe4c85f35aa8891e678de4fe35",
"90e1a4fb7fa4f4fea5df41522f4e2c86",
"e5f793180e610d329d6f7fa9d8096b18",
"4760c46c66461a79dd3adf3090c701f7",
"3069cb9d632611dcdb43ca77043c03b9",
"6f0f6571f173bd90f9883d2fd957d60f",
"11d5e99e8faa10ff3564590844406360",
"36980a6d57dac873d206493eaec9688a",
"80b3fb6bae282ec24ca35f9816a36449",
"7bfef1d5b096f81967ec3d19eeb10ea5",
"6bfe4dac77d317de1181122615f0a10",
"2b4de06792ec0eba94141819f1b9662c",
"20c030e0d055b7aeb0892e337a99f42d",
"bb53613789a3e9158d736fde63f29661",
"4ca3850427226dd18d4d8cbfa622f635",
"a4d269f9299e845a36e3b2fa8d1eb4eb",
"d13a2ccdbb7740ea83a0857b7b9398b1",
"a98956209f6723a2dedecd2df7bf25e3",
"80c8fcbc5f53c402162cb7009da70820",
"32909910636369b24500047017815f5f",
"b3d686456bd951d42ea98d69e91ba870",
"c0d544b1b483556e4d70389e06680f96",
"d08fc6f10d07dfd8c05575120a46cd3b",
"e2e25d63c3bc14c23459b03b1c80294",
"3d95d6237ea6db97afa2904116693357",
"a7e4616a2a315dfac5ddc26ef5560e77",
"7a92d499a6f3460a4dc2316a7e66d36",
"22fd961578d9b100e858db1dc3499392",
"9136172b17108d1ba7d0cc9b15400f65",
"62e50e8b0d1e3207e047a3592e8436e5",
"7fb8fdbc8c32dc2aad6d7e3d5c0179fe",
"1488c4dbdcd022e99de4fa6f98acbba8",
"9cff96bae9963ceab3c971099efd7fb9",
"5cea034b028af000c2843529921f9ad7",
"8a8c67e553f1c85c1829bffea9d18abb",
"90959c7526608ecdfb54b6f00b1687e2",
"8043ff469c9bed4d48c575435d16be7c",
"b031d0f66e8d047bee8bf38bd6d16329",
"700562e70c89e9d36bd9ce012b65eebe",
"b9994f0252c890cbb0892e337a99f42d",
"7201558f92e55496493a4a2a112261d3",
"e0f5780dbee33329caa60d250862f15f",
"1aa55867200ea789465e08d496c0420f",
"956c5ffbe1ad52976e3c8a33c4ddf2ef",
"3415f252bd71495649920492438878e5",
"d64025df908177ece7e684d25d4dcaf0",
"65309e7601a0c65d82608d43718c2b7",
"f8ef9668c3aba7c735836c728d324152",
"6c4e0987896fc5df30c7f4adc2c33ad6",
"23d1a49a9a29c776ab9281d3b84673cf",
"aa5fe2c92f01ce0a30cbbda41991e4d0",
"2fca00b377269eebdb039d8689a74349",
"fe56059777b240bb833c6c72c4b62a4d",
"82a168f7c5b8899a79368d1198f406e7",
"1ebc20758e0b61184bf4a6644c670a1e",
"f1ce06c5259f771dc24182d0db4c6889",
"a1b935c6288595267795cc6fb87dd247",
"b16913335a26e380d1a4117555fc7e56",
"87bdac1e34f3e6146db2ac45db35c175",
"f693bb3178c80d7f1783a44a88d6274",
"f39cb99f7c30a4b8310cd758d9b7cf",
"91238bfd357c2d87abe6f34906958ac9",
"801418d073795539ddb3d1341171fe71",
"b2e9bf58ab2458b5b057261be64dfc5",
"7669de4a8474b6f4b53857b83094d3a9",
"9d8c5c62020fcaf4b822d48a43773c62",
"a680830f8b76c1bbe929777b2f481029",
"6d0cd48b18471a8bf1444eeb21e761c6",
"baa8760ca5fbbc4840b559ef47048b86",
"377fceb1500e6452d9651cd1d591d64d",
"440e3ad55b603cb1b071d266df0a3bf5",
"21bf3888008b7aced6d2e576c4ef3bde",
"151b2df36e08a13fafeeb1a322c90696",
"24f5497f13a1841adb039d8689a74349",
"a4c8e8816a1c5f54e6e3ac4cbdf2e092",
"4cd19483d852712a120031fb55ce8c1c",
"69257080fd87015369fb37a80cd44134",
"22c68a7a2c8142f027eb00c151c6f711",
"82efd3fc01b0c6f34500047017815f5f",
"fcff900ce37820983f7e27638e63d848",
"cdb7efad9942e4f9695ff9c22a7938c6",
"107637b6bdf8129d4904d89e9169817b",
"f2e7ed2b973570f1a54b9afa882a89ed",
"b58f27bd7e1ace90d2afe8d5254a0d04",
"4820b629990b6a20860f0fe00407fa79",
"abfd8f1db83d6fd24f5ae69aba92b12c",
"4ed802a4aa4b8a86b161f36d4e309050",
"47ad0af4207beedb296baeb5500afa1a",
"f6f5fa7760fde38114aa582256ea1399",
"60fc7123d6360e6d620ef1b4a95dca08",
"c53ec0141303c1eb4508add1163b4513",
"2ebb84f64f8f0f565db77ed1f5c8b93",
"6b0a09fda777eef2a0398757e5dcd13c",
"2c1ecb41c0f0d2cd07c7bf20dae278a",
"e03c28dbfe1f2d9638bb8355830240f9",
"1b29de0b8be4b18733d25da891be74b8",
"38a4328fc3e65ba54c37b9ce3a3de242",
"f81a05c0b43e9bf18c9441777325d0fd",
"4cb7347f6a91da8554f7af0f5a657663",
"5328231a28719ed240a92729068b6b39",
"9239ed83d2274ac975aa7f24a9b6003a",
"71579e02d0b80bcfccebba9929a10b5e",
"8647bc3a58eb9d04493a4a2a112261d3",
"f2c96ae9904baaf220021a69db826002",
"923a0885b7ca9c55d1007f4863f3bba6",
"2f28d32781abb0527f2a6b0d07c10212",
"e68e91ef2652cd1c36e3b2fa8d1eb4eb",
"9451b957aa36883d6e6c0340d644e56e",
"72b7ad431f8c4aa2f5520a8b6a5d82e0",
"54b420da7102d792b36717b39afb6ad8",
"16fd88a99f7d4c857e484225f3bb4a8",
"8bad639b9a650908de650492e45fb14f",
"24129a2a06b35e27c1a6b3575377c8d3",
"fc553c244c85909c8ed898bae5915f59",
"e3ce79fd03b7a78d98661b9abac3e1f9",
"3fcb0aaa346bd46f11e76965808086c8",
"df8374d8f3563be8f1783a44a88d6274",
"7eb94d259b560d24f1783a44a88d6274",
"48205c79d48ca2049f433921788191f3",
"9294163aacac61f1ad5c4e4076e069c",
"f2458aaf4ab3e0585d7543afa4b9b4e8",
"f09a3e98938bc1a990e678878ceea8a4",
"2e2f34305ef8cbc1533ccec14d70360b",
"1a477f7b2c1799e1b728e6e715c3f8cf",
"bcd2418fb0d9727c563fcc2752ece39",
"ace76562ee9d7c3a913c66b05d18ae8",
"1b5ae67e2ffb387341fbc1e2da054acb",
"f563b39b846922f22ea98d69e91ba870",
"fce717669ca521884e1a9fae5403e01f",
"d0563b5c26d096ea7b270a7ae3fc6351",
"43656f715bb28776bff15b656f256f05",
"3006b66ed9a9fd55a4ccffb47ce1f60d",
"f5dc5957a6b9712835836c728d324152",
"3d87710d90c8627dd2afe8d5254a0d04",
"21b22c30f1c6ddb9952d5d6c0ee49300",
"e077c6640dab26ecc6e735f548844733",
"bdfcf2086fafb0fec8a04932b17782af",
"fb883c27f9f0156e1e44b635c3d74c0b",
"161f92c3330d111ffc2dd24b885b9c05",
"b23dc14d788e954b3adc3470b30138f3",
"a930d381392ff51140b559ef47048b86",
"7c92e64a328f1b968f6cc6fefa15515a",
"8affea22019b77a1f1783a44a88d6274",
"1dc08faeb376c6e0296baeb5500afa1a",
"fadd7d8c94893136e4b1c2efb094888b",
"6bd7475753c3d1d62764cfba57a5de73",
"5d9f1c6f9ce9333994c6d0877753424f",
"c3e248c5f88e72e9ea65c47b660136e7",
"87c7911287c10ab8407b28046774089c",
"4fbc051d0905bb5bdaeb838d0771f3b5",
"b097e092f951f1d68bc0997abbde7f8",
"1aec9ac7e1487b7bc75516c7217fb218",
"2f9a502dafbaa50769cd744177574ad3",
"9df9d1c02e9013e7ef10d8e00e9d279c",
"f0f42d26c4a0be52a53016a50348cce3",
"f080807207cc4859b2403dba7fd079eb",
"ffcc57ea3101d18ece3df8a7477638c0",
"f01821fb48425e3a493a4a2a112261d3",
"8cc805e63f357a69ae336243a6891b91",
"468e627a950f9340d2bc6fb367627d3e",
"499edbd7de3e7423bb865c00ef25280",
"686c5c4a39de31445c31d9fe12bfc9ab",
"1470c6423b877a0882fcee0c19bca00a",
"bcc267b5e694e60c3adc3470b30138f3",
"e7b9fef210fa47505615f13c9cba61e5",
"ddc8bdf6ac3786f493b55327c66a07aa",
"27359b0203dc2f345c6cb69e0968783",
"1e355dc9bfd5da837a8c23d2d40f51b8",
"b2cfba8ee63abd118fac6a8030e15671",
"f1a956705451d66ee95eb670e9df38e4",
"6fc69edce1f6d0d5e7e684d25d4dcaf0",
"2d3f8abd0567f1601191e01810565e0c",
"52d307203aefd6bf366971e8a2cdf120",
"be129d18d202650f6d3e11439c6c22c8",
"2d4adec4323f8a7da6491c91cd8d7770",
"9695e057d7a4e992f2b67ae827b02632",
"21addfde981f0e3445c6cb69e0968783",
"f8a5ab9872d0dc5e1855a135fe295583",
"3f8f1d7023ae4a1c73ffdca541f4749c",
"1e4a7cb88d31716cc9c93fe51d670e21",
"2cf89f5b70817c3bab2844e3834b9ca1",
"e410469475df6525493a4a2a112261d3",
"81f1ac4d9c29025a36f739eb863924d1",
"c955e564c9a73650f78bdf37d618e97e",
"20b6d398c5b93a253adc3470b30138f3",
"ed4ac116e03ebb8d663191fd557d3a61",
"4735568bd188aefcb8e1b99345a5afd4",
"6b3b1d804255188993680c5a9a367b4a",
"5ea73e05bb96272b444ac0ff78e79b7",
"2eed4a1647d05552413240bb8f2e4eea",
"6e960c704961590e7e0d7ca07704f74f",
"dccb04ee585d9d10cf065d4b58dc3084",
"1eb4bc0cea4ba47c8186c25526ebdaa6",
"f144cda19f9457fef9b7ca92584b5271",
"c971b2adf8f071f49afca5ea91159877",
"1e96607159093f7bbdb7b8a99c2de3d8",
"a29f6cce8bf52a99182529900d54df69",
"f260588ef2f4c8e62a6f69f20bf5a185",
"72b7e2a9bdef8b37f491193d3fde480b",
"b5fe8de26eac454acdead18f26cf2ece",
"6caa1713e795c8a2f0478431b5ad57db",
"1914d0e6b9f0445b40e80a2d9f005aa7",
"c5380b779c689a919201f2703b45dd7",
"735122a1019fd6529dac46bde4c69ef2",
"660df170c4337cda35836c728d324152",
"770bf9d88c7f039227eb00c151c6f711",
"2b73510e4eb3d8ca87b66c61c1f7b8e4",
"3e499689bae22f3ab89ca298cd9a646",
"ae1fd69e00c00a0253116daf5a10d9c",
"823219c03b02a423c1a85f2b9754d96f",
"f87c2d4d5d622293a3f3891e682aaded",
"ee2ea7d9f26f57886580dafdf181b2d7"
]}
}
================================================
FILE: dataloader/__init__.py
================================================
================================================
FILE: dataloader/base.py
================================================
#!/usr/bin/env python3
import numpy as np
import time
import logging
import os
import random
import torch
import torch.utils.data
import pandas as pd
import csv
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
data_source,
split_file, # json filepath which contains train/test classes and meshes
subsample,
gt_filename,
#pc_size=1024,
):
self.data_source = data_source
self.subsample = subsample
self.split_file = split_file
self.gt_filename = gt_filename
#self.pc_size = pc_size
# example
# data_source: "data"
# ws.sdf_samples_subdir: "SdfSamples"
# self.gt_files[0]: "acronym/couch/meshname/sdf_data.csv"
# with gt_filename="sdf_data.csv"
def __len__(self):
return NotImplementedError
def __getitem__(self, idx):
return NotImplementedError
def sample_pointcloud(self, csvfile, pc_size):
f=pd.read_csv(csvfile, sep=',',header=None).values
f = f[f[:,-1]==0][:,:3]
if f.shape[0] < pc_size:
pc_idx = np.random.choice(f.shape[0], pc_size)
else:
pc_idx = np.random.choice(f.shape[0], pc_size, replace=False)
return torch.from_numpy(f[pc_idx]).float()
def labeled_sampling(self, f, subsample, pc_size=1024, load_from_path=True):
if load_from_path:
f=pd.read_csv(f, sep=',',header=None).values
f = torch.from_numpy(f)
half = int(subsample / 2)
neg_tensor = f[f[:,-1]<0]
pos_tensor = f[f[:,-1]>0]
if pos_tensor.shape[0] < half:
pos_idx = torch.randint(0, pos_tensor.shape[0], (half,))
else:
pos_idx = torch.randperm(pos_tensor.shape[0])[:half]
if neg_tensor.shape[0] < half:
if neg_tensor.shape[0]==0:
neg_idx = torch.randperm(pos_tensor.shape[0])[:half] # no neg indices, then just fill with positive samples
else:
neg_idx = torch.randint(0, neg_tensor.shape[0], (half,))
else:
neg_idx = torch.randperm(neg_tensor.shape[0])[:half]
pos_sample = pos_tensor[pos_idx]
if neg_tensor.shape[0]==0:
neg_sample = pos_tensor[neg_idx]
else:
neg_sample = neg_tensor[neg_idx]
pc = f[f[:,-1]==0][:,:3]
pc_idx = torch.randperm(pc.shape[0])[:pc_size]
pc = pc[pc_idx]
samples = torch.cat([pos_sample, neg_sample], 0)
return pc.float().squeeze(), samples[:,:3].float().squeeze(), samples[:, 3].float().squeeze() # pc, xyz, sdv
def get_instance_filenames(self, data_source, split, gt_filename="sdf_data.csv", filter_modulation_path=None):
do_filter = filter_modulation_path is not None
csvfiles = []
for dataset in split: # e.g. "acronym" "shapenet"
for class_name in split[dataset]:
for instance_name in split[dataset][class_name]:
instance_filename = os.path.join(data_source, dataset, class_name, instance_name, gt_filename)
if do_filter:
mod_file = os.path.join(filter_modulation_path, class_name, instance_name, "latent.txt")
# do not load if the modulation does not exist; i.e. was not trained by diffusion model
if not os.path.isfile(mod_file):
continue
if not os.path.isfile(instance_filename):
logging.warning("Requested non-existent file '{}'".format(instance_filename))
continue
csvfiles.append(instance_filename)
return csvfiles
================================================
FILE: dataloader/modulation_loader.py
================================================
#!/usr/bin/env python3
import time
import logging
import os
import random
import torch
import torch.utils.data
from diff_utils.helpers import *
import pandas as pd
import numpy as np
import csv, json
from tqdm import tqdm
class ModulationLoader(torch.utils.data.Dataset):
def __init__(self, data_path, pc_path=None, split_file=None, pc_size=None):
super().__init__()
self.conditional = pc_path is not None
if self.conditional:
self.modulations, pc_paths = self.load_modulations(data_path, pc_path, split_file)
else:
self.modulations = self.unconditional_load_modulations(data_path, split_file)
#self.modulations = self.modulations[0:8]
#pc_paths = pc_paths[0:8]
print("data shape, dataset len: ", self.modulations[0].shape, len(self.modulations))
#assert args.batch_size <= len(self.modulations)
if self.conditional:
print("loading ground truth point clouds...")
lst = []
with tqdm(pc_paths) as pbar:
for i, f in enumerate(pc_paths):
pbar.set_description("Point clouds loaded: {}/{}".format(i, len(pc_paths)))
lst.append(sample_pc(f, pc_size))
self.point_clouds = lst
assert len(self.point_clouds) == len(self.modulations)
def __len__(self):
return len(self.modulations)
def __getitem__(self, index):
pc = self.point_clouds[index] if self.conditional else False
return {
"point_cloud" : pc,
"latent" : self.modulations[index]
}
def load_modulations(self, data_source, pc_source, split, f_name="latent.txt", add_flip_augment=False, return_filepaths=True):
#split = json.load(open(split))
files = []
filepaths = [] # return filepaths for loading pcs
for dataset in split: # dataset = "acronym"
for class_name in split[dataset]:
for instance_name in split[dataset][class_name]:
if add_flip_augment:
for idx in range(4):
instance_filename = os.path.join(data_source, class_name, instance_name, "latent_{}.txt".format(idx))
if not os.path.isfile(instance_filename):
print("Requested non-existent file '{}'".format(instance_filename))
continue
files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
filepaths.append( os.path.join(pc_source, dataset, class_name, instance_name, "sdf_data.csv") )
else:
instance_filename = os.path.join(data_source, class_name, instance_name, f_name)
if not os.path.isfile(instance_filename):
#print("Requested non-existent file '{}'".format(instance_filename))
continue
files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
filepaths.append( os.path.join(pc_source, dataset, class_name, instance_name, "sdf_data.csv") )
if return_filepaths:
return files, filepaths
return files
def unconditional_load_modulations(self, data_source, split, f_name="latent.txt", add_flip_augment=False):
files = []
for dataset in split: # dataset = "acronym"
for class_name in split[dataset]:
for instance_name in split[dataset][class_name]:
if add_flip_augment:
for idx in range(4):
instance_filename = os.path.join(data_source, class_name, instance_name, "latent_{}.txt".format(idx))
if not os.path.isfile(instance_filename):
print("Requested non-existent file '{}'".format(instance_filename))
continue
files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
else:
instance_filename = os.path.join(data_source, class_name, instance_name, f_name)
if not os.path.isfile(instance_filename):
continue
files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
return files
================================================
FILE: dataloader/pc_loader.py
================================================
#!/usr/bin/env python3
import time
import logging
import os
import random
import torch
import torch.utils.data
from . import base
from tqdm import tqdm
import pandas as pd
import csv
class PCloader(base.Dataset):
def __init__(
self,
data_source,
split_file, # json filepath which contains train/test classes and meshes
pc_size=1024,
return_filename=False
):
self.pc_size = pc_size
self.gt_files = self.get_instance_filenames(data_source, split_file)
self.return_filename = return_filename
self.pc_paths = self.get_instance_filenames(data_source, split_file)
self.pc_paths = self.pc_paths[:5]
print("loading {} point clouds into memory...".format(len(self.pc_paths)))
lst = []
with tqdm(self.pc_paths) as pbar:
for i, f in enumerate(pbar):
pbar.set_description("Files loaded: {}/{}".format(i, len(self.pc_paths)))
lst.append(self.sample_pc(f, pc_size))
self.point_clouds = lst
#print("each pc shape: ", self.point_clouds[0].shape)
def get_all_files(self):
return self.point_clouds, self.pc_paths
def __getitem__(self, idx):
if self.return_filename:
return self.point_clouds[idx], self.pc_paths[idx]
else:
return self.point_clouds[idx]
def __len__(self):
return len(self.point_clouds)
def sample_pc(self, f, samp=1024):
'''
f: path to csv file
'''
# data = torch.from_numpy(np.loadtxt(f, delimiter=',')).float()
data = torch.from_numpy(pd.read_csv(f, sep=',',header=None).values).float()
pc = data[data[:,-1]==0][:,:3]
pc_idx = torch.randperm(pc.shape[0])[:samp]
pc = pc[pc_idx]
#print("pc shape, dtype: ", pc.shape, pc.dtype) # [1024,3], torch.float32
#pc = normalize_pc(pc)
#print("pc shape: ", pc.shape, pc.max(), pc.min())
return pc
================================================
FILE: dataloader/sdf_loader.py
================================================
#!/usr/bin/env python3
import time
import logging
import os
import random
import torch
import torch.utils.data
from . import base
import pandas as pd
import numpy as np
import csv, json
from tqdm import tqdm
class SdfLoader(base.Dataset):
def __init__(
self,
data_source, # path to points sampled around surface
split_file, # json filepath which contains train/test classes and meshes
grid_source=None, # path to grid points; grid refers to sampling throughout the unit cube instead of only around the surface; necessary for preventing artifacts in empty space
samples_per_mesh=16000,
pc_size=1024,
modulation_path=None # used for third stage of training; needs to be set in config file when some modulation training had been filtered
):
self.samples_per_mesh = samples_per_mesh
self.pc_size = pc_size
self.gt_files = self.get_instance_filenames(data_source, split_file, filter_modulation_path=modulation_path)
subsample = len(self.gt_files)
self.gt_files = self.gt_files[0:subsample]
self.grid_source = grid_source
#print("grid source: ", grid_source)
if grid_source:
self.grid_files = self.get_instance_filenames(grid_source, split_file, gt_filename="grid_gt.csv", filter_modulation_path=modulation_path)
self.grid_files = self.grid_files[0:subsample]
lst = []
with tqdm(self.grid_files) as pbar:
for i, f in enumerate(pbar):
pbar.set_description("Grid files loaded: {}/{}".format(i, len(self.grid_files)))
lst.append(torch.from_numpy(pd.read_csv(f, sep=',',header=None).values))
self.grid_files = lst
assert len(self.grid_files) == len(self.gt_files)
# load all csv files first
print("loading all {} files into memory...".format(len(self.gt_files)))
lst = []
with tqdm(self.gt_files) as pbar:
for i, f in enumerate(pbar):
pbar.set_description("Files loaded: {}/{}".format(i, len(self.gt_files)))
lst.append(torch.from_numpy(pd.read_csv(f, sep=',',header=None).values))
self.gt_files = lst
def __getitem__(self, idx):
near_surface_count = int(self.samples_per_mesh*0.7) if self.grid_source else self.samples_per_mesh
pc, sdf_xyz, sdf_gt = self.labeled_sampling(self.gt_files[idx], near_surface_count, self.pc_size, load_from_path=False)
if self.grid_source is not None:
grid_count = self.samples_per_mesh - near_surface_count
_, grid_xyz, grid_gt = self.labeled_sampling(self.grid_files[idx], grid_count, pc_size=0, load_from_path=False)
# each getitem is one batch so no batch dimension, only N, 3 for xyz or N for gt
# for 16000 points per batch, near surface is 11200, grid is 4800
#print("shapes: ", pc.shape, sdf_xyz.shape, sdf_gt.shape, grid_xyz.shape, grid_gt.shape)
sdf_xyz = torch.cat((sdf_xyz, grid_xyz))
sdf_gt = torch.cat((sdf_gt, grid_gt))
#print("shapes after adding grid: ", pc.shape, sdf_xyz.shape, sdf_gt.shape, grid_xyz.shape, grid_gt.shape)
data_dict = {
"xyz":sdf_xyz.float().squeeze(),
"gt_sdf":sdf_gt.float().squeeze(),
"point_cloud":pc.float().squeeze(),
}
return data_dict
def __len__(self):
return len(self.gt_files)
================================================
FILE: diff_utils/helpers.py
================================================
import math
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
from inspect import isfunction
import os
import json
#import open3d as o3d
def get_split_filenames(data_source, split_file, f_name="sdf_data.csv"):
split = json.load(open(split_file))
csvfiles = []
for dataset in split: # e.g. "acronym" "shapenet"
for class_name in split[dataset]:
for instance_name in split[dataset][class_name]:
instance_filename = os.path.join(data_source, dataset, class_name, instance_name, f_name)
if not os.path.isfile(instance_filename):
print("Requested non-existent file '{}'".format(instance_filename))
continue
csvfiles.append(instance_filename)
return csvfiles
def sample_pc(f, samp=1024, add_flip_augment=False):
'''
f: path to csv file
'''
data = torch.from_numpy(pd.read_csv(f, sep=',',header=None).values).float()
pc = data[data[:,-1]==0][:,:3]
pc_idx = torch.randperm(pc.shape[0])[:samp]
pc = pc[pc_idx]
if add_flip_augment:
pcs = []
flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=pc.device)
for idx, axis in enumerate(flip_axes):
pcs.append(pc * axis)
return pcs
return pc
def perturb_point_cloud(pc, perturb, pc_size=None, crop_percent=0.25):
'''
if pc_size is None, return entire pc; else return with shape of pc_size
'''
assert perturb in [None, "partial", "noisy"]
if perturb is None:
pc_idx = torch.randperm(pc.shape[1])[:pc_size]
pc = pc[:,pc_idx]
#print("pc shape: ", pc.shape)
return pc
elif perturb == "partial":
return crop_pc(pc, crop_percent, pc_size)
elif perturb == "noisy":
return jitter_pc(pc, pc_size)
def fps(data, number):
'''
data B N 3
number int
'''
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
return fps_data
def crop_pc(xyz, crop, pc_size=None, fixed_points = None, padding_zeros = False):
'''
crop the point cloud given a randomly selected view
input point cloud: xyz, with shape (B, N, 3)
crop: float, percentage of points to crop out (e.g. 0.25 means keep 75% of points)
pc_size: integer value, how many points to return; None if return all (all meaning xyz size * crop)
'''
if pc_size is not None:
xyz = xyz[:, torch.randperm(xyz.shape[1])[:pc_size] ]
_,n,c = xyz.shape
device = xyz.device
crop = int(xyz.shape[1]*crop)
#print("pc shape: ", xyz.shape, crop)
assert c == 3
if crop == n:
return xyz # , None
INPUT = []
CROP = []
for points in xyz:
if isinstance(crop,list):
num_crop = random.randint(crop[0],crop[1])
else:
num_crop = crop
points = points.unsqueeze(0)
if fixed_points is None:
center = F.normalize(torch.randn(1,1,3, device=device),p=2,dim=-1)
else:
if isinstance(fixed_points,list):
fixed_point = random.sample(fixed_points,1)[0]
else:
fixed_point = fixed_points
center = fixed_point.reshape(1,1,3).to(device)
distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1) # 1 1 2048
idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048
if padding_zeros:
input_data = points.clone()
input_data[0, idx[:num_crop]] = input_data[0,idx[:num_crop]] * 0
else:
input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3
crop_data = points.clone()[0, idx[:num_crop]].unsqueeze(0)
if isinstance(crop,list):
INPUT.append(fps(input_data,2048))
CROP.append(fps(crop_data,2048))
else:
INPUT.append(input_data)
CROP.append(crop_data)
input_data = torch.cat(INPUT,dim=0)# B N 3
crop_data = torch.cat(CROP,dim=0)# B M 3
return input_data.contiguous() #, crop_data.contiguous()
def visualize_pc(pc):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pc.reshape(-1,3))
o3d.io.write_point_cloud("./pc.ply", pcd)
#o3d.visualization.draw_geometries([pcd])
def jitter_pc(pc, pc_size=None, sigma=0.05, clip=0.1):
device = pc.device
pc += torch.clamp(sigma*torch.randn(*pc.shape, device=device), -1*clip, clip)
if pc_size is not None:
if len(pc.shape) == 3: # B, N, 3
pc = pc[:, torch.randperm(pc.shape[1])[:pc_size] ]
else: # N, 3
pc = pc[torch.randperm(pc.shape[0])[:pc_size] ]
return pc
def normalize_pc(pc):
pc -= torch.mean(pc, axis=0)
m = torch.max(torch.sqrt(torch.sum(pc**2, axis=1)))
#bbox_length = torch.sqrt( torch.sum((torch.max(pc, axis=0)[0] - torch.min(pc, axis=0)[0])**2) )
pc /= m
return pc
def save_model(iters, model, optimizer, loss, path):
torch.save({'iters': iters,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss},
path)
def load_model(model, optimizer, path):
checkpoint = torch.load(path)
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
optimizer = None
model.load_state_dict(checkpoint['model_state_dict'])
loss = checkpoint['loss']
iters = checkpoint['iters']
print("loading from iter {}...".format(iters))
return iters, model, optimizer, loss
def save_code_to_conf(conf_dir):
path = os.path.join(conf_dir, "code")
os.makedirs(path, exist_ok=True)
for folder in ["utils", "models", "diff_utils", "dataloader", "metrics"]:
os.makedirs(os.path.join(path, folder), exist_ok=True)
os.system("""cp -r ./{0}/* "{1}" """.format(folder, os.path.join(path, folder)))
# other files
os.system("""cp *.py "{}" """.format(path))
class ScheduledOpt:
'''
optimizer = ScheduledOpt(4000, torch.optim.Adam(model.parameters(), lr=0))
'''
"Optim wrapper that implements rate."
def __init__(self, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self._rate = 0
# def state_dict(self):
# """Returns the state of the warmup scheduler as a :class:`dict`.
# It contains an entry for every variable in self.__dict__ which
# is not the optimizer.
# """
# return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
# def load_state_dict(self, state_dict):
# """Loads the warmup scheduler's state.
# Arguments:
# state_dict (dict): warmup scheduler state. Should be an object returned
# from a call to :meth:`state_dict`.
# """
# self.__dict__.update(state_dict)
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
#print("rate: ",rate)
def zero_grad(self):
self.optimizer.zero_grad()
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
warm_schedule = torch.linspace(0, 3e-4, self.warmup, dtype = torch.float64)
if step < self.warmup:
return warm_schedule[step]
else:
return 3e-4 / (math.sqrt(step-self.warmup+1))
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cycle(dl):
while True:
for data in dl:
yield data
def has_int_squareroot(num):
return (math.sqrt(num) ** 2) == num
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def convert_image_to(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
# normalization functions
#from 0,1 to -1,1
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
# from -1,1 to 0,1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# from any batch to [0,1]
# f should have shape (batch, -1)
def normalize_to_zero_to_one(f):
f -= f.min(1, keepdim=True)[0]
f /= f.max(1, keepdim=True)[0]
return f
# extract the appropriate t index for a batch of indices
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def linear_beta_schedule(timesteps):
#print("using LINEAR schedule")
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
#print("using COSINE schedule")
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
cos_in = ((x / timesteps) + s) / (1 + s) * math.pi * 0.5
np_in = cos_in.numpy()
alphas_cumprod = np.cos(np_in) ** 2
alphas_cumprod = torch.from_numpy(alphas_cumprod)
#alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
================================================
FILE: diff_utils/model_utils.py
================================================
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from .pointnet.pointnet_classifier import PointNetClassifier
from .pointnet.conv_pointnet import ConvPointnet
from .pointnet.dgcnn import DGCNN
from .helpers import *
class LayerNorm(nn.Module):
def __init__(self, dim, eps = 1e-5, stable = False):
super().__init__()
self.eps = eps
self.stable = stable
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
if self.stable:
x = x / x.amax(dim = -1, keepdim = True).detach()
var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
mean = torch.mean(x, dim = -1, keepdim = True)
return (x - mean) * (var + self.eps).rsqrt() * self.g
# mlp
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor = 2.,
depth = 2,
norm = False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(
nn.Linear(dim_in, hidden_dim),
nn.SiLU(),
norm_fn()
)]
for _ in range(depth - 1):
layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_fn()
))
layers.append(nn.Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
# relative positional bias for causal transformer
class RelPosBias(nn.Module):
def __init__(
self,
heads = 8,
num_buckets = 32,
max_distance = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
n = -relative_position
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
return torch.where(is_small, n, val_if_large)
def forward(self, i, j, *, device):
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')
# feedforward
class SwiGLU(nn.Module):
""" used successfully in https://arxiv.org/abs/2204.0231 """
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(
dim,
out_dim = None,
mult = 4,
dropout = 0.,
post_activation_norm = False
):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
#print("dropout: ", dropout)
out_dim = default(out_dim, dim)
#print("out_dim: ", out_dim)
inner_dim = int(mult * dim)
return nn.Sequential(
LayerNorm(dim),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, out_dim, bias = False)
)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class Attention(nn.Module):
def __init__(
self,
dim,
kv_dim=None,
*,
out_dim = None,
dim_head = 64,
heads = 8,
dropout = 0.,
causal = False,
rotary_emb = None,
pb_relax_alpha = 128
):
super().__init__()
self.pb_relax_alpha = pb_relax_alpha
self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)
self.heads = heads
inner_dim = dim_head * heads
kv_dim = default(kv_dim, dim)
self.causal = causal
self.norm = LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(kv_dim, dim_head * 2, bias = False)
self.rotary_emb = rotary_emb
out_dim = default(out_dim, dim)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, out_dim, bias = False),
LayerNorm(out_dim)
)
def forward(self, x, context=None, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
context = default(context, x) #self attention if context is None
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
q = q * self.scale
# rotary embeddings
if exists(self.rotary_emb):
q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
# calculate query / key similarities
sim = einsum('b h i d, b j d -> b h i j', q, k)
# relative positional encoding (T5 style)
#print("attn bias, sim shapes: ", attn_bias.shape, sim.shape)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
# attention
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
sim = sim * self.pb_relax_alpha
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
================================================
FILE: diff_utils/pointnet/__init__.py
================================================
================================================
FILE: diff_utils/pointnet/conv_pointnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch_scatter import scatter_mean, scatter_max
class ConvPointnet(nn.Module):
''' PointNet-based encoder network with ResNet blocks for each point.
Number of input points are fixed.
Args:
c_dim (int): dimension of latent code c
dim (int): input points dimension
hidden_dim (int): hidden dimension of the network
scatter_type (str): feature aggregation when doing local pooling
unet (bool): weather to use U-Net
unet_kwargs (str): U-Net parameters
plane_resolution (int): defined resolution for plane feature
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
n_blocks (int): number of blocks ResNetBlockFC layers
'''
def __init__(self, c_dim=512, dim=3, hidden_dim=128, scatter_type='max',
unet=True, unet_kwargs={"depth": 4, "merge_mode": "concat", "start_filts": 32},
plane_resolution=64, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
super().__init__()
self.c_dim = c_dim
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.blocks = nn.ModuleList([
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU()
self.hidden_dim = hidden_dim
if unet:
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
else:
self.unet = None
self.reso_plane = plane_resolution
self.plane_type = plane_type
self.padding = padding
if scatter_type == 'max':
self.scatter = scatter_max
elif scatter_type == 'mean':
self.scatter = scatter_mean
# takes in "p": point cloud and "query": sdf_xyz
# sample plane features for unlabeled_query as well
def forward(self, p, query):
batch_size, T, D = p.size()
# acquire the index for each point
coord = {}
index = {}
if 'xz' in self.plane_type:
coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
if 'xy' in self.plane_type:
coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
if 'yz' in self.plane_type:
coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
net = self.fc_pos(p)
net = self.blocks[0](net)
for block in self.blocks[1:]:
pooled = self.pool_local(coord, index, net)
net = torch.cat([net, pooled], dim=2)
net = block(net)
c = self.fc_c(net)
fea = {}
plane_feat_sum = 0
if 'xz' in self.plane_type:
fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
if 'xy' in self.plane_type:
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
if 'yz' in self.plane_type:
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
return plane_feat_sum.transpose(2,1)
def normalize_coordinate(self, p, padding=0.1, plane='xz'):
''' Normalize coordinate to [0, 1] for unit cube experiments
Args:
p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
plane (str): plane feature type, ['xz', 'xy', 'yz']
'''
if plane == 'xz':
xy = p[:, :, [0, 2]]
elif plane =='xy':
xy = p[:, :, [0, 1]]
else:
xy = p[:, :, [1, 2]]
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
xy_new = xy_new + 0.5 # range (0, 1)
# f there are outliers out of the range
if xy_new.max() >= 1:
xy_new[xy_new >= 1] = 1 - 10e-6
if xy_new.min() < 0:
xy_new[xy_new < 0] = 0.0
return xy_new
def coordinate2index(self, x, reso):
''' Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model
Args:
x (tensor): coordinate
reso (int): defined resolution
coord_type (str): coordinate type
'''
x = (x * reso).long()
index = x[:, :, 0] + reso * x[:, :, 1]
index = index[:, None, :]
return index
# xy is the normalized coordinates of the point cloud of each plane
# I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
def pool_local(self, xy, index, c):
bs, fea_dim = c.size(0), c.size(2)
keys = xy.keys()
c_out = 0
for key in keys:
# scatter plane features from points
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
if self.scatter == scatter_max:
fea = fea[0]
# gather feature back to points
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
c_out += fea
return c_out.permute(0, 2, 1)
def generate_plane_features(self, p, c, plane='xz'):
# acquire indices of features in plane
xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
index = self.coordinate2index(xy, self.reso_plane)
# scatter plane features from points
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
c = c.permute(0, 2, 1) # B x 512 x T
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
# process the plane features with UNet
if self.unet is not None:
fea_plane = self.unet(fea_plane)
return fea_plane
# sample_plane_feature function copied from /src/conv_onet/models/decoder.py
# uses values from plane_feature and pixel locations from vgrid to interpolate feature
def sample_plane_feature(self, query, plane_feature, plane):
xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
xy = xy[:, :, None].float()
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
return sampled_feat
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class DownConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pooling = pooling
self.conv1 = conv3x3(self.in_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
before_pool = x
if self.pooling:
x = self.pool(x)
return x, before_pool
class UpConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'):
super(UpConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels,
mode=self.up_mode)
if self.merge_mode == 'concat':
self.conv1 = conv3x3(
2*self.out_channels, self.out_channels)
else:
# num of input channels to conv2 is same
self.conv1 = conv3x3(self.out_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
def forward(self, from_down, from_up):
""" Forward pass
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
"""
from_up = self.upconv(from_up)
if self.merge_mode == 'concat':
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
class UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
expansive pathway) about an input tensor is merged with
information representing the localization of details
(from the encoding, compressive pathway).
Modifications to the original paper:
(1) padding is used in 3x3 convolutions to prevent loss
of border pixels
(2) merging outputs does not require cropping due to (1)
(3) residual connections can be used by specifying
UNet(merge_mode='add')
(4) if non-parametric upsampling is used in the decoder
pathway (specified by upmode='upsample'), then an
additional 1x1 2d convolution occurs after upsampling
to reduce channel dimensionality by a factor of 2.
This channel halving happens with the convolution in
the tranpose convolution (specified by upmode='transpose')
"""
def __init__(self, num_classes, in_channels=3, depth=5,
start_filts=64, up_mode='transpose',
merge_mode='concat', **kwargs):
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(UNet, self).__init__()
if up_mode in ('transpose', 'upsample'):
self.up_mode = up_mode
else:
raise ValueError("\"{}\" is not a valid mode for "
"upsampling. Only \"transpose\" and "
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ('concat', 'add'):
self.merge_mode = merge_mode
else:
raise ValueError("\"{}\" is not a valid mode for"
"merging up and down paths. "
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add':
raise ValueError("up_mode \"upsample\" is incompatible "
"with merge_mode \"add\" at the moment "
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
self.num_classes = num_classes
self.in_channels = in_channels
self.start_filts = start_filts
self.depth = depth
self.down_convs = []
self.up_convs = []
# create the encoder pathway and add to a list
for i in range(depth):
ins = self.in_channels if i == 0 else outs
outs = self.start_filts*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv)
# create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks
for i in range(depth-1):
ins = outs
outs = ins // 2
up_conv = UpConv(ins, outs, up_mode=up_mode,
merge_mode=merge_mode)
self.up_convs.append(up_conv)
# add the list of modules to current module
self.down_convs = nn.ModuleList(self.down_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.conv_final = conv1x1(outs, self.num_classes)
self.reset_params()
@staticmethod
def weight_init(m):
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, x):
encoder_outs = []
# encoder pathway, save outputs for merging
for i, module in enumerate(self.down_convs):
x, before_pool = module(x)
encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)]
x = module(before_pool, x)
# No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already.
x = self.conv_final(x)
return x
# Resnet Blocks
class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
def __init__(self, size_in, size_out=None, size_h=None):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
self.actvn = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
# Initialization
nn.init.zeros_(self.fc_1.weight)
def forward(self, x):
net = self.fc_0(self.actvn(x))
dx = self.fc_1(self.actvn(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
================================================
FILE: diff_utils/pointnet/dgcnn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
def knn(x, k):
inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
xx = torch.sum(x ** 2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()
idx = pairwise_distance.topk(k=k, dim=-1)[1]
return idx
def get_graph_feature(x, k=20):
idx = knn(x, k=k)
batch_size, num_points, _ = idx.size()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
_, num_dims, _ = x.size()
x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims)
# -> (batch_size*num_points, num_dims)
# batch_size * num_points * k + range(0, batch_size*num_points)
feature = x.view(batch_size * num_points, -1)[idx, :]
feature = feature.view(batch_size, num_points, k, num_dims)
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
return feature
class DGCNN(nn.Module):
def __init__(
self,
emb_dims=512,
use_bn=False,
output_channels=100 # number of categories to predict
):
super().__init__()
if use_bn:
print("using batch norm")
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
self.bn5 = nn.BatchNorm2d(emb_dims)
self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))
self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))
self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), self.bn4, nn.LeakyReLU(negative_slope=0.2))
self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2))
else:
self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.linear1 = nn.Linear(emb_dims*2, 512, bias=False)
self.bn6 = nn.BatchNorm1d(512)
self.dp1 = nn.Dropout(p=0.5)
self.linear2 = nn.Linear(512, 256)
self.bn7 = nn.BatchNorm1d(256)
self.dp2 = nn.Dropout(p=0.5)
self.linear3 = nn.Linear(256, output_channels)
def forward(self, x):
batch_size, num_dims, num_points = x.size() # x: batch x 3 x num of points
x = get_graph_feature(x) # x: batch x 6 x num of points x 20
x1 = self.conv1(x) # x1: batch x 64 x num of points x 20
x1_max = x1.max(dim=-1, keepdim=True)[0] # x1_max: batch x 64 x num of points x 1
x2 = self.conv2(x1) # x2: batch x 64 x num of points x 20
x2_max = x2.max(dim=-1, keepdim=True)[0] # x2_max: batch x 64 x num of points x 1
x3 = self.conv3(x2) # x3: batch x 128 x num of points x 20
x3_max = x3.max(dim=-1, keepdim=True)[0] # x3_max: batch x 128 x num of points x 1
x4 = self.conv4(x3) # x4: batch x 256 x num of points x 20
x4_max = x4.max(dim=-1, keepdim=True)[0] # x4_max: batch x 256 x num of points x 1
x_max = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max: batch x 512 x num of points x 1
point_feat = torch.squeeze(self.conv5(x_max), dim=3) # point feat: batch x 512 x num of points
#global_feat = point_feat.max(dim=2, keepdim=False)[0] # global feat: batch x 512
x1 = F.adaptive_max_pool1d(point_feat, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
x2 = F.adaptive_avg_pool1d(point_feat, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
x = torch.cat((x1, x2), 1) # (batch_size, emb_dims*2)
x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
x = self.dp1(x)
x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)
x = self.dp2(x)
x = self.linear3(x) # (batch_size, 256) -> (batch_size, output_channels)
return x
def get_global_feature(self, x):
batch_size, num_dims, num_points = x.size() # x: batch x 3 x num of points
x = get_graph_feature(x) # x: batch x 6 x num of points x 20
x1 = self.conv1(x) # x1: batch x 64 x num of points x 20
x1_max = x1.max(dim=-1, keepdim=True)[0] # x1_max: batch x 64 x num of points x 1
x2 = self.conv2(x1) # x2: batch x 64 x num of points x 20
x2_max = x2.max(dim=-1, keepdim=True)[0] # x2_max: batch x 64 x num of points x 1
x3 = self.conv3(x2) # x3: batch x 128 x num of points x 20
x3_max = x3.max(dim=-1, keepdim=True)[0] # x3_max: batch x 128 x num of points x 1
x4 = self.conv4(x3) # x4: batch x 256 x num of points x 20
x4_max = x4.max(dim=-1, keepdim=True)[0] # x4_max: batch x 256 x num of points x 1
x_max = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max: batch x 512 x num of points x 1
point_feat = torch.squeeze(self.conv5(x_max), dim=3) # point feat: batch x 512 x num of points
#global_feat = point_feat.max(dim=2, keepdim=False)[0] # global feat: batch x 512
x1 = F.adaptive_max_pool1d(point_feat, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
x2 = F.adaptive_avg_pool1d(point_feat, 1).view(batch_size, -1) # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
x = torch.cat((x1, x2), 1) # (batch_size, emb_dims*2)
return x
================================================
FILE: diff_utils/pointnet/pointnet_base.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad
from .transformer import Transformer
##-----------------------------------------------------------------------------
# Class for PointNetBase. Subclasses PyTorch's own "nn" module
#
# Computes the local embeddings and global features for an input set of points
##
class PointNetBase(nn.Module):
def __init__(self, num_points=2000, K=3):
# Call the super constructor
super(PointNetBase, self).__init__()
# Input transformer for K-dimensional input
# K should be 3 for XYZ coordinates, but can be larger if normals,
# colors, etc are included
self.input_transformer = Transformer(num_points, K)
# Embedding transformer is always going to be 64 dimensional
self.embedding_transformer = Transformer(num_points, 64)
# Multilayer perceptrons with shared weights are implemented as
# convolutions. This is because we are mapping from K inputs to 64
# outputs, so we can just consider each of the 64 K-dim filters as
# describing the weight matrix for each point dimension (X,Y,Z,...) to
# each index of the 64 dimension embeddings
self.mlp1 = nn.Sequential(
nn.Conv1d(K, 64, 1),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Conv1d(64, 64, 1),
nn.BatchNorm1d(64),
nn.ReLU())
self.mlp2 = nn.Sequential(
nn.Conv1d(64, 64, 1),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Conv1d(64, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Conv1d(128, 1024, 1),
nn.BatchNorm1d(1024),
nn.ReLU())
# Take as input a B x K x N matrix of B batches of N points with K
# dimensions
def forward(self, x):
# Number of points put into the network
N = x.shape[2]
# First compute the input data transform and transform the data
# T1 is B x K x K and x is B x K x N, so output is B x K x N
T1 = self.input_transformer(x)
x = torch.bmm(T1, x)
# Run the transformed inputs through the first embedding MLP
# Output is B x 64 x N
x = self.mlp1(x)
# Transform the embeddings. This gives us the "local embedding"
# referred to in the paper/slides
# T2 is B x 64 x 64 and x is B x 64 x N, so output is B x 64 x N
T2 = self.embedding_transformer(x)
local_embedding = torch.bmm(T2, x)
# Further embed the "local embeddings"
# Output is B x 1024 x N
global_feature = self.mlp2(local_embedding)
# Pool over the number of points. This results in the "global feature"
# referred to in the paper/slides
# Output should be B x 1024 x 1 --> B x 1024 (after squeeze)
global_feature = F.max_pool1d(global_feature, N).squeeze(2)
return global_feature, local_embedding, T2
================================================
FILE: diff_utils/pointnet/pointnet_classifier.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad
from .pointnet_base import PointNetBase
##-----------------------------------------------------------------------------
# Class for PointNetClassifier. Subclasses PyTorch's own "nn" module
#
# Computes the local embeddings and global features for an input set of points
##
class PointNetClassifier(nn.Module):
def __init__(self, num_points=2000, K=3):
# Call the super constructor
super(PointNetClassifier, self).__init__()
# Local and global feature extractor for PointNet
self.base = PointNetBase(num_points, K)
# Classifier for ShapeNet
self.classifier = nn.Sequential(
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.7),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.7),
nn.Linear(256, 40))
# Take as input a B x K x N matrix of B batches of N points with K
# dimensions
def forward(self, x):
# Only need to keep the global feature descriptors for classification
# Output should be B x 1024
global_feature, local_embedding, T2 = self.base(x)
# first attempt: only use the global feature
#return global_feature
# second attempt: concat local and global feature similar to segmentation network but create the downsampling MLP in the diffusion model
# local embedding shape: B x 64 x N; global feature shape: B x 1024; concat to get B x 1088 x N
num_points = local_embedding.shape[-1]
global_feature = global_feature.unsqueeze(-1).repeat(1,1,num_points)
point_features = torch.cat( (global_feature, local_embedding), dim=1 ) # shape is B x 1088 x N
return point_features
# Returns a B x 40
#return self.classifier(x), T2
================================================
FILE: diff_utils/pointnet/transformer.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad
##-----------------------------------------------------------------------------
# Class for Transformer. Subclasses PyTorch's own "nn" module
#
# Computes a KxK affine transform from the input data to transform inputs
# to a "canonical view"
##
class Transformer(nn.Module):
def __init__(self, num_points=2000, K=3):
# Call the super constructor
super(Transformer, self).__init__()
# Number of dimensions of the data
self.K = K
# Size of input
self.N = num_points
# Initialize identity matrix on the GPU (do this here so it only
# happens once)
self.identity = grad.Variable(
torch.eye(self.K).double().view(-1).cuda())
# First embedding block
self.block1 =nn.Sequential(
nn.Conv1d(K, 64, 1),
nn.BatchNorm1d(64),
nn.ReLU())
# Second embedding block
self.block2 =nn.Sequential(
nn.Conv1d(64, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU())
# Third embedding block
self.block3 =nn.Sequential(
nn.Conv1d(128, 1024, 1),
nn.BatchNorm1d(1024),
nn.ReLU())
# Multilayer perceptron
self.mlp = nn.Sequential(
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Linear(256, K * K))
# Take as input a B x K x N matrix of B batches of N points with K
# dimensions
def forward(self, x):
# Compute the feature extractions
# Output should ultimately be B x 1024 x N
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
# Pool over the number of points
# Output should be B x 1024 x 1 --> B x 1024 (after squeeze)
x = F.max_pool1d(x, self.N).squeeze(2)
# Run the pooled features through the multi-layer perceptron
# Output should be B x K^2
x = self.mlp(x)
# Add identity matrix to transform
# Output is still B x K^2 (broadcasting takes care of batch dimension)
x += self.identity
# Reshape the output into B x K x K affine transformation matrices
x = x.view(-1, self.K, self.K)
return x
================================================
FILE: diff_utils/sdf_utils.py
================================================
import math
import torch
import json
import torch.nn.functional as F
from torch import nn, einsum
import os
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from sdf_model.model import *
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load base network
specs_path = "sdf_model/config/siren/specs.json"
specs = json.load(open(specs_path))
model = MetaSDF(specs).to(device)
checkpoint = torch.load("sdf_model/config/siren/last.ckpt", map_location=device)
model.load_state_dict(checkpoint['state_dict'])
for p in model.parameters():
p.requires_grad=False
def pred_sdf_loss(x0, idx):
pc, xyz, gt = sdf_sampling(gt_files[idx], 16000, batch=x0.shape[0])
xyz = xyz.to(device)
gt = gt.to(device)
sdf_loss = functional_sdf_model(x0, xyz, gt)
return sdf_loss
def functional_sdf_model(modulation, xyz, gt):
'''
modulation: modulation vector with shape 512
xyz: query points, input to the model, dim= 16000x3
gt: ground truth; calculate l1 loss with prediction, dim= 16000x1
return: sdf_loss
'''
#print("shapes: ", modulation.shape, xyz.shape)
pred_sdf = model(modulation, xyz)
sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze())
return sdf_loss
def sdf_sampling(f, subsample, pc_size=1024, batch=1):
# f=pd.read_csv(f, sep=',',header=None).values
# f = torch.from_numpy(f)
pcs = torch.empty(batch, pc_size, 3)
xyz = torch.empty(batch, subsample, 3)
gt = torch.empty(batch, subsample)
for i in range(batch):
half = int(subsample / 2)
neg_tensor = f[f[:,-1]<0]
pos_tensor = f[f[:,-1]>0]
if pos_tensor.shape[0] < half:
pos_idx = torch.randint(pos_tensor.shape[0], (half,))
else:
pos_idx = torch.randperm(pos_tensor.shape[0])[:half]
if neg_tensor.shape[0] < half:
neg_idx = torch.randint(neg_tensor.shape[0], (half,))
else:
neg_idx = torch.randperm(neg_tensor.shape[0])[:half]
pos_sample = pos_tensor[pos_idx]
neg_sample = neg_tensor[neg_idx]
pc = f[f[:,-1]==0][:,:3]
pc_idx = torch.randperm(pc.shape[0])[:pc_size]
pc = pc[pc_idx]
samples = torch.cat([pos_sample, neg_sample], 0)
pcs[i] = pc.float()
xyz[i] = samples[:,:3].float()
gt[i] = samples[:, 3].float()
return pcs, xyz, gt
def apply_to_sdf(f, x):
for idx, l in enumerate(x):
x[idx] = f(l)
return x
================================================
FILE: environment.yml
================================================
name: diffusionsdf
channels:
- pytorch
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotlipy=0.7.0=py39h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2022.4.26=h06a4308_0
- certifi=2022.5.18.1=py39h06a4308_0
- cffi=1.15.0=py39hd667e15_1
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- cryptography=37.0.1=py39h9ce1e76_0
- cudatoolkit=11.3.1=h2bc3f7f_2
- ffmpeg=4.3=hf484d3e_0
- freetype=2.11.0=h70c0345_0
- giflib=5.2.1=h7b6447c_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- idna=3.3=pyhd3eb1b0_0
- intel-openmp=2021.4.0=h06a4308_3561
- jpeg=9e=h7f8727e_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.3=he6710b0_2
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h7f8727e_2
- libidn2=2.3.2=h7f8727e_0
- libpng=1.6.37=hbc83047_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.2.0=h2818925_1
- libunistring=0.9.10=h27cfd23_0
- libuv=1.40.0=h7b6447c_0
- libwebp=1.2.2=h55f646e_0
- libwebp-base=1.2.2=h7f8727e_0
- lz4-c=1.9.3=h295c915_1
- mkl=2021.4.0=h06a4308_640
- mkl-service=2.4.0=py39h7f8727e_0
- mkl_fft=1.3.1=py39hd3c417c_0
- mkl_random=1.2.2=py39h51133e4_0
- ncurses=6.3=h7f8727e_2
- nettle=3.7.3=hbbd107a_1
- numpy=1.22.3=py39he7a7128_0
- numpy-base=1.22.3=py39hf524024_0
- openh264=2.1.1=h4ff587b_0
- openssl=1.1.1o=h7f8727e_0
- pillow=9.0.1=py39h22f2fdc_0
- pip=21.2.4=py39h06a4308_0
- pycparser=2.21=pyhd3eb1b0_0
- pyopenssl=22.0.0=pyhd3eb1b0_0
- pysocks=1.7.1=py39h06a4308_0
- python=3.9.12=h12debd9_1
- pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
- pytorch-mutex=1.0=cuda
- readline=8.1.2=h7f8727e_1
- requests=2.27.1=pyhd3eb1b0_0
- setuptools=61.2.0=py39h06a4308_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.38.3=hc218d9a_0
- tk=8.6.12=h1ccaba5_0
- torchaudio=0.11.0=py39_cu113
- torchvision=0.12.0=py39_cu113
- typing_extensions=4.1.1=pyh06a4308_0
- tzdata=2022a=hda174b7_0
- urllib3=1.26.9=py39h06a4308_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.5=h7f8727e_1
- zlib=1.2.12=h7f8727e_2
- zstd=1.5.2=ha4553b6_0
- pip:
- absl-py==1.1.0
- addict==2.4.0
- aiohttp==3.8.1
- aiosignal==1.2.0
- asttokens==2.2.1
- async-timeout==4.0.2
- attrs==21.4.0
- backcall==0.2.0
- cachetools==5.2.0
- click==8.1.3
- comm==0.1.2
- configargparse==1.5.3
- contourpy==1.0.7
- cycler==0.11.0
- dash==2.8.1
- dash-core-components==2.0.0
- dash-html-components==2.0.0
- dash-table==5.0.0
- debugpy==1.6.6
- decorator==5.1.1
- einops==0.6.0
- einops-exts==0.0.4
- executing==1.2.0
- fastjsonschema==2.16.2
- flask==2.2.2
- fonttools==4.38.0
- frozenlist==1.3.0
- fsspec==2022.5.0
- google-auth==2.6.6
- google-auth-oauthlib==0.4.6
- grpcio==1.46.3
- imageio==2.19.3
- importlib-metadata==4.11.4
- ipykernel==6.21.1
- ipython==8.10.0
- ipywidgets==8.0.4
- itsdangerous==2.1.2
- jedi==0.18.2
- jinja2==3.1.2
- joblib==1.2.0
- jsonschema==4.17.3
- jupyter-client==8.0.2
- jupyter-core==5.2.0
- jupyterlab-widgets==3.0.5
- kiwisolver==1.4.4
- markdown==3.3.7
- markupsafe==2.1.2
- matplotlib==3.6.3
- matplotlib-inline==0.1.6
- multidict==6.0.2
- nbformat==5.5.0
- nest-asyncio==1.5.6
- networkx==2.8.2
- oauthlib==3.2.0
- open3d==0.16.0
- packaging==21.3
- pandas==1.4.2
- parso==0.8.3
- pexpect==4.8.0
- pickleshare==0.7.5
- platformdirs==3.0.0
- plotly==5.13.0
- plyfile==0.7.4
- prompt-toolkit==3.0.36
- protobuf==3.20.1
- psutil==5.9.4
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pydeprecate==0.3.2
- pygments==2.14.0
- pyparsing==3.0.9
- pyquaternion==0.9.9
- pyrsistent==0.19.3
- python-dateutil==2.8.2
- pytorch-lightning==1.6.4
- pytz==2022.1
- pywavelets==1.3.0
- pyyaml==6.0
- pyzmq==25.0.0
- requests-oauthlib==1.3.1
- rotary-embedding-torch==0.2.1
- rsa==4.8
- scikit-image==0.19.2
- scikit-learn==1.2.1
- scipy==1.8.1
- stack-data==0.6.2
- tenacity==8.2.1
- tensorboard==2.9.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- threadpoolctl==3.1.0
- tifffile==2022.5.4
- torch-scatter==2.0.9
- torchmetrics==0.9.0
- tornado==6.2
- tqdm==4.64.0
- traitlets==5.9.0
- trimesh==3.12.5
- wcwidth==0.2.6
- werkzeug==2.2.2
- widgetsnbextension==4.0.5
- yarl==1.7.2
- zipp==3.8.0
prefix: /home/gchou/.conda/envs/diffusion
================================================
FILE: metrics/StructuralLosses/__init__.py
================================================
#import torch
#from MakePytorchBackend import AddGPU, Foo, ApproxMatch
#from Add import add_gpu, approx_match
================================================
FILE: metrics/StructuralLosses/match_cost.py
================================================
import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad
# Inherit from Function
class MatchCostFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
match : batch_size * #query_points * #dataset_points
'''
match, temp = ApproxMatch(seta, setb)
ctx.match = match
cost = MatchCost(seta, setb, match)
return cost
"""
grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
"""
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
#grad_input = grad_weight = grad_bias = None
grada, gradb = MatchCostGrad(seta, setb, ctx.match)
grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
return grada*grad_output_expand, gradb*grad_output_expand
match_cost = MatchCostFunction.apply
================================================
FILE: metrics/StructuralLosses/nn_distance.py
================================================
import torch
from torch.autograd import Function
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
# Inherit from Function
class NNDistanceFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
dist1, idx1, dist2, idx2
'''
dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
ctx.idx1 = idx1
ctx.idx2 = idx2
return dist1, dist2
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_dist1, grad_dist2):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
idx1 = ctx.idx1
idx2 = ctx.idx2
grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
return grada, gradb
nn_distance = NNDistanceFunction.apply
================================================
FILE: metrics/__init__.py
================================================
================================================
FILE: metrics/evaluation_metrics.py
================================================
import torch
import numpy as np
import warnings
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm
# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
def distChamfer(a, b):
x, y = a, b
bs, num_points, points_dim = x.size()
xx = torch.bmm(x, x.transpose(2, 1))
yy = torch.bmm(y, y.transpose(2, 1))
zz = torch.bmm(x, y.transpose(2, 1))
diag_ind = torch.arange(0, num_points).to(a).long()
rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
P = (rx.transpose(2, 1) + ry - 2 * zz)
return P.min(1)[0], P.min(2)[0]
# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
from metrics.StructuralLosses.nn_distance import nn_distance
try:
#from pytorch_structural_losses.nn_distance import nn_distance
#print("cuda available")
def distChamferCUDA(x, y):
return nn_distance(x, y)
except:
print("distChamferCUDA not available; fall back to slower version.")
# def distChamferCUDA(x, y):
# return distChamfer(x, y)
def emd_approx(x, y):
bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2)
assert npts == mpts, "EMD only works if two point clouds are equal size"
dim = x.shape[-1]
x = x.reshape(bs, npts, 1, dim)
y = y.reshape(bs, 1, mpts, dim)
dist = (x - y).norm(dim=-1, keepdim=False) # (bs, npts, mpts)
emd_lst = []
dist_np = dist.cpu().detach().numpy()
for i in range(bs):
d_i = dist_np[i]
r_idx, c_idx = linear_sum_assignment(d_i)
emd_i = d_i[r_idx, c_idx].mean()
emd_lst.append(emd_i)
emd = np.stack(emd_lst).reshape(-1)
emd_torch = torch.from_numpy(emd).to(x)
return emd_torch
try:
from metrics.StructuralLosses.match_cost import match_cost
#print("cuda available")
def emd_approx_cuda(sample, ref):
B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
assert N == N_ref, "Not sure what would EMD do in this case"
emd = match_cost(sample, ref) # (B,)
emd_norm = emd / float(N) # (B,)
return emd_norm
except:
print("emd_approx_cuda not available. Fall back to slower version.")
# def emd_approx_cuda(sample, ref):
#return emd_approx(sample, ref)
def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True,
accelerated_emd=False):
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)
cd_lst = []
emd_lst = []
iterator = range(0, N_sample, batch_size)
for b_start in iterator:
b_end = min(N_sample, b_start + batch_size)
sample_batch = sample_pcs[b_start:b_end]
ref_batch = ref_pcs[b_start:b_end]
if accelerated_cd:
dl, dr = distChamferCUDA(sample_batch, ref_batch)
else:
dl, dr = distChamfer(sample_batch, ref_batch)
cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))
if accelerated_emd:
emd_batch = emd_approx_cuda(sample_batch, ref_batch)
else:
emd_batch = emd_approx(sample_batch, ref_batch)
emd_lst.append(emd_batch)
if reduced:
cd = torch.cat(cd_lst).mean()
emd = torch.cat(emd_lst).mean()
else:
cd = torch.cat(cd_lst)
emd = torch.cat(emd_lst)
results = {
'MMD-CD': cd,
'MMD-EMD': emd,
}
return results
def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True,
accelerated_emd=True):
batch_size = sample_pcs.shape[0] if batch_size is None else batch_size
N_sample = sample_pcs.shape[0]
N_ref = ref_pcs.shape[0]
all_cd = []
all_emd = []
iterator = range(N_sample)
with tqdm(iterator) as pbar:
for sample_b_start in pbar:
pbar.set_description("Files evaluated: {}/{}".format(sample_b_start, N_sample))
sample_batch = sample_pcs[sample_b_start]
cd_lst = []
emd_lst = []
for ref_b_start in range(0, N_ref, batch_size):
ref_b_end = min(N_ref, ref_b_start + batch_size)
ref_batch = ref_pcs[ref_b_start:ref_b_end]
batch_size_ref = ref_batch.size(0)
sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
sample_batch_exp = sample_batch_exp.contiguous()
if accelerated_cd and distChamferCUDA is not None:
dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
else:
dl, dr = distChamfer(sample_batch_exp, ref_batch)
cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))
# if accelerated_emd:
# emd_batch = emd_approx_cuda(sample_batch_exp, ref_batch)
# else:
# emd_batch = emd_approx(sample_batch_exp, ref_batch)
# emd_lst.append(emd_batch.view(1, -1))
cd_lst = torch.cat(cd_lst, dim=1)
#emd_lst = torch.cat(emd_lst, dim=1)
all_cd.append(cd_lst)
#all_emd.append(emd_lst)
all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref
#all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref
return all_cd, 0#all_emd
def emd_tmd_from_pcs(gen_pcs):
sum_dist = 0
for j in range(len(gen_pcs)):
for k in range(j + 1, len(gen_pcs), 1):
pc1 = gen_pcs[j]
pc2 = gen_pcs[k]
chamfer_dist = emd_approx_cuda(pc1, pc2) #compute_trimesh_chamfer(pc1, pc2)
print("emd dist: ", chamfer_dist)
sum_dist += chamfer_dist
mean_dist = sum_dist * 2 / (len(gen_pcs) - 1)
return mean_dist
# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py
def knn(Mxx, Mxy, Myy, k, sqrt=False):
n0 = Mxx.size(0)
n1 = Myy.size(0)
label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx)
M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
if sqrt:
M = M.abs().sqrt()
INFINITY = float('inf')
val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)
count = torch.zeros(n0 + n1).to(Mxx)
for i in range(0, k):
count = count + label.index_select(0, idx[i])
pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()
s = {
'tp': (pred * label).sum(),
'fp': (pred * (1 - label)).sum(),
'fn': ((1 - pred) * label).sum(),
'tn': ((1 - pred) * (1 - label)).sum(),
}
s.update({
'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
'acc': torch.eq(label, pred).float().mean(),
})
return s
def lgan_mmd_cov(all_dist):
# all dist shape = [number of generated pcs, number of ref pcs]
N_sample, N_ref = all_dist.size(0), all_dist.size(1)
min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
min_val, _ = torch.min(all_dist, dim=0)
mmd = min_val.mean()
#mmd_smp = min_val_fromsmp.mean()
cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
cov = torch.tensor(cov).to(all_dist)
return {
'lgan_mmd': mmd,
'lgan_cov': cov,
#'lgan_mmd_smp': mmd_smp,
}
def compute_mmd(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True):
results = {}
batch_size = sample_pcs.shape[0] if batch_size is None else batch_size
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
res_cd = lgan_mmd_cov(M_rs_cd.t())
# results.update({
# "%s-CD" % k: v for k, v in res_cd.items()
# })
#res_emd = lgan_mmd_cov(M_rs_emd.t())
# results.update({
# "%s-EMD" % k: v for k, v in res_emd.items()
# })
return res_cd,0# res_emd
def compute_cd(sample_pcs, ref_pcs):
dl, dr = distChamferCUDA(sample_pcs, ref_pcs)
print("dl, dr shapes: ", dl.shape, dr.shape, dl.mean(dim=1).shape)
res = (dl.mean(dim=1) + dr.mean(dim=1))#.view(1, -1)
return res
def compute_all_metrics(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True):
results = {}
batch_size = sample_pcs.shape[0] if batch_size is None else batch_size
M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
res_cd = lgan_mmd_cov(M_rs_cd.t())
results.update({
"%s-CD" % k: v for k, v in res_cd.items()
})
res_emd = lgan_mmd_cov(M_rs_emd.t())
results.update({
"%s-EMD" % k: v for k, v in res_emd.items()
})
M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd)
M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)
# 1-NN results
one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
results.update({
"1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
})
one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
results.update({
"1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
})
return res_cd, res_emd, one_nn_cd_res, one_nn_emd_res
#######################################################
# JSD : from https://github.com/optas/latent_3d_points
#######################################################
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
"""Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
that is placed in the unit-cube.
If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
"""
grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
spacing = 1.0 / float(resolution - 1)
for i in range(resolution):
for j in range(resolution):
for k in range(resolution):
grid[i, j, k, 0] = i * spacing - 0.5
grid[i, j, k, 1] = j * spacing - 0.5
grid[i, j, k, 2] = k * spacing - 0.5
if clip_sphere:
grid = grid.reshape(-1, 3)
grid = grid[norm(grid, axis=1) <= 0.5]
return grid, spacing
def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
"""Computes the JSD between two sets of point-clouds, as introduced in the paper
```Learning Representations And Generative Models For 3D Point Clouds```.
Args:
sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
resolution: (int) grid-resolution. Affects granularity of measurements.
"""
in_unit_sphere = True
sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
return jensen_shannon_divergence(sample_grid_var, ref_grid_var)
def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False):
"""Given a collection of point-clouds, estimate the entropy of the random variables
corresponding to occupancy-grid activation patterns.
Inputs:
pclouds: (numpy array) #point-clouds x points per point-cloud x 3
grid_resolution (int) size of occupancy grid that will be used.
"""
epsilon = 10e-4
bound = 0.5 + epsilon
if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit cube.')
if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
if verbose:
warnings.warn('Point-clouds are not in unit sphere.')
grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
grid_coordinates = grid_coordinates.reshape(-1, 3)
grid_counters = np.zeros(len(grid_coordinates))
grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)
for pc in pclouds:
_, indices = nn.kneighbors(pc)
indices = np.squeeze(indices)
for i in indices:
grid_counters[i] += 1
indices = np.unique(indices)
for i in indices:
grid_bernoulli_rvars[i] += 1
acc_entropy = 0.0
n = float(len(pclouds))
for g in grid_bernoulli_rvars:
if g > 0:
p = float(g) / n
acc_entropy += entropy([p, 1.0 - p])
return acc_entropy / len(grid_counters), grid_counters
def jensen_shannon_divergence(P, Q):
if np.any(P < 0) or np.any(Q < 0):
raise ValueError('Negative values.')
if len(P) != len(Q):
raise ValueError('Non equal size.')
P_ = P / np.sum(P) # Ensure probabilities.
Q_ = Q / np.sum(Q)
e1 = entropy(P_, base=2)
e2 = entropy(Q_, base=2)
e_sum = entropy((P_ + Q_) / 2.0, base=2)
res = e_sum - ((e1 + e2) / 2.0)
res2 = _jsdiv(P_, Q_)
if not np.allclose(res, res2, atol=10e-5, rtol=0):
warnings.warn('Numerical values of two JSD methods don\'t agree.')
return res
def _jsdiv(P, Q):
"""another way of computing JSD"""
def _kldiv(A, B):
a = A.copy()
b = B.copy()
idx = np.logical_and(a > 0, b > 0)
a = a[idx]
b = b[idx]
return np.sum([v for v in a * np.log2(a / b)])
P_ = P / np.sum(P)
Q_ = Q / np.sum(Q)
M = 0.5 * (P_ + Q_)
return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))
if __name__ == "__main__":
#from pytorch_structural_losses.nn_distance import nn_distance # need to make
#from metrics.StructuralLosses.nn_distance import nn_distance
B, N = 100, 2048
x = torch.rand(B, N, 3)
y = torch.rand(B, N, 3)
#get_dist = nn_distance # if cuda available
get_dist = distChamfer # cuda not available
#distChamfer = distChamferCUDA
min_l, min_r = get_dist(x.cuda(), y.cuda())
print(min_l.shape)
print(min_r.shape)
l_dist = min_l.mean().cpu().detach().item()
r_dist = min_r.mean().cpu().detach().item()
print(l_dist, r_dist)
================================================
FILE: metrics/pytorch_structural_losses/.gitignore
================================================
PyTorchStructuralLosses.egg-info/
================================================
FILE: metrics/pytorch_structural_losses/Makefile
================================================
###############################################################################
# Uncomment for debugging
# DEBUG := 1
# Pretty build
# Q ?= @
CXX := g++
PYTHON := python
NVCC := /usr/local/cuda/bin/nvcc
# PYTHON Header path
PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())')
PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]')
PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]')
# CUDA ROOT DIR that contains bin/ lib64/ and include/
# CUDA_DIR := /usr/local/cuda
CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())')
INCLUDE_DIRS := ./ $(CUDA_DIR)/include
INCLUDE_DIRS += $(PYTHON_HEADER_DIR)
INCLUDE_DIRS += $(PYTORCH_INCLUDES)
# Custom (MKL/ATLAS/OpenBLAS) include and lib directories.
# Leave commented to accept the defaults for your choice of BLAS
# (which should work)!
# BLAS_INCLUDE := /path/to/your/blas
# BLAS_LIB := /path/to/your/blas
###############################################################################
SRC_DIR := ./src
OBJ_DIR := ./objs
CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp)
CU_SRCS := $(wildcard $(SRC_DIR)/*.cu)
OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS))
CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS))
STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a
# CUDA architecture setting: going with all of them.
# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility.
# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility.
CUDA_ARCH := -gencode arch=compute_61,code=sm_61 \
-gencode arch=compute_61,code=compute_61 \
-gencode arch=compute_52,code=sm_52
# We will also explicitly add stdc++ to the link target.
LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu
# Debugging
ifeq ($(DEBUG), 1)
COMMON_FLAGS += -DDEBUG -g -O0
# https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/
NVCCFLAGS += -g -G # -rdc true
else
COMMON_FLAGS += -DNDEBUG -O3
endif
WARNINGS := -Wall -Wno-sign-compare -Wcomment
INCLUDE_DIRS += $(BLAS_INCLUDE)
# Automatic dependency generation (nvcc is handled separately)
CXXFLAGS += -std=c++14 -MMD -MP
# Complete build flags.
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \
-DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0
CXXFLAGS += -pthread -fPIC -fwrapv -std=c++14 $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS += -std=c++14 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
all: $(STATIC_LIB)
$(PYTHON) setup.py build
@ mv build/lib.linux-x86_64-cpython-39/StructuralLosses ..
#@ mv build/lib.linux-x86_64-3.6/StructuralLosses ..
@ mv build/lib.linux-x86_64-cpython-39/*.so ../StructuralLosses/
#@ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/
@- $(RM) -rf $(OBJ_DIR) build objs
$(OBJ_DIR):
@ mkdir -p $@
@ mkdir -p $@/cuda
$(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR)
@ echo CXX $<
$(Q)$(CXX) $< $(CXXFLAGS) -c -o $@
$(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR)
@ echo NVCC $<
$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \
-odir $(@D)
$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@
$(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR)
$(RM) -f $(STATIC_LIB)
$(RM) -rf build dist
@ echo LD -o $@
ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS)
clean:
@- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses
================================================
FILE: metrics/pytorch_structural_losses/StructuralLosses/__init__.py
================================================
#import torch
#from MakePytorchBackend import AddGPU, Foo, ApproxMatch
#from Add import add_gpu, approx_match
================================================
FILE: metrics/pytorch_structural_losses/StructuralLosses/match_cost.py
================================================
import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad
# Inherit from Function
class MatchCostFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
match : batch_size * #query_points * #dataset_points
'''
match, temp = ApproxMatch(seta, setb)
ctx.match = match
cost = MatchCost(seta, setb, match)
return cost
"""
grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
"""
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
#grad_input = grad_weight = grad_bias = None
grada, gradb = MatchCostGrad(seta, setb, ctx.match)
grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
return grada*grad_output_expand, gradb*grad_output_expand
match_cost = MatchCostFunction.apply
================================================
FILE: metrics/pytorch_structural_losses/StructuralLosses/nn_distance.py
================================================
import torch
from torch.autograd import Function
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
# Inherit from Function
class NNDistanceFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
dist1, idx1, dist2, idx2
'''
dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
ctx.idx1 = idx1
ctx.idx2 = idx2
return dist1, dist2
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_dist1, grad_dist2):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
idx1 = ctx.idx1
idx2 = ctx.idx2
grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
return grada, gradb
nn_distance = NNDistanceFunction.apply
================================================
FILE: metrics/pytorch_structural_losses/__init__.py
================================================
#import torch
#from MakePytorchBackend import AddGPU, Foo, ApproxMatch
#from Add import add_gpu, approx_match
================================================
FILE: metrics/pytorch_structural_losses/match_cost.py
================================================
import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad
# Inherit from Function
class MatchCostFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
match : batch_size * #query_points * #dataset_points
'''
match, temp = ApproxMatch(seta, setb)
ctx.match = match
cost = MatchCost(seta, setb, match)
return cost
"""
grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
"""
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
#grad_input = grad_weight = grad_bias = None
grada, gradb = MatchCostGrad(seta, setb, ctx.match)
grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
return grada*grad_output_expand, gradb*grad_output_expand
match_cost = MatchCostFunction.apply
================================================
FILE: metrics/pytorch_structural_losses/nn_distance.py
================================================
import torch
from torch.autograd import Function
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
from .StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
# Inherit from Function
class NNDistanceFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, seta, setb):
#print("Match Cost Forward")
ctx.save_for_backward(seta, setb)
'''
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
dist1, idx1, dist2, idx2
'''
dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
ctx.idx1 = idx1
ctx.idx2 = idx2
return dist1, dist2
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_dist1, grad_dist2):
#print("Match Cost Backward")
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
seta, setb = ctx.saved_tensors
idx1 = ctx.idx1
idx2 = ctx.idx2
grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
return grada, gradb
nn_distance = NNDistanceFunction.apply
================================================
FILE: metrics/pytorch_structural_losses/pybind/bind.cpp
================================================
#include
#include
#include "pybind/extern.hpp"
namespace py = pybind11;
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("ApproxMatch", &ApproxMatch);
m.def("MatchCost", &MatchCost);
m.def("MatchCostGrad", &MatchCostGrad);
m.def("NNDistance", &NNDistance);
m.def("NNDistanceGrad", &NNDistanceGrad);
}
================================================
FILE: metrics/pytorch_structural_losses/pybind/extern.hpp
================================================
std::vector ApproxMatch(at::Tensor in_a, at::Tensor in_b);
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
std::vector NNDistance(at::Tensor set_d, at::Tensor set_q);
std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2);
================================================
FILE: metrics/pytorch_structural_losses/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
# Python interface
setup(
name='PyTorchStructuralLosses',
version='0.1.0',
install_requires=['torch'],
packages=['StructuralLosses'],
package_dir={'StructuralLosses': './'},
ext_modules=[
CUDAExtension(
name='StructuralLossesBackend',
include_dirs=['./'],
sources=[
'pybind/bind.cpp',
],
libraries=['make_pytorch'],
library_dirs=['objs'],
# extra_compile_args=['-g']
)
],
cmdclass={'build_ext': BuildExtension},
author='Christopher B. Choy',
author_email='chrischoy@ai.stanford.edu',
description='Tutorial for Pytorch C++ Extension with a Makefile',
keywords='Pytorch C++ Extension',
url='https://github.com/chrischoy/MakePytorchPlusPlus',
zip_safe=False,
)
================================================
FILE: metrics/pytorch_structural_losses/src/approxmatch.cu
================================================
#include "utils.hpp"
__global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){
float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
float multiL,multiR;
if (n>=m){
multiL=1;
multiR=n/m;
}else{
multiL=m/n;
multiR=1;
}
const int Block=1024;
__shared__ float buf[Block*4];
for (int i=blockIdx.x;i=-2;j--){
for (int j=7;j>-2;j--){
float level=-powf(4.0f,j);
if (j==-2){
level=0;
}
for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,out);
//}
__global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){
__shared__ float sum_grad[256*3];
for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad2);
//}
/*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
cudaStream_t stream)*/
// temp: TensorShape{b,(n+m)*2}
void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){
approxmatchkernel
<<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp);
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
throw std::runtime_error(Formatter()
<< "CUDA kernel failed : " << std::to_string(err));
}
void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){
matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out);
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
throw std::runtime_error(Formatter()
<< "CUDA kernel failed : " << std::to_string(err));
}
void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){
matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1);
matchcostgrad2kernel<<>>(b,n,m,xyz1,xyz2,match,grad2);
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
throw std::runtime_error(Formatter()
<< "CUDA kernel failed : " << std::to_string(err));
}
================================================
FILE: metrics/pytorch_structural_losses/src/approxmatch.cuh
================================================
/*
template
void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
cudaStream_t stream);
*/
void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream);
void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream);
void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream);
================================================
FILE: metrics/pytorch_structural_losses/src/nndistance.cu
================================================
__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
const int batch=512;
__shared__ float buf[batch*3];
for (int i=blockIdx.x;ibest){
result[(i*n+j)]=best;
result_i[(i*n+j)]=best_i;
}
}
__syncthreads();
}
}
}
void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
NmDistanceKernel<<>>(b,n,xyz,m,xyz2,result,result_i);
NmDistanceKernel<<>>(b,m,xyz2,n,xyz,result2,result2_i);
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2);
NmDistanceGradKernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1);
}
================================================
FILE: metrics/pytorch_structural_losses/src/nndistance.cuh
================================================
void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
================================================
FILE: metrics/pytorch_structural_losses/src/structural_loss.cpp
================================================
#include
#include
#include "src/approxmatch.cuh"
#include "src/nndistance.cuh"
#include
#include
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
/*
input:
set1 : batch_size * #dataset_points * 3
set2 : batch_size * #query_points * 3
returns:
match : batch_size * #query_points * #dataset_points
*/
// temp: TensorShape{b,(n+m)*2}
std::vector ApproxMatch(at::Tensor set_d, at::Tensor set_q) {
//std::cout << "[ApproxMatch] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl;
at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(match);
CHECK_INPUT(temp);
approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),temp.data(), at::cuda::getCurrentCUDAStream());
return {match, temp};
}
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
//std::cout << "[MatchCost] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[MatchCost] batch_size:" << batch_size << std::endl;
at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(match);
CHECK_INPUT(out);
matchcost(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),out.data(),at::cuda::getCurrentCUDAStream());
return out;
}
std::vector MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
//std::cout << "[MatchCostGrad] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl;
at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(match);
CHECK_INPUT(grad1);
CHECK_INPUT(grad2);
matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data(),set_q.data(),match.data(),grad1.data(),grad2.data(),at::cuda::getCurrentCUDAStream());
return {grad1, grad2};
}
/*
input:
set_d : batch_size * #dataset_points * 3
set_q : batch_size * #query_points * 3
returns:
dist1, idx1 : batch_size * #dataset_points
dist2, idx2 : batch_size * #query_points
*/
std::vector NNDistance(at::Tensor set_d, at::Tensor set_q) {
//std::cout << "[NNDistance] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[NNDistance] batch_size:" << batch_size << std::endl;
at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(dist1);
CHECK_INPUT(idx1);
CHECK_INPUT(dist2);
CHECK_INPUT(idx2);
// void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
nndistance(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),dist1.data(),idx1.data(),dist2.data(),idx2.data(), at::cuda::getCurrentCUDAStream());
return {dist1, idx1, dist2, idx2};
}
std::vector NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) {
//std::cout << "[NNDistanceGrad] Called." << std::endl;
int64_t batch_size = set_d.size(0);
int64_t n_dataset_points = set_d.size(1); // n
int64_t n_query_points = set_q.size(1); // m
//std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl;
at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
CHECK_INPUT(set_d);
CHECK_INPUT(set_q);
CHECK_INPUT(idx1);
CHECK_INPUT(idx2);
CHECK_INPUT(grad_dist1);
CHECK_INPUT(grad_dist2);
CHECK_INPUT(grad1);
CHECK_INPUT(grad2);
//void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
nndistancegrad(batch_size,n_dataset_points,set_d.data(),n_query_points,set_q.data(),
grad_dist1.data(),idx1.data(),
grad_dist2.data(),idx2.data(),
grad1.data(),grad2.data(),
at::cuda::getCurrentCUDAStream());
return {grad1, grad2};
}
================================================
FILE: metrics/pytorch_structural_losses/src/utils.hpp
================================================
#include
#include
#include
class Formatter {
public:
Formatter() {}
~Formatter() {}
template Formatter &operator<<(const Type &value) {
stream_ << value;
return *this;
}
std::string str() const { return stream_.str(); }
operator std::string() const { return stream_.str(); }
enum ConvertToString { to_str };
std::string operator>>(ConvertToString) { return stream_.str(); }
private:
std::stringstream stream_;
Formatter(const Formatter &);
Formatter &operator=(Formatter &);
};
================================================
FILE: models/__init__.py
================================================
#!/usr/bin/python3
from models.sdf_model import SdfModel
from models.autoencoder import BetaVAE
from models.archs.encoders.conv_pointnet import UNet
from models.diffusion import *
from models.archs.diffusion_arch import *
#from diffusion import *
from models.sdf_model import SdfModel
from models.combined_model import CombinedModel
================================================
FILE: models/archs/__init__.py
================================================
#!/usr/bin/python3
================================================
FILE: models/archs/diffusion_arch.py
================================================
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from rotary_embedding_torch import RotaryEmbedding
from diff_utils.model_utils import *
from random import sample
class CausalTransformer(nn.Module):
def __init__(
self,
dim,
depth,
dim_in_out=None,
cross_attn=False,
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_in = False,
norm_out = True,
attn_dropout = 0.,
ff_dropout = 0.,
final_proj = True,
normformer = False,
rotary_emb = True,
**kwargs
):
super().__init__()
self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM
self.rel_pos_bias = RelPosBias(heads = heads)
rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
rotary_emb_cross = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
self.layers = nn.ModuleList([])
dim_in_out = default(dim_in_out, dim)
self.use_same_dims = (dim_in_out is None) or (dim_in_out==dim)
point_feature_dim = kwargs.get('point_feature_dim', dim)
if cross_attn:
#print("using CROSS ATTN, with dropout {}".format(attn_dropout))
self.layers.append(nn.ModuleList([
Attention(dim = dim_in_out, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
Attention(dim = dim, kv_dim=point_feature_dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
Attention(dim = dim, kv_dim=point_feature_dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.layers.append(nn.ModuleList([
Attention(dim = dim, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
Attention(dim = dim, kv_dim=point_feature_dim, out_dim=dim_in_out, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),
FeedForward(dim = dim_in_out, out_dim=dim_in_out, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
else:
self.layers.append(nn.ModuleList([
Attention(dim = dim_in_out, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, out_dim=dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.layers.append(nn.ModuleList([
Attention(dim = dim, out_dim=dim_in_out, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
FeedForward(dim = dim_in_out, out_dim=dim_in_out, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
]))
self.norm = LayerNorm(dim_in_out, stable = True) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim_in_out, dim_in_out, bias = False) if final_proj else nn.Identity()
self.cross_attn = cross_attn
def forward(self, x, time_emb=None, context=None):
n, device = x.shape[1], x.device
x = self.init_norm(x)
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
if self.cross_attn:
#assert context is not None
for idx, (self_attn, cross_attn, ff) in enumerate(self.layers):
#print("x1 shape: ", x.shape)
if (idx==0 or idx==len(self.layers)-1) and not self.use_same_dims:
x = self_attn(x, attn_bias = attn_bias)
x = cross_attn(x, context=context) # removing attn_bias for now
else:
x = self_attn(x, attn_bias = attn_bias) + x
x = cross_attn(x, context=context) + x # removing attn_bias for now
#print("x2 shape, context shape: ", x.shape, context.shape)
#print("x3 shape, context shape: ", x.shape, context.shape)
x = ff(x) + x
else:
for idx, (attn, ff) in enumerate(self.layers):
#print("x1 shape: ", x.shape)
if (idx==0 or idx==len(self.layers)-1) and not self.use_same_dims:
x = attn(x, attn_bias = attn_bias)
else:
x = attn(x, attn_bias = attn_bias) + x
#print("x2 shape: ", x.shape)
x = ff(x) + x
#print("x3 shape: ", x.shape)
out = self.norm(x)
return self.project_out(out)
class DiffusionNet(nn.Module):
def __init__(
self,
dim,
dim_in_out=None,
num_timesteps = None,
num_time_embeds = 1,
cond = None,
**kwargs
):
super().__init__()
self.num_time_embeds = num_time_embeds
self.dim = dim
self.cond = cond
self.cross_attn = kwargs.get('cross_attn', False)
self.cond_dropout = kwargs.get('cond_dropout', False)
self.point_feature_dim = kwargs.get('point_feature_dim', dim)
self.dim_in_out = default(dim_in_out, dim)
#print("dim, in out, point feature dim: ", dim, dim_in_out, self.point_feature_dim)
#print("cond dropout: ", self.cond_dropout)
self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, self.dim_in_out * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(self.dim_in_out), MLP(self.dim_in_out, self.dim_in_out * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
)
# last input to the transformer: "a final embedding whose output from the Transformer is used to predicted the unnoised CLIP image embedding"
self.learned_query = nn.Parameter(torch.randn(self.dim_in_out))
self.causal_transformer = CausalTransformer(dim = dim, dim_in_out=self.dim_in_out, **kwargs)
if cond:
# output dim of pointnet needs to match model dim; unless add additional linear layer
self.pointnet = ConvPointnet(c_dim=self.point_feature_dim)
def forward(
self,
data,
diffusion_timesteps,
pass_cond=-1, # default -1, depends on prob; but pass as argument during sampling
):
if self.cond:
assert type(data) is tuple
data, cond = data # adding noise to cond_feature so doing this in diffusion.py
#print("data, cond shape: ", data.shape, cond.shape) # B, dim_in_out; B, N, 3
#print("pass cond: ", pass_cond)
if self.cond_dropout:
# classifier-free guidance: 20% unconditional
prob = torch.randint(low=0, high=10, size=(1,))
percentage = 8
if prob < percentage or pass_cond==0:
cond_feature = torch.zeros( (cond.shape[0], cond.shape[1], self.point_feature_dim), device=data.device )
#print("zeros shape: ", cond_feature.shape)
elif prob >= percentage or pass_cond==1:
cond_feature = self.pointnet(cond, cond)
#print("cond shape: ", cond_feature.shape)
else:
cond_feature = self.pointnet(cond, cond)
batch, dim, device, dtype = *data.shape, data.device, data.dtype
num_time_embeds = self.num_time_embeds
time_embed = self.to_time_embeds(diffusion_timesteps)
data = data.unsqueeze(1)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
model_inputs = [time_embed, data, learned_queries]
if self.cond and not self.cross_attn:
model_inputs.insert(0, cond_feature) # cond_feature defined in first loop above
tokens = torch.cat(model_inputs, dim = 1) # (b, 3/4, d); batch and d=512 same across the model_inputs
#print("tokens shape: ", tokens.shape)
if self.cross_attn:
cond_feature = None if not self.cond else cond_feature
#print("tokens shape: ", tokens.shape, cond_feature.shape)
tokens = self.causal_transformer(tokens, context=cond_feature)
else:
tokens = self.causal_transformer(tokens)
# get learned query, which should predict the sdf layer embedding (per DDPM timestep)
pred = tokens[..., -1, :]
return pred
================================================
FILE: models/archs/encoders/__init__.py
================================================
#!/usr/bin/python3
================================================
FILE: models/archs/encoders/auto_decoder.py
================================================
#!/usr/bin/env python3
import torch.nn as nn
import torch
import torch.nn.functional as F
import json
import math
class AutoDecoder():
# specs is a json filepath that contains the specifications for the experiment
# also requires total num_scenes (all meshes from all classes), which is given by len of dataset
def __init__(self, num_scenes, latent_size):
self.num_scenes = num_scenes
self.latent_size = latent_size
def build_model(self):
lat_vecs = nn.Embedding(self.num_scenes, self.latent_size, max_norm=1.0)
nn.init.normal_(lat_vecs.weight.data, 0.0, 1.0/math.sqrt(self.latent_size))
return lat_vecs
================================================
FILE: models/archs/encoders/conv_pointnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch_scatter import scatter_mean, scatter_max
class ConvPointnet(nn.Module):
''' PointNet-based encoder network with ResNet blocks for each point.
Number of input points are fixed.
Args:
c_dim (int): dimension of latent code c
dim (int): input points dimension
hidden_dim (int): hidden dimension of the network
scatter_type (str): feature aggregation when doing local pooling
unet (bool): weather to use U-Net
unet_kwargs (str): U-Net parameters
plane_resolution (int): defined resolution for plane feature
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
n_blocks (int): number of blocks ResNetBlockFC layers
'''
def __init__(self, c_dim=512, dim=3, hidden_dim=128, scatter_type='max',
unet=True, unet_kwargs={"depth": 4, "merge_mode": "concat", "start_filts": 32},
plane_resolution=64, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5,
inject_noise=False):
super().__init__()
self.c_dim = c_dim
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.blocks = nn.ModuleList([
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU()
self.hidden_dim = hidden_dim
if unet:
self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
else:
self.unet = None
self.reso_plane = plane_resolution
self.plane_type = plane_type
self.padding = padding
if scatter_type == 'max':
self.scatter = scatter_max
elif scatter_type == 'mean':
self.scatter = scatter_mean
def generate_plane_features(self, p, c, plane='xz'):
# acquire indices of features in plane
xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
index = self.coordinate2index(xy, self.reso_plane)
# scatter plane features from points
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
c = c.permute(0, 2, 1) # B x 512 x T
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
# process the plane features with UNet
if self.unet is not None:
fea_plane = self.unet(fea_plane)
return fea_plane
# takes in "p": point cloud and "query": sdf_xyz
# sample plane features for unlabeled_query as well
def forward(self, p, query):
batch_size, T, D = p.size()
# acquire the index for each point
coord = {}
index = {}
if 'xz' in self.plane_type:
coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
if 'xy' in self.plane_type:
coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
if 'yz' in self.plane_type:
coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
net = self.fc_pos(p)
net = self.blocks[0](net)
for block in self.blocks[1:]:
pooled = self.pool_local(coord, index, net)
net = torch.cat([net, pooled], dim=2)
net = block(net)
c = self.fc_c(net)
fea = {}
plane_feat_sum = 0
#denoise_loss = 0
if 'xz' in self.plane_type:
fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
if 'xy' in self.plane_type:
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
if 'yz' in self.plane_type:
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
return plane_feat_sum.transpose(2,1)
# given plane features with dimensions (3*dim, 64, 64)
# first reshape into the three planes, then generate query features from it
def forward_with_plane_features(self, plane_features, query):
# plane features shape: batch, dim*3, 64, 64
idx = int(plane_features.shape[1] / 3)
fea = {}
fea['xz'], fea['xy'], fea['yz'] = plane_features[:,0:idx,...], plane_features[:,idx:idx*2,...], plane_features[:,idx*2:,...]
#print("shapes: ", fea['xz'].shape, fea['xy'].shape, fea['yz'].shape) #([1, 256, 64, 64])
plane_feat_sum = 0
plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
return plane_feat_sum.transpose(2,1)
# c is point cloud features
# p is point cloud (coordinates)
def forward_with_pc_features(self, c, p, query):
#print("c, p shapes:", c.shape, p.shape)
fea = {}
fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
plane_feat_sum = 0
plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
return plane_feat_sum.transpose(2,1)
def get_point_cloud_features(self, p):
batch_size, T, D = p.size()
# acquire the index for each point
coord = {}
index = {}
if 'xz' in self.plane_type:
coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
if 'xy' in self.plane_type:
coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
if 'yz' in self.plane_type:
coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
net = self.fc_pos(p)
net = self.blocks[0](net)
for block in self.blocks[1:]:
pooled = self.pool_local(coord, index, net)
net = torch.cat([net, pooled], dim=2)
net = block(net)
c = self.fc_c(net)
return c
def get_plane_features(self, p):
c = self.get_point_cloud_features(p)
fea = {}
if 'xz' in self.plane_type:
fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
if 'xy' in self.plane_type:
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
if 'yz' in self.plane_type:
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
return fea['xz'], fea['xy'], fea['yz']
def normalize_coordinate(self, p, padding=0.1, plane='xz'):
''' Normalize coordinate to [0, 1] for unit cube experiments
Args:
p (tensor): point
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
plane (str): plane feature type, ['xz', 'xy', 'yz']
'''
if plane == 'xz':
xy = p[:, :, [0, 2]]
elif plane =='xy':
xy = p[:, :, [0, 1]]
else:
xy = p[:, :, [1, 2]]
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
xy_new = xy_new + 0.5 # range (0, 1)
# f there are outliers out of the range
if xy_new.max() >= 1:
xy_new[xy_new >= 1] = 1 - 10e-6
if xy_new.min() < 0:
xy_new[xy_new < 0] = 0.0
return xy_new
def coordinate2index(self, x, reso):
''' Normalize coordinate to [0, 1] for unit cube experiments.
Corresponds to our 3D model
Args:
x (tensor): coordinate
reso (int): defined resolution
coord_type (str): coordinate type
'''
x = (x * reso).long()
index = x[:, :, 0] + reso * x[:, :, 1]
index = index[:, None, :]
return index
# xy is the normalized coordinates of the point cloud of each plane
# I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
def pool_local(self, xy, index, c):
bs, fea_dim = c.size(0), c.size(2)
keys = xy.keys()
c_out = 0
for key in keys:
# scatter plane features from points
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
if self.scatter == scatter_max:
fea = fea[0]
# gather feature back to points
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
c_out += fea
return c_out.permute(0, 2, 1)
# sample_plane_feature function copied from /src/conv_onet/models/decoder.py
# uses values from plane_feature and pixel locations from vgrid to interpolate feature
def sample_plane_feature(self, query, plane_feature, plane):
xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
xy = xy[:, :, None].float()
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
return sampled_feat
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class DownConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pooling = pooling
self.conv1 = conv3x3(self.in_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
before_pool = x
if self.pooling:
x = self.pool(x)
return x, before_pool
class UpConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'):
super(UpConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels,
mode=self.up_mode)
if self.merge_mode == 'concat':
self.conv1 = conv3x3(
2*self.out_channels, self.out_channels)
else:
# num of input channels to conv2 is same
self.conv1 = conv3x3(self.out_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
def forward(self, from_down, from_up):
""" Forward pass
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
"""
from_up = self.upconv(from_up)
if self.merge_mode == 'concat':
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
class UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
expansive pathway) about an input tensor is merged with
information representing the localization of details
(from the encoding, compressive pathway).
Modifications to the original paper:
(1) padding is used in 3x3 convolutions to prevent loss
of border pixels
(2) merging outputs does not require cropping due to (1)
(3) residual connections can be used by specifying
UNet(merge_mode='add')
(4) if non-parametric upsampling is used in the decoder
pathway (specified by upmode='upsample'), then an
additional 1x1 2d convolution occurs after upsampling
to reduce channel dimensionality by a factor of 2.
This channel halving happens with the convolution in
the tranpose convolution (specified by upmode='transpose')
"""
def __init__(self, num_classes, in_channels=3, depth=5,
start_filts=64, up_mode='transpose', same_channels=False,
merge_mode='concat', **kwargs):
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(UNet, self).__init__()
if up_mode in ('transpose', 'upsample'):
self.up_mode = up_mode
else:
raise ValueError("\"{}\" is not a valid mode for "
"upsampling. Only \"transpose\" and "
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ('concat', 'add'):
self.merge_mode = merge_mode
else:
raise ValueError("\"{}\" is not a valid mode for"
"merging up and down paths. "
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add':
raise ValueError("up_mode \"upsample\" is incompatible "
"with merge_mode \"add\" at the moment "
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
self.num_classes = num_classes
self.in_channels = in_channels
self.start_filts = start_filts
self.depth = depth
self.down_convs = []
self.up_convs = []
# create the encoder pathway and add to a list
for i in range(depth):
ins = self.in_channels if i == 0 else outs
outs = self.start_filts*(2**i) if not same_channels else self.in_channels
pooling = True if i < depth-1 else False
#print("down ins, outs: ", ins, outs) # [latent dim, 32], [32, 64]...[128, 256]
down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv)
# create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks
for i in range(depth-1):
ins = outs
outs = ins // 2 if not same_channels else ins
up_conv = UpConv(ins, outs, up_mode=up_mode,
merge_mode=merge_mode)
self.up_convs.append(up_conv)
#print("up ins, outs: ", ins, outs)# [256, 128]...[64, 32]; final 32 to latent is done through self.conv_final
# add the list of modules to current module
self.down_convs = nn.ModuleList(self.down_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.conv_final = conv1x1(outs, self.num_classes)
self.reset_params()
@staticmethod
def weight_init(m):
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, x):
encoder_outs = []
# encoder pathway, save outputs for merging
for i, module in enumerate(self.down_convs):
#print("down {} x1: ".format(i), x.shape) # increasing channels but decreasing resolution (64x64 -> 8x8)
x, before_pool = module(x)
#print("down {} x2: ".format(i), x.shape)
encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)]
#print("up {} x1: ".format(i), x.shape)
x = module(before_pool, x)
#print("up {} x2: ".format(i), x.shape)
#exit()
# No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already.
x = self.conv_final(x)
return x
def generate(self, x):
return self(x)
# Resnet Blocks
class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
def __init__(self, size_in, size_out=None, size_h=None):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
self.actvn = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
# Initialization
nn.init.zeros_(self.fc_1.weight)
def forward(self, x):
net = self.fc_0(self.actvn(x))
dx = self.fc_1(self.actvn(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
================================================
FILE: models/archs/encoders/dgcnn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
def knn(x, k):
inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
xx = torch.sum(x ** 2, dim=1, keepdim=True)
pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()
idx = pairwise_distance.topk(k=k, dim=-1)[1]
return idx
def get_graph_feature(x, k=20):
idx = knn(x, k=k)
batch_size, num_points, _ = idx.size()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
_, num_dims, _ = x.size()
x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims)
# -> (batch_size*num_points, num_dims)
# batch_size * num_points * k + range(0, batch_size*num_points)
feature = x.view(batch_size * num_points, -1)[idx, :]
feature = feature.view(batch_size, num_points, k, num_dims)
x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
return feature
class DGCNN(nn.Module):
def __init__(
self,
emb_dims=512,
use_bn=False
):
super().__init__()
if use_bn:
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(256)
self.bn5 = nn.BatchNorm2d(emb_dims)
self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))
self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))
self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), self.bn4, nn.LeakyReLU(negative_slope=0.2))
self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2))
else:
self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
def forward(self, x):
batch_size, num_dims, num_points = x.size() # x: batch x 3 x num of points
x = get_graph_feature(x) # x: batch x 6 x num of points x 20
x1 = self.conv1(x) # x1: batch x 64 x num of points x 20
x1_max = x1.max(dim=-1, keepdim=True)[0] # x1_max: batch x 64 x num of points x 1
x2 = self.conv2(x1) # x2: batch x 64 x num of points x 20
x2_max = x2.max(dim=-1, keepdim=True)[0] # x2_max: batch x 64 x num of points x 1
x3 = self.conv3(x2) # x3: batch x 128 x num of points x 20
x3_max = x3.max(dim=-1, keepdim=True)[0] # x3_max: batch x 128 x num of points x 1
x4 = self.conv4(x3) # x4: batch x 256 x num of points x 20
x4_max = x4.max(dim=-1, keepdim=True)[0] # x4_max: batch x 256 x num of points x 1
x_max = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max: batch x 512 x num of points x 1
point_feat = torch.squeeze(self.conv5(x_max), dim=3) # point feat: batch x 512 x num of points
global_feat = point_feat.max(dim=2, keepdim=False)[0] # global feat: batch x 512
return global_feat
================================================
FILE: models/archs/encoders/rbf.py
================================================
#!/usr/bin/env python3
import torch.nn as nn
import torch
import torch.nn.functional as F
import json
import numpy as np
# https://github.com/vsitzmann/siren/blob/4df34baee3f0f9c8f351630992c1fe1f69114b5f/modules.py#L266
class RBFLayer(nn.Module):
'''Transforms incoming data using a given radial basis function.
- Input: (1, N, in_features) where N is an arbitrary batch size
- Output: (1, N, out_features) where N is an arbitrary batch size'''
def __init__(self, in_features=3, out_features=1024):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.centres = nn.Parameter(torch.Tensor(out_features, in_features))
self.sigmas = nn.Parameter(torch.Tensor(out_features))
self.reset_parameters()
self.freq = nn.Parameter(np.pi * torch.ones((1, self.out_features)))
def reset_parameters(self):
nn.init.uniform_(self.centres, -1, 1)
nn.init.constant_(self.sigmas, 10)
def forward(self, x):
#x = x[0, ...]
size = (x.size(0), self.out_features, self.in_features)
x = x.unsqueeze(1).expand(size)
c = self.centres.unsqueeze(0).expand(size)
distances = (x - c).pow(2).sum(-1) * self.sigmas.unsqueeze(0)
return self.gaussian(distances).unsqueeze(0)
def gaussian(self, alpha):
phi = torch.exp(-1 * alpha.pow(2))
return phi
================================================
FILE: models/archs/encoders/sal_pointnet.py
================================================
#!/usr/bin/env python3
import torch.nn as nn
import torch
import torch.nn.functional as F
import json
import sys
sys.path.append("..")
import utils # actual dir is ../utils
# pointnet from SAL paper: https://github.com/matanatz/SAL/blob/master/code/model/network.py#L14
class SalPointNet(nn.Module):
''' PointNet-based encoder network. Based on: https://github.com/autonomousvision/occupancy_networks
Args:
c_dim (int): dimension of latent code c
dim (int): input points dimension
hidden_dim (int): hidden dimension of the network
'''
def __init__(self, c_dim=256, in_dim=3, hidden_dim=128):
super().__init__()
self.c_dim = c_dim
self.fc_pos = nn.Linear(in_dim, 2*hidden_dim)
self.fc_0 = nn.Linear(2*hidden_dim, hidden_dim)
self.fc_1 = nn.Linear(2*hidden_dim, hidden_dim)
self.fc_2 = nn.Linear(2*hidden_dim, hidden_dim)
self.fc_3 = nn.Linear(2*hidden_dim, hidden_dim)
self.fc_mean = nn.Linear(hidden_dim, c_dim)
self.fc_std = nn.Linear(hidden_dim, c_dim)
torch.nn.init.constant_(self.fc_mean.weight,0)
torch.nn.init.constant_(self.fc_mean.bias, 0)
torch.nn.init.constant_(self.fc_std.weight, 0)
torch.nn.init.constant_(self.fc_std.bias, -10)
self.actvn = nn.ReLU()
self.pool = self.maxpool
def forward(self, p):
net = self.fc_pos(p)
net = self.fc_0(self.actvn(net))
pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
net = torch.cat([net, pooled], dim=2)
net = self.fc_1(self.actvn(net))
pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
net = torch.cat([net, pooled], dim=2)
net = self.fc_2(self.actvn(net))
pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
net = torch.cat([net, pooled], dim=2)
net = self.fc_3(self.actvn(net))
net = self.pool(net, dim=1)
c_mean = self.fc_mean(self.actvn(net))
c_std = self.fc_std(self.actvn(net))
return c_mean,c_std
def maxpool(self, x, dim=-1, keepdim=False):
out, _ = x.max(dim=dim, keepdim=keepdim)
return out
================================================
FILE: models/archs/encoders/vanilla_pointnet.py
================================================
#!/usr/bin/env python3
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchmeta.modules import (MetaModule, MetaSequential, MetaLinear, MetaConv1d, MetaBatchNorm1d)
import json
import sys
sys.path.append("..")
import utils # actual dir is ../utils
class PointNet(MetaModule):
def __init__(self, latent_size):
super().__init__()
self.latent_size = latent_size
self.conv1 = MetaConv1d(3, 64, kernel_size=1, bias=False)
self.conv2 = MetaConv1d(64, 64, kernel_size=1, bias=False)
self.conv3 = MetaConv1d(64, 64, kernel_size=1, bias=False)
self.conv4 = MetaConv1d(64, 128, kernel_size=1, bias=False)
self.conv5 = MetaConv1d(128, self.latent_size, kernel_size=1, bias=False)
self.bn1 = MetaBatchNorm1d(64)
self.bn2 = MetaBatchNorm1d(64)
self.bn3 = MetaBatchNorm1d(64)
self.bn4 = MetaBatchNorm1d(128)
self.bn5 = MetaBatchNorm1d(self.latent_size)
# self.layers = MetaSequential(
# self.conv1, self.bn1, nn.ReLU(),
# self.conv2, self.bn2, nn.ReLU(),
# self.conv3, self.bn3, nn.ReLU(),
# self.conv4, self.bn4, nn.ReLU(),
# self.conv5, self.bn5,
# )
# def forward(self, x, params=None):
# x = self.layers(x, params=self.get_subdict(params, 'pointnet'))
# x = x.max(dim=2, keepdim=False)[0]
# return x
def forward(self, x, params=None):
x = F.relu(self.bn1(self.conv1(x, params)))
x = F.relu(self.bn2(self.conv2(x, params)))
x = F.relu(self.bn3(self.conv3(x, params)))
x = F.relu(self.bn4(self.conv4(x, params)))
x = self.bn5(self.conv5(x, params))
x = x.max(dim=2, keepdim=False)[0]
return x
================================================
FILE: models/archs/modulated_sdf.py
================================================
#!/usr/bin/env python3
import torch.nn as nn
import torch
import torch.nn.functional as F
import json
import sys
import torch.nn.init as init
import numpy as np
# no dropout or skip connections for now
class Layer(nn.Module):
def __init__(self, dim_in=512, dim_out=512, dim=512, dropout_prob=0.0, geo_init='first', activation='relu'):
super().__init__()
self.linear = nn.Linear(dim_in, dim_out)
if activation=='relu':
self.activation = nn.ReLU()
elif activation=='tanh':
self.activation = nn.Tanh()
else:
self.activation = nn.Identity()
#self.dropout = nn.Dropout(p=dropout_prob)
if geo_init == 'first':
init.normal_(self.linear.weight, mean=0.0, std=np.sqrt(2) / np.sqrt(dim))
init.constant_(self.linear.bias, 0.0)
elif geo_init == 'last':
init.normal_(self.linear.weight, mean=2 * np.sqrt(np.pi) / np.sqrt(dim), std=0.000001)
init.constant_(self.linear.bias, -0.5)
def forward(self, x):
out = self.linear(x)
out = self.activation(out)
#out = self.dropout(out)
return out
class ModulatedMLP(nn.Module):
def __init__(self, latent_size=512, hidden_dim=512, num_layers=9, latent_in=True,
skip_connection=[4], dropout_prob=0.0, pos_enc=False, pe_num_freq=5, tanh_act=False
):
super().__init__()
self.skip_connection = skip_connection # list of the indices of layer to add skip connection
self.hidden_dim = hidden_dim
self.pe_num_freq = pe_num_freq
self.pos_enc = pos_enc
self.latent_in = latent_in
#print("tanh act: ", tanh_act)
#print("latent in, skip layer: ", latent_in, skip_connection)
first_dim_in = 3
# must remove last tanh layer if using positional encoding
# posititional encoding
if pos_enc:
pe_func = []
for i in range(self.pe_num_freq):
pe_func.append(lambda data, freq=i: torch.sin(data * (2**i)))
pe_func.append(lambda data, freq=i: torch.cos(data * (2**i)))
self.pe_func = pe_func
first_dim_in = 3*pe_num_freq*2
# latent code concatenated to coordinates as input to model
if latent_in:
num_modulations = hidden_dim
#num_modulations = hidden_dim * (num_layers - 1)
first_dim_in += latent_size # num_modulations
mod_act = nn.ReLU()
else: # use shifting instead of concatenation
# We modulate features at every *hidden* layer of the base network and
# therefore have dim_hidden * (num_layers - 1) modulations, since the last layer is not modulated
num_modulations = hidden_dim * (num_layers - 1)
mod_act = nn.Identity()
#self.mod_net = nn.Sequential(nn.Linear(latent_size, num_modulations), mod_act)
layers = []
#print("index, dim in: ", end='')
for i in range(num_layers-1):
if i==0:
dim_in = first_dim_in
elif i in skip_connection:
dim_in = hidden_dim+3+latent_size #num_modulations+3+hidden_dim
else:
dim_in = hidden_dim
#print(i, dim_in, end = '; ')
layers.append(
Layer(
dim_in=dim_in,
dim_out=hidden_dim,
activation='relu',
geo_init='first',
dim=hidden_dim,
dropout_prob=dropout_prob
)
)
self.net = nn.Sequential(*layers)
last_act = 'tanh' if tanh_act else 'identity'
self.last_layer = Layer(dim_in=hidden_dim,dim_out=1,activation=last_act,geo_init='last',dim=hidden_dim)
def pe_transform(self, data):
pe_data = torch.cat([f(data) for f in self.pe_func], dim=-1)
return pe_data
def forward(self, xyz, latent):
'''
xyz: B, 16000, 3 (query coordinates for predicting)
latent: B, 512 (latent vector from 3 gradient steps)
'''
#print("latent: ",latent.shape)
modulations = latent#self.mod_net(latent)
#print("mod size: ",modulations.shape, xyz.shape)
#print("latent size: ", latent.shape, modulations.shape) # B,512 and B,512
if self.pos_enc:
xyz = self.pe_transform(xyz)
x = xyz.clone()
if self.latent_in:
# modulations = modulations.unsqueeze(-2).repeat(1,xyz.shape[1],1) # [B, 16000, 512] or [B, 16000, 256]
#print("repeated mod shape: ",modulations.shape)
x = torch.cat((x, modulations),dim=-1)
#print("input size: ", x.shape, xyz.shape) # [8, 16000, 515], [8, 16000, 3]
idx = 0
for i, layer in enumerate(self.net):
if i in self.skip_connection:
x = torch.cat(( x, torch.cat((xyz, modulations),dim=-1)), dim=-1)
x = layer.linear(x)
if not self.latent_in:
shift = modulations[:, idx : idx + self.hidden_dim].unsqueeze(1)
x = x + shift
idx += self.hidden_dim
x = layer.activation(x)
#x = layer.dropout(x)
out = self.last_layer(x)
return out, modulations
================================================
FILE: models/archs/resnet_block.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
# Resnet Blocks
class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
def __init__(self, size_in, size_out=None, size_h=None):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
self.actvn = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
# Initialization
nn.init.zeros_(self.fc_1.weight)
def forward(self, x):
net = self.fc_0(self.actvn(x))
dx = self.fc_1(self.actvn(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
================================================
FILE: models/archs/sdf_decoder.py
================================================
#!/usr/bin/env python3
import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.nn.init as init
import numpy as np
class SdfDecoder(nn.Module):
def __init__(self, latent_size=256, hidden_dim=512,
skip_connection=True, tanh_act=False,
geo_init=True, input_size=None
):
super().__init__()
self.latent_size = latent_size
self.input_size = latent_size+3 if input_size is None else input_size
self.skip_connection = skip_connection
self.tanh_act = tanh_act
skip_dim = hidden_dim+self.input_size if skip_connection else hidden_dim
self.block1 = nn.Sequential(
nn.Linear(self.input_size, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.block2 = nn.Sequential(
nn.Linear(skip_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.block3 = nn.Linear(hidden_dim, 1)
if geo_init:
for m in self.block3.modules():
if isinstance(m, nn.Linear):
init.normal_(m.weight, mean=2 * np.sqrt(np.pi) / np.sqrt(hidden_dim), std=0.000001)
init.constant_(m.bias, -0.5)
for m in self.block2.modules():
if isinstance(m, nn.Linear):
init.normal_(m.weight, mean=0.0, std=np.sqrt(2) / np.sqrt(hidden_dim))
init.constant_(m.bias, 0.0)
for m in self.block1.modules():
if isinstance(m, nn.Linear):
init.normal_(m.weight, mean=0.0, std=np.sqrt(2) / np.sqrt(hidden_dim))
init.constant_(m.bias, 0.0)
def forward(self, x):
'''
x: concatenated xyz and shape features, shape: B, N, D+3
'''
block1_out = self.block1(x)
# skip connection, concat
if self.skip_connection:
block2_in = torch.cat([x, block1_out], dim=-1)
else:
block2_in = block1_out
block2_out = self.block2(block2_in)
out = self.block3(block2_out)
if self.tanh_act:
out = nn.Tanh()(out)
return out
================================================
FILE: models/archs/unet.py
================================================
'''
Codes are from:
https://github.com/jaxony/unet-pytorch/blob/master/model.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import numpy as np
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class DownConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pooling = pooling
self.conv1 = conv3x3(self.in_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
before_pool = x
if self.pooling:
x = self.pool(x)
return x, before_pool
class UpConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'):
super(UpConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels,
mode=self.up_mode)
if self.merge_mode == 'concat':
self.conv1 = conv3x3(
2*self.out_channels, self.out_channels)
else:
# num of input channels to conv2 is same
self.conv1 = conv3x3(self.out_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
def forward(self, from_down, from_up):
""" Forward pass
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
"""
from_up = self.upconv(from_up)
if self.merge_mode == 'concat':
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
class UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
expansive pathway) about an input tensor is merged with
information representing the localization of details
(from the encoding, compressive pathway).
Modifications to the original paper:
(1) padding is used in 3x3 convolutions to prevent loss
of border pixels
(2) merging outputs does not require cropping due to (1)
(3) residual connections can be used by specifying
UNet(merge_mode='add')
(4) if non-parametric upsampling is used in the decoder
pathway (specified by upmode='upsample'), then an
additional 1x1 2d convolution occurs after upsampling
to reduce channel dimensionality by a factor of 2.
This channel halving happens with the convolution in
the tranpose convolution (specified by upmode='transpose')
"""
def __init__(self, num_classes, in_channels=3, depth=5,
start_filts=64, up_mode='transpose',
merge_mode='concat', **kwargs):
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(UNet, self).__init__()
if up_mode in ('transpose', 'upsample'):
self.up_mode = up_mode
else:
raise ValueError("\"{}\" is not a valid mode for "
"upsampling. Only \"transpose\" and "
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ('concat', 'add'):
self.merge_mode = merge_mode
else:
raise ValueError("\"{}\" is not a valid mode for"
"merging up and down paths. "
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add':
raise ValueError("up_mode \"upsample\" is incompatible "
"with merge_mode \"add\" at the moment "
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
self.num_classes = num_classes
self.in_channels = in_channels
self.start_filts = start_filts
self.depth = depth
self.down_convs = []
self.up_convs = []
# create the encoder pathway and add to a list
for i in range(depth):
ins = self.in_channels if i == 0 else outs
outs = self.start_filts*(2**i)
pooling = True if i < depth-1 else False
down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv)
# create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks
for i in range(depth-1):
ins = outs
outs = ins // 2
up_conv = UpConv(ins, outs, up_mode=up_mode,
merge_mode=merge_mode)
self.up_convs.append(up_conv)
# add the list of modules to current module
self.down_convs = nn.ModuleList(self.down_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.conv_final = conv1x1(outs, self.num_classes)
self.reset_params()
@staticmethod
def weight_init(m):
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, x):
encoder_outs = []
# encoder pathway, save outputs for merging
for i, module in enumerate(self.down_convs):
x, before_pool = module(x)
encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i+2)]
x = module(before_pool, x)
# No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already.
x = self.conv_final(x)
return x
if __name__ == "__main__":
"""
testing
"""
model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
print(model)
print(sum(p.numel() for p in model.parameters()))
reso = 176
x = np.zeros((1, 1, reso, reso))
x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan
x = torch.FloatTensor(x)
out = model(x)
print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))
# loss = torch.sum(out)
# loss.backward()
================================================
FILE: models/autoencoder.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange, reduce
from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')
class BetaVAE(nn.Module):
num_iter = 0 # Global static variable to keep track of iterations
def __init__(self,
in_channels: int,
latent_dim: int,
hidden_dims: List = None,
kl_std=1.0,
beta: int = 4,
gamma:float = 10.,
max_capacity: int = 25,
Capacity_max_iter: int = 1e5, # 10000 in default configs
loss_type:str = 'B',
**kwargs) -> None:
super(BetaVAE, self).__init__()
self.latent_dim = latent_dim
self.beta = beta
self.gamma = gamma
self.loss_type = loss_type
self.C_max = torch.Tensor([max_capacity])
self.C_stop_iter = Capacity_max_iter
self.in_channels = in_channels
self.kl_std = kl_std
#print("kl standard deviation: ", self.kl_std)
modules = []
if hidden_dims is None:
#hidden_dims = [32, 64, 128, 256, 512]
hidden_dims = [512, 512, 512, 512, 512]
self.hidden_dims = hidden_dims
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) # for plane features resolution 64x64, spatial resolution is 2x2 after the last encoder layer
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= self.in_channels, # changed from 3 to in_channels
kernel_size= 3, padding= 1),
nn.Tanh())
#print(self)
def encode(self, enc_input: Tensor) -> List[Tensor]:
"""
Encodes the input by passing through the encoder network
and returns the latent codes.
:enc_input: (Tensor) Input tensor to encoder [B x D x resolution x resolution]
:return: (Tensor) List of latent codes
"""
result = enc_input
result = self.encoder(enc_input) # [B, D, 2, 2]
result = torch.flatten(result, start_dim=1) # ([32, D*4])
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return [mu, log_var]
def decode(self, z: Tensor) -> Tensor:
'''
z: latent vector: B, D (D = latent_dim*3)
'''
result = self.decoder_input(z) # ([32, D*4])
result = result.view(-1, int(result.shape[-1]/4), 2, 2) # for plane features resolution 64x64, spatial resolution is 2x2 after the last encoder layer
result = self.decoder(result)
result = self.final_layer(result) # ([32, D, resolution, resolution])
return result
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Will a single z be enough to compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, data: Tensor, **kwargs) -> Tensor:
mu, log_var = self.encode(data)
z = self.reparameterize(mu, log_var)
return [self.decode(z), data, mu, log_var, z]
# only using VAE loss
def loss_function(self,
*args,
**kwargs) -> dict:
self.num_iter += 1
recons = args[0]
data = args[1]
mu = args[2]
log_var = args[3]
kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
#print("recon, data shape: ", recons.shape, data.shape)
#recons_loss = F.mse_loss(recons, data)
# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)
if self.kl_std == 'zero_mean':
latent = self.reparameterize(mu, log_var)
#print("latent shape: ", latent.shape) # (B, dim)
l2_size_loss = torch.sum(torch.norm(latent, dim=-1))
kl_loss = l2_size_loss / latent.shape[0]
else:
std = torch.exp(0.5 * log_var)
gt_dist = torch.distributions.normal.Normal( torch.zeros_like(mu), torch.ones_like(std)*self.kl_std )
sampled_dist = torch.distributions.normal.Normal( mu, std )
#gt_dist = normal_dist.sample(log_var.shape)
#print("gt dist shape: ", gt_dist.shape)
kl = torch.distributions.kl.kl_divergence(sampled_dist, gt_dist) # reversed KL
kl_loss = reduce(kl, 'b ... -> b (...)', 'mean').mean()
return kld_weight * kl_loss
def sample(self,
num_samples:int,
**kwargs) -> Tensor:
"""
Samples from the latent space and return the corresponding
image space map.
:param num_samples: (Int) Number of samples
:param current_device: (Int) Device to run the model
:return: (Tensor)
"""
z = torch.randn(num_samples, self.latent_dim)
z = z.cuda()
samples = self.decode(z)
return samples
def generate(self, x: Tensor, **kwargs) -> Tensor:
"""
Given an input image x, returns the reconstructed image
:param x: (Tensor) [B x C x H x W]
:return: (Tensor) [B x C x H x W]
"""
return self.forward(x)[0]
def get_latent(self, x):
'''
given input x, return the latent code
x: [B x C x H x W]
return: [B x latent_dim]
'''
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return z
================================================
FILE: models/combined_model.py
================================================
import torch
import torch.utils.data
from torch.nn import functional as F
import pytorch_lightning as pl
# add paths in model/__init__.py for new models
from models import *
class CombinedModel(pl.LightningModule):
def __init__(self, specs):
super().__init__()
self.specs = specs
self.task = specs['training_task'] # 'combined' or 'modulation' or 'diffusion'
if self.task in ('combined', 'modulation'):
self.sdf_model = SdfModel(specs=specs)
feature_dim = specs["SdfModelSpecs"]["latent_dim"] # latent dim of pointnet
modulation_dim = feature_dim*3 # latent dim of modulation
latent_std = specs.get("latent_std", 0.25) # std of target gaussian distribution of latent space
hidden_dims = [modulation_dim, modulation_dim, modulation_dim, modulation_dim, modulation_dim]
self.vae_model = BetaVAE(in_channels=feature_dim*3, latent_dim=modulation_dim, hidden_dims=hidden_dims, kl_std=latent_std)
if self.task in ('combined', 'diffusion'):
self.diffusion_model = DiffusionModel(model=DiffusionNet(**specs["diffusion_model_specs"]), **specs["diffusion_specs"])
def training_step(self, x, idx):
if self.task == 'combined':
return self.train_combined(x)
elif self.task == 'modulation':
return self.train_modulation(x)
elif self.task == 'diffusion':
return self.train_diffusion(x)
def configure_optimizers(self):
if self.task == 'combined':
params_list = [
{ 'params': list(self.sdf_model.parameters()) + list(self.vae_model.parameters()), 'lr':self.specs['sdf_lr'] },
{ 'params': self.diffusion_model.parameters(), 'lr':self.specs['diff_lr'] }
]
elif self.task == 'modulation':
params_list = [
{ 'params': self.parameters(), 'lr':self.specs['sdf_lr'] }
]
elif self.task == 'diffusion':
params_list = [
{ 'params': self.parameters(), 'lr':self.specs['diff_lr'] }
]
optimizer = torch.optim.Adam(params_list)
return {
"optimizer": optimizer,
# "lr_scheduler": {
# "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=50000, threshold=0.0002, min_lr=1e-6, verbose=False),
# "monitor": "total"
# }
}
#-----------different training steps for sdf modulation, diffusion, combined----------
def train_modulation(self, x):
xyz = x['xyz'] # (B, N, 3)
gt = x['gt_sdf'] # (B, N)
pc = x['point_cloud'] # (B, 1024, 3)
# STEP 1: obtain reconstructed plane feature and latent code
plane_features = self.sdf_model.pointnet.get_plane_features(pc)
original_features = torch.cat(plane_features, dim=1)
out = self.vae_model(original_features) # out = [self.decode(z), input, mu, log_var, z]
reconstructed_plane_feature, latent = out[0], out[-1]
# STEP 2: pass recon back to GenSDF pipeline
pred_sdf = self.sdf_model.forward_with_plane_features(reconstructed_plane_feature, xyz)
# STEP 3: losses for VAE and SDF
# we only use the KL loss for the VAE; no reconstruction loss
try:
vae_loss = self.vae_model.loss_function(*out, M_N=self.specs["kld_weight"] )
except:
print("vae loss is nan at epoch {}...".format(self.current_epoch))
return None # skips this batch
sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze(), reduction='none')
sdf_loss = reduce(sdf_loss, 'b ... -> b (...)', 'mean').mean()
loss = sdf_loss + vae_loss
loss_dict = {"sdf": sdf_loss, "vae": vae_loss}
self.log_dict(loss_dict, prog_bar=True, enable_graph=False)
return loss
def train_diffusion(self, x):
self.train()
pc = x['point_cloud'] # (B, 1024, 3) or False if unconditional
latent = x['latent'] # (B, D)
# unconditional training if cond is None
cond = pc if self.specs['diffusion_model_specs']['cond'] else None
# diff_100 and 1000 loss refers to the losses when t<100 and 100 b (...)', 'mean').mean()
# STEP 4: use latent as input to diffusion model
cond = pc if self.specs['diffusion_model_specs']['cond'] else None
diff_loss, diff_100_loss, diff_1000_loss, pred_latent, perturbed_pc = self.diffusion_model.diffusion_model_from_latent(latent, cond=cond)
# STEP 5: use predicted / reconstructed latent to run SDF loss
generated_plane_feature = self.vae_model.decode(pred_latent)
generated_sdf_pred = self.sdf_model.forward_with_plane_features(generated_plane_feature, xyz)
generated_sdf_loss = F.l1_loss(generated_sdf_pred.squeeze(), gt.squeeze())
# surface weight could prioritize points closer to surface but we did not notice better results when using it
#surface_weight = torch.exp(-50 * torch.abs(gt))
#generated_sdf_loss = torch.mean( F.l1_loss(generated_sdf_pred, gt, reduction='none') * surface_weight )
# we did not experiment with using constants/weights for each loss (VAE loss is weighted using value in specs file)
# results could potentially improve with a grid search
loss = sdf_loss + vae_loss + diff_loss + generated_sdf_loss
loss_dict = {
"total": loss,
"sdf": sdf_loss,
"vae": vae_loss,
"diff": diff_loss,
# diff_100 and 1000 loss refers to the losses when t<100 and 100 0 else 0.
x_T = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
traj.append(x_T.clone())
if traj:
return x_T, traj
else:
return x_T
@torch.no_grad()
def sample(self, dim, batch_size, noise=None, clip_denoised = True, traj=False, cond=None):
batch, device, objective = batch_size, self.betas.device, self.objective
traj = []
x_T = default(noise, torch.randn(batch, dim, device = device))
for t in reversed(range(0, self.num_timesteps)):
time_cond = torch.full((batch,), t, device = device, dtype = torch.long)
model_input = (x_T, cond) if cond is not None else x_T
pred_noise, x_start, *_ = self.model_predictions(model_input, time_cond)
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, _, model_log_variance = self.q_posterior(x_start = x_start, x_t = x_T, t = time_cond)
noise = torch.randn_like(x_T) if t > 0 else 0. # no noise if t == 0
x_T = model_mean + (0.5 * model_log_variance).exp() * noise
traj.append(x_T.clone())
if traj:
return x_T, traj
else:
return x_T
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# "nice property": return x_t given x_0, noise, and timestep
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
#noise = torch.clamp(noise, min=-6.0, max=6.0)
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
# main function for calculating loss
def forward(self, x_start, t, ret_pred_x=False, noise = None, cond=None):
'''
x_start: [B, D]
t: [B]
'''
noise = default(noise, lambda: torch.randn_like(x_start))
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_in = (x, cond) if cond is not None else x
model_out = self.model(model_in, t)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
else:
raise ValueError(f'unknown objective {self.objective}')
loss = self.loss_fn(model_out, target, reduction = 'none')
#loss = reduce(loss, 'b ... -> b (...)', 'mean', b = x_start.shape[0]) # only one dim of latent so don't need this line
loss = loss * extract(self.p2_loss_weight, t, loss.shape)
unreduced_loss = loss.detach().clone().mean(dim=1)
if ret_pred_x:
return loss.mean(), x, target, model_out, unreduced_loss
else:
return loss.mean(), unreduced_loss
def model_predictions(self, model_input, t):
#model_output1 = self.model(model_input, t, pass_cond=0)
#model_output2 = self.model(model_input, t, pass_cond=1)
#model_output = model_output2*5 - model_output1*4
model_output = self.model(model_input, t, pass_cond=1)
x = model_input[0] if type(model_input) is tuple else model_input
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, model_output)
elif self.objective == 'pred_x0':
pred_noise = self.predict_noise_from_start(x, t, model_output)
x_start = model_output
return ModelPrediction(pred_noise, x_start)
# a wrapper function that only takes x_start (clean modulation vector) and condition
# does everything including sampling timestep and returns loss, loss_100, loss_1000, prediction
def diffusion_model_from_latent(self, x_start, cond=None):
#if self.perturb_pc is None and cond is not None:
# print("check whether to pass condition!!!")
# STEP 1: sample timestep
t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device).long()
# STEP 2: perturb condition
pc = perturb_point_cloud(cond, self.perturb_pc, self.pc_size, self.crop_percent) if cond is not None else None
# STEP 3: pass to forward function
loss, x, target, model_out, unreduced_loss = self(x_start, t, cond=pc, ret_pred_x=True)
loss_100 = unreduced_loss[t<100].mean().detach()
loss_1000 = unreduced_loss[t>100].mean().detach()
return loss, loss_100, loss_1000, model_out, pc
def generate_from_pc(self, pc, load_pc=False, batch=5, save_pc=False, return_pc=False, ddim=False, perturb_pc=True):
self.eval()
with torch.no_grad():
if load_pc:
pc = sample_pc(pc, self.pc_size).cuda().unsqueeze(0)
if pc is None:
input_pc = None
save_pc = False
full_perturbed_pc = None
else:
if perturb_pc:
full_perturbed_pc = perturb_point_cloud(pc, self.perturb_pc)
perturbed_pc = full_perturbed_pc[:, torch.randperm(full_perturbed_pc.shape[1])[:self.pc_size] ]
input_pc = perturbed_pc.repeat(batch, 1, 1)
else:
full_perturbed_pc = pc
perturbed_pc = pc
input_pc = pc.repeat(batch, 1, 1)
#print("shapes: ", pc.shape, self.pc_size, self.perturb_pc, perturbed_pc.shape, full_perturbed_pc.shape)
#print("pc path: ", pc_path)
#print("pc shape: ", perturbed_pc.shape, input_pc.shape)
if save_pc: # save perturbed pc ply file for visualization
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(perturbed_pc.cpu().numpy().squeeze())
o3d.io.write_point_cloud("{}/input_pc.ply".format(save_pc), pcd)
sample_fn = self.ddim_sample if ddim else self.sample
samp,_ = sample_fn(dim=self.model.dim_in_out, batch_size=batch, traj=False, cond=input_pc)
if return_pc:
return samp, perturbed_pc
return samp
def generate_unconditional(self, num_samples):
self.eval()
with torch.no_grad():
samp,_ = self.sample(dim=self.model.dim_in_out, batch_size=num_samples, traj=False, cond=None)
return samp
================================================
FILE: models/diffusion.py
================================================
import math
import copy
import torch
from torch import nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from collections import namedtuple
from functools import partial
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
#from model.diffusion.model import *
from diff_utils.helpers import *
import numpy as np
import os
from statistics import mean
from tqdm.auto import tqdm
import open3d as o3d
# constants
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
class DiffusionModel(nn.Module):
def __init__(
self,
model,
timesteps = 1000, sampling_timesteps = None, beta_schedule = 'cosine',
sample_pc_size = 682, perturb_pc = None, crop_percent=0.25,
loss_type = 'l2', objective = 'pred_x0',
data_scale = 1.0, data_shift = 0.0,
p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended
p2_loss_weight_k = 1,
ddim_sampling_eta = 1.
):
super().__init__()
self.model = model
self.objective = objective
betas = linear_beta_schedule(timesteps) if beta_schedule == 'linear' else cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.pc_size = sample_pc_size
self.perturb_pc = perturb_pc
self.crop_percent = crop_percent
assert self.perturb_pc in [None, "partial", "noisy"]
self.loss_fn = F.l1_loss if loss_type=='l1' else F.mse_loss
# sampling related parameters
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# self.register_buffer('data_scale', torch.tensor(data_scale))
# self.register_buffer('data_shift', torch.tensor(data_shift))
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# calculate p2 reweighting
register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
@torch.no_grad()
def ddim_sample(self, dim, batch_size, noise=None, clip_denoised = True, traj=False, cond=None):
batch, device, total_timesteps, sampling_timesteps, eta, objective = batch_size, self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
times = torch.linspace(0., total_timesteps, steps = sampling_timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
traj = []
x_T = default(noise, torch.randn(batch, dim, device = device))
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = self.alphas_cumprod_prev[time]
alpha_next = self.alphas_cumprod_prev[time_next]
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
model_input = (x_T, cond) if cond is not None else x_T
pred_noise, x_start, *_ = self.model_predictions(model_input, time_cond)
if clip_denoised:
x_start.clamp_(-1., 1.)
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = ((1 - alpha_next) - sigma ** 2).sqrt()
noise = torch.randn_like(x_T) if time_next > 0 else 0.
x_T = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
traj.append(x_T.clone())
if traj:
return x_T, traj
else:
return x_T
@torch.no_grad()
def sample(self, dim, batch_size, noise=None, clip_denoised = True, traj=False, cond=None):
batch, device, objective = batch_size, self.betas.device, self.objective
traj = []
x_T = default(noise, torch.randn(batch, dim, device = device))
for t in reversed(range(0, self.num_timesteps)):
time_cond = torch.full((batch,), t, device = device, dtype = torch.long)
model_input = (x_T, cond) if cond is not None else x_T
pred_noise, x_start, *_ = self.model_predictions(model_input, time_cond)
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, _, model_log_variance = self.q_posterior(x_start = x_start, x_t = x_T, t = time_cond)
noise = torch.randn_like(x_T) if t > 0 else 0. # no noise if t == 0
x_T = model_mean + (0.5 * model_log_variance).exp() * noise
traj.append(x_T.clone())
if traj:
return x_T, traj
else:
return x_T
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# "nice property": return x_t given x_0, noise, and timestep
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
#noise = torch.clamp(noise, min=-6.0, max=6.0)
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
# main function for calculating loss
def forward(self, x_start, t, ret_pred_x=False, noise = None, cond=None):
'''
x_start: [B, D]
t: [B]
'''
noise = default(noise, lambda: torch.randn_like(x_start))
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_in = (x, cond) if cond is not None else x
model_out = self.model(model_in, t)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
else:
raise ValueError(f'unknown objective {self.objective}')
loss = self.loss_fn(model_out, target, reduction = 'none')
#loss = reduce(loss, 'b ... -> b (...)', 'mean', b = x_start.shape[0]) # only one dim of latent so don't need this line
loss = loss * extract(self.p2_loss_weight, t, loss.shape)
unreduced_loss = loss.detach().clone().mean(dim=1)
if ret_pred_x:
return loss.mean(), x, target, model_out, unreduced_loss
else:
return loss.mean(), unreduced_loss
def model_predictions(self, model_input, t):
#model_output1 = self.model(model_input, t, pass_cond=0)
#model_output2 = self.model(model_input, t, pass_cond=1)
#model_output = model_output2*5 - model_output1*4
model_output = self.model(model_input, t, pass_cond=1)
x = model_input[0] if type(model_input) is tuple else model_input
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, model_output)
elif self.objective == 'pred_x0':
pred_noise = self.predict_noise_from_start(x, t, model_output)
x_start = model_output
return ModelPrediction(pred_noise, x_start)
# a wrapper function that only takes x_start (clean modulation vector) and condition
# does everything including sampling timestep and returns loss, loss_100, loss_1000, prediction
def diffusion_model_from_latent(self, x_start, cond=None):
#if self.perturb_pc is None and cond is not None:
# print("check whether to pass condition!!!")
# STEP 1: sample timestep
t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device).long()
# STEP 2: perturb condition
pc = perturb_point_cloud(cond, self.perturb_pc, self.pc_size, self.crop_percent) if cond is not None else None
# STEP 3: pass to forward function
loss, x, target, model_out, unreduced_loss = self(x_start, t, cond=pc, ret_pred_x=True)
loss_100 = unreduced_loss[t<100].mean().detach()
loss_1000 = unreduced_loss[t>100].mean().detach()
return loss, loss_100, loss_1000, model_out, pc
def generate_from_pc(self, pc, load_pc=False, batch=5, save_pc=False, return_pc=False, ddim=False, perturb_pc=True):
self.eval()
with torch.no_grad():
if load_pc:
pc = sample_pc(pc, self.pc_size).cuda().unsqueeze(0)
if pc is None:
input_pc = None
save_pc = False
full_perturbed_pc = None
else:
if perturb_pc:
full_perturbed_pc = perturb_point_cloud(pc, self.perturb_pc)
perturbed_pc = full_perturbed_pc[:, torch.randperm(full_perturbed_pc.shape[1])[:self.pc_size] ]
input_pc = perturbed_pc.repeat(batch, 1, 1)
else:
full_perturbed_pc = pc
perturbed_pc = pc
input_pc = pc.repeat(batch, 1, 1)
#print("shapes: ", pc.shape, self.pc_size, self.perturb_pc, perturbed_pc.shape, full_perturbed_pc.shape)
#print("pc path: ", pc_path)
#print("pc shape: ", perturbed_pc.shape, input_pc.shape)
if save_pc: # save perturbed pc ply file for visualization
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(perturbed_pc.cpu().numpy().squeeze())
o3d.io.write_point_cloud("{}/input_pc.ply".format(save_pc), pcd)
sample_fn = self.ddim_sample if ddim else self.sample
samp,_ = sample_fn(dim=self.model.dim_in_out, batch_size=batch, traj=False, cond=input_pc)
if return_pc:
return samp, perturbed_pc
return samp
def generate_unconditional(self, num_samples):
self.eval()
with torch.no_grad():
samp,_ = self.sample(dim=self.model.dim_in_out, batch_size=num_samples, traj=False, cond=None)
return samp
================================================
FILE: models/sdf_model.py
================================================
#!/usr/bin/env python3
import torch.nn as nn
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import sys
import os
from pathlib import Path
import numpy as np
import math
from einops import rearrange, reduce
from models.archs.sdf_decoder import *
from models.archs.encoders.conv_pointnet import ConvPointnet
from utils import mesh, evaluate
class SdfModel(pl.LightningModule):
def __init__(self, specs):
super().__init__()
self.specs = specs
model_specs = self.specs["SdfModelSpecs"]
self.hidden_dim = model_specs["hidden_dim"]
self.latent_dim = model_specs["latent_dim"]
self.skip_connection = model_specs.get("skip_connection", True)
self.tanh_act = model_specs.get("tanh_act", False)
self.pn_hidden = model_specs.get("pn_hidden_dim", self.latent_dim)
self.pointnet = ConvPointnet(c_dim=self.latent_dim, hidden_dim=self.pn_hidden, plane_resolution=64)
self.model = SdfDecoder(latent_size=self.latent_dim, hidden_dim=self.hidden_dim, skip_connection=self.skip_connection, tanh_act=self.tanh_act)
self.model.train()
#print(self.model)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), self.specs["sdf_lr"])
return optimizer
def training_step(self, x, idx):
xyz = x['xyz'] # (B, 16000, 3)
gt = x['gt_sdf'] # (B, 16000)
pc = x['point_cloud'] # (B, 1024, 3)
shape_features = self.pointnet(pc, xyz)
pred_sdf = self.model(xyz, shape_features)
sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze(), reduction = 'none')
sdf_loss = reduce(sdf_loss, 'b ... -> b (...)', 'mean').mean()
return sdf_loss
def forward(self, pc, xyz):
shape_features = self.pointnet(pc, xyz)
return self.model(xyz, shape_features).squeeze()
def forward_with_plane_features(self, plane_features, xyz):
'''
plane_features: B, D*3, res, res (e.g. B, 768, 64, 64)
xyz: B, N, 3
'''
point_features = self.pointnet.forward_with_plane_features(plane_features, xyz) # point_features: B, N, D
pred_sdf = self.model( torch.cat((xyz, point_features),dim=-1) )
return pred_sdf # [B, num_points]
================================================
FILE: test.py
================================================
#!/usr/bin/env python3
import torch
import torch.utils.data
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from pytorch_lightning import loggers as pl_loggers
import os
import json, csv
import time
from tqdm.auto import tqdm
from einops import rearrange, reduce
import numpy as np
import trimesh
import warnings
# add paths in model/__init__.py for new models
from models import *
from utils import mesh, evaluate
from utils.reconstruct import *
from diff_utils.helpers import *
#from metrics.evaluation_metrics import *#compute_all_metrics
#from metrics import evaluation_metrics
from dataloader.pc_loader import PCloader
@torch.no_grad()
def test_modulations():
# load dataset, dataloader, model checkpoint
test_split = json.load(open(specs["TestSplit"]))
test_dataset = PCloader(specs["DataSource"], test_split, pc_size=specs.get("PCsize",1024), return_filename=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=0)
ckpt = "{}.ckpt".format(args.resume) if args.resume=='last' else "epoch={}.ckpt".format(args.resume)
resume = os.path.join(args.exp_dir, ckpt)
model = CombinedModel.load_from_checkpoint(resume, specs=specs).cuda().eval()
# filename for logging chamfer distances of reconstructed meshes
cd_file = os.path.join(recon_dir, "cd.csv")
with tqdm(test_dataloader) as pbar:
for idx, data in enumerate(pbar):
pbar.set_description("Files evaluated: {}/{}".format(idx, len(test_dataloader)))
point_cloud, filename = data # filename = path to the csv file of sdf data
filename = filename[0] # filename is a tuple
cls_name = filename.split("/")[-3]
mesh_name = filename.split("/")[-2]
outdir = os.path.join(recon_dir, "{}/{}".format(cls_name, mesh_name))
os.makedirs(outdir, exist_ok=True)
mesh_filename = os.path.join(outdir, "reconstruct")
# given point cloud, create modulations (e.g. 1D latent vectors)
plane_features = model.sdf_model.pointnet.get_plane_features(point_cloud.cuda()) # tuple, 3 items with ([1, D, resolution, resolution])
plane_features = torch.cat(plane_features, dim=1) # ([1, D*3, resolution, resolution])
recon = model.vae_model.generate(plane_features) # ([1, D*3, resolution, resolution])
#print("mesh filename: ", mesh_filename)
# N is the grid resolution for marching cubes; set max_batch to largest number gpu can hold
mesh.create_mesh(model.sdf_model, recon, mesh_filename, N=256, max_batch=2**21, from_plane_features=True)
# load the created mesh (mesh_filename), and compare with input point cloud
# to calculate and log chamfer distance
mesh_log_name = cls_name+"/"+mesh_name
try:
evaluate.main(point_cloud, mesh_filename, cd_file, mesh_log_name)
except Exception as e:
print(e)
# save modulation vectors for training diffusion model for next stage
# filter based on the chamfer distance so that all training data for diffusion model is clean
# would recommend visualizing some reconstructed meshes and manually determining what chamfer distance threshold to use
try:
# skips modulations that have chamfer distance > 0.0018
# the filter also weighs gaps / empty space higher
if not filter_threshold(mesh_filename, point_cloud, 0.0018):
continue
outdir = os.path.join(latent_dir, "{}/{}".format(cls_name, mesh_name))
os.makedirs(outdir, exist_ok=True)
features = model.sdf_model.pointnet.get_plane_features(point_cloud.cuda())
features = torch.cat(features, dim=1) # ([1, D*3, resolution, resolution])
latent = model.vae_model.get_latent(features) # (1, D*3)
np.savetxt(os.path.join(outdir, "latent.txt"), latent.cpu().numpy())
except Exception as e:
print(e)
@torch.no_grad()
def test_generation():
# load model
if args.resume == 'finetune': # after second stage of training
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# loads the sdf and vae models
model = CombinedModel.load_from_checkpoint(specs["modulation_ckpt_path"], specs=specs, strict=False)
# loads the diffusion model; directly calling diffusion_model.load_state_dict to prevent overwriting sdf and vae params
ckpt = torch.load(specs["diffusion_ckpt_path"])
new_state_dict = {}
for k,v in ckpt['state_dict'].items():
new_key = k.replace("diffusion_model.", "") # remove "diffusion_model." from keys since directly loading into diffusion model
new_state_dict[new_key] = v
model.diffusion_model.load_state_dict(new_state_dict)
model = model.cuda().eval()
else:
ckpt = "{}.ckpt".format(args.resume) if args.resume=='last' else "epoch={}.ckpt".format(args.resume)
resume = os.path.join(args.exp_dir, ckpt)
model = CombinedModel.load_from_checkpoint(resume, specs=specs).cuda().eval()
conditional = specs["diffusion_model_specs"]["cond"]
if not conditional:
samples = model.diffusion_model.generate_unconditional(args.num_samples)
plane_features = model.vae_model.decode(samples)
for i in range(len(plane_features)):
plane_feature = plane_features[i].unsqueeze(0)
mesh.create_mesh(model.sdf_model, plane_feature, recon_dir+"/{}_recon".format(i), N=128, max_batch=2**21, from_plane_features=True)
else:
# load dataset, dataloader, model checkpoint
test_split = json.load(open(specs["TestSplit"]))
test_dataset = PCloader(specs["DataSource"], test_split, pc_size=specs.get("PCsize",1024), return_filename=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=0)
with tqdm(test_dataloader) as pbar:
for idx, data in enumerate(pbar):
pbar.set_description("Files generated: {}/{}".format(idx, len(test_dataloader)))
point_cloud, filename = data # filename = path to the csv file of sdf data
filename = filename[0] # filename is a tuple
cls_name = filename.split("/")[-3]
mesh_name = filename.split("/")[-2]
outdir = os.path.join(recon_dir, "{}/{}".format(cls_name, mesh_name))
os.makedirs(outdir, exist_ok=True)
# filter, set threshold manually after a few visualizations
if args.filter:
threshold = 0.08
tmp_lst = []
count = 0
while len(tmp_lst) b', 'mean', b = consistency.shape[0]) # one value per generated sample
#print("consistency shape: ", consistency.shape, loss.shape, consistency[0].mean(), consistency[1].mean(), loss) # cons: [B,N]; loss: [B]
thresh_idx = loss<=threshold
tmp_lst.extend(plane_features[thresh_idx])
if count > 5: # repeat this filtering process as needed
break
# skip the point cloud if cannot produce consistent samples or
# just use the samples that are produced if comparing to other methods
if len(tmp_lst)<1:
continue
plane_features = tmp_lst[0:min(10,len(tmp_lst))]
else:
# for each point cloud, the partial pc and its conditional generations are all saved in the same directory
samples, perturbed_pc = model.diffusion_model.generate_from_pc(point_cloud.cuda(), batch=args.num_samples, save_pc=outdir, return_pc=True)
plane_features = model.vae_model.decode(samples)
for i in range(len(plane_features)):
plane_feature = plane_features[i].unsqueeze(0)
mesh.create_mesh(model.sdf_model, plane_feature, outdir+"/{}_recon".format(i), N=128, max_batch=2**21, from_plane_features=True)
if __name__ == "__main__":
import argparse
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--exp_dir", "-e", required=True,
help="This directory should include experiment specifications in 'specs.json,' and logging will be done in this directory as well.",
)
arg_parser.add_argument(
"--resume", "-r", default=None,
help="continue from previous saved logs, integer value, 'last', or 'finetune'",
)
arg_parser.add_argument("--num_samples", "-n", default=5, type=int, help='number of samples to generate and reconstruct')
arg_parser.add_argument("--filter", default=False, help='whether to filter when sampling conditionally')
args = arg_parser.parse_args()
specs = json.load(open(os.path.join(args.exp_dir, "specs.json")))
print(specs["Description"])
recon_dir = os.path.join(args.exp_dir, "recon")
os.makedirs(recon_dir, exist_ok=True)
if specs['training_task'] == 'modulation':
latent_dir = os.path.join(args.exp_dir, "modulations")
os.makedirs(latent_dir, exist_ok=True)
test_modulations()
elif specs['training_task'] == 'combined':
test_generation()
================================================
FILE: train.py
================================================
#!/usr/bin/env python3
import torch
import torch.utils.data
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from pytorch_lightning import loggers as pl_loggers
import os
import json, csv
import time
from tqdm.auto import tqdm
from einops import rearrange, reduce
import numpy as np
import trimesh
import warnings
# add paths in model/__init__.py for new models
from models import *
from utils import mesh, evaluate
from utils.reconstruct import *
from diff_utils.helpers import *
#from metrics.evaluation_metrics import *#compute_all_metrics
#from metrics import evaluation_metrics
from dataloader.pc_loader import PCloader
from dataloader.sdf_loader import SdfLoader
from dataloader.modulation_loader import ModulationLoader
def train():
# initialize dataset and loader
split = json.load(open(specs["TrainSplit"], "r"))
if specs['training_task'] == 'diffusion':
train_dataset = ModulationLoader(specs["data_path"], pc_path=specs.get("pc_path",None), split_file=split, pc_size=specs.get("total_pc_size", None))
else:
train_dataset = SdfLoader(specs["DataSource"], split, pc_size=specs.get("PCsize",1024), grid_source=specs.get("GridSource", None), modulation_path=specs.get("modulation_path", None))
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size, num_workers=args.workers,
drop_last=True, shuffle=True, pin_memory=True, persistent_workers=True
)
# creates a copy of current code / files in the config folder
save_code_to_conf(args.exp_dir)
# pytorch lightning callbacks
callback = ModelCheckpoint(dirpath=args.exp_dir, filename='{epoch}', save_top_k=-1, save_last=True, every_n_epochs=specs["log_freq"])
lr_monitor = LearningRateMonitor(logging_interval='step')
callbacks = [callback, lr_monitor]
model = CombinedModel(specs)
# note on loading from checkpoint:
# if resuming from training modulation, diffusion, or end-to-end, just load saved checkpoint
# however, if fine-tuning end-to-end after training modulation and diffusion separately, will need to load sdf and diffusion checkpoints separately
if args.resume == 'finetune':
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model = model.load_from_checkpoint(specs["modulation_ckpt_path"], specs=specs, strict=False)
# loads the diffusion model; directly calling diffusion_model.load_state_dict to prevent overwriting sdf and vae params
ckpt = torch.load(specs["diffusion_ckpt_path"])
new_state_dict = {}
for k,v in ckpt['state_dict'].items():
new_key = k.replace("diffusion_model.", "") # remove "diffusion_model." from keys since directly loading into diffusion model
new_state_dict[new_key] = v
model.diffusion_model.load_state_dict(new_state_dict)
resume = None
elif args.resume is not None:
ckpt = "{}.ckpt".format(args.resume) if args.resume=='last' else "epoch={}.ckpt".format(args.resume)
resume = os.path.join(args.exp_dir, ckpt)
else:
resume = None
# precision 16 can be unstable (nan loss); recommend using 32
trainer = pl.Trainer(accelerator='gpu', devices=-1, precision=32, max_epochs=specs["num_epochs"], callbacks=callbacks, log_every_n_steps=1,
default_root_dir=os.path.join("tensorboard_logs", args.exp_dir))
trainer.fit(model=model, train_dataloaders=train_dataloader, ckpt_path=resume)
if __name__ == "__main__":
import argparse
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument(
"--exp_dir", "-e", required=True,
help="This directory should include experiment specifications in 'specs.json,' and logging will be done in this directory as well.",
)
arg_parser.add_argument(
"--resume", "-r", default=None,
help="continue from previous saved logs, integer value, 'last', or 'finetune'",
)
arg_parser.add_argument("--batch_size", "-b", default=32, type=int)
arg_parser.add_argument( "--workers", "-w", default=8, type=int)
args = arg_parser.parse_args()
specs = json.load(open(os.path.join(args.exp_dir, "specs.json")))
print(specs["Description"])
train()
================================================
FILE: utils/__init__.py
================================================
#!/usr/bin/env python3
================================================
FILE: utils/chamfer.py
================================================
import numpy as np
from scipy.spatial import cKDTree as KDTree
def compute_trimesh_chamfer(gt_points, gen_points, offset=0, scale=1):
"""
This function computes a symmetric chamfer distance, i.e. the sum of both chamfers.
gt_points: numpy array. trimesh.points.PointCloud of just poins, sampled from the surface (see
compute_metrics.ply for more documentation)
gen_mesh: numpy array. trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction
method (see compute_metrics.py for more)
"""
# gen_points_sampled = trimesh.sample.sample_surface(gen_mesh, num_mesh_samples)[0]
gen_points = gen_points / scale - offset
# one direction
gen_points_kd_tree = KDTree(gen_points)
one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points)
gt_to_gen_chamfer = np.mean(np.square(one_distances))
# other direction
gt_points_kd_tree = KDTree(gt_points)
two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points)
gen_to_gt_chamfer = np.mean(np.square(two_distances))
return gt_to_gen_chamfer + gen_to_gt_chamfer
def scale_to_unit_sphere(points):
"""
scale point clouds into a unit sphere
:param points: (n, 3) numpy array
:return:
"""
midpoints = (np.max(points, axis=0) + np.min(points, axis=0)) / 2
points = points - midpoints
scale = np.max(np.sqrt(np.sum(points ** 2, axis=1)))
points = points / scale
return points
================================================
FILE: utils/evaluate.py
================================================
#!/usr/bin/env python3
import argparse
import logging
import json
import numpy as np
import pandas as pd
import os, sys
import trimesh
from scipy.spatial import cKDTree as KDTree
from utils import uhd, tmd
import csv
def main(gt_pc, recon_mesh, out_file, mesh_name, return_value=False, return_sampled_pc=False, prioritize_cov=False, pc_size=None):
gt_pc = gt_pc.cpu().detach().numpy().squeeze()
recon_mesh = trimesh.load(os.path.join(os.getcwd(), recon_mesh)+".ply")
recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, gt_pc.shape[0])
full_recon_pc = trimesh.sample.sample_surface(recon_mesh, pc_size)[0] if pc_size is not None else recon_pc
recon_kd_tree = KDTree(recon_pc)
one_distances, one_vertex_ids = recon_kd_tree.query(gt_pc)
gt_to_recon_chamfer = np.mean(np.square(one_distances))
# other direction
gt_kd_tree = KDTree(gt_pc)
two_distances, two_vertex_ids = gt_kd_tree.query(recon_pc)
recon_to_gt_chamfer = np.mean(np.square(two_distances))
if prioritize_cov: # higher CD for gaps/holes
loss_chamfer = gt_to_recon_chamfer * 2.0 + recon_to_gt_chamfer * 0.5
else:
loss_chamfer = gt_to_recon_chamfer + recon_to_gt_chamfer
if return_value:
return loss_chamfer
out_file = os.path.join(os.getcwd(), out_file)
with open(out_file,"a",) as f:
writer = csv.writer(f)
writer.writerow([mesh_name,loss_chamfer])
if return_sampled_pc:
return full_recon_pc, loss_chamfer
def calc_cd(gt_pc, recon_pc):
gt_pc = gt_pc.cpu().detach().numpy().squeeze()
recon_kd_tree = KDTree(recon_pc)
one_distances, one_vertex_ids = recon_kd_tree.query(gt_pc)
gt_to_recon_chamfer = np.mean(np.square(one_distances))
# other direction
gt_kd_tree = KDTree(gt_pc)
two_distances, two_vertex_ids = gt_kd_tree.query(recon_pc)
recon_to_gt_chamfer = np.mean(np.square(two_distances))
return gt_to_recon_chamfer + recon_to_gt_chamfer
def single_eval(gt_csv, recon_mesh):
# f=pd.read_csv(gt_csv, sep=',',header=None).values
# f = f[f[:,-1]==0][:,:3]
recon_mesh = trimesh.load( recon_mesh )
recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, 30000)
print("recon pc min max: ", recon_pc.max(), recon_pc.min())
# load from SIREN .xyz file
f = np.genfromtxt(gt_csv)
pc = f[:,:3]
coord_max = np.amax(pc, axis=0, keepdims=True)
coord_min = np.amin(pc, axis=0, keepdims=True)
coords = (pc - coord_min) / (coord_max - coord_min)
coords -= 0.5
coords *= 2.
# pc -= np.mean(pc, axis=0, keepdims=True)
# bbox_length = np.sqrt( np.sum((np.max(pc, axis=0) - np.min(pc, axis=0))**2) )
# pc /= bbox_length
f = coords
print("f min max: ", f.max(), f.min())
pc_idx = np.random.choice(f.shape[0], 30000, replace=False)
gt_pc = f[pc_idx]
recon_mesh = trimesh.load( recon_mesh )
recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, 30000)
recon_kd_tree = KDTree(recon_pc)
one_distances, one_vertex_ids = recon_kd_tree.query(gt_pc)
gt_to_recon_chamfer = np.mean(np.square(one_distances))
# other direction
gt_kd_tree = KDTree(gt_pc)
two_distances, two_vertex_ids = gt_kd_tree.query(recon_pc)
recon_to_gt_chamfer = np.mean(np.square(two_distances))
loss_chamfer = gt_to_recon_chamfer + recon_to_gt_chamfer
print("CD loss: ", loss_chamfer)
if __name__ == "__main__":
single_eval(sys.argv[1], sys.argv[2])
================================================
FILE: utils/mesh.py
================================================
#!/usr/bin/env python3
import logging
import math
import numpy as np
import plyfile
import skimage.measure
import time
import torch
# N: resolution of grid; 256 is typically sufficient
# max batch: as large as GPU memory will allow
# shape_feature is either point cloud, mesh_idx (neuralpull), or generated latent code (deepsdf)
def create_mesh(
model, shape_feature, filename, N=256, max_batch=1000000, level_set=0.0, occupancy=False, point_cloud=None, from_plane_features=False, from_pc_features=False
):
start_time = time.time()
ply_filename = filename
model.eval()
# the voxel_origin is the (bottom, left, down) corner, not the middle
voxel_origin = [-1, -1, -1]
voxel_size = 2.0 / (N - 1)
cube = create_cube(N)
cube_points = cube.shape[0]
head = 0
while head < cube_points:
query = cube[head : min(head + max_batch, cube_points), 0:3].unsqueeze(0)
# inference defined in forward function per pytorch lightning convention
#print("shapes: ", shape_feature.shape, query.shape)
if from_plane_features:
pred_sdf = model.forward_with_plane_features(shape_feature.cuda(), query.cuda()).detach().cpu()
else:
pred_sdf = model(shape_feature.cuda(), query.cuda()).detach().cpu()
cube[head : min(head + max_batch, cube_points), 3] = pred_sdf.squeeze()
head += max_batch
# for occupancy instead of SDF, subtract 0.5 so the surface boundary becomes 0
sdf_values = cube[:, 3] - 0.5 if occupancy else cube[:, 3]
sdf_values = sdf_values.reshape(N, N, N)
#print("inference time: {}".format(time.time() - start_time))
convert_sdf_samples_to_ply(
sdf_values.data,
voxel_origin,
voxel_size,
ply_filename + ".ply",
level_set
)
# create cube from (-1,-1,-1) to (1,1,1) and uniformly sample points for marching cube
def create_cube(N):
overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
samples = torch.zeros(N ** 3, 4)
# the voxel_origin is the (bottom, left, down) corner, not the middle
voxel_origin = [-1, -1, -1]
voxel_size = 2.0 / (N - 1)
# transform first 3 columns
# to be the x, y, z index
samples[:, 2] = overall_index % N
samples[:, 1] = (overall_index.long().float() / N) % N
samples[:, 0] = ((overall_index.long().float() / N) / N) % N
# transform first 3 columns
# to be the x, y, z coordinate
samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
samples.requires_grad = False
return samples
def convert_sdf_samples_to_ply(
pytorch_3d_sdf_tensor,
voxel_grid_origin,
voxel_size,
ply_filename_out,
level_set=0.0
):
"""
Convert sdf samples to .ply
:param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
:voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
:voxel_size: float, the size of the voxels
:ply_filename_out: string, path of the filename to save to
This function adapted from: https://github.com/RobotLocomotion/spartan
"""
numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
# use marching_cubes_lewiner or marching_cubes depending on pytorch version
try:
verts, faces, normals, values = skimage.measure.marching_cubes(
numpy_3d_sdf_tensor, level=level_set, spacing=[voxel_size] * 3
)
except Exception as e:
print("skipping {}; error: {}".format(ply_filename_out, e))
return
# transform from voxel coordinates to camera coordinates
# note x and y are flipped in the output of marching_cubes
mesh_points = np.zeros_like(verts)
mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
num_verts = verts.shape[0]
num_faces = faces.shape[0]
verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
for i in range(0, num_verts):
verts_tuple[i] = tuple(mesh_points[i, :])
faces_building = []
for i in range(0, num_faces):
faces_building.append(((faces[i, :].tolist(),)))
faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
ply_data = plyfile.PlyData([el_verts, el_faces])
ply_data.write(ply_filename_out)
================================================
FILE: utils/mmd.py
================================================
================================================
FILE: utils/reconstruct.py
================================================
#!/usr/bin/env python3
import torch
import torch.utils.data
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning import loggers as pl_loggers
import os
import json
import time
from tqdm.auto import tqdm
from einops import rearrange, reduce
import numpy as np
import trimesh
# add paths in model/__init__.py for new models
from models import *
from utils import mesh, evaluate, reconstruct
from diff_utils.helpers import *
def vis_recon(test_dataloader, sdf_model, vae_model, recon_dir, take_mod=False, calc_cd=False):
resolution = 64
recon_batch = 2**20
# run visualization of the reconstructed plane features to confirm that recon loss is low enough
with torch.no_grad():
if args.evaluate:
point_clouds, pc_paths = test_dataloader.get_all_files()
point_clouds = torch.stack(point_clouds) # stack [(1024,3), (1024,3)...] to (B, 1024, 3)
recon_meshes = torch.empty(*point_clouds.shape)
# ***change to MODULATION generated paths later!!!!
for idx, path in enumerate(pc_paths):
#print("path: ", path)
cls_name = path.split("/")[-3]
mesh_name = path.split("/")[-2]
mesh_filename = os.path.join(recon_dir, "{}/{}/reconstruct".format(cls_name, mesh_name))
recon_mesh = trimesh.load(os.path.join(os.getcwd(), mesh_filename)+".ply")
recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, point_clouds.shape[1])
recon_meshes[idx] = torch.from_numpy(recon_pc)
print("ref, recon shapes: ", recon_meshes.shape, point_clouds.shape) # should both be = B, N, 3
results = evaluation_metrics.compute_all_metrics(recon_meshes.float(), point_clouds.float(), accelerated_cd=False)
for k,v in results.items():
print(k, ": ", v)
elif args.take_mod and not args.sample:
lst = []
if args.mod_folder:
files = os.listdir(args.mod_folder)
for f in files:
if os.path.isfile(os.path.join(args.mod_folder, f)) and f[-4:]=='.txt':
lst.append(os.path.join(args.mod_folder, f))
else:
lst = args.take_mod
for idx, m in enumerate(lst):
latent = torch.from_numpy(np.loadtxt(m)).float().cuda()
recon = vae_model.decode(latent)
name = args.output_name if args.output_name else "mod_recon"
name += "{}".format(idx)
os.makedirs(os.path.join(recon_dir, "modulation_recon"), exist_ok=True)
mesh_filename = os.path.join(recon_dir, "modulation_recon", name)
mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)
elif args.sample:
recon = vae_model.sample(num_samples=1)
name = args.output_name if args.output_name else "mod_recon"
os.makedirs(os.path.join(recon_dir, "modulation_recon"), exist_ok=True)
mesh_filename = os.path.join(recon_dir, "modulation_recon", name)
mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)
else:
for idx, data in enumerate(test_dataloader): # test_loader does not shuffle
# if idx % 10 != 0:
# continue
data, filename = data # filename = path to the csv file of sdf data
filename = filename[0] # filename is a tuple for some reason
random_flip = specs.get("random_flip", False)
# if random_flip:
# flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=data.device)
# prob = torch.randint(low=0, high=4, size=(1,))
# flip_axis = flip_axes[prob] # shape=[1,3]
# data *= flip_axis.unsqueeze(0).repeat(data.shape[0], data.shape[1], 1)
# if random_flip:
# flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=data.device)
# for axis in flip_axes:
# flipped_data = data * axis.unsqueeze(0).repeat(data.shape[0], data.shape[1], 1)
cls_name = filename.split("/")[-3]
mesh_name = filename.split("/")[-2]
outdir = os.path.join(recon_dir, "{}/{}".format(cls_name, mesh_name))
os.makedirs(outdir, exist_ok=True)
mesh_filename = os.path.join(outdir, "reconstruct")
plane_features = sdf_model.pointnet.get_plane_features(data.cuda()) # 3 items with ([1, 256, 64, 64])
plane_features = torch.cat(plane_features, dim=1) # ([1, 768, 64, 64])
recon = vae_model.generate(plane_features) # ([1, 768, 64, 64])
# create_mesh samples the grid points, then calls sdf_model.forward_with_plane_features, which calls pointnet.forward_with_plane_features
#print("mesh filename: ", mesh_filename)
mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)
if calc_cd:
evaluate_filename = os.path.join(recon_dir, "cd.csv")
mesh_log_name = cls_name+"/"+mesh_name
try:
evaluate.main(data, mesh_filename, evaluate_filename, mesh_log_name)
except Exception as e:
print(e)
# try:
# if not filter_threshold(mesh_filename, data, 0.0018):
# continue
# outdir = os.path.join(latent_dir, "{}/{}".format(cls_name, mesh_name))
# os.makedirs(outdir, exist_ok=True)
# features = sdf_model.pointnet.get_plane_features(data.cuda())
# #print("features shape: ", features[0].shape) # ([1, 256, 64, 64])
# features = torch.cat(features, dim=1)
# latent = vae_model.get_latent(features)
# #print("latent shape: ", latent.shape)
# np.savetxt(os.path.join(outdir, "latent.txt"), latent.cpu().numpy())
# except Exception as e:
# print(e)
def filter_threshold(mesh, gt_pc, threshold): # mesh is path to mesh without .ply ext
cd = evaluate.main(gt_pc, mesh, None, None, return_value=True, prioritize_cov=True)
return cd <= threshold
def extract_latents(test_dataloader, sdf_model, vae_model, save_dir):
# only extract the latent vectors
latent_dir = os.path.join(save_dir, "modulations")
os.makedirs(latent_dir, exist_ok=True)
with torch.no_grad():
for idx, data in enumerate(test_dataloader): # test_loader does not shuffle
data, filename = data
filename = filename[0] # filename is a tuple for some reason
cls_name = filename.split("/")[-3]
mesh_name = filename.split("/")[-2]
# if filtering based on CD threshold
saved_mesh = os.path.join(recon_dir, "{}/{}/reconstruct".format(cls_name, mesh_name))
gt_pc = data
try:
if not filter_threshold(saved_mesh, gt_pc, 0.0022):
continue
outdir = os.path.join(latent_dir, "{}/{}".format(cls_name, mesh_name))
os.makedirs(outdir, exist_ok=True)
random_flip = specs.get("random_flip", False)
if random_flip:
flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=data.device)
for idx, axis in enumerate(flip_axes):
flipped_data = data * axis.unsqueeze(0).repeat(data.shape[0], data.shape[1], 1)
features = sdf_model.pointnet.get_plane_features(flipped_data.cuda())
#print("features shape: ", features[0].shape) # ([1, 256, 64, 64])
features = torch.cat(features, dim=1)
latent = vae_model.get_latent(features)
#print("latent shape: ", latent.shape)
np.savetxt(os.path.join(outdir, "latent_{}.txt".format(idx)), latent.cpu().numpy())
else:
features = sdf_model.pointnet.get_plane_features(data.cuda())
#print("features shape: ", features[0].shape) # ([1, 256, 64, 64])
features = torch.cat(features, dim=1)
latent = vae_model.get_latent(features)
#print("latent shape: ", latent.shape)
np.savetxt(os.path.join(outdir, "latent.txt"), latent.cpu().numpy())
except Exception as e:
print(e)
================================================
FILE: utils/renderer.py
================================================
import numpy as np
import trimesh
import trimesh.transformations as tra
import pyrender
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as R
COLORS = [
np.array([255, 10, 10, 255]),
np.array([10, 255, 10, 255]),
np.array([10, 234, 255, 255])
]
class OnlineObjectRenderer:
def __init__(self, fov=np.pi / 6, caching=True):
"""
Args:
fov: float,
"""
self._fov = fov
self._scene = None
self._init_scene()
self._caching = caching
self._nodes = []
def _init_scene(self, height=480, width=480):
self._scene = pyrender.Scene()
camera = pyrender.PerspectiveCamera(yfov=self._fov, znear=0.001) # do not change aspect ratio
camera_pose = tra.euler_matrix(np.pi, 0, 0)
self._scene.add(camera, pose=camera_pose, name='camera')
direc_l = pyrender.DirectionalLight(color=np.ones(3), intensity=1.0)
self._scene.add(direc_l, pose=camera_pose)
self.renderer = pyrender.OffscreenRenderer(height, width)
def add_mesh(self, path, name, rotation=None, translation=None):
mesh = trimesh.load(path)
color = np.tile(COLORS[len(self._nodes) % len(COLORS)], (mesh.vertices.shape[0], 1))
# mesh_mean = np.mean(mesh.vertices, 0)
# mesh.vertices -= np.expand_dims(mesh_mean, 0)
mesh.visual.vertex_colors = color
mesh = pyrender.Mesh.from_trimesh(mesh.copy(), smooth=False)
if rotation is not None and (isinstance(rotation, list) or isinstance(rotation, tuple)):
rotation = np.array(rotation)
if rotation is not None and rotation.shape == (3, 3):
rotation = R.from_matrix(rotation).as_quat()
elif rotation is not None and rotation.shape == (3,):
rotation = R.from_euler("xyz", rotation).as_quat()
if translation is None:
translation = np.array((0, 0, 3))
else:
translation = translation.copy()
translation[2] += 3.
node = pyrender.Node(mesh=mesh, rotation=rotation, translation=translation, name=name)
self._scene.add_node(node)
self._nodes.append(node)
return name
def add_pointcloud(self, path, name, colors=None, rotation=None, translation=None):
pc = np.loadtxt(path, delimiter=",")
sm = trimesh.creation.uv_sphere(radius=0.008)
if colors is None:
colors = COLORS[len(self._nodes) % len(COLORS)]
sm.visual.vertex_colors = colors
tfs = np.tile(np.eye(4), (len(pc), 1, 1))
tfs[:,:3,3] = pc
if rotation is not None and (isinstance(rotation, list) or isinstance(rotation, tuple)):
rotation = np.array(rotation)
if rotation is not None and rotation.shape == (3, 3):
rotation = R.from_matrix(rotation).as_quat()
elif rotation is not None and rotation.shape == (3,):
rotation = R.from_euler("xyz", rotation).as_quat()
if translation is None:
translation = np.array((0, 0, 3))
else:
translation = translation.copy()
translation[2] += 3.
# pts = rotation @ pc + translation
mesh = pyrender.Mesh.from_trimesh(sm, poses=tfs)
node = pyrender.Node(mesh=mesh, rotation=rotation, translation=translation, name=name)
self._scene.add_node(node)
self._nodes.append(node)
return name
def clear(self):
for node in self._nodes:
self._scene.remove_node(node)
self._nodes = []
def render(self):
return self.renderer.render(self._scene)
if __name__ == "__main__":
renderer = OnlineObjectRenderer()
renderer.add_mesh(
"/path/to/mesh",
"o1", # name of output
np.array([np.pi/2, -np.pi/2, 0]),
np.array([0, .5, 0])
)
img, dp = renderer.render()
plt.axis("off")
plt.imshow(img)
plt.show()
================================================
FILE: utils/tmd.py
================================================
import argparse
import os
import numpy as np
import trimesh
from utils.chamfer import compute_trimesh_chamfer
import glob
def process_one(shape_dir):
pc_paths = glob.glob(os.path.join(shape_dir, "fake-z*.ply"))
pc_paths = sorted(pc_paths)
gen_pcs = []
for path in pc_paths:
sample_pts = trimesh.load(path)
sample_pts = sample_pts.vertices
gen_pcs.append(sample_pts)
sum_dist = 0
for j in range(len(gen_pcs)):
for k in range(j + 1, len(gen_pcs), 1):
pc1 = gen_pcs[j]
pc2 = gen_pcs[k]
chamfer_dist = compute_trimesh_chamfer(pc1, pc2)
sum_dist += chamfer_dist
mean_dist = sum_dist * 2 / (len(gen_pcs) - 1)
return mean_dist
def tmd_from_pcs(gen_pcs):
sum_dist = 0
for j in range(len(gen_pcs)):
for k in range(j + 1, len(gen_pcs), 1):
pc1 = gen_pcs[j]
pc2 = gen_pcs[k]
chamfer_dist = compute_trimesh_chamfer(pc1, pc2)
sum_dist += chamfer_dist
mean_dist = sum_dist * 2 / (len(gen_pcs) - 1)
return mean_dist
def Total_Mutual_Difference(args):
shape_names = sorted(os.listdir(args.src))
res = 0
all_shape_dir = [os.path.join(args.src, name) for name in shape_names]
results = Parallel(n_jobs=args.process, verbose=2)(delayed(process_one)(path) for path in all_shape_dir)
info_path = args.src + '-record_meandist.txt'
with open(info_path, 'w') as fp:
for i in range(len(shape_names)):
print("ID: {} \t mean_dist: {:.4f}".format(shape_names[i], results[i]), file=fp)
res = np.mean(results)
return res
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str)
parser.add_argument("-p", "--process", type=int, default=10)
parser.add_argument("-o", "--output", type=str)
args = parser.parse_args()
if args.output is None:
args.output = args.src + '-eval_TMD.txt'
res = Total_Mutual_Difference(args)
print("Avg Total Multual Difference: {}".format(res))
with open(args.output, "w") as fp:
fp.write("SRC: {}\n".format(args.src))
fp.write("Total Multual Difference: {}\n".format(res))
if __name__ == '__main__':
main()
================================================
FILE: utils/uhd.py
================================================
import argparse
import os
import torch
import numpy as np
from scipy.spatial import cKDTree as KDTree
import trimesh
import glob
from joblib import Parallel, delayed
def directed_hausdorff(point_cloud1:torch.Tensor, point_cloud2:torch.Tensor, reduce_mean=True):
"""
:param point_cloud1: (B, 3, N)
:param point_cloud2: (B, 3, M)
:return: directed hausdorff distance, A -> B
"""
n_pts1 = point_cloud1.shape[2]
n_pts2 = point_cloud2.shape[2]
pc1 = point_cloud1.unsqueeze(3)
pc1 = pc1.repeat((1, 1, 1, n_pts2)) # (B, 3, N, M)
pc2 = point_cloud2.unsqueeze(2)
pc2 = pc2.repeat((1, 1, n_pts1, 1)) # (B, 3, N, M)
l2_dist = torch.sqrt(torch.sum((pc1 - pc2) ** 2, dim=1)) # (B, N, M)
shortest_dist, _ = torch.min(l2_dist, dim=2)
hausdorff_dist, _ = torch.max(shortest_dist, dim=1) # (B, )
if reduce_mean:
hausdorff_dist = torch.mean(hausdorff_dist)
return hausdorff_dist
def nn_distance(query_points, ref_points):
ref_points_kd_tree = KDTree(ref_points)
one_distances, one_vertex_ids = ref_points_kd_tree.query(query_points)
return one_distances
def completeness(query_points, ref_points, thres=0.03):
a2b_nn_distance = nn_distance(query_points, ref_points)
percentage = np.sum(a2b_nn_distance < thres) / len(a2b_nn_distance)
return percentage
def process_one(shape_dir):
# load generated shape
pc_paths = glob.glob(os.path.join(shape_dir, "fake-z*.ply"))
pc_paths = sorted(pc_paths)
gen_pcs = []
for path in pc_paths:
sample_pts = trimesh.load(path)
sample_pts = np.asarray(sample_pts.vertices)
# sample_pts = torch.tensor(sample_pts.vertices).transpose(1, 0)
gen_pcs.append(sample_pts)
# load partial input
partial_path = os.path.join(shape_dir, "raw.ply")
partial_pc = trimesh.load(partial_path)
partial_pc = np.asarray(partial_pc.vertices)
# partial_pc = torch.tensor(partial_pc.vertices).transpose(1, 0)
# completeness percentage
gen_comp = 0
for sample_pts in gen_pcs:
comp = completeness(partial_pc, sample_pts)
gen_comp += comp
gen_comp = gen_comp / len(gen_pcs)
# unidirectional hausdorff
gen_pcs = [torch.tensor(pc).transpose(1, 0) for pc in gen_pcs]
gen_pcs = torch.stack(gen_pcs, dim=0)
partial_pc = torch.tensor(partial_pc).transpose(1, 0)
partial_pc = partial_pc.unsqueeze(0).repeat((gen_pcs.size(0), 1, 1))
hausdorff = directed_hausdorff(partial_pc, gen_pcs, reduce_mean=True).item()
return gen_comp, hausdorff
def uhd_from_pcs(gen_pcs, partial_pc):
# completeness percentage
gen_comp = 0
for sample_pts in gen_pcs:
comp = completeness(partial_pc, sample_pts)
gen_comp += comp
gen_comp = gen_comp / len(gen_pcs)
# unidirectional hausdorff
gen_pcs = [torch.tensor(pc).transpose(1, 0) for pc in gen_pcs]
gen_pcs = torch.stack(gen_pcs, dim=0)
partial_pc = torch.tensor(partial_pc).transpose(1, 0)
partial_pc = partial_pc.unsqueeze(0).repeat((gen_pcs.size(0), 1, 1))
hausdorff = directed_hausdorff(partial_pc, gen_pcs, reduce_mean=True).item()
return gen_comp, hausdorff
def func(args):
shape_names = sorted(os.listdir(args.src))
all_shape_dir = [os.path.join(args.src, name) for name in shape_names]
results = Parallel(n_jobs=args.process, verbose=2)(delayed(process_one)(path) for path in all_shape_dir)
res_comp, res_hausdorff = zip(*results)
res_comp = np.mean(res_comp)
res_hausdorff = np.mean(res_hausdorff)
return res_hausdorff, res_comp
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str)
parser.add_argument("-p", "--process", type=int, default=10)
parser.add_argument("-o", "--output", type=str)
args = parser.parse_args()
if args.output is None:
args.output = args.src + '-eval_UHD.txt'
res_hausdorff, res_comp = func(args)
print("Avg Unidirectional Hausdorff Distance: {}".format(res_hausdorff))
print("Avg Completeness: {}".format(res_comp))
with open(args.output, "a") as fp:
fp.write("SRC: {}\n".format(args.src))
fp.write("Avg Unidirectional Hausdorff Distance: {}\n".format(res_hausdorff))
fp.write("Avg Completeness: {}\n".format(res_comp))
if __name__ == '__main__':
main()
================================================
FILE: utils/visualize.py
================================================
import numpy as np
import trimesh
import sys
import h5py
import open3d as o3d
def create_point_marker(center, color):
# point cloud point info (e.g. radius)
point = trimesh.primitives.Sphere(
radius=0.002,
center=center
)
point.visual.vertex_colors = color
return point
def get_color(labels):
return np.stack([np.ones(labels.shape[0]) - labels, labels, np.zeros(labels.shape[0])], axis=1)
def vis_pc(obj_pc):
point_color = get_color(np.array([1.0]))
point_markers = [create_point_marker(center=obj_pc[t], color=point_color[0]) for t in range(len(obj_pc))]
trimesh.Scene(point_markers).show()
def main(mesh, pc, query):
lst = []
if pc:
point_color = get_color(np.array([0.5]))
point_markers = [create_point_marker(center=point_cloud[t], color=point_color[0]) for t in range(len(point_cloud))]
lst.append(trimesh.Scene(point_markers))
if query:
q_color = get_color(np.array([0.0]))
q_markers = [create_point_marker(center=queries[t], color=q_color[0]) for t in range(len(queries))]
lst.append(trimesh.Scene(q_markers))
if mesh:
lst.append(object_mesh)
scene = trimesh.scene.scene.append_scenes( lst ).show()
#trimesh.exchange.export.export_scene(scene, "vis.ply")
object_mesh = trimesh.load(sys.argv[1])
object_mesh.apply_scale(0.1)
pc = object_mesh.vertices
point_cloud = np.loadtxt(sys.argv[2], dtype=float, delimiter=',') # num of points x 3
p_idx = np.random.choice(point_cloud.shape[0], 10000)
point_cloud = point_cloud[p_idx][:,0:3]
# for visualizing pc only
# pcd = o3d.geometry.PointCloud()
# pcd.points = o3d.utility.Vector3dVector(point_cloud)
# #o3d.io.write_point_cloud("./pc.ply", pcd)
# o3d.visualization.draw_geometries([pcd])
# mesh, pc, query
main(False, True, False)