main 7eb2d6786b88 cached
81 files
38.4 MB
87.3k tokens
330 symbols
1 requests
Download .txt
Showing preview only (319K chars total). Download the full file or copy to clipboard to get everything.
Repository: princeton-computational-imaging/Diffusion-SDF
Branch: main
Commit: 7eb2d6786b88
Files: 81
Total size: 38.4 MB

Directory structure:
gitextract_hv9jwie8/

├── .gitignore
├── LICENSE
├── README.md
├── config/
│   ├── stage1_sdf/
│   │   └── specs.json
│   ├── stage2_diff_cond/
│   │   └── specs.json
│   ├── stage2_diff_uncond/
│   │   └── specs.json
│   ├── stage3_cond/
│   │   └── specs.json
│   └── stage3_uncond/
│       └── specs.json
├── data/
│   ├── acronym/
│   │   └── Couch/
│   │       └── 37cfcafe606611d81246538126da07a8/
│   │           └── sdf_data.csv
│   ├── grid_data/
│   │   └── acronym/
│   │       └── Couch/
│   │           └── 37cfcafe606611d81246538126da07a8/
│   │               └── grid_gt.csv
│   └── splits/
│       └── couch_all.json
├── dataloader/
│   ├── .base.py.swp
│   ├── __init__.py
│   ├── base.py
│   ├── modulation_loader.py
│   ├── pc_loader.py
│   └── sdf_loader.py
├── diff_utils/
│   ├── helpers.py
│   ├── model_utils.py
│   ├── pointnet/
│   │   ├── __init__.py
│   │   ├── conv_pointnet.py
│   │   ├── dgcnn.py
│   │   ├── dgcnn_bn32.pt
│   │   ├── pointnet_base.py
│   │   ├── pointnet_classifier.py
│   │   └── transformer.py
│   └── sdf_utils.py
├── environment.yml
├── metrics/
│   ├── StructuralLosses/
│   │   ├── __init__.py
│   │   ├── match_cost.py
│   │   └── nn_distance.py
│   ├── __init__.py
│   ├── evaluation_metrics.py
│   └── pytorch_structural_losses/
│       ├── .gitignore
│       ├── Makefile
│       ├── StructuralLosses/
│       │   ├── __init__.py
│       │   ├── match_cost.py
│       │   └── nn_distance.py
│       ├── __init__.py
│       ├── match_cost.py
│       ├── nn_distance.py
│       ├── pybind/
│       │   ├── bind.cpp
│       │   └── extern.hpp
│       ├── setup.py
│       └── src/
│           ├── approxmatch.cu
│           ├── approxmatch.cuh
│           ├── nndistance.cu
│           ├── nndistance.cuh
│           ├── structural_loss.cpp
│           └── utils.hpp
├── models/
│   ├── __init__.py
│   ├── archs/
│   │   ├── __init__.py
│   │   ├── diffusion_arch.py
│   │   ├── encoders/
│   │   │   ├── __init__.py
│   │   │   ├── auto_decoder.py
│   │   │   ├── conv_pointnet.py
│   │   │   ├── dgcnn.py
│   │   │   ├── rbf.py
│   │   │   ├── sal_pointnet.py
│   │   │   └── vanilla_pointnet.py
│   │   ├── modulated_sdf.py
│   │   ├── resnet_block.py
│   │   ├── sdf_decoder.py
│   │   └── unet.py
│   ├── autoencoder.py
│   ├── combined_model.py
│   ├── diff_np_if_torch_error.py
│   ├── diffusion.py
│   └── sdf_model.py
├── test.py
├── train.py
└── utils/
    ├── __init__.py
    ├── chamfer.py
    ├── evaluate.py
    ├── mesh.py
    ├── mmd.py
    ├── reconstruct.py
    ├── renderer.py
    ├── tmd.py
    ├── uhd.py
    └── visualize.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.DS_Store
__pycache__/


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 Gene Chou

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions

