[
  {
    "path": ".gitignore",
    "content": ".DS_Store\n__pycache__/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Gene Chou\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions\n\n[**Paper**](https://arxiv.org/abs/2211.13757) | [**Supplement**](https://light.princeton.edu/wp-content/uploads/2023/03/diffusionsdf_supp.pdf) | [**Project Page**](https://light.princeton.edu/publication/diffusion-sdf/) <br>\n\nThis repository contains the official implementation of <br> \n**[ICCV 2023] Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions** <br>\n[Gene Chou](https://genechou.com), [Yuval Bahat](https://sites.google.com/view/yuval-bahat/home), [Felix Heide](https://www.cs.princeton.edu/~fheide/) <br>\n\n\nIf you find our code or paper useful, please consider citing\n```bibtex\n@inproceedings{chou2022diffusionsdf,\ntitle={Diffusion-SDF: Conditional Generative Modeling of Signed Distance Functions},\nauthor={Gene Chou and Yuval Bahat and Felix Heide},\njournal={The IEEE International Conference on Computer Vision (ICCV)},\nyear={2023}\n}\n```\n\n\n```cpp\nroot directory\n  ├── config  \n  │   └── // folders for checkpoints and training configs\n  ├── data  \n  │   └── // folders for data (in csv format) and train test splits (json)\n  ├── models  \n  │   ├── // models and lightning modules; main model is 'combined_model.py'\n  │   └── archs\n  │       └── // architectures such as PointNets, SDF MLPs, diffusion network..etc\n  ├── dataloader  \n  │   └── // dataloaders for different stages of training and generation\n  ├── utils  \n  │   └── // reconstruction and evaluation\n  ├── metrics  \n  │   └── // reconstruction and evaluation\n  ├── diff_utils  \n  │   └── // helper functions for diffusion\n  ├── environment.yml  // package requirements\n  ├── train.py  // script for training, specify the stage of training in the config files\n  ├── test.py  // script for testing, specify the stage of testing in the config files\n  └── tensorboard_logs  // created when running any training script\n  \n```\n\n## Installation\nWe recommend creating an [anaconda](https://www.anaconda.com/) environment using our provided `environment.yml`:\n\n```\nconda env create -f environment.yml\nconda activate diffusionsdf\n```\n\n## Dataset\nFor training, we preprocess all meshes and store query coordinates and signed distance values in csv files. Each csv file corresponds to one object, and each line represents a coordinate followed by its signed distance value. See `data/acronym` for examples. Modify the dataloader according to your file format. <br>\n\nWhen sampling query points, make sure to also **sample uniformly within the 3D grid space** (i.e. from (-1,-1,-1) to (1,1,1)) rather than only sampling near the surface to avoid artifacts. For each training batch, we take 70% of query points sampled near the object surface and 30% sampled uniformly in the grid. `grid_source` in our dataloader and config file refers to the latter. <br>\n\n## Training\nAs described in our [paper](https://arxiv.org/abs/2211.13757), there are three stages of training. All corresponding config files can be found in the `config` folders. Logs are created in a `tensorboard_logs` folder in the root directory. We recommend tuning the `\"kld_weight\"` when training the joint SDF-VAE model as it enforces the continuity of the latent space. A higher value (e.g. 0.1) will result in better interpolation and generalization but sometimes more artifacts. A lower value (e.g. 0.00001) will result in worse interpolation but higher quality of generations. <br>\n\n1. Training SDF modulations\n\n```\npython train.py -e config/stage1_sdf/ -b 32 -w 8    # -b for batch size, -w for workers, -r to resume training\n```\nTraining notes: For Acronym / ShapeNet datasets, the loss should go down to $6 \\sim 8 \\times 10^{-4}$. Run testing to visualize whether the quality of reconstructed shapes is sufficient. The quality of reconstructions will carry over to the quality of generations. Note that the dimension of the VAE latent vectors will be 3 times `\"latent_dim\"` in `\"SdfModelSpecs\"` listed in the config file.\n\n2. Training the diffusion model using the modulations extracted from the first stage \n\n```\n# extract the modulations / latent vectors, which will be saved in a \"modulations\" folder in the config directory\n# the folder needs to correspond to \"data_path\" in the diffusion config files\n\npython test.py -e config/stage1_sdf/ -r last\n\n# unconditional\npython train.py -e config/stage2_diff_uncond/ -b 32 -w 8 \n\n# conditional\npython train.py -e config/stage2_diff_cond/ -b 32 -w 8 \n```\nTraining notes: When extracting modulations, we recommend filtering based on the chamfer distance. See `test_modulations()` in `test.py` for details. Some notes on the conditional config file:  `\"perturb_pc\":\"partial\"`, `\"crop_percent\":0.5`, and `\"sample_pc_size\":128` refers to cropping 50% of a point cloud with 128 points to use as condition. `dim` in `diffusion_model_specs` needs to be the dimension of the latent vector, which is 3 times `\"latent_dim\"` in `\"SdfModelSpecs\"`. <br>\n\n\n3. End-to-end training using the saved models from above \n\n```\n# unconditional\npython train.py -e config/stage3_uncond/ -b 32 -w 8 -r finetune     # training from the saved models of first two stages\npython train.py -e config/stage3_uncond/ -b 32 -w 8 -r last     # resuming training if third stage has been trained \n\n# conditional\npython train.py -e config/stage3_cond/ -b 32 -w 8 -r finetune    # training from the saved models of first two stages\npython train.py -e config/stage3_cond/ -b 32 -w 8 -r last     # resuming training if third stage has been trained \n```\nTraining notes: The config file needs to contain the saved checkpoints for the previous two stages of training. The sdf loss (not generated sdf loss) should approach $6 \\sim 8 \\times 10^{-4}$.\n\n## Testing\n1. Testing SDF reconstructions and saving modulations\n\nAfter the first stage of training, visualize / test reconstructions and save modulations:\n```\n# extract the modulations / latent vectors, which will be saved in a \"modulations\" folder in the config directory\n# the folder needs to correspond to \"data_path\" in the diffusion config files\npython test.py -e config/stage1_sdf/ -r last\n```\nA `recon` folder in the config directory will contain the `.ply` reconstructions and a `cd.csv` file that logs Chamfer Distance (CD). A `modulation` folder will contain `latent.txt` files for each SDF. The `modulation` folder will be the data path to the second stage of training.\n\n2. Generations \n\nMeshes can be generated after the second or third stage of training.\n```\npython test.py -e config/stage3_uncond/ -r finetune  # generation after second stage \npython test.py -e config/stage3_uncond/ -r last      # after third stage \n```\nA `recon` folder in the config directory will contain the `.ply` reconstructions. `max_batch` arguments in `test.py` are used for running marching cubes; change it to the max value your GPU memory can hold.\n\n\n## References\nWe adapt code from <br>\nGenSDF https://github.com/princeton-computational-imaging/gensdf <br>\nDALLE2-pytorch https://github.com/lucidrains/DALLE2-pytorch <br>\nConvolutional Occupancy Networks https://github.com/autonomousvision/convolutional_occupancy_networks (for PointNet encoder) <br>\nMultimodal Shape Completion via cGANs https://github.com/ChrisWu1997/Multimodal-Shape-Completion (for conditional metrics) <br>\nPointFlow https://github.com/stevenygd/PointFlow (for unconditional metrics)\n"
  },
  {
    "path": "config/stage1_sdf/specs.json",
    "content": "{\n    \"Description\" : \"training joint SDF-VAE model for modulating SDFs on the couch dataset\",\n    \"DataSource\" : \"data\",\n    \"GridSource\" : \"data/grid_data\",\n    \"TrainSplit\" : \"data/splits/couch_all.json\",\n    \"TestSplit\" : \"data/splits/couch_all.json\",\n    \n    \"training_task\": \"modulation\",\n  \n    \"SdfModelSpecs\" : {\n      \"hidden_dim\" : 512,\n      \"latent_dim\" : 256,\n      \"pn_hidden_dim\" : 128,\n      \"num_layers\" : 9\n    },\n\n    \"SampPerMesh\" : 16000,\n    \"PCsize\" : 1024,\n  \n    \"num_epochs\" : 100001,\n    \"log_freq\" : 5000,\n    \"kld_weight\" : 1e-5,\n    \"latent_std\" : 0.25,\n  \n    \"sdf_lr\" : 1e-4\n}\n  \n  \n  \n"
  },
  {
    "path": "config/stage2_diff_cond/specs.json",
    "content": "{\n  \"Description\" : \"diffusion training (conditional) on couch dataset\",\n  \"pc_path\" : \"data\",\n  \"total_pc_size\" : 10000,\n  \"TrainSplit\" : \"data/splits/couch_all.json\",\n  \"TestSplit\" : \"data/splits/couch_all.json\",\n  \"data_path\" : \"config/stage1_sdf/modulations\",\n\n  \"training_task\": \"diffusion\",\n\n  \n  \"num_epochs\" : 50001,\n  \"log_freq\" : 5000,\n\n  \"diff_lr\" : 1e-5,\n\n  \"diffusion_specs\" : {\n    \"timesteps\" : 1000,\n    \"objective\" : \"pred_x0\",\n    \"loss_type\" : \"l2\",\n    \"perturb_pc\" : \"partial\",\n    \"crop_percent\": 0.5,\n    \"sample_pc_size\" : 128\n  },\n  \"diffusion_model_specs\": {\n    \"dim\" : 768,\n    \"depth\" : 4,\n    \"ff_dropout\" : 0.3,\n    \"cond\" : true,\n    \"cross_attn\" : true,\n    \"cond_dropout\":true,\n    \"point_feature_dim\" : 128\n  }\n}\n\n\n"
  },
  {
    "path": "config/stage2_diff_uncond/specs.json",
    "content": "{\n  \"Description\" : \"diffusion training (unconditional) on couch dataset\",\n  \"TrainSplit\" : \"data/splits/couch_all.json\",\n  \"TestSplit\" : \"data/splits/couch_all.json\",\n  \"data_path\" : \"config/stage1_sdf/modulations\",\n\n  \"training_task\": \"diffusion\",\n\n  \"num_epochs\" : 50001,\n  \"log_freq\" : 5000,\n\n  \"diff_lr\" : 1e-5,\n\n  \"diffusion_specs\" : {\n    \"timesteps\" : 1000,\n    \"objective\" : \"pred_x0\",\n    \"loss_type\" : \"l2\"\n  },\n  \"diffusion_model_specs\": {\n    \"dim\" : 768,\n    \"dim_in_out\" : 768,\n    \"depth\" : 4,\n    \"ff_dropout\" : 0.3,\n    \"cond\" : false\n  }\n}\n\n"
  },
  {
    "path": "config/stage3_cond/specs.json",
    "content": "{\n  \"Description\" : \"end-to-end training (conditional) on couch dataset\",\n  \"DataSource\" : \"data\",\n  \"GridSource\" : \"grid_data\",\n  \"TrainSplit\" : \"data/splits/couch_all.json\",\n  \"TestSplit\" : \"data/splits/couch_all.json\",\n  \"modulation_path\" : \"config/stage1_sdf/modulations\",\n\n  \"modulation_ckpt_path\" : \"config/stage1_sdf/last.ckpt\",\n  \"diffusion_ckpt_path\" : \"config/stage2_cond/last.ckpt\",\n  \n  \"training_task\": \"combined\",\n\n  \"num_epochs\" : 100001,\n  \"log_freq\" : 5000,\n\n  \"kld_weight\" : 1e-5,\n  \"latent_std\" : 0.25,\n  \n  \"sdf_lr\" : 1e-4,\n  \"diff_lr\" : 1e-5,\n\n  \"SdfModelSpecs\" : {\n    \"hidden_dim\" : 512,\n    \"latent_dim\" : 256,\n    \"pn_hidden_dim\" : 128,\n    \"num_layers\" : 9\n  },\n  \"SampPerMesh\" : 16000,\n  \"PCsize\" : 1024,\n\n  \"diffusion_specs\" : {\n    \"timesteps\" : 1000,\n    \"objective\" : \"pred_x0\",\n    \"loss_type\" : \"l2\",\n    \"perturb_pc\" : \"partial\",\n    \"crop_percent\": 0.5,\n    \"sample_pc_size\" : 128\n  },\n  \"diffusion_model_specs\": {\n    \"dim\" : 768,\n    \"depth\" : 4,\n    \"ff_dropout\" : 0.3,\n    \"cond\" : true,\n    \"cross_attn\" : true,\n    \"cond_dropout\":true,\n    \"point_feature_dim\" : 128\n  }\n}\n\n\n"
  },
  {
    "path": "config/stage3_uncond/specs.json",
    "content": "{\n  \"Description\" : \"end-to-end training (unconditional) on couch dataset\",\n  \"DataSource\" : \"data\",\n  \"GridSource\" : \"grid_data\",\n  \"TrainSplit\" : \"data/splits/couch_all.json\",\n  \"TestSplit\" : \"data/splits/couch_all.json\",\n  \"modulation_path\" : \"config/stage1_sdf/modulations\",\n\n  \"modulation_ckpt_path\" : \"config/stage1_sdf/last.ckpt\",\n  \"diffusion_ckpt_path\" : \"config/stage2_uncond/last.ckpt\",\n  \n  \"training_task\": \"combined\",\n\n  \"num_epochs\" : 100001,\n  \"log_freq\" : 5000,\n\n  \"kld_weight\" : 1e-5,\n  \"latent_std\" : 0.25,\n  \n  \"sdf_lr\" : 1e-4,\n  \"diff_lr\" : 1e-5,\n\n  \"SdfModelSpecs\" : {\n    \"hidden_dim\" : 512,\n    \"latent_dim\" : 256,\n    \"pn_hidden_dim\" : 128,\n    \"num_layers\" : 9\n  },\n\n  \"SampPerMesh\" : 16000,\n  \"PCsize\" : 1024,\n\n  \"diffusion_specs\" : {\n    \"timesteps\" : 1000,\n    \"objective\" : \"pred_x0\",\n    \"loss_type\" : \"l2\"\n  },\n  \"diffusion_model_specs\": {\n    \"dim\" : 768,\n    \"dim_in_out\" : 768,\n    \"depth\" : 4,\n    \"ff_dropout\" : 0.3,\n    \"cond\" : false\n  }\n}\n\n\n"
  },
  {
    "path": "data/splits/couch_all.json",
    "content": "{\n    \"acronym\": {\n        \"Couch\": [\n            \"37cfcafe606611d81246538126da07a8\",\n            \"fd124209f0bc1222f34e56d1d9ad8c65\",\n            \"9bbf6977b3f0f9cb11e76965808086c8\",\n            \"b6ac23d327248d72627bb9f102840372\",\n            \"133f2d360f9b8691f5ee22e800bb9145\",\n            \"fa1e1a91e66faf411de55fee5ac2c5c2\",\n            \"a10a157254f74b0835836c728d324152\",\n            \"9152b02103bdc98633b8015b1af14a5f\",\n            \"9cddb828b936db93c341afa383659322\",\n            \"69c1e004f9e42b71e7e684d25d4dcaf0\",\n            \"694af585378fef692d0fc9d5a4650a3b\",\n            \"9156988a1a8645e727eb00c151c6f711\",\n            \"e549ab0c75d7fd08cbf65787367134d6\",\n            \"354c37c168778a0bd4830313df3656b\",\n            \"c32dc92e660bf12b7b6fd5468f603b31\",\n            \"a613610e1b9eda601e20309da4bdcbd0\",\n            \"8701046c07327e3063f9008a349ae40b\",\n            \"95277cbbe4e4a566c5ee765adc4c8ef3\",\n            \"c29c56fe616c0986e7e684d25d4dcaf0\",\n            \"252be483777007c22e7955415f58545\",\n            \"8030e04cf9c19768623a7f06c75e0539\",\n            \"21f76612b56d67edf54efb4962ed3879\",\n            \"cfa489d103b4f937411053a770408f0d\",\n            \"83fd038152fca4ba1c5fa3821fb8a849\",\n            \"a38ab75ff03f4befe3f7a74e12a274ef\",\n            \"ac376ac6f62d2bd535836c728d324152\",\n            \"af796bbeb2990931a1ce49849c98d31c\",\n            \"3f90f35a391d5e273b04279a45e22431\",\n            \"cfb40e7b9990da99c2f927df125f5ce4\",\n            \"26205c6a0a25d5f884099151cc96de84\",\n            \"199da405357516a37540a07f0d62945\",\n            \"f261a5d15d8dd1cee16219238f4e984c\",\n            \"4c5d2b857e9b60a884b236d07141c1a6\",\n            \"9b818d54d601b026b161f36d4e309050\",\n            \"40060ca118b6dd904d3df84a5eca1a73\",\n            \"a17f6d35f289433b5112295517941cf7\",\n            \"6475924415a3534111e85af13fa8bb9b\",\n            \"87f103e24f91af8d4343db7d677fae7b\",\n            \"656925921cfaf668639c9ca4528de8f2\",\n            \"b38978ed5bb13266f2b67ae827b02632\",\n            \"ce0a19a131121305d2bc6fb367627d3e\",\n            \"f645a78ce2bd743ed0c05eb40b42c942\",\n            \"2b8d1c67d17d3911d9cff7df5347abca\",\n            \"4ef1c8ad9c72263cdc6d207f01b16dd0\",\n            \"537438b5f127b55e851f4ba6aaedaaa8\",\n            \"2b158fc055b09892e3f7a74e12a274ef\",\n            \"abb7fb6cc7c5b2632829384b8c505bb2\",\n            \"e68758f3775f700df1783a44a88d6274\",\n            \"4dd15f5b09c1494e8058cf23f6382c1\",\n            \"6c4eece6ed9f26f1a1741c84cc821e55\",\n            \"ae7e1fe11d9243bc5c7a30510dbe4e9f\",\n            \"63d0dd6566dfec32ea65c47b660136e7\",\n            \"7c6491c7fe65449cd9bdef3310c0e5a6\",\n            \"de14f6786ddc7e11bd83c80a89cb8f94\",\n            \"a1b0f606bee9dd8e2e7ed79f4f48ff79\",\n            \"8efa91e2f3e2eaf7bdc82a7932cd806\",\n            \"fa1e84f906d8b61973b1be0dc36d8eee\",\n            \"3f6cbfc955e3a82e2f3bfb1446a70fbe\",\n            \"a627a11941892ada3707a503164f0d91\",\n            \"436a96f58ef9a6fdb039d8689a74349\",\n            \"c644689138389daa749ddac355b8e63d\",\n            \"ee5f19266a3ed535a6491c91cd8d7770\",\n            \"62852049a9f049cca67a896eea47bc81\",\n            \"2efd479c13f4cf72af1a85b857ec9afc\",\n            \"32b6c232154f2f35550644042dd119da\",\n            \"5fffafa504c7ee2a8c4f202fffc87396\",\n            \"b714755fb21264c73f89a402a2f2d0aa\",\n            \"7e68c12da163627dd2856b4dc0e7a482\",\n            \"ce273dabcdd9943b663191fd557d3a61\",\n            \"a207c52399b289be9657063f91d78d19\",\n            \"f0a02b63e1c84a3fbf7df791578d3fb5\",\n            \"34156abf2ce71a60e7e684d25d4dcaf0\",\n            \"80982ab523bd3ad35c7200259c323750\",\n            \"82460a80f6e9cd8bb0851ce87f32a267\",\n            \"823e0ffaf2b47fb27f45370489ca3156\",\n            \"e6d92067a01231cb3f7e27638e63d848\",\n            \"35b4c449b5ae845013aebd62ca4dcb58\",\n            \"8a1a39223639f16e833c6c72c4b62a4d\",\n            \"f87682301c61288c25a12159dbc477b5\",\n            \"76bef187c092b6335ff61a3a2a0e2484\",\n            \"9473a8f3e2182d90d810b14a81e12eca\",\n            \"ec9d8a450d5d30935951e132a7bb8604\",\n            \"e94c453523b1dabda4669f677ccd56a9\",\n            \"d8e1f6781276eb1af34e56d1d9ad8c65\",\n            \"1fd45c57ab27cb6cea65c47b660136e7\",\n            \"b55d24ed786a8279ad2d3ca7f660ddd\",\n            \"30497cce015f143c1191e01810565e0c\",\n            \"9b73921863beef532d1fcd5297483645\",\n            \"e49c0df0a42bdbecc4b4c7225ff8487e\",\n            \"bac90cda9aeb6f5d348e0a0702972359\",\n            \"264d40f99914b97e577df49fb73cc67c\",\n            \"b1b5d123613e75c97d1a9cfc3f714b6d\",\n            \"3b109c1cd75648b71a9a361595e863ab\",\n            \"44708e57c3c0140f9cf1b86170c5d320\",\n            \"66f193b8f45cd9525582003645829560\",\n            \"9c4c4fbaa6e9284fabd9c31d4e727bfe\",\n            \"f20e7a860fca9179d57c8a5f8e280cfb\",\n            \"8872b51e43d2839b2bbb2dbe46302370\",\n            \"323e278c6d7942691836774006aed7e2\",\n            \"4106b3d95b6dc95caacb7aee27fab780\",\n            \"601bf25b4512502145c6cb69e0968783\",\n            \"cb57531ee7104145b0892e337a99f42d\",\n            \"b81749d11da67cef57a847db7547c1f3\",\n            \"3aab3cf29ab92ce7c29432ec481a60b1\",\n            \"e0b897af3eee5ec8d8ac5d7ad0953104\",\n            \"eda881a6ea96da8a46874ce99dea28d5\",\n            \"5bf5096583e15c0080741efeb2454ffb\",\n            \"b765b2c997c459fa83fb3a64ac774b17\",\n            \"ca3e6e011ed7ecb32dc86554a873c4fe\",\n            \"eb74c7d047809a8c859e01b847403b9e\",\n            \"3cce48090f5fabe4e356f23093e95c53\",\n            \"dc079a42bd90dcd9593ebeeedbff73b\",\n            \"2c12a9ba6d4d25df8af30108ea9ccb6c\",\n            \"97f506b27434613d51c4deb11af7079e\",\n            \"98e6845fd0c59edef1a8499ee563cb43\",\n            \"addd6a0ef4f55a66d810b14a81e12eca\",\n            \"21c61d015c6a9591e3f7a74e12a274ef\",\n            \"9de72746c0f00317dd73da65dc0e6a7\",\n            \"ac8a2b54ab0d2f94493a4a2a112261d3\",\n            \"f31648796796cff297c8d78b9aede742\",\n            \"c69ce34e38f6218b2f809039658ca52\",\n            \"31b5cb5dfaa253b3df85db41e3677c28\",\n            \"d8fa31c19a952efb293968bf1f72ae90\",\n            \"ab347f12652bda8eab7f9d2da6fc61cf\",\n            \"ccd0c5e06744ad9a5ca7e476d2d4a26e\",\n            \"daf0e2193c5d1096540b442e651b7ad2\",\n            \"3aa613c06675d2a4dd94d3cfe79da065\",\n            \"6c3cee36477c444e2e64cef1881421b\",\n            \"9a32ae6c8a6afa1e7cfc6b949bbd80bf\",\n            \"a85046f089363914e500b815ce2d83a5\",\n            \"be2fe14cb518d6346c115ab7398115c6\",\n            \"7961d0f612add0cee08bb071746122b9\",\n            \"b95a9b552791e21035836c728d324152\",\n            \"d2014ddbb4e91d7ecb1a776b5576b46b\",\n            \"9c0c2110a58e15febc48810968735439\",\n            \"950b680f7b13e72956433d91451f5f9e\",\n            \"d01ce0f02e25fc2b42e1bb4fe264125f\",\n            \"5fabccb7eaa2adb19a037b4abf810691\",\n            \"ad5ef1b493028c0bd810b14a81e12eca\",\n            \"a2bdc3a6cb13bf197a969536c4ba7f8\",\n            \"3c56ceef171fa142126c0b0ea3ea0a2c\",\n            \"a8ae390fc6de647e1191e01810565e0c\",\n            \"10507ae95f984daccd8f3fe9ca2145e1\",\n            \"ebf1982ccb77cf7d4c37b9ce3a3de242\",\n            \"fbd0055862daa31a2d8ad3188383fcc8\",\n            \"1d6250cafc410fdfe8058cf23f6382c1\",\n            \"7833d94635b755793adc3470b30138f3\",\n            \"ee4ec19df3a6b81317fd40ad196436c\",\n            \"82d25519070e3d5d6f1ad7def14e2855\",\n            \"1ed2e9056a9e94a693e60794f9200b7\",\n            \"980d28e46333db492878c1c13b7f1ba6\",\n            \"a84ff0a6802b0661f345fb470303964a\",\n            \"b58291a2e96691189d1eb836604648db\",\n            \"f653f39d443ac6af15c0ed29be4328d5\",\n            \"4f20a1fe4c85f35aa8891e678de4fe35\",\n            \"90e1a4fb7fa4f4fea5df41522f4e2c86\",\n            \"e5f793180e610d329d6f7fa9d8096b18\",\n            \"4760c46c66461a79dd3adf3090c701f7\",\n            \"3069cb9d632611dcdb43ca77043c03b9\",\n            \"6f0f6571f173bd90f9883d2fd957d60f\",\n            \"11d5e99e8faa10ff3564590844406360\",\n            \"36980a6d57dac873d206493eaec9688a\",\n            \"80b3fb6bae282ec24ca35f9816a36449\",\n            \"7bfef1d5b096f81967ec3d19eeb10ea5\",\n            \"6bfe4dac77d317de1181122615f0a10\",\n            \"2b4de06792ec0eba94141819f1b9662c\",\n            \"20c030e0d055b7aeb0892e337a99f42d\",\n            \"bb53613789a3e9158d736fde63f29661\",\n            \"4ca3850427226dd18d4d8cbfa622f635\",\n            \"a4d269f9299e845a36e3b2fa8d1eb4eb\",\n            \"d13a2ccdbb7740ea83a0857b7b9398b1\",\n            \"a98956209f6723a2dedecd2df7bf25e3\",\n            \"80c8fcbc5f53c402162cb7009da70820\",\n            \"32909910636369b24500047017815f5f\",\n            \"b3d686456bd951d42ea98d69e91ba870\",\n            \"c0d544b1b483556e4d70389e06680f96\",\n            \"d08fc6f10d07dfd8c05575120a46cd3b\",\n            \"e2e25d63c3bc14c23459b03b1c80294\",\n            \"3d95d6237ea6db97afa2904116693357\",\n            \"a7e4616a2a315dfac5ddc26ef5560e77\",\n            \"7a92d499a6f3460a4dc2316a7e66d36\",\n            \"22fd961578d9b100e858db1dc3499392\",\n            \"9136172b17108d1ba7d0cc9b15400f65\",\n            \"62e50e8b0d1e3207e047a3592e8436e5\",\n            \"7fb8fdbc8c32dc2aad6d7e3d5c0179fe\",\n            \"1488c4dbdcd022e99de4fa6f98acbba8\",\n            \"9cff96bae9963ceab3c971099efd7fb9\",\n            \"5cea034b028af000c2843529921f9ad7\",\n            \"8a8c67e553f1c85c1829bffea9d18abb\",\n            \"90959c7526608ecdfb54b6f00b1687e2\",\n            \"8043ff469c9bed4d48c575435d16be7c\",\n            \"b031d0f66e8d047bee8bf38bd6d16329\",\n            \"700562e70c89e9d36bd9ce012b65eebe\",\n            \"b9994f0252c890cbb0892e337a99f42d\",\n            \"7201558f92e55496493a4a2a112261d3\",\n            \"e0f5780dbee33329caa60d250862f15f\",\n            \"1aa55867200ea789465e08d496c0420f\",\n            \"956c5ffbe1ad52976e3c8a33c4ddf2ef\",\n            \"3415f252bd71495649920492438878e5\",\n            \"d64025df908177ece7e684d25d4dcaf0\",\n            \"65309e7601a0c65d82608d43718c2b7\",\n            \"f8ef9668c3aba7c735836c728d324152\",\n            \"6c4e0987896fc5df30c7f4adc2c33ad6\",\n            \"23d1a49a9a29c776ab9281d3b84673cf\",\n            \"aa5fe2c92f01ce0a30cbbda41991e4d0\",\n            \"2fca00b377269eebdb039d8689a74349\",\n            \"fe56059777b240bb833c6c72c4b62a4d\",\n            \"82a168f7c5b8899a79368d1198f406e7\",\n            \"1ebc20758e0b61184bf4a6644c670a1e\",\n            \"f1ce06c5259f771dc24182d0db4c6889\",\n            \"a1b935c6288595267795cc6fb87dd247\",\n            \"b16913335a26e380d1a4117555fc7e56\",\n            \"87bdac1e34f3e6146db2ac45db35c175\",\n            \"f693bb3178c80d7f1783a44a88d6274\",\n            \"f39cb99f7c30a4b8310cd758d9b7cf\",\n            \"91238bfd357c2d87abe6f34906958ac9\",\n            \"801418d073795539ddb3d1341171fe71\",\n            \"b2e9bf58ab2458b5b057261be64dfc5\",\n            \"7669de4a8474b6f4b53857b83094d3a9\",\n            \"9d8c5c62020fcaf4b822d48a43773c62\",\n            \"a680830f8b76c1bbe929777b2f481029\",\n            \"6d0cd48b18471a8bf1444eeb21e761c6\",\n            \"baa8760ca5fbbc4840b559ef47048b86\",\n            \"377fceb1500e6452d9651cd1d591d64d\",\n            \"440e3ad55b603cb1b071d266df0a3bf5\",\n            \"21bf3888008b7aced6d2e576c4ef3bde\",\n            \"151b2df36e08a13fafeeb1a322c90696\",\n            \"24f5497f13a1841adb039d8689a74349\",\n            \"a4c8e8816a1c5f54e6e3ac4cbdf2e092\",\n            \"4cd19483d852712a120031fb55ce8c1c\",\n            \"69257080fd87015369fb37a80cd44134\",\n            \"22c68a7a2c8142f027eb00c151c6f711\",\n            \"82efd3fc01b0c6f34500047017815f5f\",\n            \"fcff900ce37820983f7e27638e63d848\",\n            \"cdb7efad9942e4f9695ff9c22a7938c6\",\n            \"107637b6bdf8129d4904d89e9169817b\",\n            \"f2e7ed2b973570f1a54b9afa882a89ed\",\n            \"b58f27bd7e1ace90d2afe8d5254a0d04\",\n            \"4820b629990b6a20860f0fe00407fa79\",\n            \"abfd8f1db83d6fd24f5ae69aba92b12c\",\n            \"4ed802a4aa4b8a86b161f36d4e309050\",\n            \"47ad0af4207beedb296baeb5500afa1a\",\n            \"f6f5fa7760fde38114aa582256ea1399\",\n            \"60fc7123d6360e6d620ef1b4a95dca08\",\n            \"c53ec0141303c1eb4508add1163b4513\",\n            \"2ebb84f64f8f0f565db77ed1f5c8b93\",\n            \"6b0a09fda777eef2a0398757e5dcd13c\",\n            \"2c1ecb41c0f0d2cd07c7bf20dae278a\",\n            \"e03c28dbfe1f2d9638bb8355830240f9\",\n            \"1b29de0b8be4b18733d25da891be74b8\",\n            \"38a4328fc3e65ba54c37b9ce3a3de242\",\n            \"f81a05c0b43e9bf18c9441777325d0fd\",\n            \"4cb7347f6a91da8554f7af0f5a657663\",\n            \"5328231a28719ed240a92729068b6b39\",\n            \"9239ed83d2274ac975aa7f24a9b6003a\",\n            \"71579e02d0b80bcfccebba9929a10b5e\",\n            \"8647bc3a58eb9d04493a4a2a112261d3\",\n            \"f2c96ae9904baaf220021a69db826002\",\n            \"923a0885b7ca9c55d1007f4863f3bba6\",\n            \"2f28d32781abb0527f2a6b0d07c10212\",\n            \"e68e91ef2652cd1c36e3b2fa8d1eb4eb\",\n            \"9451b957aa36883d6e6c0340d644e56e\",\n            \"72b7ad431f8c4aa2f5520a8b6a5d82e0\",\n            \"54b420da7102d792b36717b39afb6ad8\",\n            \"16fd88a99f7d4c857e484225f3bb4a8\",\n            \"8bad639b9a650908de650492e45fb14f\",\n            \"24129a2a06b35e27c1a6b3575377c8d3\",\n            \"fc553c244c85909c8ed898bae5915f59\",\n            \"e3ce79fd03b7a78d98661b9abac3e1f9\",\n            \"3fcb0aaa346bd46f11e76965808086c8\",\n            \"df8374d8f3563be8f1783a44a88d6274\",\n            \"7eb94d259b560d24f1783a44a88d6274\",\n            \"48205c79d48ca2049f433921788191f3\",\n            \"9294163aacac61f1ad5c4e4076e069c\",\n            \"f2458aaf4ab3e0585d7543afa4b9b4e8\",\n            \"f09a3e98938bc1a990e678878ceea8a4\",\n            \"2e2f34305ef8cbc1533ccec14d70360b\",\n            \"1a477f7b2c1799e1b728e6e715c3f8cf\",\n            \"bcd2418fb0d9727c563fcc2752ece39\",\n            \"ace76562ee9d7c3a913c66b05d18ae8\",\n            \"1b5ae67e2ffb387341fbc1e2da054acb\",\n            \"f563b39b846922f22ea98d69e91ba870\",\n            \"fce717669ca521884e1a9fae5403e01f\",\n            \"d0563b5c26d096ea7b270a7ae3fc6351\",\n            \"43656f715bb28776bff15b656f256f05\",\n            \"3006b66ed9a9fd55a4ccffb47ce1f60d\",\n            \"f5dc5957a6b9712835836c728d324152\",\n            \"3d87710d90c8627dd2afe8d5254a0d04\",\n            \"21b22c30f1c6ddb9952d5d6c0ee49300\",\n            \"e077c6640dab26ecc6e735f548844733\",\n            \"bdfcf2086fafb0fec8a04932b17782af\",\n            \"fb883c27f9f0156e1e44b635c3d74c0b\",\n            \"161f92c3330d111ffc2dd24b885b9c05\",\n            \"b23dc14d788e954b3adc3470b30138f3\",\n            \"a930d381392ff51140b559ef47048b86\",\n            \"7c92e64a328f1b968f6cc6fefa15515a\",\n            \"8affea22019b77a1f1783a44a88d6274\",\n            \"1dc08faeb376c6e0296baeb5500afa1a\",\n            \"fadd7d8c94893136e4b1c2efb094888b\",\n            \"6bd7475753c3d1d62764cfba57a5de73\",\n            \"5d9f1c6f9ce9333994c6d0877753424f\",\n            \"c3e248c5f88e72e9ea65c47b660136e7\",\n            \"87c7911287c10ab8407b28046774089c\",\n            \"4fbc051d0905bb5bdaeb838d0771f3b5\",\n            \"b097e092f951f1d68bc0997abbde7f8\",\n            \"1aec9ac7e1487b7bc75516c7217fb218\",\n            \"2f9a502dafbaa50769cd744177574ad3\",\n            \"9df9d1c02e9013e7ef10d8e00e9d279c\",\n            \"f0f42d26c4a0be52a53016a50348cce3\",\n            \"f080807207cc4859b2403dba7fd079eb\",\n            \"ffcc57ea3101d18ece3df8a7477638c0\",\n            \"f01821fb48425e3a493a4a2a112261d3\",\n            \"8cc805e63f357a69ae336243a6891b91\",\n            \"468e627a950f9340d2bc6fb367627d3e\",\n            \"499edbd7de3e7423bb865c00ef25280\",\n            \"686c5c4a39de31445c31d9fe12bfc9ab\",\n            \"1470c6423b877a0882fcee0c19bca00a\",\n            \"bcc267b5e694e60c3adc3470b30138f3\",\n            \"e7b9fef210fa47505615f13c9cba61e5\",\n            \"ddc8bdf6ac3786f493b55327c66a07aa\",\n            \"27359b0203dc2f345c6cb69e0968783\",\n            \"1e355dc9bfd5da837a8c23d2d40f51b8\",\n            \"b2cfba8ee63abd118fac6a8030e15671\",\n            \"f1a956705451d66ee95eb670e9df38e4\",\n            \"6fc69edce1f6d0d5e7e684d25d4dcaf0\",\n            \"2d3f8abd0567f1601191e01810565e0c\",\n            \"52d307203aefd6bf366971e8a2cdf120\",\n            \"be129d18d202650f6d3e11439c6c22c8\",\n            \"2d4adec4323f8a7da6491c91cd8d7770\",\n            \"9695e057d7a4e992f2b67ae827b02632\",\n            \"21addfde981f0e3445c6cb69e0968783\",\n            \"f8a5ab9872d0dc5e1855a135fe295583\",\n            \"3f8f1d7023ae4a1c73ffdca541f4749c\",\n            \"1e4a7cb88d31716cc9c93fe51d670e21\",\n            \"2cf89f5b70817c3bab2844e3834b9ca1\",\n            \"e410469475df6525493a4a2a112261d3\",\n            \"81f1ac4d9c29025a36f739eb863924d1\",\n            \"c955e564c9a73650f78bdf37d618e97e\",\n            \"20b6d398c5b93a253adc3470b30138f3\",\n            \"ed4ac116e03ebb8d663191fd557d3a61\",\n            \"4735568bd188aefcb8e1b99345a5afd4\",\n            \"6b3b1d804255188993680c5a9a367b4a\",\n            \"5ea73e05bb96272b444ac0ff78e79b7\",\n            \"2eed4a1647d05552413240bb8f2e4eea\",\n            \"6e960c704961590e7e0d7ca07704f74f\",\n            \"dccb04ee585d9d10cf065d4b58dc3084\",\n            \"1eb4bc0cea4ba47c8186c25526ebdaa6\",\n            \"f144cda19f9457fef9b7ca92584b5271\",\n            \"c971b2adf8f071f49afca5ea91159877\",\n            \"1e96607159093f7bbdb7b8a99c2de3d8\",\n            \"a29f6cce8bf52a99182529900d54df69\",\n            \"f260588ef2f4c8e62a6f69f20bf5a185\",\n            \"72b7e2a9bdef8b37f491193d3fde480b\",\n            \"b5fe8de26eac454acdead18f26cf2ece\",\n            \"6caa1713e795c8a2f0478431b5ad57db\",\n            \"1914d0e6b9f0445b40e80a2d9f005aa7\",\n            \"c5380b779c689a919201f2703b45dd7\",\n            \"735122a1019fd6529dac46bde4c69ef2\",\n            \"660df170c4337cda35836c728d324152\",\n            \"770bf9d88c7f039227eb00c151c6f711\",\n            \"2b73510e4eb3d8ca87b66c61c1f7b8e4\",\n            \"3e499689bae22f3ab89ca298cd9a646\",\n            \"ae1fd69e00c00a0253116daf5a10d9c\",\n            \"823219c03b02a423c1a85f2b9754d96f\",\n            \"f87c2d4d5d622293a3f3891e682aaded\",\n            \"ee2ea7d9f26f57886580dafdf181b2d7\"\n        ]}\n}"
  },
  {
    "path": "dataloader/__init__.py",
    "content": ""
  },
  {
    "path": "dataloader/base.py",
    "content": "#!/usr/bin/env python3\n\nimport numpy as np\nimport time \nimport logging\nimport os\nimport random\nimport torch\nimport torch.utils.data\n\nimport pandas as pd \nimport csv\n\nclass Dataset(torch.utils.data.Dataset):\n    def __init__(\n        self,\n        data_source,\n        split_file, # json filepath which contains train/test classes and meshes \n        subsample,\n        gt_filename,\n        #pc_size=1024,\n    ):\n\n        self.data_source = data_source \n        self.subsample = subsample\n        self.split_file = split_file\n        self.gt_filename = gt_filename\n        #self.pc_size = pc_size\n\n        # example\n        # data_source: \"data\"\n        # ws.sdf_samples_subdir: \"SdfSamples\"\n        # self.gt_files[0]: \"acronym/couch/meshname/sdf_data.csv\"\n            # with gt_filename=\"sdf_data.csv\"\n\n    def __len__(self):\n        return NotImplementedError\n\n    def __getitem__(self, idx):     \n        return NotImplementedError\n\n    def sample_pointcloud(self, csvfile, pc_size):\n        f=pd.read_csv(csvfile, sep=',',header=None).values\n\n        f = f[f[:,-1]==0][:,:3]\n\n        if f.shape[0] < pc_size:\n            pc_idx = np.random.choice(f.shape[0], pc_size)\n        else:\n            pc_idx = np.random.choice(f.shape[0], pc_size, replace=False)\n\n        return torch.from_numpy(f[pc_idx]).float()\n\n    def labeled_sampling(self, f, subsample, pc_size=1024, load_from_path=True):\n        if load_from_path:\n            f=pd.read_csv(f, sep=',',header=None).values\n            f = torch.from_numpy(f)\n\n        half = int(subsample / 2) \n        neg_tensor = f[f[:,-1]<0]\n        pos_tensor = f[f[:,-1]>0]\n\n        if pos_tensor.shape[0] < half:\n            pos_idx = torch.randint(0, pos_tensor.shape[0], (half,))\n        else:\n            pos_idx = torch.randperm(pos_tensor.shape[0])[:half]\n\n        if neg_tensor.shape[0] < half:\n            if neg_tensor.shape[0]==0:\n                neg_idx = torch.randperm(pos_tensor.shape[0])[:half] # no neg indices, then just fill with positive samples\n            else:\n                neg_idx = torch.randint(0, neg_tensor.shape[0], (half,))\n        else:\n            neg_idx = torch.randperm(neg_tensor.shape[0])[:half]\n\n        pos_sample = pos_tensor[pos_idx]\n\n        if neg_tensor.shape[0]==0:\n            neg_sample = pos_tensor[neg_idx]\n        else:\n            neg_sample = neg_tensor[neg_idx]\n\n        pc = f[f[:,-1]==0][:,:3]\n        pc_idx = torch.randperm(pc.shape[0])[:pc_size]\n        pc = pc[pc_idx]\n\n        samples = torch.cat([pos_sample, neg_sample], 0)\n\n        return pc.float().squeeze(), samples[:,:3].float().squeeze(), samples[:, 3].float().squeeze() # pc, xyz, sdv\n\n\n    def get_instance_filenames(self, data_source, split, gt_filename=\"sdf_data.csv\", filter_modulation_path=None):\n            \n            do_filter = filter_modulation_path is not None \n            csvfiles = []\n            for dataset in split: # e.g. \"acronym\" \"shapenet\"\n                for class_name in split[dataset]:\n                    for instance_name in split[dataset][class_name]:\n                        instance_filename = os.path.join(data_source, dataset, class_name, instance_name, gt_filename)\n\n                        if do_filter:\n                            mod_file = os.path.join(filter_modulation_path, class_name, instance_name, \"latent.txt\")\n\n                            # do not load if the modulation does not exist; i.e. was not trained by diffusion model\n                            if not os.path.isfile(mod_file):\n                                continue\n                        \n                        if not os.path.isfile(instance_filename):\n                            logging.warning(\"Requested non-existent file '{}'\".format(instance_filename))\n                            continue\n\n                        csvfiles.append(instance_filename)\n            return csvfiles\n"
  },
  {
    "path": "dataloader/modulation_loader.py",
    "content": "#!/usr/bin/env python3\n\nimport time \nimport logging\nimport os\nimport random\nimport torch\nimport torch.utils.data\nfrom diff_utils.helpers import * \n\nimport pandas as pd \nimport numpy as np\nimport csv, json\n\nfrom tqdm import tqdm\n\nclass ModulationLoader(torch.utils.data.Dataset):\n    def __init__(self, data_path, pc_path=None, split_file=None, pc_size=None):\n        super().__init__()\n\n        self.conditional = pc_path is not None \n\n        if self.conditional:\n            self.modulations, pc_paths = self.load_modulations(data_path, pc_path, split_file)\n        else:\n            self.modulations = self.unconditional_load_modulations(data_path, split_file)\n        #self.modulations = self.modulations[0:8]\n        #pc_paths = pc_paths[0:8]\n\n        print(\"data shape, dataset len: \", self.modulations[0].shape, len(self.modulations))\n        #assert args.batch_size <= len(self.modulations)\n        \n        if self.conditional:\n            print(\"loading ground truth point clouds...\")            \n            lst = []\n            with tqdm(pc_paths) as pbar:\n                for i, f in enumerate(pc_paths):\n                    pbar.set_description(\"Point clouds loaded: {}/{}\".format(i, len(pc_paths)))\n                    lst.append(sample_pc(f, pc_size))\n            self.point_clouds = lst\n\n            assert len(self.point_clouds) == len(self.modulations)\n        \n        \n    def __len__(self):\n        return len(self.modulations)\n\n    def __getitem__(self, index):\n\n        pc = self.point_clouds[index] if self.conditional else False\n        return {\n            \"point_cloud\" : pc,\n            \"latent\" : self.modulations[index]         \n        }\n        \n\n    def load_modulations(self, data_source, pc_source, split, f_name=\"latent.txt\", add_flip_augment=False, return_filepaths=True):\n        #split = json.load(open(split))\n        files = []\n        filepaths = [] # return filepaths for loading pcs\n        for dataset in split: # dataset = \"acronym\" \n            for class_name in split[dataset]:\n                for instance_name in split[dataset][class_name]:\n\n                    if add_flip_augment:\n                        for idx in range(4):\n                            instance_filename = os.path.join(data_source, class_name, instance_name, \"latent_{}.txt\".format(idx))\n                            if not os.path.isfile(instance_filename):\n                                print(\"Requested non-existent file '{}'\".format(instance_filename))\n                                continue\n                            files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )\n                        filepaths.append( os.path.join(pc_source, dataset, class_name, instance_name, \"sdf_data.csv\") )\n\n                    else:\n                        instance_filename = os.path.join(data_source, class_name, instance_name, f_name)\n                        if not os.path.isfile(instance_filename):\n                            #print(\"Requested non-existent file '{}'\".format(instance_filename))\n                            continue\n                        files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )\n                        filepaths.append( os.path.join(pc_source, dataset, class_name, instance_name, \"sdf_data.csv\") )\n        if return_filepaths:\n            return files, filepaths\n        return files\n\n    def unconditional_load_modulations(self, data_source, split, f_name=\"latent.txt\", add_flip_augment=False):\n        files = []\n        for dataset in split: # dataset = \"acronym\" \n            for class_name in split[dataset]:\n                for instance_name in split[dataset][class_name]:\n\n                    if add_flip_augment:\n                        for idx in range(4):\n                            instance_filename = os.path.join(data_source, class_name, instance_name, \"latent_{}.txt\".format(idx))\n                            if not os.path.isfile(instance_filename):\n                                print(\"Requested non-existent file '{}'\".format(instance_filename))\n                                continue\n                            files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )\n\n                    else:\n                        instance_filename = os.path.join(data_source, class_name, instance_name, f_name)\n                        if not os.path.isfile(instance_filename):\n                            continue\n                        files.append( torch.from_numpy(np.loadtxt(instance_filename)).float() )\n        return files"
  },
  {
    "path": "dataloader/pc_loader.py",
    "content": "#!/usr/bin/env python3\n\nimport time \nimport logging\nimport os\nimport random\nimport torch\nimport torch.utils.data\nfrom . import base \nfrom tqdm import tqdm\n\nimport pandas as pd \nimport csv\n\nclass PCloader(base.Dataset):\n\n    def __init__(\n        self,\n        data_source,\n        split_file, # json filepath which contains train/test classes and meshes \n        pc_size=1024,\n        return_filename=False\n    ):\n\n        self.pc_size = pc_size\n        self.gt_files = self.get_instance_filenames(data_source, split_file)\n        self.return_filename = return_filename\n\n        self.pc_paths = self.get_instance_filenames(data_source, split_file)\n        self.pc_paths = self.pc_paths[:5] \n        print(\"loading {} point clouds into memory...\".format(len(self.pc_paths)))\n        lst = []\n        with tqdm(self.pc_paths) as pbar:\n            for i, f in enumerate(pbar):\n                pbar.set_description(\"Files loaded: {}/{}\".format(i, len(self.pc_paths)))\n                lst.append(self.sample_pc(f, pc_size))\n        self.point_clouds = lst\n\n        #print(\"each pc shape: \", self.point_clouds[0].shape)\n\n    def get_all_files(self):\n        return self.point_clouds, self.pc_paths \n    \n    def __getitem__(self, idx): \n        if self.return_filename:\n            return self.point_clouds[idx], self.pc_paths[idx]\n        else:\n            return self.point_clouds[idx]\n\n\n    def __len__(self):\n        return len(self.point_clouds)\n\n\n    def sample_pc(self, f, samp=1024): \n        '''\n        f: path to csv file\n        '''\n        # data = torch.from_numpy(np.loadtxt(f, delimiter=',')).float()\n        data = torch.from_numpy(pd.read_csv(f, sep=',',header=None).values).float()\n        pc = data[data[:,-1]==0][:,:3]\n        pc_idx = torch.randperm(pc.shape[0])[:samp] \n        pc = pc[pc_idx]\n        #print(\"pc shape, dtype: \", pc.shape, pc.dtype) # [1024,3], torch.float32\n        #pc = normalize_pc(pc)\n        #print(\"pc shape: \", pc.shape, pc.max(), pc.min())\n        return pc\n\n\n\n    \n"
  },
  {
    "path": "dataloader/sdf_loader.py",
    "content": "#!/usr/bin/env python3\n\nimport time \nimport logging\nimport os\nimport random\nimport torch\nimport torch.utils.data\nfrom . import base \n\nimport pandas as pd \nimport numpy as np\nimport csv, json\n\nfrom tqdm import tqdm\n\nclass SdfLoader(base.Dataset):\n\n    def __init__(\n        self,\n        data_source, # path to points sampled around surface\n        split_file, # json filepath which contains train/test classes and meshes \n        grid_source=None, # path to grid points; grid refers to sampling throughout the unit cube instead of only around the surface; necessary for preventing artifacts in empty space\n        samples_per_mesh=16000,\n        pc_size=1024,\n        modulation_path=None # used for third stage of training; needs to be set in config file when some modulation training had been filtered\n    ):\n \n        self.samples_per_mesh = samples_per_mesh\n        self.pc_size = pc_size\n        self.gt_files = self.get_instance_filenames(data_source, split_file, filter_modulation_path=modulation_path)\n\n        subsample = len(self.gt_files) \n        self.gt_files = self.gt_files[0:subsample]\n\n        self.grid_source = grid_source\n        #print(\"grid source: \", grid_source)\n    \n        if grid_source:\n            self.grid_files = self.get_instance_filenames(grid_source, split_file, gt_filename=\"grid_gt.csv\", filter_modulation_path=modulation_path)\n            self.grid_files = self.grid_files[0:subsample]\n            lst = []\n            with tqdm(self.grid_files) as pbar:\n                for i, f in enumerate(pbar):\n                    pbar.set_description(\"Grid files loaded: {}/{}\".format(i, len(self.grid_files)))\n                    lst.append(torch.from_numpy(pd.read_csv(f, sep=',',header=None).values))\n            self.grid_files = lst\n            \n            assert len(self.grid_files) == len(self.gt_files)\n\n\n        # load all csv files first \n        print(\"loading all {} files into memory...\".format(len(self.gt_files)))\n        lst = []\n        with tqdm(self.gt_files) as pbar:\n            for i, f in enumerate(pbar):\n                pbar.set_description(\"Files loaded: {}/{}\".format(i, len(self.gt_files)))\n                lst.append(torch.from_numpy(pd.read_csv(f, sep=',',header=None).values))\n        self.gt_files = lst\n\n\n    def __getitem__(self, idx): \n\n        near_surface_count = int(self.samples_per_mesh*0.7) if self.grid_source else self.samples_per_mesh\n\n        pc, sdf_xyz, sdf_gt =  self.labeled_sampling(self.gt_files[idx], near_surface_count, self.pc_size, load_from_path=False)\n        \n\n        if self.grid_source is not None:\n            grid_count = self.samples_per_mesh - near_surface_count\n            _, grid_xyz, grid_gt = self.labeled_sampling(self.grid_files[idx], grid_count, pc_size=0, load_from_path=False)\n            # each getitem is one batch so no batch dimension, only N, 3 for xyz or N for gt \n            # for 16000 points per batch, near surface is 11200, grid is 4800\n            #print(\"shapes: \", pc.shape,  sdf_xyz.shape, sdf_gt.shape, grid_xyz.shape, grid_gt.shape)\n            sdf_xyz = torch.cat((sdf_xyz, grid_xyz))\n            sdf_gt = torch.cat((sdf_gt, grid_gt))\n            #print(\"shapes after adding grid: \", pc.shape, sdf_xyz.shape, sdf_gt.shape, grid_xyz.shape, grid_gt.shape)\n\n        data_dict = {\n                    \"xyz\":sdf_xyz.float().squeeze(),\n                    \"gt_sdf\":sdf_gt.float().squeeze(), \n                    \"point_cloud\":pc.float().squeeze(),\n                    }\n\n        return data_dict\n\n    def __len__(self):\n        return len(self.gt_files)\n\n\n\n    \n"
  },
  {
    "path": "diff_utils/helpers.py",
    "content": "import math\nimport torch\nimport torch.nn.functional as F\nimport numpy as np \nimport pandas as pd \nimport random \nfrom inspect import isfunction\nimport os\nimport json \n#import open3d as o3d\n\n\ndef get_split_filenames(data_source, split_file, f_name=\"sdf_data.csv\"):\n    split = json.load(open(split_file))\n    csvfiles = []\n    for dataset in split: # e.g. \"acronym\" \"shapenet\"\n        for class_name in split[dataset]:\n            for instance_name in split[dataset][class_name]:\n                instance_filename = os.path.join(data_source, dataset, class_name, instance_name, f_name)\n                if not os.path.isfile(instance_filename):\n                    print(\"Requested non-existent file '{}'\".format(instance_filename))\n                    continue\n                csvfiles.append(instance_filename)\n    return csvfiles\n\ndef sample_pc(f, samp=1024, add_flip_augment=False): \n    '''\n    f: path to csv file\n    '''\n    data = torch.from_numpy(pd.read_csv(f, sep=',',header=None).values).float()\n    pc = data[data[:,-1]==0][:,:3]\n    pc_idx = torch.randperm(pc.shape[0])[:samp] \n    pc = pc[pc_idx]\n\n    if add_flip_augment:\n        pcs = []\n        flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=pc.device)\n        for idx, axis in enumerate(flip_axes):\n            pcs.append(pc * axis)\n        return pcs\n\n\n    return pc\n\ndef perturb_point_cloud(pc, perturb, pc_size=None, crop_percent=0.25):\n    '''\n    if pc_size is None, return entire pc; else return with shape of pc_size\n    '''\n    assert perturb in [None, \"partial\", \"noisy\"]\n    if perturb is None:\n        pc_idx = torch.randperm(pc.shape[1])[:pc_size] \n        pc = pc[:,pc_idx]   \n        #print(\"pc shape: \", pc.shape)\n        return pc\n    elif perturb == \"partial\":\n        return crop_pc(pc, crop_percent, pc_size)\n    elif perturb == \"noisy\":\n        return jitter_pc(pc, pc_size)\n\ndef fps(data, number):\n    '''\n        data B N 3\n        number int\n    '''\n    fps_idx = pointnet2_utils.furthest_point_sample(data, number) \n    fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()\n    return fps_data\n\ndef crop_pc(xyz, crop, pc_size=None, fixed_points = None, padding_zeros = False):\n    '''\n     crop the point cloud given a randomly selected view\n     input point cloud: xyz, with shape (B, N, 3)\n     crop: float, percentage of points to crop out (e.g. 0.25 means keep 75% of points)\n     pc_size: integer value, how many points to return; None if return all (all meaning xyz size * crop)\n    '''\n    \n\n    if pc_size is not None:\n        xyz = xyz[:, torch.randperm(xyz.shape[1])[:pc_size] ]\n    \n    _,n,c = xyz.shape\n    device = xyz.device\n        \n    crop = int(xyz.shape[1]*crop)\n    #print(\"pc shape: \", xyz.shape, crop)\n\n    \n    assert c == 3\n    if crop == n:\n        return xyz # , None\n        \n    INPUT = []\n    CROP = []\n    for points in xyz:\n        if isinstance(crop,list):\n            num_crop = random.randint(crop[0],crop[1])\n        else:\n            num_crop = crop\n\n        points = points.unsqueeze(0)\n\n        if fixed_points is None:       \n            center = F.normalize(torch.randn(1,1,3, device=device),p=2,dim=-1)\n        else:\n            if isinstance(fixed_points,list):\n                fixed_point = random.sample(fixed_points,1)[0]\n            else:\n                fixed_point = fixed_points\n            center = fixed_point.reshape(1,1,3).to(device)\n\n        distance_matrix = torch.norm(center.unsqueeze(2) - points.unsqueeze(1), p =2 ,dim = -1)  # 1 1 2048\n\n        idx = torch.argsort(distance_matrix,dim=-1, descending=False)[0,0] # 2048\n\n        if padding_zeros:\n            input_data = points.clone()\n            input_data[0, idx[:num_crop]] =  input_data[0,idx[:num_crop]] * 0\n\n        else:\n            input_data = points.clone()[0, idx[num_crop:]].unsqueeze(0) # 1 N 3\n\n        crop_data =  points.clone()[0, idx[:num_crop]].unsqueeze(0)\n\n        if isinstance(crop,list):\n            INPUT.append(fps(input_data,2048))\n            CROP.append(fps(crop_data,2048))\n        else:\n            INPUT.append(input_data)\n            CROP.append(crop_data)\n\n    input_data = torch.cat(INPUT,dim=0)# B N 3\n    crop_data = torch.cat(CROP,dim=0)# B M 3\n\n    return input_data.contiguous() #, crop_data.contiguous()\n\n\n\ndef visualize_pc(pc):\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(pc.reshape(-1,3))\n    o3d.io.write_point_cloud(\"./pc.ply\", pcd)\n    #o3d.visualization.draw_geometries([pcd])\n\ndef jitter_pc(pc, pc_size=None, sigma=0.05, clip=0.1):\n    device = pc.device\n    pc += torch.clamp(sigma*torch.randn(*pc.shape, device=device), -1*clip, clip)\n    if pc_size is not None:\n        if len(pc.shape) == 3: # B, N, 3\n            pc = pc[:, torch.randperm(pc.shape[1])[:pc_size] ]\n        else: # N, 3\n            pc = pc[torch.randperm(pc.shape[0])[:pc_size] ]\n\n    return pc\n\n\ndef normalize_pc(pc):\n    pc -= torch.mean(pc, axis=0)\n    m = torch.max(torch.sqrt(torch.sum(pc**2, axis=1)))\n    #bbox_length = torch.sqrt( torch.sum((torch.max(pc, axis=0)[0] - torch.min(pc, axis=0)[0])**2) )\n    pc /= m\n    return pc\n\ndef save_model(iters, model, optimizer, loss, path):\n    torch.save({'iters': iters,\n                'model_state_dict': model.state_dict(),\n                'optimizer_state_dict': optimizer.state_dict(),\n                'loss': loss}, \n                path)\n\ndef load_model(model, optimizer, path):\n    checkpoint = torch.load(path)\n    \n    if optimizer is not None:\n        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n    else:\n        optimizer = None\n    \n    model.load_state_dict(checkpoint['model_state_dict'])\n    loss = checkpoint['loss']\n    iters = checkpoint['iters']\n    print(\"loading from iter {}...\".format(iters))\n    return iters, model, optimizer, loss\n\n\ndef save_code_to_conf(conf_dir):\n    path = os.path.join(conf_dir, \"code\")\n    os.makedirs(path, exist_ok=True)\n    for folder in [\"utils\", \"models\", \"diff_utils\", \"dataloader\", \"metrics\"]: \n        os.makedirs(os.path.join(path, folder), exist_ok=True)\n        os.system(\"\"\"cp -r ./{0}/* \"{1}\" \"\"\".format(folder, os.path.join(path, folder)))\n\n    # other files\n    os.system(\"\"\"cp *.py \"{}\" \"\"\".format(path))\n\nclass ScheduledOpt:\n    '''\n    optimizer = ScheduledOpt(4000, torch.optim.Adam(model.parameters(), lr=0))\n    '''\n    \"Optim wrapper that implements rate.\"\n    def __init__(self, warmup, optimizer):\n        self.optimizer = optimizer\n        self._step = 0\n        self.warmup = warmup\n        self._rate = 0\n    \n    # def state_dict(self):\n    #     \"\"\"Returns the state of the warmup scheduler as a :class:`dict`.\n    #     It contains an entry for every variable in self.__dict__ which\n    #     is not the optimizer.\n    #     \"\"\"\n    #     return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}\n    \n    # def load_state_dict(self, state_dict):\n    #     \"\"\"Loads the warmup scheduler's state.\n    #     Arguments:\n    #         state_dict (dict): warmup scheduler state. Should be an object returned\n    #             from a call to :meth:`state_dict`.\n    #     \"\"\"\n    #     self.__dict__.update(state_dict) \n        \n    def step(self):\n        \"Update parameters and rate\"\n        self._step += 1\n        rate = self.rate()\n        for p in self.optimizer.param_groups:\n            p['lr'] = rate\n        self._rate = rate\n        self.optimizer.step()\n        #print(\"rate: \",rate)\n\n    def zero_grad(self):\n        self.optimizer.zero_grad()\n        \n    def rate(self, step = None):\n        \"Implement `lrate` above\"\n        if step is None:\n            step = self._step\n\n        warm_schedule = torch.linspace(0, 3e-4, self.warmup, dtype = torch.float64)\n        if step < self.warmup:\n            return warm_schedule[step]\n        else:\n            return 3e-4 / (math.sqrt(step-self.warmup+1))\n\ndef exists(x):\n    return x is not None\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\ndef cycle(dl):\n    while True:\n        for data in dl:\n            yield data\n\ndef has_int_squareroot(num):\n    return (math.sqrt(num) ** 2) == num\n\ndef num_to_groups(num, divisor):\n    groups = num // divisor\n    remainder = num % divisor\n    arr = [divisor] * groups\n    if remainder > 0:\n        arr.append(remainder)\n    return arr\n\ndef convert_image_to(img_type, image):\n    if image.mode != img_type:\n        return image.convert(img_type)\n    return image\n\n# normalization functions\n\n#from 0,1 to -1,1\ndef normalize_to_neg_one_to_one(img):\n    return img * 2 - 1\n\n# from -1,1 to 0,1\ndef unnormalize_to_zero_to_one(t):\n    return (t + 1) * 0.5\n\n# from any batch to [0,1]\n# f should have shape (batch, -1)\ndef normalize_to_zero_to_one(f):\n    f -= f.min(1, keepdim=True)[0]\n    f /= f.max(1, keepdim=True)[0]\n    return f\n\n\n# extract the appropriate t index for a batch of indices\ndef extract(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\ndef linear_beta_schedule(timesteps):\n    #print(\"using LINEAR schedule\")\n    scale = 1000 / timesteps\n    beta_start = scale * 0.0001 \n    beta_end = scale * 0.02 \n    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)\n\ndef cosine_beta_schedule(timesteps, s = 0.008):\n    \"\"\"\n    cosine schedule\n    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n    \"\"\"\n    #print(\"using COSINE schedule\")\n    steps = timesteps + 1\n    x = torch.linspace(0, timesteps, steps, dtype = torch.float64)\n    cos_in = ((x / timesteps) + s) / (1 + s) * math.pi * 0.5\n    np_in = cos_in.numpy()\n    alphas_cumprod = np.cos(np_in)  ** 2\n    alphas_cumprod = torch.from_numpy(alphas_cumprod)\n    #alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])\n\n    return torch.clip(betas, 0, 0.999)\n\n"
  },
  {
    "path": "diff_utils/model_utils.py",
    "content": "import math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum \n\nfrom einops import rearrange, repeat, reduce\nfrom einops.layers.torch import Rearrange\nfrom einops_exts import rearrange_many, repeat_many, check_shape\n\nfrom .pointnet.pointnet_classifier import PointNetClassifier\nfrom .pointnet.conv_pointnet import ConvPointnet\nfrom .pointnet.dgcnn import DGCNN\n\nfrom .helpers import *\n\nclass LayerNorm(nn.Module):\n    def __init__(self, dim, eps = 1e-5, stable = False):\n        super().__init__()\n        self.eps = eps\n        self.stable = stable\n        self.g = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        if self.stable:\n            x = x / x.amax(dim = -1, keepdim = True).detach()\n\n        var = torch.var(x, dim = -1, unbiased = False, keepdim = True)\n        mean = torch.mean(x, dim = -1, keepdim = True)\n        return (x - mean) * (var + self.eps).rsqrt() * self.g\n\n# mlp\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        dim_in,\n        dim_out,\n        *,\n        expansion_factor = 2.,\n        depth = 2,\n        norm = False,\n    ):\n        super().__init__()\n        hidden_dim = int(expansion_factor * dim_out)\n        norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()\n\n        layers = [nn.Sequential(\n            nn.Linear(dim_in, hidden_dim),\n            nn.SiLU(),\n            norm_fn()\n        )]\n\n        for _ in range(depth - 1):\n            layers.append(nn.Sequential(\n                nn.Linear(hidden_dim, hidden_dim),\n                nn.SiLU(),\n                norm_fn()\n            ))\n\n        layers.append(nn.Linear(hidden_dim, dim_out))\n        self.net = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.net(x.float())\n\n# relative positional bias for causal transformer\n\nclass RelPosBias(nn.Module):\n    def __init__(\n        self,\n        heads = 8,\n        num_buckets = 32,\n        max_distance = 128,\n    ):\n        super().__init__()\n        self.num_buckets = num_buckets\n        self.max_distance = max_distance\n        self.relative_attention_bias = nn.Embedding(num_buckets, heads)\n\n    @staticmethod\n    def _relative_position_bucket(\n        relative_position,\n        num_buckets = 32,\n        max_distance = 128\n    ):\n        n = -relative_position\n        n = torch.max(n, torch.zeros_like(n))\n\n        max_exact = num_buckets // 2\n        is_small = n < max_exact\n\n        val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()\n        val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))\n        return torch.where(is_small, n, val_if_large)\n\n    def forward(self, i, j, *, device):\n        q_pos = torch.arange(i, dtype = torch.long, device = device)\n        k_pos = torch.arange(j, dtype = torch.long, device = device)\n        rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')\n        rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)\n        values = self.relative_attention_bias(rp_bucket)\n        return rearrange(values, 'i j h -> h i j')\n\n# feedforward\n\nclass SwiGLU(nn.Module):\n    \"\"\" used successfully in https://arxiv.org/abs/2204.0231 \"\"\"\n    def forward(self, x):\n        x, gate = x.chunk(2, dim = -1)\n        return x * F.silu(gate)\n\ndef FeedForward(\n    dim,\n    out_dim = None,\n    mult = 4,\n    dropout = 0.,\n    post_activation_norm = False\n):\n    \"\"\" post-activation norm https://arxiv.org/abs/2110.09456 \"\"\"\n\n    #print(\"dropout: \", dropout)\n    out_dim = default(out_dim, dim)\n    #print(\"out_dim: \", out_dim)\n    inner_dim = int(mult * dim)\n    return nn.Sequential(\n        LayerNorm(dim),\n        nn.Linear(dim, inner_dim * 2, bias = False),\n        SwiGLU(),\n        LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),\n        nn.Dropout(dropout),\n        nn.Linear(inner_dim, out_dim, bias = False)\n    )\n\nclass SinusoidalPosEmb(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.dim = dim\n\n    def forward(self, x):\n        device = x.device\n        half_dim = self.dim // 2\n        emb = math.log(10000) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)\n        emb = x[:, None] * emb[None, :]\n        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n        return emb\n\ndef exists(x):\n    return x is not None\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        kv_dim=None,\n        *,\n        out_dim = None,\n        dim_head = 64,\n        heads = 8,\n        dropout = 0.,\n        causal = False,\n        rotary_emb = None,\n        pb_relax_alpha = 128\n    ):\n        super().__init__()\n        self.pb_relax_alpha = pb_relax_alpha\n        self.scale = dim_head ** -0.5 * (pb_relax_alpha ** -1)\n\n        self.heads = heads\n        inner_dim = dim_head * heads\n        kv_dim = default(kv_dim, dim)\n\n        self.causal = causal\n\n        self.norm = LayerNorm(dim)\n\n        self.dropout = nn.Dropout(dropout)\n\n        self.null_kv = nn.Parameter(torch.randn(2, dim_head))\n        self.to_q = nn.Linear(dim, inner_dim, bias = False)\n        self.to_kv = nn.Linear(kv_dim, dim_head * 2, bias = False)\n\n        self.rotary_emb = rotary_emb\n\n        out_dim = default(out_dim, dim)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, out_dim, bias = False),\n            LayerNorm(out_dim)\n        )\n\n    def forward(self, x, context=None, mask = None, attn_bias = None):\n        b, n, device = *x.shape[:2], x.device\n\n        context = default(context, x) #self attention if context is None \n\n        x = self.norm(x)\n        q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))\n\n        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)\n        q = q * self.scale\n\n        # rotary embeddings\n\n        if exists(self.rotary_emb):\n            q, k = map(self.rotary_emb.rotate_queries_or_keys, (q, k))\n\n        # add null key / value for classifier free guidance in prior net\n\n        nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)\n        k = torch.cat((nk, k), dim = -2)\n        v = torch.cat((nv, v), dim = -2)\n\n        # calculate query / key similarities\n\n        sim = einsum('b h i d, b j d -> b h i j', q, k)\n\n        # relative positional encoding (T5 style)\n        #print(\"attn bias, sim shapes: \", attn_bias.shape, sim.shape)\n        if exists(attn_bias):\n            sim = sim + attn_bias\n\n        # masking\n\n        max_neg_value = -torch.finfo(sim.dtype).max\n\n        if exists(mask):\n            mask = F.pad(mask, (1, 0), value = True)\n            mask = rearrange(mask, 'b j -> b 1 1 j')\n            sim = sim.masked_fill(~mask, max_neg_value)\n\n        if self.causal:\n            i, j = sim.shape[-2:]\n            causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)\n            sim = sim.masked_fill(causal_mask, max_neg_value)\n\n        # attention\n\n        sim = sim - sim.amax(dim = -1, keepdim = True).detach()\n        sim = sim * self.pb_relax_alpha\n\n        attn = sim.softmax(dim = -1)\n        attn = self.dropout(attn)\n\n        # aggregate values\n\n        out = einsum('b h i j, b j d -> b h i d', attn, v)\n\n        out = rearrange(out, 'b h n d -> b n (h d)')\n        return self.to_out(out)\n\n"
  },
  {
    "path": "diff_utils/pointnet/__init__.py",
    "content": ""
  },
  {
    "path": "diff_utils/pointnet/conv_pointnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\nfrom torch_scatter import scatter_mean, scatter_max\n\n\nclass ConvPointnet(nn.Module):\n    ''' PointNet-based encoder network with ResNet blocks for each point.\n        Number of input points are fixed.\n    \n    Args:\n        c_dim (int): dimension of latent code c\n        dim (int): input points dimension\n        hidden_dim (int): hidden dimension of the network\n        scatter_type (str): feature aggregation when doing local pooling\n        unet (bool): weather to use U-Net\n        unet_kwargs (str): U-Net parameters\n        plane_resolution (int): defined resolution for plane feature\n        plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume\n        padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]\n        n_blocks (int): number of blocks ResNetBlockFC layers\n    '''\n\n    def __init__(self, c_dim=512, dim=3, hidden_dim=128, scatter_type='max', \n                 unet=True, unet_kwargs={\"depth\": 4, \"merge_mode\": \"concat\", \"start_filts\": 32}, \n                 plane_resolution=64, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):\n        super().__init__()\n        self.c_dim = c_dim\n\n        self.fc_pos = nn.Linear(dim, 2*hidden_dim)\n        self.blocks = nn.ModuleList([\n            ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)\n        ])\n        self.fc_c = nn.Linear(hidden_dim, c_dim)\n\n        self.actvn = nn.ReLU()\n        self.hidden_dim = hidden_dim\n\n        if unet:\n            self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)\n        else:\n            self.unet = None\n\n        self.reso_plane = plane_resolution\n        self.plane_type = plane_type\n        self.padding = padding\n\n        if scatter_type == 'max':\n            self.scatter = scatter_max\n        elif scatter_type == 'mean':\n            self.scatter = scatter_mean\n\n\n    # takes in \"p\": point cloud and \"query\": sdf_xyz \n    # sample plane features for unlabeled_query as well \n    def forward(self, p, query):\n        batch_size, T, D = p.size()\n\n        # acquire the index for each point\n        coord = {}\n        index = {}\n        if 'xz' in self.plane_type:\n            coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)\n            index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)\n        if 'xy' in self.plane_type:\n            coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)\n            index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)\n        if 'yz' in self.plane_type:\n            coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)\n            index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)\n\n        \n        net = self.fc_pos(p)\n\n        net = self.blocks[0](net)\n        for block in self.blocks[1:]:\n            pooled = self.pool_local(coord, index, net)\n            net = torch.cat([net, pooled], dim=2)\n            net = block(net)\n\n        c = self.fc_c(net)\n\n        fea = {}\n        plane_feat_sum = 0\n        if 'xz' in self.plane_type:\n            fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)\n            plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')\n        if 'xy' in self.plane_type:\n            fea['xy'] = self.generate_plane_features(p, c, plane='xy')\n            plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')\n        if 'yz' in self.plane_type:\n            fea['yz'] = self.generate_plane_features(p, c, plane='yz')\n            plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')\n\n        return plane_feat_sum.transpose(2,1)\n\n\n    def normalize_coordinate(self, p, padding=0.1, plane='xz'):\n        ''' Normalize coordinate to [0, 1] for unit cube experiments\n\n        Args:\n            p (tensor): point\n            padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]\n            plane (str): plane feature type, ['xz', 'xy', 'yz']\n        '''\n        if plane == 'xz':\n            xy = p[:, :, [0, 2]]\n        elif plane =='xy':\n            xy = p[:, :, [0, 1]]\n        else:\n            xy = p[:, :, [1, 2]]\n\n        xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)\n        xy_new = xy_new + 0.5 # range (0, 1)\n\n        # f there are outliers out of the range\n        if xy_new.max() >= 1:\n            xy_new[xy_new >= 1] = 1 - 10e-6\n        if xy_new.min() < 0:\n            xy_new[xy_new < 0] = 0.0\n        return xy_new\n\n\n    def coordinate2index(self, x, reso):\n        ''' Normalize coordinate to [0, 1] for unit cube experiments.\n            Corresponds to our 3D model\n\n        Args:\n            x (tensor): coordinate\n            reso (int): defined resolution\n            coord_type (str): coordinate type\n        '''\n        x = (x * reso).long()\n        index = x[:, :, 0] + reso * x[:, :, 1]\n        index = index[:, None, :]\n        return index\n\n\n    # xy is the normalized coordinates of the point cloud of each plane \n    # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input \n    def pool_local(self, xy, index, c):\n        bs, fea_dim = c.size(0), c.size(2)\n        keys = xy.keys()\n\n        c_out = 0\n        for key in keys:\n            # scatter plane features from points\n            fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)\n            if self.scatter == scatter_max:\n                fea = fea[0]\n            # gather feature back to points\n            fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))\n            c_out += fea\n        return c_out.permute(0, 2, 1)\n\n\n    def generate_plane_features(self, p, c, plane='xz'):\n        # acquire indices of features in plane\n        xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)\n        index = self.coordinate2index(xy, self.reso_plane)\n\n        # scatter plane features from points\n        fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)\n        c = c.permute(0, 2, 1) # B x 512 x T\n        fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2\n        fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)\n\n        # process the plane features with UNet\n        if self.unet is not None:\n            fea_plane = self.unet(fea_plane)\n\n        return fea_plane\n\n\n    # sample_plane_feature function copied from /src/conv_onet/models/decoder.py\n    # uses values from plane_feature and pixel locations from vgrid to interpolate feature\n    def sample_plane_feature(self, query, plane_feature, plane):\n        xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)\n        xy = xy[:, :, None].float()\n        vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)\n        sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)\n        return sampled_feat\n\n\ndef conv3x3(in_channels, out_channels, stride=1, \n            padding=1, bias=True, groups=1):    \n    return nn.Conv2d(\n        in_channels,\n        out_channels,\n        kernel_size=3,\n        stride=stride,\n        padding=padding,\n        bias=bias,\n        groups=groups)\n\ndef upconv2x2(in_channels, out_channels, mode='transpose'):\n    if mode == 'transpose':\n        return nn.ConvTranspose2d(\n            in_channels,\n            out_channels,\n            kernel_size=2,\n            stride=2)\n    else:\n        # out_channels is always going to be the same\n        # as in_channels\n        return nn.Sequential(\n            nn.Upsample(mode='bilinear', scale_factor=2),\n            conv1x1(in_channels, out_channels))\n\ndef conv1x1(in_channels, out_channels, groups=1):\n    return nn.Conv2d(\n        in_channels,\n        out_channels,\n        kernel_size=1,\n        groups=groups,\n        stride=1)\n\n\nclass DownConv(nn.Module):\n    \"\"\"\n    A helper Module that performs 2 convolutions and 1 MaxPool.\n    A ReLU activation follows each convolution.\n    \"\"\"\n    def __init__(self, in_channels, out_channels, pooling=True):\n        super(DownConv, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.pooling = pooling\n\n        self.conv1 = conv3x3(self.in_channels, self.out_channels)\n        self.conv2 = conv3x3(self.out_channels, self.out_channels)\n\n        if self.pooling:\n            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        before_pool = x\n        if self.pooling:\n            x = self.pool(x)\n        return x, before_pool\n\n\nclass UpConv(nn.Module):\n    \"\"\"\n    A helper Module that performs 2 convolutions and 1 UpConvolution.\n    A ReLU activation follows each convolution.\n    \"\"\"\n    def __init__(self, in_channels, out_channels, \n                 merge_mode='concat', up_mode='transpose'):\n        super(UpConv, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.merge_mode = merge_mode\n        self.up_mode = up_mode\n\n        self.upconv = upconv2x2(self.in_channels, self.out_channels, \n            mode=self.up_mode)\n\n        if self.merge_mode == 'concat':\n            self.conv1 = conv3x3(\n                2*self.out_channels, self.out_channels)\n        else:\n            # num of input channels to conv2 is same\n            self.conv1 = conv3x3(self.out_channels, self.out_channels)\n        self.conv2 = conv3x3(self.out_channels, self.out_channels)\n\n\n    def forward(self, from_down, from_up):\n        \"\"\" Forward pass\n        Arguments:\n            from_down: tensor from the encoder pathway\n            from_up: upconv'd tensor from the decoder pathway\n        \"\"\"\n        from_up = self.upconv(from_up)\n        if self.merge_mode == 'concat':\n            x = torch.cat((from_up, from_down), 1)\n        else:\n            x = from_up + from_down\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        return x\n\n\nclass UNet(nn.Module):\n    \"\"\" `UNet` class is based on https://arxiv.org/abs/1505.04597\n\n    The U-Net is a convolutional encoder-decoder neural network.\n    Contextual spatial information (from the decoding,\n    expansive pathway) about an input tensor is merged with\n    information representing the localization of details\n    (from the encoding, compressive pathway).\n\n    Modifications to the original paper:\n    (1) padding is used in 3x3 convolutions to prevent loss\n        of border pixels\n    (2) merging outputs does not require cropping due to (1)\n    (3) residual connections can be used by specifying\n        UNet(merge_mode='add')\n    (4) if non-parametric upsampling is used in the decoder\n        pathway (specified by upmode='upsample'), then an\n        additional 1x1 2d convolution occurs after upsampling\n        to reduce channel dimensionality by a factor of 2.\n        This channel halving happens with the convolution in\n        the tranpose convolution (specified by upmode='transpose')\n    \"\"\"\n\n    def __init__(self, num_classes, in_channels=3, depth=5, \n                 start_filts=64, up_mode='transpose', \n                 merge_mode='concat', **kwargs):\n        \"\"\"\n        Arguments:\n            in_channels: int, number of channels in the input tensor.\n                Default is 3 for RGB images.\n            depth: int, number of MaxPools in the U-Net.\n            start_filts: int, number of convolutional filters for the \n                first conv.\n            up_mode: string, type of upconvolution. Choices: 'transpose'\n                for transpose convolution or 'upsample' for nearest neighbour\n                upsampling.\n        \"\"\"\n        super(UNet, self).__init__()\n\n        if up_mode in ('transpose', 'upsample'):\n            self.up_mode = up_mode\n        else:\n            raise ValueError(\"\\\"{}\\\" is not a valid mode for \"\n                             \"upsampling. Only \\\"transpose\\\" and \"\n                             \"\\\"upsample\\\" are allowed.\".format(up_mode))\n    \n        if merge_mode in ('concat', 'add'):\n            self.merge_mode = merge_mode\n        else:\n            raise ValueError(\"\\\"{}\\\" is not a valid mode for\"\n                             \"merging up and down paths. \"\n                             \"Only \\\"concat\\\" and \"\n                             \"\\\"add\\\" are allowed.\".format(up_mode))\n\n        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'\n        if self.up_mode == 'upsample' and self.merge_mode == 'add':\n            raise ValueError(\"up_mode \\\"upsample\\\" is incompatible \"\n                             \"with merge_mode \\\"add\\\" at the moment \"\n                             \"because it doesn't make sense to use \"\n                             \"nearest neighbour to reduce \"\n                             \"depth channels (by half).\")\n\n        self.num_classes = num_classes\n        self.in_channels = in_channels\n        self.start_filts = start_filts\n        self.depth = depth\n\n        self.down_convs = []\n        self.up_convs = []\n\n        # create the encoder pathway and add to a list\n        for i in range(depth):\n            ins = self.in_channels if i == 0 else outs\n            outs = self.start_filts*(2**i)\n            pooling = True if i < depth-1 else False\n\n            down_conv = DownConv(ins, outs, pooling=pooling)\n            self.down_convs.append(down_conv)\n\n        # create the decoder pathway and add to a list\n        # - careful! decoding only requires depth-1 blocks\n        for i in range(depth-1):\n            ins = outs\n            outs = ins // 2\n            up_conv = UpConv(ins, outs, up_mode=up_mode,\n                merge_mode=merge_mode)\n            self.up_convs.append(up_conv)\n\n        # add the list of modules to current module\n        self.down_convs = nn.ModuleList(self.down_convs)\n        self.up_convs = nn.ModuleList(self.up_convs)\n\n        self.conv_final = conv1x1(outs, self.num_classes)\n\n        self.reset_params()\n\n    @staticmethod\n    def weight_init(m):\n        if isinstance(m, nn.Conv2d):\n            init.xavier_normal_(m.weight)\n            init.constant_(m.bias, 0)\n\n\n    def reset_params(self):\n        for i, m in enumerate(self.modules()):\n            self.weight_init(m)\n\n\n    def forward(self, x):\n        encoder_outs = []\n        # encoder pathway, save outputs for merging\n        for i, module in enumerate(self.down_convs):\n            x, before_pool = module(x)\n            encoder_outs.append(before_pool)\n        for i, module in enumerate(self.up_convs):\n            before_pool = encoder_outs[-(i+2)]\n            x = module(before_pool, x)\n        \n        # No softmax is used. This means you need to use\n        # nn.CrossEntropyLoss is your training script,\n        # as this module includes a softmax already.\n        x = self.conv_final(x)\n        return x\n\n# Resnet Blocks\nclass ResnetBlockFC(nn.Module):\n    ''' Fully connected ResNet Block class.\n    Args:\n        size_in (int): input dimension\n        size_out (int): output dimension\n        size_h (int): hidden dimension\n    '''\n\n    def __init__(self, size_in, size_out=None, size_h=None):\n        super().__init__()\n        # Attributes\n        if size_out is None:\n            size_out = size_in\n\n        if size_h is None:\n            size_h = min(size_in, size_out)\n\n        self.size_in = size_in\n        self.size_h = size_h\n        self.size_out = size_out\n        # Submodules\n        self.fc_0 = nn.Linear(size_in, size_h)\n        self.fc_1 = nn.Linear(size_h, size_out)\n        self.actvn = nn.ReLU()\n\n        if size_in == size_out:\n            self.shortcut = None\n        else:\n            self.shortcut = nn.Linear(size_in, size_out, bias=False)\n        # Initialization\n        nn.init.zeros_(self.fc_1.weight)\n\n    def forward(self, x):\n        net = self.fc_0(self.actvn(x))\n        dx = self.fc_1(self.actvn(net))\n\n        if self.shortcut is not None:\n            x_s = self.shortcut(x)\n        else:\n            x_s = x\n\n        return x_s + dx"
  },
  {
    "path": "diff_utils/pointnet/dgcnn.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef knn(x, k):\n    inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)\n    xx = torch.sum(x ** 2, dim=1, keepdim=True)\n    pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()\n\n    idx = pairwise_distance.topk(k=k, dim=-1)[1]\n    return idx\n\ndef get_graph_feature(x, k=20):\n    idx = knn(x, k=k)\n    batch_size, num_points, _ = idx.size()\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points\n\n    idx = idx + idx_base\n\n    idx = idx.view(-1)\n\n    _, num_dims, _ = x.size()\n\n    x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims)  \n                                       # -> (batch_size*num_points, num_dims) \n                                       #   batch_size * num_points * k + range(0, batch_size*num_points)\n    feature = x.view(batch_size * num_points, -1)[idx, :]\n    feature = feature.view(batch_size, num_points, k, num_dims)\n    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)\n    feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)\n    return feature\n\nclass DGCNN(nn.Module):\n\n    def __init__(\n        self, \n        emb_dims=512,\n        use_bn=False,\n        output_channels=100 # number of categories to predict \n    ):\n\n        super().__init__()\n\n        if use_bn:\n            print(\"using batch norm\")\n            self.bn1 = nn.BatchNorm2d(64)\n            self.bn2 = nn.BatchNorm2d(64)\n            self.bn3 = nn.BatchNorm2d(128)\n            self.bn4 = nn.BatchNorm2d(256)\n            self.bn5 = nn.BatchNorm2d(emb_dims)\n\n            self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))\n            self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))\n            self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))\n            self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), self.bn4, nn.LeakyReLU(negative_slope=0.2))\n            self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2))\n\n        else:\n            self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n            self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n            self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n            self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n            self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n\n        self.linear1 = nn.Linear(emb_dims*2, 512, bias=False)\n        self.bn6 = nn.BatchNorm1d(512)\n        self.dp1 = nn.Dropout(p=0.5)\n        self.linear2 = nn.Linear(512, 256)\n        self.bn7 = nn.BatchNorm1d(256)\n        self.dp2 = nn.Dropout(p=0.5)\n        self.linear3 = nn.Linear(256, output_channels)\n\n\n    def forward(self, x):\n        batch_size, num_dims, num_points = x.size()                 # x:      batch x   3 x num of points\n        x = get_graph_feature(x)                                    # x:      batch x   6 x num of points x 20\n\n        x1     = self.conv1(x)                                      # x1:     batch x  64 x num of points x 20\n        x1_max = x1.max(dim=-1, keepdim=True)[0]                    # x1_max: batch x  64 x num of points x 1\n\n        x2     = self.conv2(x1)                                     # x2:     batch x  64 x num of points x 20\n        x2_max = x2.max(dim=-1, keepdim=True)[0]                    # x2_max: batch x  64 x num of points x 1\n\n        x3     = self.conv3(x2)                                     # x3:     batch x 128 x num of points x 20\n        x3_max = x3.max(dim=-1, keepdim=True)[0]                    # x3_max: batch x 128 x num of points x 1\n\n        x4     = self.conv4(x3)                                     # x4:     batch x 256 x num of points x 20\n        x4_max = x4.max(dim=-1, keepdim=True)[0]                    # x4_max: batch x 256 x num of points x 1\n \n        x_max  = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max:  batch x 512 x num of points x 1\n\n        point_feat = torch.squeeze(self.conv5(x_max), dim=3)        # point feat:  batch x 512 x num of points\n\n        #global_feat = point_feat.max(dim=2, keepdim=False)[0]       # global feat: batch x 512\n\n        x1 = F.adaptive_max_pool1d(point_feat, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)\n        x2 = F.adaptive_avg_pool1d(point_feat, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)\n        x = torch.cat((x1, x2), 1)              # (batch_size, emb_dims*2)\n\n        x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) # (batch_size, emb_dims*2) -> (batch_size, 512)\n        x = self.dp1(x)\n        x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) # (batch_size, 512) -> (batch_size, 256)\n        x = self.dp2(x)\n        x = self.linear3(x)                                             # (batch_size, 256) -> (batch_size, output_channels)\n\n\n        return x\n\n    def get_global_feature(self, x):\n        batch_size, num_dims, num_points = x.size()                 # x:      batch x   3 x num of points\n        x = get_graph_feature(x)                                    # x:      batch x   6 x num of points x 20\n\n        x1     = self.conv1(x)                                      # x1:     batch x  64 x num of points x 20\n        x1_max = x1.max(dim=-1, keepdim=True)[0]                    # x1_max: batch x  64 x num of points x 1\n\n        x2     = self.conv2(x1)                                     # x2:     batch x  64 x num of points x 20\n        x2_max = x2.max(dim=-1, keepdim=True)[0]                    # x2_max: batch x  64 x num of points x 1\n\n        x3     = self.conv3(x2)                                     # x3:     batch x 128 x num of points x 20\n        x3_max = x3.max(dim=-1, keepdim=True)[0]                    # x3_max: batch x 128 x num of points x 1\n\n        x4     = self.conv4(x3)                                     # x4:     batch x 256 x num of points x 20\n        x4_max = x4.max(dim=-1, keepdim=True)[0]                    # x4_max: batch x 256 x num of points x 1\n \n        x_max  = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max:  batch x 512 x num of points x 1\n\n        point_feat = torch.squeeze(self.conv5(x_max), dim=3)        # point feat:  batch x 512 x num of points\n\n        #global_feat = point_feat.max(dim=2, keepdim=False)[0]       # global feat: batch x 512\n\n        x1 = F.adaptive_max_pool1d(point_feat, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)\n        x2 = F.adaptive_avg_pool1d(point_feat, 1).view(batch_size, -1)           # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims)\n        x = torch.cat((x1, x2), 1)              # (batch_size, emb_dims*2)\n        return x"
  },
  {
    "path": "diff_utils/pointnet/pointnet_base.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as grad\n\nfrom .transformer import Transformer\n\n\n##-----------------------------------------------------------------------------\n# Class for PointNetBase. Subclasses PyTorch's own \"nn\" module\n#\n# Computes the local embeddings and global features for an input set of points\n##\nclass PointNetBase(nn.Module):\n\n    def __init__(self, num_points=2000, K=3):\n        # Call the super constructor\n        super(PointNetBase, self).__init__()\n\n        # Input transformer for K-dimensional input\n        # K should be 3 for XYZ coordinates, but can be larger if normals, \n        # colors, etc are included\n        self.input_transformer = Transformer(num_points, K)\n\n        # Embedding transformer is always going to be 64 dimensional\n        self.embedding_transformer = Transformer(num_points, 64)\n\n\t\t# Multilayer perceptrons with shared weights are implemented as \n\t\t# convolutions. This is because we are mapping from K inputs to 64 \n\t\t# outputs, so we can just consider each of the 64 K-dim filters as \n\t\t# describing the weight matrix for each point dimension (X,Y,Z,...) to\n\t\t# each index of the 64 dimension embeddings\n        self.mlp1 = nn.Sequential(\n            nn.Conv1d(K, 64, 1),\n            nn.BatchNorm1d(64),\n            nn.ReLU(),\n            nn.Conv1d(64, 64, 1),\n            nn.BatchNorm1d(64),\n            nn.ReLU())\n\n        self.mlp2 = nn.Sequential(\n            nn.Conv1d(64, 64, 1),\n            nn.BatchNorm1d(64),\n            nn.ReLU(),\n            nn.Conv1d(64, 128, 1),\n            nn.BatchNorm1d(128),\n            nn.ReLU(),\n            nn.Conv1d(128, 1024, 1),\n            nn.BatchNorm1d(1024),\n            nn.ReLU())\n\n\n\t# Take as input a B x K x N matrix of B batches of N points with K \n\t# dimensions\n    def forward(self, x):\n\n        # Number of points put into the network\n        N = x.shape[2]\n\n        # First compute the input data transform and transform the data\n        # T1 is B x K x K and x is B x K x N, so output is B x K x N\n        T1 = self.input_transformer(x)\n        x = torch.bmm(T1, x)\n\n        # Run the transformed inputs through the first embedding MLP\n        # Output is B x 64 x N\n        x = self.mlp1(x)\n\n        # Transform the embeddings. This gives us the \"local embedding\" \n        # referred to in the paper/slides\n        # T2 is B x 64 x 64 and x is B x 64 x N, so output is B x 64 x N\n        T2 = self.embedding_transformer(x)\n        local_embedding = torch.bmm(T2, x)\n\n        # Further embed the \"local embeddings\"\n        # Output is B x 1024 x N\n        global_feature = self.mlp2(local_embedding)\n\n        # Pool over the number of points. This results in the \"global feature\"\n        # referred to in the paper/slides\n        # Output should be B x 1024 x 1 --> B x 1024 (after squeeze)\n        global_feature = F.max_pool1d(global_feature, N).squeeze(2)\n\n        return global_feature, local_embedding, T2\n\n"
  },
  {
    "path": "diff_utils/pointnet/pointnet_classifier.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as grad\n\nfrom .pointnet_base import PointNetBase\n\n\n##-----------------------------------------------------------------------------\n# Class for PointNetClassifier. Subclasses PyTorch's own \"nn\" module\n#\n# Computes the local embeddings and global features for an input set of points\n##\nclass PointNetClassifier(nn.Module):\n\n\tdef __init__(self, num_points=2000, K=3):\n\t\t# Call the super constructor\n\t\tsuper(PointNetClassifier, self).__init__()\n\n\t\t# Local and global feature extractor for PointNet\n\t\tself.base = PointNetBase(num_points, K)\n\n\t\t# Classifier for ShapeNet\n\t\tself.classifier = nn.Sequential(\n\t\t\tnn.Linear(1024, 512),\n\t\t\tnn.BatchNorm1d(512),\n\t\t\tnn.ReLU(),\n\t\t\tnn.Dropout(0.7),\n\t\t\tnn.Linear(512, 256),\n\t\t\tnn.BatchNorm1d(256),\n\t\t\tnn.ReLU(),\n\t\t\tnn.Dropout(0.7),\n\t\t\tnn.Linear(256, 40))\n\n\n\t# Take as input a B x K x N matrix of B batches of N points with K \n\t# dimensions\n\tdef forward(self, x):\n\n\t\t# Only need to keep the global feature descriptors for classification\n\t\t# Output should be B x 1024\n\t\tglobal_feature, local_embedding, T2 = self.base(x)\n\n\t\t# first attempt: only use the global feature\n\t\t#return global_feature\n\n\t\t# second attempt: concat local and global feature similar to segmentation network but create the downsampling MLP in the diffusion model\n\t\t# local embedding shape: B x 64 x N; global feature shape: B x 1024; concat to get B x 1088 x N \n\t\tnum_points = local_embedding.shape[-1]\n\t\tglobal_feature = global_feature.unsqueeze(-1).repeat(1,1,num_points)\n\t\tpoint_features = torch.cat( (global_feature, local_embedding), dim=1 ) # shape is B x 1088 x N \n\t\treturn point_features\n\t\t\n\t\t\n\n\t\t# Returns a B x 40 \n\t\t#return self.classifier(x), T2"
  },
  {
    "path": "diff_utils/pointnet/transformer.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.autograd as grad\n\n\n\n##-----------------------------------------------------------------------------\n# Class for Transformer. Subclasses PyTorch's own \"nn\" module\n#\n# Computes a KxK affine transform from the input data to transform inputs\n# to a \"canonical view\"\n##\nclass Transformer(nn.Module):\n\n\tdef __init__(self, num_points=2000, K=3):\n\t\t# Call the super constructor\n\t\tsuper(Transformer, self).__init__()\n\n\t\t# Number of dimensions of the data\n\t\tself.K = K\n\n\t\t# Size of input\n\t\tself.N = num_points\n\n\t\t# Initialize identity matrix on the GPU (do this here so it only \n\t\t# happens once)\n\t\tself.identity = grad.Variable(\n\t\t\ttorch.eye(self.K).double().view(-1).cuda())\n\n\t\t# First embedding block\n\t\tself.block1 =nn.Sequential(\n\t\t\tnn.Conv1d(K, 64, 1),\n\t\t\tnn.BatchNorm1d(64),\n\t\t\tnn.ReLU())\n\n\t\t# Second embedding block\n\t\tself.block2 =nn.Sequential(\n\t\t\tnn.Conv1d(64, 128, 1),\n\t\t\tnn.BatchNorm1d(128),\n\t\t\tnn.ReLU())\n\n\t\t# Third embedding block\n\t\tself.block3 =nn.Sequential(\n\t\t\tnn.Conv1d(128, 1024, 1),\n\t\t\tnn.BatchNorm1d(1024),\n\t\t\tnn.ReLU())\n\n\t\t# Multilayer perceptron\n\t\tself.mlp = nn.Sequential(\n\t\t\tnn.Linear(1024, 512),\n\t\t\tnn.BatchNorm1d(512),\n\t\t\tnn.ReLU(),\n\t\t\tnn.Linear(512, 256),\n\t\t\tnn.BatchNorm1d(256),\n\t\t\tnn.ReLU(),\n\t\t\tnn.Linear(256, K * K))\n\n\n\t# Take as input a B x K x N matrix of B batches of N points with K \n\t# dimensions\n\tdef forward(self, x):\n\n\t\t# Compute the feature extractions\n\t\t# Output should ultimately be B x 1024 x N\n\t\tx = self.block1(x)\n\t\tx = self.block2(x)\n\t\tx = self.block3(x)\n\n\t\t# Pool over the number of points\n\t\t# Output should be B x 1024 x 1 --> B x 1024 (after squeeze)\n\t\tx = F.max_pool1d(x, self.N).squeeze(2)\n\t\t\n\t\t# Run the pooled features through the multi-layer perceptron\n\t\t# Output should be B x K^2\n\t\tx = self.mlp(x)\n\n\t\t# Add identity matrix to transform\n\t\t# Output is still B x K^2 (broadcasting takes care of batch dimension)\n\t\tx += self.identity\n\n\t\t# Reshape the output into B x K x K affine transformation matrices\n\t\tx = x.view(-1, self.K, self.K)\n\n\t\treturn x\n\n"
  },
  {
    "path": "diff_utils/sdf_utils.py",
    "content": "import math\nimport torch\nimport json \nimport torch.nn.functional as F\nfrom torch import nn, einsum \nimport os\n\nfrom einops import rearrange, repeat, reduce\nfrom einops.layers.torch import Rearrange\n\nfrom sdf_model.model import *\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n\n# load base network\nspecs_path = \"sdf_model/config/siren/specs.json\"\nspecs = json.load(open(specs_path))\nmodel = MetaSDF(specs).to(device)\ncheckpoint = torch.load(\"sdf_model/config/siren/last.ckpt\", map_location=device)\nmodel.load_state_dict(checkpoint['state_dict'])\n\nfor p in model.parameters():\n    p.requires_grad=False\n\ndef pred_sdf_loss(x0, idx):\n\n    pc, xyz, gt = sdf_sampling(gt_files[idx], 16000, batch=x0.shape[0])\n    xyz = xyz.to(device)\n    gt = gt.to(device)\n    sdf_loss = functional_sdf_model(x0, xyz, gt)\n\n    return sdf_loss\n\n\ndef functional_sdf_model(modulation, xyz, gt):\n    '''\n    modulation: modulation vector with shape 512\n    xyz: query points, input to the model, dim= 16000x3\n    gt: ground truth; calculate l1 loss with prediction, dim= 16000x1\n\n    return: sdf_loss \n    '''\n    #print(\"shapes: \", modulation.shape, xyz.shape)\n    pred_sdf = model(modulation, xyz)\n    sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze())\n\n    return sdf_loss\n\n\ndef sdf_sampling(f, subsample, pc_size=1024, batch=1):\n    # f=pd.read_csv(f, sep=',',header=None).values\n    # f = torch.from_numpy(f)\n\n    pcs = torch.empty(batch, pc_size, 3)\n    xyz = torch.empty(batch, subsample, 3)\n    gt = torch.empty(batch, subsample)\n\n    for i in range(batch):\n        half = int(subsample / 2) \n        neg_tensor = f[f[:,-1]<0]\n        pos_tensor = f[f[:,-1]>0]\n\n        if pos_tensor.shape[0] < half:\n            pos_idx = torch.randint(pos_tensor.shape[0], (half,))\n        else:\n            pos_idx = torch.randperm(pos_tensor.shape[0])[:half]\n\n        if neg_tensor.shape[0] < half:\n            neg_idx = torch.randint(neg_tensor.shape[0], (half,))\n        else:\n            neg_idx = torch.randperm(neg_tensor.shape[0])[:half]\n\n        pos_sample = pos_tensor[pos_idx]\n        neg_sample = neg_tensor[neg_idx]\n\n        pc = f[f[:,-1]==0][:,:3]\n        pc_idx = torch.randperm(pc.shape[0])[:pc_size]\n        pc = pc[pc_idx]\n\n        samples = torch.cat([pos_sample, neg_sample], 0)\n\n\n        pcs[i] = pc.float()\n        xyz[i] = samples[:,:3].float()\n        gt[i] = samples[:, 3].float()\n\n\n    return pcs, xyz, gt\n\ndef apply_to_sdf(f, x):\n    for idx, l in enumerate(x):\n        x[idx] = f(l)\n    return x"
  },
  {
    "path": "environment.yml",
    "content": "name: diffusionsdf\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_gnu\n  - blas=1.0=mkl\n  - brotlipy=0.7.0=py39h27cfd23_1003\n  - bzip2=1.0.8=h7b6447c_0\n  - ca-certificates=2022.4.26=h06a4308_0\n  - certifi=2022.5.18.1=py39h06a4308_0\n  - cffi=1.15.0=py39hd667e15_1\n  - charset-normalizer=2.0.4=pyhd3eb1b0_0\n  - cryptography=37.0.1=py39h9ce1e76_0\n  - cudatoolkit=11.3.1=h2bc3f7f_2\n  - ffmpeg=4.3=hf484d3e_0\n  - freetype=2.11.0=h70c0345_0\n  - giflib=5.2.1=h7b6447c_0\n  - gmp=6.2.1=h295c915_3\n  - gnutls=3.6.15=he1e5248_0\n  - idna=3.3=pyhd3eb1b0_0\n  - intel-openmp=2021.4.0=h06a4308_3561\n  - jpeg=9e=h7f8727e_0\n  - lame=3.100=h7b6447c_0\n  - lcms2=2.12=h3be6417_0\n  - ld_impl_linux-64=2.38=h1181459_1\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=11.2.0=h1234567_1\n  - libgomp=11.2.0=h1234567_1\n  - libiconv=1.16=h7f8727e_2\n  - libidn2=2.3.2=h7f8727e_0\n  - libpng=1.6.37=hbc83047_0\n  - libstdcxx-ng=11.2.0=h1234567_1\n  - libtasn1=4.16.0=h27cfd23_0\n  - libtiff=4.2.0=h2818925_1\n  - libunistring=0.9.10=h27cfd23_0\n  - libuv=1.40.0=h7b6447c_0\n  - libwebp=1.2.2=h55f646e_0\n  - libwebp-base=1.2.2=h7f8727e_0\n  - lz4-c=1.9.3=h295c915_1\n  - mkl=2021.4.0=h06a4308_640\n  - mkl-service=2.4.0=py39h7f8727e_0\n  - mkl_fft=1.3.1=py39hd3c417c_0\n  - mkl_random=1.2.2=py39h51133e4_0\n  - ncurses=6.3=h7f8727e_2\n  - nettle=3.7.3=hbbd107a_1\n  - numpy=1.22.3=py39he7a7128_0\n  - numpy-base=1.22.3=py39hf524024_0\n  - openh264=2.1.1=h4ff587b_0\n  - openssl=1.1.1o=h7f8727e_0\n  - pillow=9.0.1=py39h22f2fdc_0\n  - pip=21.2.4=py39h06a4308_0\n  - pycparser=2.21=pyhd3eb1b0_0\n  - pyopenssl=22.0.0=pyhd3eb1b0_0\n  - pysocks=1.7.1=py39h06a4308_0\n  - python=3.9.12=h12debd9_1\n  - pytorch=1.11.0=py3.9_cuda11.3_cudnn8.2.0_0\n  - pytorch-mutex=1.0=cuda\n  - readline=8.1.2=h7f8727e_1\n  - requests=2.27.1=pyhd3eb1b0_0\n  - setuptools=61.2.0=py39h06a4308_0\n  - six=1.16.0=pyhd3eb1b0_1\n  - sqlite=3.38.3=hc218d9a_0\n  - tk=8.6.12=h1ccaba5_0\n  - torchaudio=0.11.0=py39_cu113\n  - torchvision=0.12.0=py39_cu113\n  - typing_extensions=4.1.1=pyh06a4308_0\n  - tzdata=2022a=hda174b7_0\n  - urllib3=1.26.9=py39h06a4308_0\n  - wheel=0.37.1=pyhd3eb1b0_0\n  - xz=5.2.5=h7f8727e_1\n  - zlib=1.2.12=h7f8727e_2\n  - zstd=1.5.2=ha4553b6_0\n  - pip:\n    - absl-py==1.1.0\n    - addict==2.4.0\n    - aiohttp==3.8.1\n    - aiosignal==1.2.0\n    - asttokens==2.2.1\n    - async-timeout==4.0.2\n    - attrs==21.4.0\n    - backcall==0.2.0\n    - cachetools==5.2.0\n    - click==8.1.3\n    - comm==0.1.2\n    - configargparse==1.5.3\n    - contourpy==1.0.7\n    - cycler==0.11.0\n    - dash==2.8.1\n    - dash-core-components==2.0.0\n    - dash-html-components==2.0.0\n    - dash-table==5.0.0\n    - debugpy==1.6.6\n    - decorator==5.1.1\n    - einops==0.6.0\n    - einops-exts==0.0.4\n    - executing==1.2.0\n    - fastjsonschema==2.16.2\n    - flask==2.2.2\n    - fonttools==4.38.0\n    - frozenlist==1.3.0\n    - fsspec==2022.5.0\n    - google-auth==2.6.6\n    - google-auth-oauthlib==0.4.6\n    - grpcio==1.46.3\n    - imageio==2.19.3\n    - importlib-metadata==4.11.4\n    - ipykernel==6.21.1\n    - ipython==8.10.0\n    - ipywidgets==8.0.4\n    - itsdangerous==2.1.2\n    - jedi==0.18.2\n    - jinja2==3.1.2\n    - joblib==1.2.0\n    - jsonschema==4.17.3\n    - jupyter-client==8.0.2\n    - jupyter-core==5.2.0\n    - jupyterlab-widgets==3.0.5\n    - kiwisolver==1.4.4\n    - markdown==3.3.7\n    - markupsafe==2.1.2\n    - matplotlib==3.6.3\n    - matplotlib-inline==0.1.6\n    - multidict==6.0.2\n    - nbformat==5.5.0\n    - nest-asyncio==1.5.6\n    - networkx==2.8.2\n    - oauthlib==3.2.0\n    - open3d==0.16.0\n    - packaging==21.3\n    - pandas==1.4.2\n    - parso==0.8.3\n    - pexpect==4.8.0\n    - pickleshare==0.7.5\n    - platformdirs==3.0.0\n    - plotly==5.13.0\n    - plyfile==0.7.4\n    - prompt-toolkit==3.0.36\n    - protobuf==3.20.1\n    - psutil==5.9.4\n    - ptyprocess==0.7.0\n    - pure-eval==0.2.2\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pydeprecate==0.3.2\n    - pygments==2.14.0\n    - pyparsing==3.0.9\n    - pyquaternion==0.9.9\n    - pyrsistent==0.19.3\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.6.4\n    - pytz==2022.1\n    - pywavelets==1.3.0\n    - pyyaml==6.0\n    - pyzmq==25.0.0\n    - requests-oauthlib==1.3.1\n    - rotary-embedding-torch==0.2.1\n    - rsa==4.8\n    - scikit-image==0.19.2\n    - scikit-learn==1.2.1\n    - scipy==1.8.1\n    - stack-data==0.6.2\n    - tenacity==8.2.1\n    - tensorboard==2.9.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.1\n    - threadpoolctl==3.1.0\n    - tifffile==2022.5.4\n    - torch-scatter==2.0.9\n    - torchmetrics==0.9.0\n    - tornado==6.2\n    - tqdm==4.64.0\n    - traitlets==5.9.0\n    - trimesh==3.12.5\n    - wcwidth==0.2.6\n    - werkzeug==2.2.2\n    - widgetsnbextension==4.0.5\n    - yarl==1.7.2\n    - zipp==3.8.0\nprefix: /home/gchou/.conda/envs/diffusion\n"
  },
  {
    "path": "metrics/StructuralLosses/__init__.py",
    "content": "#import torch\n\n#from MakePytorchBackend import AddGPU, Foo, ApproxMatch\n\n#from Add import add_gpu, approx_match\n\n"
  },
  {
    "path": "metrics/StructuralLosses/match_cost.py",
    "content": "import torch\nfrom torch.autograd import Function\nfrom metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad\n\n# Inherit from Function\nclass MatchCostFunction(Function):\n    # Note that both forward and backward are @staticmethods\n    @staticmethod\n    # bias is an optional argument\n    def forward(ctx, seta, setb):\n        #print(\"Match Cost Forward\")\n        ctx.save_for_backward(seta, setb)\n        '''\n        input:\n\t        set1 : batch_size * #dataset_points * 3\n\t        set2 : batch_size * #query_points * 3\n        returns:\n\t        match : batch_size * #query_points * #dataset_points\n        '''\n        match, temp = ApproxMatch(seta, setb)\n        ctx.match = match\n        cost = MatchCost(seta, setb, match)\n        return cost\n\n    \"\"\"\n    grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)\n\treturn [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]\n\t\"\"\"\n    # This function has only a single output, so it gets only one gradient\n    @staticmethod\n    def backward(ctx, grad_output):\n        #print(\"Match Cost Backward\")\n        # This is a pattern that is very convenient - at the top of backward\n        # unpack saved_tensors and initialize all gradients w.r.t. inputs to\n        # None. Thanks to the fact that additional trailing Nones are\n        # ignored, the return statement is simple even when the function has\n        # optional inputs.\n        seta, setb = ctx.saved_tensors\n        #grad_input = grad_weight = grad_bias = None\n        grada, gradb = MatchCostGrad(seta, setb, ctx.match)\n        grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)\n        return grada*grad_output_expand, gradb*grad_output_expand\n\nmatch_cost = MatchCostFunction.apply\n\n"
  },
  {
    "path": "metrics/StructuralLosses/nn_distance.py",
    "content": "import torch\nfrom torch.autograd import Function\n# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad\nfrom metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad\n\n# Inherit from Function\nclass NNDistanceFunction(Function):\n    # Note that both forward and backward are @staticmethods\n    @staticmethod\n    # bias is an optional argument\n    def forward(ctx, seta, setb):\n        #print(\"Match Cost Forward\")\n        ctx.save_for_backward(seta, setb)\n        '''\n        input:\n\t        set1 : batch_size * #dataset_points * 3\n\t        set2 : batch_size * #query_points * 3\n        returns:\n\t        dist1, idx1, dist2, idx2\n        '''\n        dist1, idx1, dist2, idx2 = NNDistance(seta, setb)\n        ctx.idx1 = idx1\n        ctx.idx2 = idx2\n        return dist1, dist2\n\n    # This function has only a single output, so it gets only one gradient\n    @staticmethod\n    def backward(ctx, grad_dist1, grad_dist2):\n        #print(\"Match Cost Backward\")\n        # This is a pattern that is very convenient - at the top of backward\n        # unpack saved_tensors and initialize all gradients w.r.t. inputs to\n        # None. Thanks to the fact that additional trailing Nones are\n        # ignored, the return statement is simple even when the function has\n        # optional inputs.\n        seta, setb = ctx.saved_tensors\n        idx1 = ctx.idx1\n        idx2 = ctx.idx2\n        grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)\n        return grada, gradb\n\nnn_distance = NNDistanceFunction.apply\n\n"
  },
  {
    "path": "metrics/__init__.py",
    "content": ""
  },
  {
    "path": "metrics/evaluation_metrics.py",
    "content": "import torch\nimport numpy as np\nimport warnings\nfrom scipy.stats import entropy\nfrom sklearn.neighbors import NearestNeighbors\nfrom numpy.linalg import norm\nfrom scipy.optimize import linear_sum_assignment\nfrom tqdm import tqdm\n\n\n# Borrow from https://github.com/ThibaultGROUEIX/AtlasNet\ndef distChamfer(a, b):\n    x, y = a, b\n    bs, num_points, points_dim = x.size()\n    xx = torch.bmm(x, x.transpose(2, 1))\n    yy = torch.bmm(y, y.transpose(2, 1))\n    zz = torch.bmm(x, y.transpose(2, 1))\n    diag_ind = torch.arange(0, num_points).to(a).long()\n    rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx)\n    ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy)\n    P = (rx.transpose(2, 1) + ry - 2 * zz)\n    return P.min(1)[0], P.min(2)[0]\n\n# Import CUDA version of approximate EMD, from https://github.com/zekunhao1995/pcgan-pytorch/\nfrom metrics.StructuralLosses.nn_distance import nn_distance\ntry:\n    \n    #from pytorch_structural_losses.nn_distance import nn_distance\n    #print(\"cuda available\")\n    def distChamferCUDA(x, y):\n        return nn_distance(x, y)\nexcept:\n    print(\"distChamferCUDA not available; fall back to slower version.\")\n    # def distChamferCUDA(x, y):\n    #     return distChamfer(x, y)\n\n\ndef emd_approx(x, y):\n    bs, npts, mpts, dim = x.size(0), x.size(1), y.size(1), x.size(2)\n    assert npts == mpts, \"EMD only works if two point clouds are equal size\"\n    dim = x.shape[-1]\n    x = x.reshape(bs, npts, 1, dim)\n    y = y.reshape(bs, 1, mpts, dim)\n    dist = (x - y).norm(dim=-1, keepdim=False)  # (bs, npts, mpts)\n\n    emd_lst = []\n    dist_np = dist.cpu().detach().numpy()\n    for i in range(bs):\n        d_i = dist_np[i]\n        r_idx, c_idx = linear_sum_assignment(d_i)\n        emd_i = d_i[r_idx, c_idx].mean()\n        emd_lst.append(emd_i)\n    emd = np.stack(emd_lst).reshape(-1)\n    emd_torch = torch.from_numpy(emd).to(x)\n    return emd_torch\n\n\ntry:\n    from metrics.StructuralLosses.match_cost import match_cost\n    #print(\"cuda available\")\n    def emd_approx_cuda(sample, ref):\n        B, N, N_ref = sample.size(0), sample.size(1), ref.size(1)\n        assert N == N_ref, \"Not sure what would EMD do in this case\"\n        emd = match_cost(sample, ref)  # (B,)\n        emd_norm = emd / float(N)  # (B,)\n        return emd_norm\nexcept:\n    print(\"emd_approx_cuda not available. Fall back to slower version.\")\n    # def emd_approx_cuda(sample, ref):\n        #return emd_approx(sample, ref)\n\n\ndef EMD_CD(sample_pcs, ref_pcs, batch_size, accelerated_cd=False, reduced=True,\n           accelerated_emd=False):\n    N_sample = sample_pcs.shape[0]\n    N_ref = ref_pcs.shape[0]\n    assert N_sample == N_ref, \"REF:%d SMP:%d\" % (N_ref, N_sample)\n\n    cd_lst = []\n    emd_lst = []\n    iterator = range(0, N_sample, batch_size)\n\n    for b_start in iterator:\n        b_end = min(N_sample, b_start + batch_size)\n        sample_batch = sample_pcs[b_start:b_end]\n        ref_batch = ref_pcs[b_start:b_end]\n\n        if accelerated_cd:\n            dl, dr = distChamferCUDA(sample_batch, ref_batch)\n        else:\n            dl, dr = distChamfer(sample_batch, ref_batch)\n        cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1))\n\n        if accelerated_emd:\n            emd_batch = emd_approx_cuda(sample_batch, ref_batch)\n        else:\n            emd_batch = emd_approx(sample_batch, ref_batch)\n        emd_lst.append(emd_batch)\n\n    if reduced:\n        cd = torch.cat(cd_lst).mean()\n        emd = torch.cat(emd_lst).mean()\n    else:\n        cd = torch.cat(cd_lst)\n        emd = torch.cat(emd_lst)\n\n    results = {\n        'MMD-CD': cd,\n        'MMD-EMD': emd,\n    }\n    return results\n\n\ndef _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True,\n                      accelerated_emd=True):\n    batch_size = sample_pcs.shape[0] if batch_size is None else batch_size\n    N_sample = sample_pcs.shape[0]\n    N_ref = ref_pcs.shape[0]\n    all_cd = []\n    all_emd = []\n    iterator = range(N_sample)\n    with tqdm(iterator) as pbar:\n        for sample_b_start in pbar:\n            pbar.set_description(\"Files evaluated: {}/{}\".format(sample_b_start, N_sample))\n            sample_batch = sample_pcs[sample_b_start]\n\n            cd_lst = []\n            emd_lst = []\n            for ref_b_start in range(0, N_ref, batch_size):\n                ref_b_end = min(N_ref, ref_b_start + batch_size)\n                ref_batch = ref_pcs[ref_b_start:ref_b_end]\n\n                batch_size_ref = ref_batch.size(0)\n                sample_batch_exp = sample_batch.view(1, -1, 3).expand(batch_size_ref, -1, -1)\n                sample_batch_exp = sample_batch_exp.contiguous()\n\n                if accelerated_cd and distChamferCUDA is not None:\n                    dl, dr = distChamferCUDA(sample_batch_exp, ref_batch)\n                else:\n                    dl, dr = distChamfer(sample_batch_exp, ref_batch)\n                cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1))\n\n                # if accelerated_emd:\n                #     emd_batch = emd_approx_cuda(sample_batch_exp, ref_batch)\n                # else:\n                #     emd_batch = emd_approx(sample_batch_exp, ref_batch)\n                # emd_lst.append(emd_batch.view(1, -1))\n\n            cd_lst = torch.cat(cd_lst, dim=1)\n            #emd_lst = torch.cat(emd_lst, dim=1)\n            all_cd.append(cd_lst)\n            #all_emd.append(emd_lst)\n\n    all_cd = torch.cat(all_cd, dim=0)  # N_sample, N_ref\n    #all_emd = torch.cat(all_emd, dim=0)  # N_sample, N_ref\n\n    return all_cd, 0#all_emd\n\ndef emd_tmd_from_pcs(gen_pcs):\n    sum_dist = 0\n    for j in range(len(gen_pcs)):\n        for k in range(j + 1, len(gen_pcs), 1):\n            pc1 = gen_pcs[j]\n            pc2 = gen_pcs[k]\n            chamfer_dist = emd_approx_cuda(pc1, pc2) #compute_trimesh_chamfer(pc1, pc2)\n            print(\"emd dist: \", chamfer_dist)\n            sum_dist += chamfer_dist\n    mean_dist = sum_dist * 2 / (len(gen_pcs) - 1)\n    return mean_dist\n\n\n# Adapted from https://github.com/xuqiantong/GAN-Metrics/blob/master/framework/metric.py\ndef knn(Mxx, Mxy, Myy, k, sqrt=False):\n    n0 = Mxx.size(0)\n    n1 = Myy.size(0)\n    label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx)\n    M = torch.cat((torch.cat((Mxx, Mxy), 1), torch.cat((Mxy.transpose(0, 1), Myy), 1)), 0)\n    if sqrt:\n        M = M.abs().sqrt()\n    INFINITY = float('inf')\n    val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk(k, 0, False)\n\n    count = torch.zeros(n0 + n1).to(Mxx)\n    for i in range(0, k):\n        count = count + label.index_select(0, idx[i])\n    pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float()\n\n    s = {\n        'tp': (pred * label).sum(),\n        'fp': (pred * (1 - label)).sum(),\n        'fn': ((1 - pred) * label).sum(),\n        'tn': ((1 - pred) * (1 - label)).sum(),\n    }\n\n    s.update({\n        'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10),\n        'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10),\n        'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10),\n        'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10),\n        'acc': torch.eq(label, pred).float().mean(),\n    })\n    return s\n\n\ndef lgan_mmd_cov(all_dist):\n    # all dist shape = [number of generated pcs, number of ref pcs]\n    N_sample, N_ref = all_dist.size(0), all_dist.size(1)\n    min_val_fromsmp, min_idx = torch.min(all_dist, dim=1)\n    min_val, _ = torch.min(all_dist, dim=0)\n    mmd = min_val.mean()\n    #mmd_smp = min_val_fromsmp.mean()\n    cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref)\n    cov = torch.tensor(cov).to(all_dist)\n    return {\n        'lgan_mmd': mmd,\n        'lgan_cov': cov,\n        #'lgan_mmd_smp': mmd_smp,\n    }\n\n\ndef compute_mmd(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True):\n\n    results = {}\n    batch_size = sample_pcs.shape[0] if batch_size is None else batch_size\n\n    M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)\n\n    res_cd = lgan_mmd_cov(M_rs_cd.t())\n    # results.update({\n    #     \"%s-CD\" % k: v for k, v in res_cd.items()\n    # })\n\n    #res_emd = lgan_mmd_cov(M_rs_emd.t())\n    # results.update({\n    #     \"%s-EMD\" % k: v for k, v in res_emd.items()\n    # })\n\n    return res_cd,0# res_emd\n\n\ndef compute_cd(sample_pcs, ref_pcs):\n    \n    dl, dr = distChamferCUDA(sample_pcs, ref_pcs)\n    print(\"dl, dr shapes: \", dl.shape, dr.shape, dl.mean(dim=1).shape)\n    res = (dl.mean(dim=1) + dr.mean(dim=1))#.view(1, -1)\n    return res\n\n\ndef compute_all_metrics(sample_pcs, ref_pcs, batch_size=None, accelerated_cd=True):\n    results = {}\n\n    batch_size = sample_pcs.shape[0] if batch_size is None else batch_size\n\n    M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)\n\n    res_cd = lgan_mmd_cov(M_rs_cd.t())\n    results.update({\n        \"%s-CD\" % k: v for k, v in res_cd.items()\n    })\n\n    res_emd = lgan_mmd_cov(M_rs_emd.t())\n    results.update({\n        \"%s-EMD\" % k: v for k, v in res_emd.items()\n    })\n\n    M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size, accelerated_cd=accelerated_cd)\n\n    M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size, accelerated_cd=accelerated_cd)\n\n    # 1-NN results\n    one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False)\n    results.update({\n        \"1-NN-CD-%s\" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k\n    })\n    one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False)\n    results.update({\n        \"1-NN-EMD-%s\" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k\n    })\n\n    return res_cd, res_emd, one_nn_cd_res, one_nn_emd_res\n\n\n#######################################################\n# JSD : from https://github.com/optas/latent_3d_points\n#######################################################\ndef unit_cube_grid_point_cloud(resolution, clip_sphere=False):\n    \"\"\"Returns the center coordinates of each cell of a 3D grid with resolution^3 cells,\n    that is placed in the unit-cube.\n    If clip_sphere it True it drops the \"corner\" cells that lie outside the unit-sphere.\n    \"\"\"\n    grid = np.ndarray((resolution, resolution, resolution, 3), np.float32)\n    spacing = 1.0 / float(resolution - 1)\n    for i in range(resolution):\n        for j in range(resolution):\n            for k in range(resolution):\n                grid[i, j, k, 0] = i * spacing - 0.5\n                grid[i, j, k, 1] = j * spacing - 0.5\n                grid[i, j, k, 2] = k * spacing - 0.5\n\n    if clip_sphere:\n        grid = grid.reshape(-1, 3)\n        grid = grid[norm(grid, axis=1) <= 0.5]\n\n    return grid, spacing\n\n\ndef jsd_between_point_cloud_sets(sample_pcs, ref_pcs, resolution=28):\n    \"\"\"Computes the JSD between two sets of point-clouds, as introduced in the paper\n    ```Learning Representations And Generative Models For 3D Point Clouds```.\n    Args:\n        sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points.\n        ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points.\n        resolution: (int) grid-resolution. Affects granularity of measurements.\n    \"\"\"\n    in_unit_sphere = True\n    sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1]\n    ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1]\n    return jensen_shannon_divergence(sample_grid_var, ref_grid_var)\n\n\ndef entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False, verbose=False):\n    \"\"\"Given a collection of point-clouds, estimate the entropy of the random variables\n    corresponding to occupancy-grid activation patterns.\n    Inputs:\n        pclouds: (numpy array) #point-clouds x points per point-cloud x 3\n        grid_resolution (int) size of occupancy grid that will be used.\n    \"\"\"\n    epsilon = 10e-4\n    bound = 0.5 + epsilon\n    if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound:\n        if verbose:\n            warnings.warn('Point-clouds are not in unit cube.')\n\n    if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound:\n        if verbose:\n            warnings.warn('Point-clouds are not in unit sphere.')\n\n    grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere)\n    grid_coordinates = grid_coordinates.reshape(-1, 3)\n    grid_counters = np.zeros(len(grid_coordinates))\n    grid_bernoulli_rvars = np.zeros(len(grid_coordinates))\n    nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates)\n\n    for pc in pclouds:\n        _, indices = nn.kneighbors(pc)\n        indices = np.squeeze(indices)\n        for i in indices:\n            grid_counters[i] += 1\n        indices = np.unique(indices)\n        for i in indices:\n            grid_bernoulli_rvars[i] += 1\n\n    acc_entropy = 0.0\n    n = float(len(pclouds))\n    for g in grid_bernoulli_rvars:\n        if g > 0:\n            p = float(g) / n\n            acc_entropy += entropy([p, 1.0 - p])\n\n    return acc_entropy / len(grid_counters), grid_counters\n\n\ndef jensen_shannon_divergence(P, Q):\n    if np.any(P < 0) or np.any(Q < 0):\n        raise ValueError('Negative values.')\n    if len(P) != len(Q):\n        raise ValueError('Non equal size.')\n\n    P_ = P / np.sum(P)  # Ensure probabilities.\n    Q_ = Q / np.sum(Q)\n\n    e1 = entropy(P_, base=2)\n    e2 = entropy(Q_, base=2)\n    e_sum = entropy((P_ + Q_) / 2.0, base=2)\n    res = e_sum - ((e1 + e2) / 2.0)\n\n    res2 = _jsdiv(P_, Q_)\n\n    if not np.allclose(res, res2, atol=10e-5, rtol=0):\n        warnings.warn('Numerical values of two JSD methods don\\'t agree.')\n\n    return res\n\n\ndef _jsdiv(P, Q):\n    \"\"\"another way of computing JSD\"\"\"\n\n    def _kldiv(A, B):\n        a = A.copy()\n        b = B.copy()\n        idx = np.logical_and(a > 0, b > 0)\n        a = a[idx]\n        b = b[idx]\n        return np.sum([v for v in a * np.log2(a / b)])\n\n    P_ = P / np.sum(P)\n    Q_ = Q / np.sum(Q)\n\n    M = 0.5 * (P_ + Q_)\n\n    return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M))\n\n\nif __name__ == \"__main__\":\n    #from pytorch_structural_losses.nn_distance import nn_distance # need to make\n    #from metrics.StructuralLosses.nn_distance import nn_distance\n    B, N = 100, 2048\n    x = torch.rand(B, N, 3)\n    y = torch.rand(B, N, 3)\n\n    #get_dist = nn_distance # if cuda available\n\n    get_dist = distChamfer # cuda not available\n\n    #distChamfer = distChamferCUDA\n    min_l, min_r = get_dist(x.cuda(), y.cuda())\n    print(min_l.shape)\n    print(min_r.shape)\n\n    l_dist = min_l.mean().cpu().detach().item()\n    r_dist = min_r.mean().cpu().detach().item()\n    print(l_dist, r_dist)\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/.gitignore",
    "content": "PyTorchStructuralLosses.egg-info/\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/Makefile",
    "content": "###############################################################################\n# Uncomment for debugging\n# DEBUG := 1\n# Pretty build\n# Q ?= @\n\nCXX := g++\nPYTHON := python\nNVCC := /usr/local/cuda/bin/nvcc\n\n# PYTHON Header path\nPYTHON_HEADER_DIR := $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())')\nPYTORCH_INCLUDES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]')\nPYTORCH_LIBRARIES := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]')\n\n# CUDA ROOT DIR that contains bin/ lib64/ and include/\n# CUDA_DIR := /usr/local/cuda\nCUDA_DIR := $(shell $(PYTHON) -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())')\n\nINCLUDE_DIRS := ./ $(CUDA_DIR)/include\n\nINCLUDE_DIRS += $(PYTHON_HEADER_DIR)\nINCLUDE_DIRS += $(PYTORCH_INCLUDES)\n\n# Custom (MKL/ATLAS/OpenBLAS) include and lib directories.\n# Leave commented to accept the defaults for your choice of BLAS\n# (which should work)!\n# BLAS_INCLUDE := /path/to/your/blas\n# BLAS_LIB := /path/to/your/blas\n\n###############################################################################\nSRC_DIR := ./src\nOBJ_DIR := ./objs\nCPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp)\nCU_SRCS := $(wildcard $(SRC_DIR)/*.cu)\nOBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS))\nCU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS))\nSTATIC_LIB := $(OBJ_DIR)/libmake_pytorch.a\n\n# CUDA architecture setting: going with all of them.\n# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility.\n# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility.\nCUDA_ARCH :=\t-gencode arch=compute_61,code=sm_61 \\\n\t\t-gencode arch=compute_61,code=compute_61 \\\n\t\t-gencode arch=compute_52,code=sm_52\n\n# We will also explicitly add stdc++ to the link target.\nLIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu\n\n# Debugging\nifeq ($(DEBUG), 1)\n\tCOMMON_FLAGS += -DDEBUG -g -O0\n\t# https://gcoe-dresden.de/reaching-the-shore-with-a-fog-warning-my-eurohack-day-4-morning-session/\n\tNVCCFLAGS += -g -G # -rdc true\nelse\n\tCOMMON_FLAGS += -DNDEBUG -O3\nendif\n\nWARNINGS := -Wall -Wno-sign-compare -Wcomment\n\nINCLUDE_DIRS += $(BLAS_INCLUDE)\n\n# Automatic dependency generation (nvcc is handled separately)\nCXXFLAGS += -std=c++14 -MMD -MP\n\n# Complete build flags.\nCOMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \\\n\t     -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=0\nCXXFLAGS += -pthread -fPIC -fwrapv -std=c++14 $(COMMON_FLAGS) $(WARNINGS)\nNVCCFLAGS += -std=c++14 -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)\n\nall: $(STATIC_LIB)\n\t$(PYTHON) setup.py build\n\t@ mv build/lib.linux-x86_64-cpython-39/StructuralLosses ..\n\t#@ mv build/lib.linux-x86_64-3.6/StructuralLosses ..\n\t@ mv build/lib.linux-x86_64-cpython-39/*.so ../StructuralLosses/\n\t#@ mv build/lib.linux-x86_64-3.6/*.so ../StructuralLosses/\n\t@- $(RM) -rf $(OBJ_DIR) build objs\n\n$(OBJ_DIR):\n\t@ mkdir -p $@\n\t@ mkdir -p $@/cuda\n\n$(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR)\n\t@ echo CXX $<\n\t$(Q)$(CXX) $< $(CXXFLAGS) -c -o $@\n\n$(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR)\n\t@ echo NVCC $<\n\t$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \\\n\t\t-odir $(@D)\n\t$(Q)$(NVCC) $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@\n\n$(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR)\n\t$(RM) -f $(STATIC_LIB)\n\t$(RM) -rf build dist\n\t@ echo LD -o $@\n\tar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS)\n\nclean:\n\t@- $(RM) -rf $(OBJ_DIR) build dist ../StructuralLosses\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/StructuralLosses/__init__.py",
    "content": "#import torch\n\n#from MakePytorchBackend import AddGPU, Foo, ApproxMatch\n\n#from Add import add_gpu, approx_match\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/StructuralLosses/match_cost.py",
    "content": "import torch\nfrom torch.autograd import Function\nfrom metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad\n\n# Inherit from Function\nclass MatchCostFunction(Function):\n    # Note that both forward and backward are @staticmethods\n    @staticmethod\n    # bias is an optional argument\n    def forward(ctx, seta, setb):\n        #print(\"Match Cost Forward\")\n        ctx.save_for_backward(seta, setb)\n        '''\n        input:\n\t        set1 : batch_size * #dataset_points * 3\n\t        set2 : batch_size * #query_points * 3\n        returns:\n\t        match : batch_size * #query_points * #dataset_points\n        '''\n        match, temp = ApproxMatch(seta, setb)\n        ctx.match = match\n        cost = MatchCost(seta, setb, match)\n        return cost\n\n    \"\"\"\n    grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)\n\treturn [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]\n\t\"\"\"\n    # This function has only a single output, so it gets only one gradient\n    @staticmethod\n    def backward(ctx, grad_output):\n        #print(\"Match Cost Backward\")\n        # This is a pattern that is very convenient - at the top of backward\n        # unpack saved_tensors and initialize all gradients w.r.t. inputs to\n        # None. Thanks to the fact that additional trailing Nones are\n        # ignored, the return statement is simple even when the function has\n        # optional inputs.\n        seta, setb = ctx.saved_tensors\n        #grad_input = grad_weight = grad_bias = None\n        grada, gradb = MatchCostGrad(seta, setb, ctx.match)\n        grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)\n        return grada*grad_output_expand, gradb*grad_output_expand\n\nmatch_cost = MatchCostFunction.apply\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/StructuralLosses/nn_distance.py",
    "content": "import torch\nfrom torch.autograd import Function\n# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad\nfrom metrics.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad\n\n# Inherit from Function\nclass NNDistanceFunction(Function):\n    # Note that both forward and backward are @staticmethods\n    @staticmethod\n    # bias is an optional argument\n    def forward(ctx, seta, setb):\n        #print(\"Match Cost Forward\")\n        ctx.save_for_backward(seta, setb)\n        '''\n        input:\n\t        set1 : batch_size * #dataset_points * 3\n\t        set2 : batch_size * #query_points * 3\n        returns:\n\t        dist1, idx1, dist2, idx2\n        '''\n        dist1, idx1, dist2, idx2 = NNDistance(seta, setb)\n        ctx.idx1 = idx1\n        ctx.idx2 = idx2\n        return dist1, dist2\n\n    # This function has only a single output, so it gets only one gradient\n    @staticmethod\n    def backward(ctx, grad_dist1, grad_dist2):\n        #print(\"Match Cost Backward\")\n        # This is a pattern that is very convenient - at the top of backward\n        # unpack saved_tensors and initialize all gradients w.r.t. inputs to\n        # None. Thanks to the fact that additional trailing Nones are\n        # ignored, the return statement is simple even when the function has\n        # optional inputs.\n        seta, setb = ctx.saved_tensors\n        idx1 = ctx.idx1\n        idx2 = ctx.idx2\n        grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)\n        return grada, gradb\n\nnn_distance = NNDistanceFunction.apply\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/__init__.py",
    "content": "#import torch\n\n#from MakePytorchBackend import AddGPU, Foo, ApproxMatch\n\n#from Add import add_gpu, approx_match\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/match_cost.py",
    "content": "import torch\nfrom torch.autograd import Function\nfrom metrics.StructuralLosses.StructuralLossesBackend import ApproxMatch, MatchCost, MatchCostGrad\n\n# Inherit from Function\nclass MatchCostFunction(Function):\n    # Note that both forward and backward are @staticmethods\n    @staticmethod\n    # bias is an optional argument\n    def forward(ctx, seta, setb):\n        #print(\"Match Cost Forward\")\n        ctx.save_for_backward(seta, setb)\n        '''\n        input:\n\t        set1 : batch_size * #dataset_points * 3\n\t        set2 : batch_size * #query_points * 3\n        returns:\n\t        match : batch_size * #query_points * #dataset_points\n        '''\n        match, temp = ApproxMatch(seta, setb)\n        ctx.match = match\n        cost = MatchCost(seta, setb, match)\n        return cost\n\n    \"\"\"\n    grad_1,grad_2=approxmatch_module.match_cost_grad(xyz1,xyz2,match)\n\treturn [grad_1*tf.expand_dims(tf.expand_dims(grad_cost,1),2),grad_2*tf.expand_dims(tf.expand_dims(grad_cost,1),2),None]\n\t\"\"\"\n    # This function has only a single output, so it gets only one gradient\n    @staticmethod\n    def backward(ctx, grad_output):\n        #print(\"Match Cost Backward\")\n        # This is a pattern that is very convenient - at the top of backward\n        # unpack saved_tensors and initialize all gradients w.r.t. inputs to\n        # None. Thanks to the fact that additional trailing Nones are\n        # ignored, the return statement is simple even when the function has\n        # optional inputs.\n        seta, setb = ctx.saved_tensors\n        #grad_input = grad_weight = grad_bias = None\n        grada, gradb = MatchCostGrad(seta, setb, ctx.match)\n        grad_output_expand = grad_output.unsqueeze(1).unsqueeze(2)\n        return grada*grad_output_expand, gradb*grad_output_expand\n\nmatch_cost = MatchCostFunction.apply\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/nn_distance.py",
    "content": "import torch\nfrom torch.autograd import Function\n# from extensions.StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad\nfrom .StructuralLosses.StructuralLossesBackend import NNDistance, NNDistanceGrad\n\n# Inherit from Function\nclass NNDistanceFunction(Function):\n    # Note that both forward and backward are @staticmethods\n    @staticmethod\n    # bias is an optional argument\n    def forward(ctx, seta, setb):\n        #print(\"Match Cost Forward\")\n        ctx.save_for_backward(seta, setb)\n        '''\n        input:\n\t        set1 : batch_size * #dataset_points * 3\n\t        set2 : batch_size * #query_points * 3\n        returns:\n\t        dist1, idx1, dist2, idx2\n        '''\n        dist1, idx1, dist2, idx2 = NNDistance(seta, setb)\n        ctx.idx1 = idx1\n        ctx.idx2 = idx2\n        return dist1, dist2\n\n    # This function has only a single output, so it gets only one gradient\n    @staticmethod\n    def backward(ctx, grad_dist1, grad_dist2):\n        #print(\"Match Cost Backward\")\n        # This is a pattern that is very convenient - at the top of backward\n        # unpack saved_tensors and initialize all gradients w.r.t. inputs to\n        # None. Thanks to the fact that additional trailing Nones are\n        # ignored, the return statement is simple even when the function has\n        # optional inputs.\n        seta, setb = ctx.saved_tensors\n        idx1 = ctx.idx1\n        idx2 = ctx.idx2\n        grada, gradb = NNDistanceGrad(seta, setb, idx1, idx2, grad_dist1, grad_dist2)\n        return grada, gradb\n\nnn_distance = NNDistanceFunction.apply\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/pybind/bind.cpp",
    "content": "#include <string>\n\n#include <torch/extension.h>\n\n#include \"pybind/extern.hpp\"\n\nnamespace py = pybind11;\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m){\n  m.def(\"ApproxMatch\", &ApproxMatch);\n  m.def(\"MatchCost\", &MatchCost);\n  m.def(\"MatchCostGrad\", &MatchCostGrad);\n  m.def(\"NNDistance\", &NNDistance);\n  m.def(\"NNDistanceGrad\", &NNDistanceGrad);\n}\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/pybind/extern.hpp",
    "content": "std::vector<at::Tensor> ApproxMatch(at::Tensor in_a, at::Tensor in_b);\nat::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match);\nstd::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match);\n\nstd::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q);\nstd::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2);\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import CUDAExtension, BuildExtension\n\n# Python interface\nsetup(\n    name='PyTorchStructuralLosses',\n    version='0.1.0',\n    install_requires=['torch'],\n    packages=['StructuralLosses'],\n    package_dir={'StructuralLosses': './'},\n    ext_modules=[\n        CUDAExtension(\n            name='StructuralLossesBackend',\n            include_dirs=['./'],\n            sources=[\n                'pybind/bind.cpp',\n            ],\n            libraries=['make_pytorch'],\n            library_dirs=['objs'],\n            # extra_compile_args=['-g']\n        )\n    ],\n    cmdclass={'build_ext': BuildExtension},\n    author='Christopher B. Choy',\n    author_email='chrischoy@ai.stanford.edu',\n    description='Tutorial for Pytorch C++ Extension with a Makefile',\n    keywords='Pytorch C++ Extension',\n    url='https://github.com/chrischoy/MakePytorchPlusPlus',\n    zip_safe=False,\n)\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/approxmatch.cu",
    "content": "#include \"utils.hpp\"\n\n__global__ void approxmatchkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){\n\tfloat * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n;\n\tfloat multiL,multiR;\n\tif (n>=m){\n\t\tmultiL=1;\n\t\tmultiR=n/m;\n\t}else{\n\t\tmultiL=m/n;\n\t\tmultiR=1;\n\t}\n\tconst int Block=1024;\n\t__shared__ float buf[Block*4];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfor (int j=threadIdx.x;j<n*m;j+=blockDim.x)\n\t\t\tmatch[i*n*m+j]=0;\n\t\tfor (int j=threadIdx.x;j<n;j+=blockDim.x)\n\t\t\tremainL[j]=multiL;\n\t\tfor (int j=threadIdx.x;j<m;j+=blockDim.x)\n\t\t\tremainR[j]=multiR;\n\t\t__syncthreads();\n\t\t//for (int j=7;j>=-2;j--){\n\t\tfor (int j=7;j>-2;j--){\n\t\t\tfloat level=-powf(4.0f,j);\n\t\t\tif (j==-2){\n\t\t\t\tlevel=0;\n\t\t\t}\n\t\t\tfor (int k0=0;k0<n;k0+=blockDim.x){\n\t\t\t\tint k=k0+threadIdx.x;\n\t\t\t\tfloat x1=0,y1=0,z1=0;\n\t\t\t\tif (k<n){\n\t\t\t\t\tx1=xyz1[i*n*3+k*3+0];\n\t\t\t\t\ty1=xyz1[i*n*3+k*3+1];\n\t\t\t\t\tz1=xyz1[i*n*3+k*3+2];\n\t\t\t\t}\n\t\t\t\tfloat suml=1e-9f;\n\t\t\t\tfor (int l0=0;l0<m;l0+=Block){\n\t\t\t\t\tint lend=min(m,l0+Block)-l0;\n\t\t\t\t\tfor (int l=threadIdx.x;l<lend;l+=blockDim.x){\n\t\t\t\t\t\tfloat x2=xyz2[i*m*3+l0*3+l*3+0];\n\t\t\t\t\t\tfloat y2=xyz2[i*m*3+l0*3+l*3+1];\n\t\t\t\t\t\tfloat z2=xyz2[i*m*3+l0*3+l*3+2];\n\t\t\t\t\t\tbuf[l*4+0]=x2;\n\t\t\t\t\t\tbuf[l*4+1]=y2;\n\t\t\t\t\t\tbuf[l*4+2]=z2;\n\t\t\t\t\t\tbuf[l*4+3]=remainR[l0+l];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tfor (int l=0;l<lend;l++){\n\t\t\t\t\t\tfloat x2=buf[l*4+0];\n\t\t\t\t\t\tfloat y2=buf[l*4+1];\n\t\t\t\t\t\tfloat z2=buf[l*4+2];\n\t\t\t\t\t\tfloat d=level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1));\n\t\t\t\t\t\tfloat w=__expf(d)*buf[l*4+3];\n\t\t\t\t\t\tsuml+=w;\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (k<n)\n\t\t\t\t\tratioL[k]=remainL[k]/suml;\n\t\t\t}\n\t\t\t/*for (int k=threadIdx.x;k<n;k+=gridDim.x){\n\t\t\t\tfloat x1=xyz1[i*n*3+k*3+0];\n\t\t\t\tfloat y1=xyz1[i*n*3+k*3+1];\n\t\t\t\tfloat z1=xyz1[i*n*3+k*3+2];\n\t\t\t\tfloat suml=1e-9f;\n\t\t\t\tfor (int l=0;l<m;l++){\n\t\t\t\t\tfloat x2=xyz2[i*m*3+l*3+0];\n\t\t\t\t\tfloat y2=xyz2[i*m*3+l*3+1];\n\t\t\t\t\tfloat z2=xyz2[i*m*3+l*3+2];\n\t\t\t\t\tfloat w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*remainR[l];\n\t\t\t\t\tsuml+=w;\n\t\t\t\t}\n\t\t\t\tratioL[k]=remainL[k]/suml;\n\t\t\t}*/\n\t\t\t__syncthreads();\n\t\t\tfor (int l0=0;l0<m;l0+=blockDim.x){\n\t\t\t\tint l=l0+threadIdx.x;\n\t\t\t\tfloat x2=0,y2=0,z2=0;\n\t\t\t\tif (l<m){\n\t\t\t\t\tx2=xyz2[i*m*3+l*3+0];\n\t\t\t\t\ty2=xyz2[i*m*3+l*3+1];\n\t\t\t\t\tz2=xyz2[i*m*3+l*3+2];\n\t\t\t\t}\n\t\t\t\tfloat sumr=0;\n\t\t\t\tfor (int k0=0;k0<n;k0+=Block){\n\t\t\t\t\tint kend=min(n,k0+Block)-k0;\n\t\t\t\t\tfor (int k=threadIdx.x;k<kend;k+=blockDim.x){\n\t\t\t\t\t\tbuf[k*4+0]=xyz1[i*n*3+k0*3+k*3+0];\n\t\t\t\t\t\tbuf[k*4+1]=xyz1[i*n*3+k0*3+k*3+1];\n\t\t\t\t\t\tbuf[k*4+2]=xyz1[i*n*3+k0*3+k*3+2];\n\t\t\t\t\t\tbuf[k*4+3]=ratioL[k0+k];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tfor (int k=0;k<kend;k++){\n\t\t\t\t\t\tfloat x1=buf[k*4+0];\n\t\t\t\t\t\tfloat y1=buf[k*4+1];\n\t\t\t\t\t\tfloat z1=buf[k*4+2];\n\t\t\t\t\t\tfloat w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*buf[k*4+3];\n\t\t\t\t\t\tsumr+=w;\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (l<m){\n\t\t\t\t\tsumr*=remainR[l];\n\t\t\t\t\tfloat consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);\n\t\t\t\t\tratioR[l]=consumption*remainR[l];\n\t\t\t\t\tremainR[l]=fmaxf(0.0f,remainR[l]-sumr);\n\t\t\t\t}\n\t\t\t}\n\t\t\t/*for (int l=threadIdx.x;l<m;l+=blockDim.x){\n\t\t\t\tfloat x2=xyz2[i*m*3+l*3+0];\n\t\t\t\tfloat y2=xyz2[i*m*3+l*3+1];\n\t\t\t\tfloat z2=xyz2[i*m*3+l*3+2];\n\t\t\t\tfloat sumr=0;\n\t\t\t\tfor (int k=0;k<n;k++){\n\t\t\t\t\tfloat x1=xyz1[i*n*3+k*3+0];\n\t\t\t\t\tfloat y1=xyz1[i*n*3+k*3+1];\n\t\t\t\t\tfloat z1=xyz1[i*n*3+k*3+2];\n\t\t\t\t\tfloat w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k];\n\t\t\t\t\tsumr+=w;\n\t\t\t\t}\n\t\t\t\tsumr*=remainR[l];\n\t\t\t\tfloat consumption=fminf(remainR[l]/(sumr+1e-9f),1.0f);\n\t\t\t\tratioR[l]=consumption*remainR[l];\n\t\t\t\tremainR[l]=fmaxf(0.0f,remainR[l]-sumr);\n\t\t\t}*/\n\t\t\t__syncthreads();\n\t\t\tfor (int k0=0;k0<n;k0+=blockDim.x){\n\t\t\t\tint k=k0+threadIdx.x;\n\t\t\t\tfloat x1=0,y1=0,z1=0;\n\t\t\t\tif (k<n){\n\t\t\t\t\tx1=xyz1[i*n*3+k*3+0];\n\t\t\t\t\ty1=xyz1[i*n*3+k*3+1];\n\t\t\t\t\tz1=xyz1[i*n*3+k*3+2];\n\t\t\t\t}\n\t\t\t\tfloat suml=0;\n\t\t\t\tfor (int l0=0;l0<m;l0+=Block){\n\t\t\t\t\tint lend=min(m,l0+Block)-l0;\n\t\t\t\t\tfor (int l=threadIdx.x;l<lend;l+=blockDim.x){\n\t\t\t\t\t\tbuf[l*4+0]=xyz2[i*m*3+l0*3+l*3+0];\n\t\t\t\t\t\tbuf[l*4+1]=xyz2[i*m*3+l0*3+l*3+1];\n\t\t\t\t\t\tbuf[l*4+2]=xyz2[i*m*3+l0*3+l*3+2];\n\t\t\t\t\t\tbuf[l*4+3]=ratioR[l0+l];\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t\tfloat rl=ratioL[k];\n\t\t\t\t\tif (k<n){\n\t\t\t\t\t\tfor (int l=0;l<lend;l++){\n\t\t\t\t\t\t\tfloat x2=buf[l*4+0];\n\t\t\t\t\t\t\tfloat y2=buf[l*4+1];\n\t\t\t\t\t\t\tfloat z2=buf[l*4+2];\n\t\t\t\t\t\t\tfloat w=__expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*rl*buf[l*4+3];\n\t\t\t\t\t\t\tmatch[i*n*m+(l0+l)*n+k]+=w;\n\t\t\t\t\t\t\tsuml+=w;\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t\t__syncthreads();\n\t\t\t\t}\n\t\t\t\tif (k<n)\n\t\t\t\t\tremainL[k]=fmaxf(0.0f,remainL[k]-suml);\n\t\t\t}\n\t\t\t/*for (int k=threadIdx.x;k<n;k+=blockDim.x){\n\t\t\t\tfloat x1=xyz1[i*n*3+k*3+0];\n\t\t\t\tfloat y1=xyz1[i*n*3+k*3+1];\n\t\t\t\tfloat z1=xyz1[i*n*3+k*3+2];\n\t\t\t\tfloat suml=0;\n\t\t\t\tfor (int l=0;l<m;l++){\n\t\t\t\t\tfloat x2=xyz2[i*m*3+l*3+0];\n\t\t\t\t\tfloat y2=xyz2[i*m*3+l*3+1];\n\t\t\t\t\tfloat z2=xyz2[i*m*3+l*3+2];\n\t\t\t\t\tfloat w=expf(level*((x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1)))*ratioL[k]*ratioR[l];\n\t\t\t\t\tmatch[i*n*m+l*n+k]+=w;\n\t\t\t\t\tsuml+=w;\n\t\t\t\t}\n\t\t\t\tremainL[k]=fmaxf(0.0f,remainL[k]-suml);\n\t\t\t}*/\n\t\t\t__syncthreads();\n\t\t}\n\t}\n}\n\n__global__ void matchcostkernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){\n\t__shared__ float allsum[512];\n\tconst int Block=256;\n\t__shared__ float buf[Block*3];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfloat subsum=0;\n\t\tfor (int k0=0;k0<m;k0+=Block){\n\t\t\tint endk=min(m,k0+Block);\n\t\t\tfor (int k=threadIdx.x;k<(endk-k0)*3;k+=blockDim.x){\n\t\t\t\tbuf[k]=xyz2[i*m*3+k0*3+k];\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t\tfor (int j=threadIdx.x;j<n;j+=blockDim.x){\n\t\t\t\tfloat x1=xyz1[(i*n+j)*3+0];\n\t\t\t\tfloat y1=xyz1[(i*n+j)*3+1];\n\t\t\t\tfloat z1=xyz1[(i*n+j)*3+2];\n\t\t\t\tfor (int k=0;k<endk-k0;k++){\n\t\t\t\t\t//float x2=xyz2[(i*m+k)*3+0]-x1;\n\t\t\t\t\t//float y2=xyz2[(i*m+k)*3+1]-y1;\n\t\t\t\t\t//float z2=xyz2[(i*m+k)*3+2]-z1;\n\t\t\t\t\tfloat x2=buf[k*3+0]-x1;\n\t\t\t\t\tfloat y2=buf[k*3+1]-y1;\n\t\t\t\t\tfloat z2=buf[k*3+2]-z1;\n\t\t\t\t\tfloat d=sqrtf(x2*x2+y2*y2+z2*z2);\n\t\t\t\t\tsubsum+=match[i*n*m+(k0+k)*n+j]*d;\n\t\t\t\t}\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t}\n\t\tallsum[threadIdx.x]=subsum;\n\t\tfor (int j=1;j<blockDim.x;j<<=1){\n\t\t\t__syncthreads();\n\t\t\tif ((threadIdx.x&j)==0 && threadIdx.x+j<blockDim.x){\n\t\t\t\tallsum[threadIdx.x]+=allsum[threadIdx.x+j];\n\t\t\t}\n\t\t}\n\t\tif (threadIdx.x==0)\n\t\t\tout[i]=allsum[0];\n\t\t__syncthreads();\n\t}\n}\n//void matchcostLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * out){\n//\tmatchcost<<<32,512>>>(b,n,m,xyz1,xyz2,match,out);\n//}\n\n__global__ void matchcostgrad2kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){\n\t__shared__ float sum_grad[256*3];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tint kbeg=m*blockIdx.y/gridDim.y;\n\t\tint kend=m*(blockIdx.y+1)/gridDim.y;\n\t\tfor (int k=kbeg;k<kend;k++){\n\t\t\tfloat x2=xyz2[(i*m+k)*3+0];\n\t\t\tfloat y2=xyz2[(i*m+k)*3+1];\n\t\t\tfloat z2=xyz2[(i*m+k)*3+2];\n\t\t\tfloat subsumx=0,subsumy=0,subsumz=0;\n\t\t\tfor (int j=threadIdx.x;j<n;j+=blockDim.x){\n\t\t\t\tfloat x1=x2-xyz1[(i*n+j)*3+0];\n\t\t\t\tfloat y1=y2-xyz1[(i*n+j)*3+1];\n\t\t\t\tfloat z1=z2-xyz1[(i*n+j)*3+2];\n\t\t\t\tfloat d=match[i*n*m+k*n+j]*rsqrtf(fmaxf(x1*x1+y1*y1+z1*z1,1e-20f));\n\t\t\t\tsubsumx+=x1*d;\n\t\t\t\tsubsumy+=y1*d;\n\t\t\t\tsubsumz+=z1*d;\n\t\t\t}\n\t\t\tsum_grad[threadIdx.x*3+0]=subsumx;\n\t\t\tsum_grad[threadIdx.x*3+1]=subsumy;\n\t\t\tsum_grad[threadIdx.x*3+2]=subsumz;\n\t\t\tfor (int j=1;j<blockDim.x;j<<=1){\n\t\t\t\t__syncthreads();\n\t\t\t\tint j1=threadIdx.x;\n\t\t\t\tint j2=threadIdx.x+j;\n\t\t\t\tif ((j1&j)==0 && j2<blockDim.x){\n\t\t\t\t\tsum_grad[j1*3+0]+=sum_grad[j2*3+0];\n\t\t\t\t\tsum_grad[j1*3+1]+=sum_grad[j2*3+1];\n\t\t\t\t\tsum_grad[j1*3+2]+=sum_grad[j2*3+2];\n\t\t\t\t}\n\t\t\t}\n\t\t\tif (threadIdx.x==0){\n\t\t\t\tgrad2[(i*m+k)*3+0]=sum_grad[0];\n\t\t\t\tgrad2[(i*m+k)*3+1]=sum_grad[1];\n\t\t\t\tgrad2[(i*m+k)*3+2]=sum_grad[2];\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t}\n\t}\n}\n__global__ void matchcostgrad1kernel(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad1){\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfor (int l=threadIdx.x;l<n;l+=blockDim.x){\n\t\t\tfloat x1=xyz1[i*n*3+l*3+0];\n\t\t\tfloat y1=xyz1[i*n*3+l*3+1];\n\t\t\tfloat z1=xyz1[i*n*3+l*3+2];\n\t\t\tfloat dx=0,dy=0,dz=0;\n\t\t\tfor (int k=0;k<m;k++){\n\t\t\t\tfloat x2=xyz2[i*m*3+k*3+0];\n\t\t\t\tfloat y2=xyz2[i*m*3+k*3+1];\n\t\t\t\tfloat z2=xyz2[i*m*3+k*3+2];\n\t\t\t\tfloat d=match[i*n*m+k*n+l]*rsqrtf(fmaxf((x1-x2)*(x1-x2)+(y1-y2)*(y1-y2)+(z1-z2)*(z1-z2),1e-20f));\n\t\t\t\tdx+=(x1-x2)*d;\n\t\t\t\tdy+=(y1-y2)*d;\n\t\t\t\tdz+=(z1-z2)*d;\n\t\t\t}\n\t\t\tgrad1[i*n*3+l*3+0]=dx;\n\t\t\tgrad1[i*n*3+l*3+1]=dy;\n\t\t\tgrad1[i*n*3+l*3+2]=dz;\n\t\t}\n\t}\n}\n//void matchcostgradLauncher(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad2){\n//\tmatchcostgrad<<<dim3(32,32),256>>>(b,n,m,xyz1,xyz2,match,grad2);\n//}\n\n/*void AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,\n                  cudaStream_t stream)*/\n// temp: TensorShape{b,(n+m)*2}\nvoid approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream){\n\tapproxmatchkernel\n      <<<32, 512, 0, stream>>>(b,n,m,xyz1,xyz2,match,temp);\n      \n  cudaError_t err = cudaGetLastError();\n  if (cudaSuccess != err)\n    throw std::runtime_error(Formatter()\n                             << \"CUDA kernel failed : \" << std::to_string(err));\n}\n\nvoid matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream){\n\tmatchcostkernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,out);\n      \n  cudaError_t err = cudaGetLastError();\n  if (cudaSuccess != err)\n    throw std::runtime_error(Formatter()\n                             << \"CUDA kernel failed : \" << std::to_string(err));\n}\n\nvoid matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream){\n\tmatchcostgrad1kernel<<<32,512,0,stream>>>(b,n,m,xyz1,xyz2,match,grad1);\n\tmatchcostgrad2kernel<<<dim3(32,32),256,0,stream>>>(b,n,m,xyz1,xyz2,match,grad2);\n\t\n    cudaError_t err = cudaGetLastError();\n    if (cudaSuccess != err)\n        throw std::runtime_error(Formatter()\n                             << \"CUDA kernel failed : \" << std::to_string(err));\n}\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/approxmatch.cuh",
    "content": "/*\ntemplate <typename Dtype>\nvoid AddGPUKernel(Dtype *in_a, Dtype *in_b, Dtype *out_c, int N,\n                  cudaStream_t stream);\n*/\nvoid approxmatch(int b,int n,int m,const float * xyz1,const float * xyz2,float * match,float * temp, cudaStream_t stream);\nvoid matchcost(int b,int n,int m,const float * xyz1,const float * xyz2,float * match, float * out, cudaStream_t stream);\nvoid matchcostgrad(int b,int n,int m,const float * xyz1,const float * xyz2,const float * match,float * grad1,float * grad2, cudaStream_t stream);\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/nndistance.cu",
    "content": "\n__global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){\n\tconst int batch=512;\n\t__shared__ float buf[batch*3];\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfor (int k2=0;k2<m;k2+=batch){\n\t\t\tint end_k=min(m,k2+batch)-k2;\n\t\t\tfor (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){\n\t\t\t\tbuf[j]=xyz2[(i*m+k2)*3+j];\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t\tfor (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){\n\t\t\t\tfloat x1=xyz[(i*n+j)*3+0];\n\t\t\t\tfloat y1=xyz[(i*n+j)*3+1];\n\t\t\t\tfloat z1=xyz[(i*n+j)*3+2];\n\t\t\t\tint best_i=0;\n\t\t\t\tfloat best=0;\n\t\t\t\tint end_ka=end_k-(end_k&3);\n\t\t\t\tif (end_ka==batch){\n\t\t\t\t\tfor (int k=0;k<batch;k+=4){\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tfloat x2=buf[k*3+0]-x1;\n\t\t\t\t\t\t\tfloat y2=buf[k*3+1]-y1;\n\t\t\t\t\t\t\tfloat z2=buf[k*3+2]-z1;\n\t\t\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\t\t\tif (k==0 || d<best){\n\t\t\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\t\t\tbest_i=k+k2;\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tfloat x2=buf[k*3+3]-x1;\n\t\t\t\t\t\t\tfloat y2=buf[k*3+4]-y1;\n\t\t\t\t\t\t\tfloat z2=buf[k*3+5]-z1;\n\t\t\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\t\t\tif (d<best){\n\t\t\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\t\t\tbest_i=k+k2+1;\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tfloat x2=buf[k*3+6]-x1;\n\t\t\t\t\t\t\tfloat y2=buf[k*3+7]-y1;\n\t\t\t\t\t\t\tfloat z2=buf[k*3+8]-z1;\n\t\t\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\t\t\tif (d<best){\n\t\t\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\t\t\tbest_i=k+k2+2;\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tfloat x2=buf[k*3+9]-x1;\n\t\t\t\t\t\t\tfloat y2=buf[k*3+10]-y1;\n\t\t\t\t\t\t\tfloat z2=buf[k*3+11]-z1;\n\t\t\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\t\t\tif (d<best){\n\t\t\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\t\t\tbest_i=k+k2+3;\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}else{\n\t\t\t\t\tfor (int k=0;k<end_ka;k+=4){\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tfloat x2=buf[k*3+0]-x1;\n\t\t\t\t\t\t\tfloat y2=buf[k*3+1]-y1;\n\t\t\t\t\t\t\tfloat z2=buf[k*3+2]-z1;\n\t\t\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\t\t\tif (k==0 || d<best){\n\t\t\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\t\t\tbest_i=k+k2;\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tfloat x2=buf[k*3+3]-x1;\n\t\t\t\t\t\t\tfloat y2=buf[k*3+4]-y1;\n\t\t\t\t\t\t\tfloat z2=buf[k*3+5]-z1;\n\t\t\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\t\t\tif (d<best){\n\t\t\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\t\t\tbest_i=k+k2+1;\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tfloat x2=buf[k*3+6]-x1;\n\t\t\t\t\t\t\tfloat y2=buf[k*3+7]-y1;\n\t\t\t\t\t\t\tfloat z2=buf[k*3+8]-z1;\n\t\t\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\t\t\tif (d<best){\n\t\t\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\t\t\tbest_i=k+k2+2;\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t\t{\n\t\t\t\t\t\t\tfloat x2=buf[k*3+9]-x1;\n\t\t\t\t\t\t\tfloat y2=buf[k*3+10]-y1;\n\t\t\t\t\t\t\tfloat z2=buf[k*3+11]-z1;\n\t\t\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\t\t\tif (d<best){\n\t\t\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\t\t\tbest_i=k+k2+3;\n\t\t\t\t\t\t\t}\n\t\t\t\t\t\t}\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\tfor (int k=end_ka;k<end_k;k++){\n\t\t\t\t\tfloat x2=buf[k*3+0]-x1;\n\t\t\t\t\tfloat y2=buf[k*3+1]-y1;\n\t\t\t\t\tfloat z2=buf[k*3+2]-z1;\n\t\t\t\t\tfloat d=x2*x2+y2*y2+z2*z2;\n\t\t\t\t\tif (k==0 || d<best){\n\t\t\t\t\t\tbest=d;\n\t\t\t\t\t\tbest_i=k+k2;\n\t\t\t\t\t}\n\t\t\t\t}\n\t\t\t\tif (k2==0 || result[(i*n+j)]>best){\n\t\t\t\t\tresult[(i*n+j)]=best;\n\t\t\t\t\tresult_i[(i*n+j)]=best_i;\n\t\t\t\t}\n\t\t\t}\n\t\t\t__syncthreads();\n\t\t}\n\t}\n}\nvoid nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){\n\tNmDistanceKernel<<<dim3(32,16,1),512, 0, stream>>>(b,n,xyz,m,xyz2,result,result_i);\n\tNmDistanceKernel<<<dim3(32,16,1),512, 0, stream>>>(b,m,xyz2,n,xyz,result2,result2_i);\n}\n__global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){\n\tfor (int i=blockIdx.x;i<b;i+=gridDim.x){\n\t\tfor (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){\n\t\t\tfloat x1=xyz1[(i*n+j)*3+0];\n\t\t\tfloat y1=xyz1[(i*n+j)*3+1];\n\t\t\tfloat z1=xyz1[(i*n+j)*3+2];\n\t\t\tint j2=idx1[i*n+j];\n\t\t\tfloat x2=xyz2[(i*m+j2)*3+0];\n\t\t\tfloat y2=xyz2[(i*m+j2)*3+1];\n\t\t\tfloat z2=xyz2[(i*m+j2)*3+2];\n\t\t\tfloat g=grad_dist1[i*n+j]*2;\n\t\t\tatomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));\n\t\t\tatomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));\n\t\t\tatomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));\n\t\t\tatomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));\n\t\t\tatomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));\n\t\t\tatomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));\n\t\t}\n\t}\n}\nvoid nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){\n\tcudaMemset(grad_xyz1,0,b*n*3*4);\n\tcudaMemset(grad_xyz2,0,b*m*3*4);\n\tNmDistanceGradKernel<<<dim3(1,16,1),256, 0, stream>>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2);\n\tNmDistanceGradKernel<<<dim3(1,16,1),256, 0, stream>>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1);\n}\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/nndistance.cuh",
    "content": "void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);\nvoid nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/structural_loss.cpp",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\n#include \"src/approxmatch.cuh\"\n#include \"src/nndistance.cuh\"\n\n#include <vector>\n#include <iostream>\n\n#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\n/*\ninput:\n\tset1 : batch_size * #dataset_points * 3\n\tset2 : batch_size * #query_points * 3\nreturns:\n\tmatch : batch_size * #query_points * #dataset_points\n*/\n//  temp: TensorShape{b,(n+m)*2}\nstd::vector<at::Tensor> ApproxMatch(at::Tensor set_d, at::Tensor set_q) {\n    //std::cout << \"[ApproxMatch] Called.\" << std::endl;\n    int64_t batch_size = set_d.size(0);    \n    int64_t n_dataset_points = set_d.size(1); // n\n    int64_t n_query_points = set_q.size(1);   // m\n    //std::cout << \"[ApproxMatch] batch_size:\" << batch_size << std::endl;\n    at::Tensor match = torch::empty({batch_size, n_query_points, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    at::Tensor temp = torch::empty({batch_size, (n_query_points+n_dataset_points)*2}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    CHECK_INPUT(set_d);\n    CHECK_INPUT(set_q);\n    CHECK_INPUT(match);\n    CHECK_INPUT(temp);\n    \n    approxmatch(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),temp.data<float>(), at::cuda::getCurrentCUDAStream());\n    return {match, temp};\n}\n\nat::Tensor MatchCost(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {\n    //std::cout << \"[MatchCost] Called.\" << std::endl;\n    int64_t batch_size = set_d.size(0);    \n    int64_t n_dataset_points = set_d.size(1); // n\n    int64_t n_query_points = set_q.size(1);   // m\n    //std::cout << \"[MatchCost] batch_size:\" << batch_size << std::endl;\n    at::Tensor out = torch::empty({batch_size}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    CHECK_INPUT(set_d);\n    CHECK_INPUT(set_q);\n    CHECK_INPUT(match);\n    CHECK_INPUT(out);\n    matchcost(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),out.data<float>(),at::cuda::getCurrentCUDAStream());\n    return out;\n}\n\nstd::vector<at::Tensor> MatchCostGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor match) {\n    //std::cout << \"[MatchCostGrad] Called.\" << std::endl;\n    int64_t batch_size = set_d.size(0);    \n    int64_t n_dataset_points = set_d.size(1); // n\n    int64_t n_query_points = set_q.size(1);   // m\n    //std::cout << \"[MatchCostGrad] batch_size:\" << batch_size << std::endl;\n    at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    CHECK_INPUT(set_d);\n    CHECK_INPUT(set_q);\n    CHECK_INPUT(match);\n    CHECK_INPUT(grad1);\n    CHECK_INPUT(grad2);\n    matchcostgrad(batch_size,n_dataset_points,n_query_points,set_d.data<float>(),set_q.data<float>(),match.data<float>(),grad1.data<float>(),grad2.data<float>(),at::cuda::getCurrentCUDAStream());\n    return {grad1, grad2};\n}\n\n\n/*\ninput:\n\tset_d : batch_size * #dataset_points * 3\n\tset_q : batch_size * #query_points * 3\nreturns:\n\tdist1, idx1 : batch_size * #dataset_points\n\tdist2, idx2 : batch_size * #query_points\n*/\nstd::vector<at::Tensor> NNDistance(at::Tensor set_d, at::Tensor set_q) {\n    //std::cout << \"[NNDistance] Called.\" << std::endl;\n    int64_t batch_size = set_d.size(0);    \n    int64_t n_dataset_points = set_d.size(1); // n\n    int64_t n_query_points = set_q.size(1);   // m\n    //std::cout << \"[NNDistance] batch_size:\" << batch_size << std::endl;\n    at::Tensor dist1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    at::Tensor idx1 = torch::empty({batch_size, n_dataset_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));\n    at::Tensor dist2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    at::Tensor idx2 = torch::empty({batch_size, n_query_points}, torch::TensorOptions().dtype(torch::kInt32).device(set_d.device()));\n    CHECK_INPUT(set_d);\n    CHECK_INPUT(set_q);\n    CHECK_INPUT(dist1);\n    CHECK_INPUT(idx1);\n    CHECK_INPUT(dist2);\n    CHECK_INPUT(idx2);\n    // void nndistance(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream);\n    nndistance(batch_size,n_dataset_points,set_d.data<float>(),n_query_points,set_q.data<float>(),dist1.data<float>(),idx1.data<int>(),dist2.data<float>(),idx2.data<int>(), at::cuda::getCurrentCUDAStream());\n    return {dist1, idx1, dist2, idx2};\n}\n\nstd::vector<at::Tensor> NNDistanceGrad(at::Tensor set_d, at::Tensor set_q, at::Tensor idx1, at::Tensor idx2, at::Tensor grad_dist1, at::Tensor grad_dist2) {\n    //std::cout << \"[NNDistanceGrad] Called.\" << std::endl;\n    int64_t batch_size = set_d.size(0);    \n    int64_t n_dataset_points = set_d.size(1); // n\n    int64_t n_query_points = set_q.size(1);   // m\n    //std::cout << \"[NNDistanceGrad] batch_size:\" << batch_size << std::endl;\n    at::Tensor grad1 = torch::empty({batch_size,n_dataset_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    at::Tensor grad2 = torch::empty({batch_size,n_query_points,3}, torch::TensorOptions().dtype(torch::kFloat32).device(set_d.device()));\n    CHECK_INPUT(set_d);\n    CHECK_INPUT(set_q);\n    CHECK_INPUT(idx1);\n    CHECK_INPUT(idx2);\n    CHECK_INPUT(grad_dist1);\n    CHECK_INPUT(grad_dist2);\n    CHECK_INPUT(grad1);\n    CHECK_INPUT(grad2);\n    //void nndistancegrad(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream);\n    nndistancegrad(batch_size,n_dataset_points,set_d.data<float>(),n_query_points,set_q.data<float>(),\n        grad_dist1.data<float>(),idx1.data<int>(),\n        grad_dist2.data<float>(),idx2.data<int>(),\n        grad1.data<float>(),grad2.data<float>(),\n        at::cuda::getCurrentCUDAStream());\n    return {grad1, grad2};\n}\n\n"
  },
  {
    "path": "metrics/pytorch_structural_losses/src/utils.hpp",
    "content": "#include <iostream>\n#include <sstream>\n#include <string>\n\nclass Formatter {\npublic:\n  Formatter() {}\n  ~Formatter() {}\n\n  template <typename Type> Formatter &operator<<(const Type &value) {\n    stream_ << value;\n    return *this;\n  }\n\n  std::string str() const { return stream_.str(); }\n  operator std::string() const { return stream_.str(); }\n\n  enum ConvertToString { to_str };\n\n  std::string operator>>(ConvertToString) { return stream_.str(); }\n\nprivate:\n  std::stringstream stream_;\n  Formatter(const Formatter &);\n  Formatter &operator=(Formatter &);\n};\n"
  },
  {
    "path": "models/__init__.py",
    "content": "#!/usr/bin/python3\n\nfrom models.sdf_model import SdfModel\nfrom models.autoencoder import BetaVAE\n\nfrom models.archs.encoders.conv_pointnet import UNet\n\nfrom models.diffusion import *\nfrom models.archs.diffusion_arch import * \n#from diffusion import *\nfrom models.sdf_model import SdfModel\n\nfrom models.combined_model import CombinedModel\n\n\n"
  },
  {
    "path": "models/archs/__init__.py",
    "content": "#!/usr/bin/python3\n\n"
  },
  {
    "path": "models/archs/diffusion_arch.py",
    "content": "import math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum \n\nfrom einops import rearrange, repeat, reduce\nfrom einops.layers.torch import Rearrange\nfrom einops_exts import rearrange_many, repeat_many, check_shape\n\nfrom rotary_embedding_torch import RotaryEmbedding\n\nfrom diff_utils.model_utils import * \n\nfrom random import sample\n\nclass CausalTransformer(nn.Module):\n    def __init__(\n        self,\n        dim, \n        depth,\n        dim_in_out=None,\n        cross_attn=False,\n        dim_head = 64,\n        heads = 8,\n        ff_mult = 4,\n        norm_in = False,\n        norm_out = True, \n        attn_dropout = 0.,\n        ff_dropout = 0.,\n        final_proj = True, \n        normformer = False,\n        rotary_emb = True, \n        **kwargs\n    ):\n        super().__init__()\n        self.init_norm = LayerNorm(dim) if norm_in else nn.Identity() # from latest BLOOM model and Yandex's YaLM\n\n        self.rel_pos_bias = RelPosBias(heads = heads)\n\n        rotary_emb = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None\n        rotary_emb_cross = RotaryEmbedding(dim = min(32, dim_head)) if rotary_emb else None\n\n        self.layers = nn.ModuleList([])\n\n        dim_in_out = default(dim_in_out, dim)\n        self.use_same_dims = (dim_in_out is None) or (dim_in_out==dim)\n        point_feature_dim = kwargs.get('point_feature_dim', dim)\n\n        if cross_attn:\n            #print(\"using CROSS ATTN, with dropout {}\".format(attn_dropout))\n            self.layers.append(nn.ModuleList([\n                    Attention(dim = dim_in_out, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),\n                    Attention(dim = dim, kv_dim=point_feature_dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),\n                    FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)\n                ]))\n            for _ in range(depth):\n                self.layers.append(nn.ModuleList([\n                    Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),\n                    Attention(dim = dim, kv_dim=point_feature_dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),\n                    FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)\n                ]))\n            self.layers.append(nn.ModuleList([\n                    Attention(dim = dim, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, rotary_emb = rotary_emb),\n                    Attention(dim = dim, kv_dim=point_feature_dim, out_dim=dim_in_out, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb_cross),\n                    FeedForward(dim = dim_in_out, out_dim=dim_in_out, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)\n                ]))\n        else:\n            self.layers.append(nn.ModuleList([\n                    Attention(dim = dim_in_out, out_dim=dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),\n                    FeedForward(dim = dim, out_dim=dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)\n                ]))\n            for _ in range(depth):\n                self.layers.append(nn.ModuleList([\n                    Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),\n                    FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)\n                ]))\n            self.layers.append(nn.ModuleList([\n                    Attention(dim = dim, out_dim=dim_in_out, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),\n                    FeedForward(dim = dim_in_out, out_dim=dim_in_out, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)\n                ]))\n\n        self.norm = LayerNorm(dim_in_out, stable = True) if norm_out else nn.Identity()  # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options\n        self.project_out = nn.Linear(dim_in_out, dim_in_out, bias = False) if final_proj else nn.Identity()\n\n        self.cross_attn = cross_attn\n\n    def forward(self, x, time_emb=None, context=None):\n        n, device = x.shape[1], x.device\n\n        x = self.init_norm(x)\n\n        attn_bias = self.rel_pos_bias(n, n + 1, device = device)\n\n        if self.cross_attn:\n            #assert context is not None \n            for idx, (self_attn, cross_attn, ff) in enumerate(self.layers):\n                #print(\"x1 shape: \", x.shape)\n                if (idx==0 or idx==len(self.layers)-1) and not self.use_same_dims:\n                    x = self_attn(x, attn_bias = attn_bias)\n                    x = cross_attn(x, context=context) # removing attn_bias for now \n                else:\n                    x = self_attn(x, attn_bias = attn_bias) + x \n                    x = cross_attn(x, context=context) + x  # removing attn_bias for now \n                #print(\"x2 shape, context shape: \", x.shape, context.shape)\n                \n                #print(\"x3 shape, context shape: \", x.shape, context.shape)\n                x = ff(x) + x\n        \n        else:\n            for idx, (attn, ff) in enumerate(self.layers):\n                #print(\"x1 shape: \", x.shape)\n                if (idx==0 or idx==len(self.layers)-1) and not self.use_same_dims:\n                    x = attn(x, attn_bias = attn_bias)\n                else:\n                    x = attn(x, attn_bias = attn_bias) + x\n                #print(\"x2 shape: \", x.shape)\n                x = ff(x) + x\n                #print(\"x3 shape: \", x.shape)\n\n        out = self.norm(x)\n        return self.project_out(out)\n\nclass DiffusionNet(nn.Module):\n\n    def __init__(\n        self,\n        dim,\n        dim_in_out=None,\n        num_timesteps = None,\n        num_time_embeds = 1,\n        cond = None,\n        **kwargs\n    ):\n        super().__init__()\n        self.num_time_embeds = num_time_embeds\n        self.dim = dim\n        self.cond = cond\n        self.cross_attn = kwargs.get('cross_attn', False)\n        self.cond_dropout = kwargs.get('cond_dropout', False)\n        self.point_feature_dim = kwargs.get('point_feature_dim', dim)\n\n        self.dim_in_out = default(dim_in_out, dim)\n        #print(\"dim, in out, point feature dim: \", dim, dim_in_out, self.point_feature_dim)\n        #print(\"cond dropout: \", self.cond_dropout)\n\n        self.to_time_embeds = nn.Sequential(\n            nn.Embedding(num_timesteps, self.dim_in_out * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(self.dim_in_out), MLP(self.dim_in_out, self.dim_in_out * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP\n            Rearrange('b (n d) -> b n d', n = num_time_embeds)\n        )\n\n        # last input to the transformer: \"a final embedding whose output from the Transformer is used to predicted the unnoised CLIP image embedding\"\n        self.learned_query = nn.Parameter(torch.randn(self.dim_in_out))\n        self.causal_transformer = CausalTransformer(dim = dim, dim_in_out=self.dim_in_out, **kwargs)\n\n        if cond:\n            # output dim of pointnet needs to match model dim; unless add additional linear layer\n            self.pointnet = ConvPointnet(c_dim=self.point_feature_dim) \n\n\n    def forward(\n        self,\n        data, \n        diffusion_timesteps,\n        pass_cond=-1, # default -1, depends on prob; but pass as argument during sampling\n\n    ):\n\n        if self.cond:\n            assert type(data) is tuple\n            data, cond = data # adding noise to cond_feature so doing this in diffusion.py\n\n            #print(\"data, cond shape: \", data.shape, cond.shape) # B, dim_in_out; B, N, 3\n            #print(\"pass cond: \", pass_cond)\n            if self.cond_dropout:\n                # classifier-free guidance: 20% unconditional \n                prob = torch.randint(low=0, high=10, size=(1,))\n                percentage = 8\n                if prob < percentage or pass_cond==0:\n                    cond_feature = torch.zeros( (cond.shape[0], cond.shape[1], self.point_feature_dim), device=data.device )\n                    #print(\"zeros shape: \", cond_feature.shape) \n                elif prob >= percentage or pass_cond==1:\n                    cond_feature = self.pointnet(cond, cond)\n                    #print(\"cond shape: \", cond_feature.shape)\n            else:\n                cond_feature = self.pointnet(cond, cond)\n\n            \n        batch, dim, device, dtype = *data.shape, data.device, data.dtype\n\n        num_time_embeds = self.num_time_embeds\n        time_embed = self.to_time_embeds(diffusion_timesteps)\n\n        data = data.unsqueeze(1)\n\n        learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)\n\n        model_inputs = [time_embed, data, learned_queries]\n\n        if self.cond and not self.cross_attn:\n            model_inputs.insert(0, cond_feature) # cond_feature defined in first loop above \n        \n        tokens = torch.cat(model_inputs, dim = 1) # (b, 3/4, d); batch and d=512 same across the model_inputs \n        #print(\"tokens shape: \", tokens.shape)\n\n        if self.cross_attn:\n            cond_feature = None if not self.cond else cond_feature\n            #print(\"tokens shape: \", tokens.shape, cond_feature.shape)\n            tokens = self.causal_transformer(tokens, context=cond_feature)\n        else:\n            tokens = self.causal_transformer(tokens)\n\n        # get learned query, which should predict the sdf layer embedding (per DDPM timestep)\n        pred = tokens[..., -1, :]\n\n        return pred\n\n"
  },
  {
    "path": "models/archs/encoders/__init__.py",
    "content": "#!/usr/bin/python3\n\n"
  },
  {
    "path": "models/archs/encoders/auto_decoder.py",
    "content": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport json\nimport math\n\nclass AutoDecoder():\n    # specs is a json filepath that contains the specifications for the experiment \n    # also requires total num_scenes (all meshes from all classes), which is given by len of dataset\n    def __init__(self, num_scenes, latent_size):\n        self.num_scenes = num_scenes\n        self.latent_size = latent_size\n\n    def build_model(self):\n        lat_vecs = nn.Embedding(self.num_scenes, self.latent_size, max_norm=1.0)\n        nn.init.normal_(lat_vecs.weight.data, 0.0, 1.0/math.sqrt(self.latent_size))\n        return lat_vecs\n"
  },
  {
    "path": "models/archs/encoders/conv_pointnet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import init\n\nfrom torch_scatter import scatter_mean, scatter_max\n\n\nclass ConvPointnet(nn.Module):\n    ''' PointNet-based encoder network with ResNet blocks for each point.\n        Number of input points are fixed.\n    \n    Args:\n        c_dim (int): dimension of latent code c\n        dim (int): input points dimension\n        hidden_dim (int): hidden dimension of the network\n        scatter_type (str): feature aggregation when doing local pooling\n        unet (bool): weather to use U-Net\n        unet_kwargs (str): U-Net parameters\n        plane_resolution (int): defined resolution for plane feature\n        plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume\n        padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]\n        n_blocks (int): number of blocks ResNetBlockFC layers\n    '''\n\n    def __init__(self, c_dim=512, dim=3, hidden_dim=128, scatter_type='max', \n                 unet=True, unet_kwargs={\"depth\": 4, \"merge_mode\": \"concat\", \"start_filts\": 32}, \n                 plane_resolution=64, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5,\n                 inject_noise=False):\n        super().__init__()\n        self.c_dim = c_dim\n\n        self.fc_pos = nn.Linear(dim, 2*hidden_dim)\n        self.blocks = nn.ModuleList([\n            ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)\n        ])\n        self.fc_c = nn.Linear(hidden_dim, c_dim)\n\n        self.actvn = nn.ReLU()\n        self.hidden_dim = hidden_dim\n\n        if unet:\n            self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)\n        else:\n            self.unet = None\n\n        self.reso_plane = plane_resolution\n        self.plane_type = plane_type\n        self.padding = padding\n\n        if scatter_type == 'max':\n            self.scatter = scatter_max\n        elif scatter_type == 'mean':\n            self.scatter = scatter_mean\n\n    def generate_plane_features(self, p, c, plane='xz'):\n        # acquire indices of features in plane\n        xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)\n        index = self.coordinate2index(xy, self.reso_plane)\n\n        # scatter plane features from points\n        fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)\n        c = c.permute(0, 2, 1) # B x 512 x T\n        fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2\n        fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)\n\n        # process the plane features with UNet\n        if self.unet is not None:\n            fea_plane = self.unet(fea_plane)\n\n        return fea_plane \n\n    # takes in \"p\": point cloud and \"query\": sdf_xyz \n    # sample plane features for unlabeled_query as well \n    def forward(self, p, query):\n        batch_size, T, D = p.size()\n\n        # acquire the index for each point\n        coord = {}\n        index = {}\n        if 'xz' in self.plane_type:\n            coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)\n            index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)\n        if 'xy' in self.plane_type:\n            coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)\n            index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)\n        if 'yz' in self.plane_type:\n            coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)\n            index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)\n\n        \n        net = self.fc_pos(p)\n\n        net = self.blocks[0](net)\n        for block in self.blocks[1:]:\n            pooled = self.pool_local(coord, index, net)\n            net = torch.cat([net, pooled], dim=2)\n            net = block(net)\n\n        c = self.fc_c(net)\n        \n        fea = {}\n        plane_feat_sum = 0\n        #denoise_loss = 0\n        if 'xz' in self.plane_type:\n            fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)\n            plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')\n        if 'xy' in self.plane_type:\n            fea['xy'] = self.generate_plane_features(p, c, plane='xy')\n            plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')\n        if 'yz' in self.plane_type:\n            fea['yz'] = self.generate_plane_features(p, c, plane='yz')\n            plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')\n\n        return plane_feat_sum.transpose(2,1)\n\n    # given plane features with dimensions (3*dim, 64, 64)\n    # first reshape into the three planes, then generate query features from it \n    def forward_with_plane_features(self, plane_features, query):\n        # plane features shape: batch, dim*3, 64, 64\n        idx = int(plane_features.shape[1] / 3)\n        fea = {}\n        fea['xz'], fea['xy'], fea['yz'] = plane_features[:,0:idx,...], plane_features[:,idx:idx*2,...], plane_features[:,idx*2:,...]\n        #print(\"shapes: \", fea['xz'].shape, fea['xy'].shape, fea['yz'].shape) #([1, 256, 64, 64])\n        plane_feat_sum = 0\n\n        plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')\n        plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')\n        plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')\n\n        return plane_feat_sum.transpose(2,1)\n\n    # c is point cloud features\n    # p is point cloud (coordinates)\n    def forward_with_pc_features(self, c, p, query):\n\n        #print(\"c, p shapes:\", c.shape, p.shape)\n\n        fea = {}\n        fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)\n        fea['xy'] = self.generate_plane_features(p, c, plane='xy')\n        fea['yz'] = self.generate_plane_features(p, c, plane='yz')\n\n        plane_feat_sum = 0\n\n        plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')\n        plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')\n        plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')\n\n        return plane_feat_sum.transpose(2,1)\n\n\n    def get_point_cloud_features(self, p):\n        batch_size, T, D = p.size()\n\n        # acquire the index for each point\n        coord = {}\n        index = {}\n        if 'xz' in self.plane_type:\n            coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)\n            index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)\n        if 'xy' in self.plane_type:\n            coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)\n            index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)\n        if 'yz' in self.plane_type:\n            coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)\n            index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)\n\n        net = self.fc_pos(p)\n\n        net = self.blocks[0](net)\n        for block in self.blocks[1:]:\n            pooled = self.pool_local(coord, index, net)\n            net = torch.cat([net, pooled], dim=2)\n            net = block(net)\n\n        c = self.fc_c(net)\n\n        return c\n\n    def get_plane_features(self, p):\n\n        c = self.get_point_cloud_features(p)\n        fea = {}\n        if 'xz' in self.plane_type:\n            fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)\n        if 'xy' in self.plane_type:\n            fea['xy'] = self.generate_plane_features(p, c, plane='xy')\n        if 'yz' in self.plane_type:\n            fea['yz'] = self.generate_plane_features(p, c, plane='yz')\n\n        return fea['xz'], fea['xy'], fea['yz']\n\n\n    def normalize_coordinate(self, p, padding=0.1, plane='xz'):\n        ''' Normalize coordinate to [0, 1] for unit cube experiments\n\n        Args:\n            p (tensor): point\n            padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]\n            plane (str): plane feature type, ['xz', 'xy', 'yz']\n        '''\n        if plane == 'xz':\n            xy = p[:, :, [0, 2]]\n        elif plane =='xy':\n            xy = p[:, :, [0, 1]]\n        else:\n            xy = p[:, :, [1, 2]]\n\n        xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)\n        xy_new = xy_new + 0.5 # range (0, 1)\n\n        # f there are outliers out of the range\n        if xy_new.max() >= 1:\n            xy_new[xy_new >= 1] = 1 - 10e-6\n        if xy_new.min() < 0:\n            xy_new[xy_new < 0] = 0.0\n        return xy_new\n\n\n    def coordinate2index(self, x, reso):\n        ''' Normalize coordinate to [0, 1] for unit cube experiments.\n            Corresponds to our 3D model\n\n        Args:\n            x (tensor): coordinate\n            reso (int): defined resolution\n            coord_type (str): coordinate type\n        '''\n        x = (x * reso).long()\n        index = x[:, :, 0] + reso * x[:, :, 1]\n        index = index[:, None, :]\n        return index\n\n\n    # xy is the normalized coordinates of the point cloud of each plane \n    # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input \n    def pool_local(self, xy, index, c):\n        bs, fea_dim = c.size(0), c.size(2)\n        keys = xy.keys()\n\n        c_out = 0\n        for key in keys:\n            # scatter plane features from points\n            fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)\n            if self.scatter == scatter_max:\n                fea = fea[0]\n            # gather feature back to points\n            fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))\n            c_out += fea\n        return c_out.permute(0, 2, 1)\n\n    # sample_plane_feature function copied from /src/conv_onet/models/decoder.py\n    # uses values from plane_feature and pixel locations from vgrid to interpolate feature\n    def sample_plane_feature(self, query, plane_feature, plane):\n        xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)\n        xy = xy[:, :, None].float()\n        vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)\n        sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)\n        return sampled_feat\n\n\ndef conv3x3(in_channels, out_channels, stride=1, \n            padding=1, bias=True, groups=1):    \n    return nn.Conv2d(\n        in_channels,\n        out_channels,\n        kernel_size=3,\n        stride=stride,\n        padding=padding,\n        bias=bias,\n        groups=groups)\n\ndef upconv2x2(in_channels, out_channels, mode='transpose'):\n    if mode == 'transpose':\n        return nn.ConvTranspose2d(\n            in_channels,\n            out_channels,\n            kernel_size=2,\n            stride=2)\n    else:\n        # out_channels is always going to be the same\n        # as in_channels\n        return nn.Sequential(\n            nn.Upsample(mode='bilinear', scale_factor=2),\n            conv1x1(in_channels, out_channels))\n\ndef conv1x1(in_channels, out_channels, groups=1):\n    return nn.Conv2d(\n        in_channels,\n        out_channels,\n        kernel_size=1,\n        groups=groups,\n        stride=1)\n\n\nclass DownConv(nn.Module):\n    \"\"\"\n    A helper Module that performs 2 convolutions and 1 MaxPool.\n    A ReLU activation follows each convolution.\n    \"\"\"\n    def __init__(self, in_channels, out_channels, pooling=True):\n        super(DownConv, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.pooling = pooling\n\n        self.conv1 = conv3x3(self.in_channels, self.out_channels)\n        self.conv2 = conv3x3(self.out_channels, self.out_channels)\n\n        if self.pooling:\n            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        before_pool = x\n        if self.pooling:\n            x = self.pool(x)\n        return x, before_pool\n\n\nclass UpConv(nn.Module):\n    \"\"\"\n    A helper Module that performs 2 convolutions and 1 UpConvolution.\n    A ReLU activation follows each convolution.\n    \"\"\"\n    def __init__(self, in_channels, out_channels, \n                 merge_mode='concat', up_mode='transpose'):\n        super(UpConv, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.merge_mode = merge_mode\n        self.up_mode = up_mode\n\n        self.upconv = upconv2x2(self.in_channels, self.out_channels, \n            mode=self.up_mode)\n\n        if self.merge_mode == 'concat':\n            self.conv1 = conv3x3(\n                2*self.out_channels, self.out_channels)\n        else:\n            # num of input channels to conv2 is same\n            self.conv1 = conv3x3(self.out_channels, self.out_channels)\n        self.conv2 = conv3x3(self.out_channels, self.out_channels)\n\n\n    def forward(self, from_down, from_up):\n        \"\"\" Forward pass\n        Arguments:\n            from_down: tensor from the encoder pathway\n            from_up: upconv'd tensor from the decoder pathway\n        \"\"\"\n        from_up = self.upconv(from_up)\n        if self.merge_mode == 'concat':\n            x = torch.cat((from_up, from_down), 1)\n        else:\n            x = from_up + from_down\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        return x\n\n\nclass UNet(nn.Module):\n    \"\"\" `UNet` class is based on https://arxiv.org/abs/1505.04597\n\n    The U-Net is a convolutional encoder-decoder neural network.\n    Contextual spatial information (from the decoding,\n    expansive pathway) about an input tensor is merged with\n    information representing the localization of details\n    (from the encoding, compressive pathway).\n\n    Modifications to the original paper:\n    (1) padding is used in 3x3 convolutions to prevent loss\n        of border pixels\n    (2) merging outputs does not require cropping due to (1)\n    (3) residual connections can be used by specifying\n        UNet(merge_mode='add')\n    (4) if non-parametric upsampling is used in the decoder\n        pathway (specified by upmode='upsample'), then an\n        additional 1x1 2d convolution occurs after upsampling\n        to reduce channel dimensionality by a factor of 2.\n        This channel halving happens with the convolution in\n        the tranpose convolution (specified by upmode='transpose')\n    \"\"\"\n\n    def __init__(self, num_classes, in_channels=3, depth=5, \n                 start_filts=64, up_mode='transpose', same_channels=False,\n                 merge_mode='concat', **kwargs):\n        \"\"\"\n        Arguments:\n            in_channels: int, number of channels in the input tensor.\n                Default is 3 for RGB images.\n            depth: int, number of MaxPools in the U-Net.\n            start_filts: int, number of convolutional filters for the \n                first conv.\n            up_mode: string, type of upconvolution. Choices: 'transpose'\n                for transpose convolution or 'upsample' for nearest neighbour\n                upsampling.\n        \"\"\"\n        super(UNet, self).__init__()\n\n        if up_mode in ('transpose', 'upsample'):\n            self.up_mode = up_mode\n        else:\n            raise ValueError(\"\\\"{}\\\" is not a valid mode for \"\n                             \"upsampling. Only \\\"transpose\\\" and \"\n                             \"\\\"upsample\\\" are allowed.\".format(up_mode))\n    \n        if merge_mode in ('concat', 'add'):\n            self.merge_mode = merge_mode\n        else:\n            raise ValueError(\"\\\"{}\\\" is not a valid mode for\"\n                             \"merging up and down paths. \"\n                             \"Only \\\"concat\\\" and \"\n                             \"\\\"add\\\" are allowed.\".format(up_mode))\n\n        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'\n        if self.up_mode == 'upsample' and self.merge_mode == 'add':\n            raise ValueError(\"up_mode \\\"upsample\\\" is incompatible \"\n                             \"with merge_mode \\\"add\\\" at the moment \"\n                             \"because it doesn't make sense to use \"\n                             \"nearest neighbour to reduce \"\n                             \"depth channels (by half).\")\n\n        self.num_classes = num_classes\n        self.in_channels = in_channels\n        self.start_filts = start_filts\n        self.depth = depth\n\n        self.down_convs = []\n        self.up_convs = []\n\n        # create the encoder pathway and add to a list\n        for i in range(depth):\n            ins = self.in_channels if i == 0 else outs\n            outs = self.start_filts*(2**i) if not same_channels else self.in_channels\n            pooling = True if i < depth-1 else False\n            #print(\"down ins, outs: \", ins, outs)  # [latent dim, 32], [32, 64]...[128, 256]\n\n            down_conv = DownConv(ins, outs, pooling=pooling)\n            self.down_convs.append(down_conv)\n\n        # create the decoder pathway and add to a list\n        # - careful! decoding only requires depth-1 blocks\n        for i in range(depth-1):\n            ins = outs\n            outs = ins // 2 if not same_channels else ins \n            up_conv = UpConv(ins, outs, up_mode=up_mode,\n                merge_mode=merge_mode)\n            self.up_convs.append(up_conv)\n            #print(\"up ins, outs: \", ins, outs)# [256, 128]...[64, 32]; final 32 to latent is done through self.conv_final \n\n        # add the list of modules to current module\n        self.down_convs = nn.ModuleList(self.down_convs)\n        self.up_convs = nn.ModuleList(self.up_convs)\n\n        self.conv_final = conv1x1(outs, self.num_classes)\n\n        self.reset_params()\n\n    @staticmethod\n    def weight_init(m):\n        if isinstance(m, nn.Conv2d):\n            init.xavier_normal_(m.weight)\n            init.constant_(m.bias, 0)\n\n\n    def reset_params(self):\n        for i, m in enumerate(self.modules()):\n            self.weight_init(m)\n\n\n    def forward(self, x):\n        encoder_outs = []\n        # encoder pathway, save outputs for merging\n        for i, module in enumerate(self.down_convs):\n            #print(\"down {} x1: \".format(i), x.shape) # increasing channels but decreasing resolution (64x64 -> 8x8)\n            x, before_pool = module(x)\n            #print(\"down {} x2: \".format(i), x.shape)\n            encoder_outs.append(before_pool)\n        for i, module in enumerate(self.up_convs):\n            before_pool = encoder_outs[-(i+2)]\n            #print(\"up {} x1: \".format(i), x.shape)\n            x = module(before_pool, x)\n            #print(\"up {} x2: \".format(i), x.shape)\n        #exit()\n        \n        # No softmax is used. This means you need to use\n        # nn.CrossEntropyLoss is your training script,\n        # as this module includes a softmax already.\n        x = self.conv_final(x)\n        return x\n\n    def generate(self, x):\n        return self(x)\n\n# Resnet Blocks\nclass ResnetBlockFC(nn.Module):\n    ''' Fully connected ResNet Block class.\n    Args:\n        size_in (int): input dimension\n        size_out (int): output dimension\n        size_h (int): hidden dimension\n    '''\n\n    def __init__(self, size_in, size_out=None, size_h=None):\n        super().__init__()\n        # Attributes\n        if size_out is None:\n            size_out = size_in\n\n        if size_h is None:\n            size_h = min(size_in, size_out)\n\n        self.size_in = size_in\n        self.size_h = size_h\n        self.size_out = size_out\n        # Submodules\n        self.fc_0 = nn.Linear(size_in, size_h)\n        self.fc_1 = nn.Linear(size_h, size_out)\n        self.actvn = nn.ReLU()\n\n        if size_in == size_out:\n            self.shortcut = None\n        else:\n            self.shortcut = nn.Linear(size_in, size_out, bias=False)\n        # Initialization\n        nn.init.zeros_(self.fc_1.weight)\n\n    def forward(self, x):\n        net = self.fc_0(self.actvn(x))\n        dx = self.fc_1(self.actvn(net))\n\n        if self.shortcut is not None:\n            x_s = self.shortcut(x)\n        else:\n            x_s = x\n\n        return x_s + dx"
  },
  {
    "path": "models/archs/encoders/dgcnn.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef knn(x, k):\n    inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)\n    xx = torch.sum(x ** 2, dim=1, keepdim=True)\n    pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()\n\n    idx = pairwise_distance.topk(k=k, dim=-1)[1]\n    return idx\n\ndef get_graph_feature(x, k=20):\n    idx = knn(x, k=k)\n    batch_size, num_points, _ = idx.size()\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points\n\n    idx = idx + idx_base\n\n    idx = idx.view(-1)\n\n    _, num_dims, _ = x.size()\n\n    x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims)  \n                                       # -> (batch_size*num_points, num_dims) \n                                       #   batch_size * num_points * k + range(0, batch_size*num_points)\n    feature = x.view(batch_size * num_points, -1)[idx, :]\n    feature = feature.view(batch_size, num_points, k, num_dims)\n    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)\n    feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)\n    return feature\n\nclass DGCNN(nn.Module):\n\n    def __init__(\n        self, \n        emb_dims=512,\n        use_bn=False\n    ):\n\n        super().__init__()\n\n        if use_bn:\n            self.bn1 = nn.BatchNorm2d(64)\n            self.bn2 = nn.BatchNorm2d(64)\n            self.bn3 = nn.BatchNorm2d(128)\n            self.bn4 = nn.BatchNorm2d(256)\n            self.bn5 = nn.BatchNorm2d(emb_dims)\n\n            self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), self.bn1, nn.LeakyReLU(negative_slope=0.2))\n            self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), self.bn2, nn.LeakyReLU(negative_slope=0.2))\n            self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), self.bn3, nn.LeakyReLU(negative_slope=0.2))\n            self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), self.bn4, nn.LeakyReLU(negative_slope=0.2))\n            self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), self.bn5, nn.LeakyReLU(negative_slope=0.2))\n\n        else:\n            self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n            self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n            self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n            self.conv4 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n            self.conv5 = nn.Sequential(nn.Conv2d(512, emb_dims, kernel_size=1, bias=False), nn.LeakyReLU(negative_slope=0.2))\n\n    def forward(self, x):\n        batch_size, num_dims, num_points = x.size()                 # x:      batch x   3 x num of points\n        x = get_graph_feature(x)                                    # x:      batch x   6 x num of points x 20\n\n        x1     = self.conv1(x)                                      # x1:     batch x  64 x num of points x 20\n        x1_max = x1.max(dim=-1, keepdim=True)[0]                    # x1_max: batch x  64 x num of points x 1\n\n        x2     = self.conv2(x1)                                     # x2:     batch x  64 x num of points x 20\n        x2_max = x2.max(dim=-1, keepdim=True)[0]                    # x2_max: batch x  64 x num of points x 1\n\n        x3     = self.conv3(x2)                                     # x3:     batch x 128 x num of points x 20\n        x3_max = x3.max(dim=-1, keepdim=True)[0]                    # x3_max: batch x 128 x num of points x 1\n\n        x4     = self.conv4(x3)                                     # x4:     batch x 256 x num of points x 20\n        x4_max = x4.max(dim=-1, keepdim=True)[0]                    # x4_max: batch x 256 x num of points x 1\n \n        x_max  = torch.cat((x1_max, x2_max, x3_max, x4_max), dim=1) # x_max:  batch x 512 x num of points x 1\n\n        point_feat = torch.squeeze(self.conv5(x_max), dim=3)        # point feat:  batch x 512 x num of points\n\n        global_feat = point_feat.max(dim=2, keepdim=False)[0]       # global feat: batch x 512\n\n        return global_feat"
  },
  {
    "path": "models/archs/encoders/rbf.py",
    "content": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport json\nimport numpy as np\n\n# https://github.com/vsitzmann/siren/blob/4df34baee3f0f9c8f351630992c1fe1f69114b5f/modules.py#L266\n\nclass RBFLayer(nn.Module):\n    '''Transforms incoming data using a given radial basis function.\n        - Input: (1, N, in_features) where N is an arbitrary batch size\n        - Output: (1, N, out_features) where N is an arbitrary batch size'''\n\n    def __init__(self, in_features=3, out_features=1024):\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.centres = nn.Parameter(torch.Tensor(out_features, in_features))\n        self.sigmas = nn.Parameter(torch.Tensor(out_features))\n        self.reset_parameters()\n\n        self.freq = nn.Parameter(np.pi * torch.ones((1, self.out_features)))\n\n    def reset_parameters(self):\n        nn.init.uniform_(self.centres, -1, 1)\n        nn.init.constant_(self.sigmas, 10)\n\n    def forward(self, x):\n        #x = x[0, ...]\n        size = (x.size(0), self.out_features, self.in_features)\n        x = x.unsqueeze(1).expand(size)\n        c = self.centres.unsqueeze(0).expand(size)\n        distances = (x - c).pow(2).sum(-1) * self.sigmas.unsqueeze(0)\n        return self.gaussian(distances).unsqueeze(0)\n\n    def gaussian(self, alpha):\n        phi = torch.exp(-1 * alpha.pow(2))\n        return phi\n"
  },
  {
    "path": "models/archs/encoders/sal_pointnet.py",
    "content": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\n\nimport json\nimport sys\nsys.path.append(\"..\")\nimport utils # actual dir is ../utils\n\n# pointnet from SAL paper: https://github.com/matanatz/SAL/blob/master/code/model/network.py#L14\nclass SalPointNet(nn.Module):\n    ''' PointNet-based encoder network. Based on: https://github.com/autonomousvision/occupancy_networks\n    Args:\n        c_dim (int): dimension of latent code c\n        dim (int): input points dimension\n        hidden_dim (int): hidden dimension of the network\n    '''\n\n    def __init__(self, c_dim=256, in_dim=3, hidden_dim=128):\n        super().__init__()\n        self.c_dim = c_dim\n\n        self.fc_pos = nn.Linear(in_dim, 2*hidden_dim)\n        self.fc_0 = nn.Linear(2*hidden_dim, hidden_dim)\n        self.fc_1 = nn.Linear(2*hidden_dim, hidden_dim)\n        self.fc_2 = nn.Linear(2*hidden_dim, hidden_dim)\n        self.fc_3 = nn.Linear(2*hidden_dim, hidden_dim)\n        self.fc_mean = nn.Linear(hidden_dim, c_dim)\n        self.fc_std = nn.Linear(hidden_dim, c_dim)\n        torch.nn.init.constant_(self.fc_mean.weight,0)\n        torch.nn.init.constant_(self.fc_mean.bias, 0)\n\n        torch.nn.init.constant_(self.fc_std.weight, 0)\n        torch.nn.init.constant_(self.fc_std.bias, -10)\n\n        self.actvn = nn.ReLU()\n        self.pool = self.maxpool\n\n    def forward(self, p):\n        net = self.fc_pos(p)\n        net = self.fc_0(self.actvn(net))\n        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())\n        net = torch.cat([net, pooled], dim=2)\n\n        net = self.fc_1(self.actvn(net))\n        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())\n        net = torch.cat([net, pooled], dim=2)\n\n        net = self.fc_2(self.actvn(net))\n        pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())\n        net = torch.cat([net, pooled], dim=2)\n\n        net = self.fc_3(self.actvn(net))\n\n        net = self.pool(net, dim=1)\n\n        c_mean = self.fc_mean(self.actvn(net))\n        c_std = self.fc_std(self.actvn(net))\n\n        return c_mean,c_std\n\n    def maxpool(self, x, dim=-1, keepdim=False):\n        out, _ = x.max(dim=dim, keepdim=keepdim)\n        return out\n"
  },
  {
    "path": "models/archs/encoders/vanilla_pointnet.py",
    "content": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nfrom torchmeta.modules import (MetaModule, MetaSequential, MetaLinear, MetaConv1d, MetaBatchNorm1d)\n\nimport json\nimport sys\nsys.path.append(\"..\")\nimport utils # actual dir is ../utils\n\nclass PointNet(MetaModule):\n    def __init__(self, latent_size):\n        super().__init__()\n        self.latent_size = latent_size\n        \n        self.conv1 = MetaConv1d(3, 64, kernel_size=1, bias=False)\n        self.conv2 = MetaConv1d(64, 64, kernel_size=1, bias=False)\n        self.conv3 = MetaConv1d(64, 64, kernel_size=1, bias=False)\n        self.conv4 = MetaConv1d(64, 128, kernel_size=1, bias=False)\n        self.conv5 = MetaConv1d(128, self.latent_size, kernel_size=1, bias=False)\n        self.bn1 = MetaBatchNorm1d(64)\n        self.bn2 = MetaBatchNorm1d(64)\n        self.bn3 = MetaBatchNorm1d(64)\n        self.bn4 = MetaBatchNorm1d(128)\n        self.bn5 = MetaBatchNorm1d(self.latent_size)\n\n        # self.layers = MetaSequential(\n        #     self.conv1, self.bn1, nn.ReLU(),\n        #     self.conv2, self.bn2, nn.ReLU(),\n        #     self.conv3, self.bn3, nn.ReLU(),\n        #     self.conv4, self.bn4, nn.ReLU(),\n        #     self.conv5, self.bn5,\n        #     )\n\n\n    # def forward(self, x, params=None):\n    #     x = self.layers(x, params=self.get_subdict(params, 'pointnet'))\n    #     x = x.max(dim=2, keepdim=False)[0]\n    #     return x\n\n    def forward(self, x, params=None):\n        x = F.relu(self.bn1(self.conv1(x, params)))\n        x = F.relu(self.bn2(self.conv2(x, params)))\n        x = F.relu(self.bn3(self.conv3(x, params)))\n        x = F.relu(self.bn4(self.conv4(x, params)))\n        x = self.bn5(self.conv5(x, params))\n        x = x.max(dim=2, keepdim=False)[0]\n        return x\n"
  },
  {
    "path": "models/archs/modulated_sdf.py",
    "content": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport json\nimport sys\nimport torch.nn.init as init\nimport numpy as np\n\n\n# no dropout or skip connections for now \n\nclass Layer(nn.Module):\n    def __init__(self, dim_in=512, dim_out=512, dim=512, dropout_prob=0.0, geo_init='first', activation='relu'):\n        super().__init__()\n\n        self.linear = nn.Linear(dim_in, dim_out)\n        if activation=='relu':\n            self.activation = nn.ReLU() \n        elif activation=='tanh':\n            self.activation = nn.Tanh() \n        else:\n            self.activation = nn.Identity()\n\n        #self.dropout = nn.Dropout(p=dropout_prob)\n\n        if geo_init == 'first':\n            init.normal_(self.linear.weight, mean=0.0, std=np.sqrt(2) / np.sqrt(dim))\n            init.constant_(self.linear.bias, 0.0)\n        elif geo_init == 'last':\n            init.normal_(self.linear.weight, mean=2 * np.sqrt(np.pi) / np.sqrt(dim), std=0.000001)\n            init.constant_(self.linear.bias, -0.5)\n\n    \n    def forward(self, x):\n        out = self.linear(x)\n        out = self.activation(out)\n        #out = self.dropout(out)\n\n        return out \n\n\n\nclass ModulatedMLP(nn.Module):\n    def __init__(self, latent_size=512, hidden_dim=512, num_layers=9, latent_in=True,\n                 skip_connection=[4], dropout_prob=0.0, pos_enc=False, pe_num_freq=5, tanh_act=False\n                 ):\n        super().__init__()\n\n        self.skip_connection = skip_connection # list of the indices of layer to add skip connection\n        self.hidden_dim = hidden_dim\n        self.pe_num_freq = pe_num_freq\n        self.pos_enc = pos_enc\n        self.latent_in = latent_in\n\n        #print(\"tanh act: \", tanh_act)\n        #print(\"latent in, skip layer: \", latent_in, skip_connection)\n\n        first_dim_in = 3\n\n        # must remove last tanh layer if using positional encoding \n        # posititional encoding\n        if pos_enc:\n            pe_func = []\n            for i in range(self.pe_num_freq):\n                pe_func.append(lambda data, freq=i: torch.sin(data * (2**i)))\n                pe_func.append(lambda data, freq=i: torch.cos(data * (2**i)))\n            self.pe_func = pe_func\n            first_dim_in = 3*pe_num_freq*2 \n\n\n        # latent code concatenated to coordinates as input to model\n        if latent_in:\n            num_modulations = hidden_dim\n            #num_modulations = hidden_dim * (num_layers - 1)\n            first_dim_in += latent_size # num_modulations\n            mod_act = nn.ReLU()\n\n        else: # use shifting instead of concatenation\n            # We modulate features at every *hidden* layer of the base network and\n            # therefore have dim_hidden * (num_layers - 1) modulations, since the last layer is not modulated\n            num_modulations = hidden_dim * (num_layers - 1)\n            mod_act = nn.Identity() \n        \n        #self.mod_net = nn.Sequential(nn.Linear(latent_size, num_modulations), mod_act)\n\n        layers = []\n        #print(\"index, dim in: \", end='')\n        for i in range(num_layers-1):\n            if i==0:\n                dim_in = first_dim_in\n            elif i in skip_connection:\n                dim_in = hidden_dim+3+latent_size #num_modulations+3+hidden_dim\n            else:\n                dim_in = hidden_dim\n\n            #print(i, dim_in, end = '; ')\n\n            layers.append(\n                Layer(\n                    dim_in=dim_in,\n                    dim_out=hidden_dim,\n                    activation='relu',\n                    geo_init='first',\n                    dim=hidden_dim,\n                    dropout_prob=dropout_prob\n                )\n            )\n\n        self.net = nn.Sequential(*layers)\n        last_act = 'tanh' if tanh_act else 'identity'\n        self.last_layer = Layer(dim_in=hidden_dim,dim_out=1,activation=last_act,geo_init='last',dim=hidden_dim)\n\n\n    def pe_transform(self, data):\n        pe_data = torch.cat([f(data) for f in self.pe_func], dim=-1)\n        return pe_data\n    def forward(self, xyz, latent):\n        '''\n        xyz: B, 16000, 3 (query coordinates for predicting)\n        latent: B, 512 (latent vector from 3 gradient steps)\n        '''\n        #print(\"latent: \",latent.shape)\n        modulations = latent#self.mod_net(latent)\n        #print(\"mod size: \",modulations.shape, xyz.shape)\n        #print(\"latent size: \", latent.shape, modulations.shape) # B,512 and B,512\n\n        if self.pos_enc:\n            xyz = self.pe_transform(xyz)\n\n        x = xyz.clone()\n\n        if self.latent_in:\n        #    modulations = modulations.unsqueeze(-2).repeat(1,xyz.shape[1],1) # [B, 16000, 512] or [B, 16000, 256]\n            #print(\"repeated mod shape: \",modulations.shape)\n            x = torch.cat((x, modulations),dim=-1)\n\n        #print(\"input size: \", x.shape, xyz.shape) # [8, 16000, 515], [8, 16000, 3]\n\n        idx = 0\n\n        for i, layer in enumerate(self.net):\n\n            if i in self.skip_connection:\n                x = torch.cat(( x, torch.cat((xyz, modulations),dim=-1)), dim=-1)\n            \n            x = layer.linear(x)\n            if not self.latent_in:\n                shift = modulations[:, idx : idx + self.hidden_dim].unsqueeze(1)\n                x = x + shift\n                idx += self.hidden_dim\n            x = layer.activation(x)\n            #x = layer.dropout(x)\n\n        out = self.last_layer(x)\n\n        return out, modulations\n\n\n    \n    \n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "models/archs/resnet_block.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n# Resnet Blocks\nclass ResnetBlockFC(nn.Module):\n    ''' Fully connected ResNet Block class.\n    Args:\n        size_in (int): input dimension\n        size_out (int): output dimension\n        size_h (int): hidden dimension\n    '''\n\n    def __init__(self, size_in, size_out=None, size_h=None):\n        super().__init__()\n        # Attributes\n        if size_out is None:\n            size_out = size_in\n\n        if size_h is None:\n            size_h = min(size_in, size_out)\n\n        self.size_in = size_in\n        self.size_h = size_h\n        self.size_out = size_out\n        # Submodules\n        self.fc_0 = nn.Linear(size_in, size_h)\n        self.fc_1 = nn.Linear(size_h, size_out)\n        self.actvn = nn.ReLU()\n\n        if size_in == size_out:\n            self.shortcut = None\n        else:\n            self.shortcut = nn.Linear(size_in, size_out, bias=False)\n        # Initialization\n        nn.init.zeros_(self.fc_1.weight)\n\n    def forward(self, x):\n        net = self.fc_0(self.actvn(x))\n        dx = self.fc_1(self.actvn(net))\n\n        if self.shortcut is not None:\n            x_s = self.shortcut(x)\n        else:\n            x_s = x\n\n        return x_s + dx"
  },
  {
    "path": "models/archs/sdf_decoder.py",
    "content": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport torch.nn.init as init\nimport numpy as np\n\n\nclass SdfDecoder(nn.Module):\n    def __init__(self, latent_size=256, hidden_dim=512,\n                 skip_connection=True, tanh_act=False,\n                 geo_init=True, input_size=None\n                 ):\n        super().__init__()\n        self.latent_size = latent_size\n        self.input_size = latent_size+3 if input_size is None else input_size\n        self.skip_connection = skip_connection\n        self.tanh_act = tanh_act\n\n        skip_dim = hidden_dim+self.input_size if skip_connection else hidden_dim \n\n        self.block1 = nn.Sequential(\n            nn.Linear(self.input_size, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n        )\n\n        self.block2 = nn.Sequential(\n            nn.Linear(skip_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n            nn.Linear(hidden_dim, hidden_dim),\n            nn.ReLU(),\n        )\n\n\n        self.block3 = nn.Linear(hidden_dim, 1)\n\n        if geo_init:\n            for m in self.block3.modules():\n                if isinstance(m, nn.Linear):\n                    init.normal_(m.weight, mean=2 * np.sqrt(np.pi) / np.sqrt(hidden_dim), std=0.000001)\n                    init.constant_(m.bias, -0.5)\n\n            for m in self.block2.modules():\n                if isinstance(m, nn.Linear):\n                    init.normal_(m.weight, mean=0.0, std=np.sqrt(2) / np.sqrt(hidden_dim))\n                    init.constant_(m.bias, 0.0)\n\n            for m in self.block1.modules():\n                if isinstance(m, nn.Linear):\n                    init.normal_(m.weight, mean=0.0, std=np.sqrt(2) / np.sqrt(hidden_dim))\n                    init.constant_(m.bias, 0.0)\n\n\n    def forward(self, x):\n        '''\n        x: concatenated xyz and shape features, shape: B, N, D+3 \n        '''        \n        block1_out = self.block1(x)\n\n        # skip connection, concat \n        if self.skip_connection:\n            block2_in = torch.cat([x, block1_out], dim=-1) \n        else:\n            block2_in = block1_out\n\n        block2_out = self.block2(block2_in)\n\n        out = self.block3(block2_out)\n\n        if self.tanh_act:\n            out = nn.Tanh()(out)\n\n        return out\n"
  },
  {
    "path": "models/archs/unet.py",
    "content": "'''\nCodes are from:\nhttps://github.com/jaxony/unet-pytorch/blob/master/model.py\n'''\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom collections import OrderedDict\nfrom torch.nn import init\nimport numpy as np\n\ndef conv3x3(in_channels, out_channels, stride=1, \n            padding=1, bias=True, groups=1):    \n    return nn.Conv2d(\n        in_channels,\n        out_channels,\n        kernel_size=3,\n        stride=stride,\n        padding=padding,\n        bias=bias,\n        groups=groups)\n\ndef upconv2x2(in_channels, out_channels, mode='transpose'):\n    if mode == 'transpose':\n        return nn.ConvTranspose2d(\n            in_channels,\n            out_channels,\n            kernel_size=2,\n            stride=2)\n    else:\n        # out_channels is always going to be the same\n        # as in_channels\n        return nn.Sequential(\n            nn.Upsample(mode='bilinear', scale_factor=2),\n            conv1x1(in_channels, out_channels))\n\ndef conv1x1(in_channels, out_channels, groups=1):\n    return nn.Conv2d(\n        in_channels,\n        out_channels,\n        kernel_size=1,\n        groups=groups,\n        stride=1)\n\n\nclass DownConv(nn.Module):\n    \"\"\"\n    A helper Module that performs 2 convolutions and 1 MaxPool.\n    A ReLU activation follows each convolution.\n    \"\"\"\n    def __init__(self, in_channels, out_channels, pooling=True):\n        super(DownConv, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.pooling = pooling\n\n        self.conv1 = conv3x3(self.in_channels, self.out_channels)\n        self.conv2 = conv3x3(self.out_channels, self.out_channels)\n\n        if self.pooling:\n            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        before_pool = x\n        if self.pooling:\n            x = self.pool(x)\n        return x, before_pool\n\n\nclass UpConv(nn.Module):\n    \"\"\"\n    A helper Module that performs 2 convolutions and 1 UpConvolution.\n    A ReLU activation follows each convolution.\n    \"\"\"\n    def __init__(self, in_channels, out_channels, \n                 merge_mode='concat', up_mode='transpose'):\n        super(UpConv, self).__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.merge_mode = merge_mode\n        self.up_mode = up_mode\n\n        self.upconv = upconv2x2(self.in_channels, self.out_channels, \n            mode=self.up_mode)\n\n        if self.merge_mode == 'concat':\n            self.conv1 = conv3x3(\n                2*self.out_channels, self.out_channels)\n        else:\n            # num of input channels to conv2 is same\n            self.conv1 = conv3x3(self.out_channels, self.out_channels)\n        self.conv2 = conv3x3(self.out_channels, self.out_channels)\n\n\n    def forward(self, from_down, from_up):\n        \"\"\" Forward pass\n        Arguments:\n            from_down: tensor from the encoder pathway\n            from_up: upconv'd tensor from the decoder pathway\n        \"\"\"\n        from_up = self.upconv(from_up)\n        if self.merge_mode == 'concat':\n            x = torch.cat((from_up, from_down), 1)\n        else:\n            x = from_up + from_down\n        x = F.relu(self.conv1(x))\n        x = F.relu(self.conv2(x))\n        return x\n\n\nclass UNet(nn.Module):\n    \"\"\" `UNet` class is based on https://arxiv.org/abs/1505.04597\n\n    The U-Net is a convolutional encoder-decoder neural network.\n    Contextual spatial information (from the decoding,\n    expansive pathway) about an input tensor is merged with\n    information representing the localization of details\n    (from the encoding, compressive pathway).\n\n    Modifications to the original paper:\n    (1) padding is used in 3x3 convolutions to prevent loss\n        of border pixels\n    (2) merging outputs does not require cropping due to (1)\n    (3) residual connections can be used by specifying\n        UNet(merge_mode='add')\n    (4) if non-parametric upsampling is used in the decoder\n        pathway (specified by upmode='upsample'), then an\n        additional 1x1 2d convolution occurs after upsampling\n        to reduce channel dimensionality by a factor of 2.\n        This channel halving happens with the convolution in\n        the tranpose convolution (specified by upmode='transpose')\n    \"\"\"\n\n    def __init__(self, num_classes, in_channels=3, depth=5, \n                 start_filts=64, up_mode='transpose', \n                 merge_mode='concat', **kwargs):\n        \"\"\"\n        Arguments:\n            in_channels: int, number of channels in the input tensor.\n                Default is 3 for RGB images.\n            depth: int, number of MaxPools in the U-Net.\n            start_filts: int, number of convolutional filters for the \n                first conv.\n            up_mode: string, type of upconvolution. Choices: 'transpose'\n                for transpose convolution or 'upsample' for nearest neighbour\n                upsampling.\n        \"\"\"\n        super(UNet, self).__init__()\n\n        if up_mode in ('transpose', 'upsample'):\n            self.up_mode = up_mode\n        else:\n            raise ValueError(\"\\\"{}\\\" is not a valid mode for \"\n                             \"upsampling. Only \\\"transpose\\\" and \"\n                             \"\\\"upsample\\\" are allowed.\".format(up_mode))\n    \n        if merge_mode in ('concat', 'add'):\n            self.merge_mode = merge_mode\n        else:\n            raise ValueError(\"\\\"{}\\\" is not a valid mode for\"\n                             \"merging up and down paths. \"\n                             \"Only \\\"concat\\\" and \"\n                             \"\\\"add\\\" are allowed.\".format(up_mode))\n\n        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'\n        if self.up_mode == 'upsample' and self.merge_mode == 'add':\n            raise ValueError(\"up_mode \\\"upsample\\\" is incompatible \"\n                             \"with merge_mode \\\"add\\\" at the moment \"\n                             \"because it doesn't make sense to use \"\n                             \"nearest neighbour to reduce \"\n                             \"depth channels (by half).\")\n\n        self.num_classes = num_classes\n        self.in_channels = in_channels\n        self.start_filts = start_filts\n        self.depth = depth\n\n        self.down_convs = []\n        self.up_convs = []\n\n        # create the encoder pathway and add to a list\n        for i in range(depth):\n            ins = self.in_channels if i == 0 else outs\n            outs = self.start_filts*(2**i)\n            pooling = True if i < depth-1 else False\n\n            down_conv = DownConv(ins, outs, pooling=pooling)\n            self.down_convs.append(down_conv)\n\n        # create the decoder pathway and add to a list\n        # - careful! decoding only requires depth-1 blocks\n        for i in range(depth-1):\n            ins = outs\n            outs = ins // 2\n            up_conv = UpConv(ins, outs, up_mode=up_mode,\n                merge_mode=merge_mode)\n            self.up_convs.append(up_conv)\n\n        # add the list of modules to current module\n        self.down_convs = nn.ModuleList(self.down_convs)\n        self.up_convs = nn.ModuleList(self.up_convs)\n\n        self.conv_final = conv1x1(outs, self.num_classes)\n\n        self.reset_params()\n\n    @staticmethod\n    def weight_init(m):\n        if isinstance(m, nn.Conv2d):\n            init.xavier_normal_(m.weight)\n            init.constant_(m.bias, 0)\n\n\n    def reset_params(self):\n        for i, m in enumerate(self.modules()):\n            self.weight_init(m)\n\n\n    def forward(self, x):\n        encoder_outs = []\n        # encoder pathway, save outputs for merging\n        for i, module in enumerate(self.down_convs):\n            x, before_pool = module(x)\n            encoder_outs.append(before_pool)\n        for i, module in enumerate(self.up_convs):\n            before_pool = encoder_outs[-(i+2)]\n            x = module(before_pool, x)\n        \n        # No softmax is used. This means you need to use\n        # nn.CrossEntropyLoss is your training script,\n        # as this module includes a softmax already.\n        x = self.conv_final(x)\n        return x\n\nif __name__ == \"__main__\":\n    \"\"\"\n    testing\n    \"\"\"\n    model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)\n    print(model)\n    print(sum(p.numel() for p in model.parameters()))\n\n    reso = 176\n    x = np.zeros((1, 1, reso, reso))\n    x[:,:,int(reso/2-1), int(reso/2-1)] = np.nan\n    x = torch.FloatTensor(x)\n\n    out = model(x)\n    print('%f'%(torch.sum(torch.isnan(out)).detach().cpu().numpy()/(reso*reso)))\n    \n    # loss = torch.sum(out)\n    # loss.backward()"
  },
  {
    "path": "models/autoencoder.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom einops import rearrange, reduce\n\nfrom typing import List, Callable, Union, Any, TypeVar, Tuple\nTensor = TypeVar('torch.tensor')\n\n\nclass BetaVAE(nn.Module):\n\n    num_iter = 0 # Global static variable to keep track of iterations\n\n    def __init__(self,\n                 in_channels: int,\n                 latent_dim: int,\n                 hidden_dims: List = None,\n                 kl_std=1.0,\n                 beta: int = 4,\n                 gamma:float = 10., \n                 max_capacity: int = 25,\n                 Capacity_max_iter: int = 1e5, # 10000 in default configs\n                 loss_type:str = 'B',\n                 **kwargs) -> None:\n        super(BetaVAE, self).__init__()\n\n        self.latent_dim = latent_dim\n        self.beta = beta\n        self.gamma = gamma\n        self.loss_type = loss_type\n        self.C_max = torch.Tensor([max_capacity])\n        self.C_stop_iter = Capacity_max_iter\n        self.in_channels = in_channels\n\n        self.kl_std = kl_std\n\n        #print(\"kl standard deviation: \", self.kl_std)\n\n        modules = []\n        if hidden_dims is None:\n            #hidden_dims = [32, 64, 128, 256, 512]\n            hidden_dims = [512, 512, 512, 512, 512]\n\n        self.hidden_dims = hidden_dims\n\n        # Build Encoder\n        for h_dim in hidden_dims:\n            modules.append(\n                nn.Sequential(\n                    nn.Conv2d(in_channels, out_channels=h_dim, kernel_size=3, stride=2, padding=1),\n                    nn.BatchNorm2d(h_dim),\n                    nn.LeakyReLU())\n            )\n            in_channels = h_dim\n\n        self.encoder = nn.Sequential(*modules)\n        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)  # for plane features resolution 64x64, spatial resolution is 2x2 after the last encoder layer\n        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) \n\n\n        # Build Decoder\n        modules = []\n\n        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) \n\n        hidden_dims.reverse()\n\n        for i in range(len(hidden_dims) - 1):\n            modules.append(\n                nn.Sequential(\n                    nn.ConvTranspose2d(hidden_dims[i],\n                                    hidden_dims[i + 1],\n                                    kernel_size=3,\n                                    stride = 2,\n                                    padding=1,\n                                    output_padding=1),\n                    nn.BatchNorm2d(hidden_dims[i + 1]),\n                    nn.LeakyReLU())\n            )\n\n\n\n        self.decoder = nn.Sequential(*modules)\n\n        self.final_layer = nn.Sequential(\n                            nn.ConvTranspose2d(hidden_dims[-1],\n                                               hidden_dims[-1],\n                                               kernel_size=3,\n                                               stride=2,\n                                               padding=1,\n                                               output_padding=1),\n                            nn.BatchNorm2d(hidden_dims[-1]),\n                            nn.LeakyReLU(),\n                            nn.Conv2d(hidden_dims[-1], out_channels= self.in_channels, # changed from 3 to in_channels\n                                      kernel_size= 3, padding= 1),\n                            nn.Tanh())\n\n\n        #print(self)\n\n    def encode(self, enc_input: Tensor) -> List[Tensor]:\n        \"\"\"\n        Encodes the input by passing through the encoder network\n        and returns the latent codes.\n        :enc_input: (Tensor) Input tensor to encoder [B x D x resolution x resolution]\n        :return: (Tensor) List of latent codes\n        \"\"\"\n        result = enc_input\n        result = self.encoder(enc_input)  # [B, D, 2, 2]\n        result = torch.flatten(result, start_dim=1) # ([32, D*4])\n\n        # Split the result into mu and var components\n        # of the latent Gaussian distribution\n        mu = self.fc_mu(result)\n        log_var = self.fc_var(result)\n\n        return [mu, log_var]\n\n    def decode(self, z: Tensor) -> Tensor:\n        '''\n        z: latent vector: B, D (D = latent_dim*3)\n        '''\n        \n        result = self.decoder_input(z) # ([32, D*4])\n        result = result.view(-1, int(result.shape[-1]/4), 2, 2)  # for plane features resolution 64x64, spatial resolution is 2x2 after the last encoder layer\n        result = self.decoder(result)\n        result = self.final_layer(result) # ([32, D, resolution, resolution])\n        return result\n\n    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:\n        \"\"\"\n        Will a single z be enough to compute the expectation\n        for the loss??\n        :param mu: (Tensor) Mean of the latent Gaussian\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian\n        :return:\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n    def forward(self, data: Tensor, **kwargs) -> Tensor:\n        mu, log_var = self.encode(data)\n        z = self.reparameterize(mu, log_var)\n        return  [self.decode(z), data, mu, log_var, z]\n\n    # only using VAE loss\n    def loss_function(self,\n                      *args,\n                      **kwargs) -> dict:\n        self.num_iter += 1\n        recons = args[0]\n        data = args[1]\n        mu = args[2]\n        log_var = args[3]\n        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset\n        #print(\"recon, data shape: \", recons.shape, data.shape)\n        #recons_loss = F.mse_loss(recons, data)\n\n        # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)\n\n      \n        if self.kl_std == 'zero_mean':\n            latent = self.reparameterize(mu, log_var) \n            #print(\"latent shape: \", latent.shape) # (B, dim)\n            l2_size_loss = torch.sum(torch.norm(latent, dim=-1))\n            kl_loss = l2_size_loss / latent.shape[0]\n\n        else:\n            std = torch.exp(0.5 * log_var)\n            gt_dist = torch.distributions.normal.Normal( torch.zeros_like(mu), torch.ones_like(std)*self.kl_std )\n            sampled_dist = torch.distributions.normal.Normal( mu, std )\n            #gt_dist = normal_dist.sample(log_var.shape)\n            #print(\"gt dist shape: \", gt_dist.shape)\n\n            kl = torch.distributions.kl.kl_divergence(sampled_dist, gt_dist) # reversed KL\n            kl_loss = reduce(kl, 'b ... -> b (...)', 'mean').mean()\n\n        return kld_weight * kl_loss\n\n    def sample(self,\n               num_samples:int,\n                **kwargs) -> Tensor:\n        \"\"\"\n        Samples from the latent space and return the corresponding\n        image space map.\n        :param num_samples: (Int) Number of samples\n        :param current_device: (Int) Device to run the model\n        :return: (Tensor)\n        \"\"\"\n        z = torch.randn(num_samples, self.latent_dim)\n\n        z = z.cuda()\n\n        samples = self.decode(z)\n        return samples\n\n    def generate(self, x: Tensor, **kwargs) -> Tensor:\n        \"\"\"\n        Given an input image x, returns the reconstructed image\n        :param x: (Tensor) [B x C x H x W]\n        :return: (Tensor) [B x C x H x W]\n        \"\"\"\n\n        return self.forward(x)[0]\n\n    def get_latent(self, x):\n        '''\n        given input x, return the latent code\n        x:  [B x C x H x W]\n        return: [B x latent_dim]\n        '''\n        mu, log_var = self.encode(x)\n        z = self.reparameterize(mu, log_var)\n        return z "
  },
  {
    "path": "models/combined_model.py",
    "content": "import torch\nimport torch.utils.data \nfrom torch.nn import functional as F\nimport pytorch_lightning as pl\n\n# add paths in model/__init__.py for new models\nfrom models import * \n\nclass CombinedModel(pl.LightningModule):\n    def __init__(self, specs):\n        super().__init__()\n        self.specs = specs\n\n        self.task = specs['training_task'] # 'combined' or 'modulation' or 'diffusion'\n\n        if self.task in ('combined', 'modulation'):\n            self.sdf_model = SdfModel(specs=specs) \n\n            feature_dim = specs[\"SdfModelSpecs\"][\"latent_dim\"] # latent dim of pointnet \n            modulation_dim = feature_dim*3 # latent dim of modulation\n            latent_std = specs.get(\"latent_std\", 0.25) # std of target gaussian distribution of latent space\n            hidden_dims = [modulation_dim, modulation_dim, modulation_dim, modulation_dim, modulation_dim]\n            self.vae_model = BetaVAE(in_channels=feature_dim*3, latent_dim=modulation_dim, hidden_dims=hidden_dims, kl_std=latent_std)\n\n        if self.task in ('combined', 'diffusion'):\n            self.diffusion_model = DiffusionModel(model=DiffusionNet(**specs[\"diffusion_model_specs\"]), **specs[\"diffusion_specs\"]) \n \n\n    def training_step(self, x, idx):\n\n        if self.task == 'combined':\n            return self.train_combined(x)\n        elif self.task == 'modulation':\n            return self.train_modulation(x)\n        elif self.task == 'diffusion':\n            return self.train_diffusion(x)\n        \n\n    def configure_optimizers(self):\n\n        if self.task == 'combined':\n            params_list = [\n                    { 'params': list(self.sdf_model.parameters()) + list(self.vae_model.parameters()), 'lr':self.specs['sdf_lr'] },\n                    { 'params': self.diffusion_model.parameters(), 'lr':self.specs['diff_lr'] }\n                ]\n        elif self.task == 'modulation':\n            params_list = [\n                    { 'params': self.parameters(), 'lr':self.specs['sdf_lr'] }\n                ]\n        elif self.task == 'diffusion':\n            params_list = [\n                    { 'params': self.parameters(), 'lr':self.specs['diff_lr'] }\n                ]\n\n        optimizer = torch.optim.Adam(params_list)\n        return {\n                \"optimizer\": optimizer,\n                # \"lr_scheduler\": {\n                # \"scheduler\": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=50000, threshold=0.0002, min_lr=1e-6, verbose=False),\n                # \"monitor\": \"total\"\n                # }\n        }\n\n\n    #-----------different training steps for sdf modulation, diffusion, combined----------\n\n    def train_modulation(self, x):\n\n        xyz = x['xyz'] # (B, N, 3)\n        gt = x['gt_sdf'] # (B, N)\n        pc = x['point_cloud'] # (B, 1024, 3)\n\n        # STEP 1: obtain reconstructed plane feature and latent code \n        plane_features = self.sdf_model.pointnet.get_plane_features(pc)\n        original_features = torch.cat(plane_features, dim=1)\n        out = self.vae_model(original_features) # out = [self.decode(z), input, mu, log_var, z]\n        reconstructed_plane_feature, latent = out[0], out[-1]\n\n        # STEP 2: pass recon back to GenSDF pipeline \n        pred_sdf = self.sdf_model.forward_with_plane_features(reconstructed_plane_feature, xyz)\n        \n        # STEP 3: losses for VAE and SDF\n        # we only use the KL loss for the VAE; no reconstruction loss\n        try:\n            vae_loss = self.vae_model.loss_function(*out, M_N=self.specs[\"kld_weight\"] )\n        except:\n            print(\"vae loss is nan at epoch {}...\".format(self.current_epoch))\n            return None # skips this batch\n\n        sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze(), reduction='none')\n        sdf_loss = reduce(sdf_loss, 'b ... -> b (...)', 'mean').mean()\n\n        loss = sdf_loss + vae_loss\n\n        loss_dict =  {\"sdf\": sdf_loss, \"vae\": vae_loss}\n        self.log_dict(loss_dict, prog_bar=True, enable_graph=False)\n\n        return loss\n\n\n    def train_diffusion(self, x):\n\n        self.train()\n\n        pc = x['point_cloud'] # (B, 1024, 3) or False if unconditional \n        latent = x['latent'] # (B, D)\n\n        # unconditional training if cond is None \n        cond = pc if self.specs['diffusion_model_specs']['cond'] else None \n\n        # diff_100 and 1000 loss refers to the losses when t<100 and 100<t<1000, respectively \n        # typically diff_100 approaches 0 while diff_1000 can still be relatively high\n        # visualizing loss curves can help with debugging if training is unstable\n        diff_loss, diff_100_loss, diff_1000_loss, pred_latent, perturbed_pc = self.diffusion_model.diffusion_model_from_latent(latent, cond=cond)\n\n        loss_dict =  {\n                        \"total\": diff_loss,\n                        \"diff100\": diff_100_loss, # note that this can appear as nan when the training batch does not have sampled timesteps < 100\n                        \"diff1000\": diff_1000_loss\n                    }\n        self.log_dict(loss_dict, prog_bar=True, enable_graph=False)\n\n        return diff_loss\n\n    # the first half is the same as \"train_sdf_modulation\"\n    # the reconstructed latent is used as input to the diffusion model, rather than loading latents from the dataloader as in \"train_diffusion\"\n    def train_combined(self, x):\n        xyz = x['xyz'] # (B, N, 3)\n        gt = x['gt_sdf'] # (B, N)\n        pc = x['point_cloud'] # (B, 1024, 3)\n\n        # STEP 1: obtain reconstructed plane feature for SDF and latent code for diffusion\n        plane_features = self.sdf_model.pointnet.get_plane_features(pc)\n        original_features = torch.cat(plane_features, dim=1)\n        #print(\"plane feat shape: \", feat.shape)\n        out = self.vae_model(original_features) # out = [self.decode(z), input, mu, log_var, z]\n        reconstructed_plane_feature, latent = out[0], out[-1] # [B, D*3, resolution, resolution], [B, D*3]\n\n        # STEP 2: pass recon back to GenSDF pipeline \n        pred_sdf = self.sdf_model.forward_with_plane_features(reconstructed_plane_feature, xyz)\n        \n        # STEP 3: losses for VAE and SDF \n        try:\n            vae_loss = self.vae_model.loss_function(*out, M_N=self.specs[\"kld_weight\"] )\n        except:\n            print(\"vae loss is nan at epoch {}...\".format(self.current_epoch))\n            return None # skips this batch\n        sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze(), reduction='none')\n        sdf_loss = reduce(sdf_loss, 'b ... -> b (...)', 'mean').mean()\n\n        # STEP 4: use latent as input to diffusion model\n        cond = pc if self.specs['diffusion_model_specs']['cond'] else None\n        diff_loss, diff_100_loss, diff_1000_loss, pred_latent, perturbed_pc = self.diffusion_model.diffusion_model_from_latent(latent, cond=cond)\n        \n        # STEP 5: use predicted / reconstructed latent to run SDF loss \n        generated_plane_feature = self.vae_model.decode(pred_latent)\n        generated_sdf_pred = self.sdf_model.forward_with_plane_features(generated_plane_feature, xyz)\n        generated_sdf_loss = F.l1_loss(generated_sdf_pred.squeeze(), gt.squeeze())\n\n        # surface weight could prioritize points closer to surface but we did not notice better results when using it \n        #surface_weight = torch.exp(-50 * torch.abs(gt))\n        #generated_sdf_loss = torch.mean( F.l1_loss(generated_sdf_pred, gt, reduction='none') * surface_weight )\n\n        # we did not experiment with using constants/weights for each loss (VAE loss is weighted using value in specs file)\n        # results could potentially improve with a grid search \n        loss = sdf_loss + vae_loss + diff_loss + generated_sdf_loss\n\n        loss_dict =  {\n                        \"total\": loss,\n                        \"sdf\": sdf_loss,\n                        \"vae\": vae_loss,\n                        \"diff\": diff_loss,\n                        # diff_100 and 1000 loss refers to the losses when t<100 and 100<t<1000, respectively \n                        # typically diff_100 approaches 0 while diff_1000 can still be relatively high\n                        # visualizing loss curves can help with debugging if training is unstable\n                        #\"diff100\": diff_100_loss, # note that this can sometimes appear as nan when the training batch does not have sampled timesteps < 100\n                        #\"diff1000\": diff_1000_loss,\n                        \"gensdf\": generated_sdf_loss,\n                    }\n        self.log_dict(loss_dict, prog_bar=True, enable_graph=False)\n\n        return loss"
  },
  {
    "path": "models/diff_np_if_torch_error.py",
    "content": "import math\nimport copy\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom inspect import isfunction\nfrom collections import namedtuple\nfrom functools import partial\n\nfrom einops import rearrange, reduce\nfrom einops.layers.torch import Rearrange\n\n#from model.diffusion.model import * \nfrom diff_utils.helpers import * \n\nimport numpy as np\nimport os\nfrom statistics import mean\nfrom tqdm.auto import tqdm\nimport open3d as o3d\n\n\n# constants\nModelPrediction =  namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])\n\n\nclass DiffusionModel(nn.Module):\n    def __init__(\n        self,\n        model,\n        timesteps = 1000, sampling_timesteps = None, beta_schedule = 'cosine',\n        sample_pc_size = 682, perturb_pc = None,  crop_percent=0.25,\n        loss_type = 'l2', objective = 'pred_x0', \n        data_scale = 1.0, data_shift = 0.0,\n        p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended\n        p2_loss_weight_k = 1,\n        ddim_sampling_eta = 1.\n    ):\n        super().__init__()\n\n        self.model = model\n        self.objective = objective\n\n        betas = linear_beta_schedule(timesteps) if beta_schedule == 'linear' else cosine_beta_schedule(timesteps)\n        alphas = 1. - betas\n        alphas_cumprod = torch.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n\n        self.pc_size = sample_pc_size\n        self.perturb_pc = perturb_pc \n        self.crop_percent = crop_percent\n        assert self.perturb_pc in [None, \"partial\", \"noisy\"]\n\n        self.loss_fn = F.l1_loss if loss_type=='l1' else F.mse_loss\n\n        # sampling related parameters\n        self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training\n        assert self.sampling_timesteps <= timesteps\n        self.ddim_sampling_eta = ddim_sampling_eta\n        \n        # self.register_buffer('data_scale', torch.tensor(data_scale))\n        # self.register_buffer('data_shift', torch.tensor(data_shift))\n\n        # helper function to register buffer from float64 to float32\n        #register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))\n        register_buffer = lambda name, val: self.register_buffer(name, torch.from_numpy(val).to(torch.float32))\n        \n        np_alphas = alphas_cumprod.numpy()\n        np_alphas_prev = alphas_cumprod_prev.numpy()\n        betas = betas.numpy()\n        alphas = alphas.numpy()\n\n\n        # register_buffer('betas', betas)\n        # register_buffer('alphas_cumprod', alphas_cumprod)\n        # register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)\n        register_buffer('betas', betas)\n        register_buffer('alphas_cumprod', np_alphas)\n        register_buffer('alphas_cumprod_prev', np_alphas_prev)\n\n\n        register_buffer('sqrt_alphas_cumprod', np.sqrt(np_alphas))\n        register_buffer('sqrt_one_minus_alphas_cumprod', np.sqrt(1. - np_alphas))\n        register_buffer('log_one_minus_alphas_cumprod', np.log(1. - np_alphas))\n        register_buffer('sqrt_recip_alphas_cumprod', np.sqrt(1. / np_alphas))\n        register_buffer('sqrt_recipm1_alphas_cumprod', np.sqrt(1. / np_alphas - 1))\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        posterior_variance = betas * (1. - np_alphas_prev) / (1. - np_alphas)\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        register_buffer('posterior_variance', posterior_variance)\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        register_buffer('posterior_log_variance_clipped', np.log(posterior_variance.clip(min =1e-20)))\n        register_buffer('posterior_mean_coef1', betas * np.sqrt(np_alphas_prev) / (1. - np_alphas))\n        register_buffer('posterior_mean_coef2', (1. - np_alphas_prev) * np.sqrt(alphas) / (1. - np_alphas))\n\n\n\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        # register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))\n        # register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))\n        # register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))\n        # register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))\n        # register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))\n\n        # # calculations for posterior q(x_{t-1} | x_t, x_0)\n        # posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n        # # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        # register_buffer('posterior_variance', posterior_variance)\n        # # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        # register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))\n        # register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))\n        # register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))\n        \n        # calculate p2 reweighting\n        # register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)\n        register_buffer('p2_loss_weight', (p2_loss_weight_k + np_alphas / (1 - np_alphas)) ** -p2_loss_weight_gamma)\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def predict_noise_from_start(self, x_t, t, x0):\n        return (\n            (x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \\\n            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n        )\n\n    @torch.no_grad()\n    def ddim_sample(self, dim, batch_size, noise=None, clip_denoised = True, traj=False, cond=None):\n        batch, device, total_timesteps, sampling_timesteps, eta, objective = batch_size, self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective\n        times = torch.linspace(0., total_timesteps, steps = sampling_timesteps + 2)[:-1]\n        times = list(reversed(times.int().tolist()))\n        time_pairs = list(zip(times[:-1], times[1:]))\n\n        traj = []\n\n        x_T = default(noise, torch.randn(batch, dim, device = device)) \n\n        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):\n            alpha = self.alphas_cumprod_prev[time]\n            alpha_next = self.alphas_cumprod_prev[time_next]\n\n            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)\n\n            model_input = (x_T, cond) if cond is not None else x_T\n            pred_noise, x_start, *_ = self.model_predictions(model_input, time_cond)\n\n            if clip_denoised:\n                x_start.clamp_(-1., 1.)\n\n            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n            c = ((1 - alpha_next) - sigma ** 2).sqrt()\n\n            noise = torch.randn_like(x_T) if time_next > 0 else 0.\n\n            x_T = x_start * alpha_next.sqrt() + \\\n                  c * pred_noise + \\\n                  sigma * noise\n            \n            traj.append(x_T.clone())\n    \n        if traj:\n            return x_T, traj\n        else:\n            return x_T\n\n    @torch.no_grad()\n    def sample(self, dim, batch_size, noise=None, clip_denoised = True, traj=False, cond=None):\n\n        batch, device, objective = batch_size, self.betas.device, self.objective\n\n        traj = []\n\n        x_T = default(noise, torch.randn(batch, dim, device = device))\n\n        for t in reversed(range(0, self.num_timesteps)):\n            \n            time_cond = torch.full((batch,), t, device = device, dtype = torch.long)\n\n            model_input = (x_T, cond) if cond is not None else x_T\n            pred_noise, x_start, *_ = self.model_predictions(model_input, time_cond)\n            if clip_denoised:\n                x_start.clamp_(-1., 1.)\n\n            model_mean, _, model_log_variance = self.q_posterior(x_start = x_start, x_t = x_T, t = time_cond)\n\n            noise = torch.randn_like(x_T) if t > 0 else 0. # no noise if t == 0\n\n            x_T = model_mean + (0.5 * model_log_variance).exp() * noise\n            \n            traj.append(x_T.clone())\n    \n        if traj:\n            return x_T, traj\n        else:\n            return x_T\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n\n    # \"nice property\": return x_t given x_0, noise, and timestep\n    def q_sample(self, x_start, t, noise=None):\n        \n        noise = default(noise, lambda: torch.randn_like(x_start))\n        #noise = torch.clamp(noise, min=-6.0, max=6.0)\n\n        return (\n            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n        )\n\n    # main function for calculating loss\n    def forward(self, x_start, t, ret_pred_x=False, noise = None, cond=None):\n        '''\n        x_start: [B, D]\n        t: [B]\n        '''\n\n        noise = default(noise, lambda: torch.randn_like(x_start)) \n\n        x = self.q_sample(x_start=x_start, t=t, noise=noise) \n\n        model_in = (x, cond) if cond is not None else x\n        model_out = self.model(model_in, t)\n\n        if self.objective == 'pred_noise':\n            target = noise\n        elif self.objective == 'pred_x0':\n            target = x_start\n        else:\n            raise ValueError(f'unknown objective {self.objective}')\n\n        loss = self.loss_fn(model_out, target, reduction = 'none')\n        #loss = reduce(loss, 'b ... -> b (...)', 'mean', b = x_start.shape[0]) # only one dim of latent so don't need this line \n        \n        loss = loss * extract(self.p2_loss_weight, t, loss.shape)\n        unreduced_loss = loss.detach().clone().mean(dim=1)\n        \n        if ret_pred_x:\n            return loss.mean(), x, target, model_out, unreduced_loss\n        else:\n            return loss.mean(), unreduced_loss\n\n    def model_predictions(self, model_input, t):\n        \n        #model_output1 = self.model(model_input, t, pass_cond=0)\n        #model_output2 = self.model(model_input, t, pass_cond=1)\n        #model_output = model_output2*5 - model_output1*4\n        model_output = self.model(model_input, t, pass_cond=1)\n\n        x = model_input[0] if type(model_input) is tuple else model_input\n\n        if self.objective == 'pred_noise':\n            pred_noise = model_output\n            x_start = self.predict_start_from_noise(x, t, model_output)\n\n        elif self.objective == 'pred_x0':\n            pred_noise = self.predict_noise_from_start(x, t, model_output)\n            x_start = model_output\n\n        return ModelPrediction(pred_noise, x_start)\n\n\n    \n    # a wrapper function that only takes x_start (clean modulation vector) and condition\n    # does everything including sampling timestep and returns loss, loss_100, loss_1000, prediction\n    def diffusion_model_from_latent(self, x_start, cond=None):\n        #if self.perturb_pc is None and cond is not None:\n        #    print(\"check whether to pass condition!!!\")\n\n        # STEP 1: sample timestep \n        t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device).long()\n\n        # STEP 2: perturb condition\n        pc = perturb_point_cloud(cond, self.perturb_pc, self.pc_size, self.crop_percent) if cond is not None else None\n\n        # STEP 3: pass to forward function\n        loss, x, target, model_out, unreduced_loss = self(x_start, t, cond=pc, ret_pred_x=True)\n        loss_100 = unreduced_loss[t<100].mean().detach()\n        loss_1000 = unreduced_loss[t>100].mean().detach()\n\n        return loss, loss_100, loss_1000, model_out, pc\n\n\n    def generate_from_pc(self, pc, load_pc=False, batch=5, save_pc=False, return_pc=False, ddim=False, perturb_pc=True):\n        self.eval()\n\n        with torch.no_grad():\n            if load_pc:\n                pc = sample_pc(pc, self.pc_size).cuda().unsqueeze(0)\n\n            if pc is None:\n                input_pc = None\n                save_pc = False\n                full_perturbed_pc = None\n\n            else:\n                if perturb_pc:\n                    full_perturbed_pc = perturb_point_cloud(pc, self.perturb_pc)\n                    perturbed_pc = full_perturbed_pc[:, torch.randperm(full_perturbed_pc.shape[1])[:self.pc_size] ]\n                    input_pc = perturbed_pc.repeat(batch, 1, 1)\n                else:\n                    full_perturbed_pc = pc\n                    perturbed_pc = pc\n                    input_pc = pc.repeat(batch, 1, 1)\n\n            #print(\"shapes: \", pc.shape, self.pc_size, self.perturb_pc, perturbed_pc.shape, full_perturbed_pc.shape)\n            #print(\"pc path: \", pc_path)\n\n            #print(\"pc shape: \", perturbed_pc.shape, input_pc.shape)\n            if save_pc: # save perturbed pc ply file for visualization\n                pcd = o3d.geometry.PointCloud()\n                pcd.points = o3d.utility.Vector3dVector(perturbed_pc.cpu().numpy().squeeze())\n                o3d.io.write_point_cloud(\"{}/input_pc.ply\".format(save_pc), pcd)\n            \n            sample_fn = self.ddim_sample if ddim else self.sample\n            samp,_ = sample_fn(dim=self.model.dim_in_out, batch_size=batch, traj=False, cond=input_pc)\n\n        if return_pc:\n            return samp, perturbed_pc\n        return samp\n\n    def generate_unconditional(self, num_samples):\n        self.eval()\n        with torch.no_grad():\n            samp,_ = self.sample(dim=self.model.dim_in_out, batch_size=num_samples, traj=False, cond=None)\n\n        return samp\n"
  },
  {
    "path": "models/diffusion.py",
    "content": "import math\nimport copy\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom inspect import isfunction\nfrom collections import namedtuple\nfrom functools import partial\n\nfrom einops import rearrange, reduce\nfrom einops.layers.torch import Rearrange\n\n#from model.diffusion.model import * \nfrom diff_utils.helpers import * \n\nimport numpy as np\nimport os\nfrom statistics import mean\nfrom tqdm.auto import tqdm\nimport open3d as o3d\n\n\n# constants\nModelPrediction =  namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])\n\n\nclass DiffusionModel(nn.Module):\n    def __init__(\n        self,\n        model,\n        timesteps = 1000, sampling_timesteps = None, beta_schedule = 'cosine',\n        sample_pc_size = 682, perturb_pc = None,  crop_percent=0.25,\n        loss_type = 'l2', objective = 'pred_x0', \n        data_scale = 1.0, data_shift = 0.0,\n        p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended\n        p2_loss_weight_k = 1,\n        ddim_sampling_eta = 1.\n    ):\n        super().__init__()\n\n        self.model = model\n        self.objective = objective\n\n        betas = linear_beta_schedule(timesteps) if beta_schedule == 'linear' else cosine_beta_schedule(timesteps)\n        alphas = 1. - betas\n        alphas_cumprod = torch.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n\n        self.pc_size = sample_pc_size\n        self.perturb_pc = perturb_pc \n        self.crop_percent = crop_percent\n        assert self.perturb_pc in [None, \"partial\", \"noisy\"]\n\n        self.loss_fn = F.l1_loss if loss_type=='l1' else F.mse_loss\n\n        # sampling related parameters\n        self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training\n        assert self.sampling_timesteps <= timesteps\n        self.ddim_sampling_eta = ddim_sampling_eta\n        \n        # self.register_buffer('data_scale', torch.tensor(data_scale))\n        # self.register_buffer('data_shift', torch.tensor(data_shift))\n\n        # helper function to register buffer from float64 to float32\n        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))\n        \n\n        register_buffer('betas', betas)\n        register_buffer('alphas_cumprod', alphas_cumprod)\n        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))\n        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))\n        register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))\n        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))\n        register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        register_buffer('posterior_variance', posterior_variance)\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))\n        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))\n        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))\n        \n        # calculate p2 reweighting\n        register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma)\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n            extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def predict_noise_from_start(self, x_t, t, x0):\n        return (\n            (x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \\\n            extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n        )\n\n    @torch.no_grad()\n    def ddim_sample(self, dim, batch_size, noise=None, clip_denoised = True, traj=False, cond=None):\n        batch, device, total_timesteps, sampling_timesteps, eta, objective = batch_size, self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective\n        times = torch.linspace(0., total_timesteps, steps = sampling_timesteps + 2)[:-1]\n        times = list(reversed(times.int().tolist()))\n        time_pairs = list(zip(times[:-1], times[1:]))\n\n        traj = []\n\n        x_T = default(noise, torch.randn(batch, dim, device = device)) \n\n        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):\n            alpha = self.alphas_cumprod_prev[time]\n            alpha_next = self.alphas_cumprod_prev[time_next]\n\n            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)\n\n            model_input = (x_T, cond) if cond is not None else x_T\n            pred_noise, x_start, *_ = self.model_predictions(model_input, time_cond)\n\n            if clip_denoised:\n                x_start.clamp_(-1., 1.)\n\n            sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()\n            c = ((1 - alpha_next) - sigma ** 2).sqrt()\n\n            noise = torch.randn_like(x_T) if time_next > 0 else 0.\n\n            x_T = x_start * alpha_next.sqrt() + \\\n                  c * pred_noise + \\\n                  sigma * noise\n            \n            traj.append(x_T.clone())\n    \n        if traj:\n            return x_T, traj\n        else:\n            return x_T\n\n    @torch.no_grad()\n    def sample(self, dim, batch_size, noise=None, clip_denoised = True, traj=False, cond=None):\n\n        batch, device, objective = batch_size, self.betas.device, self.objective\n\n        traj = []\n\n        x_T = default(noise, torch.randn(batch, dim, device = device))\n\n        for t in reversed(range(0, self.num_timesteps)):\n            \n            time_cond = torch.full((batch,), t, device = device, dtype = torch.long)\n\n            model_input = (x_T, cond) if cond is not None else x_T\n            pred_noise, x_start, *_ = self.model_predictions(model_input, time_cond)\n            if clip_denoised:\n                x_start.clamp_(-1., 1.)\n\n            model_mean, _, model_log_variance = self.q_posterior(x_start = x_start, x_t = x_T, t = time_cond)\n\n            noise = torch.randn_like(x_T) if t > 0 else 0. # no noise if t == 0\n\n            x_T = model_mean + (0.5 * model_log_variance).exp() * noise\n            \n            traj.append(x_T.clone())\n    \n        if traj:\n            return x_T, traj\n        else:\n            return x_T\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n            extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n            extract(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n\n    # \"nice property\": return x_t given x_0, noise, and timestep\n    def q_sample(self, x_start, t, noise=None):\n        \n        noise = default(noise, lambda: torch.randn_like(x_start))\n        #noise = torch.clamp(noise, min=-6.0, max=6.0)\n\n        return (\n            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise\n        )\n\n    # main function for calculating loss\n    def forward(self, x_start, t, ret_pred_x=False, noise = None, cond=None):\n        '''\n        x_start: [B, D]\n        t: [B]\n        '''\n\n        noise = default(noise, lambda: torch.randn_like(x_start)) \n\n        x = self.q_sample(x_start=x_start, t=t, noise=noise) \n\n        model_in = (x, cond) if cond is not None else x\n        model_out = self.model(model_in, t)\n\n        if self.objective == 'pred_noise':\n            target = noise\n        elif self.objective == 'pred_x0':\n            target = x_start\n        else:\n            raise ValueError(f'unknown objective {self.objective}')\n\n        loss = self.loss_fn(model_out, target, reduction = 'none')\n        #loss = reduce(loss, 'b ... -> b (...)', 'mean', b = x_start.shape[0]) # only one dim of latent so don't need this line \n        \n        loss = loss * extract(self.p2_loss_weight, t, loss.shape)\n        unreduced_loss = loss.detach().clone().mean(dim=1)\n        \n        if ret_pred_x:\n            return loss.mean(), x, target, model_out, unreduced_loss\n        else:\n            return loss.mean(), unreduced_loss\n\n    def model_predictions(self, model_input, t):\n        \n        #model_output1 = self.model(model_input, t, pass_cond=0)\n        #model_output2 = self.model(model_input, t, pass_cond=1)\n        #model_output = model_output2*5 - model_output1*4\n        model_output = self.model(model_input, t, pass_cond=1)\n\n        x = model_input[0] if type(model_input) is tuple else model_input\n\n        if self.objective == 'pred_noise':\n            pred_noise = model_output\n            x_start = self.predict_start_from_noise(x, t, model_output)\n\n        elif self.objective == 'pred_x0':\n            pred_noise = self.predict_noise_from_start(x, t, model_output)\n            x_start = model_output\n\n        return ModelPrediction(pred_noise, x_start)\n\n\n    \n    # a wrapper function that only takes x_start (clean modulation vector) and condition\n    # does everything including sampling timestep and returns loss, loss_100, loss_1000, prediction\n    def diffusion_model_from_latent(self, x_start, cond=None):\n        #if self.perturb_pc is None and cond is not None:\n        #    print(\"check whether to pass condition!!!\")\n\n        # STEP 1: sample timestep \n        t = torch.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device).long()\n\n        # STEP 2: perturb condition\n        pc = perturb_point_cloud(cond, self.perturb_pc, self.pc_size, self.crop_percent) if cond is not None else None\n\n        # STEP 3: pass to forward function\n        loss, x, target, model_out, unreduced_loss = self(x_start, t, cond=pc, ret_pred_x=True)\n        loss_100 = unreduced_loss[t<100].mean().detach()\n        loss_1000 = unreduced_loss[t>100].mean().detach()\n\n        return loss, loss_100, loss_1000, model_out, pc\n\n\n    def generate_from_pc(self, pc, load_pc=False, batch=5, save_pc=False, return_pc=False, ddim=False, perturb_pc=True):\n        self.eval()\n\n        with torch.no_grad():\n            if load_pc:\n                pc = sample_pc(pc, self.pc_size).cuda().unsqueeze(0)\n\n            if pc is None:\n                input_pc = None\n                save_pc = False\n                full_perturbed_pc = None\n\n            else:\n                if perturb_pc:\n                    full_perturbed_pc = perturb_point_cloud(pc, self.perturb_pc)\n                    perturbed_pc = full_perturbed_pc[:, torch.randperm(full_perturbed_pc.shape[1])[:self.pc_size] ]\n                    input_pc = perturbed_pc.repeat(batch, 1, 1)\n                else:\n                    full_perturbed_pc = pc\n                    perturbed_pc = pc\n                    input_pc = pc.repeat(batch, 1, 1)\n\n            #print(\"shapes: \", pc.shape, self.pc_size, self.perturb_pc, perturbed_pc.shape, full_perturbed_pc.shape)\n            #print(\"pc path: \", pc_path)\n\n            #print(\"pc shape: \", perturbed_pc.shape, input_pc.shape)\n            if save_pc: # save perturbed pc ply file for visualization\n                pcd = o3d.geometry.PointCloud()\n                pcd.points = o3d.utility.Vector3dVector(perturbed_pc.cpu().numpy().squeeze())\n                o3d.io.write_point_cloud(\"{}/input_pc.ply\".format(save_pc), pcd)\n            \n            sample_fn = self.ddim_sample if ddim else self.sample\n            samp,_ = sample_fn(dim=self.model.dim_in_out, batch_size=batch, traj=False, cond=input_pc)\n\n        if return_pc:\n            return samp, perturbed_pc\n        return samp\n\n    def generate_unconditional(self, num_samples):\n        self.eval()\n        with torch.no_grad():\n            samp,_ = self.sample(dim=self.model.dim_in_out, batch_size=num_samples, traj=False, cond=None)\n\n        return samp"
  },
  {
    "path": "models/sdf_model.py",
    "content": "#!/usr/bin/env python3\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.functional as F\nimport pytorch_lightning as pl \n\nimport sys\nimport os \nfrom pathlib import Path\nimport numpy as np \nimport math\n\nfrom einops import rearrange, reduce\n\nfrom models.archs.sdf_decoder import * \nfrom models.archs.encoders.conv_pointnet import ConvPointnet\nfrom utils import mesh, evaluate\n\n\nclass SdfModel(pl.LightningModule):\n\n    def __init__(self, specs):\n        super().__init__()\n        \n        self.specs = specs\n        model_specs = self.specs[\"SdfModelSpecs\"]\n        self.hidden_dim = model_specs[\"hidden_dim\"]\n        self.latent_dim = model_specs[\"latent_dim\"]\n        self.skip_connection = model_specs.get(\"skip_connection\", True)\n        self.tanh_act = model_specs.get(\"tanh_act\", False)\n        self.pn_hidden = model_specs.get(\"pn_hidden_dim\", self.latent_dim)\n\n        self.pointnet = ConvPointnet(c_dim=self.latent_dim, hidden_dim=self.pn_hidden, plane_resolution=64)\n        \n        self.model = SdfDecoder(latent_size=self.latent_dim, hidden_dim=self.hidden_dim, skip_connection=self.skip_connection, tanh_act=self.tanh_act)\n        \n        self.model.train()\n        #print(self.model)\n\n\n    def configure_optimizers(self):\n\n        optimizer = torch.optim.Adam(self.parameters(), self.specs[\"sdf_lr\"])\n        return optimizer\n\n \n    def training_step(self, x, idx):\n\n        xyz = x['xyz'] # (B, 16000, 3)\n        gt = x['gt_sdf'] # (B, 16000)\n        pc = x['point_cloud'] # (B, 1024, 3)\n\n        shape_features = self.pointnet(pc, xyz)\n\n        pred_sdf = self.model(xyz, shape_features)\n\n        sdf_loss = F.l1_loss(pred_sdf.squeeze(), gt.squeeze(), reduction = 'none')\n        sdf_loss = reduce(sdf_loss, 'b ... -> b (...)', 'mean').mean()\n    \n        return sdf_loss \n            \n    \n\n    def forward(self, pc, xyz):\n        shape_features = self.pointnet(pc, xyz)\n\n        return self.model(xyz, shape_features).squeeze()\n\n    def forward_with_plane_features(self, plane_features, xyz):\n        '''\n        plane_features: B, D*3, res, res (e.g. B, 768, 64, 64)\n        xyz: B, N, 3\n        '''\n        point_features = self.pointnet.forward_with_plane_features(plane_features, xyz) # point_features: B, N, D\n        pred_sdf = self.model( torch.cat((xyz, point_features),dim=-1) )  \n        return pred_sdf # [B, num_points] \n"
  },
  {
    "path": "test.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.utils.data \nfrom torch.nn import functional as F\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor\nfrom pytorch_lightning import loggers as pl_loggers\n\nimport os\nimport json, csv\nimport time\nfrom tqdm.auto import tqdm\nfrom einops import rearrange, reduce\nimport numpy as np\nimport trimesh\nimport warnings\n\n# add paths in model/__init__.py for new models\nfrom models import * \nfrom utils import mesh, evaluate\nfrom utils.reconstruct import *\nfrom diff_utils.helpers import * \n#from metrics.evaluation_metrics import *#compute_all_metrics\n#from metrics import evaluation_metrics\n\nfrom dataloader.pc_loader import PCloader\n\n@torch.no_grad()\ndef test_modulations():\n    \n    # load dataset, dataloader, model checkpoint\n    test_split = json.load(open(specs[\"TestSplit\"]))\n    test_dataset = PCloader(specs[\"DataSource\"], test_split, pc_size=specs.get(\"PCsize\",1024), return_filename=True)\n    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=0)\n\n    ckpt = \"{}.ckpt\".format(args.resume) if args.resume=='last' else \"epoch={}.ckpt\".format(args.resume)\n    resume = os.path.join(args.exp_dir, ckpt)\n    model = CombinedModel.load_from_checkpoint(resume, specs=specs).cuda().eval()\n\n    # filename for logging chamfer distances of reconstructed meshes\n    cd_file = os.path.join(recon_dir, \"cd.csv\") \n\n    with tqdm(test_dataloader) as pbar:\n        for idx, data in enumerate(pbar):\n            pbar.set_description(\"Files evaluated: {}/{}\".format(idx, len(test_dataloader)))\n\n            point_cloud, filename = data # filename = path to the csv file of sdf data\n            filename = filename[0] # filename is a tuple\n\n            cls_name = filename.split(\"/\")[-3]\n            mesh_name = filename.split(\"/\")[-2]\n            outdir = os.path.join(recon_dir, \"{}/{}\".format(cls_name, mesh_name))\n            os.makedirs(outdir, exist_ok=True)\n            mesh_filename = os.path.join(outdir, \"reconstruct\")\n            \n            # given point cloud, create modulations (e.g. 1D latent vectors)\n            plane_features = model.sdf_model.pointnet.get_plane_features(point_cloud.cuda())  # tuple, 3 items with ([1, D, resolution, resolution])\n            plane_features = torch.cat(plane_features, dim=1) # ([1, D*3, resolution, resolution])\n            recon = model.vae_model.generate(plane_features) # ([1, D*3, resolution, resolution])\n            #print(\"mesh filename: \", mesh_filename)\n            # N is the grid resolution for marching cubes; set max_batch to largest number gpu can hold\n            mesh.create_mesh(model.sdf_model, recon, mesh_filename, N=256, max_batch=2**21, from_plane_features=True)\n\n            # load the created mesh (mesh_filename), and compare with input point cloud\n            # to calculate and log chamfer distance \n            mesh_log_name = cls_name+\"/\"+mesh_name\n            try:\n                evaluate.main(point_cloud, mesh_filename, cd_file, mesh_log_name)\n            except Exception as e:\n                print(e)\n\n\n            # save modulation vectors for training diffusion model for next stage\n            # filter based on the chamfer distance so that all training data for diffusion model is clean \n            # would recommend visualizing some reconstructed meshes and manually determining what chamfer distance threshold to use\n            try:\n                # skips modulations that have chamfer distance > 0.0018\n                # the filter also weighs gaps / empty space higher\n                if not filter_threshold(mesh_filename, point_cloud, 0.0018): \n                    continue\n                outdir = os.path.join(latent_dir, \"{}/{}\".format(cls_name, mesh_name))\n                os.makedirs(outdir, exist_ok=True)\n                features = model.sdf_model.pointnet.get_plane_features(point_cloud.cuda())\n                features = torch.cat(features, dim=1) # ([1, D*3, resolution, resolution])\n                latent = model.vae_model.get_latent(features) # (1, D*3)\n                np.savetxt(os.path.join(outdir, \"latent.txt\"), latent.cpu().numpy())\n            except Exception as e:\n                print(e)\n\n\n           \n@torch.no_grad()\ndef test_generation():\n\n    # load model \n    if args.resume == 'finetune': # after second stage of training \n        with warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n\n            # loads the sdf and vae models\n            model = CombinedModel.load_from_checkpoint(specs[\"modulation_ckpt_path\"], specs=specs, strict=False) \n\n            # loads the diffusion model; directly calling diffusion_model.load_state_dict to prevent overwriting sdf and vae params\n            ckpt = torch.load(specs[\"diffusion_ckpt_path\"])\n            new_state_dict = {}\n            for k,v in ckpt['state_dict'].items():\n                new_key = k.replace(\"diffusion_model.\", \"\") # remove \"diffusion_model.\" from keys since directly loading into diffusion model\n                new_state_dict[new_key] = v\n            model.diffusion_model.load_state_dict(new_state_dict)\n\n            model = model.cuda().eval()\n    else:\n        ckpt = \"{}.ckpt\".format(args.resume) if args.resume=='last' else \"epoch={}.ckpt\".format(args.resume)\n        resume = os.path.join(args.exp_dir, ckpt)\n        model = CombinedModel.load_from_checkpoint(resume, specs=specs).cuda().eval()\n\n    conditional = specs[\"diffusion_model_specs\"][\"cond\"] \n\n    if not conditional:\n        samples = model.diffusion_model.generate_unconditional(args.num_samples)\n        plane_features = model.vae_model.decode(samples)\n        for i in range(len(plane_features)):\n            plane_feature = plane_features[i].unsqueeze(0)\n            mesh.create_mesh(model.sdf_model, plane_feature, recon_dir+\"/{}_recon\".format(i), N=128, max_batch=2**21, from_plane_features=True)\n            \n    else:\n        # load dataset, dataloader, model checkpoint\n        test_split = json.load(open(specs[\"TestSplit\"]))\n        test_dataset = PCloader(specs[\"DataSource\"], test_split, pc_size=specs.get(\"PCsize\",1024), return_filename=True)\n        test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, num_workers=0)\n\n        with tqdm(test_dataloader) as pbar:\n            for idx, data in enumerate(pbar):\n                pbar.set_description(\"Files generated: {}/{}\".format(idx, len(test_dataloader)))\n\n                point_cloud, filename = data # filename = path to the csv file of sdf data\n                filename = filename[0] # filename is a tuple\n\n                cls_name = filename.split(\"/\")[-3]\n                mesh_name = filename.split(\"/\")[-2]\n                outdir = os.path.join(recon_dir, \"{}/{}\".format(cls_name, mesh_name))\n                os.makedirs(outdir, exist_ok=True)\n\n                # filter, set threshold manually after a few visualizations\n                if args.filter:\n                    threshold = 0.08\n                    tmp_lst = []\n                    count = 0\n                    while len(tmp_lst)<args.num_samples:\n                        count+=1\n                        samples, perturbed_pc = model.diffusion_model.generate_from_pc(point_cloud.cuda(), batch=args.num_samples, save_pc=outdir, return_pc=True) # batch should be set to max number GPU can hold\n                        plane_features = model.vae_model.decode(samples)\n                        # predicting the sdf values of the point cloud\n                        perturbed_pc_pred = model.sdf_model.forward_with_plane_features(plane_features, perturbed_pc.repeat(args.num_samples, 1, 1))\n                        consistency = F.l1_loss(perturbed_pc_pred, torch.zeros_like(perturbed_pc_pred), reduction='none')\n                        loss = reduce(consistency, 'b ... -> b', 'mean', b = consistency.shape[0]) # one value per generated sample \n                        #print(\"consistency shape: \", consistency.shape, loss.shape, consistency[0].mean(), consistency[1].mean(), loss) # cons: [B,N]; loss: [B]\n                        thresh_idx = loss<=threshold\n                        tmp_lst.extend(plane_features[thresh_idx])\n\n                        if count > 5: # repeat this filtering process as needed \n                            break\n                    # skip the point cloud if cannot produce consistent samples or \n                    # just use the samples that are produced if comparing to other methods\n                    if len(tmp_lst)<1: \n                        continue\n                    plane_features = tmp_lst[0:min(10,len(tmp_lst))]\n\n                else:\n                    # for each point cloud, the partial pc and its conditional generations are all saved in the same directory \n                    samples, perturbed_pc = model.diffusion_model.generate_from_pc(point_cloud.cuda(), batch=args.num_samples, save_pc=outdir, return_pc=True)\n                    plane_features = model.vae_model.decode(samples)\n                \n                for i in range(len(plane_features)):\n                    plane_feature = plane_features[i].unsqueeze(0)\n                    mesh.create_mesh(model.sdf_model, plane_feature, outdir+\"/{}_recon\".format(i), N=128, max_batch=2**21, from_plane_features=True)\n            \n\n\n    \nif __name__ == \"__main__\":\n\n    import argparse\n\n    arg_parser = argparse.ArgumentParser()\n    arg_parser.add_argument(\n        \"--exp_dir\", \"-e\", required=True,\n        help=\"This directory should include experiment specifications in 'specs.json,' and logging will be done in this directory as well.\",\n    )\n    arg_parser.add_argument(\n        \"--resume\", \"-r\", default=None,\n        help=\"continue from previous saved logs, integer value, 'last', or 'finetune'\",\n    )\n\n    arg_parser.add_argument(\"--num_samples\", \"-n\", default=5, type=int, help='number of samples to generate and reconstruct')\n\n    arg_parser.add_argument(\"--filter\", default=False, help='whether to filter when sampling conditionally')\n\n    args = arg_parser.parse_args()\n    specs = json.load(open(os.path.join(args.exp_dir, \"specs.json\")))\n    print(specs[\"Description\"])\n\n\n    recon_dir = os.path.join(args.exp_dir, \"recon\")\n    os.makedirs(recon_dir, exist_ok=True)\n    \n    if specs['training_task'] == 'modulation':\n        latent_dir = os.path.join(args.exp_dir, \"modulations\")\n        os.makedirs(latent_dir, exist_ok=True)\n        test_modulations()\n    elif specs['training_task'] == 'combined':\n        test_generation()\n\n  \n"
  },
  {
    "path": "train.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.utils.data \nfrom torch.nn import functional as F\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor\nfrom pytorch_lightning import loggers as pl_loggers\n\nimport os\nimport json, csv\nimport time\nfrom tqdm.auto import tqdm\nfrom einops import rearrange, reduce\nimport numpy as np\nimport trimesh\nimport warnings\n\n# add paths in model/__init__.py for new models\nfrom models import * \nfrom utils import mesh, evaluate\nfrom utils.reconstruct import *\nfrom diff_utils.helpers import * \n#from metrics.evaluation_metrics import *#compute_all_metrics\n#from metrics import evaluation_metrics\n\nfrom dataloader.pc_loader import PCloader\nfrom dataloader.sdf_loader import SdfLoader\nfrom dataloader.modulation_loader import ModulationLoader\n\n\ndef train():\n    \n    # initialize dataset and loader\n    split = json.load(open(specs[\"TrainSplit\"], \"r\"))\n    if specs['training_task'] == 'diffusion':\n        train_dataset = ModulationLoader(specs[\"data_path\"], pc_path=specs.get(\"pc_path\",None), split_file=split, pc_size=specs.get(\"total_pc_size\", None))\n    else:\n        train_dataset = SdfLoader(specs[\"DataSource\"], split, pc_size=specs.get(\"PCsize\",1024), grid_source=specs.get(\"GridSource\", None), modulation_path=specs.get(\"modulation_path\", None))\n    train_dataloader = torch.utils.data.DataLoader(\n            train_dataset,\n            batch_size=args.batch_size, num_workers=args.workers,\n            drop_last=True, shuffle=True, pin_memory=True, persistent_workers=True\n        )\n\n    # creates a copy of current code / files in the config folder\n    save_code_to_conf(args.exp_dir) \n    \n    # pytorch lightning callbacks \n    callback = ModelCheckpoint(dirpath=args.exp_dir, filename='{epoch}', save_top_k=-1, save_last=True, every_n_epochs=specs[\"log_freq\"])\n    lr_monitor = LearningRateMonitor(logging_interval='step')\n    callbacks = [callback, lr_monitor]\n\n    model = CombinedModel(specs)\n\n    # note on loading from checkpoint:\n    # if resuming from training modulation, diffusion, or end-to-end, just load saved checkpoint \n    # however, if fine-tuning end-to-end after training modulation and diffusion separately, will need to load sdf and diffusion checkpoints separately\n    if args.resume == 'finetune':\n        with warnings.catch_warnings():\n            warnings.simplefilter(\"ignore\")\n            model = model.load_from_checkpoint(specs[\"modulation_ckpt_path\"], specs=specs, strict=False)\n            # loads the diffusion model; directly calling diffusion_model.load_state_dict to prevent overwriting sdf and vae params\n            ckpt = torch.load(specs[\"diffusion_ckpt_path\"])\n            new_state_dict = {}\n            for k,v in ckpt['state_dict'].items():\n                new_key = k.replace(\"diffusion_model.\", \"\") # remove \"diffusion_model.\" from keys since directly loading into diffusion model\n                new_state_dict[new_key] = v\n            model.diffusion_model.load_state_dict(new_state_dict)\n        resume = None\n    elif args.resume is not None:\n        ckpt = \"{}.ckpt\".format(args.resume) if args.resume=='last' else \"epoch={}.ckpt\".format(args.resume)\n        resume = os.path.join(args.exp_dir, ckpt)\n    else:\n        resume = None  \n\n    # precision 16 can be unstable (nan loss); recommend using 32\n    trainer = pl.Trainer(accelerator='gpu', devices=-1, precision=32, max_epochs=specs[\"num_epochs\"], callbacks=callbacks, log_every_n_steps=1,\n                        default_root_dir=os.path.join(\"tensorboard_logs\", args.exp_dir))\n    trainer.fit(model=model, train_dataloaders=train_dataloader, ckpt_path=resume)\n\n    \n\n    \nif __name__ == \"__main__\":\n\n    import argparse\n\n    arg_parser = argparse.ArgumentParser()\n    arg_parser.add_argument(\n        \"--exp_dir\", \"-e\", required=True,\n        help=\"This directory should include experiment specifications in 'specs.json,' and logging will be done in this directory as well.\",\n    )\n    arg_parser.add_argument(\n        \"--resume\", \"-r\", default=None,\n        help=\"continue from previous saved logs, integer value, 'last', or 'finetune'\",\n    )\n\n    arg_parser.add_argument(\"--batch_size\", \"-b\", default=32, type=int)\n    arg_parser.add_argument( \"--workers\", \"-w\", default=8, type=int)\n\n    args = arg_parser.parse_args()\n    specs = json.load(open(os.path.join(args.exp_dir, \"specs.json\")))\n    print(specs[\"Description\"])\n\n\n    train()"
  },
  {
    "path": "utils/__init__.py",
    "content": "#!/usr/bin/env python3\n\n"
  },
  {
    "path": "utils/chamfer.py",
    "content": "import numpy as np\nfrom scipy.spatial import cKDTree as KDTree\n\n\ndef compute_trimesh_chamfer(gt_points, gen_points, offset=0, scale=1):\n    \"\"\"\n    This function computes a symmetric chamfer distance, i.e. the sum of both chamfers.\n    gt_points: numpy array. trimesh.points.PointCloud of just poins, sampled from the surface (see\n               compute_metrics.ply for more documentation)\n    gen_mesh: numpy array. trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction\n              method (see compute_metrics.py for more)\n    \"\"\"\n\n    # gen_points_sampled = trimesh.sample.sample_surface(gen_mesh, num_mesh_samples)[0]\n\n    gen_points = gen_points / scale - offset\n\n    # one direction\n    gen_points_kd_tree = KDTree(gen_points)\n    one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points)\n    gt_to_gen_chamfer = np.mean(np.square(one_distances))\n\n    # other direction\n    gt_points_kd_tree = KDTree(gt_points)\n    two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points)\n    gen_to_gt_chamfer = np.mean(np.square(two_distances))\n\n    return gt_to_gen_chamfer + gen_to_gt_chamfer\n\n\ndef scale_to_unit_sphere(points):\n    \"\"\"\n    scale point clouds into a unit sphere\n    :param points: (n, 3) numpy array\n    :return:\n    \"\"\"\n    midpoints = (np.max(points, axis=0) + np.min(points, axis=0)) / 2\n    points = points - midpoints\n    scale = np.max(np.sqrt(np.sum(points ** 2, axis=1)))\n    points = points / scale\n    return points\n"
  },
  {
    "path": "utils/evaluate.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport logging\nimport json\nimport numpy as np\nimport pandas as pd \nimport os, sys\nimport trimesh\nfrom scipy.spatial import cKDTree as KDTree\n\nfrom utils import uhd, tmd\n\nimport csv\n\ndef main(gt_pc, recon_mesh, out_file, mesh_name, return_value=False, return_sampled_pc=False, prioritize_cov=False, pc_size=None):\n\n    gt_pc = gt_pc.cpu().detach().numpy().squeeze()\n\n    recon_mesh = trimesh.load(os.path.join(os.getcwd(), recon_mesh)+\".ply\")\n\n    recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, gt_pc.shape[0])\n\n    full_recon_pc = trimesh.sample.sample_surface(recon_mesh, pc_size)[0] if pc_size is not None else recon_pc\n\n    recon_kd_tree = KDTree(recon_pc)\n    one_distances, one_vertex_ids = recon_kd_tree.query(gt_pc)\n    gt_to_recon_chamfer = np.mean(np.square(one_distances))\n\n    # other direction\n    gt_kd_tree = KDTree(gt_pc)\n    two_distances, two_vertex_ids = gt_kd_tree.query(recon_pc)\n    recon_to_gt_chamfer = np.mean(np.square(two_distances))\n    \n    if prioritize_cov: # higher CD for gaps/holes\n        loss_chamfer = gt_to_recon_chamfer * 2.0 + recon_to_gt_chamfer * 0.5\n    else:\n        loss_chamfer = gt_to_recon_chamfer + recon_to_gt_chamfer\n\n    if return_value:\n        return loss_chamfer\n\n    out_file = os.path.join(os.getcwd(), out_file)\n\n    with open(out_file,\"a\",) as f:\n        writer = csv.writer(f)\n        writer.writerow([mesh_name,loss_chamfer])\n\n    if return_sampled_pc:\n        return full_recon_pc, loss_chamfer\n\n\ndef calc_cd(gt_pc, recon_pc):\n\n    gt_pc = gt_pc.cpu().detach().numpy().squeeze()\n\n    recon_kd_tree = KDTree(recon_pc)\n    one_distances, one_vertex_ids = recon_kd_tree.query(gt_pc)\n    gt_to_recon_chamfer = np.mean(np.square(one_distances))\n\n    # other direction\n    gt_kd_tree = KDTree(gt_pc)\n    two_distances, two_vertex_ids = gt_kd_tree.query(recon_pc)\n    recon_to_gt_chamfer = np.mean(np.square(two_distances))\n    \n    return gt_to_recon_chamfer + recon_to_gt_chamfer\n\n\ndef single_eval(gt_csv, recon_mesh):\n    # f=pd.read_csv(gt_csv, sep=',',header=None).values\n    # f = f[f[:,-1]==0][:,:3]\n\n    recon_mesh = trimesh.load( recon_mesh )\n    recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, 30000)\n    print(\"recon pc min max: \", recon_pc.max(), recon_pc.min())\n    # load from SIREN .xyz file \n    f = np.genfromtxt(gt_csv)\n    pc = f[:,:3]\n    coord_max = np.amax(pc, axis=0, keepdims=True)\n    coord_min = np.amin(pc, axis=0, keepdims=True)\n    coords = (pc - coord_min) / (coord_max - coord_min)\n    coords -= 0.5\n    coords *= 2.\n    # pc -= np.mean(pc, axis=0, keepdims=True)\n    # bbox_length = np.sqrt( np.sum((np.max(pc, axis=0) - np.min(pc, axis=0))**2) )\n    # pc /= bbox_length\n    f = coords\n    print(\"f min max: \", f.max(), f.min())\n\n    pc_idx = np.random.choice(f.shape[0], 30000, replace=False)\n    gt_pc = f[pc_idx] \n\n    recon_mesh = trimesh.load( recon_mesh )\n    recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, 30000)\n\n    recon_kd_tree = KDTree(recon_pc)\n    one_distances, one_vertex_ids = recon_kd_tree.query(gt_pc)\n    gt_to_recon_chamfer = np.mean(np.square(one_distances))\n\n    # other direction\n    gt_kd_tree = KDTree(gt_pc)\n    two_distances, two_vertex_ids = gt_kd_tree.query(recon_pc)\n    recon_to_gt_chamfer = np.mean(np.square(two_distances))\n    \n    loss_chamfer = gt_to_recon_chamfer + recon_to_gt_chamfer\n\n    print(\"CD loss: \", loss_chamfer)\n\n\n\nif __name__ == \"__main__\":\n    single_eval(sys.argv[1], sys.argv[2])\n"
  },
  {
    "path": "utils/mesh.py",
    "content": "#!/usr/bin/env python3\n\nimport logging\nimport math\nimport numpy as np\nimport plyfile\nimport skimage.measure\nimport time\nimport torch\n\n\n# N: resolution of grid; 256 is typically sufficient \n# max batch: as large as GPU memory will allow\n# shape_feature is either point cloud, mesh_idx (neuralpull), or generated latent code (deepsdf)\ndef create_mesh(\n    model, shape_feature, filename, N=256, max_batch=1000000, level_set=0.0, occupancy=False, point_cloud=None, from_plane_features=False, from_pc_features=False\n):\n    \n    start_time = time.time()\n    ply_filename = filename\n\n    model.eval()\n\n    # the voxel_origin is the (bottom, left, down) corner, not the middle\n    voxel_origin = [-1, -1, -1]\n    voxel_size = 2.0 / (N - 1)\n    cube = create_cube(N)\n    cube_points = cube.shape[0]\n\n    head = 0\n    while head < cube_points:\n        \n        query = cube[head : min(head + max_batch, cube_points), 0:3].unsqueeze(0)\n        \n        # inference defined in forward function per pytorch lightning convention\n        #print(\"shapes: \", shape_feature.shape, query.shape)\n        if from_plane_features:\n            pred_sdf = model.forward_with_plane_features(shape_feature.cuda(), query.cuda()).detach().cpu()\n        else:\n            pred_sdf = model(shape_feature.cuda(), query.cuda()).detach().cpu()\n\n        cube[head : min(head + max_batch, cube_points), 3] = pred_sdf.squeeze()\n            \n        head += max_batch\n    \n    # for occupancy instead of SDF, subtract 0.5 so the surface boundary becomes 0\n    sdf_values = cube[:, 3] - 0.5 if occupancy else cube[:, 3] \n    sdf_values = sdf_values.reshape(N, N, N) \n\n    #print(\"inference time: {}\".format(time.time() - start_time))\n\n    convert_sdf_samples_to_ply(\n        sdf_values.data,\n        voxel_origin,\n        voxel_size,\n        ply_filename + \".ply\",\n        level_set\n    )\n\n\n# create cube from (-1,-1,-1) to (1,1,1) and uniformly sample points for marching cube\ndef create_cube(N):\n\n    overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())\n    samples = torch.zeros(N ** 3, 4)\n\n    # the voxel_origin is the (bottom, left, down) corner, not the middle\n    voxel_origin = [-1, -1, -1]\n    voxel_size = 2.0 / (N - 1)\n    \n    # transform first 3 columns\n    # to be the x, y, z index\n    samples[:, 2] = overall_index % N\n    samples[:, 1] = (overall_index.long().float() / N) % N\n    samples[:, 0] = ((overall_index.long().float() / N) / N) % N\n\n    # transform first 3 columns\n    # to be the x, y, z coordinate\n    samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]\n    samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]\n    samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]\n\n    samples.requires_grad = False\n\n    return samples\n\n\n\ndef convert_sdf_samples_to_ply(\n    pytorch_3d_sdf_tensor,\n    voxel_grid_origin,\n    voxel_size,\n    ply_filename_out,\n    level_set=0.0\n):\n    \"\"\"\n    Convert sdf samples to .ply\n\n    :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)\n    :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid\n    :voxel_size: float, the size of the voxels\n    :ply_filename_out: string, path of the filename to save to\n\n    This function adapted from: https://github.com/RobotLocomotion/spartan\n    \"\"\"\n\n    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()\n\n    # use marching_cubes_lewiner or marching_cubes depending on pytorch version \n    try:\n        verts, faces, normals, values = skimage.measure.marching_cubes(\n            numpy_3d_sdf_tensor, level=level_set, spacing=[voxel_size] * 3\n        )\n    except Exception as e:\n        print(\"skipping {}; error: {}\".format(ply_filename_out, e))\n        return\n\n    # transform from voxel coordinates to camera coordinates\n    # note x and y are flipped in the output of marching_cubes\n    mesh_points = np.zeros_like(verts)\n    mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]\n    mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]\n    mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]\n\n    num_verts = verts.shape[0]\n    num_faces = faces.shape[0]\n\n    verts_tuple = np.zeros((num_verts,), dtype=[(\"x\", \"f4\"), (\"y\", \"f4\"), (\"z\", \"f4\")])\n\n    for i in range(0, num_verts):\n        verts_tuple[i] = tuple(mesh_points[i, :])\n\n    faces_building = []\n    for i in range(0, num_faces):\n        faces_building.append(((faces[i, :].tolist(),)))\n    faces_tuple = np.array(faces_building, dtype=[(\"vertex_indices\", \"i4\", (3,))])\n\n    el_verts = plyfile.PlyElement.describe(verts_tuple, \"vertex\")\n    el_faces = plyfile.PlyElement.describe(faces_tuple, \"face\")\n\n    ply_data = plyfile.PlyData([el_verts, el_faces])\n    ply_data.write(ply_filename_out)\n\n\n"
  },
  {
    "path": "utils/mmd.py",
    "content": ""
  },
  {
    "path": "utils/reconstruct.py",
    "content": "#!/usr/bin/env python3\n\nimport torch\nimport torch.utils.data \nfrom torch.nn import functional as F\nimport pytorch_lightning as pl\nfrom pytorch_lightning.callbacks import ModelCheckpoint, Callback\nfrom pytorch_lightning import loggers as pl_loggers\n\nimport os\nimport json\nimport time\nfrom tqdm.auto import tqdm\nfrom einops import rearrange, reduce\nimport numpy as np\nimport trimesh\n\n# add paths in model/__init__.py for new models\nfrom models import * \nfrom utils import mesh, evaluate, reconstruct\nfrom diff_utils.helpers import * \n\ndef vis_recon(test_dataloader, sdf_model, vae_model, recon_dir, take_mod=False, calc_cd=False):\n    resolution = 64\n    recon_batch = 2**20\n    # run visualization of the reconstructed plane features to confirm that recon loss is low enough \n    with torch.no_grad():\n        if args.evaluate:\n            point_clouds, pc_paths = test_dataloader.get_all_files()\n\n            point_clouds = torch.stack(point_clouds) # stack [(1024,3), (1024,3)...] to (B, 1024, 3)\n\n            recon_meshes = torch.empty(*point_clouds.shape)\n\n            # ***change to MODULATION generated paths later!!!!\n            for idx, path in enumerate(pc_paths):\n                #print(\"path: \", path)\n                cls_name = path.split(\"/\")[-3]\n                mesh_name = path.split(\"/\")[-2]\n                mesh_filename = os.path.join(recon_dir, \"{}/{}/reconstruct\".format(cls_name, mesh_name))\n                recon_mesh = trimesh.load(os.path.join(os.getcwd(), mesh_filename)+\".ply\")\n                recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, point_clouds.shape[1])\n                recon_meshes[idx] = torch.from_numpy(recon_pc)\n\n            print(\"ref, recon shapes: \", recon_meshes.shape, point_clouds.shape) # should both be = B, N, 3\n            results = evaluation_metrics.compute_all_metrics(recon_meshes.float(), point_clouds.float(), accelerated_cd=False)\n            for k,v in results.items():\n                print(k, \": \", v)\n\n        elif args.take_mod and not args.sample:\n\n            lst = []\n            if args.mod_folder:\n                files = os.listdir(args.mod_folder)\n                for f in files:\n                    if os.path.isfile(os.path.join(args.mod_folder, f)) and f[-4:]=='.txt':\n                        lst.append(os.path.join(args.mod_folder, f))\n            else:\n                lst = args.take_mod\n\n            for idx, m in enumerate(lst):\n                latent = torch.from_numpy(np.loadtxt(m)).float().cuda()\n                recon = vae_model.decode(latent) \n                name = args.output_name if args.output_name else \"mod_recon\"\n                name += \"{}\".format(idx)\n                os.makedirs(os.path.join(recon_dir, \"modulation_recon\"), exist_ok=True)\n                mesh_filename = os.path.join(recon_dir, \"modulation_recon\", name)\n                mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)\n        elif args.sample:\n            recon = vae_model.sample(num_samples=1)\n            name = args.output_name if args.output_name else \"mod_recon\"\n            os.makedirs(os.path.join(recon_dir, \"modulation_recon\"), exist_ok=True)\n            mesh_filename = os.path.join(recon_dir, \"modulation_recon\", name)\n            mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)\n        else:\n            for idx, data in enumerate(test_dataloader): # test_loader does not shuffle \n                # if idx % 10 != 0:\n                #     continue\n                data, filename = data # filename = path to the csv file of sdf data\n                filename = filename[0] # filename is a tuple for some reason\n\n\n                random_flip = specs.get(\"random_flip\", False)\n                # if random_flip:\n                #     flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=data.device)\n                #     prob = torch.randint(low=0, high=4, size=(1,))\n                #     flip_axis = flip_axes[prob] # shape=[1,3]\n                #     data *= flip_axis.unsqueeze(0).repeat(data.shape[0], data.shape[1], 1)\n\n                # if random_flip:\n                #     flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=data.device)\n                #     for axis in flip_axes:\n                #         flipped_data = data * axis.unsqueeze(0).repeat(data.shape[0], data.shape[1], 1)\n\n\n                cls_name = filename.split(\"/\")[-3]\n                mesh_name = filename.split(\"/\")[-2]\n                outdir = os.path.join(recon_dir, \"{}/{}\".format(cls_name, mesh_name))\n                os.makedirs(outdir, exist_ok=True)\n                mesh_filename = os.path.join(outdir, \"reconstruct\")\n               \n                plane_features = sdf_model.pointnet.get_plane_features(data.cuda())  # 3 items with ([1, 256, 64, 64])\n                plane_features = torch.cat(plane_features, dim=1) # ([1, 768, 64, 64])\n                recon = vae_model.generate(plane_features) # ([1, 768, 64, 64])\n\n                # create_mesh samples the grid points, then calls sdf_model.forward_with_plane_features, which calls pointnet.forward_with_plane_features\n                #print(\"mesh filename: \", mesh_filename)\n                mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)\n\n                \n                if calc_cd:\n                    evaluate_filename = os.path.join(recon_dir, \"cd.csv\")\n                    mesh_log_name = cls_name+\"/\"+mesh_name\n                    try:\n                        evaluate.main(data, mesh_filename, evaluate_filename, mesh_log_name)\n                    except Exception as e:\n                        print(e)\n\n                # try:\n                #     if not filter_threshold(mesh_filename, data, 0.0018):\n                #         continue\n                #     outdir = os.path.join(latent_dir, \"{}/{}\".format(cls_name, mesh_name))\n                #     os.makedirs(outdir, exist_ok=True)\n                #     features = sdf_model.pointnet.get_plane_features(data.cuda())\n                #     #print(\"features shape: \", features[0].shape) # ([1, 256, 64, 64])\n                #     features = torch.cat(features, dim=1)\n                #     latent = vae_model.get_latent(features)\n\n                #     #print(\"latent shape: \", latent.shape)\n                #     np.savetxt(os.path.join(outdir, \"latent.txt\"), latent.cpu().numpy())\n                # except Exception as e:\n                #     print(e)\n\n\n           \n\ndef filter_threshold(mesh, gt_pc, threshold): # mesh is path to mesh without .ply ext\n    cd = evaluate.main(gt_pc, mesh, None, None, return_value=True, prioritize_cov=True)\n    return cd <= threshold\n\n\n\ndef extract_latents(test_dataloader, sdf_model, vae_model, save_dir):\n    # only extract the latent vectors \n    latent_dir = os.path.join(save_dir, \"modulations\")\n    os.makedirs(latent_dir, exist_ok=True)\n    with torch.no_grad():\n        for idx, data in enumerate(test_dataloader): # test_loader does not shuffle \n\n            data, filename = data\n            filename = filename[0] # filename is a tuple for some reason\n            cls_name = filename.split(\"/\")[-3]\n            mesh_name = filename.split(\"/\")[-2]\n\n            # if filtering based on CD threshold\n            saved_mesh = os.path.join(recon_dir, \"{}/{}/reconstruct\".format(cls_name, mesh_name))\n            gt_pc = data\n            try:\n                if not filter_threshold(saved_mesh, gt_pc, 0.0022):\n                    continue\n\n                outdir = os.path.join(latent_dir, \"{}/{}\".format(cls_name, mesh_name))\n                os.makedirs(outdir, exist_ok=True)\n\n                random_flip = specs.get(\"random_flip\", False)\n                if random_flip:\n                    flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=data.device)\n                    for idx, axis in enumerate(flip_axes):\n                        flipped_data = data * axis.unsqueeze(0).repeat(data.shape[0], data.shape[1], 1)\n                \n                        features = sdf_model.pointnet.get_plane_features(flipped_data.cuda())\n                        #print(\"features shape: \", features[0].shape) # ([1, 256, 64, 64])\n                        features = torch.cat(features, dim=1)\n                        latent = vae_model.get_latent(features)\n\n                        #print(\"latent shape: \", latent.shape)\n                        np.savetxt(os.path.join(outdir, \"latent_{}.txt\".format(idx)), latent.cpu().numpy())\n                \n                else:\n                    features = sdf_model.pointnet.get_plane_features(data.cuda())\n                    #print(\"features shape: \", features[0].shape) # ([1, 256, 64, 64])\n                    features = torch.cat(features, dim=1)\n                    latent = vae_model.get_latent(features)\n\n                    #print(\"latent shape: \", latent.shape)\n                    np.savetxt(os.path.join(outdir, \"latent.txt\"), latent.cpu().numpy())\n\n            except Exception as e:\n                print(e)\n\n"
  },
  {
    "path": "utils/renderer.py",
    "content": "import numpy as np\nimport trimesh\nimport trimesh.transformations as tra\nimport pyrender\nimport matplotlib.pyplot as plt \nfrom scipy.spatial.transform import Rotation as R\n\nCOLORS = [\n    np.array([255, 10, 10, 255]), \n    np.array([10, 255, 10, 255]), \n    np.array([10, 234, 255, 255])\n]\nclass OnlineObjectRenderer:\n    def __init__(self, fov=np.pi / 6, caching=True):\n        \"\"\"\n        Args:\n          fov: float, \n        \"\"\"\n        self._fov = fov\n        self._scene = None\n        self._init_scene()\n        self._caching = caching\n        self._nodes = []\n\n    def _init_scene(self, height=480, width=480):\n        self._scene = pyrender.Scene()\n        camera = pyrender.PerspectiveCamera(yfov=self._fov, znear=0.001) # do not change aspect ratio\n        camera_pose = tra.euler_matrix(np.pi, 0, 0)\n        self._scene.add(camera, pose=camera_pose, name='camera')\n        direc_l = pyrender.DirectionalLight(color=np.ones(3), intensity=1.0)\n        self._scene.add(direc_l, pose=camera_pose)\n        self.renderer = pyrender.OffscreenRenderer(height, width)\n        \n    def add_mesh(self, path, name, rotation=None, translation=None):\n        \n        mesh = trimesh.load(path)\n        color = np.tile(COLORS[len(self._nodes) % len(COLORS)], (mesh.vertices.shape[0], 1))\n        # mesh_mean = np.mean(mesh.vertices, 0)\n        # mesh.vertices -= np.expand_dims(mesh_mean, 0)\n        mesh.visual.vertex_colors = color\n        mesh = pyrender.Mesh.from_trimesh(mesh.copy(), smooth=False)\n        if rotation is not None and (isinstance(rotation, list) or isinstance(rotation, tuple)):\n            rotation = np.array(rotation)\n        if rotation is not None and rotation.shape == (3, 3):\n            rotation = R.from_matrix(rotation).as_quat()\n        elif rotation is not None and rotation.shape == (3,):\n            rotation = R.from_euler(\"xyz\", rotation).as_quat()\n        if translation is None:\n            translation = np.array((0, 0, 3))\n        else:\n            translation = translation.copy()\n            translation[2] += 3.\n        node = pyrender.Node(mesh=mesh, rotation=rotation, translation=translation, name=name)\n        self._scene.add_node(node)\n        self._nodes.append(node)\n        return name\n    \n    def add_pointcloud(self, path, name, colors=None, rotation=None, translation=None):\n        pc = np.loadtxt(path, delimiter=\",\")\n        sm = trimesh.creation.uv_sphere(radius=0.008)\n        if colors is None:\n            colors = COLORS[len(self._nodes) % len(COLORS)]\n        sm.visual.vertex_colors = colors\n        tfs = np.tile(np.eye(4), (len(pc), 1, 1))\n        tfs[:,:3,3] = pc\n        if rotation is not None and (isinstance(rotation, list) or isinstance(rotation, tuple)):\n            rotation = np.array(rotation)\n        if rotation is not None and rotation.shape == (3, 3):\n            rotation = R.from_matrix(rotation).as_quat()\n        elif rotation is not None and rotation.shape == (3,):\n            rotation = R.from_euler(\"xyz\", rotation).as_quat()\n        if translation is None:\n            translation = np.array((0, 0, 3))\n        else:\n            translation = translation.copy()\n            translation[2] += 3.\n        # pts = rotation @ pc + translation\n        mesh = pyrender.Mesh.from_trimesh(sm, poses=tfs)\n        node = pyrender.Node(mesh=mesh, rotation=rotation, translation=translation, name=name)\n        self._scene.add_node(node)\n        self._nodes.append(node)\n        return name\n        \n    \n    def clear(self):\n        for node in self._nodes:\n            self._scene.remove_node(node)\n        self._nodes = []\n    \n    def render(self):\n        return self.renderer.render(self._scene)\n\nif __name__ == \"__main__\":\n    renderer = OnlineObjectRenderer()\n    renderer.add_mesh(\n        \"/path/to/mesh\", \n        \"o1\", # name of output\n        np.array([np.pi/2, -np.pi/2, 0]),\n        np.array([0, .5, 0])\n    )\n    img, dp = renderer.render()\n    plt.axis(\"off\")\n    plt.imshow(img)\n    plt.show()    "
  },
  {
    "path": "utils/tmd.py",
    "content": "import argparse\nimport os\nimport numpy as np\nimport trimesh\nfrom utils.chamfer import compute_trimesh_chamfer\nimport glob\n\n\ndef process_one(shape_dir):\n    pc_paths = glob.glob(os.path.join(shape_dir, \"fake-z*.ply\"))\n    pc_paths = sorted(pc_paths)\n    gen_pcs = []\n    for path in pc_paths:\n        sample_pts = trimesh.load(path)\n        sample_pts = sample_pts.vertices\n        gen_pcs.append(sample_pts)\n\n    sum_dist = 0\n    for j in range(len(gen_pcs)):\n        for k in range(j + 1, len(gen_pcs), 1):\n            pc1 = gen_pcs[j]\n            pc2 = gen_pcs[k]\n            chamfer_dist = compute_trimesh_chamfer(pc1, pc2)\n            sum_dist += chamfer_dist\n    mean_dist = sum_dist * 2 / (len(gen_pcs) - 1)\n    return mean_dist\n\ndef tmd_from_pcs(gen_pcs):\n    sum_dist = 0\n    for j in range(len(gen_pcs)):\n        for k in range(j + 1, len(gen_pcs), 1):\n            pc1 = gen_pcs[j]\n            pc2 = gen_pcs[k]\n            chamfer_dist = compute_trimesh_chamfer(pc1, pc2)\n            sum_dist += chamfer_dist\n    mean_dist = sum_dist * 2 / (len(gen_pcs) - 1)\n    return mean_dist\n\n\ndef Total_Mutual_Difference(args):\n    shape_names = sorted(os.listdir(args.src))\n    res = 0\n    all_shape_dir = [os.path.join(args.src, name) for name in shape_names]\n\n    results = Parallel(n_jobs=args.process, verbose=2)(delayed(process_one)(path) for path in all_shape_dir)\n\n    info_path = args.src + '-record_meandist.txt'\n    with open(info_path, 'w') as fp:\n        for i in range(len(shape_names)):\n            print(\"ID: {} \\t mean_dist: {:.4f}\".format(shape_names[i], results[i]), file=fp)\n    res = np.mean(results)\n\n    return res\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--src\", type=str)\n    parser.add_argument(\"-p\", \"--process\", type=int, default=10)\n    parser.add_argument(\"-o\", \"--output\", type=str)\n    args = parser.parse_args()\n\n    if args.output is None:\n        args.output = args.src + '-eval_TMD.txt'\n\n    res = Total_Mutual_Difference(args)\n    print(\"Avg Total Multual Difference: {}\".format(res))\n\n    with open(args.output, \"w\") as fp:\n        fp.write(\"SRC: {}\\n\".format(args.src))\n        fp.write(\"Total Multual Difference: {}\\n\".format(res))\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "utils/uhd.py",
    "content": "import argparse\nimport os\nimport torch\nimport numpy as np\nfrom scipy.spatial import cKDTree as KDTree\nimport trimesh\nimport glob\nfrom joblib import Parallel, delayed\n\n\ndef directed_hausdorff(point_cloud1:torch.Tensor, point_cloud2:torch.Tensor, reduce_mean=True):\n    \"\"\"\n\n    :param point_cloud1: (B, 3, N)\n    :param point_cloud2: (B, 3, M)\n    :return: directed hausdorff distance, A -> B\n    \"\"\"\n    n_pts1 = point_cloud1.shape[2]\n    n_pts2 = point_cloud2.shape[2]\n\n    pc1 = point_cloud1.unsqueeze(3)\n    pc1 = pc1.repeat((1, 1, 1, n_pts2)) # (B, 3, N, M)\n    pc2 = point_cloud2.unsqueeze(2)\n    pc2 = pc2.repeat((1, 1, n_pts1, 1)) # (B, 3, N, M)\n\n    l2_dist = torch.sqrt(torch.sum((pc1 - pc2) ** 2, dim=1)) # (B, N, M)\n\n    shortest_dist, _ = torch.min(l2_dist, dim=2)\n\n    hausdorff_dist, _ = torch.max(shortest_dist, dim=1) # (B, )\n\n    if reduce_mean:\n        hausdorff_dist = torch.mean(hausdorff_dist)\n\n    return hausdorff_dist\n\n\ndef nn_distance(query_points, ref_points):\n    ref_points_kd_tree = KDTree(ref_points)\n    one_distances, one_vertex_ids = ref_points_kd_tree.query(query_points)\n    return one_distances\n\n\ndef completeness(query_points, ref_points, thres=0.03):\n    a2b_nn_distance =  nn_distance(query_points, ref_points)\n    percentage = np.sum(a2b_nn_distance < thres) / len(a2b_nn_distance)\n    return percentage\n\n\ndef process_one(shape_dir):\n    # load generated shape\n    pc_paths = glob.glob(os.path.join(shape_dir, \"fake-z*.ply\"))\n    pc_paths = sorted(pc_paths)\n\n    gen_pcs = []\n    for path in pc_paths:\n        sample_pts = trimesh.load(path)\n        sample_pts = np.asarray(sample_pts.vertices)\n        # sample_pts = torch.tensor(sample_pts.vertices).transpose(1, 0)\n        gen_pcs.append(sample_pts)\n\n    # load partial input\n    partial_path = os.path.join(shape_dir, \"raw.ply\")\n    partial_pc = trimesh.load(partial_path)\n    partial_pc = np.asarray(partial_pc.vertices)\n    # partial_pc = torch.tensor(partial_pc.vertices).transpose(1, 0)\n\n    # completeness percentage\n    gen_comp = 0\n    for sample_pts in gen_pcs:\n        comp = completeness(partial_pc, sample_pts)\n        gen_comp += comp\n    gen_comp = gen_comp / len(gen_pcs)\n\n    # unidirectional hausdorff\n    gen_pcs = [torch.tensor(pc).transpose(1, 0) for pc in gen_pcs]\n    gen_pcs = torch.stack(gen_pcs, dim=0)\n    partial_pc = torch.tensor(partial_pc).transpose(1, 0)\n\n    partial_pc = partial_pc.unsqueeze(0).repeat((gen_pcs.size(0), 1, 1))\n\n    hausdorff = directed_hausdorff(partial_pc, gen_pcs, reduce_mean=True).item()\n\n    return gen_comp, hausdorff\n\ndef uhd_from_pcs(gen_pcs, partial_pc):\n    # completeness percentage\n    gen_comp = 0\n    for sample_pts in gen_pcs:\n        comp = completeness(partial_pc, sample_pts)\n        gen_comp += comp\n    gen_comp = gen_comp / len(gen_pcs)\n\n    # unidirectional hausdorff\n    gen_pcs = [torch.tensor(pc).transpose(1, 0) for pc in gen_pcs]\n    gen_pcs = torch.stack(gen_pcs, dim=0)\n    partial_pc = torch.tensor(partial_pc).transpose(1, 0)\n\n    partial_pc = partial_pc.unsqueeze(0).repeat((gen_pcs.size(0), 1, 1))\n\n    hausdorff = directed_hausdorff(partial_pc, gen_pcs, reduce_mean=True).item()\n\n    return gen_comp, hausdorff\n\n\ndef func(args):\n    shape_names = sorted(os.listdir(args.src))\n    all_shape_dir = [os.path.join(args.src, name) for name in shape_names]\n\n    results = Parallel(n_jobs=args.process, verbose=2)(delayed(process_one)(path) for path in all_shape_dir)\n\n    res_comp, res_hausdorff = zip(*results)\n    res_comp = np.mean(res_comp)\n    res_hausdorff = np.mean(res_hausdorff)\n\n    return res_hausdorff, res_comp\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--src\", type=str)\n    parser.add_argument(\"-p\", \"--process\", type=int, default=10)\n    parser.add_argument(\"-o\", \"--output\", type=str)\n    args = parser.parse_args()\n\n    if args.output is None:\n        args.output = args.src + '-eval_UHD.txt'\n\n    res_hausdorff, res_comp = func(args)\n    print(\"Avg Unidirectional Hausdorff Distance: {}\".format(res_hausdorff))\n    print(\"Avg Completeness: {}\".format(res_comp))\n\n    with open(args.output, \"a\") as fp:\n        fp.write(\"SRC: {}\\n\".format(args.src))\n        fp.write(\"Avg Unidirectional Hausdorff Distance: {}\\n\".format(res_hausdorff))\n        fp.write(\"Avg Completeness: {}\\n\".format(res_comp))\n\n\nif __name__ == '__main__':\n    main()"
  },
  {
    "path": "utils/visualize.py",
    "content": "import numpy as np\nimport trimesh\nimport sys\nimport h5py\nimport open3d as o3d\n\ndef create_point_marker(center, color):\n\n    # point cloud point info (e.g. radius)\n    point = trimesh.primitives.Sphere(\n        radius=0.002, \n        center=center\n    )\n    point.visual.vertex_colors = color\n    return point\n\ndef get_color(labels):\n    return np.stack([np.ones(labels.shape[0]) - labels, labels, np.zeros(labels.shape[0])], axis=1)\n\n\ndef vis_pc(obj_pc):\n    point_color = get_color(np.array([1.0]))\n    point_markers = [create_point_marker(center=obj_pc[t], color=point_color[0]) for t in range(len(obj_pc))]\n\n    trimesh.Scene(point_markers).show()\n\ndef main(mesh, pc, query):\n\n    lst = []\n\n    if pc:\n        point_color = get_color(np.array([0.5]))\n        point_markers = [create_point_marker(center=point_cloud[t], color=point_color[0]) for t in range(len(point_cloud))]\n        lst.append(trimesh.Scene(point_markers))\n\n    if query: \n        q_color = get_color(np.array([0.0]))\n        q_markers = [create_point_marker(center=queries[t], color=q_color[0]) for t in range(len(queries))]\n        lst.append(trimesh.Scene(q_markers))\n\n    if mesh:\n        lst.append(object_mesh)\n\n    scene = trimesh.scene.scene.append_scenes( lst ).show() \n    #trimesh.exchange.export.export_scene(scene, \"vis.ply\") \n\n\nobject_mesh = trimesh.load(sys.argv[1])\nobject_mesh.apply_scale(0.1)\npc = object_mesh.vertices\n\npoint_cloud = np.loadtxt(sys.argv[2], dtype=float, delimiter=',') # num of points x 3\np_idx = np.random.choice(point_cloud.shape[0], 10000)\npoint_cloud = point_cloud[p_idx][:,0:3]\n\n# for visualizing pc only\n# pcd = o3d.geometry.PointCloud()\n# pcd.points = o3d.utility.Vector3dVector(point_cloud)\n# #o3d.io.write_point_cloud(\"./pc.ply\", pcd)\n# o3d.visualization.draw_geometries([pcd])\n\n\n#     mesh, pc, query\nmain(False, True, False)\n\n"
  }
]