[**Paper**](https://arxiv.org/abs/2211.13757) | [**Supplement**](https://light.princeton.edu/wp-content/uploads/2023/03/diffusionsdf_supp.pdf) | [**Project Page**](https://light.princeton.edu/publication/diffusion-sdf/) <br>

This repository contains the official implementation of <br> 
**[ICCV 2023] Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions** <br>
[Gene Chou](https://genechou.com), [Yuval Bahat](https://sites.google.com/view/yuval-bahat/home), [Felix Heide](https://www.cs.princeton.edu/~fheide/) <br>


If you find our code or paper useful, please consider citing
```bibtex
@inproceedings{chou2022diffusionsdf,
title={Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions},
author={Gene Chou and Yuval Bahat and Felix Heide},
journal={The IEEE International Conference on Computer Vision (ICCV)},
year={2023}
}
```


```cpp
root directory
  ├── config  
  │   └── // folders for checkpoints and training configs
  ├── data  
  │   └── // folders for data (in csv format) and train test splits (json)
  ├── models  
  │   ├── // models and lightning modules; main model is 'combined_model.py'
  │   └── archs
  │       └── // architectures such as PointNets, SDF MLPs, diffusion network..etc
  ├── dataloader  
  │   └── // dataloaders for different stages of training and generation
  ├── utils  
  │   └── // reconstruction and evaluation
  ├── metrics  
  │   └── // reconstruction and evaluation
  ├── diff_utils  
  │   └── // helper functions for diffusion
  ├── environment.yml  // package requirements
  ├── train.py  // script for training, specify the stage of training in the config files
  ├── test.py  // script for testing, specify the stage of testing in the config files
  └── tensorboard_logs  // created when running any training script
  
```

## Installation
We recommend creating an [anaconda](https://www.anaconda.com/) environment using our provided `environment.yml`:

```
conda env create -f environment.yml
conda activate diffusionsdf
```

## Dataset
For training, we preprocess all meshes and store query coordinates and signed distance values in csv files. Each csv file corresponds to one object, and each line represents a coordinate followed by its signed distance value. See `data/acronym` for examples. Modify the dataloader according to your file format. <br>

When sampling query points, make sure to also **sample uniformly within the 3D grid space** (i.e. from (-1,-1,-1) to (1,1,1)) rather than only sampling near the surface to avoid artifacts. For each training batch, we take 70% of query points sampled near the object surface and 30% sampled uniformly in the grid. `grid_source` in our dataloader and config file refers to the latter. <br>

## Training
As described in our [paper](https://arxiv.org/abs/2211.13757), there are three stages of training. All corresponding config files can be found in the `config` folders. Logs are created in a `tensorboard_logs` folder in the root directory. We recommend tuning the `"kld_weight"` when training the joint SDF-VAE model as it enforces the continuity of the latent space. A higher value (e.g. 0.1) will result in better interpolation and generalization but sometimes more artifacts. A lower value (e.g. 0.00001) will result in worse interpolation but higher quality of generations. <br>

1. Training SDF modulations

```
python train.py -e config/stage1_sdf/ -b 32 -w 8    # -b for batch size, -w for workers, -r to resume training
```
Training notes: For Acronym / ShapeNet datasets, the loss should go down to $6 \sim 8 \times 10^{-4}$. Run testing to visualize whether the quality of reconstructed shapes is sufficient. The quality of reconstructions will carry over to the quality of generations. Note that the dimension of the VAE latent vectors will be 3 times `"latent_dim"` in `"SdfModelSpecs"` listed in the config file.

2. Training the diffusion model using the modulations extracted from the first stage 

```
# extract the modulations / latent vectors, which will be saved in a "modulations" folder in the config directory
# the folder needs to correspond to "data_path" in the diffusion config files

python test.py -e config/stage1_sdf/ -r last

# unconditional
python train.py -e config/stage2_diff_uncond/ -b 32 -w 8 

# conditional
python train.py -e config/stage2_diff_cond/ -b 32 -w 8 
```
Training notes: When extracting modulations, we recommend filtering based on the chamfer distance. See `test_modulations()` in `test.py` for details. Some notes on the conditional config file:  `"perturb_pc":"partial"`, `"crop_percent":0.5`, and `"sample_pc_size":128` refers to cropping 50% of a point cloud with 128 points to use as condition. `dim` in `diffusion_model_specs` needs to be the dimension of the latent vector, which is 3 times `"latent_dim"` in `"SdfModelSpecs"`. <br>


3. End-to-end training using the saved models from above 

```
# unconditional
python train.py -e config/stage3_uncond/ -b 32 -w 8 -r finetune     # training from the saved models of first two stages
python train.py -e config/stage3_uncond/ -b 32 -w 8 -r last     # resuming training if third stage has been trained 

# conditional
python train.py -e config/stage3_cond/ -b 32 -w 8 -r finetune    # training from the saved models of first two stages
python train.py -e config/stage3_cond/ -b 32 -w 8 -r last     # resuming training if third stage has been trained 
```
Training notes: The config file needs to contain the saved checkpoints for the previous two stages of training. The sdf loss (not generated sdf loss) should approach $6 \sim 8 \times 10^{-4}$.

## Testing
1. Testing SDF reconstructions and saving modulations

After the first stage of training, visualize / test reconstructions and save modulations:
```
# extract the modulations / latent vectors, which will be saved in a "modulations" folder in the config directory
# the folder needs to correspond to "data_path" in the diffusion config files
python test.py -e config/stage1_sdf/ -r last
```
A `recon` folder in the config directory will contain the `.ply` reconstructions and a `cd.csv` file that logs Chamfer Distance (CD). A `modulation` folder will contain `latent.txt` files for each SDF. The `modulation` folder will be the data path to the second stage of training.

2. Generations 

Meshes can be generated after the second or third stage of training.
```
python test.py -e config/stage3_uncond/ -r finetune  # generation after second stage 
python test.py -e config/stage3_uncond/ -r last      # after third stage 
```
A `recon` folder in the config directory will contain the `.ply` reconstructions. `max_batch` arguments in `test.py` are used for running marching cubes; change it to the max value your GPU memory can hold.


## References
We adapt code from <br>
GenSDF https://github.com/princeton-computational-imaging/gensdf <br>
DALLE2-pytorch https://github.com/lucidrains/DALLE2-pytorch <br>
Convolutional Occupancy Networks https://github.com/autonomousvision/convolutional_occupancy_networks (for PointNet encoder) <br>
Multimodal Shape Completion via cGANs https://github.com/ChrisWu1997/Multimodal-Shape-Completion (for conditional metrics) <br>
PointFlow https://github.com/stevenygd/PointFlow (for unconditional metrics)


================================================
FILE: config/stage1_sdf/specs.json
================================================
{
    "Description" : "training joint SDF-VAE model for modulating SDFs on the couch dataset",
    "DataSource" : "data",
    "GridSource" : "data/grid_data",
    "TrainSplit" : "data/splits/couch_all.json",
    "TestSplit" : "data/splits/couch_all.json",
    
    "training_task": "modulation",
  
    "SdfModelSpecs" : {
      "hidden_dim" : 512,
      "latent_dim" : 256,
      "pn_hidden_dim" : 128,
      "num_layers" : 9
    },

    "SampPerMesh" : 16000,
    "PCsize" : 1024,
  
    "num_epochs" : 100001,
    "log_freq" : 5000,
    "kld_weight" : 1e-5,
    "latent_std" : 0.25,
  
    "sdf_lr" : 1e-4
}
  
  
  


================================================
FILE: config/stage2_diff_cond/specs.json
================================================
{
  "Description" : "diffusion training (conditional) on couch dataset",
  "pc_path" : "data",
  "total_pc_size" : 10000,
  "TrainSplit" : "data/splits/couch_all.json",
  "TestSplit" : "data/splits/couch_all.json",
  "data_path" : "config/stage1_sdf/modulations",

  "training_task": "diffusion",

  
  "num_epochs" : 50001,
  "log_freq" : 5000,

  "diff_lr" : 1e-5,

  "diffusion_specs" : {
    "timesteps" : 1000,
    "objective" : "pred_x0",
    "loss_type" : "l2",
    "perturb_pc" : "partial",
    "crop_percent": 0.5,
    "sample_pc_size" : 128
  },
  "diffusion_model_specs": {
    "dim" : 768,
    "depth" : 4,
    "ff_dropout" : 0.3,
    "cond" : true,
    "cross_attn" : true,
    "cond_dropout":true,
    "point_feature_dim" : 128
  }
}




================================================
FILE: config/stage2_diff_uncond/specs.json
================================================
{
  "Description" : "diffusion training (unconditional) on couch dataset",
  "TrainSplit" : "data/splits/couch_all.json",
  "TestSplit" : "data/splits/couch_all.json",
  "data_path" : "config/stage1_sdf/modulations",

  "training_task": "diffusion",

  "num_epochs" : 50001,
  "log_freq" : 5000,

  "diff_lr" : 1e-5,

  "diffusion_specs" : {
    "timesteps" : 1000,
    "objective" : "pred_x0",
    "loss_type" : "l2"
  },
  "diffusion_model_specs": {
    "dim" : 768,
    "dim_in_out" : 768,
    "depth" : 4,
    "ff_dropout" : 0.3,
    "cond" : false
  }
}



================================================
FILE: config/stage3_cond/specs.json
================================================
{
  "Description" : "end-to-end training (conditional) on couch dataset",
  "DataSource" : "data",
  "GridSource" : "grid_data",
  "TrainSplit" : "data/splits/couch_all.json",
  "TestSplit" : "data/splits/couch_all.json",
  "modulation_path" : "config/stage1_sdf/modulations",

  "modulation_ckpt_path" : "config/stage1_sdf/last.ckpt",
  "diffusion_ckpt_path" : "config/stage2_cond/last.ckpt",
  
  "training_task": "combined",

  "num_epochs" : 100001,
  "log_freq" : 5000,

  "kld_weight" : 1e-5,
  "latent_std" : 0.25,
  
  "sdf_lr" : 1e-4,
  "diff_lr" : 1e-5,

  "SdfModelSpecs" : {
    "hidden_dim" : 512,
    "latent_dim" : 256,
    "pn_hidden_dim" : 128,
    "num_layers" : 9
  },
  "SampPerMesh" : 16000,
  "PCsize" : 1024,

  "diffusion_specs" : {
    "timesteps" : 1000,
    "objective" : "pred_x0",
    "loss_type" : "l2",
    "perturb_pc" : "partial",
    "crop_percent": 0.5,
    "sample_pc_size" : 128
  },
  "diffusion_model_specs": {
    "dim" : 768,
    "depth" : 4,
    "ff_dropout" : 0.3,
    "cond" : true,
    "cross_attn" : true,
    "cond_dropout":true,
    "point_feature_dim" : 128
  }
}




================================================
FILE: config/stage3_uncond/specs.json
================================================
{
  "Description" : "end-to-end training (unconditional) on couch dataset",
  "DataSource" : "data",
  "GridSource" : "grid_data",
  "TrainSplit" : "data/splits/couch_all.json",
  "TestSplit" : "data/splits/couch_all.json",
  "modulation_path" : "config/stage1_sdf/modulations",

  "modulation_ckpt_path" : "config/stage1_sdf/last.ckpt",
  "diffusion_ckpt_path" : "config/stage2_uncond/last.ckpt",
  
  "training_task": "combined",

  "num_epochs" : 100001,
  "log_freq" : 5000,

  "kld_weight" : 1e-5,
  "latent_std" : 0.25,
  
  "sdf_lr" : 1e-4,
  "diff_lr" : 1e-5,

  "SdfModelSpecs" : {
    "hidden_dim" : 512,
    "latent_dim" : 256,
    "pn_hidden_dim" : 128,
    "num_layers" : 9
  },

  "SampPerMesh" : 16000,
  "PCsize" : 1024,

  "diffusion_specs" : {
    "timesteps" : 1000,
    "objective" : "pred_x0",
    "loss_type" : "l2"
  },
  "diffusion_model_specs": {
    "dim" : 768,
    "dim_in_out" : 768,
    "depth" : 4,
    "ff_dropout" : 0.3,
    "cond" : false
  }
}




================================================
FILE: data/acronym/Couch/37cfcafe606611d81246538126da07a8/sdf_data.csv
================================================
[File too large to display: 21.4 MB]

================================================
FILE: data/grid_data/acronym/Couch/37cfcafe606611d81246538126da07a8/grid_gt.csv
================================================
[File too large to display: 16.7 MB]

================================================
FILE: data/splits/couch_all.json
================================================
{
    "acronym": {
        "Couch": [
            "37cfcafe606611d81246538126da07a8",
            "fd124209f0bc1222f34e56d1d9ad8c65",
            "9bbf6977b3f0f9cb11e76965808086c8",
            "b6ac23d327248d72627bb9f102840372",
            "133f2d360f9b8691f5ee22e800bb9145",
            "fa1e1a91e66faf411de55fee5ac2c5c2",
            "a10a157254f74b0835836c728d324152",
            "9152b02103bdc98633b8015b1af14a5f",
            "9cddb828b936db93c341afa383659322",
            "69c1e004f9e42b71e7e684d25d4dcaf0",
            "694af585378fef692d0fc9d5a4650a3b",
            "9156988a1a8645e727eb00c151c6f711",
            "e549ab0c75d7fd08cbf65787367134d6",
            "354c37c168778a0bd4830313df3656b",
            "c32dc92e660bf12b7b6fd5468f603b31",
            "a613610e1b9eda601e20309da4bdcbd0",
            "8701046c07327e3063f9008a349ae40b",
            "95277cbbe4e4a566c5ee765adc4c8ef3",
            "c29c56fe616c0986e7e684d25d4dcaf0",
            "252be483777007c22e7955415f58545",
            "8030e04cf9c19768623a7f06c75e0539",
            "21f76612b56d67edf54efb4962ed3879",
            "cfa489d103b4f937411053a770408f0d",
            "83fd038152fca4ba1c5fa3821fb8a849",
            "a38ab75ff03f4befe3f7a74e12a274ef",
            "ac376ac6f62d2bd535836c728d324152",
            "af796bbeb2990931a1ce49849c98d31c",
            "3f90f35a391d5e273b04279a45e22431",
            "cfb40e7b9990da99c2f927df125f5ce4",
            "26205c6a0a25d5f884099151cc96de84",
            "199da405357516a37540a07f0d62945",
            "f261a5d15d8dd1cee16219238f4e984c",
            "4c5d2b857e9b60a884b236d07141c1a6",
            "9b818d54d601b026b161f36d4e309050",
            "40060ca118b6dd904d3df84a5eca1a73",
            "a17f6d35f289433b5112295517941cf7",
            "6475924415a3534111e85af13fa8bb9b",
            "87f103e24f91af8d4343db7d677fae7b",
            "656925921cfaf668639c9ca4528de8f2",
            "b38978ed5bb13266f2b67ae827b02632",
            "ce0a19a131121305d2bc6fb367627d3e",
            "f645a78ce2bd743ed0c05eb40b42c942",
            "2b8d1c67d17d3911d9cff7df5347abca",
            "4ef1c8ad9c72263cdc6d207f01b16dd0",
            "537438b5f127b55e851f4ba6aaedaaa8",
            "2b158fc055b09892e3f7a74e12a274ef",
            "abb7fb6cc7c5b2632829384b8c505bb2",
            "e68758f3775f700df1783a44a88d6274",
            "4dd15f5b09c1494e8058cf23f6382c1",
            "6c4eece6ed9f26f1a1741c84cc821e55",
            "ae7e1fe11d9243bc5c7a30510dbe4e9f",
            "63d0dd6566dfec32ea65c47b660136e7",
            "7c6491c7fe65449cd9bdef3310c0e5a6",
            "de14f6786ddc7e11bd83c80a89cb8f94",
            "a1b0f606bee9dd8e2e7ed79f4f48ff79",
            "8efa91e2f3e2eaf7bdc82a7932cd806",
            "fa1e84f906d8b61973b1be0dc36d8eee",
            "3f6cbfc955e3a82e2f3bfb1446a70fbe",
            "a627a11941892ada3707a503164f0d91",
            "436a96f58ef9a6fdb039d8689a74349",
            "c644689138389daa749ddac355b8e63d",
            "ee5f19266a3ed535a6491c91cd8d7770",
            "62852049a9f049cca67a896eea47bc81",
            "2efd479c13f4cf72af1a85b857ec9afc",
            "32b6c232154f2f35550644042dd119da",
            "5fffafa504c7ee2a8c4f202fffc87396",
            "b714755fb21264c73f89a402a2f2d0aa",
            "7e68c12da163627dd2856b4dc0e7a482",
            "ce273dabcdd9943b663191fd557d3a61",
            "a207c52399b289be9657063f91d78d19",
            "f0a02b63e1c84a3fbf7df791578d3fb5",
            "34156abf2ce71a60e7e684d25d4dcaf0",
            "80982ab523bd3ad35c7200259c323750",
            "82460a80f6e9cd8bb0851ce87f32a267",
            "823e0ffaf2b47fb27f45370489ca3156",
            "e6d92067a01231cb3f7e27638e63d848",
            "35b4c449b5ae845013aebd62ca4dcb58",
            "8a1a39223639f16e833c6c72c4b62a4d",
            "f87682301c61288c25a12159dbc477b5",
            "76bef187c092b6335ff61a3a2a0e2484",
            "9473a8f3e2182d90d810b14a81e12eca",
            "ec9d8a450d5d30935951e132a7bb8604",
            "e94c453523b1dabda4669f677ccd56a9",
            "d8e1f6781276eb1af34e56d1d9ad8c65",
            "1fd45c57ab27cb6cea65c47b660136e7",
            "b55d24ed786a8279ad2d3ca7f660ddd",
            "30497cce015f143c1191e01810565e0c",
            "9b73921863beef532d1fcd5297483645",
            "e49c0df0a42bdbecc4b4c7225ff8487e",
            "bac90cda9aeb6f5d348e0a0702972359",
            "264d40f99914b97e577df49fb73cc67c",
            "b1b5d123613e75c97d1a9cfc3f714b6d",
            "3b109c1cd75648b71a9a361595e863ab",
            "44708e57c3c0140f9cf1b86170c5d320",
            "66f193b8f45cd9525582003645829560",
            "9c4c4fbaa6e9284fabd9c31d4e727bfe",
            "f20e7a860fca9179d57c8a5f8e280cfb",
            "8872b51e43d2839b2bbb2dbe46302370",
            "323e278c6d7942691836774006aed7e2",
            "4106b3d95b6dc95caacb7aee27fab780",
            "601bf25b4512502145c6cb69e0968783",
            "cb57531ee7104145b0892e337a99f42d",
            "b81749d11da67cef57a847db7547c1f3",
            "3aab3cf29ab92ce7c29432ec481a60b1",
            "e0b897af3eee5ec8d8ac5d7ad0953104",
            "eda881a6ea96da8a46874ce99dea28d5",
            "5bf5096583e15c0080741efeb2454ffb",
            "b765b2c997c459fa83fb3a64ac774b17",
            "ca3e6e011ed7ecb32dc86554a873c4fe",
            "eb74c7d047809a8c859e01b847403b9e",
            "3cce48090f5fabe4e356f23093e95c53",
            "dc079a42bd90dcd9593ebeeedbff73b",
            "2c12a9ba6d4d25df8af30108ea9ccb6c",
            "97f506b27434613d51c4deb11af7079e",
            "98e6845fd0c59edef1a8499ee563cb43",
            "addd6a0ef4f55a66d810b14a81e12eca",
            "21c61d015c6a9591e3f7a74e12a274ef",
            "9de72746c0f00317dd73da65dc0e6a7",
            "ac8a2b54ab0d2f94493a4a2a112261d3",
            "f31648796796cff297c8d78b9aede742",
            "c69ce34e38f6218b2f809039658ca52",
            "31b5cb5dfaa253b3df85db41e3677c28",
            "d8fa31c19a952efb293968bf1f72ae90",
            "ab347f12652bda8eab7f9d2da6fc61cf",
            "ccd0c5e06744ad9a5ca7e476d2d4a26e",
            "daf0e2193c5d1096540b442e651b7ad2",
            "3aa613c06675d2a4dd94d3cfe79da065",
            "6c3cee36477c444e2e64cef1881421b",
            "9a32ae6c8a6afa1e7cfc6b949bbd80bf",
            "a85046f089363914e500b815ce2d83a5",
            "be2fe14cb518d6346c115ab7398115c6",
            "7961d0f612add0cee08bb071746122b9",
            "b95a9b552791e21035836c728d324152",
            "d2014ddbb4e91d7ecb1a776b5576b46b",
            "9c0c2110a58e15febc48810968735439",
            "950b680f7b13e72956433d91451f5f9e",
            "d01ce0f02e25fc2b42e1bb4fe264125f",
            "5fabccb7eaa2adb19a037b4abf810691",
            "ad5ef1b493028c0bd810b14a81e12eca",
            "a2bdc3a6cb13bf197a969536c4ba7f8",
            "3c56ceef171fa142126c0b0ea3ea0a2c",
            "a8ae390fc6de647e1191e01810565e0c",
            "10507ae95f984daccd8f3fe9ca2145e1",
            "ebf1982ccb77cf7d4c37b9ce3a3de242",
            "fbd0055862daa31a2d8ad3188383fcc8",
            "1d6250cafc410fdfe8058cf23f6382c1",
            "7833d94635b755793adc3470b30138f3",
            "ee4ec19df3a6b81317fd40ad196436c",
            "82d25519070e3d5d6f1ad7def14e2855",
            "1ed2e9056a9e94a693e60794f9200b7",
            "980d28e46333db492878c1c13b7f1ba6",
            "a84ff0a6802b0661f345fb470303964a",
            "b58291a2e96691189d1eb836604648db",
            "f653f39d443ac6af15c0ed29be4328d5",
            "4f20a1fe4c85f35aa8891e678de4fe35",
            "90e1a4fb7fa4f4fea5df41522f4e2c86",
            "e5f793180e610d329d6f7fa9d8096b18",
            "4760c46c66461a79dd3adf3090c701f7",
            "3069cb9d632611dcdb43ca77043c03b9",
            "6f0f6571f173bd90f9883d2fd957d60f",
            "11d5e99e8faa10ff3564590844406360",
            "36980a6d57dac873d206493eaec9688a",
            "80b3fb6bae282ec24ca35f9816a36449",
            "7bfef1d5b096f81967ec3d19eeb10ea5",
            "6bfe4dac77d317de1181122615f0a10",
            "2b4de06792ec0eba94141819f1b9662c",
            "20c030e0d055b7aeb0892e337a99f42d",
            "bb53613789a3e9158d736fde63f29661",
            "4ca3850427226dd18d4d8cbfa622f635",
            "a4d269f9299e845a36e3b2fa8d1eb4eb",
            "d13a2ccdbb7740ea83a0857b7b9398b1",
            "a98956209f6723a2dedecd2df7bf25e3",
            "80c8fcbc5f53c402162cb7009da70820",
            "32909910636369b24500047017815f5f",
            "b3d686456bd951d42ea98d69e91ba870",
            "c0d544b1b483556e4d70389e06680f96",
            "d08fc6f10d07dfd8c05575120a46cd3b",
            "e2e25d63c3bc14c23459b03b1c80294",
            "3d95d6237ea6db97afa2904116693357",
            "a7e4616a2a315dfac5ddc26ef5560e77",
            "7a92d499a6f3460a4dc2316a7e66d36",
            "22fd961578d9b100e858db1dc3499392",
            "9136172b17108d1ba7d0cc9b15400f65",
            "62e50e8b0d1e3207e047a3592e8436e5",
            "7fb8fdbc8c32dc2aad6d7e3d5c0179fe",
            "1488c4dbdcd022e99de4fa6f98acbba8",
            "9cff96bae9963ceab3c971099efd7fb9",
            "5cea034b028af000c2843529921f9ad7",
            "8a8c67e553f1c85c1829bffea9d18abb",
            "90959c7526608ecdfb54b6f00b1687e2",
            "8043ff469c9bed4d48c575435d16be7c",
            "b031d0f66e8d047bee8bf38bd6d16329",
            "700562e70c89e9d36bd9ce012b65eebe",
            "b9994f0252c890cbb0892e337a99f42d",
            "7201558f92e55496493a4a2a112261d3",
            "e0f5780dbee33329caa60d250862f15f",
            "1aa55867200ea789465e08d496c0420f",
            "956c5ffbe1ad52976e3c8a33c4ddf2ef",
            "3415f252bd71495649920492438878e5",
            "d64025df908177ece7e684d25d4dcaf0",
            "65309e7601a0c65d82608d43718c2b7",
            "f8ef9668c3aba7c735836c728d324152",
            "6c4e0987896fc5df30c7f4adc2c33ad6",
            "23d1a49a9a29c776ab9281d3b84673cf",
            "aa5fe2c92f01ce0a30cbbda41991e4d0",
            "2fca00b377269eebdb039d8689a74349",
            "fe56059777b240bb833c6c72c4b62a4d",
            "82a168f7c5b8899a79368d1198f406e7",
            "1ebc20758e0b61184bf4a6644c670a1e",
            "f1ce06c5259f771dc24182d0db4c6889",
            "a1b935c6288595267795cc6fb87dd247",
            "b16913335a26e380d1a4117555fc7e56",
            "87bdac1e34f3e6146db2ac45db35c175",
            "f693bb3178c80d7f1783a44a88d6274",
            "f39cb99f7c30a4b8310cd758d9b7cf",
            "91238bfd357c2d87abe6f34906958ac9",
            "801418d073795539ddb3d1341171fe71",
            "b2e9bf58ab2458b5b057261be64dfc5",
            "7669de4a8474b6f4b53857b83094d3a9",
            "9d8c5c62020fcaf4b822d48a43773c62",
            "a680830f8b76c1bbe929777b2f481029",
            "6d0cd48b18471a8bf1444eeb21e761c6",
            "baa8760ca5fbbc4840b559ef47048b86",
            "377fceb1500e6452d9651cd1d591d64d",
            "440e3ad55b603cb1b071d266df0a3bf5",
            "21bf3888008b7aced6d2e576c4ef3bde",
            "151b2df36e08a13fafeeb1a322c90696",
            "24f5497f13a1841adb039d8689a74349",
            "a4c8e8816a1c5f54e6e3ac4cbdf2e092",
            "4cd19483d852712a120031fb55ce8c1c",
            "69257080fd87015369fb37a80cd44134",
            "22c68a7a2c8142f027eb00c151c6f711",
            "82efd3fc01b0c6f34500047017815f5f",
            "fcff900ce37820983f7e27638e63d848",
            "cdb7efad9942e4f9695ff9c22a7938c6",
            "107637b6bdf8129d4904d89e9169817b",
            "f2e7ed2b973570f1a54b9afa882a89ed",
            "b58f27bd7e1ace90d2afe8d5254a0d04",
            "4820b629990b6a20860f0fe00407fa79",
            "abfd8f1db83d6fd24f5ae69aba92b12c",
            "4ed802a4aa4b8a86b161f36d4e309050",
            "47ad0af4207beedb296baeb5500afa1a",
            "f6f5fa7760fde38114aa582256ea1399",
            "60fc7123d6360e6d620ef1b4a95dca08",
            "c53ec0141303c1eb4508add1163b4513",
            "2ebb84f64f8f0f565db77ed1f5c8b93",
            "6b0a09fda777eef2a0398757e5dcd13c",
            "2c1ecb41c0f0d2cd07c7bf20dae278a",
            "e03c28dbfe1f2d9638bb8355830240f9",
            "1b29de0b8be4b18733d25da891be74b8",
            "38a4328fc3e65ba54c37b9ce3a3de242",
            "f81a05c0b43e9bf18c9441777325d0fd",
            "4cb7347f6a91da8554f7af0f5a657663",
            "5328231a28719ed240a92729068b6b39",
            "9239ed83d2274ac975aa7f24a9b6003a",
            "71579e02d0b80bcfccebba9929a10b5e",
            "8647bc3a58eb9d04493a4a2a112261d3",
            "f2c96ae9904baaf220021a69db826002",
            "923a0885b7ca9c55d1007f4863f3bba6",
            "2f28d32781abb0527f2a6b0d07c10212",
            "e68e91ef2652cd1c36e3b2fa8d1eb4eb",
            "9451b957aa36883d6e6c0340d644e56e",
            "72b7ad431f8c4aa2f5520a8b6a5d82e0",
            "54b420da7102d792b36717b39afb6ad8",
            "16fd88a99f7d4c857e484225f3bb4a8",
            "8bad639b9a650908de650492e45fb14f",
            "24129a2a06b35e27c1a6b3575377c8d3",
            "fc553c244c85909c8ed898bae5915f59",
            "e3ce79fd03b7a78d98661b9abac3e1f9",
            "3fcb0aaa346bd46f11e76965808086c8",
            "df8374d8f3563be8f1783a44a88d6274",
            "7eb94d259b560d24f1783a44a88d6274",
            "48205c79d48ca2049f433921788191f3",
            "9294163aacac61f1ad5c4e4076e069c",
            "f2458aaf4ab3e0585d7543afa4b9b4e8",
            "f09a3e98938bc1a990e678878ceea8a4",
            "2e2f34305ef8cbc1533ccec14d70360b",
            "1a477f7b2c1799e1b728e6e715c3f8cf",
            "bcd2418fb0d9727c563fcc2752ece39",
            "ace76562ee9d7c3a913c66b05d18ae8",
            "1b5ae67e2ffb387341fbc1e2da054acb",
            "f563b39b846922f22ea98d69e91ba870",
            "fce717669ca521884e1a9fae5403e01f",
            "d0563b5c26d096ea7b270a7ae3fc6351",
            "43656f715bb28776bff15b656f256f05",
            "3006b66ed9a9fd55a4ccffb47ce1f60d",
            "f5dc5957a6b9712835836c728d324152",
            "3d87710d90c8627dd2afe8d5254a0d04",
            "21b22c30f1c6ddb9952d5d6c0ee49300",
            "e077c6640dab26ecc6e735f548844733",
            "bdfcf2086fafb0fec8a04932b17782af",
            "fb883c27f9f0156e1e44b635c3d74c0b",
            "161f92c3330d111ffc2dd24b885b9c05",
            "b23dc14d788e954b3adc3470b30138f3",
            "a930d381392ff51140b559ef47048b86",
            "7c92e64a328f1b968f6cc6fefa15515a",
            "8affea22019b77a1f1783a44a88d6274",
            "1dc08faeb376c6e0296baeb5500afa1a",
            "fadd7d8c94893136e4b1c2efb094888b",
            "6bd7475753c3d1d62764cfba57a5de73",
            "5d9f1c6f9ce9333994c6d0877753424f",
            "c3e248c5f88e72e9ea65c47b660136e7",
            "87c7911287c10ab8407b28046774089c",
            "4fbc051d0905bb5bdaeb838d0771f3b5",
            "b097e092f951f1d68bc0997abbde7f8",
            "1aec9ac7e1487b7bc75516c7217fb218",
            "2f9a502dafbaa50769cd744177574ad3",
            "9df9d1c02e9013e7ef10d8e00e9d279c",
            "f0f42d26c4a0be52a53016a50348cce3",
            "f080807207cc4859b2403dba7fd079eb",
            "ffcc57ea3101d18ece3df8a7477638c0",
            "f01821fb48425e3a493a4a2a112261d3",
            "8cc805e63f357a69ae336243a6891b91",
            "468e627a950f9340d2bc6fb367627d3e",
            "499edbd7de3e7423bb865c00ef25280",
            "686c5c4a39de31445c31d9fe12bfc9ab",
            "1470c6423b877a0882fcee0c19bca00a",
            "bcc267b5e694e60c3adc3470b30138f3",
            "e7b9fef210fa47505615f13c9cba61e5",
            "ddc8bdf6ac3786f493b55327c66a07aa",
            "27359b0203dc2f345c6cb69e0968783",
            "1e355dc9bfd5da837a8c23d2d40f51b8",
            "b2cfba8ee63abd118fac6a8030e15671",
            "f1a956705451d66ee95eb670e9df38e4",
            "6fc69edce1f6d0d5e7e684d25d4dcaf0",
            "2d3f8abd0567f1601191e01810565e0c",
            "52d307203aefd6bf366971e8a2cdf120",
            "be129d18d202650f6d3e11439c6c22c8",
            "2d4adec4323f8a7da6491c91cd8d7770",
            "9695e057d7a4e992f2b67ae827b02632",
            "21addfde981f0e3445c6cb69e0968783",
            "f8a5ab9872d0dc5e1855a135fe295583",
            "3f8f1d7023ae4a1c73ffdca541f4749c",
            "1e4a7cb88d31716cc9c93fe51d670e21",
            "2cf89f5b70817c3bab2844e3834b9ca1",
            "e410469475df6525493a4a2a112261d3",
            "81f1ac4d9c29025a36f739eb863924d1",
            "c955e564c9a73650f78bdf37d618e97e",
            "20b6d398c5b93a253adc3470b30138f3",
            "ed4ac116e03ebb8d663191fd557d3a61",
            "4735568bd188aefcb8e1b99345a5afd4",
            "6b3b1d804255188993680c5a9a367b4a",
            "5ea73e05bb96272b444ac0ff78e79b7",
            "2eed4a1647d05552413240bb8f2e4eea",
            "6e960c704961590e7e0d7ca07704f74f",
            "dccb04ee585d9d10cf065d4b58dc3084",
            "1eb4bc0cea4ba47c8186c25526ebdaa6",
            "f144cda19f9457fef9b7ca92584b5271",
            "c971b2adf8f071f49afca5ea91159877",
            "1e96607159093f7bbdb7b8a99c2de3d8",
            "a29f6cce8bf52a99182529900d54df69",
            "f260588ef2f4c8e62a6f69f20bf5a185",
            "72b7e2a9bdef8b37f491193d3fde480b",
            "b5fe8de26eac454acdead18f26cf2ece",
            "6caa1713e795c8a2f0478431b5ad57db",
            "1914d0e6b9f0445b40e80a2d9f005aa7",
            "c5380b779c689a919201f2703b45dd7",
            "735122a1019fd6529dac46bde4c69ef2",
            "660df170c4337cda35836c728d324152",
            "770bf9d88c7f039227eb00c151c6f711",
            "2b73510e4eb3d8ca87b66c61c1f7b8e4",
            "3e499689bae22f3ab89ca298cd9a646",
            "ae1fd69e00c00a0253116daf5a10d9c",
            "823219c03b02a423c1a85f2b9754d96f",
            "f87c2d4d5d622293a3f3891e682aaded",
            "ee2ea7d9f26f57886580dafdf181b2d7"
        ]}
}

================================================
FILE: dataloader/__init__.py
================================================


================================================
FILE: dataloader/base.py
================================================
#!/usr/bin/env python3

import numpy as np
import time 
import logging
import os
import random
import torch
import torch.utils.data

import pandas as pd 
import csv

class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data_source,
        split_file, # json filepath which contains train/test classes and meshes 
        subsample,
        gt_filename,
        #pc_size=1024,
    ):

        self.data_source = data_source 
        self.subsample = subsample
        self.split_file = split_file
        self.gt_filename = gt_filename
        #self.pc_size = pc_size

        # example
        # data_source: "data"
        # ws.sdf_samples_subdir: "SdfSamples"
        # self.gt_files[0]: "acronym/couch/meshname/sdf_data.csv"
            # with gt_filename="sdf_data.csv"

    def __len__(self):
        return NotImplementedError

    def __getitem__(self, idx):     
        return NotImplementedError

    def sample_pointcloud(self, csvfile, pc_size):
        f=pd.read_csv(csvfile, sep=',',header=None).values

        f = f[f[:,-1]==0][:,:3]

        if f.shape[0] < pc_size:
            pc_idx = np.random.choice(f.shape[0], pc_size)
        else:
            pc_idx = np.random.choice(f.shape[0], pc_size, replace=False)

        return torch.from_numpy(f[pc_idx]).float()

    def labeled_sampling(self, f, subsample, pc_size=1024, load_from_path=True):
        if load_from_path:
            f=pd.read_csv(f, sep=',',header=None).values
            f = torch.from_numpy(f)

        half = int(subsample / 2) 
        neg_tensor = f[f[:,-1]<0]
        pos_tensor = f[f[:,-1]>0]

        if pos_tensor.shape[0] < half:
            pos_idx = torch.randint(0, pos_tensor.shape[0], (half,))
        else:
            pos_idx = torch.randperm(pos_tensor.shape[0])[:half]

        if neg_tensor.shape[0] < half:
            if neg_tensor.shape[0]==0:
                neg_idx = torch.randperm(pos_tensor.shape[0])[:half] # no neg indices, then just fill with positive samples
            else:
                neg_idx = torch.randint(0, neg_tensor.shape[0], (half,))
        else:
            neg_idx = torch.randperm(neg_tensor.shape[0])[:half]

        pos_sample = pos_tensor[pos_idx]

        if neg_tensor.shape[0]==0:
            neg_sample = pos_tensor[neg_idx]
        else:
            neg_sample = neg_tensor[neg_idx]

        pc = f[f[:,-1]==0][:,:3]
        pc_idx = torch.randperm(pc.shape[0])[:pc_size]
        pc = pc[pc_idx]

        samples = torch.cat([pos_sample, neg_sample], 0)

        return pc.float().squeeze(), samples[:,:3].float().squeeze(), samples[:, 3].float().squeeze() # pc, xyz, sdv


    def get_instance_filenames(self, data_source, split, gt_filename="sdf_data.csv", filter_modulation_path=None):
            
            do_filter = filter_modulation_path is not None 
            csvfiles = []
            for dataset in split: # e.g. "acronym" "shapenet"
                for class_name in split[dataset]:
                    for instance_name in split[dataset][class_name]:
                        instance_filename = os.path.join(data_source, dataset, class_name, instance_name, gt_filename)

                        if do_filter:
                            mod_file = os.path.join(filter_modulation_path, class_name, instance_name, "latent.txt")

                            # do not load if the modulation does not exist; i.e. was not trained by diffusion model
                            if not os.path.isfile(mod_file):
                                continue
                        
                        if not os.path.isfile(instance_filename):
                            logging.warning("Requested non-existent file '{}'".format(instance_filename))
                            continue

                        csvfiles.append(instance_filename)
            return csvfiles


================================================
FILE: dataloader/modulation_loader.py
================================================
#!/usr/bin/env python3

import time 
import logging
import os
import random
import torch
import torch.utils.data
from diff_utils.helpers import * 

import pandas as pd 
import numpy as np
import csv, json

from tqdm import tqdm

class ModulationLoader(torch.utils.data.Dataset):
    def __init__(self, data_path, pc_path=None, split_file=None, pc_size=None):
        super().__init__()

        self.conditional = pc_path is not None 

        if self.conditional:
            self.modulations, pc_paths = self.load_modulations(data_path, pc_path, split_file)
        else:
            self.modulations = self.unconditional_load_modulations(data_path, split_file)
        #self.modulations = self.modulations[0:8]
        #pc_paths = pc_paths[0:8]

        print("data shape, dataset len: ", self.modulations[0].shape, len(self.modulations))
        #assert args.batch_size <= len(self.modulations)
        
        if self.conditional:
            print("loading ground truth point clouds...")            
            lst = []
            with tqdm(pc_paths) as pbar:
                for i, f in enumerate(pc_paths):
                    pbar.set_description("Point clouds loaded: {}/{}".format(i, len(pc_paths)))
                    lst.append(sample_pc(f, pc_size))
            self.point_clouds = lst

            assert len(self.point_clouds) == len(self.modulations)
        
        
    def __len__(self):
        return len(self.modulations)

    def __getitem__(self, index):

        pc = self.point_clouds[index] if self.conditional else False
        return {
            "point_cloud" : pc,
            "latent" : self.modulations[index]         
        }
        

    def load_modulations(self, data_source, pc_source, split, f_name="latent.txt", add_flip_augment=False, return_filepaths=True):
        #split = json.load(open(split))
        files = []
        filepaths = [] # return filepaths for loading pcs
        for dataset in split: # dataset = "acronym" 
            for class_name in split[dataset]:
                for instance_name in split[dataset][class_name]:

                    if add_flip_augment:
                        for idx in range(4):
                            instance_filename = os.path.join(data_source, class_name, instance_name, "latent_{}.txt".format(idx))
                            if not os.path.isfile(instance_filename):
                                print("Requested non-existent file '{}'".format(instance_filename))
                                continue
                            files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
                        filepaths.append( os.path.join(pc_source, dataset, class_name, instance_name, "sdf_data.csv") )

                    else:
                        instance_filename = os.path.join(data_source, class_name, instance_name, f_name)
                        if not os.path.isfile(instance_filename):
                            #print("Requested non-existent file '{}'".format(instance_filename))
                            continue
                        files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
                        filepaths.append( os.path.join(pc_source, dataset, class_name, instance_name, "sdf_data.csv") )
        if return_filepaths:
            return files, filepaths
        return files

    def unconditional_load_modulations(self, data_source, split, f_name="latent.txt", add_flip_augment=False):
        files = []
        for dataset in split: # dataset = "acronym" 
            for class_name in split[dataset]:
                for instance_name in split[dataset][class_name]:

                    if add_flip_augment:
                        for idx in range(4):
                            instance_filename = os.path.join(data_source, class_name, instance_name, "latent_{}.txt".format(idx))
                            if not os.path.isfile(instance_filename):
                                print("Requested non-existent file '{}'".format(instance_filename))
                                continue
                            files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )

                    else:
                        instance_filename = os.path.join(data_source, class_name, instance_name, f_name)
                        if not os.path.isfile(instance_filename):
                            continue
                        files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )
        return files

================================================
FILE: dataloader/pc_loader.py
================================================
#!/usr/bin/env python3

import time 
import logging
import os
import random
import torch
import torch.utils.data
from . import base 
from tqdm import tqdm

import pandas as pd 
import csv

class PCloader(base.Dataset):

    def __init__(
        self,
        data_source,
        split_file, # json filepath which contains train/test classes and meshes 
        pc_size=1024,
        return_filename=False
    ):

        self.pc_size = pc_size
        self.gt_files = self.get_instance_filenames(data_source, split_file)
        self.return_filename = return_filename

        self.pc_paths = self.get_instance_filenames(data_source, split_file)
        self.pc_paths = self.pc_paths[:5] 
        print("loading {} point clouds into memory...".format(len(self.pc_paths)))
        lst = []
        with tqdm(self.pc_paths) as pbar:
            for i, f in enumerate(pbar):
                pbar.set_description("Files loaded: {}/{}".format(i, len(self.pc_paths)))
                lst.append(self.sample_pc(f, pc_size))
        self.point_clouds = lst

        #print("each pc shape: ", self.point_clouds[0].shape)

    def get_all_files(self):
        return self.point_clouds, self.pc_paths 
    
    def __getitem__(self, idx): 
        if self.return_filename:
            return self.point_clouds[idx], self.pc_paths[idx]
        else:
            return self.point_clouds[idx]


    def __len__(self):
        return len(self.point_clouds)


    def sample_pc(self, f, samp=1024): 
        '''
        f: path to csv file
        '''
        # data = torch.from_numpy(np.loadtxt(f, delimiter=',')).float()
        data = torch.from_numpy(pd.read_csv(f, sep=',',header=None).values).float()
        pc = data[data[:,-1]==0][:,:3]
        pc_idx = torch.randperm(pc.shape[0])[:samp] 
        pc = pc[pc_idx]
        #print("pc shape, dtype: ", pc.shape, pc.dtype) # [1024,3], torch.float32
        #pc = normalize_pc(pc)
        #print("pc shape: ", pc.shape, pc.max(), pc.min())
        return pc



    


================================================
FILE: dataloader/sdf_loader.py
================================================
#!/usr/bin/env python3

import time 
import logging
import os
import random
import torch
import torch.utils.data
from . import base 

import pandas as pd 
import numpy as np
import csv, json

from tqdm import tqdm

class SdfLoader(base.Dataset):

    def __init__(
        self,
        data_source, # path to points sampled around surface
        split_file, # json filepath which contains train/test classes and meshes 
        grid_source=None, # path to grid points; grid refers to sampling throughout the unit cube instead of only around the surface; necessary for preventing artifacts in empty space
        samples_per_mesh=16000,
        pc_size=1024,
        modulation_path=None # used for third stage of training; needs to be set in config file when some modulation training had been filtered
    ):
 
        self.samples_per_mesh = samples_per_mesh
        self.pc_size = pc_size
        self.gt_files = self.get_instance_filenames(data_source, split_file, filter_modulation_path=modulation_path)

        subsample = len(self.gt_files) 
        self.gt_files = self.gt_files[0:subsample]

        self.grid_source = grid_source
        #print("grid source: ", grid_source)
    
        if grid_source:
            self.grid_files = self.get_instance_filenames(grid_source, split_file, gt_filename="grid_gt.csv", filter_modulation_path=modulation_path)
            self.grid_files = self.grid_files[0:subsample]
            lst = []
            with tqdm(self.grid_files) as pbar:
                for i, f in enumerate(pbar):
                    pbar.set_description("Grid files loaded: {}/{}".format(i, len(self.grid_files)))
                    lst.append(torch.from_numpy(pd.read_csv(f, sep=',',header=None).values))
            self.grid_files = lst
            
            assert len(self.grid_files) == len(self.gt_files)


        # load all csv files first 
        print("loading all {} files into memory...".format(len(self.gt_files)))
        lst = []
        with tqdm(self.gt_files) as pbar:
            for i, f in enumerate(pbar):
                pbar.set_description("Files loaded: {}/{}".format(i, len(self.gt_files)))
                lst.append(torch.from_numpy(pd.read_csv(f, sep=',',header=None).values))
        self.gt_files = lst


    def __getitem__(self, idx): 

        near_surface_count = int(self.samples_per_mesh*0.7) if self.grid_source else self.samples_per_mesh

        pc, sdf_xyz, sdf_gt =  self.labeled_sampling(self.gt_files[idx], near_surface_count, self.pc_size, load_from_path=False)
        

        if self.grid_source is not None:
            grid_count = self.samples_per_mesh - near_surface_count
            _, grid_xyz, grid_gt = self.labeled_sampling(self.grid_files[idx], grid_count, pc_size=0, load_from_path=False)
            # each getitem is one batch so no batch dimension, only N, 3 for xyz or N for gt 
            # for 16000 points per batch, near surface is 11200, grid is 4800
            #print("shapes: ", pc.shape,  sdf_xyz.shape, sdf_gt.shape, grid_xyz.shape, grid_gt.shape)
            sdf_xyz = torch.cat((sdf_xyz, grid_xyz))
            sdf_gt = torch.cat((sdf_gt, grid_gt))
            #print("shapes after adding grid: ", pc.shape, sdf_xyz.shape, sdf_gt.shape, grid_xyz.shape, grid_gt.shape)

        data_dict = {
                    "xyz":sdf_xyz.float().squeeze(),
                    "gt_sdf":sdf_gt.float().squeeze(), 
                    "point_cloud":pc.float().squeeze(),
                    }

        return data_dict

    def __len__(self):
        return len(self.gt_files)



    


================================================
FILE: diff_utils/helpers.py
================================================
import math
import torch
import torch.nn.functional as F
import numpy as np 
import pandas as pd 
import random 
from inspect import isfunction
import os
import json 
#import open3d as o3d


def get_split_filenames(data_source, split_file, f_name="sdf_data.csv"):
    split = json.load(open(split_file))
    csvfiles = []
    for dataset in split: # e.g. "acronym" "shapenet"
        for class_name in split[dataset]:
            for instance_name in split[dataset][class_name]:
                instance_filename = os.path.join(data_source, dataset, class_name, instance_name, f_name)
                if not os.path.isfile(instance_filename):
                    print("Requested non-existent file '{}'".format(instance_filename))
                    continue
                csvfiles.append(instance_filename)
    return csvfiles

def sample_pc(f, samp=1024, add_flip_augment=False): 
    '''
    f: path to csv file
    '''
    data = torch.from_numpy(pd.read_csv(f, sep=',',header=None).values).float()
    pc = data[data[:,-1]==0][:,:3]
    pc_idx = torch.randperm(pc.shape[0])[:samp] 
    pc = pc[pc_idx]

    if add_flip_augment:
        pcs = []
        flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=pc.device)
        for idx, axis in enumerate(flip_axes):
            pcs.append(pc * axis)
        return pcs


    return pc

def perturb_point_cloud(pc, perturb, pc_size=None, crop_percent=0.25):
    '''
    if pc_size is None, return entire pc; else return with shape of pc_size
    '''
    assert perturb in [None, "partial", "noisy"]
    if perturb is None:
        pc_idx = torch.randperm(pc.shape[1])[:pc_size] 
        pc = pc[:,pc_idx]   
        #print("pc shape: ", pc.shape)
        return pc
    elif perturb == "partial":
        return crop_pc(pc, crop_percent, pc_size)
    elif perturb == "noisy":
        return jitter_pc(pc, pc_size)

def fps(data, number):
    '''
        data B N 3
        number int
    '''
    fps_idx = pointnet2_utils.furthest_point_sample(data, number) 
    fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
    return fps_data

def crop_pc(xyz, crop, pc_size=None, fixed_points = None, padding_zeros = False):
    '''
     crop the point cloud given a randomly selected view
     input point cloud: xyz, with shape (B, N, 3)
     crop: float, percentage of points to crop out (e.g. 0.25 means keep 75% of points)
     pc_size: integer value, how many points to return; None if return all (all meaning xyz size * crop)
    '''
    

    if pc_size is not None:
        xyz = xyz[:, torch.randperm(xyz.shape[1])[:pc_size] ]
    
    _,n,c = xyz.shape
    device = xyz.device
        
    crop = int(xyz.shape[1]*crop)
    #print("pc shape: ", xyz.shape, crop)

    
    assert c == 3
    if crop == n:
        return xyz # , None
        
    INPUT = []
    CROP = []
    for points in xyz:
        if isinstance(crop,list):
            num_crop = random.randint(crop[0],crop[1])
        else:
            num_crop = crop

        points = points.unsqueeze(0)

        if fixed_points is None:       
            center = F.normalize(torch.randn(1,1,3, device=device),p=2,dim=-1)
        else:
            if isinstance(fixed_points,list):
                fixed_point = random.sample(fixed_points,1)[0]
            else:
                fixed_point = fixed_points
            center = fixed_point.reshape(1,1,3).to(device)

        distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1)  # 1 1 2048

        idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048

        if padding_zeros:
            input_data = points.clone()
            input_data[0, idx[:num_crop]] =  input_data[0,idx[:num_crop]] * 0

        else:
            input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3

        crop_data =  points.clone()[0, idx[:num_crop]].unsqueeze(0)

        if isinstance(crop,list):
            INPUT.append(fps(input_data,2048))
            CROP.append(fps(crop_data,2048))
        else:
            INPUT.append(input_data)
            CROP.append(crop_data)

    input_data = torch.cat(INPUT,dim=0)# B N 3
    crop_data = torch.cat(CROP,dim=0)# B M 3

    return input_data.contiguous() #, crop_data.contiguous()



def visualize_pc(pc):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pc.reshape(-1,3))
    o3d.io.write_point_cloud("./pc.ply", pcd)
    #o3d.visualization.draw_geometries([pcd])

def jitter_pc(pc, pc_size=None, sigma=0.05, clip=0.1):
    device = pc.device
    pc += torch.clamp(sigma*torch.randn(*pc.shape, device=device), -1*clip, clip)
    if pc_size is not None:
        if len(pc.shape) == 3: # B, N, 3
            pc = pc[:, torch.randperm(pc.shape[1])[:pc_size] ]
        else: # N, 3
            pc = pc[torch.randperm(pc.shape[0])[:pc_size] ]

    return pc


def normalize_pc(pc):
    pc -= torch.mean(pc, axis=0)
    m = torch.max(torch.sqrt(torch.sum(pc**2, axis=1)))
    #bbox_length = torch.sqrt( torch.sum((torch.max(pc, axis=0)[0] - torch.min(pc, axis=0)[0])**2) )
    pc /= m
    return pc

def save_model(iters, model, optimizer, loss, path):
    torch.save({'iters': iters,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss}, 
                path)

def load_model(model, optimizer, path):
    checkpoint = torch.load(path)
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        optimizer = None
    
    model.load_state_dict(checkpoint['model_state_dict'])
    loss = checkpoint['loss']
    iters = checkpoint['iters']
    print("loading from iter {}...".format(iters))
    return iters, model, optimizer, loss


def save_code_to_conf(conf_dir):
    path = os.path.join(conf_dir, "code")
    os.makedirs(path, exist_ok=True)
    for folder in ["utils", "models", "diff_utils", "dataloader", "metrics"]: 
        os.makedirs(os.path.join(path, folder), exist_ok=True)
        os.system("""cp -r ./{0}/* "{1}" """.format(folder, os.path.join(path, folder)))

    # other files
    os.system("""cp *.py "{}" """.format(path))

class ScheduledOpt:
    '''
    optimizer = ScheduledOpt(4000, torch.optim.Adam(model.parameters(), lr=0))
    '''
    "Optim wrapper that implements rate."
    def __init__(self, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self._rate = 0
    
    # def state_dict(self):
    #     """Returns the state of the warmup scheduler as a :class:`dict`.
    #     It contains an entry for every variable in self.__dict__ which
    #     is not the optimizer.
    #     """
    #     return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
    
    # def load_state_dict(self, state_dict):
    #     """Loads the warmup scheduler's state.
    #     Arguments:
    #         state_dict (dict): warmup scheduler state. Should be an object returned
    #             from a call to :meth:`state_dict`.
    #     """
    #     self.__dict__.update(state_dict) 
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        #print("rate: ",rate)

    def zero_grad(self):
        self.optimizer.zero_grad()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step

        warm_schedule = torch.linspace(0, 3e-4, self.warmup, dtype = torch.float64)
        if step < self.warmup:
            return warm_schedule[step]
        else:
            return 3e-4 / (math.sqrt(step-self.warmup+1))

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def cycle(dl):
    while True:
        for data in dl:
            yield data

def has_int_squareroot(num):
    return (math.sqrt(num) ** 2) == num

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def convert_image_to(img_type, image):
    if image.mode != img_type:
        return image.convert(img_type)
    return image

# normalization functions

#from 0,1 to -1,1
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1

# from -1,1 to 0,1
def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

# from any batch to [0,1]
# f should have shape (batch, -1)
def normalize_to_zero_to_one(f):
    f -= f.min(1, keepdim=True)[0]
    f /= f.max(1, keepdim=True)[0]
    return f


# extract the appropriate t index for a batch of indices
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def linear_beta_schedule(timesteps):
    #print("using LINEAR schedule")
    scale = 1000 / timesteps
    beta_start = scale * 0.0001 
    beta_end = scale * 0.02 
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    #print("using COSINE schedule")
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)
    cos_in = ((x / timesteps) + s) / (1 + s) * math.pi * 0.5
    np_in = cos_in.numpy()
    alphas_cumprod = np.cos(np_in)  ** 2
    alphas_cumprod = torch.from_numpy(alphas_cumprod)
    #alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])

    return torch.clip(betas, 0, 0.999)



================================================
FILE: diff_utils/model_utils.py
================================================
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum 

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape

from .pointnet.pointnet_classifier import PointNetClassifier
from .pointnet.conv_pointnet import ConvPointnet
from .pointnet.dgcnn import DGCNN

from .helpers import *

class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5, stable = False):
        super().__init__()
        self.eps = eps
        self.stable = stable
        self.g = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        if self.stable:
            x = x / x.amax(dim = -1, keepdim = True).detach()

        var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = -1, keepdim = True)
        return (x - mean) * (var + self.eps).rsqrt() * self.g

# mlp

class MLP(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        *,
        expansion_factor = 2.,
        depth = 2,
        norm = False,
    ):
        super().__init__()
        hidden_dim = int(expansion_factor * dim_out)
        norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()

        layers = [nn.Sequential(
            nn.Linear(dim_in, hidden_dim),
            nn.SiLU(),
            norm_fn()
        )]

        for _ in range(depth - 1):
            layers.append(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.SiLU(),
                norm_fn()
            ))

        layers.append(nn.Linear(hidden_dim, dim_out))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x.float())

# relative positional bias for causal transformer

class RelPosBias(nn.Module):
    def __init__(
        self,
        heads = 8,
        num_buckets = 32,
        max_distance = 128,
    ):
        super().__init__()
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.relative_attention_bias = nn.Embedding(num_buckets, heads)

    @staticmethod
    def _relative_position_bucket(
        relative_position,
        num_buckets = 32,
        max_distance = 128
    ):
        n = -relative_position
        n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
        return torch.where(is_small, n, val_if_large)

    def forward(self, i, j, *, device):
        q_pos = torch.arange(i, dtype = torch.long, device = device)
        k_pos = torch.arange(j, dtype = torch.long, device = device)
        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
        values = self.relative_attention_bias(rp_bucket)
        return rearrange(values, 'i j h -> h i j')

# feedforward

class SwiGLU(nn.Module):
    """ used successfully in https://arxiv.org/abs/2204.0231 """
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return x * F.silu(gate)

def FeedForward(
    dim,
    out_dim = None,
    mult = 4,
    dropout = 0.,
    post_activation_norm = False
):
    """ post-activation norm https://arxiv.org/abs/2110.09456 """

    #print("dropout: ", dropout)
    out_dim = default(out_dim, dim)
    #print("out_dim: ", out_dim)
    inner_dim = int(mult * dim)
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        SwiGLU(),
        LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, out_dim, bias = False)
    )

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

class Attention(nn.Module):
    def __init__(
        self,
        dim,
        kv_dim=None,
        *,
        out_dim = None,
        dim_head = 64,
        heads = 8,
        dropout = 0.,
        causal = False,
        rotary_emb = None,
        pb_relax_alpha = 128
    ):
        super().__init__()
        self.pb_relax_alpha = pb_relax_alpha
        self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)

        self.heads = heads
        inner_dim = dim_head * heads
        kv_dim = default(kv_dim, dim)

        self.causal = causal

        self.norm = LayerNorm(dim)

        self.dropout = nn.Dropout(dropout)

        self.null_kv = nn.Parameter(torch.randn(2, dim_head))
        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(kv_dim, dim_head * 2, bias = False)

        self.rotary_emb = rotary_emb

        out_dim = default(out_dim, dim)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, out_dim, bias = False),
            LayerNorm(out_dim)
        )

    def forward(self, x, context=None, mask = None, attn_bias = None):
        b, n, device = *x.shape[:2], x.device

        context = default(context, x) #self attention if context is None 

        x = self.norm(x)
        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
        q = q * self.scale

        # rotary embeddings

        if exists(self.rotary_emb):
            q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))

        # add null key / value for classifier free guidance in prior net

        nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
        k = torch.cat((nk, k), dim = -2)
        v = torch.cat((nv, v), dim = -2)

        # calculate query / key similarities

        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # relative positional encoding (T5 style)
        #print("attn bias, sim shapes: ", attn_bias.shape, sim.shape)
        if exists(attn_bias):
            sim = sim + attn_bias

        # masking

        max_neg_value = -torch.finfo(sim.dtype).max

        if exists(mask):
            mask = F.pad(mask, (1, 0), value = True)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, max_neg_value)

        if self.causal:
            i, j = sim.shape[-2:]
            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
            sim = sim.masked_fill(causal_mask, max_neg_value)

        # attention

        sim = sim - sim.amax(dim = -1, keepdim = True).detach()
        sim = sim * self.pb_relax_alpha

        attn = sim.softmax(dim = -1)
        attn = self.dropout(attn)

        # aggregate values

        out = einsum('b h i j, b j d -> b h i d', attn, v)

        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)



================================================
FILE: diff_utils/pointnet/__init__.py
================================================


================================================
FILE: diff_utils/pointnet/conv_pointnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

from torch_scatter import scatter_mean, scatter_max


class ConvPointnet(nn.Module):
    ''' PointNet-based encoder network with ResNet blocks for each point.
        Number of input points are fixed.
    
    Args:
        c_dim (int): dimension of latent code c
        dim (int): input points dimension
        hidden_dim (int): hidden dimension of the network
        scatter_type (str): feature aggregation when doing local pooling
        unet (bool): weather to use U-Net
        unet_kwargs (str): U-Net parameters
        plane_resolution (int): defined resolution for plane feature
        plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
        padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
        n_blocks (int): number of blocks ResNetBlockFC layers
    '''

    def __init__(self, c_dim=512, dim=3, hidden_dim=128, scatter_type='max', 
                 unet=True, unet_kwargs={"depth": 4, "merge_mode": "concat", "start_filts": 32}, 
                 plane_resolution=64, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
        super().__init__()
        self.c_dim = c_dim

        self.fc_pos = nn.Linear(dim, 2*hidden_dim)
        self.blocks = nn.ModuleList([
            ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
        ])
        self.fc_c = nn.Linear(hidden_dim, c_dim)

        self.actvn = nn.ReLU()
        self.hidden_dim = hidden_dim

        if unet:
            self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
        else:
            self.unet = None

        self.reso_plane = plane_resolution
        self.plane_type = plane_type
        self.padding = padding

        if scatter_type == 'max':
            self.scatter = scatter_max
        elif scatter_type == 'mean':
            self.scatter = scatter_mean


    # takes in "p": point cloud and "query": sdf_xyz 
    # sample plane features for unlabeled_query as well 
    def forward(self, p, query):
        batch_size, T, D = p.size()

        # acquire the index for each point
        coord = {}
        index = {}
        if 'xz' in self.plane_type:
            coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
            index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
        if 'xy' in self.plane_type:
            coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
            index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
        if 'yz' in self.plane_type:
            coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
            index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)

        
        net = self.fc_pos(p)

        net = self.blocks[0](net)
        for block in self.blocks[1:]:
            pooled = self.pool_local(coord, index, net)
            net = torch.cat([net, pooled], dim=2)
            net = block(net)

        c = self.fc_c(net)

        fea = {}
        plane_feat_sum = 0
        if 'xz' in self.plane_type:
            fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
            plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
        if 'xy' in self.plane_type:
            fea['xy'] = self.generate_plane_features(p, c, plane='xy')
            plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
        if 'yz' in self.plane_type:
            fea['yz'] = self.generate_plane_features(p, c, plane='yz')
            plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')

        return plane_feat_sum.transpose(2,1)


    def normalize_coordinate(self, p, padding=0.1, plane='xz'):
        ''' Normalize coordinate to [0, 1] for unit cube experiments

        Args:
            p (tensor): point
            padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
            plane (str): plane feature type, ['xz', 'xy', 'yz']
        '''
        if plane == 'xz':
            xy = p[:, :, [0, 2]]
        elif plane =='xy':
            xy = p[:, :, [0, 1]]
        else:
            xy = p[:, :, [1, 2]]

        xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
        xy_new = xy_new + 0.5 # range (0, 1)

        # f there are outliers out of the range
        if xy_new.max() >= 1:
            xy_new[xy_new >= 1] = 1 - 10e-6
        if xy_new.min() < 0:
            xy_new[xy_new < 0] = 0.0
        return xy_new


    def coordinate2index(self, x, reso):
        ''' Normalize coordinate to [0, 1] for unit cube experiments.
            Corresponds to our 3D model

        Args:
            x (tensor): coordinate
            reso (int): defined resolution
            coord_type (str): coordinate type
        '''
        x = (x * reso).long()
        index = x[:, :, 0] + reso * x[:, :, 1]
        index = index[:, None, :]
        return index


    # xy is the normalized coordinates of the point cloud of each plane 
    # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input 
    def pool_local(self, xy, index, c):
        bs, fea_dim = c.size(0), c.size(2)
        keys = xy.keys()

        c_out = 0
        for key in keys:
            # scatter plane features from points
            fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
            if self.scatter == scatter_max:
                fea = fea[0]
            # gather feature back to points
            fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
            c_out += fea
        return c_out.permute(0, 2, 1)


    def generate_plane_features(self, p, c, plane='xz'):
        # acquire indices of features in plane
        xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
        index = self.coordinate2index(xy, self.reso_plane)

        # scatter plane features from points
        fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
        c = c.permute(0, 2, 1) # B x 512 x T
        fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
        fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)

        # process the plane features with UNet
        if self.unet is not None:
            fea_plane = self.unet(fea_plane)

        return fea_plane


    # sample_plane_feature function copied from /src/conv_onet/models/decoder.py
    # uses values from plane_feature and pixel locations from vgrid to interpolate feature
    def sample_plane_feature(self, query, plane_feature, plane):
        xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
        xy = xy[:, :, None].float()
        vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
        sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
        return sampled_feat


def conv3x3(in_channels, out_channels, stride=1, 
            padding=1, bias=True, groups=1):    
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)

def upconv2x2(in_channels, out_channels, mode='transpose'):
    if mode == 'transpose':
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2)
    else:
        # out_channels is always going to be the same
        # as in_channels
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))

def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1)


class DownConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 MaxPool.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, pooling=True):
        super(DownConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool


class UpConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 UpConvolution.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, 
                 merge_mode='concat', up_mode='transpose'):
        super(UpConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels, self.out_channels, 
            mode=self.up_mode)

        if self.merge_mode == 'concat':
            self.conv1 = conv3x3(
                2*self.out_channels, self.out_channels)
        else:
            # num of input channels to conv2 is same
            self.conv1 = conv3x3(self.out_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)


    def forward(self, from_down, from_up):
        """ Forward pass
        Arguments:
            from_down: tensor from the encoder pathway
            from_up: upconv'd tensor from the decoder pathway
        """
        from_up = self.upconv(from_up)
        if self.merge_mode == 'concat':
            x = torch.cat((from_up, from_down), 1)
        else:
            x = from_up + from_down
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x


class UNet(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597

    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).

    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')
    """

    def __init__(self, num_classes, in_channels=3, depth=5, 
                 start_filts=64, up_mode='transpose', 
                 merge_mode='concat', **kwargs):
        """
        Arguments:
            in_channels: int, number of channels in the input tensor.
                Default is 3 for RGB images.
            depth: int, number of MaxPools in the U-Net.
            start_filts: int, number of convolutional filters for the 
                first conv.
            up_mode: string, type of upconvolution. Choices: 'transpose'
                for transpose convolution or 'upsample' for nearest neighbour
                upsampling.
        """
        super(UNet, self).__init__()

        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for "
                             "upsampling. Only \"transpose\" and "
                             "\"upsample\" are allowed.".format(up_mode))
    
        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for"
                             "merging up and down paths. "
                             "Only \"concat\" and "
                             "\"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible "
                             "with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use "
                             "nearest neighbour to reduce "
                             "depth channels (by half).")

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts*(2**i)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, up_mode=up_mode,
                merge_mode=merge_mode)
            self.up_convs.append(up_conv)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.conv_final = conv1x1(outs, self.num_classes)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal_(m.weight)
            init.constant_(m.bias, 0)


    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)


    def forward(self, x):
        encoder_outs = []
        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)
        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)
        
        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x

# Resnet Blocks
class ResnetBlockFC(nn.Module):
    ''' Fully connected ResNet Block class.
    Args:
        size_in (int): input dimension
        size_out (int): output dimension
        size_h (int): hidden dimension
    '''

    def __init__(self, size_in, size_out=None, size_h=None):
        super().__init__()
        # Attributes
        if size_out is None:
            size_out = size_in

        if size_h is None:
            size_h = min(size_in, size_out)

        self.size_in = size_in
        self.size_h = size_h
        self.size_out = size_out
        # Submodules
        self.fc_0 = nn.Linear(size_in, size_h)
        self.fc_1 = nn.Linear(size_h, size_out)
        self.actvn = nn.ReLU()

        if size_in == size_out:
            self.shortcut = None
        else:
            self.shortcut = nn.Linear(size_in, size_out, bias=False)
        # Initialization
        nn.init.zeros_(self.fc_1.weight)

    def forward(self, x):
        net = self.fc_0(self.actvn(x))
        dx = self.fc_1(self.actvn(net))

        if self.shortcut is not None:
            x_s = self.shortcut(x)
        else:
            x_s = x

        return x_s + dx

================================================
FILE: diff_utils/pointnet/dgcnn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()

    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx

def get_graph_feature(x, k=20):
    idx = knn(x, k=k)
    batch_size, num_points, _ = idx.size()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims)  
                                       # -> (batch_size*num_points, num_dims) 
                                       #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
    return feature

class DGCNN(nn.Module):

    def __init__(
        self, 
        emb_dims=512,
        use_bn=False,
        output_channels=100 # number of categories to predict 
    ):

        super().__init__()

        if use_bn:
            print("using batch norm")
            self.bn1 = nn.BatchNorm2d(64)
            self.bn2 = nn.BatchNorm2d(64)
            self.bn3 = nn.BatchNorm2d(128)
            self.bn4 = nn.BatchNorm2d(256)
            self.bn5 = nn.BatchNorm2d(emb_dims)

            self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))
            self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))
            self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))
            self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), self.bn4, nn.LeakyReLU(negative_slope=0.2))
            self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2))

        else:
            self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
            self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
            self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
            self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
            self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(emb_dims*2, 512, bias=False)
        self.bn6 = nn.BatchNorm1d(512)
        self.dp1 = nn.Dropout(p=0.5)
        self.linear2 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dp2 = nn.Dropout(p=0.5)
        self.linear3 = nn.Linear(256, output_channels)


    def forward(self, x):
        batch_size, num_dims, num_points = x.size()                 # x:      batch x   3 x num of points
        x = get_graph_feature(x)                                    # x:      batch x   6 x num of points x 20

        x1     = self.conv1(x)                                      # x1:     batch x  64 x num of points x 20
        x1_max = x1.max(dim=-1, keepdim=True)[0]                    # x1_max: batch x  64 x num of points x 1

        x2     = self.conv2(x1)                                     # x2:     batch x  64 x num of points x 20
        x2_max = x2.max(dim=-1, keepdim=True)[0]                    # x2_max: batch x  64 x num of points x 1

        x3     = self.conv3(x2)                                     # x3:     batch x 128 x num of points x 20
        x3_max = x3.max(dim=-1, keepdim=True)[0]                    # x3_max: batch x 128 x num of points x 1

        x4     = self.conv4(x3)                                     # x4:     batch x 256 x num of points x 20
        x4_max = x4.max(dim=-1, keepdim=True)[0]                    # x4_max: batch x 256 x num of points x 1
 
        x_max  = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max:  batch x 512 x num of points x 1

        point_feat = torch.squeeze(self.conv5(x_max), dim=3)        # point feat:  batch x 512 x num of points

        #global_feat = point_feat.max(dim=2, keepdim=False)[0]       # global feat: batch x 512

        x1 = F.adaptive_max_pool1d(point_feat, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x2 = F.adaptive_avg_pool1d(point_feat, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x = torch.cat((x1, x2), 1)              # (batch_size, emb_dims*2)

        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)
        x = self.dp1(x)
        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)
        x = self.dp2(x)
        x = self.linear3(x)                                             # (batch_size, 256) -> (batch_size, output_channels)


        return x

    def get_global_feature(self, x):
        batch_size, num_dims, num_points = x.size()                 # x:      batch x   3 x num of points
        x = get_graph_feature(x)                                    # x:      batch x   6 x num of points x 20

        x1     = self.conv1(x)                                      # x1:     batch x  64 x num of points x 20
        x1_max = x1.max(dim=-1, keepdim=True)[0]                    # x1_max: batch x  64 x num of points x 1

        x2     = self.conv2(x1)                                     # x2:     batch x  64 x num of points x 20
        x2_max = x2.max(dim=-1, keepdim=True)[0]                    # x2_max: batch x  64 x num of points x 1

        x3     = self.conv3(x2)                                     # x3:     batch x 128 x num of points x 20
        x3_max = x3.max(dim=-1, keepdim=True)[0]                    # x3_max: batch x 128 x num of points x 1

        x4     = self.conv4(x3)                                     # x4:     batch x 256 x num of points x 20
        x4_max = x4.max(dim=-1, keepdim=True)[0]                    # x4_max: batch x 256 x num of points x 1
 
        x_max  = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max:  batch x 512 x num of points x 1

        point_feat = torch.squeeze(self.conv5(x_max), dim=3)        # point feat:  batch x 512 x num of points

        #global_feat = point_feat.max(dim=2, keepdim=False)[0]       # global feat: batch x 512

        x1 = F.adaptive_max_pool1d(point_feat, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x2 = F.adaptive_avg_pool1d(point_feat, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)
        x = torch.cat((x1, x2), 1)              # (batch_size, emb_dims*2)
        return x

================================================
FILE: diff_utils/pointnet/pointnet_base.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad

from .transformer import Transformer


##-----------------------------------------------------------------------------
# Class for PointNetBase. Subclasses PyTorch's own "nn" module
#
# Computes the local embeddings and global features for an input set of points
##
class PointNetBase(nn.Module):

    def __init__(self, num_points=2000, K=3):
        # Call the super constructor
        super(PointNetBase, self).__init__()

        # Input transformer for K-dimensional input
        # K should be 3 for XYZ coordinates, but can be larger if normals, 
        # colors, etc are included
        self.input_transformer = Transformer(num_points, K)

        # Embedding transformer is always going to be 64 dimensional
        self.embedding_transformer = Transformer(num_points, 64)

		# Multilayer perceptrons with shared weights are implemented as 
		# convolutions. This is because we are mapping from K inputs to 64 
		# outputs, so we can just consider each of the 64 K-dim filters as 
		# describing the weight matrix for each point dimension (X,Y,Z,...) to
		# each index of the 64 dimension embeddings
        self.mlp1 = nn.Sequential(
            nn.Conv1d(K, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU())

        self.mlp2 = nn.Sequential(
            nn.Conv1d(64, 64, 1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 1024, 1),
            nn.BatchNorm1d(1024),
            nn.ReLU())


	# Take as input a B x K x N matrix of B batches of N points with K 
	# dimensions
    def forward(self, x):

        # Number of points put into the network
        N = x.shape[2]

        # First compute the input data transform and transform the data
        # T1 is B x K x K and x is B x K x N, so output is B x K x N
        T1 = self.input_transformer(x)
        x = torch.bmm(T1, x)

        # Run the transformed inputs through the first embedding MLP
        # Output is B x 64 x N
        x = self.mlp1(x)

        # Transform the embeddings. This gives us the "local embedding" 
        # referred to in the paper/slides
        # T2 is B x 64 x 64 and x is B x 64 x N, so output is B x 64 x N
        T2 = self.embedding_transformer(x)
        local_embedding = torch.bmm(T2, x)

        # Further embed the "local embeddings"
        # Output is B x 1024 x N
        global_feature = self.mlp2(local_embedding)

        # Pool over the number of points. This results in the "global feature"
        # referred to in the paper/slides
        # Output should be B x 1024 x 1 --> B x 1024 (after squeeze)
        global_feature = F.max_pool1d(global_feature, N).squeeze(2)

        return global_feature, local_embedding, T2



================================================
FILE: diff_utils/pointnet/pointnet_classifier.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad

from .pointnet_base import PointNetBase


##-----------------------------------------------------------------------------
# Class for PointNetClassifier. Subclasses PyTorch's own "nn" module
#
# Computes the local embeddings and global features for an input set of points
##
class PointNetClassifier(nn.Module):

	def __init__(self, num_points=2000, K=3):
		# Call the super constructor
		super(PointNetClassifier, self).__init__()

		# Local and global feature extractor for PointNet
		self.base = PointNetBase(num_points, K)

		# Classifier for ShapeNet
		self.classifier = nn.Sequential(
			nn.Linear(1024, 512),
			nn.BatchNorm1d(512),
			nn.ReLU(),
			nn.Dropout(0.7),
			nn.Linear(512, 256),
			nn.BatchNorm1d(256),
			nn.ReLU(),
			nn.Dropout(0.7),
			nn.Linear(256, 40))


	# Take as input a B x K x N matrix of B batches of N points with K 
	# dimensions
	def forward(self, x):

		# Only need to keep the global feature descriptors for classification
		# Output should be B x 1024
		global_feature, local_embedding, T2 = self.base(x)

		# first attempt: only use the global feature
		#return global_feature

		# second attempt: concat local and global feature similar to segmentation network but create the downsampling MLP in the diffusion model
		# local embedding shape: B x 64 x N; global feature shape: B x 1024; concat to get B x 1088 x N 
		num_points = local_embedding.shape[-1]
		global_feature = global_feature.unsqueeze(-1).repeat(1,1,num_points)
		point_features = torch.cat( (global_feature, local_embedding), dim=1 ) # shape is B x 1088 x N 
		return point_features
		
		

		# Returns a B x 40 
		#return self.classifier(x), T2

================================================
FILE: diff_utils/pointnet/transformer.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as grad



##-----------------------------------------------------------------------------
# Class for Transformer. Subclasses PyTorch's own "nn" module
#
# Computes a KxK affine transform from the input data to transform inputs
# to a "canonical view"
##
class Transformer(nn.Module):

	def __init__(self, num_points=2000, K=3):
		# Call the super constructor
		super(Transformer, self).__init__()

		# Number of dimensions of the data
		self.K = K

		# Size of input
		self.N = num_points

		# Initialize identity matrix on the GPU (do this here so it only 
		# happens once)
		self.identity = grad.Variable(
			torch.eye(self.K).double().view(-1).cuda())

		# First embedding block
		self.block1 =nn.Sequential(
			nn.Conv1d(K, 64, 1),
			nn.BatchNorm1d(64),
			nn.ReLU())

		# Second embedding block
		self.block2 =nn.Sequential(
			nn.Conv1d(64, 128, 1),
			nn.BatchNorm1d(128),
			nn.ReLU())

		# Third embedding block
		self.block3 =nn.Sequential(
			nn.Conv1d(128, 1024, 1),
			nn.BatchNorm1d(1024),
			nn.ReLU())

		# Multilayer perceptron
		self.mlp = nn.Sequential(
			nn.Linear(1024, 512),
			nn.BatchNorm1d(512),
			nn.ReLU(),
			nn.Linear(512, 256),
			nn.BatchNorm1d(256),
			nn.ReLU(),
			nn.Linear(256, K * K))


	# Take as input a B x K x N matrix of B batches of N points with K 
	# dimensions
	def forward(self, x):

		# Compute the feature extractions
		# Output should ultimately be B x 1024 x N
		x = self.block1(x)
		x = self.block2(x)
		x = self.block3(x)

		# Pool over the number of points
		# Output should be B x 1024 x 1 --> B x 1024 (after squeeze)
		x = F.max_pool1d(x, self.N).squeeze(2)
		
		# Run the pooled features through the multi-layer perceptron
		# Output should be B x K^2
		x = self.mlp(x)

		# Add identity matrix to transform
		# Output is still B x K^2 (broadcasting takes care of batch dimension)
		x += self.identity

		# Reshape the output into B x K x K affine transformation matrices
		x = x.view(-1, self.K, self.K)

		return x



================================================
FILE: diff_utils/sdf_utils.py
================================================
import math
import torch
import json 
import torch.nn.functional as F
from torch import nn, einsum 
import os

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from sdf_model.model import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# load base network
specs_path = "sdf_model/config/siren/specs.json"
specs = json.load(open(specs_path))
model = MetaSDF(specs).to(device)
checkpoint = torch.load("sdf_model/config/siren/last.ckpt", map_location=device)
model.load_state_dict(checkpoint['state_dict'])

for p in model.parameters():
    p.requires_grad=False

def pred_sdf_loss(x0, idx):

    pc, xyz, gt = sdf_sampling(gt_files[idx], 16000, batch=x0.shape[0])
    xyz = xyz.to(device)
    gt = gt.to(device)
    sdf_loss = functional_sdf_model(x0, xyz, gt)

    return sdf_loss


def functional_sdf_model(modulation, xyz, gt):
    '''
    modulation: modulation vector with shape 512
    xyz: query points, input to the model, dim= 16000x3
    gt: ground truth; calculate l1 loss with prediction, dim= 16000x1

    return: sdf_loss 
    '''
    #print("shapes: ", modulation.shape, xyz.shape)
    pred_sdf = model(modulation, xyz)
    sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze())

    return sdf_loss


def sdf_sampling(f, subsample, pc_size=1024, batch=1):
    # f=pd.read_csv(f, sep=',',header=None).values
    # f = torch.from_numpy(f)

    pcs = torch.empty(batch, pc_size, 3)
    xyz = torch.empty(batch, subsample, 3)
    gt = torch.empty(batch, subsample)

    for i in range(batch):
        half = int(subsample / 2) 
        neg_tensor = f[f[:,-1]<0]
        pos_tensor = f[f[:,-1]>0]

        if pos_tensor.shape[0] < half:
            pos_idx = torch.randint(pos_tensor.shape[0], (half,))
        else:
            pos_idx = torch.randperm(pos_tensor.shape[0])[:half]

        if neg_tensor.shape[0] < half:
            neg_idx = torch.randint(neg_tensor.shape[0], (half,))
        else:
            neg_idx = torch.randperm(neg_tensor.shape[0])[:half]

        pos_sample = pos_tensor[pos_idx]
        neg_sample = neg_tensor[neg_idx]

        pc = f[f[:,-1]==0][:,:3]
        pc_idx = torch.randperm(pc.shape[0])[:pc_size]
        pc = pc[pc_idx]

        samples = torch.cat([pos_sample, neg_sample], 0)


        pcs[i] = pc.float()
        xyz[i] = samples[:,:3].float()
        gt[i] = samples[:, 3].float()


    return pcs, xyz, gt

def apply_to_sdf(f, x):
    for idx, l in enumerate(x):
        x[idx] = f(l)
    return x

================================================
FILE: environment.yml
================================================
name: diffusionsdf
channels:
  - pytorch
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - blas=1.0=mkl
  - brotlipy=0.7.0=py39h27cfd23_1003
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2022.4.26=h06a4308_0
  - certifi=2022.5.18.1=py39h06a4308_0
  - cffi=1.15.0=py39hd667e15_1
  - charset-normalizer=2.0.4=pyhd3eb1b0_0
  - cryptography=37.0.1=py39h9ce1e76_0
  - cudatoolkit=11.3.1=h2bc3f7f_2
  - ffmpeg=4.3=hf484d3e_0
  - freetype=2.11.0=h70c0345_0
  - giflib=5.2.1=h7b6447c_0
  - gmp=6.2.1=h295c915_3
  - gnutls=3.6.15=he1e5248_0
  - idna=3.3=pyhd3eb1b0_0
  - intel-openmp=2021.4.0=h06a4308_3561
  - jpeg=9e=h7f8727e_0
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.3=he6710b0_2
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libiconv=1.16=h7f8727e_2
  - libidn2=2.3.2=h7f8727e_0
  - libpng=1.6.37=hbc83047_0
  - libstdcxx-ng=11.2.0=h1234567_1
  - libtasn1=4.16.0=h27cfd23_0
  - libtiff=4.2.0=h2818925_1
  - libunistring=0.9.10=h27cfd23_0
  - libuv=1.40.0=h7b6447c_0
  - libwebp=1.2.2=h55f646e_0
  - libwebp-base=1.2.2=h7f8727e_0
  - lz4-c=1.9.3=h295c915_1
  - mkl=2021.4.0=h06a4308_640
  - mkl-service=2.4.0=py39h7f8727e_0
  - mkl_fft=1.3.1=py39hd3c417c_0
  - mkl_random=1.2.2=py39h51133e4_0
  - ncurses=6.3=h7f8727e_2
  - nettle=3.7.3=hbbd107a_1
  - numpy=1.22.3=py39he7a7128_0
  - numpy-base=1.22.3=py39hf524024_0
  - openh264=2.1.1=h4ff587b_0
  - openssl=1.1.1o=h7f8727e_0
  - pillow=9.0.1=py39h22f2fdc_0
  - pip=21.2.4=py39h06a4308_0
  - pycparser=2.21=pyhd3eb1b0_0
  - pyopenssl=22.0.0=pyhd3eb1b0_0
  - pysocks=1.7.1=py39h06a4308_0
  - python=3.9.12=h12debd9_1
  - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0
  - pytorch-mutex=1.0=cuda
  - readline=8.1.2=h7f8727e_1
  - requests=2.27.1=pyhd3eb1b0_0
  - setuptools=61.2.0=py39h06a4308_0
  - six=1.16.0=pyhd3eb1b0_1
  - sqlite=3.38.3=hc218d9a_0
  - tk=8.6.12=h1ccaba5_0
  - torchaudio=0.11.0=py39_cu113
  - torchvision=0.12.0=py39_cu113
  - typing_extensions=4.1.1=pyh06a4308_0
  - tzdata=2022a=hda174b7_0
  - urllib3=1.26.9=py39h06a4308_0
  - wheel=0.37.1=pyhd3eb1b0_0
  - xz=5.2.5=h7f8727e_1
  - zlib=1.2.12=h7f8727e_2
  - zstd=1.5.2=ha4553b6_0
  - pip:
    - absl-py==1.1.0
    - addict==2.4.0
    - aiohttp==3.8.1
    - aiosignal==1.2.0
    - asttokens==2.2.1
    - async-timeout==4.0.2
    - attrs==21.4.0
    - backcall==0.2.0
    - cachetools==5.2.0
    - click==8.1.3
    - comm==0.1.2
    - configargparse==1.5.3
    - contourpy==1.0.7
    - cycler==0.11.0
    - dash==2.8.1
    - dash-core-components==2.0.0
    - dash-html-components==2.0.0
    - dash-table==5.0.0
    - debugpy==1.6.6
    - decorator==5.1.1
    - einops==0.6.0
    - einops-exts==0.0.4
    - executing==1.2.0
    - fastjsonschema==2.16.2
    - flask==2.2.2
    - fonttools==4.38.0
    - frozenlist==1.3.0
    - fsspec==2022.5.0
    - google-auth==2.6.6
    - google-auth-oauthlib==0.4.6
    - grpcio==1.46.3
    - imageio==2.19.3
    - importlib-metadata==4.11.4
    - ipykernel==6.21.1
    - ipython==8.10.0
    - ipywidgets==8.0.4
    - itsdangerous==2.1.2
    - jedi==0.18.2
    - jinja2==3.1.2
    - joblib==1.2.0
    - jsonschema==4.17.3
    - jupyter-client==8.0.2
    - jupyter-core==5.2.0
    - jupyterlab-widgets==3.0.5
    - kiwisolver==1.4.4
    - markdown==3.3.7
    - markupsafe==2.1.2
    - matplotlib==3.6.3
    - matplotlib-inline==0.1.6
    - multidict==6.0.2
    - nbformat==5.5.0
    - nest-asyncio==1.5.6
    - networkx==2.8.2
    - oauthlib==3.2.0
    - open3d==0.16.0
    - packaging==21.3
    - pandas==1.4.2
    - parso==0.8.3
    - pexpect==4.8.0
    - pickleshare==0.7.5
    - platformdirs==3.0.0
    - plotly==5.13.0
    - plyfile==0.7.4
    - prompt-toolkit==3.0.36
    - protobuf==3.20.1
    - psutil==5.9.4
    - ptyprocess==0.7.0
    - pure-eval==0.2.2
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - pydeprecate==0.3.2
    - pygments==2.14.0
    - pyparsing==3.0.9
    - pyquaternion==0.9.9
    - pyrsistent==0.19.3
    - python-dateutil==2.8.2
    - pytorch-lightning==1.6.4
    - pytz==2022.1
    - pywavelets==1.3.0
    - pyyaml==6.0
    - pyzmq==25.0.0
    - requests-oauthlib==1.3.1
    - rotary-embedding-torch==0.2.1
    - rsa==4.8
    - scikit-image==0.19.2
    - scikit-learn==1.2.1
    - scipy==1.8.1
    - stack-data==0.6.2
    - tenacity==8.2.1
    - tensorboard==2.9.0
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.1
    - threadpoolctl==3.1.0
    - tifffile==2022.5.4
    - torch-scatter==2.0.9
    - torchmetrics==0.9.0
    - tornado==6.2
    - tqdm==4.64.0
    - traitlets==5.9.0
    - trimesh==3.12.5
    - wcwidth==0.2.6
    - werkzeug==2.2.2
    - widgetsnbextension==4.0.5
    - yarl==1.7.2
    - zipp==3.8.0
prefix: /home/gchou/.conda/envs/diffusion


================================================
FILE: metrics/StructuralLosses/__init__.py
================================================
#import torch

#from MakePytorchBackend import AddGPU, Foo, ApproxMatch

#from Add import add_gpu, approx_match



================================================
FILE: metrics/StructuralLosses/match_cost.py
================================================
import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad

# Inherit from Function
class MatchCostFunction(Function):
    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, seta, setb):
        #print("Match Cost Forward")
        ctx.save_for_backward(seta, setb)
        '''
        input:
	        set1 : batch_size * #dataset_points * 3
	        set2 : batch_size * #query_points * 3
        returns:
	        match : batch_size * #query_points * #dataset_points
        '''
        match, temp = ApproxMatch(seta, setb)
        ctx.match = match
        cost = MatchCost(seta, setb, match)
        return cost

    """
    grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
	return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
	"""
    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        #print("Match Cost Backward")
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        seta, setb = ctx.saved_tensors
        #grad_input = grad_weight = grad_bias = None
        grada, gradb = MatchCostGrad(seta, setb, ctx.match)
        grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
        return grada*grad_output_expand, gradb*grad_output_expand

match_cost = MatchCostFunction.apply



================================================
FILE: metrics/StructuralLosses/nn_distance.py
================================================
import torch
from torch.autograd import Function
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad

# Inherit from Function
class NNDistanceFunction(Function):
    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, seta, setb):
        #print("Match Cost Forward")
        ctx.save_for_backward(seta, setb)
        '''
        input:
	        set1 : batch_size * #dataset_points * 3
	        set2 : batch_size * #query_points * 3
        returns:
	        dist1, idx1, dist2, idx2
        '''
        dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
        ctx.idx1 = idx1
        ctx.idx2 = idx2
        return dist1, dist2

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_dist1, grad_dist2):
        #print("Match Cost Backward")
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        seta, setb = ctx.saved_tensors
        idx1 = ctx.idx1
        idx2 = ctx.idx2
        grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
        return grada, gradb

nn_distance = NNDistanceFunction.apply



================================================
FILE: metrics/__init__.py
================================================


================================================
FILE: metrics/evaluation_metrics.py
================================================
import torch
import numpy as np
import warnings
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
from numpy.linalg import norm
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm


# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet
def distChamfer(a, b):
    x, y = a, b
    bs, num_points, points_dim = x.size()
    xx = torch.bmm(x, x.transpose(2, 1))
    yy = torch.bmm(y, y.transpose(2, 1))
    zz = torch.bmm(x, y.transpose(2, 1))
    diag_ind = torch.arange(0, num_points).to(a).long()
    rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)
    ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)
    P = (rx.transpose(2, 1) + ry - 2 * zz)
    return P.min(1)[0], P.min(2)[0]

# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/
from metrics.StructuralLosses.nn_distance import nn_distance
try:
    
    #from pytorch_structural_losses.nn_distance import nn_distance
    #print("cuda available")
    def distChamferCUDA(x, y):
        return nn_distance(x, y)
except:
    print("distChamferCUDA not available; fall back to slower version.")
    # def distChamferCUDA(x, y):
    #     return distChamfer(x, y)


def emd_approx(x, y):
    bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2)
    assert npts == mpts, "EMD only works if two point clouds are equal size"
    dim = x.shape[-1]
    x = x.reshape(bs, npts, 1, dim)
    y = y.reshape(bs, 1, mpts, dim)
    dist = (x - y).norm(dim=-1, keepdim=False)  # (bs, npts, mpts)

    emd_lst = []
    dist_np = dist.cpu().detach().numpy()
    for i in range(bs):
        d_i = dist_np[i]
        r_idx, c_idx = linear_sum_assignment(d_i)
        emd_i = d_i[r_idx, c_idx].mean()
        emd_lst.append(emd_i)
    emd = np.stack(emd_lst).reshape(-1)
    emd_torch = torch.from_numpy(emd).to(x)
    return emd_torch


try:
    from metrics.StructuralLosses.match_cost import match_cost
    #print("cuda available")
    def emd_approx_cuda(sample, ref):
        B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)
        assert N == N_ref, "Not sure what would EMD do in this case"
        emd = match_cost(sample, ref)  # (B,)
        emd_norm = emd / float(N)  # (B,)
        return emd_norm
except:
    print("emd_approx_cuda not available. Fall back to slower version.")
    # def emd_approx_cuda(sample, ref):
        #return emd_approx(sample, ref)


def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True,
           accelerated_emd=False):
    N_sample = sample_pcs.shape[0]
    N_ref = ref_pcs.shape[0]
    assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample)

    cd_lst = []
    emd_lst = []
    iterator = range(0, N_sample, batch_size)

    for b_start in iterator:
        b_end = min(N_sample, b_start + batch_size)
        sample_batch = sample_pcs[b_start:b_end]
        ref_batch = ref_pcs[b_start:b_end]

        if accelerated_cd:
            dl, dr = distChamferCUDA(sample_batch, ref_batch)
        else:
            dl, dr = distChamfer(sample_batch, ref_batch)
        cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))

        if accelerated_emd:
            emd_batch = emd_approx_cuda(sample_batch, ref_batch)
        else:
            emd_batch = emd_approx(sample_batch, ref_batch)
        emd_lst.append(emd_batch)

    if reduced:
        cd = torch.cat(cd_lst).mean()
        emd = torch.cat(emd_lst).mean()
    else:
        cd = torch.cat(cd_lst)
        emd = torch.cat(emd_lst)

    results = {
        'MMD-CD': cd,
        'MMD-EMD': emd,
    }
    return results


def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True,
                      accelerated_emd=True):
    batch_size = sample_pcs.shape[0] if batch_size is None else batch_size
    N_sample = sample_pcs.shape[0]
    N_ref = ref_pcs.shape[0]
    all_cd = []
    all_emd = []
    iterator = range(N_sample)
    with tqdm(iterator) as pbar:
        for sample_b_start in pbar:
            pbar.set_description("Files evaluated: {}/{}".format(sample_b_start, N_sample))
            sample_batch = sample_pcs[sample_b_start]

            cd_lst = []
            emd_lst = []
            for ref_b_start in range(0, N_ref, batch_size):
                ref_b_end = min(N_ref, ref_b_start + batch_size)
                ref_batch = ref_pcs[ref_b_start:ref_b_end]

                batch_size_ref = ref_batch.size(0)
                sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)
                sample_batch_exp = sample_batch_exp.contiguous()

                if accelerated_cd and distChamferCUDA is not None:
                    dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)
                else:
                    dl, dr = distChamfer(sample_batch_exp, ref_batch)
                cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))

                # if accelerated_emd:
                #     emd_batch = emd_approx_cuda(sample_batch_exp, ref_batch)
                # else:
                #     emd_batch = emd_approx(sample_batch_exp, ref_batch)
                # emd_lst.append(emd_batch.view(1, -1))

            cd_lst = torch.cat(cd_lst, dim=1)
            #emd_lst = torch.cat(emd_lst, dim=1)
            all_cd.append(cd_lst)
            #all_emd.append(emd_lst)

    all_cd = torch.cat(all_cd, dim=0)  # N_sample, N_ref
    #all_emd = torch.cat(all_emd, dim=0)  # N_sample, N_ref

    return all_cd, 0#all_emd

def emd_tmd_from_pcs(gen_pcs):
    sum_dist = 0
    for j in range(len(gen_pcs)):
        for k in range(j + 1, len(gen_pcs), 1):
            pc1 = gen_pcs[j]
            pc2 = gen_pcs[k]
            chamfer_dist = emd_approx_cuda(pc1, pc2) #compute_trimesh_chamfer(pc1, pc2)
            print("emd dist: ", chamfer_dist)
            sum_dist += chamfer_dist
    mean_dist = sum_dist * 2 / (len(gen_pcs) - 1)
    return mean_dist


# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py
def knn(Mxx, Mxy, Myy, k, sqrt=False):
    n0 = Mxx.size(0)
    n1 = Myy.size(0)
    label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx)
    M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)
    if sqrt:
        M = M.abs().sqrt()
    INFINITY = float('inf')
    val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)

    count = torch.zeros(n0 + n1).to(Mxx)
    for i in range(0, k):
        count = count + label.index_select(0, idx[i])
    pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()

    s = {
        'tp': (pred * label).sum(),
        'fp': (pred * (1 - label)).sum(),
        'fn': ((1 - pred) * label).sum(),
        'tn': ((1 - pred) * (1 - label)).sum(),
    }

    s.update({
        'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),
        'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
        'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),
        'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),
        'acc': torch.eq(label, pred).float().mean(),
    })
    return s


def lgan_mmd_cov(all_dist):
    # all dist shape = [number of generated pcs, number of ref pcs]
    N_sample, N_ref = all_dist.size(0), all_dist.size(1)
    min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)
    min_val, _ = torch.min(all_dist, dim=0)
    mmd = min_val.mean()
    #mmd_smp = min_val_fromsmp.mean()
    cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)
    cov = torch.tensor(cov).to(all_dist)
    return {
        'lgan_mmd': mmd,
        'lgan_cov': cov,
        #'lgan_mmd_smp': mmd_smp,
    }


def compute_mmd(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True):

    results = {}
    batch_size = sample_pcs.shape[0] if batch_size is None else batch_size

    M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)

    res_cd = lgan_mmd_cov(M_rs_cd.t())
    # results.update({
    #     "%s-CD" % k: v for k, v in res_cd.items()
    # })

    #res_emd = lgan_mmd_cov(M_rs_emd.t())
    # results.update({
    #     "%s-EMD" % k: v for k, v in res_emd.items()
    # })

    return res_cd,0# res_emd


def compute_cd(sample_pcs, ref_pcs):
    
    dl, dr = distChamferCUDA(sample_pcs, ref_pcs)
    print("dl, dr shapes: ", dl.shape, dr.shape, dl.mean(dim=1).shape)
    res = (dl.mean(dim=1) + dr.mean(dim=1))#.view(1, -1)
    return res


def compute_all_metrics(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True):
    results = {}

    batch_size = sample_pcs.shape[0] if batch_size is None else batch_size

    M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)

    res_cd = lgan_mmd_cov(M_rs_cd.t())
    results.update({
        "%s-CD" % k: v for k, v in res_cd.items()
    })

    res_emd = lgan_mmd_cov(M_rs_emd.t())
    results.update({
        "%s-EMD" % k: v for k, v in res_emd.items()
    })

    M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd)

    M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)

    # 1-NN results
    one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)
    results.update({
        "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k
    })
    one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)
    results.update({
        "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k
    })

    return res_cd, res_emd, one_nn_cd_res, one_nn_emd_res


#######################################################
# JSD : from https://github.com/optas/latent_3d_points
#######################################################
def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
    """Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,
    that is placed in the unit-cube.
    If clip_sphere it True it drops the "corner" cells that lie outside the unit-sphere.
    """
    grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)
    spacing = 1.0 / float(resolution - 1)
    for i in range(resolution):
        for j in range(resolution):
            for k in range(resolution):
                grid[i, j, k, 0] = i * spacing - 0.5
                grid[i, j, k, 1] = j * spacing - 0.5
                grid[i, j, k, 2] = k * spacing - 0.5

    if clip_sphere:
        grid = grid.reshape(-1, 3)
        grid = grid[norm(grid, axis=1) <= 0.5]

    return grid, spacing


def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
    """Computes the JSD between two sets of point-clouds, as introduced in the paper
    ```Learning Representations And Generative Models For 3D Point Clouds```.
    Args:
        sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.
        ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.
        resolution: (int) grid-resolution. Affects granularity of measurements.
    """
    in_unit_sphere = True
    sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]
    ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]
    return jensen_shannon_divergence(sample_grid_var, ref_grid_var)


def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False):
    """Given a collection of point-clouds, estimate the entropy of the random variables
    corresponding to occupancy-grid activation patterns.
    Inputs:
        pclouds: (numpy array) #point-clouds x points per point-cloud x 3
        grid_resolution (int) size of occupancy grid that will be used.
    """
    epsilon = 10e-4
    bound = 0.5 + epsilon
    if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:
        if verbose:
            warnings.warn('Point-clouds are not in unit cube.')

    if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:
        if verbose:
            warnings.warn('Point-clouds are not in unit sphere.')

    grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)
    grid_coordinates = grid_coordinates.reshape(-1, 3)
    grid_counters = np.zeros(len(grid_coordinates))
    grid_bernoulli_rvars = np.zeros(len(grid_coordinates))
    nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)

    for pc in pclouds:
        _, indices = nn.kneighbors(pc)
        indices = np.squeeze(indices)
        for i in indices:
            grid_counters[i] += 1
        indices = np.unique(indices)
        for i in indices:
            grid_bernoulli_rvars[i] += 1

    acc_entropy = 0.0
    n = float(len(pclouds))
    for g in grid_bernoulli_rvars:
        if g > 0:
            p = float(g) / n
            acc_entropy += entropy([p, 1.0 - p])

    return acc_entropy / len(grid_counters), grid_counters


def jensen_shannon_divergence(P, Q):
    if np.any(P < 0) or np.any(Q < 0):
        raise ValueError('Negative values.')
    if len(P) != len(Q):
        raise ValueError('Non equal size.')

    P_ = P / np.sum(P)  # Ensure probabilities.
    Q_ = Q / np.sum(Q)

    e1 = entropy(P_, base=2)
    e2 = entropy(Q_, base=2)
    e_sum = entropy((P_ + Q_) / 2.0, base=2)
    res = e_sum - ((e1 + e2) / 2.0)

    res2 = _jsdiv(P_, Q_)

    if not np.allclose(res, res2, atol=10e-5, rtol=0):
        warnings.warn('Numerical values of two JSD methods don\'t agree.')

    return res


def _jsdiv(P, Q):
    """another way of computing JSD"""

    def _kldiv(A, B):
        a = A.copy()
        b = B.copy()
        idx = np.logical_and(a > 0, b > 0)
        a = a[idx]
        b = b[idx]
        return np.sum([v for v in a * np.log2(a / b)])

    P_ = P / np.sum(P)
    Q_ = Q / np.sum(Q)

    M = 0.5 * (P_ + Q_)

    return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))


if __name__ == "__main__":
    #from pytorch_structural_losses.nn_distance import nn_distance # need to make
    #from metrics.StructuralLosses.nn_distance import nn_distance
    B, N = 100, 2048
    x = torch.rand(B, N, 3)
    y = torch.rand(B, N, 3)

    #get_dist = nn_distance # if cuda available

    get_dist = distChamfer # cuda not available

    #distChamfer = distChamferCUDA
    min_l, min_r = get_dist(x.cuda(), y.cuda())
    print(min_l.shape)
    print(min_r.shape)

    l_dist = min_l.mean().cpu().detach().item()
    r_dist = min_r.mean().cpu().detach().item()
    print(l_dist, r_dist)


================================================
FILE: metrics/pytorch_structural_losses/.gitignore
================================================
PyTorchStructuralLosses.egg-info/


================================================
FILE: metrics/pytorch_structural_losses/Makefile
================================================
###############################################################################
# Uncomment for debugging
# DEBUG := 1
# Pretty build
# Q ?= @

CXX := g++
PYTHON := python
NVCC := /usr/local/cuda/bin/nvcc

# PYTHON Header path
PYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())')
PYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]')
PYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]')

# CUDA ROOT DIR that contains bin/ lib64/ and include/
# CUDA_DIR := /usr/local/cuda
CUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())')

INCLUDE_DIRS := ./ $(CUDA_DIR)/include

INCLUDE_DIRS += $(PYTHON_HEADER_DIR)
INCLUDE_DIRS += $(PYTORCH_INCLUDES)

# Custom (MKL/ATLAS/OpenBLAS) include and lib directories.
# Leave commented to accept the defaults for your choice of BLAS
# (which should work)!
# BLAS_INCLUDE := /path/to/your/blas
# BLAS_LIB := /path/to/your/blas

###############################################################################
SRC_DIR := ./src
OBJ_DIR := ./objs
CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp)
CU_SRCS := $(wildcard $(SRC_DIR)/*.cu)
OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS))
CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS))
STATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a

# CUDA architecture setting: going with all of them.
# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility.
# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility.
CUDA_ARCH :=	-gencode arch=compute_61,code=sm_61 \
		-gencode arch=compute_61,code=compute_61 \
		-gencode arch=compute_52,code=sm_52

# We will also explicitly add stdc++ to the link target.
LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu

# Debugging
ifeq ($(DEBUG), 1)
	COMMON_FLAGS += -DDEBUG -g -O0
	# https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/
	NVCCFLAGS += -g -G # -rdc true
else
	COMMON_FLAGS += -DNDEBUG -O3
endif

WARNINGS := -Wall -Wno-sign-compare -Wcomment

INCLUDE_DIRS += $(BLAS_INCLUDE)

# Automatic dependency generation (nvcc is handled separately)
CXXFLAGS += -std=c++14 -MMD -MP

# Complete build flags.
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \
	     -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0
CXXFLAGS += -pthread -fPIC -fwrapv -std=c++14 $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS += -std=c++14 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)

all: $(STATIC_LIB)
	$(PYTHON) setup.py build
	@ mv build/lib.linux-x86_64-cpython-39/StructuralLosses ..
	#@ mv build/lib.linux-x86_64-3.6/StructuralLosses ..
	@ mv build/lib.linux-x86_64-cpython-39/*.so ../StructuralLosses/
	#@ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/
	@- $(RM) -rf $(OBJ_DIR) build objs

$(OBJ_DIR):
	@ mkdir -p $@
	@ mkdir -p $@/cuda

$(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR)
	@ echo CXX $<
	$(Q)$(CXX) $< $(CXXFLAGS) -c -o $@

$(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR)
	@ echo NVCC $<
	$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \
		-odir $(@D)
	$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@

$(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR)
	$(RM) -f $(STATIC_LIB)
	$(RM) -rf build dist
	@ echo LD -o $@
	ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS)

clean:
	@- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses



================================================
FILE: metrics/pytorch_structural_losses/StructuralLosses/__init__.py
================================================
#import torch

#from MakePytorchBackend import AddGPU, Foo, ApproxMatch

#from Add import add_gpu, approx_match



================================================
FILE: metrics/pytorch_structural_losses/StructuralLosses/match_cost.py
================================================
import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad

# Inherit from Function
class MatchCostFunction(Function):
    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, seta, setb):
        #print("Match Cost Forward")
        ctx.save_for_backward(seta, setb)
        '''
        input:
	        set1 : batch_size * #dataset_points * 3
	        set2 : batch_size * #query_points * 3
        returns:
	        match : batch_size * #query_points * #dataset_points
        '''
        match, temp = ApproxMatch(seta, setb)
        ctx.match = match
        cost = MatchCost(seta, setb, match)
        return cost

    """
    grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
	return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
	"""
    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        #print("Match Cost Backward")
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        seta, setb = ctx.saved_tensors
        #grad_input = grad_weight = grad_bias = None
        grada, gradb = MatchCostGrad(seta, setb, ctx.match)
        grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
        return grada*grad_output_expand, gradb*grad_output_expand

match_cost = MatchCostFunction.apply



================================================
FILE: metrics/pytorch_structural_losses/StructuralLosses/nn_distance.py
================================================
import torch
from torch.autograd import Function
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
from metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad

# Inherit from Function
class NNDistanceFunction(Function):
    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, seta, setb):
        #print("Match Cost Forward")
        ctx.save_for_backward(seta, setb)
        '''
        input:
	        set1 : batch_size * #dataset_points * 3
	        set2 : batch_size * #query_points * 3
        returns:
	        dist1, idx1, dist2, idx2
        '''
        dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
        ctx.idx1 = idx1
        ctx.idx2 = idx2
        return dist1, dist2

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_dist1, grad_dist2):
        #print("Match Cost Backward")
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        seta, setb = ctx.saved_tensors
        idx1 = ctx.idx1
        idx2 = ctx.idx2
        grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
        return grada, gradb

nn_distance = NNDistanceFunction.apply



================================================
FILE: metrics/pytorch_structural_losses/__init__.py
================================================
#import torch

#from MakePytorchBackend import AddGPU, Foo, ApproxMatch

#from Add import add_gpu, approx_match



================================================
FILE: metrics/pytorch_structural_losses/match_cost.py
================================================
import torch
from torch.autograd import Function
from metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad

# Inherit from Function
class MatchCostFunction(Function):
    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, seta, setb):
        #print("Match Cost Forward")
        ctx.save_for_backward(seta, setb)
        '''
        input:
	        set1 : batch_size * #dataset_points * 3
	        set2 : batch_size * #query_points * 3
        returns:
	        match : batch_size * #query_points * #dataset_points
        '''
        match, temp = ApproxMatch(seta, setb)
        ctx.match = match
        cost = MatchCost(seta, setb, match)
        return cost

    """
    grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)
	return [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]
	"""
    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        #print("Match Cost Backward")
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        seta, setb = ctx.saved_tensors
        #grad_input = grad_weight = grad_bias = None
        grada, gradb = MatchCostGrad(seta, setb, ctx.match)
        grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)
        return grada*grad_output_expand, gradb*grad_output_expand

match_cost = MatchCostFunction.apply



================================================
FILE: metrics/pytorch_structural_losses/nn_distance.py
================================================
import torch
from torch.autograd import Function
# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad
from .StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad

# Inherit from Function
class NNDistanceFunction(Function):
    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, seta, setb):
        #print("Match Cost Forward")
        ctx.save_for_backward(seta, setb)
        '''
        input:
	        set1 : batch_size * #dataset_points * 3
	        set2 : batch_size * #query_points * 3
        returns:
	        dist1, idx1, dist2, idx2
        '''
        dist1, idx1, dist2, idx2 = NNDistance(seta, setb)
        ctx.idx1 = idx1
        ctx.idx2 = idx2
        return dist1, dist2

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_dist1, grad_dist2):
        #print("Match Cost Backward")
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        seta, setb = ctx.saved_tensors
        idx1 = ctx.idx1
        idx2 = ctx.idx2
        grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)
        return grada, gradb

nn_distance = NNDistanceFunction.apply



================================================
FILE: metrics/pytorch_structural_losses/pybind/bind.cpp
================================================
#include <string>

#include <torch/extension.h>

#include "pybind/extern.hpp"

namespace py = pybind11;

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
  m.def("ApproxMatch", &ApproxMatch);
  m.def("MatchCost", &MatchCost);
  m.def("MatchCostGrad", &MatchCostGrad);
  m.def("NNDistance", &NNDistance);
  m.def("NNDistanceGrad", &NNDistanceGrad);
}


================================================
FILE: metrics/pytorch_structural_losses/pybind/extern.hpp
================================================
std::vector<at::Tensor> ApproxMatch(at::Tensor in_a, at::Tensor in_b);
at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match);
std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match);

std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q);
std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2);


================================================
FILE: metrics/pytorch_structural_losses/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension

# Python interface
setup(
    name='PyTorchStructuralLosses',
    version='0.1.0',
    install_requires=['torch'],
    packages=['StructuralLosses'],
    package_dir={'StructuralLosses': './'},
    ext_modules=[
        CUDAExtension(
            name='StructuralLossesBackend',
            include_dirs=['./'],
            sources=[
                'pybind/bind.cpp',
            ],
            libraries=['make_pytorch'],
            library_dirs=['objs'],
            # extra_compile_args=['-g']
        )
    ],
    cmdclass={'build_ext': BuildExtension},
    author='Christopher B. Choy',
    author_email='chrischoy@ai.stanford.edu',
    description='Tutorial for Pytorch C++ Extension with a Makefile',
    keywords='Pytorch C++ Extension',
    url='https://github.com/chrischoy/MakePytorchPlusPlus',
    zip_safe=False,
)


================================================
FILE: metrics/pytorch_structural_losses/src/approxmatch.cu
================================================
#include "utils.hpp"

__global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){
	float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;
	float multiL,multiR;
	if (n>=m){
		multiL=1;
		multiR=n/m;
	}else{
		multiL=m/n;
		multiR=1;
	}
	const int Block=1024;
	__shared__ float buf[Block*4];
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		for (int j=threadIdx.x;j<n*m;j+=blockDim.x)
			match[i*n*m+j]=0;
		for (int j=threadIdx.x;j<n;j+=blockDim.x)
			remainL[j]=multiL;
		for (int j=threadIdx.x;j<m;j+=blockDim.x)
			remainR[j]=multiR;
		__syncthreads();
		//for (int j=7;j>=-2;j--){
		for (int j=7;j>-2;j--){
			float level=-powf(4.0f,j);
			if (j==-2){
				level=0;
			}
			for (int k0=0;k0<n;k0+=blockDim.x){
				int k=k0+threadIdx.x;
				float x1=0,y1=0,z1=0;
				if (k<n){
					x1=xyz1[i*n*3+k*3+0];
					y1=xyz1[i*n*3+k*3+1];
					z1=xyz1[i*n*3+k*3+2];
				}
				float suml=1e-9f;
				for (int l0=0;l0<m;l0+=Block){
					int lend=min(m,l0+Block)-l0;
					for (int l=threadIdx.x;l<lend;l+=blockDim.x){
						float x2=xyz2[i*m*3+l0*3+l*3+0];
						float y2=xyz2[i*m*3+l0*3+l*3+1];
						float z2=xyz2[i*m*3+l0*3+l*3+2];
						buf[l*4+0]=x2;
						buf[l*4+1]=y2;
						buf[l*4+2]=z2;
						buf[l*4+3]=remainR[l0+l];
					}
					__syncthreads();
					for (int l=0;l<lend;l++){
						float x2=buf[l*4+0];
						float y2=buf[l*4+1];
						float z2=buf[l*4+2];
						float d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));
						float w=__expf(d)*buf[l*4+3];
						suml+=w;
					}
					__syncthreads();
				}
				if (k<n)
					ratioL[k]=remainL[k]/suml;
			}
			/*for (int k=threadIdx.x;k<n;k+=gridDim.x){
				float x1=xyz1[i*n*3+k*3+0];
				float y1=xyz1[i*n*3+k*3+1];
				float z1=xyz1[i*n*3+k*3+2];
				float suml=1e-9f;
				for (int l=0;l<m;l++){
					float x2=xyz2[i*m*3+l*3+0];
					float y2=xyz2[i*m*3+l*3+1];
					float z2=xyz2[i*m*3+l*3+2];
					float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*remainR[l];
					suml+=w;
				}
				ratioL[k]=remainL[k]/suml;
			}*/
			__syncthreads();
			for (int l0=0;l0<m;l0+=blockDim.x){
				int l=l0+threadIdx.x;
				float x2=0,y2=0,z2=0;
				if (l<m){
					x2=xyz2[i*m*3+l*3+0];
					y2=xyz2[i*m*3+l*3+1];
					z2=xyz2[i*m*3+l*3+2];
				}
				float sumr=0;
				for (int k0=0;k0<n;k0+=Block){
					int kend=min(n,k0+Block)-k0;
					for (int k=threadIdx.x;k<kend;k+=blockDim.x){
						buf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];
						buf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];
						buf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];
						buf[k*4+3]=ratioL[k0+k];
					}
					__syncthreads();
					for (int k=0;k<kend;k++){
						float x1=buf[k*4+0];
						float y1=buf[k*4+1];
						float z1=buf[k*4+2];
						float w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];
						sumr+=w;
					}
					__syncthreads();
				}
				if (l<m){
					sumr*=remainR[l];
					float consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
					ratioR[l]=consumption*remainR[l];
					remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
				}
			}
			/*for (int l=threadIdx.x;l<m;l+=blockDim.x){
				float x2=xyz2[i*m*3+l*3+0];
				float y2=xyz2[i*m*3+l*3+1];
				float z2=xyz2[i*m*3+l*3+2];
				float sumr=0;
				for (int k=0;k<n;k++){
					float x1=xyz1[i*n*3+k*3+0];
					float y1=xyz1[i*n*3+k*3+1];
					float z1=xyz1[i*n*3+k*3+2];
					float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k];
					sumr+=w;
				}
				sumr*=remainR[l];
				float consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);
				ratioR[l]=consumption*remainR[l];
				remainR[l]=fmaxf(0.0f,remainR[l]-sumr);
			}*/
			__syncthreads();
			for (int k0=0;k0<n;k0+=blockDim.x){
				int k=k0+threadIdx.x;
				float x1=0,y1=0,z1=0;
				if (k<n){
					x1=xyz1[i*n*3+k*3+0];
					y1=xyz1[i*n*3+k*3+1];
					z1=xyz1[i*n*3+k*3+2];
				}
				float suml=0;
				for (int l0=0;l0<m;l0+=Block){
					int lend=min(m,l0+Block)-l0;
					for (int l=threadIdx.x;l<lend;l+=blockDim.x){
						buf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];
						buf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];
						buf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];
						buf[l*4+3]=ratioR[l0+l];
					}
					__syncthreads();
					float rl=ratioL[k];
					if (k<n){
						for (int l=0;l<lend;l++){
							float x2=buf[l*4+0];
							float y2=buf[l*4+1];
							float z2=buf[l*4+2];
							float w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];
							match[i*n*m+(l0+l)*n+k]+=w;
							suml+=w;
						}
					}
					__syncthreads();
				}
				if (k<n)
					remainL[k]=fmaxf(0.0f,remainL[k]-suml);
			}
			/*for (int k=threadIdx.x;k<n;k+=blockDim.x){
				float x1=xyz1[i*n*3+k*3+0];
				float y1=xyz1[i*n*3+k*3+1];
				float z1=xyz1[i*n*3+k*3+2];
				float suml=0;
				for (int l=0;l<m;l++){
					float x2=xyz2[i*m*3+l*3+0];
					float y2=xyz2[i*m*3+l*3+1];
					float z2=xyz2[i*m*3+l*3+2];
					float w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k]*ratioR[l];
					match[i*n*m+l*n+k]+=w;
					suml+=w;
				}
				remainL[k]=fmaxf(0.0f,remainL[k]-suml);
			}*/
			__syncthreads();
		}
	}
}

__global__ void matchcostkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){
	__shared__ float allsum[512];
	const int Block=256;
	__shared__ float buf[Block*3];
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		float subsum=0;
		for (int k0=0;k0<m;k0+=Block){
			int endk=min(m,k0+Block);
			for (int k=threadIdx.x;k<(endk-k0)*3;k+=blockDim.x){
				buf[k]=xyz2[i*m*3+k0*3+k];
			}
			__syncthreads();
			for (int j=threadIdx.x;j<n;j+=blockDim.x){
				float x1=xyz1[(i*n+j)*3+0];
				float y1=xyz1[(i*n+j)*3+1];
				float z1=xyz1[(i*n+j)*3+2];
				for (int k=0;k<endk-k0;k++){
					//float x2=xyz2[(i*m+k)*3+0]-x1;
					//float y2=xyz2[(i*m+k)*3+1]-y1;
					//float z2=xyz2[(i*m+k)*3+2]-z1;
					float x2=buf[k*3+0]-x1;
					float y2=buf[k*3+1]-y1;
					float z2=buf[k*3+2]-z1;
					float d=sqrtf(x2*x2+y2*y2+z2*z2);
					subsum+=match[i*n*m+(k0+k)*n+j]*d;
				}
			}
			__syncthreads();
		}
		allsum[threadIdx.x]=subsum;
		for (int j=1;j<blockDim.x;j<<=1){
			__syncthreads();
			if ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){
				allsum[threadIdx.x]+=allsum[threadIdx.x+j];
			}
		}
		if (threadIdx.x==0)
			out[i]=allsum[0];
		__syncthreads();
	}
}
//void matchcostLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * out){
//	matchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);
//}

__global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){
	__shared__ float sum_grad[256*3];
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		int kbeg=m*blockIdx.y/gridDim.y;
		int kend=m*(blockIdx.y+1)/gridDim.y;
		for (int k=kbeg;k<kend;k++){
			float x2=xyz2[(i*m+k)*3+0];
			float y2=xyz2[(i*m+k)*3+1];
			float z2=xyz2[(i*m+k)*3+2];
			float subsumx=0,subsumy=0,subsumz=0;
			for (int j=threadIdx.x;j<n;j+=blockDim.x){
				float x1=x2-xyz1[(i*n+j)*3+0];
				float y1=y2-xyz1[(i*n+j)*3+1];
				float z1=z2-xyz1[(i*n+j)*3+2];
				float d=match[i*n*m+k*n+j]*rsqrtf(fmaxf(x1*x1+y1*y1+z1*z1,1e-20f));
				subsumx+=x1*d;
				subsumy+=y1*d;
				subsumz+=z1*d;
			}
			sum_grad[threadIdx.x*3+0]=subsumx;
			sum_grad[threadIdx.x*3+1]=subsumy;
			sum_grad[threadIdx.x*3+2]=subsumz;
			for (int j=1;j<blockDim.x;j<<=1){
				__syncthreads();
				int j1=threadIdx.x;
				int j2=threadIdx.x+j;
				if ((j1&j)==0 && j2<blockDim.x){
					sum_grad[j1*3+0]+=sum_grad[j2*3+0];
					sum_grad[j1*3+1]+=sum_grad[j2*3+1];
					sum_grad[j1*3+2]+=sum_grad[j2*3+2];
				}
			}
			if (threadIdx.x==0){
				grad2[(i*m+k)*3+0]=sum_grad[0];
				grad2[(i*m+k)*3+1]=sum_grad[1];
				grad2[(i*m+k)*3+2]=sum_grad[2];
			}
			__syncthreads();
		}
	}
}
__global__ void matchcostgrad1kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad1){
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		for (int l=threadIdx.x;l<n;l+=blockDim.x){
			float x1=xyz1[i*n*3+l*3+0];
			float y1=xyz1[i*n*3+l*3+1];
			float z1=xyz1[i*n*3+l*3+2];
			float dx=0,dy=0,dz=0;
			for (int k=0;k<m;k++){
				float x2=xyz2[i*m*3+k*3+0];
				float y2=xyz2[i*m*3+k*3+1];
				float z2=xyz2[i*m*3+k*3+2];
				float d=match[i*n*m+k*n+l]*rsqrtf(fmaxf((x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)+(z1-z2)*(z1-z2),1e-20f));
				dx+=(x1-x2)*d;
				dy+=(y1-y2)*d;
				dz+=(z1-z2)*d;
			}
			grad1[i*n*3+l*3+0]=dx;
			grad1[i*n*3+l*3+1]=dy;
			grad1[i*n*3+l*3+2]=dz;
		}
	}
}
//void matchcostgradLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad2){
//	matchcostgrad<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);
//}

/*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
                  cudaStream_t stream)*/
// temp: TensorShape{b,(n+m)*2}
void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){
	approxmatchkernel
      <<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp);
      
  cudaError_t err = cudaGetLastError();
  if (cudaSuccess != err)
    throw std::runtime_error(Formatter()
                             << "CUDA kernel failed : " << std::to_string(err));
}

void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){
	matchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out);
      
  cudaError_t err = cudaGetLastError();
  if (cudaSuccess != err)
    throw std::runtime_error(Formatter()
                             << "CUDA kernel failed : " << std::to_string(err));
}

void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){
	matchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1);
	matchcostgrad2kernel<<<dim3(32,32),256,0,stream>>>(b,n,m,xyz1,xyz2,match,grad2);
	
    cudaError_t err = cudaGetLastError();
    if (cudaSuccess != err)
        throw std::runtime_error(Formatter()
                             << "CUDA kernel failed : " << std::to_string(err));
}


================================================
FILE: metrics/pytorch_structural_losses/src/approxmatch.cuh
================================================
/*
template <typename Dtype>
void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,
                  cudaStream_t stream);
*/
void approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream);
void matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream);
void matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream);


================================================
FILE: metrics/pytorch_structural_losses/src/nndistance.cu
================================================

__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
	const int batch=512;
	__shared__ float buf[batch*3];
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		for (int k2=0;k2<m;k2+=batch){
			int end_k=min(m,k2+batch)-k2;
			for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
				buf[j]=xyz2[(i*m+k2)*3+j];
			}
			__syncthreads();
			for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
				float x1=xyz[(i*n+j)*3+0];
				float y1=xyz[(i*n+j)*3+1];
				float z1=xyz[(i*n+j)*3+2];
				int best_i=0;
				float best=0;
				int end_ka=end_k-(end_k&3);
				if (end_ka==batch){
					for (int k=0;k<batch;k+=4){
						{
							float x2=buf[k*3+0]-x1;
							float y2=buf[k*3+1]-y1;
							float z2=buf[k*3+2]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (k==0 || d<best){
								best=d;
								best_i=k+k2;
							}
						}
						{
							float x2=buf[k*3+3]-x1;
							float y2=buf[k*3+4]-y1;
							float z2=buf[k*3+5]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+1;
							}
						}
						{
							float x2=buf[k*3+6]-x1;
							float y2=buf[k*3+7]-y1;
							float z2=buf[k*3+8]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+2;
							}
						}
						{
							float x2=buf[k*3+9]-x1;
							float y2=buf[k*3+10]-y1;
							float z2=buf[k*3+11]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+3;
							}
						}
					}
				}else{
					for (int k=0;k<end_ka;k+=4){
						{
							float x2=buf[k*3+0]-x1;
							float y2=buf[k*3+1]-y1;
							float z2=buf[k*3+2]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (k==0 || d<best){
								best=d;
								best_i=k+k2;
							}
						}
						{
							float x2=buf[k*3+3]-x1;
							float y2=buf[k*3+4]-y1;
							float z2=buf[k*3+5]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+1;
							}
						}
						{
							float x2=buf[k*3+6]-x1;
							float y2=buf[k*3+7]-y1;
							float z2=buf[k*3+8]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+2;
							}
						}
						{
							float x2=buf[k*3+9]-x1;
							float y2=buf[k*3+10]-y1;
							float z2=buf[k*3+11]-z1;
							float d=x2*x2+y2*y2+z2*z2;
							if (d<best){
								best=d;
								best_i=k+k2+3;
							}
						}
					}
				}
				for (int k=end_ka;k<end_k;k++){
					float x2=buf[k*3+0]-x1;
					float y2=buf[k*3+1]-y1;
					float z2=buf[k*3+2]-z1;
					float d=x2*x2+y2*y2+z2*z2;
					if (k==0 || d<best){
						best=d;
						best_i=k+k2;
					}
				}
				if (k2==0 || result[(i*n+j)]>best){
					result[(i*n+j)]=best;
					result_i[(i*n+j)]=best_i;
				}
			}
			__syncthreads();
		}
	}
}
void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
	NmDistanceKernel<<<dim3(32,16,1),512, 0, stream>>>(b,n,xyz,m,xyz2,result,result_i);
	NmDistanceKernel<<<dim3(32,16,1),512, 0, stream>>>(b,m,xyz2,n,xyz,result2,result2_i);
}
__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
			float x1=xyz1[(i*n+j)*3+0];
			float y1=xyz1[(i*n+j)*3+1];
			float z1=xyz1[(i*n+j)*3+2];
			int j2=idx1[i*n+j];
			float x2=xyz2[(i*m+j2)*3+0];
			float y2=xyz2[(i*m+j2)*3+1];
			float z2=xyz2[(i*m+j2)*3+2];
			float g=grad_dist1[i*n+j]*2;
			atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
			atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
			atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
			atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
			atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
			atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
		}
	}
}
void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){
	cudaMemset(grad_xyz1,0,b*n*3*4);
	cudaMemset(grad_xyz2,0,b*m*3*4);
	NmDistanceGradKernel<<<dim3(1,16,1),256, 0, stream>>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2);
	NmDistanceGradKernel<<<dim3(1,16,1),256, 0, stream>>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1);
}



================================================
FILE: metrics/pytorch_structural_losses/src/nndistance.cuh
================================================
void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);


================================================
FILE: metrics/pytorch_structural_losses/src/structural_loss.cpp
================================================
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "src/approxmatch.cuh"
#include "src/nndistance.cuh"

#include <vector>
#include <iostream>

#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

/*
input:
	set1 : batch_size * #dataset_points * 3
	set2 : batch_size * #query_points * 3
returns:
	match : batch_size * #query_points * #dataset_points
*/
//  temp: TensorShape{b,(n+m)*2}
std::vector<at::Tensor> ApproxMatch(at::Tensor set_d, at::Tensor set_q) {
    //std::cout << "[ApproxMatch] Called." << std::endl;
    int64_t batch_size = set_d.size(0);    
    int64_t n_dataset_points = set_d.size(1); // n
    int64_t n_query_points = set_q.size(1);   // m
    //std::cout << "[ApproxMatch] batch_size:" << batch_size << std::endl;
    at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    CHECK_INPUT(set_d);
    CHECK_INPUT(set_q);
    CHECK_INPUT(match);
    CHECK_INPUT(temp);
    
    approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),temp.data<float>(), at::cuda::getCurrentCUDAStream());
    return {match, temp};
}

at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
    //std::cout << "[MatchCost] Called." << std::endl;
    int64_t batch_size = set_d.size(0);    
    int64_t n_dataset_points = set_d.size(1); // n
    int64_t n_query_points = set_q.size(1);   // m
    //std::cout << "[MatchCost] batch_size:" << batch_size << std::endl;
    at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    CHECK_INPUT(set_d);
    CHECK_INPUT(set_q);
    CHECK_INPUT(match);
    CHECK_INPUT(out);
    matchcost(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),out.data<float>(),at::cuda::getCurrentCUDAStream());
    return out;
}

std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {
    //std::cout << "[MatchCostGrad] Called." << std::endl;
    int64_t batch_size = set_d.size(0);    
    int64_t n_dataset_points = set_d.size(1); // n
    int64_t n_query_points = set_q.size(1);   // m
    //std::cout << "[MatchCostGrad] batch_size:" << batch_size << std::endl;
    at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    CHECK_INPUT(set_d);
    CHECK_INPUT(set_q);
    CHECK_INPUT(match);
    CHECK_INPUT(grad1);
    CHECK_INPUT(grad2);
    matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),grad1.data<float>(),grad2.data<float>(),at::cuda::getCurrentCUDAStream());
    return {grad1, grad2};
}


/*
input:
	set_d : batch_size * #dataset_points * 3
	set_q : batch_size * #query_points * 3
returns:
	dist1, idx1 : batch_size * #dataset_points
	dist2, idx2 : batch_size * #query_points
*/
std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q) {
    //std::cout << "[NNDistance] Called." << std::endl;
    int64_t batch_size = set_d.size(0);    
    int64_t n_dataset_points = set_d.size(1); // n
    int64_t n_query_points = set_q.size(1);   // m
    //std::cout << "[NNDistance] batch_size:" << batch_size << std::endl;
    at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
    at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));
    CHECK_INPUT(set_d);
    CHECK_INPUT(set_q);
    CHECK_INPUT(dist1);
    CHECK_INPUT(idx1);
    CHECK_INPUT(dist2);
    CHECK_INPUT(idx2);
    // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);
    nndistance(batch_size,n_dataset_points,set_d.data<float>(),n_query_points,set_q.data<float>(),dist1.data<float>(),idx1.data<int>(),dist2.data<float>(),idx2.data<int>(), at::cuda::getCurrentCUDAStream());
    return {dist1, idx1, dist2, idx2};
}

std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) {
    //std::cout << "[NNDistanceGrad] Called." << std::endl;
    int64_t batch_size = set_d.size(0);    
    int64_t n_dataset_points = set_d.size(1); // n
    int64_t n_query_points = set_q.size(1);   // m
    //std::cout << "[NNDistanceGrad] batch_size:" << batch_size << std::endl;
    at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));
    CHECK_INPUT(set_d);
    CHECK_INPUT(set_q);
    CHECK_INPUT(idx1);
    CHECK_INPUT(idx2);
    CHECK_INPUT(grad_dist1);
    CHECK_INPUT(grad_dist2);
    CHECK_INPUT(grad1);
    CHECK_INPUT(grad2);
    //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);
    nndistancegrad(batch_size,n_dataset_points,set_d.data<float>(),n_query_points,set_q.data<float>(),
        grad_dist1.data<float>(),idx1.data<int>(),
        grad_dist2.data<float>(),idx2.data<int>(),
        grad1.data<float>(),grad2.data<float>(),
        at::cuda::getCurrentCUDAStream());
    return {grad1, grad2};
}



================================================
FILE: metrics/pytorch_structural_losses/src/utils.hpp
================================================
#include <iostream>
#include <sstream>
#include <string>

class Formatter {
public:
  Formatter() {}
  ~Formatter() {}

  template <typename Type> Formatter &operator<<(const Type &value) {
    stream_ << value;
    return *this;
  }

  std::string str() const { return stream_.str(); }
  operator std::string() const { return stream_.str(); }

  enum ConvertToString { to_str };

  std::string operator>>(ConvertToString) { return stream_.str(); }

private:
  std::stringstream stream_;
  Formatter(const Formatter &);
  Formatter &operator=(Formatter &);
};


================================================
FILE: models/__init__.py
================================================
#!/usr/bin/python3

from models.sdf_model import SdfModel
from models.autoencoder import BetaVAE

from models.archs.encoders.conv_pointnet import UNet

from models.diffusion import *
from models.archs.diffusion_arch import * 
#from diffusion import *
from models.sdf_model import SdfModel

from models.combined_model import CombinedModel




================================================
FILE: models/archs/__init__.py
================================================
#!/usr/bin/python3



================================================
FILE: models/archs/diffusion_arch.py
================================================
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum 

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape

from rotary_embedding_torch import RotaryEmbedding

from diff_utils.model_utils import * 

from random import sample

class CausalTransformer(nn.Module):
    def __init__(
        self,
        dim, 
        depth,
        dim_in_out=None,
        cross_attn=False,
        dim_head = 64,
        heads = 8,
        ff_mult = 4,
        norm_in = False,
        norm_out = True, 
        attn_dropout = 0.,
        ff_dropout = 0.,
        final_proj = True, 
        normformer = False,
        rotary_emb = True, 
        **kwargs
    ):
        super().__init__()
        self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM

        self.rel_pos_bias = RelPosBias(heads = heads)

        rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None
        rotary_emb_cross = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None

        self.layers = nn.ModuleList([])

        dim_in_out = default(dim_in_out, dim)
        self.use_same_dims = (dim_in_out is None) or (dim_in_out==dim)
        point_feature_dim = kwargs.get('point_feature_dim', dim)

        if cross_attn:
            #print("using CROSS ATTN, with dropout {}".format(attn_dropout))
            self.layers.append(nn.ModuleList([
                    Attention(dim = dim_in_out, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
                    Attention(dim = dim, kv_dim=point_feature_dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),
                    FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
                ]))
            for _ in range(depth):
                self.layers.append(nn.ModuleList([
                    Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
                    Attention(dim = dim, kv_dim=point_feature_dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),
                    FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
                ]))
            self.layers.append(nn.ModuleList([
                    Attention(dim = dim, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),
                    Attention(dim = dim, kv_dim=point_feature_dim, out_dim=dim_in_out, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),
                    FeedForward(dim = dim_in_out, out_dim=dim_in_out, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
                ]))
        else:
            self.layers.append(nn.ModuleList([
                    Attention(dim = dim_in_out, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
                    FeedForward(dim = dim, out_dim=dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
                ]))
            for _ in range(depth):
                self.layers.append(nn.ModuleList([
                    Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
                    FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
                ]))
            self.layers.append(nn.ModuleList([
                    Attention(dim = dim, out_dim=dim_in_out, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
                    FeedForward(dim = dim_in_out, out_dim=dim_in_out, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
                ]))

        self.norm = LayerNorm(dim_in_out, stable = True) if norm_out else nn.Identity()  # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
        self.project_out = nn.Linear(dim_in_out, dim_in_out, bias = False) if final_proj else nn.Identity()

        self.cross_attn = cross_attn

    def forward(self, x, time_emb=None, context=None):
        n, device = x.shape[1], x.device

        x = self.init_norm(x)

        attn_bias = self.rel_pos_bias(n, n + 1, device = device)

        if self.cross_attn:
            #assert context is not None 
            for idx, (self_attn, cross_attn, ff) in enumerate(self.layers):
                #print("x1 shape: ", x.shape)
                if (idx==0 or idx==len(self.layers)-1) and not self.use_same_dims:
                    x = self_attn(x, attn_bias = attn_bias)
                    x = cross_attn(x, context=context) # removing attn_bias for now 
                else:
                    x = self_attn(x, attn_bias = attn_bias) + x 
                    x = cross_attn(x, context=context) + x  # removing attn_bias for now 
                #print("x2 shape, context shape: ", x.shape, context.shape)
                
                #print("x3 shape, context shape: ", x.shape, context.shape)
                x = ff(x) + x
        
        else:
            for idx, (attn, ff) in enumerate(self.layers):
                #print("x1 shape: ", x.shape)
                if (idx==0 or idx==len(self.layers)-1) and not self.use_same_dims:
                    x = attn(x, attn_bias = attn_bias)
                else:
                    x = attn(x, attn_bias = attn_bias) + x
                #print("x2 shape: ", x.shape)
                x = ff(x) + x
                #print("x3 shape: ", x.shape)

        out = self.norm(x)
        return self.project_out(out)

class DiffusionNet(nn.Module):

    def __init__(
        self,
        dim,
        dim_in_out=None,
        num_timesteps = None,
        num_time_embeds = 1,
        cond = None,
        **kwargs
    ):
        super().__init__()
        self.num_time_embeds = num_time_embeds
        self.dim = dim
        self.cond = cond
        self.cross_attn = kwargs.get('cross_attn', False)
        self.cond_dropout = kwargs.get('cond_dropout', False)
        self.point_feature_dim = kwargs.get('point_feature_dim', dim)

        self.dim_in_out = default(dim_in_out, dim)
        #print("dim, in out, point feature dim: ", dim, dim_in_out, self.point_feature_dim)
        #print("cond dropout: ", self.cond_dropout)

        self.to_time_embeds = nn.Sequential(
            nn.Embedding(num_timesteps, self.dim_in_out * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(self.dim_in_out), MLP(self.dim_in_out, self.dim_in_out * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
            Rearrange('b (n d) -> b n d', n = num_time_embeds)
        )

        # last input to the transformer: "a final embedding whose output from the Transformer is used to predicted the unnoised CLIP image embedding"
        self.learned_query = nn.Parameter(torch.randn(self.dim_in_out))
        self.causal_transformer = CausalTransformer(dim = dim, dim_in_out=self.dim_in_out, **kwargs)

        if cond:
            # output dim of pointnet needs to match model dim; unless add additional linear layer
            self.pointnet = ConvPointnet(c_dim=self.point_feature_dim) 


    def forward(
        self,
        data, 
        diffusion_timesteps,
        pass_cond=-1, # default -1, depends on prob; but pass as argument during sampling

    ):

        if self.cond:
            assert type(data) is tuple
            data, cond = data # adding noise to cond_feature so doing this in diffusion.py

            #print("data, cond shape: ", data.shape, cond.shape) # B, dim_in_out; B, N, 3
            #print("pass cond: ", pass_cond)
            if self.cond_dropout:
                # classifier-free guidance: 20% unconditional 
                prob = torch.randint(low=0, high=10, size=(1,))
                percentage = 8
                if prob < percentage or pass_cond==0:
                    cond_feature = torch.zeros( (cond.shape[0], cond.shape[1], self.point_feature_dim), device=data.device )
                    #print("zeros shape: ", cond_feature.shape) 
                elif prob >= percentage or pass_cond==1:
                    cond_feature = self.pointnet(cond, cond)
                    #print("cond shape: ", cond_feature.shape)
            else:
                cond_feature = self.pointnet(cond, cond)

            
        batch, dim, device, dtype = *data.shape, data.device, data.dtype

        num_time_embeds = self.num_time_embeds
        time_embed = self.to_time_embeds(diffusion_timesteps)

        data = data.unsqueeze(1)

        learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)

        model_inputs = [time_embed, data, learned_queries]

        if self.cond and not self.cross_attn:
            model_inputs.insert(0, cond_feature) # cond_feature defined in first loop above 
        
        tokens = torch.cat(model_inputs, dim = 1) # (b, 3/4, d); batch and d=512 same across the model_inputs 
        #print("tokens shape: ", tokens.shape)

        if self.cross_attn:
            cond_feature = None if not self.cond else cond_feature
            #print("tokens shape: ", tokens.shape, cond_feature.shape)
            tokens = self.causal_transformer(tokens, context=cond_feature)
        else:
            tokens = self.causal_transformer(tokens)

        # get learned query, which should predict the sdf layer embedding (per DDPM timestep)
        pred = tokens[..., -1, :]

        return pred



================================================
FILE: models/archs/encoders/__init__.py
================================================
#!/usr/bin/python3



================================================
FILE: models/archs/encoders/auto_decoder.py
================================================
#!/usr/bin/env python3

import torch.nn as nn
import torch
import torch.nn.functional as F
import json
import math

class AutoDecoder():
    # specs is a json filepath that contains the specifications for the experiment 
    # also requires total num_scenes (all meshes from all classes), which is given by len of dataset
    def __init__(self, num_scenes, latent_size):
        self.num_scenes = num_scenes
        self.latent_size = latent_size

    def build_model(self):
        lat_vecs = nn.Embedding(self.num_scenes, self.latent_size, max_norm=1.0)
        nn.init.normal_(lat_vecs.weight.data, 0.0, 1.0/math.sqrt(self.latent_size))
        return lat_vecs


================================================
FILE: models/archs/encoders/conv_pointnet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

from torch_scatter import scatter_mean, scatter_max


class ConvPointnet(nn.Module):
    ''' PointNet-based encoder network with ResNet blocks for each point.
        Number of input points are fixed.
    
    Args:
        c_dim (int): dimension of latent code c
        dim (int): input points dimension
        hidden_dim (int): hidden dimension of the network
        scatter_type (str): feature aggregation when doing local pooling
        unet (bool): weather to use U-Net
        unet_kwargs (str): U-Net parameters
        plane_resolution (int): defined resolution for plane feature
        plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
        padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
        n_blocks (int): number of blocks ResNetBlockFC layers
    '''

    def __init__(self, c_dim=512, dim=3, hidden_dim=128, scatter_type='max', 
                 unet=True, unet_kwargs={"depth": 4, "merge_mode": "concat", "start_filts": 32}, 
                 plane_resolution=64, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5,
                 inject_noise=False):
        super().__init__()
        self.c_dim = c_dim

        self.fc_pos = nn.Linear(dim, 2*hidden_dim)
        self.blocks = nn.ModuleList([
            ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
        ])
        self.fc_c = nn.Linear(hidden_dim, c_dim)

        self.actvn = nn.ReLU()
        self.hidden_dim = hidden_dim

        if unet:
            self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
        else:
            self.unet = None

        self.reso_plane = plane_resolution
        self.plane_type = plane_type
        self.padding = padding

        if scatter_type == 'max':
            self.scatter = scatter_max
        elif scatter_type == 'mean':
            self.scatter = scatter_mean

    def generate_plane_features(self, p, c, plane='xz'):
        # acquire indices of features in plane
        xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
        index = self.coordinate2index(xy, self.reso_plane)

        # scatter plane features from points
        fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
        c = c.permute(0, 2, 1) # B x 512 x T
        fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
        fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)

        # process the plane features with UNet
        if self.unet is not None:
            fea_plane = self.unet(fea_plane)

        return fea_plane 

    # takes in "p": point cloud and "query": sdf_xyz 
    # sample plane features for unlabeled_query as well 
    def forward(self, p, query):
        batch_size, T, D = p.size()

        # acquire the index for each point
        coord = {}
        index = {}
        if 'xz' in self.plane_type:
            coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
            index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
        if 'xy' in self.plane_type:
            coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
            index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
        if 'yz' in self.plane_type:
            coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
            index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)

        
        net = self.fc_pos(p)

        net = self.blocks[0](net)
        for block in self.blocks[1:]:
            pooled = self.pool_local(coord, index, net)
            net = torch.cat([net, pooled], dim=2)
            net = block(net)

        c = self.fc_c(net)
        
        fea = {}
        plane_feat_sum = 0
        #denoise_loss = 0
        if 'xz' in self.plane_type:
            fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
            plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
        if 'xy' in self.plane_type:
            fea['xy'] = self.generate_plane_features(p, c, plane='xy')
            plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
        if 'yz' in self.plane_type:
            fea['yz'] = self.generate_plane_features(p, c, plane='yz')
            plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')

        return plane_feat_sum.transpose(2,1)

    # given plane features with dimensions (3*dim, 64, 64)
    # first reshape into the three planes, then generate query features from it 
    def forward_with_plane_features(self, plane_features, query):
        # plane features shape: batch, dim*3, 64, 64
        idx = int(plane_features.shape[1] / 3)
        fea = {}
        fea['xz'], fea['xy'], fea['yz'] = plane_features[:,0:idx,...], plane_features[:,idx:idx*2,...], plane_features[:,idx*2:,...]
        #print("shapes: ", fea['xz'].shape, fea['xy'].shape, fea['yz'].shape) #([1, 256, 64, 64])
        plane_feat_sum = 0

        plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
        plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
        plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')

        return plane_feat_sum.transpose(2,1)

    # c is point cloud features
    # p is point cloud (coordinates)
    def forward_with_pc_features(self, c, p, query):

        #print("c, p shapes:", c.shape, p.shape)

        fea = {}
        fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
        fea['xy'] = self.generate_plane_features(p, c, plane='xy')
        fea['yz'] = self.generate_plane_features(p, c, plane='yz')

        plane_feat_sum = 0

        plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
        plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
        plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')

        return plane_feat_sum.transpose(2,1)


    def get_point_cloud_features(self, p):
        batch_size, T, D = p.size()

        # acquire the index for each point
        coord = {}
        index = {}
        if 'xz' in self.plane_type:
            coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
            index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
        if 'xy' in self.plane_type:
            coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
            index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
        if 'yz' in self.plane_type:
            coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
            index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)

        net = self.fc_pos(p)

        net = self.blocks[0](net)
        for block in self.blocks[1:]:
            pooled = self.pool_local(coord, index, net)
            net = torch.cat([net, pooled], dim=2)
            net = block(net)

        c = self.fc_c(net)

        return c

    def get_plane_features(self, p):

        c = self.get_point_cloud_features(p)
        fea = {}
        if 'xz' in self.plane_type:
            fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
        if 'xy' in self.plane_type:
            fea['xy'] = self.generate_plane_features(p, c, plane='xy')
        if 'yz' in self.plane_type:
            fea['yz'] = self.generate_plane_features(p, c, plane='yz')

        return fea['xz'], fea['xy'], fea['yz']


    def normalize_coordinate(self, p, padding=0.1, plane='xz'):
        ''' Normalize coordinate to [0, 1] for unit cube experiments

        Args:
            p (tensor): point
            padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
            plane (str): plane feature type, ['xz', 'xy', 'yz']
        '''
        if plane == 'xz':
            xy = p[:, :, [0, 2]]
        elif plane =='xy':
            xy = p[:, :, [0, 1]]
        else:
            xy = p[:, :, [1, 2]]

        xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
        xy_new = xy_new + 0.5 # range (0, 1)

        # f there are outliers out of the range
        if xy_new.max() >= 1:
            xy_new[xy_new >= 1] = 1 - 10e-6
        if xy_new.min() < 0:
            xy_new[xy_new < 0] = 0.0
        return xy_new


    def coordinate2index(self, x, reso):
        ''' Normalize coordinate to [0, 1] for unit cube experiments.
            Corresponds to our 3D model

        Args:
            x (tensor): coordinate
            reso (int): defined resolution
            coord_type (str): coordinate type
        '''
        x = (x * reso).long()
        index = x[:, :, 0] + reso * x[:, :, 1]
        index = index[:, None, :]
        return index


    # xy is the normalized coordinates of the point cloud of each plane 
    # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input 
    def pool_local(self, xy, index, c):
        bs, fea_dim = c.size(0), c.size(2)
        keys = xy.keys()

        c_out = 0
        for key in keys:
            # scatter plane features from points
            fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
            if self.scatter == scatter_max:
                fea = fea[0]
            # gather feature back to points
            fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
            c_out += fea
        return c_out.permute(0, 2, 1)

    # sample_plane_feature function copied from /src/conv_onet/models/decoder.py
    # uses values from plane_feature and pixel locations from vgrid to interpolate feature
    def sample_plane_feature(self, query, plane_feature, plane):
        xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
        xy = xy[:, :, None].float()
        vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
        sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
        return sampled_feat


def conv3x3(in_channels, out_channels, stride=1, 
            padding=1, bias=True, groups=1):    
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)

def upconv2x2(in_channels, out_channels, mode='transpose'):
    if mode == 'transpose':
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2)
    else:
        # out_channels is always going to be the same
        # as in_channels
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))

def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1)


class DownConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 MaxPool.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, pooling=True):
        super(DownConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool


class UpConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 UpConvolution.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, 
                 merge_mode='concat', up_mode='transpose'):
        super(UpConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels, self.out_channels, 
            mode=self.up_mode)

        if self.merge_mode == 'concat':
            self.conv1 = conv3x3(
                2*self.out_channels, self.out_channels)
        else:
            # num of input channels to conv2 is same
            self.conv1 = conv3x3(self.out_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)


    def forward(self, from_down, from_up):
        """ Forward pass
        Arguments:
            from_down: tensor from the encoder pathway
            from_up: upconv'd tensor from the decoder pathway
        """
        from_up = self.upconv(from_up)
        if self.merge_mode == 'concat':
            x = torch.cat((from_up, from_down), 1)
        else:
            x = from_up + from_down
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x


class UNet(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597

    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).

    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')
    """

    def __init__(self, num_classes, in_channels=3, depth=5, 
                 start_filts=64, up_mode='transpose', same_channels=False,
                 merge_mode='concat', **kwargs):
        """
        Arguments:
            in_channels: int, number of channels in the input tensor.
                Default is 3 for RGB images.
            depth: int, number of MaxPools in the U-Net.
            start_filts: int, number of convolutional filters for the 
                first conv.
            up_mode: string, type of upconvolution. Choices: 'transpose'
                for transpose convolution or 'upsample' for nearest neighbour
                upsampling.
        """
        super(UNet, self).__init__()

        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for "
                             "upsampling. Only \"transpose\" and "
                             "\"upsample\" are allowed.".format(up_mode))
    
        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for"
                             "merging up and down paths. "
                             "Only \"concat\" and "
                             "\"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible "
                             "with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use "
                             "nearest neighbour to reduce "
                             "depth channels (by half).")

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts*(2**i) if not same_channels else self.in_channels
            pooling = True if i < depth-1 else False
            #print("down ins, outs: ", ins, outs)  # [latent dim, 32], [32, 64]...[128, 256]

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2 if not same_channels else ins 
            up_conv = UpConv(ins, outs, up_mode=up_mode,
                merge_mode=merge_mode)
            self.up_convs.append(up_conv)
            #print("up ins, outs: ", ins, outs)# [256, 128]...[64, 32]; final 32 to latent is done through self.conv_final 

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.conv_final = conv1x1(outs, self.num_classes)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal_(m.weight)
            init.constant_(m.bias, 0)


    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)


    def forward(self, x):
        encoder_outs = []
        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            #print("down {} x1: ".format(i), x.shape) # increasing channels but decreasing resolution (64x64 -> 8x8)
            x, before_pool = module(x)
            #print("down {} x2: ".format(i), x.shape)
            encoder_outs.append(before_pool)
        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            #print("up {} x1: ".format(i), x.shape)
            x = module(before_pool, x)
            #print("up {} x2: ".format(i), x.shape)
        #exit()
        
        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x

    def generate(self, x):
        return self(x)

# Resnet Blocks
class ResnetBlockFC(nn.Module):
    ''' Fully connected ResNet Block class.
    Args:
        size_in (int): input dimension
        size_out (int): output dimension
        size_h (int): hidden dimension
    '''

    def __init__(self, size_in, size_out=None, size_h=None):
        super().__init__()
        # Attributes
        if size_out is None:
            size_out = size_in

        if size_h is None:
            size_h = min(size_in, size_out)

        self.size_in = size_in
        self.size_h = size_h
        self.size_out = size_out
        # Submodules
        self.fc_0 = nn.Linear(size_in, size_h)
        self.fc_1 = nn.Linear(size_h, size_out)
        self.actvn = nn.ReLU()

        if size_in == size_out:
            self.shortcut = None
        else:
            self.shortcut = nn.Linear(size_in, size_out, bias=False)
        # Initialization
        nn.init.zeros_(self.fc_1.weight)

    def forward(self, x):
        net = self.fc_0(self.actvn(x))
        dx = self.fc_1(self.actvn(net))

        if self.shortcut is not None:
            x_s = self.shortcut(x)
        else:
            x_s = x

        return x_s + dx

================================================
FILE: models/archs/encoders/dgcnn.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

def knn(x, k):
    inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
    xx = torch.sum(x ** 2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()

    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx

def get_graph_feature(x, k=20):
    idx = knn(x, k=k)
    batch_size, num_points, _ = idx.size()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points

    idx = idx + idx_base

    idx = idx.view(-1)

    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims)  
                                       # -> (batch_size*num_points, num_dims) 
                                       #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size * num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
    return feature

class DGCNN(nn.Module):

    def __init__(
        self, 
        emb_dims=512,
        use_bn=False
    ):

        super().__init__()

        if use_bn:
            self.bn1 = nn.BatchNorm2d(64)
            self.bn2 = nn.BatchNorm2d(64)
            self.bn3 = nn.BatchNorm2d(128)
            self.bn4 = nn.BatchNorm2d(256)
            self.bn5 = nn.BatchNorm2d(emb_dims)

            self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))
            self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))
            self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))
            self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), self.bn4, nn.LeakyReLU(negative_slope=0.2))
            self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2))

        else:
            self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
            self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
            self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
            self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))
            self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))

    def forward(self, x):
        batch_size, num_dims, num_points = x.size()                 # x:      batch x   3 x num of points
        x = get_graph_feature(x)                                    # x:      batch x   6 x num of points x 20

        x1     = self.conv1(x)                                      # x1:     batch x  64 x num of points x 20
        x1_max = x1.max(dim=-1, keepdim=True)[0]                    # x1_max: batch x  64 x num of points x 1

        x2     = self.conv2(x1)                                     # x2:     batch x  64 x num of points x 20
        x2_max = x2.max(dim=-1, keepdim=True)[0]                    # x2_max: batch x  64 x num of points x 1

        x3     = self.conv3(x2)                                     # x3:     batch x 128 x num of points x 20
        x3_max = x3.max(dim=-1, keepdim=True)[0]                    # x3_max: batch x 128 x num of points x 1

        x4     = self.conv4(x3)                                     # x4:     batch x 256 x num of points x 20
        x4_max = x4.max(dim=-1, keepdim=True)[0]                    # x4_max: batch x 256 x num of points x 1
 
        x_max  = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max:  batch x 512 x num of points x 1

        point_feat = torch.squeeze(self.conv5(x_max), dim=3)        # point feat:  batch x 512 x num of points

        global_feat = point_feat.max(dim=2, keepdim=False)[0]       # global feat: batch x 512

        return global_feat

================================================
FILE: models/archs/encoders/rbf.py
================================================
#!/usr/bin/env python3

import torch.nn as nn
import torch
import torch.nn.functional as F
import json
import numpy as np

# https://github.com/vsitzmann/siren/blob/4df34baee3f0f9c8f351630992c1fe1f69114b5f/modules.py#L266

class RBFLayer(nn.Module):
    '''Transforms incoming data using a given radial basis function.
        - Input: (1, N, in_features) where N is an arbitrary batch size
        - Output: (1, N, out_features) where N is an arbitrary batch size'''

    def __init__(self, in_features=3, out_features=1024):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.centres = nn.Parameter(torch.Tensor(out_features, in_features))
        self.sigmas = nn.Parameter(torch.Tensor(out_features))
        self.reset_parameters()

        self.freq = nn.Parameter(np.pi * torch.ones((1, self.out_features)))

    def reset_parameters(self):
        nn.init.uniform_(self.centres, -1, 1)
        nn.init.constant_(self.sigmas, 10)

    def forward(self, x):
        #x = x[0, ...]
        size = (x.size(0), self.out_features, self.in_features)
        x = x.unsqueeze(1).expand(size)
        c = self.centres.unsqueeze(0).expand(size)
        distances = (x - c).pow(2).sum(-1) * self.sigmas.unsqueeze(0)
        return self.gaussian(distances).unsqueeze(0)

    def gaussian(self, alpha):
        phi = torch.exp(-1 * alpha.pow(2))
        return phi


================================================
FILE: models/archs/encoders/sal_pointnet.py
================================================
#!/usr/bin/env python3

import torch.nn as nn
import torch
import torch.nn.functional as F

import json
import sys
sys.path.append("..")
import utils # actual dir is ../utils

# pointnet from SAL paper: https://github.com/matanatz/SAL/blob/master/code/model/network.py#L14
class SalPointNet(nn.Module):
    ''' PointNet-based encoder network. Based on: https://github.com/autonomousvision/occupancy_networks
    Args:
        c_dim (int): dimension of latent code c
        dim (int): input points dimension
        hidden_dim (int): hidden dimension of the network
    '''

    def __init__(self, c_dim=256, in_dim=3, hidden_dim=128):
        super().__init__()
        self.c_dim = c_dim

        self.fc_pos = nn.Linear(in_dim, 2*hidden_dim)
        self.fc_0 = nn.Linear(2*hidden_dim, hidden_dim)
        self.fc_1 = nn.Linear(2*hidden_dim, hidden_dim)
        self.fc_2 = nn.Linear(2*hidden_dim, hidden_dim)
        self.fc_3 = nn.Linear(2*hidden_dim, hidden_dim)
        self.fc_mean = nn.Linear(hidden_dim, c_dim)
        self.fc_std = nn.Linear(hidden_dim, c_dim)
        torch.nn.init.constant_(self.fc_mean.weight,0)
        torch.nn.init.constant_(self.fc_mean.bias, 0)

        torch.nn.init.constant_(self.fc_std.weight, 0)
        torch.nn.init.constant_(self.fc_std.bias, -10)

        self.actvn = nn.ReLU()
        self.pool = self.maxpool

    def forward(self, p):
        net = self.fc_pos(p)
        net = self.fc_0(self.actvn(net))
        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
        net = torch.cat([net, pooled], dim=2)

        net = self.fc_1(self.actvn(net))
        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
        net = torch.cat([net, pooled], dim=2)

        net = self.fc_2(self.actvn(net))
        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
        net = torch.cat([net, pooled], dim=2)

        net = self.fc_3(self.actvn(net))

        net = self.pool(net, dim=1)

        c_mean = self.fc_mean(self.actvn(net))
        c_std = self.fc_std(self.actvn(net)
Download .txt
gitextract_hv9jwie8/

├── .gitignore
├── LICENSE
├── README.md
├── config/
│   ├── stage1_sdf/
│   │   └── specs.json
│   ├── stage2_diff_cond/
│   │   └── specs.json
│   ├── stage2_diff_uncond/
│   │   └── specs.json
│   ├── stage3_cond/
│   │   └── specs.json
│   └── stage3_uncond/
│       └── specs.json
├── data/
│   ├── acronym/
│   │   └── Couch/
│   │       └── 37cfcafe606611d81246538126da07a8/
│   │           └── sdf_data.csv
│   ├── grid_data/
│   │   └── acronym/
│   │       └── Couch/
│   │           └── 37cfcafe606611d81246538126da07a8/
│   │               └── grid_gt.csv
│   └── splits/
│       └── couch_all.json
├── dataloader/
│   ├── .base.py.swp
│   ├── __init__.py
│   ├── base.py
│   ├── modulation_loader.py
│   ├── pc_loader.py
│   └── sdf_loader.py
├── diff_utils/
│   ├── helpers.py
│   ├── model_utils.py
│   ├── pointnet/
│   │   ├── __init__.py
│   │   ├── conv_pointnet.py
│   │   ├── dgcnn.py
│   │   ├── dgcnn_bn32.pt
│   │   ├── pointnet_base.py
│   │   ├── pointnet_classifier.py
│   │   └── transformer.py
│   └── sdf_utils.py
├── environment.yml
├── metrics/
│   ├── StructuralLosses/
│   │   ├── __init__.py
│   │   ├── match_cost.py
│   │   └── nn_distance.py
│   ├── __init__.py
│   ├── evaluation_metrics.py
│   └── pytorch_structural_losses/
│       ├── .gitignore
│       ├── Makefile
│       ├── StructuralLosses/
│       │   ├── __init__.py
│       │   ├── match_cost.py
│       │   └── nn_distance.py
│       ├── __init__.py
│       ├── match_cost.py
│       ├── nn_distance.py
│       ├── pybind/
│       │   ├── bind.cpp
│       │   └── extern.hpp
│       ├── setup.py
│       └── src/
│           ├── approxmatch.cu
│           ├── approxmatch.cuh
│           ├── nndistance.cu
│           ├── nndistance.cuh
│           ├── structural_loss.cpp
│           └── utils.hpp
├── models/
│   ├── __init__.py
│   ├── archs/
│   │   ├── __init__.py
│   │   ├── diffusion_arch.py
│   │   ├── encoders/
│   │   │   ├── __init__.py
│   │   │   ├── auto_decoder.py
│   │   │   ├── conv_pointnet.py
│   │   │   ├── dgcnn.py
│   │   │   ├── rbf.py
│   │   │   ├── sal_pointnet.py
│   │   │   └── vanilla_pointnet.py
│   │   ├── modulated_sdf.py
│   │   ├── resnet_block.py
│   │   ├── sdf_decoder.py
│   │   └── unet.py
│   ├── autoencoder.py
│   ├── combined_model.py
│   ├── diff_np_if_torch_error.py
│   ├── diffusion.py
│   └── sdf_model.py
├── test.py
├── train.py
└── utils/
    ├── __init__.py
    ├── chamfer.py
    ├── evaluate.py
    ├── mesh.py
    ├── mmd.py
    ├── reconstruct.py
    ├── renderer.py
    ├── tmd.py
    ├── uhd.py
    └── visualize.py
Download .txt
SYMBOL INDEX (330 symbols across 48 files)

FILE: dataloader/base.py
  class Dataset (line 14) | class Dataset(torch.utils.data.Dataset):
    method __init__ (line 15) | def __init__(
    method __len__ (line 36) | def __len__(self):
    method __getitem__ (line 39) | def __getitem__(self, idx):
    method sample_pointcloud (line 42) | def sample_pointcloud(self, csvfile, pc_size):
    method labeled_sampling (line 54) | def labeled_sampling(self, f, subsample, pc_size=1024, load_from_path=...
    method get_instance_filenames (line 92) | def get_instance_filenames(self, data_source, split, gt_filename="sdf_...

FILE: dataloader/modulation_loader.py
  class ModulationLoader (line 17) | class ModulationLoader(torch.utils.data.Dataset):
    method __init__ (line 18) | def __init__(self, data_path, pc_path=None, split_file=None, pc_size=N...
    method __len__ (line 45) | def __len__(self):
    method __getitem__ (line 48) | def __getitem__(self, index):
    method load_modulations (line 57) | def load_modulations(self, data_source, pc_source, split, f_name="late...
    method unconditional_load_modulations (line 85) | def unconditional_load_modulations(self, data_source, split, f_name="l...

FILE: dataloader/pc_loader.py
  class PCloader (line 15) | class PCloader(base.Dataset):
    method __init__ (line 17) | def __init__(
    method get_all_files (line 41) | def get_all_files(self):
    method __getitem__ (line 44) | def __getitem__(self, idx):
    method __len__ (line 51) | def __len__(self):
    method sample_pc (line 55) | def sample_pc(self, f, samp=1024):

FILE: dataloader/sdf_loader.py
  class SdfLoader (line 17) | class SdfLoader(base.Dataset):
    method __init__ (line 19) | def __init__(
    method __getitem__ (line 62) | def __getitem__(self, idx):
    method __len__ (line 87) | def __len__(self):

FILE: diff_utils/helpers.py
  function get_split_filenames (line 13) | def get_split_filenames(data_source, split_file, f_name="sdf_data.csv"):
  function sample_pc (line 26) | def sample_pc(f, samp=1024, add_flip_augment=False):
  function perturb_point_cloud (line 45) | def perturb_point_cloud(pc, perturb, pc_size=None, crop_percent=0.25):
  function fps (line 60) | def fps(data, number):
  function crop_pc (line 69) | def crop_pc(xyz, crop, pc_size=None, fixed_points = None, padding_zeros ...
  function visualize_pc (line 138) | def visualize_pc(pc):
  function jitter_pc (line 144) | def jitter_pc(pc, pc_size=None, sigma=0.05, clip=0.1):
  function normalize_pc (line 156) | def normalize_pc(pc):
  function save_model (line 163) | def save_model(iters, model, optimizer, loss, path):
  function load_model (line 170) | def load_model(model, optimizer, path):
  function save_code_to_conf (line 185) | def save_code_to_conf(conf_dir):
  class ScheduledOpt (line 195) | class ScheduledOpt:
    method __init__ (line 200) | def __init__(self, warmup, optimizer):
    method step (line 221) | def step(self):
    method zero_grad (line 231) | def zero_grad(self):
    method rate (line 234) | def rate(self, step = None):
  function exists (line 245) | def exists(x):
  function default (line 248) | def default(val, d):
  function cycle (line 253) | def cycle(dl):
  function has_int_squareroot (line 258) | def has_int_squareroot(num):
  function num_to_groups (line 261) | def num_to_groups(num, divisor):
  function convert_image_to (line 269) | def convert_image_to(img_type, image):
  function normalize_to_neg_one_to_one (line 277) | def normalize_to_neg_one_to_one(img):
  function unnormalize_to_zero_to_one (line 281) | def unnormalize_to_zero_to_one(t):
  function normalize_to_zero_to_one (line 286) | def normalize_to_zero_to_one(f):
  function extract (line 293) | def extract(a, t, x_shape):
  function linear_beta_schedule (line 298) | def linear_beta_schedule(timesteps):
  function cosine_beta_schedule (line 305) | def cosine_beta_schedule(timesteps, s = 0.008):

FILE: diff_utils/model_utils.py
  class LayerNorm (line 16) | class LayerNorm(nn.Module):
    method __init__ (line 17) | def __init__(self, dim, eps = 1e-5, stable = False):
    method forward (line 23) | def forward(self, x):
  class MLP (line 33) | class MLP(nn.Module):
    method __init__ (line 34) | def __init__(
    method forward (line 63) | def forward(self, x):
  class RelPosBias (line 68) | class RelPosBias(nn.Module):
    method __init__ (line 69) | def __init__(
    method _relative_position_bucket (line 81) | def _relative_position_bucket(
    method forward (line 96) | def forward(self, i, j, *, device):
  class SwiGLU (line 106) | class SwiGLU(nn.Module):
    method forward (line 108) | def forward(self, x):
  function FeedForward (line 112) | def FeedForward(
  class SinusoidalPosEmb (line 134) | class SinusoidalPosEmb(nn.Module):
    method __init__ (line 135) | def __init__(self, dim):
    method forward (line 139) | def forward(self, x):
  function exists (line 148) | def exists(x):
  function default (line 151) | def default(val, d):
  class Attention (line 156) | class Attention(nn.Module):
    method __init__ (line 157) | def __init__(
    method forward (line 197) | def forward(self, x, context=None, mask = None, attn_bias = None):

FILE: diff_utils/pointnet/conv_pointnet.py
  class ConvPointnet (line 9) | class ConvPointnet(nn.Module):
    method __init__ (line 26) | def __init__(self, c_dim=512, dim=3, hidden_dim=128, scatter_type='max',
    method forward (line 58) | def forward(self, p, query):
    method normalize_coordinate (line 100) | def normalize_coordinate(self, p, padding=0.1, plane='xz'):
    method coordinate2index (line 126) | def coordinate2index(self, x, reso):
    method pool_local (line 143) | def pool_local(self, xy, index, c):
    method generate_plane_features (line 159) | def generate_plane_features(self, p, c, plane='xz'):
    method sample_plane_feature (line 179) | def sample_plane_feature(self, query, plane_feature, plane):
  function conv3x3 (line 187) | def conv3x3(in_channels, out_channels, stride=1,
  function upconv2x2 (line 198) | def upconv2x2(in_channels, out_channels, mode='transpose'):
  function conv1x1 (line 212) | def conv1x1(in_channels, out_channels, groups=1):
  class DownConv (line 221) | class DownConv(nn.Module):
    method __init__ (line 226) | def __init__(self, in_channels, out_channels, pooling=True):
    method forward (line 239) | def forward(self, x):
  class UpConv (line 248) | class UpConv(nn.Module):
    method __init__ (line 253) | def __init__(self, in_channels, out_channels,
    method forward (line 274) | def forward(self, from_down, from_up):
  class UNet (line 290) | class UNet(nn.Module):
    method __init__ (line 313) | def __init__(self, num_classes, in_channels=3, depth=5,
    method weight_init (line 387) | def weight_init(m):
    method reset_params (line 393) | def reset_params(self):
    method forward (line 398) | def forward(self, x):
  class ResnetBlockFC (line 415) | class ResnetBlockFC(nn.Module):
    method __init__ (line 423) | def __init__(self, size_in, size_out=None, size_h=None):
    method forward (line 447) | def forward(self, x):

FILE: diff_utils/pointnet/dgcnn.py
  function knn (line 5) | def knn(x, k):
  function get_graph_feature (line 13) | def get_graph_feature(x, k=20):
  class DGCNN (line 35) | class DGCNN(nn.Module):
    method __init__ (line 37) | def __init__(
    method forward (line 76) | def forward(self, x):
    method get_global_feature (line 111) | def get_global_feature(self, x):

FILE: diff_utils/pointnet/pointnet_base.py
  class PointNetBase (line 14) | class PointNetBase(nn.Module):
    method __init__ (line 16) | def __init__(self, num_points=2000, K=3):
    method forward (line 55) | def forward(self, x):

FILE: diff_utils/pointnet/pointnet_classifier.py
  class PointNetClassifier (line 14) | class PointNetClassifier(nn.Module):
    method __init__ (line 16) | def __init__(self, num_points=2000, K=3):
    method forward (line 38) | def forward(self, x):

FILE: diff_utils/pointnet/transformer.py
  class Transformer (line 14) | class Transformer(nn.Module):
    method __init__ (line 16) | def __init__(self, num_points=2000, K=3):
    method forward (line 62) | def forward(self, x):

FILE: diff_utils/sdf_utils.py
  function pred_sdf_loss (line 26) | def pred_sdf_loss(x0, idx):
  function functional_sdf_model (line 36) | def functional_sdf_model(modulation, xyz, gt):
  function sdf_sampling (line 51) | def sdf_sampling(f, subsample, pc_size=1024, batch=1):
  function apply_to_sdf (line 91) | def apply_to_sdf(f, x):

FILE: metrics/StructuralLosses/match_cost.py
  class MatchCostFunction (line 6) | class MatchCostFunction(Function):
    method forward (line 10) | def forward(ctx, seta, setb):
    method backward (line 31) | def backward(ctx, grad_output):

FILE: metrics/StructuralLosses/nn_distance.py
  class NNDistanceFunction (line 7) | class NNDistanceFunction(Function):
    method forward (line 11) | def forward(ctx, seta, setb):
    method backward (line 28) | def backward(ctx, grad_dist1, grad_dist2):

FILE: metrics/evaluation_metrics.py
  function distChamfer (line 12) | def distChamfer(a, b):
  function distChamferCUDA (line 30) | def distChamferCUDA(x, y):
  function emd_approx (line 38) | def emd_approx(x, y):
  function emd_approx_cuda (line 61) | def emd_approx_cuda(sample, ref):
  function EMD_CD (line 73) | def EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduce...
  function _pairwise_EMD_CD_ (line 114) | def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size=None, accelerated_...
  function emd_tmd_from_pcs (line 159) | def emd_tmd_from_pcs(gen_pcs):
  function knn (line 173) | def knn(Mxx, Mxy, Myy, k, sqrt=False):
  function lgan_mmd_cov (line 205) | def lgan_mmd_cov(all_dist):
  function compute_mmd (line 221) | def compute_mmd(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True):
  function compute_cd (line 241) | def compute_cd(sample_pcs, ref_pcs):
  function compute_all_metrics (line 249) | def compute_all_metrics(sample_pcs, ref_pcs, batch_size=None, accelerate...
  function unit_cube_grid_point_cloud (line 286) | def unit_cube_grid_point_cloud(resolution, clip_sphere=False):
  function jsd_between_point_cloud_sets (line 307) | def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):
  function entropy_of_occupancy_grid (line 321) | def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False,...
  function jensen_shannon_divergence (line 363) | def jensen_shannon_divergence(P, Q):
  function _jsdiv (line 385) | def _jsdiv(P, Q):

FILE: metrics/pytorch_structural_losses/StructuralLosses/match_cost.py
  class MatchCostFunction (line 6) | class MatchCostFunction(Function):
    method forward (line 10) | def forward(ctx, seta, setb):
    method backward (line 31) | def backward(ctx, grad_output):

FILE: metrics/pytorch_structural_losses/StructuralLosses/nn_distance.py
  class NNDistanceFunction (line 7) | class NNDistanceFunction(Function):
    method forward (line 11) | def forward(ctx, seta, setb):
    method backward (line 28) | def backward(ctx, grad_dist1, grad_dist2):

FILE: metrics/pytorch_structural_losses/match_cost.py
  class MatchCostFunction (line 6) | class MatchCostFunction(Function):
    method forward (line 10) | def forward(ctx, seta, setb):
    method backward (line 31) | def backward(ctx, grad_output):

FILE: metrics/pytorch_structural_losses/nn_distance.py
  class NNDistanceFunction (line 7) | class NNDistanceFunction(Function):
    method forward (line 11) | def forward(ctx, seta, setb):
    method backward (line 28) | def backward(ctx, grad_dist1, grad_dist2):

FILE: metrics/pytorch_structural_losses/pybind/bind.cpp
  function PYBIND11_MODULE (line 9) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){

FILE: metrics/pytorch_structural_losses/src/structural_loss.cpp
  function ApproxMatch (line 22) | std::vector<at::Tensor> ApproxMatch(at::Tensor set_d, at::Tensor set_q) {
  function MatchCost (line 39) | at::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor matc...
  function MatchCostGrad (line 54) | std::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q...
  function NNDistance (line 80) | std::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q) {
  function NNDistanceGrad (line 101) | std::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_...

FILE: metrics/pytorch_structural_losses/src/utils.hpp
  class Formatter (line 5) | class Formatter {
    method Formatter (line 7) | Formatter() {}
    method Formatter (line 10) | Formatter &operator<<(const Type &value) {
    method str (line 15) | std::string str() const { return stream_.str(); }
    type ConvertToString (line 18) | enum ConvertToString { to_str }

FILE: models/archs/diffusion_arch.py
  class CausalTransformer (line 16) | class CausalTransformer(nn.Module):
    method __init__ (line 17) | def __init__(
    method forward (line 87) | def forward(self, x, time_emb=None, context=None):
  class DiffusionNet (line 123) | class DiffusionNet(nn.Module):
    method __init__ (line 125) | def __init__(
    method forward (line 160) | def forward(

FILE: models/archs/encoders/auto_decoder.py
  class AutoDecoder (line 9) | class AutoDecoder():
    method __init__ (line 12) | def __init__(self, num_scenes, latent_size):
    method build_model (line 16) | def build_model(self):

FILE: models/archs/encoders/conv_pointnet.py
  class ConvPointnet (line 9) | class ConvPointnet(nn.Module):
    method __init__ (line 26) | def __init__(self, c_dim=512, dim=3, hidden_dim=128, scatter_type='max',
    method generate_plane_features (line 56) | def generate_plane_features(self, p, c, plane='xz'):
    method forward (line 75) | def forward(self, p, query):
    method forward_with_plane_features (line 119) | def forward_with_plane_features(self, plane_features, query):
    method forward_with_pc_features (line 135) | def forward_with_pc_features(self, c, p, query):
    method get_point_cloud_features (line 153) | def get_point_cloud_features(self, p):
    method get_plane_features (line 181) | def get_plane_features(self, p):
    method normalize_coordinate (line 195) | def normalize_coordinate(self, p, padding=0.1, plane='xz'):
    method coordinate2index (line 221) | def coordinate2index(self, x, reso):
    method pool_local (line 238) | def pool_local(self, xy, index, c):
    method sample_plane_feature (line 255) | def sample_plane_feature(self, query, plane_feature, plane):
  function conv3x3 (line 263) | def conv3x3(in_channels, out_channels, stride=1,
  function upconv2x2 (line 274) | def upconv2x2(in_channels, out_channels, mode='transpose'):
  function conv1x1 (line 288) | def conv1x1(in_channels, out_channels, groups=1):
  class DownConv (line 297) | class DownConv(nn.Module):
    method __init__ (line 302) | def __init__(self, in_channels, out_channels, pooling=True):
    method forward (line 315) | def forward(self, x):
  class UpConv (line 324) | class UpConv(nn.Module):
    method __init__ (line 329) | def __init__(self, in_channels, out_channels,
    method forward (line 350) | def forward(self, from_down, from_up):
  class UNet (line 366) | class UNet(nn.Module):
    method __init__ (line 389) | def __init__(self, num_classes, in_channels=3, depth=5,
    method weight_init (line 465) | def weight_init(m):
    method reset_params (line 471) | def reset_params(self):
    method forward (line 476) | def forward(self, x):
    method generate (line 497) | def generate(self, x):
  class ResnetBlockFC (line 501) | class ResnetBlockFC(nn.Module):
    method __init__ (line 509) | def __init__(self, size_in, size_out=None, size_h=None):
    method forward (line 533) | def forward(self, x):

FILE: models/archs/encoders/dgcnn.py
  function knn (line 5) | def knn(x, k):
  function get_graph_feature (line 13) | def get_graph_feature(x, k=20):
  class DGCNN (line 35) | class DGCNN(nn.Module):
    method __init__ (line 37) | def __init__(
    method forward (line 65) | def forward(self, x):

FILE: models/archs/encoders/rbf.py
  class RBFLayer (line 11) | class RBFLayer(nn.Module):
    method __init__ (line 16) | def __init__(self, in_features=3, out_features=1024):
    method reset_parameters (line 26) | def reset_parameters(self):
    method forward (line 30) | def forward(self, x):
    method gaussian (line 38) | def gaussian(self, alpha):

FILE: models/archs/encoders/sal_pointnet.py
  class SalPointNet (line 13) | class SalPointNet(nn.Module):
    method __init__ (line 21) | def __init__(self, c_dim=256, in_dim=3, hidden_dim=128):
    method forward (line 41) | def forward(self, p):
    method maxpool (line 64) | def maxpool(self, x, dim=-1, keepdim=False):

FILE: models/archs/encoders/vanilla_pointnet.py
  class PointNet (line 13) | class PointNet(MetaModule):
    method __init__ (line 14) | def __init__(self, latent_size):
    method forward (line 43) | def forward(self, x, params=None):

FILE: models/archs/modulated_sdf.py
  class Layer (line 14) | class Layer(nn.Module):
    method __init__ (line 15) | def __init__(self, dim_in=512, dim_out=512, dim=512, dropout_prob=0.0,...
    method forward (line 36) | def forward(self, x):
  class ModulatedMLP (line 45) | class ModulatedMLP(nn.Module):
    method __init__ (line 46) | def __init__(self, latent_size=512, hidden_dim=512, num_layers=9, late...
    method pe_transform (line 116) | def pe_transform(self, data):
    method forward (line 119) | def forward(self, xyz, latent):

FILE: models/archs/resnet_block.py
  class ResnetBlockFC (line 6) | class ResnetBlockFC(nn.Module):
    method __init__ (line 14) | def __init__(self, size_in, size_out=None, size_h=None):
    method forward (line 38) | def forward(self, x):

FILE: models/archs/sdf_decoder.py
  class SdfDecoder (line 10) | class SdfDecoder(nn.Module):
    method __init__ (line 11) | def __init__(self, latent_size=256, hidden_dim=512,
    method forward (line 65) | def forward(self, x):

FILE: models/archs/unet.py
  function conv3x3 (line 14) | def conv3x3(in_channels, out_channels, stride=1,
  function upconv2x2 (line 25) | def upconv2x2(in_channels, out_channels, mode='transpose'):
  function conv1x1 (line 39) | def conv1x1(in_channels, out_channels, groups=1):
  class DownConv (line 48) | class DownConv(nn.Module):
    method __init__ (line 53) | def __init__(self, in_channels, out_channels, pooling=True):
    method forward (line 66) | def forward(self, x):
  class UpConv (line 75) | class UpConv(nn.Module):
    method __init__ (line 80) | def __init__(self, in_channels, out_channels,
    method forward (line 101) | def forward(self, from_down, from_up):
  class UNet (line 117) | class UNet(nn.Module):
    method __init__ (line 140) | def __init__(self, num_classes, in_channels=3, depth=5,
    method weight_init (line 214) | def weight_init(m):
    method reset_params (line 220) | def reset_params(self):
    method forward (line 225) | def forward(self, x):

FILE: models/autoencoder.py
  class BetaVAE (line 11) | class BetaVAE(nn.Module):
    method __init__ (line 15) | def __init__(self,
    method encode (line 102) | def encode(self, enc_input: Tensor) -> List[Tensor]:
    method decode (line 120) | def decode(self, z: Tensor) -> Tensor:
    method reparameterize (line 131) | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
    method forward (line 143) | def forward(self, data: Tensor, **kwargs) -> Tensor:
    method loss_function (line 149) | def loss_function(self,
    method sample (line 182) | def sample(self,
    method generate (line 199) | def generate(self, x: Tensor, **kwargs) -> Tensor:
    method get_latent (line 208) | def get_latent(self, x):

FILE: models/combined_model.py
  class CombinedModel (line 9) | class CombinedModel(pl.LightningModule):
    method __init__ (line 10) | def __init__(self, specs):
    method training_step (line 29) | def training_step(self, x, idx):
    method configure_optimizers (line 39) | def configure_optimizers(self):
    method train_modulation (line 67) | def train_modulation(self, x):
    method train_diffusion (line 101) | def train_diffusion(self, x):
    method train_combined (line 127) | def train_combined(self, x):

FILE: models/diff_np_if_torch_error.py
  class DiffusionModel (line 27) | class DiffusionModel(nn.Module):
    method __init__ (line 28) | def __init__(
    method predict_start_from_noise (line 123) | def predict_start_from_noise(self, x_t, t, noise):
    method predict_noise_from_start (line 129) | def predict_noise_from_start(self, x_t, t, x0):
    method ddim_sample (line 136) | def ddim_sample(self, dim, batch_size, noise=None, clip_denoised = Tru...
    method sample (line 175) | def sample(self, dim, batch_size, noise=None, clip_denoised = True, tr...
    method q_posterior (line 205) | def q_posterior(self, x_start, x_t, t):
    method q_sample (line 216) | def q_sample(self, x_start, t, noise=None):
    method forward (line 227) | def forward(self, x_start, t, ret_pred_x=False, noise = None, cond=None):
    method model_predictions (line 258) | def model_predictions(self, model_input, t):
    method diffusion_model_from_latent (line 281) | def diffusion_model_from_latent(self, x_start, cond=None):
    method generate_from_pc (line 299) | def generate_from_pc(self, pc, load_pc=False, batch=5, save_pc=False, ...
    method generate_unconditional (line 337) | def generate_unconditional(self, num_samples):

FILE: models/diffusion.py
  class DiffusionModel (line 27) | class DiffusionModel(nn.Module):
    method __init__ (line 28) | def __init__(
    method predict_start_from_noise (line 94) | def predict_start_from_noise(self, x_t, t, noise):
    method predict_noise_from_start (line 100) | def predict_noise_from_start(self, x_t, t, x0):
    method ddim_sample (line 107) | def ddim_sample(self, dim, batch_size, noise=None, clip_denoised = Tru...
    method sample (line 146) | def sample(self, dim, batch_size, noise=None, clip_denoised = True, tr...
    method q_posterior (line 176) | def q_posterior(self, x_start, x_t, t):
    method q_sample (line 187) | def q_sample(self, x_start, t, noise=None):
    method forward (line 198) | def forward(self, x_start, t, ret_pred_x=False, noise = None, cond=None):
    method model_predictions (line 229) | def model_predictions(self, model_input, t):
    method diffusion_model_from_latent (line 252) | def diffusion_model_from_latent(self, x_start, cond=None):
    method generate_from_pc (line 270) | def generate_from_pc(self, pc, load_pc=False, batch=5, save_pc=False, ...
    method generate_unconditional (line 308) | def generate_unconditional(self, num_samples):

FILE: models/sdf_model.py
  class SdfModel (line 21) | class SdfModel(pl.LightningModule):
    method __init__ (line 23) | def __init__(self, specs):
    method configure_optimizers (line 42) | def configure_optimizers(self):
    method training_step (line 48) | def training_step(self, x, idx):
    method forward (line 65) | def forward(self, pc, xyz):
    method forward_with_plane_features (line 70) | def forward_with_plane_features(self, plane_features, xyz):

FILE: test.py
  function test_modulations (line 30) | def test_modulations():
  function test_generation (line 94) | def test_generation():

FILE: train.py
  function train (line 32) | def train():

FILE: utils/chamfer.py
  function compute_trimesh_chamfer (line 5) | def compute_trimesh_chamfer(gt_points, gen_points, offset=0, scale=1):
  function scale_to_unit_sphere (line 31) | def scale_to_unit_sphere(points):

FILE: utils/evaluate.py
  function main (line 16) | def main(gt_pc, recon_mesh, out_file, mesh_name, return_value=False, ret...
  function calc_cd (line 53) | def calc_cd(gt_pc, recon_pc):
  function single_eval (line 69) | def single_eval(gt_csv, recon_mesh):

FILE: utils/mesh.py
  function create_mesh (line 15) | def create_mesh(
  function create_cube (line 62) | def create_cube(N):
  function convert_sdf_samples_to_ply (line 89) | def convert_sdf_samples_to_ply(

FILE: utils/reconstruct.py
  function vis_recon (line 23) | def vis_recon(test_dataloader, sdf_model, vae_model, recon_dir, take_mod...
  function filter_threshold (line 137) | def filter_threshold(mesh, gt_pc, threshold): # mesh is path to mesh wit...
  function extract_latents (line 143) | def extract_latents(test_dataloader, sdf_model, vae_model, save_dir):

FILE: utils/renderer.py
  class OnlineObjectRenderer (line 13) | class OnlineObjectRenderer:
    method __init__ (line 14) | def __init__(self, fov=np.pi / 6, caching=True):
    method _init_scene (line 25) | def _init_scene(self, height=480, width=480):
    method add_mesh (line 34) | def add_mesh(self, path, name, rotation=None, translation=None):
    method add_pointcloud (line 58) | def add_pointcloud(self, path, name, colors=None, rotation=None, trans...
    method clear (line 85) | def clear(self):
    method render (line 90) | def render(self):

FILE: utils/tmd.py
  function process_one (line 9) | def process_one(shape_dir):
  function tmd_from_pcs (line 28) | def tmd_from_pcs(gen_pcs):
  function Total_Mutual_Difference (line 40) | def Total_Mutual_Difference(args):
  function main (line 56) | def main():

FILE: utils/uhd.py
  function directed_hausdorff (line 11) | def directed_hausdorff(point_cloud1:torch.Tensor, point_cloud2:torch.Ten...
  function nn_distance (line 38) | def nn_distance(query_points, ref_points):
  function completeness (line 44) | def completeness(query_points, ref_points, thres=0.03):
  function process_one (line 50) | def process_one(shape_dir):
  function uhd_from_pcs (line 86) | def uhd_from_pcs(gen_pcs, partial_pc):
  function func (line 106) | def func(args):
  function main (line 119) | def main():

FILE: utils/visualize.py
  function create_point_marker (line 7) | def create_point_marker(center, color):
  function get_color (line 17) | def get_color(labels):
  function vis_pc (line 21) | def vis_pc(obj_pc):
  function main (line 27) | def main(mesh, pc, query):
Condensed preview — 81 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (321K chars).
[
  {
    "path": ".gitignore",
    "chars": 23,
    "preview": ".DS_Store\n__pycache__/\n"
  },
  {
    "path": "LICENSE",
    "chars": 1066,
    "preview": "MIT License\n\nCopyright (c) 2022 Gene Chou\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
  },
  {
    "path": "README.md",
    "chars": 7346,
    "preview": "# Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions\n\n[**Paper**](https://arxiv.org/abs/2211.13"
  },
  {
    "path": "config/stage1_sdf/specs.json",
    "chars": 620,
    "preview": "{\n    \"Description\" : \"training joint SDF-VAE model for modulating SDFs on the couch dataset\",\n    \"DataSource\" : \"data\""
  },
  {
    "path": "config/stage2_diff_cond/specs.json",
    "chars": 750,
    "preview": "{\n  \"Description\" : \"diffusion training (conditional) on couch dataset\",\n  \"pc_path\" : \"data\",\n  \"total_pc_size\" : 10000"
  },
  {
    "path": "config/stage2_diff_uncond/specs.json",
    "chars": 560,
    "preview": "{\n  \"Description\" : \"diffusion training (unconditional) on couch dataset\",\n  \"TrainSplit\" : \"data/splits/couch_all.json\""
  },
  {
    "path": "config/stage3_cond/specs.json",
    "chars": 1115,
    "preview": "{\n  \"Description\" : \"end-to-end training (conditional) on couch dataset\",\n  \"DataSource\" : \"data\",\n  \"GridSource\" : \"gri"
  },
  {
    "path": "config/stage3_uncond/specs.json",
    "chars": 981,
    "preview": "{\n  \"Description\" : \"end-to-end training (unconditional) on couch dataset\",\n  \"DataSource\" : \"data\",\n  \"GridSource\" : \"g"
  },
  {
    "path": "data/splits/couch_all.json",
    "chars": 17582,
    "preview": "{\n    \"acronym\": {\n        \"Couch\": [\n            \"37cfcafe606611d81246538126da07a8\",\n            \"fd124209f0bc1222f34e5"
  },
  {
    "path": "dataloader/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "dataloader/base.py",
    "chars": 3863,
    "preview": "#!/usr/bin/env python3\n\nimport numpy as np\nimport time \nimport logging\nimport os\nimport random\nimport torch\nimport torch"
  },
  {
    "path": "dataloader/modulation_loader.py",
    "chars": 4548,
    "preview": "#!/usr/bin/env python3\n\nimport time \nimport logging\nimport os\nimport random\nimport torch\nimport torch.utils.data\nfrom di"
  },
  {
    "path": "dataloader/pc_loader.py",
    "chars": 2009,
    "preview": "#!/usr/bin/env python3\n\nimport time \nimport logging\nimport os\nimport random\nimport torch\nimport torch.utils.data\nfrom . "
  },
  {
    "path": "dataloader/sdf_loader.py",
    "chars": 3581,
    "preview": "#!/usr/bin/env python3\n\nimport time \nimport logging\nimport os\nimport random\nimport torch\nimport torch.utils.data\nfrom . "
  },
  {
    "path": "diff_utils/helpers.py",
    "chars": 10029,
    "preview": "import math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np \nimport pandas as pd \nimport random \nfrom in"
  },
  {
    "path": "diff_utils/model_utils.py",
    "chars": 7507,
    "preview": "import math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum \n\nfrom einops import rearrange, re"
  },
  {
    "path": "diff_utils/pointnet/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "diff_utils/pointnet/conv_pointnet.py",
    "chars": 16401,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\nfrom torch_scatter import "
  },
  {
    "path": "diff_utils/pointnet/dgcnn.py",
    "chars": 7363,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef knn(x, k):\n    inner = -2 * torch.matmul(x.trans"
  },
  {
    "path": "diff_utils/pointnet/pointnet_base.py",
    "chars": 2971,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as grad\n\nfrom .transformer impo"
  },
  {
    "path": "diff_utils/pointnet/pointnet_classifier.py",
    "chars": 1748,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as grad\n\nfrom .pointnet_base im"
  },
  {
    "path": "diff_utils/pointnet/transformer.py",
    "chars": 2071,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as grad\n\n\n\n##------------------"
  },
  {
    "path": "diff_utils/sdf_utils.py",
    "chars": 2520,
    "preview": "import math\nimport torch\nimport json \nimport torch.nn.functional as F\nfrom torch import nn, einsum \nimport os\n\nfrom eino"
  },
  {
    "path": "environment.yml",
    "chars": 4779,
    "preview": "name: diffusionsdf\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_g"
  },
  {
    "path": "metrics/StructuralLosses/__init__.py",
    "chars": 113,
    "preview": "#import torch\n\n#from MakePytorchBackend import AddGPU, Foo, ApproxMatch\n\n#from Add import add_gpu, approx_match\n\n"
  },
  {
    "path": "metrics/StructuralLosses/match_cost.py",
    "chars": 1809,
    "preview": "import torch\nfrom torch.autograd import Function\nfrom metrics.StructuralLosses.StructuralLossesBackend import ApproxMatc"
  },
  {
    "path": "metrics/StructuralLosses/nn_distance.py",
    "chars": 1587,
    "preview": "import torch\nfrom torch.autograd import Function\n# from extensions.StructuralLosses.StructuralLossesBackend import NNDis"
  },
  {
    "path": "metrics/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "metrics/evaluation_metrics.py",
    "chars": 14488,
    "preview": "import torch\nimport numpy as np\nimport warnings\nfrom scipy.stats import entropy\nfrom sklearn.neighbors import NearestNei"
  },
  {
    "path": "metrics/pytorch_structural_losses/.gitignore",
    "chars": 34,
    "preview": "PyTorchStructuralLosses.egg-info/\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/Makefile",
    "chars": 3571,
    "preview": "###############################################################################\n# Uncomment for debugging\n# DEBUG := 1\n#"
  },
  {
    "path": "metrics/pytorch_structural_losses/StructuralLosses/__init__.py",
    "chars": 113,
    "preview": "#import torch\n\n#from MakePytorchBackend import AddGPU, Foo, ApproxMatch\n\n#from Add import add_gpu, approx_match\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/StructuralLosses/match_cost.py",
    "chars": 1809,
    "preview": "import torch\nfrom torch.autograd import Function\nfrom metrics.StructuralLosses.StructuralLossesBackend import ApproxMatc"
  },
  {
    "path": "metrics/pytorch_structural_losses/StructuralLosses/nn_distance.py",
    "chars": 1587,
    "preview": "import torch\nfrom torch.autograd import Function\n# from extensions.StructuralLosses.StructuralLossesBackend import NNDis"
  },
  {
    "path": "metrics/pytorch_structural_losses/__init__.py",
    "chars": 113,
    "preview": "#import torch\n\n#from MakePytorchBackend import AddGPU, Foo, ApproxMatch\n\n#from Add import add_gpu, approx_match\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/match_cost.py",
    "chars": 1809,
    "preview": "import torch\nfrom torch.autograd import Function\nfrom metrics.StructuralLosses.StructuralLossesBackend import ApproxMatc"
  },
  {
    "path": "metrics/pytorch_structural_losses/nn_distance.py",
    "chars": 1580,
    "preview": "import torch\nfrom torch.autograd import Function\n# from extensions.StructuralLosses.StructuralLossesBackend import NNDis"
  },
  {
    "path": "metrics/pytorch_structural_losses/pybind/bind.cpp",
    "chars": 343,
    "preview": "#include <string>\n\n#include <torch/extension.h>\n\n#include \"pybind/extern.hpp\"\n\nnamespace py = pybind11;\n\nPYBIND11_MODULE"
  },
  {
    "path": "metrics/pytorch_structural_losses/pybind/extern.hpp",
    "chars": 469,
    "preview": "std::vector<at::Tensor> ApproxMatch(at::Tensor in_a, at::Tensor in_b);\nat::Tensor MatchCost(at::Tensor set_d, at::Tensor"
  },
  {
    "path": "metrics/pytorch_structural_losses/setup.py",
    "chars": 928,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import CUDAExtension, BuildExtension\n\n# Python interface\nset"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/approxmatch.cu",
    "chars": 10328,
    "preview": "#include \"utils.hpp\"\n\n__global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * "
  },
  {
    "path": "metrics/pytorch_structural_losses/src/approxmatch.cuh",
    "chars": 527,
    "preview": "/*\ntemplate <typename Dtype>\nvoid AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,\n                  cudaStre"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/nndistance.cu",
    "chars": 4522,
    "preview": "\n__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/nndistance.cuh",
    "chars": 375,
    "preview": "void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/structural_loss.cpp",
    "chars": 6412,
    "preview": "#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\n#include \"src/approxmatch.cuh\"\n#include \"src/nndistance"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/utils.hpp",
    "chars": 560,
    "preview": "#include <iostream>\n#include <sstream>\n#include <string>\n\nclass Formatter {\npublic:\n  Formatter() {}\n  ~Formatter() {}\n\n"
  },
  {
    "path": "models/__init__.py",
    "chars": 340,
    "preview": "#!/usr/bin/python3\n\nfrom models.sdf_model import SdfModel\nfrom models.autoencoder import BetaVAE\n\nfrom models.archs.enco"
  },
  {
    "path": "models/archs/__init__.py",
    "chars": 20,
    "preview": "#!/usr/bin/python3\n\n"
  },
  {
    "path": "models/archs/diffusion_arch.py",
    "chars": 10010,
    "preview": "import math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum \n\nfrom einops import rearrange, re"
  },
  {
    "path": "models/archs/encoders/__init__.py",
    "chars": 20,
    "preview": "#!/usr/bin/python3\n\n"
  },
  {
    "path": "models/archs/encoders/auto_decoder.py",
    "chars": 664,
    "preview": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport json\nimport math\n\nclas"
  },
  {
    "path": "models/archs/encoders/conv_pointnet.py",
    "chars": 20375,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\nfrom torch_scatter import "
  },
  {
    "path": "models/archs/encoders/dgcnn.py",
    "chars": 4339,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef knn(x, k):\n    inner = -2 * torch.matmul(x.trans"
  },
  {
    "path": "models/archs/encoders/rbf.py",
    "chars": 1428,
    "preview": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport json\nimport numpy as n"
  },
  {
    "path": "models/archs/encoders/sal_pointnet.py",
    "chars": 2211,
    "preview": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\n\nimport json\nimport sys\nsys.p"
  },
  {
    "path": "models/archs/encoders/vanilla_pointnet.py",
    "chars": 1790,
    "preview": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nfrom torchmeta.modules import"
  },
  {
    "path": "models/archs/modulated_sdf.py",
    "chars": 5430,
    "preview": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport json\nimport sys\nimport"
  },
  {
    "path": "models/archs/resnet_block.py",
    "chars": 1233,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n# Resnet Blocks\nclass ResnetBlockFC(nn.Module):\n    "
  },
  {
    "path": "models/archs/sdf_decoder.py",
    "chars": 2590,
    "preview": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport torch.nn.init as init\n"
  },
  {
    "path": "models/archs/unet.py",
    "chars": 8695,
    "preview": "'''\nCodes are from:\nhttps://github.com/jaxony/unet-pytorch/blob/master/model.py\n'''\n\nimport torch\nimport torch.nn as nn\n"
  },
  {
    "path": "models/autoencoder.py",
    "chars": 7579,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom einops import rearrange, reduce\n\nfrom typin"
  },
  {
    "path": "models/combined_model.py",
    "chars": 8535,
    "preview": "import torch\nimport torch.utils.data \nfrom torch.nn import functional as F\nimport pytorch_lightning as pl\n\n# add paths i"
  },
  {
    "path": "models/diff_np_if_torch_error.py",
    "chars": 14446,
    "preview": "import math\nimport copy\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom inspect import is"
  },
  {
    "path": "models/diffusion.py",
    "chars": 12756,
    "preview": "import math\nimport copy\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom inspect import is"
  },
  {
    "path": "models/sdf_model.py",
    "chars": 2354,
    "preview": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport pytorch_lightning as p"
  },
  {
    "path": "test.py",
    "chars": 10526,
    "preview": "#!/usr/bin/env python3\n\nimport torch\nimport torch.utils.data \nfrom torch.nn import functional as F\nimport pytorch_lightn"
  },
  {
    "path": "train.py",
    "chars": 4458,
    "preview": "#!/usr/bin/env python3\n\nimport torch\nimport torch.utils.data \nfrom torch.nn import functional as F\nimport pytorch_lightn"
  },
  {
    "path": "utils/__init__.py",
    "chars": 24,
    "preview": "#!/usr/bin/env python3\n\n"
  },
  {
    "path": "utils/chamfer.py",
    "chars": 1484,
    "preview": "import numpy as np\nfrom scipy.spatial import cKDTree as KDTree\n\n\ndef compute_trimesh_chamfer(gt_points, gen_points, offs"
  },
  {
    "path": "utils/evaluate.py",
    "chars": 3486,
    "preview": "#!/usr/bin/env python3\n\nimport argparse\nimport logging\nimport json\nimport numpy as np\nimport pandas as pd \nimport os, sy"
  },
  {
    "path": "utils/mesh.py",
    "chars": 4722,
    "preview": "#!/usr/bin/env python3\n\nimport logging\nimport math\nimport numpy as np\nimport plyfile\nimport skimage.measure\nimport time\n"
  },
  {
    "path": "utils/mmd.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/reconstruct.py",
    "chars": 9104,
    "preview": "#!/usr/bin/env python3\n\nimport torch\nimport torch.utils.data \nfrom torch.nn import functional as F\nimport pytorch_lightn"
  },
  {
    "path": "utils/renderer.py",
    "chars": 3990,
    "preview": "import numpy as np\nimport trimesh\nimport trimesh.transformations as tra\nimport pyrender\nimport matplotlib.pyplot as plt "
  },
  {
    "path": "utils/tmd.py",
    "chars": 2246,
    "preview": "import argparse\nimport os\nimport numpy as np\nimport trimesh\nfrom utils.chamfer import compute_trimesh_chamfer\nimport glo"
  },
  {
    "path": "utils/uhd.py",
    "chars": 4359,
    "preview": "import argparse\nimport os\nimport torch\nimport numpy as np\nfrom scipy.spatial import cKDTree as KDTree\nimport trimesh\nimp"
  },
  {
    "path": "utils/visualize.py",
    "chars": 1842,
    "preview": "import numpy as np\nimport trimesh\nimport sys\nimport h5py\nimport open3d as o3d\n\ndef create_point_marker(center, color):\n\n"
  }
]

// ... and 4 more files (download for full content)

About this extraction

This page contains the full source code of the princeton-computational-imaging/Diffusion-SDF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 81 files (38.4 MB), approximately 87.3k tokens, and a symbol index with 330 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!