[
  {
    "path": ".gitignore",
    "content": "__pycache__\n*.egg-info\nbuild/\ndist/\nlogs/\n.vscode/\nresults/\ntemp.sh\n\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 JunyuanDeng\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "Readme.md",
    "content": "# NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping\n\nThis repository contains the implementation of our paper:\n> **NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping** ([PDF](https://arxiv.org/pdf/2303.10709))\\\n> [Junyuan Deng](https://github.com/JunyuanDeng), [Qi Wu](https://github.com/Gatsby23), [Xieyuanli Chen](https://github.com/Chen-Xieyuanli), Songpengcheng Xia, Zhen Sun, Guoqing Liu, Wenxian Yu and Ling Pei\\\n> If you use our code in your work, please star our repo and cite our paper.\n\n```\n@inproceedings{deng2023nerfloam,\n      title={NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping}, \n      author={Junyuan Deng and Qi Wu and Xieyuanli Chen and Songpengcheng Xia and Zhen Sun and Guoqing Liu and Wenxian Yu and Ling Pei},\n      booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},\n      year={2023}\n\n}\n```\n\n<div align=center>\n<img src=\"./docs/NeRFLOAM.gif\"> \n</div>\n\n- *Our incrementally simultaneous odometry and mapping results on the Newer College dataset and the KITTI dataset sequence 00.*\n- *The maps are dense with a form of mesh, the red line indicates the odometry results.*\n- *We use the same network without training to prove the ability of generalization of our design.*\n\n\n## Overview\n\n![pipeline](./docs/pipeline.png)\n\n**Overview of our method.** Our method is based on our neural SDF and composed of three main components:\n- Neural odometry takes the pre-processed scan and optimizes the pose via back projecting the queried neural SDF; \n- Neural mapping jointly optimizes the voxel embeddings map and pose while selecting the key-scans; \n- Key-scans refined map returns SDF value and the final mesh is reconstructed by marching cube.\n\n## Quatitative results\n\n**The reconstructed maps**\n![odomap_kitti](./docs/odomap_kitti.png)\n*The qualitative result of our odometry mapping on the KITTI dataset. From left upper to right bottom, we list the results of sequences 00, 01, 03, 04, 05, 09, 10.*\n\n**The odometry results**\n![odo_qual](./docs/odo_qual.png)\n*The qualitative results of our odometry on the KITTI dataset. From left to right, we list the results of sequences 00, 01, 03, 04, 05, 07, 09, 10. The dashed line corresponds to the ground truth and the blue line to our odometry method.*\n\n\n## Data\n\n1. Newer College real-world LiDAR dataset: [website](https://ori-drs.github.io/newer-college-dataset/download/). \n\n2. MaiCity synthetic LiDAR dataset: [website](https://www.ipb.uni-bonn.de/data/mai-city-dataset/).\n\n3. KITTI dataset: [website](https://www.cvlibs.net/datasets/kitti/).\n\n## Environment Setup\n\nTo run the code, a GPU with large memory is preferred. We tested the code with RTX3090 and GTX TITAN.\n\nWe use Conda to create a virtual environment and install dependencies:\n\n- python environment: We tested our code with Python 3.8.13\n\n- [Pytorch](https://pytorch.org/get-started/locally/): The Version we tested is 1.10 with cuda10.2 (and cuda11.1)\n\n- Other depedencies are specified in requirements.txt. You can then install all dependancies using `pip` or `conda`: \n```\npip3 install -r requirements.txt\n```\n\n- After you have installed all third party libraries, run the following script to build extra Pytorch modules used in this project.\n\n```bash\nsh install.sh\n```\n\n\n- Replace the filename in mapping.py with the built library\n```python\ntorch.classes.load_library(\"third_party/sparse_octree/build/lib.xxx/svo.xxx.so\")\n```\n\n- [patchwork-plusplus](https://github.com/url-kaist/patchwork-plusplus) to separate gound from LiDAR points.\n\n- Replace the filename in src/dataset/*.py with the built library\n```python\npatchwork_module_path =\"/xxx/patchwork-plusplus/build/python_wrapper\"\n```\n\n## Demo\n\n- The full dataset can be downloaded as mentioned. You can also download the example part of dataset, use these [scripts](https://github.com/PRBonn/SHINE_mapping/tree/master/scripts) to download.\n- Take maicity seq.01 dataset as example: Modify `configs/maicity/maicity_01.yaml` so the data_path section points to the real dataset path. Now you are all set to run the code: \n```\npython demo/run.py configs/maicity/maicity_01.yaml\n```\n\n## Note\n\n- For kitti dataset, if you want to process it more fast, you can switch to branch `subscene`:\n```\ngit checkout subscene\n```\n- Then run with `python demo/run.py configs/kitti/kitti_00.yaml`\n- This branch cut the full scene into subscenes to speed up and concatenate them together. This will certainly add map inconsistency and decay tracking accuracy...\n\n## Evaluation\n\n- We follow the evaluation proposed [here](https://github.com/PRBonn/SHINE_mapping/tree/master/eval), but we did not use the `crop_intersection.py`\n\n\n## Acknowledgement\n\nSome of our codes are adapted from [Vox-Fusion](https://github.com/zju3dv/Vox-Fusion).\n\n## Contact\n\nAny questions or suggestions are welcome!\n\nJunyuan Deng: d.juney@sjtu.edu.cn and Xieyuanli Chen: xieyuanli.chen@nudt.edu.cn\n\n## License\n\nThis project is free software made available under the MIT License. For details see the LICENSE file.\n"
  },
  {
    "path": "configs/kitti/kitti.yaml",
    "content": "log_dir: './logs'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n  sdf_weight: 10000.0\n  fs_weight: 1\n  eiko_weight: 0.1\n  sdf_truncation: 0.30\n\ndecoder_specs:\n  depth: 2\n  width: 256\n  in_dim: 16\n  skips: []\n  embedder: none\n  multires: 0\n\ntracker_specs:\n  N_rays: 2048\n  learning_rate: 0.06\n  step_size: 0.2\n  max_voxel_hit: 20\n  num_iterations: 25\n\nmapper_specs:\n  N_rays_each: 2048\n  use_local_coord: False\n  voxel_size: 0.3\n  step_size: 0.5\n  window_size: 4\n  num_iterations: 25\n  max_voxel_hit: 20\n  final_iter: True\n  mesh_res: 2\n  learning_rate_emb: 0.01\n  learning_rate_decorder: 0.005\n  learning_rate_pose: 0.001\n  freeze_frame: 5\n  keyframe_gap: 8\n  remove_back: False\n  key_distance: 12\n\ndebug_args:\n  verbose: False\n  mesh_freq: 100\n"
  },
  {
    "path": "configs/kitti/kitti_00.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence00\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/00'\n  use_gt: False\n  max_depth: 40\n  min_depth: 5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: -1\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_01.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence01\n\n\n\ndata_specs:\n  data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/01/'\n  use_gt: False\n  max_depth: 30\n  min_depth: 5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: 1101\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_03.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence03\n\n\n\ndata_specs:\n  data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/03/'\n  use_gt: False\n  max_depth: 30\n  min_depth: 5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: 1101\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_04.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence04\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/04'\n  use_gt: False\n  max_depth: 50\n  min_depth: 2.75\n\ntracker_specs:\n  start_frame: 0\n  end_frame: 270\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_05.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence05\n\n\n\ndata_specs:\n  data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/05/'\n  use_gt: False\n  max_depth: 50\n  min_depth: 5\n\ntracker_specs:\n  start_frame: 2299\n  end_frame: 2760\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_06.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence06\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/06'\n  use_gt: False\n  max_depth: 40\n  min_depth: 5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: -1\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_07.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence07\n\n\n\ndata_specs:\n  data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/07'\n  use_gt: True\n  max_depth: 25\n  min_depth: 5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: 1100\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_08.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence08\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/08'\n  use_gt: False\n  max_depth: 40\n  min_depth: 5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: -1\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_09.yaml",
    "content": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence09\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/09'\n  use_gt: False\n  max_depth: 40\n  min_depth: 5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: -1\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_10.yaml",
    "content": "base_config: configs/kitti/kitti_base10.yaml\n\nexp_name: kitti/sqeuence10\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/10'\n  use_gt: False\n  max_depth: 70\n  min_depth: 2.75\n\ntracker_specs:\n  start_frame: 0\n  end_frame: 1200\n  read_offset: 1"
  },
  {
    "path": "configs/kitti/kitti_base06.yaml",
    "content": "log_dir: './logs'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n  sdf_weight: 10000.0\n  fs_weight: 1\n  eiko_weight: 0.1\n  sdf_truncation: 0.30\n\ndecoder_specs:\n  depth: 2\n  width: 256\n  in_dim: 16\n  skips: []\n  embedder: none\n  multires: 0\n\ntracker_specs:\n  N_rays: 2048\n  learning_rate: 0.06\n  step_size: 0.2\n  max_voxel_hit: 20\n  num_iterations: 25\n\nmapper_specs:\n  N_rays_each: 2048\n  use_local_coord: False\n  voxel_size: 0.3\n  step_size: 0.5\n  window_size: 4\n  num_iterations: 25\n  max_voxel_hit: 20\n  final_iter: True\n  mesh_res: 2\n  learning_rate_emb: 0.01\n  learning_rate_decorder: 0.005\n  learning_rate_pose: 0.001\n  freeze_frame: 5\n  keyframe_gap: 8\n  remove_back: False\n  key_distance: 12\n\ndebug_args:\n  verbose: False\n  mesh_freq: 100"
  },
  {
    "path": "configs/kitti/kitti_base10.yaml",
    "content": "log_dir: '/home/evsjtu2/disk1/dengjunyuan/running_logs/'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n  depth_weight: 0\n  sdf_weight: 12000.0\n  fs_weight: 1\n  eiko_weight: 0\n  sdf_truncation: 0.50\n\ndecoder_specs:\n  depth: 2\n  width: 256\n  in_dim: 16\n  skips: []\n  embedder: none\n  multires: 0\n\ntracker_specs:\n  N_rays: 2048\n  learning_rate: 0.1\n  start_frame: 0\n  end_frame: -1\n  step_size: 0.2\n  show_imgs: False\n  max_voxel_hit: 20\n  keyframe_freq: 10\n  num_iterations: 40\n\nmapper_specs:\n  N_rays_each: 2048\n  num_embeddings: 20000000\n  use_local_coord: False\n  voxel_size: 0.2\n  step_size: 0.2\n  window_size: 4\n  num_iterations: 20\n  max_voxel_hit: 20\n  final_iter: True\n  mesh_res: 2\n  overlap_th: 0.8\n  learning_rate_emb: 0.03\n  learning_rate_decorder: 0.005\n  learning_rate_pose: 0.001\n  #max_depth_first: 20\n  freeze_frame: 10\n  keyframe_gap: 7\n  remove_back: True\n  key_distance: 7\n\ndebug_args:\n  verbose: False\n  mesh_freq: 100"
  },
  {
    "path": "configs/maicity/maicity.yaml",
    "content": "log_dir: './logs'\ndecoder: lidar\ndataset: maicity\n\ncriteria:\n  sdf_weight: 10000.0\n  fs_weight: 1\n  eiko_weight: 0.1\n  sdf_truncation: 0.30\n\ndecoder_specs:\n  depth: 2\n  width: 256\n  in_dim: 16\n  skips: []\n  embedder: none\n  multires: 0\n\ntracker_specs:\n  N_rays: 2048\n  learning_rate: 0.005\n  step_size: 0.2\n  max_voxel_hit: 20\n  num_iterations: 20\n\nmapper_specs:\n  N_rays_each: 2048\n  use_local_coord: False\n  voxel_size: 0.2\n  step_size: 0.5\n  window_size: 4\n  num_iterations: 20\n  max_voxel_hit: 20\n  final_iter: True\n  mesh_res: 2\n  learning_rate_emb: 0.03\n  learning_rate_decorder: 0.005\n  learning_rate_pose: 0.001\n  freeze_frame: 5\n  keyframe_gap: 8\n  remove_back: False\n  key_distance: 12\n\ndebug_args:\n  verbose: False\n  mesh_freq: 100\n"
  },
  {
    "path": "configs/maicity/maicity_00.yaml",
    "content": "base_config: configs/maicity/maicity.yaml\n\nexp_name: maicity/sqeuence00\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/mai_city/bin/sequences/00'\n  use_gt: False\n  max_depth: 50.0\n  min_depth: 1.5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: 699\n  read_offset: 1"
  },
  {
    "path": "configs/maicity/maicity_01.yaml",
    "content": "base_config: configs/maicity/maicity.yaml\n\nexp_name: maicity/sqeuence01\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/mai_city/bin/sequences/01'\n  use_gt: False\n  max_depth: 50.0\n  min_depth: 1.5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: 99\n  read_offset: 1"
  },
  {
    "path": "configs/ncd/ncd.yaml",
    "content": "log_dir: './logs'\ndecoder: lidar\ndataset: ncd\n\ncriteria:\n  sdf_weight: 10000.0\n  fs_weight: 1\n  eiko_weight: 1.0\n  sdf_truncation: 0.30\n\ndecoder_specs:\n  depth: 2\n  width: 256\n  in_dim: 16\n  skips: []\n  embedder: none\n  multires: 0\n\ntracker_specs:\n  N_rays: 2048\n  learning_rate: 0.04\n  step_size: 0.1\n  max_voxel_hit: 20\n  num_iterations: 30\n\nmapper_specs:\n  N_rays_each: 2048\n  use_local_coord: False\n  voxel_size: 0.2\n  step_size: 0.2\n  window_size: 5\n  num_iterations: 15\n  max_voxel_hit: 20\n  final_iter: True\n  mesh_res: 2\n  learning_rate_emb: 0.002\n  learning_rate_decorder: 0.005\n  learning_rate_pose: 0.001\n  freeze_frame: 20\n  keyframe_gap: 8\n  remove_back: False\n  key_distance: 20\n\ndebug_args:\n  verbose: False\n  mesh_freq: 500\n"
  },
  {
    "path": "configs/ncd/ncd_quad.yaml",
    "content": "base_config: configs/ncd/ncd.yaml\n\nexp_name: ncd/quad\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/ncd_example/quad'\n  use_gt: False\n  max_depth: 50\n  min_depth: 1.5\n\ntracker_specs:\n  start_frame: 0\n  end_frame: -1\n  read_offset: 5\n\n"
  },
  {
    "path": "demo/parser.py",
    "content": "import yaml\nimport argparse\n\nclass ArgumentParserX(argparse.ArgumentParser):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.add_argument(\"config\", type=str)\n\n    def parse_args(self, args=None, namespace=None):\n        _args = self.parse_known_args(args, namespace)[0]\n        file_args = argparse.Namespace()\n        file_args = self.parse_config_yaml(_args.config, file_args)\n        file_args = self.convert_to_namespace(file_args)\n        for ckey, cvalue in file_args.__dict__.items():\n            try:\n                self.add_argument('--' + ckey, type=type(cvalue),\n                                  default=cvalue, required=False)\n            except argparse.ArgumentError:\n                continue\n        _args = super().parse_args(args, namespace)\n        return _args\n\n    def parse_config_yaml(self, yaml_path, args=None):\n\n        with open(yaml_path, 'r') as f:\n            configs = yaml.load(f, Loader=yaml.FullLoader)\n\n        if configs is not None:\n            base_config = configs.get('base_config')\n            if base_config is not None:\n                base_config = self.parse_config_yaml(configs[\"base_config\"])\n                if base_config is not None:\n                    configs = self.update_recursive(base_config, configs)\n                else:\n                    raise FileNotFoundError(\"base_config specified but not found!\")\n\n        return configs\n\n    def convert_to_namespace(self, dict_in, args=None):\n        if args is None:\n            args = argparse.Namespace()\n        for ckey, cvalue in dict_in.items():\n            if ckey not in args.__dict__.keys():\n                args.__dict__[ckey] = cvalue\n\n        return args\n\n    def update_recursive(self, dict1, dict2):\n        for k, v in dict2.items():\n            if k not in dict1:\n                dict1[k] = dict()\n            if isinstance(v, dict):\n                self.update_recursive(dict1[k], v)\n            else:\n                dict1[k] = v\n        return dict1\n\ndef get_parser():\n    parser = ArgumentParserX()\n    parser.add_argument(\"--resume\", default=None, type=str)\n    parser.add_argument(\"--debug\", action='store_true')\n    return parser\n\nif __name__ == '__main__':\n    args = ArgumentParserX()\n    print(args.parse_args())\n"
  },
  {
    "path": "demo/run.py",
    "content": "import os  # noqa\nimport sys  # noqa\nsys.path.insert(0, os.path.abspath('src')) # noqa\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\nimport random\nfrom parser import get_parser\nimport numpy as np\nimport torch\nfrom nerfloam import nerfloam\nimport os\n#os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:512\"\ndef setup_seed(seed):\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    np.random.seed(seed)\n    random.seed(seed)\n\nif __name__ == '__main__':\n    args = get_parser().parse_args()\n    if hasattr(args, 'seeding'):\n        setup_seed(args.seeding)\n    else:\n        setup_seed(777)\n\n    slam = nerfloam(args)\n    slam.start()\n    slam.wait_child_processes()"
  },
  {
    "path": "install.sh",
    "content": "#!/bin/bash\n\ncd third_party/marching_cubes\npython setup.py install\n\ncd ../sparse_octree\npython setup.py install\n\ncd ../sparse_voxels\npython setup.py install"
  },
  {
    "path": "requirements.txt",
    "content": "matplotlib\nopen3d\nopencv-python\nPyYAML\nscikit-image\ntqdm\ntrimesh\npyyaml"
  },
  {
    "path": "src/criterion.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.autograd import grad\n\n\nclass Criterion(nn.Module):\n    def __init__(self, args) -> None:\n        super().__init__()\n        self.args = args\n        self.eiko_weight = args.criteria[\"eiko_weight\"]\n        self.sdf_weight = args.criteria[\"sdf_weight\"]\n        self.fs_weight = args.criteria[\"fs_weight\"]\n        self.truncation = args.criteria[\"sdf_truncation\"]\n        self.max_dpeth = args.data_specs[\"max_depth\"]\n\n    def forward(self, outputs, obs, pointsCos, use_color_loss=True,\n                use_depth_loss=True, compute_sdf_loss=True,\n                weight_depth_loss=False, compute_eikonal_loss=False):\n\n        points = obs\n        loss = 0\n        loss_dict = {}\n\n        # pred_depth = outputs[\"depth\"]\n        pred_sdf = outputs[\"sdf\"]\n        z_vals = outputs[\"z_vals\"]\n        ray_mask = outputs[\"ray_mask\"]\n        valid_mask = outputs[\"valid_mask\"]\n        sampled_xyz = outputs[\"sampled_xyz\"]\n        gt_points = points[ray_mask]\n        pointsCos = pointsCos[ray_mask]\n        gt_distance = torch.norm(gt_points, 2, -1)\n\n        gt_distance = gt_distance * pointsCos.view(-1)\n        z_vals = z_vals * pointsCos.view(-1, 1)\n\n        if compute_sdf_loss:\n            fs_loss, sdf_loss, eikonal_loss = self.get_sdf_loss(\n                z_vals, gt_distance, pred_sdf,\n                truncation=self.truncation,\n                loss_type='l2',\n                valid_mask=valid_mask,\n                compute_eikonal_loss=compute_eikonal_loss,\n                points=sampled_xyz if compute_eikonal_loss else None\n            )\n            loss += self.fs_weight * fs_loss\n            loss += self.sdf_weight * sdf_loss\n            # loss += self.bs_weight * back_loss\n            loss_dict[\"fs_loss\"] = fs_loss.item()\n            # loss_dict[\"bs_loss\"] = back_loss.item()\n            loss_dict[\"sdf_loss\"] = sdf_loss.item()\n            if compute_eikonal_loss:\n                loss += self.eiko_weight * eikonal_loss\n                loss_dict[\"eiko_loss\"] = eikonal_loss.item()\n        loss_dict[\"loss\"] = loss.item()\n        # print(loss_dict)\n        return loss, loss_dict\n\n    def compute_loss(self, x, y, mask=None, loss_type=\"l2\"):\n        if mask is None:\n            mask = torch.ones_like(x).bool()\n        if loss_type == \"l1\":\n            return torch.mean(torch.abs(x - y)[mask])\n        elif loss_type == \"l2\":\n            return torch.mean(torch.square(x - y)[mask])\n\n    def get_masks(self, z_vals, depth, epsilon):\n        front_mask = torch.where(\n            z_vals < (depth - epsilon),\n            torch.ones_like(z_vals),\n            torch.zeros_like(z_vals),\n        )\n        back_mask = torch.where(\n            z_vals > (depth + epsilon),\n            torch.ones_like(z_vals),\n            torch.zeros_like(z_vals),\n        )\n        depth_mask = torch.where(\n            (depth > 0.0) & (depth < self.max_dpeth), torch.ones_like(\n                depth), torch.zeros_like(depth)\n        )\n        sdf_mask = (1.0 - front_mask) * (1.0 - back_mask) * depth_mask\n\n        num_fs_samples = torch.count_nonzero(front_mask).float()\n        num_sdf_samples = torch.count_nonzero(sdf_mask).float()\n        num_samples = num_sdf_samples + num_fs_samples\n        fs_weight = 1.0 - num_fs_samples / num_samples\n        sdf_weight = 1.0 - num_sdf_samples / num_samples\n\n        return front_mask, sdf_mask, fs_weight, sdf_weight\n\n    def get_sdf_loss(self, z_vals, depth, predicted_sdf, truncation, valid_mask, loss_type=\"l2\", compute_eikonal_loss=False, points=None):\n\n        front_mask, sdf_mask, fs_weight, sdf_weight = self.get_masks(\n            z_vals, depth.unsqueeze(-1).expand(*z_vals.shape), truncation\n        )\n        fs_loss = (self.compute_loss(predicted_sdf * front_mask * valid_mask, torch.ones_like(\n            predicted_sdf) * front_mask, loss_type=loss_type,) * fs_weight)\n        sdf_loss = (self.compute_loss((z_vals + predicted_sdf * truncation) * sdf_mask * valid_mask,\n                    depth.unsqueeze(-1).expand(*z_vals.shape) * sdf_mask, loss_type=loss_type,) * sdf_weight)\n        # back_loss = (self.compute_loss(predicted_sdf * back_mask, -torch.ones_like(\n        #     predicted_sdf) * back_mask, loss_type=loss_type,) * back_weight)\n        eikonal_loss = None\n        if compute_eikonal_loss:\n            sdf = (predicted_sdf*sdf_mask*truncation)\n            sdf = sdf[valid_mask]\n            d_points = torch.ones_like(sdf, requires_grad=False, device=sdf.device)\n            sdf_grad = grad(outputs=sdf,\n                            inputs=points,\n                            grad_outputs=d_points,\n                            retain_graph=True,\n                            only_inputs=True)[0]\n            eikonal_loss = self.compute_loss(sdf_grad[0].norm(2, -1), 1.0, loss_type=loss_type,)\n\n        return fs_loss, sdf_loss, eikonal_loss\n"
  },
  {
    "path": "src/dataset/kitti.py",
    "content": "import os.path as osp\n\nimport numpy as np\nimport torch\nfrom glob import glob\nfrom torch.utils.data import Dataset\nimport sys\nfrom scipy.spatial import cKDTree\n\npatchwork_module_path =\"/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper\"\nsys.path.insert(0, patchwork_module_path)\nimport pypatchworkpp\nparams = pypatchworkpp.Parameters()\n# params.verbose = True\n\nPatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)\n\n\nclass DataLoader(Dataset):\n    def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:\n        self.data_path = data_path\n        self.num_bin = len(glob(osp.join(self.data_path, \"velodyne/*.bin\")))\n        self.use_gt = use_gt\n        self.max_depth = max_depth\n        self.min_depth = min_depth\n        self.gt_pose = self.load_gt_pose() if use_gt else None\n\n    def get_init_pose(self, frame):\n        if self.gt_pose is not None:\n            return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])\n                                  ).reshape(4, 4)\n        else:\n            return np.eye(4)\n\n    def load_gt_pose(self):\n        gt_file = osp.join(self.data_path, \"poses_lidar.txt\")\n        gt_pose = np.loadtxt(gt_file)\n        return gt_pose\n\n    def load_points(self, index):\n        remove_abnormal_z = True\n        path = osp.join(self.data_path, \"velodyne/{:06d}.bin\".format(index))\n        points = np.fromfile(path, dtype=np.float32, count=-1).reshape([-1, 4])\n        if remove_abnormal_z:\n            points = points[points[:, 2] > -3.0]\n        points_norm = np.linalg.norm(points[:, :3], axis=-1)\n        point_mask = True\n        if self.max_depth != -1:\n            point_mask = (points_norm < self.max_depth) & point_mask\n        if self.min_depth != -1:\n            point_mask = (points_norm > self.min_depth) & point_mask\n\n        if isinstance(point_mask, np.ndarray):\n            points = points[point_mask]\n\n        PatchworkPLUSPLUS.estimateGround(points)\n        ground = PatchworkPLUSPLUS.getGround()\n        nonground = PatchworkPLUSPLUS.getNonground()\n        Patchcenters = PatchworkPLUSPLUS.getCenters()\n        normals = PatchworkPLUSPLUS.getNormals()\n        T = cKDTree(Patchcenters)\n        _, index = T.query(ground)\n        if True:\n            groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))\n        else:\n            groundcos = np.ones(ground.shape[0])\n        points = np.concatenate((ground, nonground), axis=0)\n        pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)\n\n        return points, pointcos\n\n    def __len__(self):\n        return self.num_bin\n\n    def __getitem__(self, index):\n        points, pointcos = self.load_points(index)\n        points = torch.from_numpy(points).float()\n        pointcos = torch.from_numpy(pointcos).float()\n        pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])\n                              ).reshape(4, 4) if self.use_gt else None\n        return index, points, pointcos, pose\n\n\nif __name__ == \"__main__\":\n    path = \"/home/pl21n4/dataset/kitti/dataset/sequences/00/\"\n    loader = DataLoader(path)\n    for data in loader:\n        index, points, pose = data\n        print(\"current index \", index)\n        print(\"first 10th points:\\n\", points[:10])\n        if index > 10:\n            break\n        index += 1\n"
  },
  {
    "path": "src/dataset/maicity.py",
    "content": "import os.path as osp\n\nimport numpy as np\nimport torch\nfrom glob import glob\nfrom torch.utils.data import Dataset\nimport sys\nfrom scipy.spatial import cKDTree\n\npatchwork_module_path =\"/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper\"\nsys.path.insert(0, patchwork_module_path)\nimport pypatchworkpp\n\nparams = pypatchworkpp.Parameters()\n# params.verbose = True\n\nPatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)\n\n\nclass DataLoader(Dataset):\n    def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:\n        self.data_path = data_path\n        self.num_bin = len(glob(osp.join(self.data_path, \"velodyne/*.bin\")))\n        self.use_gt = use_gt\n        self.max_depth = max_depth\n        self.min_depth = min_depth\n        self.gt_pose = self.load_gt_pose() if use_gt else None\n\n    def get_init_pose(self, frame):\n        if self.gt_pose is not None:\n            return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])\n                                  ).reshape(4, 4)\n        else:\n            return np.eye(4)\n\n    def load_gt_pose(self):\n        gt_file = osp.join(self.data_path, \"poses.txt\")\n        gt_pose = np.loadtxt(gt_file)\n        return gt_pose\n\n    def load_points(self, index):\n        path = osp.join(self.data_path, \"velodyne/{:05d}.bin\".format(index))\n        points = np.fromfile(path, dtype=np.float32, count=-1).reshape([-1, 4])\n        # points = points[:,:3]\n        # points = np.delete(points, -1, axis=1)\n        points_norm = np.linalg.norm(points[:, :3], axis=-1)\n        point_mask = True\n        if self.max_depth != -1:\n            point_mask = (points_norm < self.max_depth) & point_mask\n        if self.min_depth != -1:\n            point_mask = (points_norm > self.min_depth) & point_mask\n        if isinstance(point_mask, np.ndarray):\n            points = points[point_mask]\n\n        PatchworkPLUSPLUS.estimateGround(points)\n        ground = PatchworkPLUSPLUS.getGround()\n        nonground = PatchworkPLUSPLUS.getNonground()\n        Patchcenters = PatchworkPLUSPLUS.getCenters()\n        normals = PatchworkPLUSPLUS.getNormals()\n        T = cKDTree(Patchcenters)\n        _, index = T.query(ground)\n        if True:\n            groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))\n            # groundnorm = np.linalg.norm(ground, axis=-1)\n            # groundcos = np.where(groundnorm > 10.0, np.ones(ground.shape[0]), groundcos)\n\n        else:\n            groundcos = np.ones(ground.shape[0])\n        points = np.concatenate((ground, nonground), axis=0)\n        pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)\n\n        return points, pointcos\n\n    def __len__(self):\n        return self.num_bin\n\n    def __getitem__(self, index):\n        points, pointcos = self.load_points(index)\n        points = torch.from_numpy(points).float()\n        pointcos = torch.from_numpy(pointcos).float()\n        pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])\n                              ).reshape(4, 4) if self.use_gt else None\n        return index, points, pointcos, pose\n\n\nif __name__ == \"__main__\":\n    path = \"/home/pl21n4/dataset/kitti/dataset/sequences/00/\"\n    loader = DataLoader(path)\n    for data in loader:\n        index, points, pose = data\n        print(\"current index \", index)\n        print(\"first 10th points:\\n\", points[:10])\n        if index > 10:\n            break\n        index += 1\n"
  },
  {
    "path": "src/dataset/ncd.py",
    "content": "import os.path as osp\n\nimport numpy as np\nimport torch\nimport open3d as o3d\nfrom glob import glob\nfrom torch.utils.data import Dataset\nimport sys\nfrom scipy.spatial import cKDTree\n\npatchwork_module_path =\"/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper\"\nsys.path.insert(0, patchwork_module_path)\nimport pypatchworkpp\nparams = pypatchworkpp.Parameters()\nparams.enable_RNR = False\n# params.verbose = True\n\nPatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)\n\n\nclass DataLoader(Dataset):\n    def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:\n        self.data_path = data_path\n        self.num_bin = len(glob(osp.join(self.data_path, \"pcd/*.pcd\")))\n        self.use_gt = use_gt\n        self.max_depth = max_depth\n        self.min_depth = min_depth\n        self.gt_pose = self.load_gt_pose() if use_gt else None\n\n    def get_init_pose(self, frame):\n        if self.gt_pose is not None:\n            return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])\n                                  ).reshape(4, 4)\n        else:\n            return np.array([[5.925493285036220747e-01, -8.038419275143061649e-01, 5.218676416200035417e-02, -2.422443415414985424e-01],\n                            [8.017167514002809803e-01, 5.948020209102693467e-01, 5.882863457495644127e-02,  3.667865561670570873e+00],\n                            [-7.832971094540422397e-02, 6.980134849334420320e-03, 9.969030746023688216e-01, 6.809443654823238434e-01]])\n\n    def load_gt_pose(self):\n        gt_file = osp.join(self.data_path, \"poses.txt\")\n        gt_pose = np.loadtxt(gt_file)\n        # with open(gt_file, mode='r', encoding=\"utf-8\") as g:\n        #     line = g.readline()\n        #     while line:\n        #         # TODO:write transfomation of kitti and pose matrix\n        #         pose = np.zeros(16)\n        return gt_pose\n\n    def load_points(self, index):\n        path = osp.join(self.data_path, \"pcd/{:05d}.pcd\".format(index+500))\n        pc_load = o3d.io.read_point_cloud(path)\n        points = np.asarray(pc_load.points)\n\n        points_norm = np.linalg.norm(points, axis=-1)\n        point_mask = True\n        if self.max_depth != -1:\n            point_mask = (points_norm < self.max_depth) & point_mask\n        if self.min_depth != -1:\n            point_mask = (points_norm > self.min_depth) & point_mask\n        if isinstance(point_mask, np.ndarray):\n            points = points[point_mask]\n\n        PatchworkPLUSPLUS.estimateGround(points)\n        ground = PatchworkPLUSPLUS.getGround()\n        nonground = PatchworkPLUSPLUS.getNonground()\n        Patchcenters = PatchworkPLUSPLUS.getCenters()\n        normals = PatchworkPLUSPLUS.getNormals()\n        T = cKDTree(Patchcenters)\n        _, index = T.query(ground)\n        if True:\n            groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))\n        else:\n            groundcos = np.ones(ground.shape[0])\n        points = np.concatenate((ground, nonground), axis=0)\n        pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)\n\n        return points, pointcos\n\n    def __len__(self):\n        return self.num_bin\n\n    def __getitem__(self, index):\n        points, pointcos = self.load_points(index)\n        points = torch.from_numpy(points).float()\n        pointcos = torch.from_numpy(pointcos).float()\n        pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])\n                              ).reshape(4, 4) if self.use_gt else None\n        return index, points, pointcos, pose\n\n\nif __name__ == \"__main__\":\n    path = \"/home/pl21n4/dataset/kitti/dataset/sequences/00/\"\n    loader = DataLoader(path)\n    for data in loader:\n        index, points, pose = data\n        print(\"current index \", index)\n        print(\"first 10th points:\\n\", points[:10])\n        if index > 10:\n            break\n        index += 1\n"
  },
  {
    "path": "src/lidarFrame.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom se3pose import OptimizablePose\nfrom utils.sample_util import *\nimport random\n\n\nclass LidarFrame(nn.Module):\n    def __init__(self, index, points, pointsCos, pose=None, new_keyframe=False) -> None:\n        super().__init__()\n        self.index = index\n        self.num_point = len(points)\n        self.points = points\n        self.pointsCos = pointsCos\n        if (not new_keyframe) and (pose is not None):\n            # TODO: fix this offset\n            pose[:3, 3] += 2000\n            pose = torch.tensor(pose, requires_grad=True, dtype=torch.float32)\n            self.pose = OptimizablePose.from_matrix(pose)\n        elif new_keyframe:\n            self.pose = pose\n        self.rays_d = self.get_rays()\n        self.rel_pose = None\n\n    def get_pose(self):\n        return self.pose.matrix()\n\n    def get_translation(self):\n        return self.pose.translation()\n\n    def get_rotation(self):\n        return self.pose.rotation()\n\n    def get_points(self):\n        return self.points\n\n    def get_pointsCos(self):\n        return self.pointsCos\n\n    def set_rel_pose(self, rel_pose):\n        self.rel_pose = rel_pose\n\n    def get_rel_pose(self):\n        return self.rel_pose\n\n    @torch.no_grad()\n    def get_rays(self):\n        self.rays_norm = (torch.norm(self.points, 2, -1, keepdim=True)+1e-8)\n        rays_d = self.points / self.rays_norm\n        # TODO: to keep cosistency, add one dim, but actually no need\n        return rays_d.unsqueeze(1).float()\n\n    @torch.no_grad()\n    def sample_rays(self, N_rays, track=False):\n        self.sample_mask = sample_rays(\n            torch.ones((self.num_point, 1))[None, ...], N_rays)[0, ...]\n\n\n\n"
  },
  {
    "path": "src/loggers.py",
    "content": "import os\nimport os.path as osp\nimport pickle\nfrom datetime import datetime\n\nimport cv2\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport open3d as o3d\nimport torch\nimport yaml\n\n\nclass BasicLogger:\n    def __init__(self, args) -> None:\n        self.args = args\n        self.log_dir = osp.join(\n            args.log_dir, args.exp_name, self.get_random_time_str())\n        self.img_dir = osp.join(self.log_dir, \"imgs\")\n        self.mesh_dir = osp.join(self.log_dir, \"mesh\")\n        self.ckpt_dir = osp.join(self.log_dir, \"ckpt\")\n        self.backup_dir = osp.join(self.log_dir, \"bak\")\n        self.misc_dir = osp.join(self.log_dir, \"misc\")\n\n        os.makedirs(self.img_dir)\n        os.makedirs(self.ckpt_dir)\n        os.makedirs(self.mesh_dir)\n        os.makedirs(self.misc_dir)\n        os.makedirs(self.backup_dir)\n\n        self.log_config(args)\n\n    def get_random_time_str(self):\n        return datetime.strftime(datetime.now(), \"%Y-%m-%d-%H-%M-%S\")\n\n    def log_ckpt(self, mapper):\n        print(\"******* saving *******\")\n        decoder_state = {f: v.cpu()\n                         for f, v in mapper.decoder.state_dict().items()}\n        map_state = {f: v.cpu() for f, v in mapper.map_states.items()}\n        embeddings = mapper.dynamic_embeddings.cpu()\n        svo = mapper.svo\n        torch.save({\n            \"decoder_state\": decoder_state,\n            # \"map_state\": map_state,\n            \"embeddings\": embeddings,\n            \"svo\": svo\n        },\n            os.path.join(self.ckpt_dir, \"final_ckpt.pth\"))\n        print(\"******* finish saving *******\")\n\n    def log_config(self, config):\n        out_path = osp.join(self.backup_dir, \"config.yaml\")\n        yaml.dump(config, open(out_path, 'w'))\n\n    def log_mesh(self, mesh, name=\"final_mesh.ply\"):\n        out_path = osp.join(self.mesh_dir, name)\n        o3d.io.write_triangle_mesh(out_path, mesh)\n\n    def log_point_cloud(self, pcd, name=\"final_points.ply\"):\n        out_path = osp.join(self.mesh_dir, name)\n        o3d.io.write_point_cloud(out_path, pcd)\n\n    def log_numpy_data(self, data, name, ind=None):\n        if isinstance(data, torch.Tensor):\n            data = data.detach().cpu().numpy()\n        if ind is not None:\n            np.save(osp.join(self.misc_dir, \"{}-{:05d}.npy\".format(name, ind)), data)\n        else:\n            np.save(osp.join(self.misc_dir, f\"{name}.npy\"), data)\n            self.npy2txt(osp.join(self.misc_dir, f\"{name}.npy\"), osp.join(self.misc_dir, f\"{name}.txt\"))\n\n    def log_debug_data(self, data, idx):\n        with open(os.path.join(self.misc_dir, f\"scene_data_{idx}.pkl\"), 'wb') as f:\n            pickle.dump(data, f)\n\n    def log_raw_image(self, ind, rgb, depth):\n        if isinstance(rgb, torch.Tensor):\n            rgb = rgb.detach().cpu().numpy()\n        if isinstance(depth, torch.Tensor):\n            depth = depth.detach().cpu().numpy()\n        rgb = cv2.cvtColor(rgb*255, cv2.COLOR_RGB2BGR)\n        cv2.imwrite(osp.join(self.img_dir, \"{:05d}.jpg\".format(\n            ind)), (rgb).astype(np.uint8))\n        cv2.imwrite(osp.join(self.img_dir, \"{:05d}.png\".format(\n            ind)), (depth*5000).astype(np.uint16))\n\n    def log_images(self, ind, gt_rgb, gt_depth, rgb, depth):\n        gt_depth_np = gt_depth.detach().cpu().numpy()\n        gt_color_np = gt_rgb.detach().cpu().numpy()\n        depth_np = depth.squeeze().detach().cpu().numpy()\n        color_np = rgb.detach().cpu().numpy()\n\n        h, w = depth_np.shape\n        gt_depth_np = cv2.resize(\n            gt_depth_np, (w, h), interpolation=cv2.INTER_NEAREST)\n        gt_color_np = cv2.resize(\n            gt_color_np, (w, h), interpolation=cv2.INTER_AREA)\n\n        depth_residual = np.abs(gt_depth_np - depth_np)\n        depth_residual[gt_depth_np == 0.0] = 0.0\n        color_residual = np.abs(gt_color_np - color_np)\n        color_residual[gt_depth_np == 0.0] = 0.0\n\n        fig, axs = plt.subplots(2, 3)\n        fig.tight_layout()\n        max_depth = np.max(gt_depth_np)\n        axs[0, 0].imshow(gt_depth_np, cmap=\"plasma\",\n                         vmin=0, vmax=max_depth)\n        axs[0, 0].set_title('Input Depth')\n        axs[0, 0].set_xticks([])\n        axs[0, 0].set_yticks([])\n        axs[0, 1].imshow(depth_np, cmap=\"plasma\",\n                         vmin=0, vmax=max_depth)\n        axs[0, 1].set_title('Generated Depth')\n        axs[0, 1].set_xticks([])\n        axs[0, 1].set_yticks([])\n        axs[0, 2].imshow(depth_residual, cmap=\"plasma\",\n                         vmin=0, vmax=max_depth)\n        axs[0, 2].set_title('Depth Residual')\n        axs[0, 2].set_xticks([])\n        axs[0, 2].set_yticks([])\n        gt_color_np = np.clip(gt_color_np, 0, 1)\n        color_np = np.clip(color_np, 0, 1)\n        color_residual = np.clip(color_residual, 0, 1)\n        axs[1, 0].imshow(gt_color_np, cmap=\"plasma\")\n        axs[1, 0].set_title('Input RGB')\n        axs[1, 0].set_xticks([])\n        axs[1, 0].set_yticks([])\n        axs[1, 1].imshow(color_np, cmap=\"plasma\")\n        axs[1, 1].set_title('Generated RGB')\n        axs[1, 1].set_xticks([])\n        axs[1, 1].set_yticks([])\n        axs[1, 2].imshow(color_residual, cmap=\"plasma\")\n        axs[1, 2].set_title('RGB Residual')\n        axs[1, 2].set_xticks([])\n        axs[1, 2].set_yticks([])\n        plt.subplots_adjust(wspace=0, hspace=0)\n        plt.savefig(osp.join(self.img_dir, \"{:05d}.jpg\".format(\n            ind)), bbox_inches='tight', pad_inches=0.2)\n        plt.clf()\n        plt.close()\n\n    def npy2txt(self, input_path, output_path):\n        poses = np.load(input_path)\n        with open(output_path, mode='w') as w:\n            shape = poses.shape\n            print(shape)\n            for i in range(shape[0]):\n                one_pose = str()\n                for j in range(shape[1]):\n                    if j == (shape[1]-1):\n                        continue\n                    for k in range(shape[2]):\n                        if j == (shape[1]-2) and k == (shape[1]-1):\n                            one_pose += (str(poses[i][j][k])+\"\\n\")\n                        else:\n                            one_pose += (str(poses[i][j][k])+\" \")\n                w.write(one_pose)\n"
  },
  {
    "path": "src/mapping.py",
    "content": "from copy import deepcopy\nimport random\nfrom time import sleep\nimport numpy as np\nfrom tqdm import tqdm\nimport torch\n\nfrom criterion import Criterion\nfrom loggers import BasicLogger\nfrom utils.import_util import get_decoder, get_property\nfrom variations.render_helpers import bundle_adjust_frames\nfrom utils.mesh_util import MeshExtractor\nfrom utils.profile_util import Profiler\nfrom lidarFrame import LidarFrame\nimport torch.nn.functional as F\nfrom pathlib import Path\nimport open3d as o3d\n\ntorch.classes.load_library(\n    \"/home/pl21n4/Programmes/Vox-Fusion/third_party/sparse_octree/build/lib.linux-x86_64-cpython-38/svo.cpython-38-x86_64-linux-gnu.so\")\n\n\ndef get_network_size(net):\n    size = 0\n    for param in net.parameters():\n        size += param.element_size() * param.numel()\n    return size / 1024 / 1024\n\n\nclass Mapping:\n    def __init__(self, args, logger: BasicLogger):\n        super().__init__()\n        self.args = args\n        self.logger = logger\n        self.decoder = get_decoder(args).cuda()\n        print(self.decoder)\n        self.loss_criteria = Criterion(args)\n        self.keyframe_graph = []\n        self.initialized = False\n\n        mapper_specs = args.mapper_specs\n\n        # optional args\n        self.ckpt_freq = get_property(args, \"ckpt_freq\", -1)\n        self.final_iter = get_property(mapper_specs, \"final_iter\", False)\n        self.mesh_res = get_property(mapper_specs, \"mesh_res\", 8)\n        self.save_data_freq = get_property(\n            args.debug_args, \"save_data_freq\", 0)\n\n        # required args\n        self.voxel_size = mapper_specs[\"voxel_size\"]\n        self.window_size = mapper_specs[\"window_size\"]\n        self.num_iterations = mapper_specs[\"num_iterations\"]\n        self.n_rays = mapper_specs[\"N_rays_each\"]\n        self.sdf_truncation = args.criteria[\"sdf_truncation\"]\n        self.max_voxel_hit = mapper_specs[\"max_voxel_hit\"]\n        self.step_size = mapper_specs[\"step_size\"]\n        self.learning_rate_emb = mapper_specs[\"learning_rate_emb\"]\n        self.learning_rate_decorder = mapper_specs[\"learning_rate_decorder\"]\n        self.learning_rate_pose = mapper_specs[\"learning_rate_pose\"]\n        self.step_size = self.step_size * self.voxel_size\n        self.max_distance = args.data_specs[\"max_depth\"]\n        self.freeze_frame = mapper_specs[\"freeze_frame\"]\n        self.keyframe_gap = mapper_specs[\"keyframe_gap\"]\n        self.remove_back = mapper_specs[\"remove_back\"]\n        self.key_distance = mapper_specs[\"key_distance\"]\n        \n        embed_dim = args.decoder_specs[\"in_dim\"]\n        use_local_coord = mapper_specs[\"use_local_coord\"]\n        self.embed_dim = embed_dim - 3 if use_local_coord else embed_dim\n        #num_embeddings = mapper_specs[\"num_embeddings\"]\n        self.mesh_freq = args.debug_args[\"mesh_freq\"]\n        self.mesher = MeshExtractor(args)\n\n\n        self.voxel_id2embedding_id = -torch.ones((int(2e9), 1), dtype=torch.int)\n        self.embeds_exist_search = dict()\n        self.current_num_embeds = 0\n        self.dynamic_embeddings = None\n\n        self.svo = torch.classes.svo.Octree()\n        self.svo.init(256*256*4, embed_dim, self.voxel_size)\n\n        self.frame_poses = []\n        self.depth_maps = []\n        self.last_tracked_frame_id = 0\n        self.final_poses=[]\n        \n        verbose = get_property(args.debug_args, \"verbose\", False)\n        self.profiler = Profiler(verbose=verbose)\n        self.profiler.enable()\n        \n    def spin(self, share_data, kf_buffer):\n        print(\"mapping process started!!!!!!!!!\")\n        while True:\n            torch.cuda.empty_cache()\n            if not kf_buffer.empty():\n                tracked_frame = kf_buffer.get()\n                # self.create_voxels(tracked_frame)\n                if not self.initialized:\n                    self.first_frame_id = tracked_frame.index\n                    if self.mesher is not None:\n                        self.mesher.rays_d = tracked_frame.get_rays()\n                    self.create_voxels(tracked_frame)\n                    self.insert_keyframe(tracked_frame)\n                    while kf_buffer.empty():\n                        self.do_mapping(share_data, tracked_frame, selection_method='current')\n                    self.initialized = True\n                else:\n                    if self.remove_back:\n                        tracked_frame = self.remove_back_points(tracked_frame)\n                    self.do_mapping(share_data, tracked_frame)\n                    self.create_voxels(tracked_frame)\n                    if (torch.norm(tracked_frame.pose.translation().cpu()\n                        - self.current_keyframe.pose.translation().cpu())) > self.keyframe_gap:\n                        self.insert_keyframe(tracked_frame)\n                        print(\n                            f\"********** current num kfs: { len(self.keyframe_graph) } **********\")\n\n                # self.create_voxels(tracked_frame)\n                tracked_pose = tracked_frame.get_pose().detach()\n                ref_pose = self.current_keyframe.get_pose().detach()\n                rel_pose = torch.linalg.inv(ref_pose) @ tracked_pose\n                self.frame_poses += [(len(self.keyframe_graph) -\n                                      1, rel_pose.cpu())]\n\n                if self.mesh_freq > 0 and (tracked_frame.index) % self.mesh_freq == 0:\n                    if self.final_iter and len(self.keyframe_graph) > 20:\n                        print(f\"********** post-processing steps **********\")\n                        #self.num_iterations = 1\n                        final_num_iter = len(self.keyframe_graph) + 1\n                        progress_bar = tqdm(\n                            range(0, final_num_iter), position=0)\n                        progress_bar.set_description(\" post-processing steps\")\n                        for iter in progress_bar:\n                            #tracked_frame=self.keyframe_graph[iter//self.window_size]\n                            self.do_mapping(share_data, tracked_frame=None,\n                                            update_pose=False, update_decoder=False, selection_method='random')\n\n\n                    self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False),name=f\"mesh_{tracked_frame.index:05d}.ply\")\n                    pose = self.get_updated_poses()\n                    self.logger.log_numpy_data(np.asarray(pose), f\"frame_poses_{tracked_frame.index:05d}\")\n\n                    if self.final_iter and len(self.keyframe_graph) > 20:\n                        self.keyframe_graph = []\n                        self.keyframe_graph += [self.current_keyframe]\n                if self.save_data_freq > 0 and (tracked_frame.stamp + 1) % self.save_data_freq == 0:\n                    self.save_debug_data(tracked_frame)\n            elif share_data.stop_mapping:\n                break\n        print(\"******* extracting mesh without replay *******\")\n        self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False), name=\"final_mesh_noreplay.ply\")\n        if self.final_iter:\n            print(f\"********** post-processing steps **********\")\n            #self.num_iterations = 1\n            final_num_iter = len(self.keyframe_graph) + 1\n            progress_bar = tqdm(\n                range(0, final_num_iter), position=0)\n            progress_bar.set_description(\" post-processing steps\")\n            for iter in progress_bar:\n                tracked_frame=self.keyframe_graph[iter//self.window_size]\n                self.do_mapping(share_data, tracked_frame=None,\n                                update_pose=False, update_decoder=False, selection_method='random')\n\n        print(\"******* extracting final mesh *******\")\n        pose = self.get_updated_poses()\n        self.logger.log_numpy_data(np.asarray(pose), \"frame_poses\")\n        self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False))\n        print(\"******* mapping process died *******\")\n\n    def do_mapping(self, share_data, tracked_frame=None,\n                   update_pose=True, update_decoder=True, selection_method = 'current'):\n        self.profiler.tick(\"do_mapping\")\n        self.decoder.train()\n        optimize_targets = self.select_optimize_targets(tracked_frame, selection_method=selection_method)\n        torch.cuda.empty_cache()\n        self.profiler.tick(\"bundle_adjust_frames\")\n        bundle_adjust_frames(\n            optimize_targets,\n            self.dynamic_embeddings,\n            self.map_states,\n            self.decoder,\n            self.loss_criteria,\n            self.voxel_size,\n            self.step_size,\n            self.n_rays * 2  if selection_method=='random' else self.n_rays,\n            self.num_iterations,\n            self.sdf_truncation,\n            self.max_voxel_hit,\n            self.max_distance,\n            learning_rate=[self.learning_rate_emb, \n                           self.learning_rate_decorder, \n                           self.learning_rate_pose],\n            update_pose=update_pose,\n            update_decoder=update_decoder if tracked_frame == None or (tracked_frame.index -self.first_frame_id) < self.freeze_frame else False,\n            profiler=self.profiler\n        )\n        self.profiler.tok(\"bundle_adjust_frames\")\n        # optimize_targets = [f.cpu() for f in optimize_targets]\n        self.update_share_data(share_data)\n        self.profiler.tok(\"do_mapping\")\n        # sleep(0.01)\n\n    def select_optimize_targets(self, tracked_frame=None, selection_method='previous'):\n        # TODO: better ways\n        targets = []\n        if selection_method == 'current':\n            if tracked_frame == None:\n                raise ValueError('select one track frame')\n            else:\n                return [tracked_frame]\n        if len(self.keyframe_graph) <= self.window_size:\n            targets = self.keyframe_graph[:]\n        elif selection_method == 'random':\n            targets = random.sample(self.keyframe_graph, self.window_size)\n        elif selection_method == 'previous':\n            targets = self.keyframe_graph[-self.window_size:]\n        elif selection_method == 'overlap':\n            raise NotImplementedError(\n                f\"seletion method {selection_method} unknown\")\n\n        if tracked_frame is not None and tracked_frame != self.current_keyframe:\n            targets += [tracked_frame]\n        return targets\n\n    def update_share_data(self, share_data, frameid=None):\n        share_data.decoder = deepcopy(self.decoder)\n        tmp_states = {}\n        for k, v in self.map_states.items():\n            tmp_states[k] = v.detach().cpu()\n        share_data.states = tmp_states\n        # self.last_tracked_frame_id = frameid\n\n    def remove_back_points(self, frame):\n        rel_pose = frame.get_rel_pose()\n        points = frame.get_points()\n        points_norm = torch.norm(points, 2, -1)\n        points_xy = points[:, :2]\n        if rel_pose == None:\n            x = 1\n            y = 0\n        else:\n            x = rel_pose[0, 3]\n            y = rel_pose[1, 3]\n        rel_xy = torch.ones((1, 2))\n        rel_xy[0, 0] = x\n        rel_xy[0, 1] = y\n        point_cos = torch.sum(-points_xy * rel_xy, dim=-1)/(\n            torch.norm(points_xy, 2, -1)*(torch.norm(rel_xy, 2, -1)))\n        remove_index = ((point_cos >= 0.7) & (points_norm > self.key_distance))\n        new_points = frame.points[~remove_index]\n        new_cos = frame.get_pointsCos()[~remove_index]\n        return LidarFrame(frame.index, new_points, new_cos,\n                          frame.pose, new_keyframe=True)\n\n    def frame_maxdistance_change(self, frame, distance):\n        # kf check\n        valid_distance = distance + 0.5\n        new_keyframe_rays_norm = frame.rays_norm.reshape(-1)\n        new_keyframe_points = frame.points[new_keyframe_rays_norm <= valid_distance]\n        new_keyframe_pointsCos = frame.get_pointsCos()[new_keyframe_rays_norm <= valid_distance]\n        return LidarFrame(frame.index, new_keyframe_points, new_keyframe_pointsCos,\n                          frame.pose, new_keyframe=True)\n\n    def insert_keyframe(self, frame, valid_distance=-1):\n        # kf check\n        print(\"insert keyframe\")\n        valid_distance = self.key_distance + 0.01\n        new_keyframe_rays_norm = frame.rays_norm.reshape(-1)\n        mask = (torch.abs(frame.points[:, 0]) < valid_distance) & (torch.abs(frame.points[:, 1])\n                                                                   < valid_distance) & (torch.abs(frame.points[:, 2]) < valid_distance)\n        new_keyframe_points = frame.points[mask]\n        new_keyframe_pointsCos = frame.get_pointsCos()[mask]\n        new_keyframe = LidarFrame(frame.index, new_keyframe_points, new_keyframe_pointsCos,\n                                  frame.pose, new_keyframe=True)\n        if new_keyframe_points.shape[0] < 2*self.n_rays:\n            raise ValueError('valid_distance too small')\n        self.current_keyframe = new_keyframe\n        self.keyframe_graph += [new_keyframe]\n        # self.update_grid_features()\n\n    def create_voxels(self, frame):\n        points = frame.get_points().cuda()\n        pose = frame.get_pose().cuda()\n        print(\"frame id\", frame.index+1)\n        print(\"trans \", pose[:3, 3]-2000)\n        points = points@pose[:3, :3].transpose(-1, -2) + pose[:3, 3]\n        voxels = torch.div(points, self.voxel_size, rounding_mode='floor')\n        self.svo.insert(voxels.cpu().int())\n        self.update_grid_features()\n\n    @torch.enable_grad()\n    def get_embeddings(self, points_idx):\n\n        flatten_idx = points_idx.reshape(-1).long()\n        valid_flatten_idx = flatten_idx[flatten_idx.ne(-1)]\n        existence = F.embedding(valid_flatten_idx, self.voxel_id2embedding_id)\n        torch_add_idx = existence.eq(-1).view(-1)\n        torch_add = valid_flatten_idx[torch_add_idx]\n        if torch_add.shape[0] == 0:\n            return\n        start_num = self.current_num_embeds\n        end_num = start_num + torch_add.shape[0]\n        embeddings_add = torch.zeros((end_num-start_num, self.embed_dim),\n                                     dtype=torch.bfloat16)\n        # torch.nn.init.normal_(embeddings_add, std=0.01)\n\n        if self.dynamic_embeddings == None:\n            embeddings = [embeddings_add]\n        else:\n            embeddings = [self.dynamic_embeddings.detach().cpu(), embeddings_add]\n        embeddings = torch.cat(embeddings, dim=0)\n        self.dynamic_embeddings = embeddings.cuda().requires_grad_()\n\n        self.current_num_embeds = end_num\n        self.voxel_id2embedding_id[torch_add] = torch.arange(start_num, end_num, dtype=torch.int).view(-1, 1)\n\n    @torch.enable_grad()\n    def update_grid_features(self):\n        voxels, children, features = self.svo.get_centres_and_children()\n        centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.voxel_size\n        children = torch.cat([children, voxels[:, -1:]], -1)\n\n        centres = centres.float()\n        children = children.int()\n\n        map_states = {}\n        map_states[\"voxel_vertex_idx\"] = features\n        centres.requires_grad_()\n        map_states[\"voxel_center_xyz\"] = centres\n        map_states[\"voxel_structure\"] = children\n        self.profiler.tick(\"Creating embedding\")\n        self.get_embeddings(map_states[\"voxel_vertex_idx\"])\n        self.profiler.tok(\"Creating embedding\")\n        map_states[\"voxel_vertex_emb\"] = self.dynamic_embeddings\n        map_states[\"voxel_id2embedding_id\"] = self.voxel_id2embedding_id\n\n        self.map_states = map_states\n\n    @torch.no_grad()\n    def get_updated_poses(self, offset=-2000):\n        for i in range(len(self.frame_poses)):\n            ref_frame_ind, rel_pose = self.frame_poses[i]\n            ref_frame = self.keyframe_graph[ref_frame_ind]\n            ref_pose = ref_frame.get_pose().detach().cpu()\n            pose = ref_pose @ rel_pose\n            pose[:3, 3] += offset\n            self.final_poses += [pose.detach().cpu().numpy()]\n        self.frame_poses = []\n        return self.final_poses\n\n    @torch.no_grad()\n    def extract_mesh(self, res=8, clean_mesh=False):\n        sdf_network = self.decoder\n        sdf_network.eval()\n\n        voxels, _, features = self.svo.get_centres_and_children()\n        index = features.eq(-1).any(-1)\n        voxels = voxels[~index, :]\n        features = features[~index, :]\n        centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.voxel_size\n\n        encoder_states = {}\n        encoder_states[\"voxel_vertex_idx\"] = features\n        encoder_states[\"voxel_center_xyz\"] = centres\n        self.profiler.tick(\"Creating embedding\")\n        self.get_embeddings(encoder_states[\"voxel_vertex_idx\"])\n        self.profiler.tok(\"Creating embedding\")\n        encoder_states[\"voxel_vertex_emb\"] = self.dynamic_embeddings\n        encoder_states[\"voxel_id2embedding_id\"] = self.voxel_id2embedding_id\n\n        frame_poses = self.get_updated_poses()\n        mesh = self.mesher.create_mesh(\n            self.decoder, encoder_states, self.voxel_size, voxels,\n            frame_poses=None, depth_maps=None,\n            clean_mseh=clean_mesh, require_color=False, offset=-2000, res=res)\n        return mesh\n\n    @torch.no_grad()\n    def extract_voxels(self, offset=-10):\n        voxels, _, features = self.svo.get_centres_and_children()\n        index = features.eq(-1).any(-1)\n        voxels = voxels[~index, :]\n        features = features[~index, :]\n        voxels = (voxels[:, :3] + voxels[:, -1:] / 2) * \\\n            self.voxel_size + offset\n        # print(torch.max(features)-torch.count_nonzero(index))\n        return voxels\n\n    @torch.no_grad()\n    def save_debug_data(self, tracked_frame, offset=-10):\n        \"\"\"\n        save per-frame voxel, mesh and pose \n        \"\"\"\n        pose = tracked_frame.get_pose().detach().cpu().numpy()\n        pose[:3, 3] += offset\n        frame_poses = self.get_updated_poses()\n        mesh = self.extract_mesh(res=8, clean_mesh=True)\n        voxels = self.extract_voxels().detach().cpu().numpy()\n        keyframe_poses = [p.get_pose().detach().cpu().numpy()\n                          for p in self.keyframe_graph]\n\n        for f in frame_poses:\n            f[:3, 3] += offset\n        for kf in keyframe_poses:\n            kf[:3, 3] += offset\n\n        verts = np.asarray(mesh.vertices)\n        faces = np.asarray(mesh.triangles)\n        color = np.asarray(mesh.vertex_colors)\n\n        self.logger.log_debug_data({\n            \"pose\": pose,\n            \"updated_poses\": frame_poses,\n            \"mesh\": {\"verts\": verts, \"faces\": faces, \"color\": color},\n            \"voxels\": voxels,\n            \"voxel_size\": self.voxel_size,\n            \"keyframes\": keyframe_poses,\n            \"is_keyframe\": (tracked_frame == self.current_keyframe)\n        }, tracked_frame.stamp)\n"
  },
  {
    "path": "src/nerfloam.py",
    "content": "from multiprocessing.managers import BaseManager\nfrom time import sleep\n\nimport torch\nimport torch.multiprocessing as mp\n\nfrom loggers import BasicLogger\nfrom mapping import Mapping\nfrom share import ShareData, ShareDataProxy\nfrom tracking import Tracking\nfrom utils.import_util import get_dataset\n\n\n\nclass nerfloam:\n    def __init__(self, args):\n        self.args = args\n\n        # logger (optional)\n        self.logger = BasicLogger(args)\n\n        # shared data\n        mp.set_start_method('spawn', force=True)\n        BaseManager.register('ShareData', ShareData, ShareDataProxy)\n        manager = BaseManager()\n        manager.start()\n        self.share_data = manager.ShareData()\n        # keyframe buffer\n        self.kf_buffer = mp.Queue(maxsize=1)\n        # data stream\n        self.data_stream = get_dataset(args)\n        # tracker\n        self.tracker = Tracking(args, self.data_stream, self.logger)\n        # mapper\n        self.mapper = Mapping(args, self.logger)\n        # initialize map with first frame\n        self.tracker.process_first_frame(self.kf_buffer)\n        self.processes = []\n\n    def start(self):\n        mapping_process = mp.Process(\n            target=self.mapper.spin, args=(self.share_data, self.kf_buffer))\n        mapping_process.start()\n        print(\"initializing the first frame ...\")\n        sleep(20)\n        # self.share_data.stop_mapping=True\n        tracking_process = mp.Process(\n            target=self.tracker.spin, args=(self.share_data, self.kf_buffer))\n        tracking_process.start()\n\n        self.processes = [mapping_process, tracking_process]\n\n\n\n    def wait_child_processes(self):\n        for p in self.processes:\n            p.join()\n\n    @torch.no_grad()\n    def get_raw_trajectory(self):\n        return self.share_data.tracking_trajectory\n\n    @torch.no_grad()\n    def get_keyframe_poses(self):\n        keyframe_graph = self.mapper.keyframe_graph\n        poses = []\n        for keyframe in keyframe_graph:\n            poses.append(keyframe.get_pose().detach().cpu().numpy())\n        return poses\n"
  },
  {
    "path": "src/se3pose.py",
    "content": "\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom copy import deepcopy\n\n\nclass OptimizablePose(nn.Module):\n    def __init__(self, init_pose):\n        super().__init__()\n        assert (isinstance(init_pose, torch.FloatTensor))\n        self.register_parameter('data', nn.Parameter(init_pose))\n        self.data.required_grad_ = True\n\n    def copy_from(self, pose):\n        self.data = deepcopy(pose.data)\n\n    def matrix(self):\n        Rt = torch.eye(4)\n        Rt[:3, :3] = self.rotation()\n        Rt[:3, 3] = self.translation()\n        return Rt\n\n    def rotation(self):\n        w = self.data[3:]\n        wx = self.skew_symmetric(w)\n        theta = w.norm(dim=-1)[..., None, None]\n        I = torch.eye(3, device=w.device, dtype=torch.float32)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        R = I+A*wx+B*wx@wx\n        return R\n\n    def translation(self,):\n        return self.data[:3]\n\n    @classmethod\n    def log(cls, R, eps=1e-7):  # [...,3,3]\n        trace = R[..., 0, 0]+R[..., 1, 1]+R[..., 2, 2]\n        # ln(R) will explode if theta==pi\n        theta = ((trace-1)/2).clamp(-1+eps, 1-eps).acos_()[..., None, None] % np.pi\n        lnR = 1/(2*cls.taylor_A(theta)+1e-8) * (R-R.transpose(-2, -1))  # FIXME: wei-chiu finds it weird\n        w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]\n        w = torch.stack([w0, w1, w2], dim=-1)\n        return w\n\n    @classmethod\n    def from_matrix(cls, Rt, eps=1e-8):  # [...,3,4]\n        R, u = Rt[:3, :3], Rt[:3, 3]\n        w = cls.log(R)\n        return OptimizablePose(torch.cat([u, w], dim=-1))\n\n    @classmethod\n    def skew_symmetric(cls, w):\n        w0, w1, w2 = w.unbind(dim=-1)\n        O = torch.zeros_like(w0)\n        wx = torch.stack([\n            torch.stack([O, -w2, w1], dim=-1),\n            torch.stack([w2, O, -w0], dim=-1),\n            torch.stack([-w1, w0, O], dim=-1)], dim=-2)\n        return wx\n\n    @classmethod\n    def taylor_A(cls, x, nth=10):\n        # Taylor expansion of sin(x)/x\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            if i > 0:\n                denom *= (2*i)*(2*i+1)\n            ans = ans+(-1)**i*x**(2*i)/denom\n        return ans\n\n    @classmethod\n    def taylor_B(cls, x, nth=10):\n        # Taylor expansion of (1-cos(x))/x**2\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            denom *= (2*i+1)*(2*i+2)\n            ans = ans+(-1)**i*x**(2*i)/denom\n        return ans\n\n    @classmethod\n    def taylor_C(cls, x, nth=10):\n        # Taylor expansion of (x-sin(x))/x**3\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            denom *= (2*i+2)*(2*i+3)\n            ans = ans+(-1)**i*x**(2*i)/denom\n        return ans\n\n\nif __name__ == '__main__':\n    before = torch.tensor([[-0.955421, 0.119616, - 0.269932, 2.655830],\n                           [0.295248, 0.388339, - 0.872939, 2.981598],\n                           [0.000408, - 0.913720, - 0.406343, 1.368648],\n                           [0.000000, 0.000000, 0.000000, 1.000000]])\n    pose = OptimizablePose.from_matrix(before)\n    print(pose.rotation())\n    print(pose.translation())\n    after = pose.matrix()\n    print(after)\n    print(torch.abs((before-after)[:3, 3]))\n"
  },
  {
    "path": "src/share.py",
    "content": "from multiprocessing.managers import BaseManager, NamespaceProxy\nfrom copy import deepcopy\nimport torch.multiprocessing as mp\nfrom time import sleep\nimport sys\n\n\nclass ShareDataProxy(NamespaceProxy):\n    _exposed_ = ('__getattribute__', '__setattr__')\n\n\nclass ShareData:\n    global lock\n    lock = mp.RLock()\n\n    def __init__(self):\n        self.__stop_mapping = False\n        self.__stop_tracking = False\n\n        self.__decoder = None\n        self.__voxels = None\n        self.__octree = None\n        self.__states = None\n        self.__tracking_trajectory = []\n\n    @property\n    def decoder(self):\n        with lock:\n            return deepcopy(self.__decoder)\n            print(\"========== decoder get ==========\")\n            sys.stdout.flush()\n\n    @decoder.setter\n    def decoder(self, decoder):\n        with lock:\n            self.__decoder = deepcopy(decoder)\n            # print(\"========== decoder set ==========\")\n            sys.stdout.flush()\n\n    @property\n    def voxels(self):\n        with lock:\n            return deepcopy(self.__voxels)\n            print(\"========== voxels get ==========\")\n            sys.stdout.flush()\n\n    @voxels.setter\n    def voxels(self, voxels):\n        with lock:\n            self.__voxels = deepcopy(voxels)\n            print(\"========== voxels set ==========\")\n            sys.stdout.flush()\n\n    @property\n    def octree(self):\n        with lock:\n            return deepcopy(self.__octree)\n            print(\"========== octree get ==========\")\n            sys.stdout.flush()\n\n    @octree.setter\n    def octree(self, octree):\n        with lock:\n            self.__octree = deepcopy(octree)\n            print(\"========== octree set ==========\")\n            sys.stdout.flush()\n\n    @property\n    def states(self):\n        with lock:\n            return self.__states\n            print(\"========== states get ==========\")\n            sys.stdout.flush()\n\n    @states.setter\n    def states(self, states):\n        with lock:\n            self.__states = states\n            # print(\"========== states set ==========\")\n            sys.stdout.flush()\n\n    @property\n    def stop_mapping(self):\n        with lock:\n            return self.__stop_mapping\n            print(\"========== stop_mapping get ==========\")\n            sys.stdout.flush()\n\n    @stop_mapping.setter\n    def stop_mapping(self, stop_mapping):\n        with lock:\n            self.__stop_mapping = stop_mapping\n            print(\"========== stop_mapping set ==========\")\n            sys.stdout.flush()\n\n    @property\n    def stop_tracking(self):\n        with lock:\n            return self.__stop_tracking\n            print(\"========== stop_tracking get ==========\")\n            sys.stdout.flush()\n\n    @stop_tracking.setter\n    def stop_tracking(self, stop_tracking):\n        with lock:\n            self.__stop_tracking = stop_tracking\n            print(\"========== stop_tracking set ==========\")\n            sys.stdout.flush()\n\n    @property\n    def tracking_trajectory(self):\n        with lock:\n            return deepcopy(self.__tracking_trajectory)\n            print(\"========== tracking_trajectory get ==========\")\n            sys.stdout.flush()\n\n    def push_pose(self, pose):\n        with lock:\n            self.__tracking_trajectory.append(deepcopy(pose))\n            # print(\"========== push_pose ==========\")\n            sys.stdout.flush()\n"
  },
  {
    "path": "src/tracking.py",
    "content": "import torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom criterion import Criterion\nfrom lidarFrame import LidarFrame\nfrom utils.import_util import get_property\nfrom utils.profile_util import Profiler\nfrom variations.render_helpers import fill_in, render_rays, track_frame\nfrom se3pose import OptimizablePose\nfrom time import sleep\nfrom copy import deepcopy\n\n\nclass Tracking:\n    def __init__(self, args, data_stream, logger):\n        self.args = args\n        self.last_frame_id = 0\n        self.last_frame = None\n\n        self.data_stream = data_stream\n        self.logger = logger\n        self.loss_criteria = Criterion(args)\n\n        self.voxel_size = args.mapper_specs[\"voxel_size\"]\n        self.N_rays = args.tracker_specs[\"N_rays\"]\n        self.num_iterations = args.tracker_specs[\"num_iterations\"]\n        self.sdf_truncation = args.criteria[\"sdf_truncation\"]\n        self.learning_rate = args.tracker_specs[\"learning_rate\"]\n        self.start_frame = args.tracker_specs[\"start_frame\"]\n        self.end_frame = args.tracker_specs[\"end_frame\"]\n        self.step_size = args.tracker_specs[\"step_size\"]\n        # self.keyframe_freq = args.tracker_specs[\"keyframe_freq\"]\n        self.max_voxel_hit = args.tracker_specs[\"max_voxel_hit\"]\n        self.max_distance = args.data_specs[\"max_depth\"]\n        self.step_size = self.step_size * self.voxel_size\n        self.read_offset = args.tracker_specs[\"read_offset\"]\n        self.mesh_freq = args.debug_args[\"mesh_freq\"]\n        if self.end_frame <= 0:\n            self.end_frame = len(self.data_stream)-1\n\n        # sanity check on the lower/upper bounds\n        self.start_frame = min(self.start_frame, len(self.data_stream))\n        self.end_frame = min(self.end_frame, len(self.data_stream))\n        self.rel_pose = None\n        # profiler\n        verbose = get_property(args.debug_args, \"verbose\", False)\n        self.profiler = Profiler(verbose=verbose)\n        self.profiler.enable()\n\n    def process_first_frame(self, kf_buffer):\n        init_pose = self.data_stream.get_init_pose(self.start_frame)\n        index, points, pointcos, _ = self.data_stream[self.start_frame]\n        first_frame = LidarFrame(index, points, pointcos, init_pose)\n        first_frame.pose.requires_grad_(False)\n        first_frame.points.requires_grad_(False)\n\n        print(\"******* initializing first_frame:\", first_frame.index)\n        kf_buffer.put(first_frame, block=True)\n        self.last_frame = first_frame\n        self.start_frame += 1\n\n\n    def spin(self, share_data, kf_buffer):\n        print(\"******* tracking process started! *******\")\n        progress_bar = tqdm(\n            range(self.start_frame, self.end_frame+1), position=0)\n        progress_bar.set_description(\"tracking frame\")\n        for frame_id in progress_bar:\n            if frame_id % self.read_offset != 0:\n                continue\n            if share_data.stop_tracking:\n                break\n\n            data_in = self.data_stream[frame_id]\n\n            current_frame = LidarFrame(*data_in)\n            if isinstance(data_in[3], np.ndarray):\n                self.last_frame = current_frame\n                self.check_keyframe(current_frame, kf_buffer)\n\n            else:\n                self.do_tracking(share_data, current_frame, kf_buffer)\n\n        share_data.stop_mapping = True\n        print(\"******* tracking process died *******\")\n\n        sleep(60)\n        while not kf_buffer.empty():\n            sleep(60)\n\n    def check_keyframe(self, check_frame, kf_buffer):\n        try:\n            kf_buffer.put(check_frame, block=True)\n        except:\n            pass\n\n    def do_tracking(self, share_data, current_frame, kf_buffer):\n\n        self.profiler.tick(\"before track1111\")\n        decoder = share_data.decoder.cuda()\n        self.profiler.tok(\"before track1111\")\n\n        self.profiler.tick(\"before track2222\")\n        map_states = share_data.states\n        map_states[\"voxel_vertex_emb\"] = map_states[\"voxel_vertex_emb\"].cuda()\n        self.profiler.tok(\"before track2222\")\n\n        constant_move_pose = self.last_frame.get_pose().detach()\n        input_pose = deepcopy(self.last_frame.pose)\n        input_pose.requires_grad_(False)\n        if self.rel_pose != None:\n            constant_move_pose[:3, 3] = (constant_move_pose @ (self.rel_pose))[:3, 3]\n            input_pose.data[:3] = constant_move_pose[:3, 3].T\n        torch.cuda.empty_cache()\n        self.profiler.tick(\"track frame\")\n\n        frame_pose, hit_mask = track_frame(\n            input_pose,\n            current_frame,\n            map_states,\n            decoder,\n            self.loss_criteria,\n            self.voxel_size,\n            self.N_rays,\n            self.step_size,\n            self.num_iterations if self.rel_pose != None else self.num_iterations*5,\n            self.sdf_truncation,\n            self.learning_rate if self.rel_pose != None else self.learning_rate,\n            self.max_voxel_hit,\n            self.max_distance,\n            profiler=self.profiler,\n            depth_variance=True\n        )\n        self.profiler.tok(\"track frame\")\n        if hit_mask == None:\n            current_frame.pose = OptimizablePose.from_matrix(constant_move_pose)\n        else:\n            current_frame.pose = frame_pose\n            current_frame.hit_ratio = hit_mask.sum() / self.N_rays\n\n        self.rel_pose = torch.linalg.inv(self.last_frame.get_pose().detach()) @ current_frame.get_pose().detach()\n        current_frame.set_rel_pose(self.rel_pose)\n        self.last_frame = current_frame\n\n        self.profiler.tick(\"transport frame\")\n        self.check_keyframe(current_frame, kf_buffer)\n        self.profiler.tok(\"transport frame\")\n"
  },
  {
    "path": "src/utils/__init__.py",
    "content": ""
  },
  {
    "path": "src/utils/import_util.py",
    "content": "from importlib import import_module\nimport argparse\n\ndef get_dataset(args):\n    Dataset = import_module(\"dataset.\"+args.dataset)\n    return Dataset.DataLoader(**args.data_specs)\n\ndef get_decoder(args):\n    Decoder = import_module(\"variations.\"+args.decoder)\n    return Decoder.Decoder(**args.decoder_specs)\n\ndef get_property(args, name, default):\n    if isinstance(args, dict):\n        return args.get(name, default)\n    elif isinstance(args, argparse.Namespace):\n        if hasattr(args, name):\n            return vars(args)[name]\n        else:\n            return default\n    else:\n        raise ValueError(f\"unkown dict/namespace type: {type(args)}\")"
  },
  {
    "path": "src/utils/mesh_util.py",
    "content": "import math\nimport torch\n\nimport numpy as np\nimport open3d as o3d\nfrom scipy.spatial import cKDTree\nfrom skimage.measure import marching_cubes\nfrom variations.render_helpers import get_scores, eval_points\n\n\nclass MeshExtractor:\n    def __init__(self, args):\n        self.voxel_size = args.mapper_specs[\"voxel_size\"]\n        self.rays_d = None\n        self.depth_points = None\n\n    @ torch.no_grad()\n    def linearize_id(self, xyz, n_xyz):\n        return xyz[:, 2] + n_xyz[-1] * xyz[:, 1] + (n_xyz[-1] * n_xyz[-2]) * xyz[:, 0]\n\n    @torch.no_grad()\n    def downsample_points(self, points, voxel_size=0.01):\n        pcd = o3d.geometry.PointCloud()\n        pcd.points = o3d.utility.Vector3dVector(points)\n        pcd = pcd.voxel_down_sample(voxel_size)\n        return np.asarray(pcd.points)\n\n    @torch.no_grad()\n    def get_rays(self, w=None, h=None, K=None):\n        w = self.w if w == None else w\n        h = self.h if h == None else h\n        if K is None:\n            K = np.eye(3)\n            K[0, 0] = self.K[0, 0] * w / self.w\n            K[1, 1] = self.K[1, 1] * h / self.h\n            K[0, 2] = self.K[0, 2] * w / self.w\n            K[1, 2] = self.K[1, 2] * h / self.h\n        ix, iy = torch.meshgrid(\n            torch.arange(w), torch.arange(h), indexing='xy')\n        rays_d = torch.stack(\n            [(ix-K[0, 2]) / K[0, 0],\n             (iy-K[1, 2]) / K[1, 1],\n             torch.ones_like(ix)], -1).float()\n        return rays_d\n\n    @torch.no_grad()\n    def get_valid_points(self, frame_poses, depth_maps):\n        if isinstance(frame_poses, list):\n            all_points = []\n            print(\"extracting all points\")\n            for i in range(0, len(frame_poses), 5):\n                pose = frame_poses[i]\n                depth = depth_maps[i]\n                points = self.rays_d * depth.unsqueeze(-1)\n                points = points.reshape(-1, 3)\n                points = points @ pose[:3, :3].transpose(-1, -2) + pose[:3, 3]\n                if len(all_points) == 0:\n                    all_points = points.detach().cpu().numpy()\n                else:\n                    all_points = np.concatenate(\n                        [all_points, points.detach().cpu().numpy()], 0)\n            print(\"downsample all points\")\n            all_points = self.downsample_points(all_points)\n            return all_points\n        else:\n            pose = frame_poses\n            depth = depth_maps\n            points = self.rays_d * depth.unsqueeze(-1)\n            points = points.reshape(-1, 3)\n            points = points @ pose[:3, :3].transpose(-1, -2) + pose[:3, 3]\n            if self.depth_points is None:\n                self.depth_points = points.detach().cpu().numpy()\n            else:\n                self.depth_points = np.concatenate(\n                    [self.depth_points, points], 0)\n            self.depth_points = self.downsample_points(self.depth_points)\n        return self.depth_points\n\n    @ torch.no_grad()\n    def create_mesh(self, decoder, map_states, voxel_size, voxels,\n                    frame_poses=None, depth_maps=None, clean_mseh=False,\n                    require_color=False, offset=-80, res=8):\n\n        sdf_grid = get_scores(decoder, map_states, voxel_size, bits=res)\n        sdf_grid = sdf_grid.reshape(-1, res, res, res, 1)\n\n        voxel_centres = map_states[\"voxel_center_xyz\"]\n        verts, faces = self.marching_cubes(voxel_centres, sdf_grid)\n\n        if clean_mseh:\n            print(\"********** get points from frames **********\")\n            all_points = self.get_valid_points(frame_poses, depth_maps)\n            print(\"********** construct kdtree **********\")\n            kdtree = cKDTree(all_points)\n            print(\"********** query kdtree **********\")\n            point_mask = kdtree.query_ball_point(\n                verts, voxel_size * 0.5, workers=12, return_length=True)\n            print(\"********** finished querying kdtree **********\")\n            point_mask = point_mask > 0\n            face_mask = point_mask[faces.reshape(-1)].reshape(-1, 3).any(-1)\n\n            faces = faces[face_mask]\n\n        if require_color:\n            print(\"********** get color from network **********\")\n            verts_torch = torch.from_numpy(verts).float().cuda()\n            batch_points = torch.split(verts_torch, 1000)\n            colors = []\n            for points in batch_points:\n                voxel_pos = points // self.voxel_size\n                batch_voxels = voxels[:, :3].cuda()\n                batch_voxels = batch_voxels.unsqueeze(\n                    0).repeat(voxel_pos.shape[0], 1, 1)\n\n                # filter outliers\n                nonzeros = (batch_voxels == voxel_pos.unsqueeze(1)).all(-1)\n                nonzeros = torch.where(nonzeros, torch.ones_like(\n                    nonzeros).int(), -torch.ones_like(nonzeros).int())\n                sorted, index = torch.sort(nonzeros, dim=-1, descending=True)\n                sorted = sorted[:, 0]\n                index = index[:, 0]\n                valid = (sorted != -1)\n                color_empty = torch.zeros_like(points)\n                points = points[valid, :]\n                index = index[valid]\n\n                # get color\n                if len(points) > 0:\n                    color = eval_points(decoder, map_states,\n                                        points, index, voxel_size).cuda()\n                    color_empty[valid] = color\n                colors += [color_empty]\n            colors = torch.cat(colors, 0)\n\n        mesh = o3d.geometry.TriangleMesh()\n        mesh.vertices = o3d.utility.Vector3dVector(verts+offset)\n        mesh.triangles = o3d.utility.Vector3iVector(faces)\n        if require_color:\n            mesh.vertex_colors = o3d.utility.Vector3dVector(\n                colors.detach().cpu().numpy())\n        mesh.compute_vertex_normals()\n        return mesh\n\n    @ torch.no_grad()\n    def marching_cubes(self, voxels, sdf):\n        voxels = voxels[:, :3]\n        sdf = sdf[..., 0]\n        res = 1.0 / (sdf.shape[1] - 1)\n        spacing = [res, res, res]\n\n        num_verts = 0\n        total_verts = []\n        total_faces = []\n        for i in range(len(voxels)):\n            sdf_volume = sdf[i].detach().cpu().numpy()\n            if np.min(sdf_volume) > 0 or np.max(sdf_volume) < 0:\n                continue\n            verts, faces, _, _ = marching_cubes(sdf_volume, 0, spacing=spacing)\n            verts -= 0.5\n            verts *= self.voxel_size\n            verts += voxels[i].detach().cpu().numpy()\n            faces += num_verts\n            num_verts += verts.shape[0]\n\n            total_verts += [verts]\n            total_faces += [faces]\n        total_verts = np.concatenate(total_verts)\n        total_faces = np.concatenate(total_faces)\n        return total_verts, total_faces\n"
  },
  {
    "path": "src/utils/profile_util.py",
    "content": "from time import time\nimport torch\n\n\nclass Profiler(object):\n    def __init__(self, verbose=False) -> None:\n        self.timer = dict()\n        self.time_log = dict()\n        self.enabled = False\n        self.verbose = verbose\n\n    def enable(self):\n        self.enabled = True\n\n    def disable(self):\n        self.enabled = False\n\n    def tick(self, name):\n        if not self.enabled:\n            return\n        self.timer[name] = time()\n        if name not in self.time_log:\n            self.time_log[name] = list()\n\n    def tok(self, name):\n        if not self.enabled:\n            return\n        if name not in self.timer:\n            return\n        torch.cuda.synchronize()\n        elapsed = time() - self.timer[name]\n        if self.verbose:\n            print(f\"{name}: {elapsed*1000:.2f} ms\")\n        else:\n            self.time_log[name].append(elapsed * 1000)\n"
  },
  {
    "path": "src/utils/sample_util.py",
    "content": "import torch\n\n\ndef sampling_without_replacement(logp, k):\n    def gumbel_like(u):\n        return -torch.log(-torch.log(torch.rand_like(u) + 1e-7) + 1e-7)\n\n    scores = logp + gumbel_like(logp)\n    return scores.topk(k, dim=-1)[1]\n\n\ndef sample_rays(mask, num_samples):\n    B, H, W = mask.shape\n    probs = mask / (mask.sum() + 1e-9)\n    flatten_probs = probs.reshape(B, -1)\n    sampled_index = sampling_without_replacement(\n        torch.log(flatten_probs + 1e-9), num_samples)\n    sampled_masks = (torch.zeros_like(\n        flatten_probs).scatter_(-1, sampled_index, 1).reshape(B, H, W) > 0)\n    return sampled_masks"
  },
  {
    "path": "src/variations/decode_morton.py",
    "content": "import numpy as np\n\n\ndef compact(value):\n    x = value & 0x1249249249249249\n    x = (x | x >> 2) & 0x10c30c30c30c30c3\n    x = (x | x >> 4) & 0x100f00f00f00f00f\n    x = (x | x >> 8) & 0x1f0000ff0000ff\n    x = (x | x >> 16) & 0x1f00000000ffff\n    x = (x | x >> 32) & 0x1fffff\n    return x\n\n\ndef decode(code):\n    return compact(code >> 0), compact(code >> 1), compact(code >> 2)\n\n\nfor i in range(10):\n    x, y, z = decode(samples_valid['sampled_point_voxel_idx'][i])\n    print(x, y, z)\n    print(torch.sqrt((x-80)**2+(y-80)**2+(z-80)**2))\n    print(samples_valid['sampled_point_depth'][i])\n"
  },
  {
    "path": "src/variations/lidar.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass GaussianFourierFeatureTransform(torch.nn.Module):\n    \"\"\"\n    Modified based on the implementation of Gaussian Fourier feature mapping.\n\n    \"Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains\":\n       https://arxiv.org/abs/2006.10739\n       https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html\n\n    \"\"\"\n\n    def __init__(self, num_input_channels, mapping_size=93, scale=25, learnable=True):\n        super().__init__()\n\n        if learnable:\n            self._B = nn.Parameter(torch.randn(\n                (num_input_channels, mapping_size)) * scale)\n        else:\n            self._B = torch.randn((num_input_channels, mapping_size)) * scale\n        self.embedding_size = mapping_size\n\n    def forward(self, x):\n        # x = x.squeeze(0)\n        assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(x.dim())\n        x = x @ self._B.to(x.device)\n        return torch.sin(x)\n\n\nclass Nerf_positional_embedding(torch.nn.Module):\n    \"\"\"\n    Nerf positional embedding.\n\n    \"\"\"\n\n    def __init__(self, in_dim, multires, log_sampling=True):\n        super().__init__()\n        self.log_sampling = log_sampling\n        self.include_input = True\n        self.periodic_fns = [torch.sin, torch.cos]\n        self.max_freq_log2 = multires-1\n        self.num_freqs = multires\n        self.max_freq = self.max_freq_log2\n        self.N_freqs = self.num_freqs\n        self.embedding_size = multires*in_dim*2 + in_dim\n\n    def forward(self, x):\n        # x = x.squeeze(0)\n        assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(\n            x.dim())\n\n        if self.log_sampling:\n            freq_bands = 2.**torch.linspace(0.,\n                                            self.max_freq, steps=self.N_freqs)\n        else:\n            freq_bands = torch.linspace(\n                2.**0., 2.**self.max_freq, steps=self.N_freqs)\n        output = []\n        if self.include_input:\n            output.append(x)\n        for freq in freq_bands:\n            for p_fn in self.periodic_fns:\n                output.append(p_fn(x * freq))\n        ret = torch.cat(output, dim=1)\n        return ret\n\n\nclass Same(nn.Module):\n    def __init__(self, in_dim) -> None:\n        super().__init__()\n        self.embedding_size = in_dim\n\n    def forward(self, x):\n        return x\n\n\nclass Decoder(nn.Module):\n    def __init__(self,\n                 depth=8,\n                 width=258,\n                 in_dim=3,\n                 sdf_dim=128,\n                 skips=[4],\n                 multires=6,\n                 embedder='none',\n                 point_dim=3,\n                 local_coord=False,\n                 **kwargs) -> None:\n        super().__init__()\n        self.D = depth\n        self.W = width\n        self.skips = skips\n        self.point_dim = point_dim\n        if embedder == 'nerf':\n            self.pe = Nerf_positional_embedding(in_dim, multires)\n        elif embedder == 'none':\n            self.pe = Same(in_dim)\n        elif embedder == 'gaussian':\n            self.pe = GaussianFourierFeatureTransform(in_dim)\n        else:\n            raise NotImplementedError(\"unknown positional encoder\")\n        self.pts_linears = nn.ModuleList(\n            [nn.Linear(self.pe.embedding_size, width)] + [nn.Linear(width, width) if i not in self.skips else nn.Linear(width + self.pe.embedding_size, width) for i in range(depth-1)])\n        self.sdf_out = nn.Linear(width, 1)\n\n    def get_values(self, input):\n        x = self.pe(input)\n        # point = input[:, -3:]\n        h = x\n        for i, l in enumerate(self.pts_linears):\n            h = self.pts_linears[i](h)\n            h = F.relu(h)\n            if i in self.skips:\n                h = torch.cat([x, h], -1)\n\n        # outputs = self.output_linear(h)\n        # outputs[:, :3] = torch.sigmoid(outputs[:, :3])\n        sdf_out = self.sdf_out(h)\n\n        return sdf_out\n\n    def forward(self, inputs):\n        outputs = self.get_values(inputs)\n\n        return {\n            'sdf': outputs,\n            # 'depth': outputs[:, 1]\n        }\n"
  },
  {
    "path": "src/variations/render_helpers.py",
    "content": "from copy import deepcopy\nimport torch\nimport torch.nn.functional as F\n\nfrom .voxel_helpers import ray_intersect, ray_sample\nfrom torch.autograd import grad\n\n\ndef ray(ray_start, ray_dir, depths):\n    return ray_start + ray_dir * depths\n\n\ndef fill_in(shape, mask, input, initial=1.0):\n    if isinstance(initial, torch.Tensor):\n        output = initial.expand(*shape)\n    else:\n        output = input.new_ones(*shape) * initial\n    return output.masked_scatter(mask.unsqueeze(-1).expand(*shape), input)\n\n\ndef masked_scatter(mask, x):\n    B, K = mask.size()\n    if x.dim() == 1:\n        return x.new_zeros(B, K).masked_scatter(mask, x)\n    return x.new_zeros(B, K, x.size(-1)).masked_scatter(\n        mask.unsqueeze(-1).expand(B, K, x.size(-1)), x\n    )\n\n\ndef masked_scatter_ones(mask, x):\n    B, K = mask.size()\n    if x.dim() == 1:\n        return x.new_ones(B, K).masked_scatter(mask, x)\n    return x.new_ones(B, K, x.size(-1)).masked_scatter(\n        mask.unsqueeze(-1).expand(B, K, x.size(-1)), x\n    )\n\n\n@torch.enable_grad()\ndef trilinear_interp(p, q, point_feats):\n    weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True)\n    if point_feats.dim() == 2:\n        point_feats = point_feats.view(point_feats.size(0), 8, -1)\n\n    point_feats = (weights * point_feats).sum(1)\n    return point_feats\n\n\ndef offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):\n    c = torch.arange(1, 2 * bits, 2, device=point_xyz.device)\n    ox, oy, oz = torch.meshgrid([c, c, c], indexing='ij')\n    offset = (torch.cat([\n        ox.reshape(-1, 1),\n        oy.reshape(-1, 1),\n        oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1)\n    if not offset_only:\n        return (\n            point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel)\n    return offset.type_as(point_xyz) * quarter_voxel\n\n\n@torch.enable_grad()\ndef get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size):\n    # tri-linear interpolation\n    p = ((sampled_xyz - point_xyz) / voxel_size + 0.5).unsqueeze(1)\n    q = offset_points(p, 0.5, offset_only=True).unsqueeze(0) + 0.5\n    feats = trilinear_interp(p, q, point_feats).float()\n    # if self.args.local_coord:\n    # feats = torch.cat([(p-.5).squeeze(1).float(), feats], dim=-1)\n    return feats\n\n\n@torch.enable_grad()\ndef get_features(samples, map_states, voxel_size):\n    # encoder states\n    point_idx = map_states[\"voxel_vertex_idx\"].cuda()\n    point_xyz = map_states[\"voxel_center_xyz\"].cuda()\n    values = map_states[\"voxel_vertex_emb\"]\n    point_id2embedid = map_states[\"voxel_id2embedding_id\"]\n    # ray point samples\n    sampled_idx = samples[\"sampled_point_voxel_idx\"].long()\n    sampled_xyz = samples[\"sampled_point_xyz\"]\n    sampled_dis = samples[\"sampled_point_distance\"]\n\n    point_xyz = F.embedding(sampled_idx, point_xyz).requires_grad_()\n    selected_points_idx = F.embedding(sampled_idx, point_idx)\n    flatten_selected_points_idx = selected_points_idx.view(-1)\n    embed_idx = F.embedding(flatten_selected_points_idx.cpu(), point_id2embedid).squeeze(-1)\n    point_feats = F.embedding(embed_idx.cuda(), values).view(point_xyz.size(0), -1)\n\n    feats = get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size)\n    inputs = {\"xyz\": point_xyz, \"dists\": sampled_dis, \"emb\": feats.cuda()}\n    return inputs\n\n\n@torch.no_grad()\ndef get_scores(sdf_network, map_states, voxel_size, bits=8):\n    feats = map_states[\"voxel_vertex_idx\"]\n    points = map_states[\"voxel_center_xyz\"]\n    values = map_states[\"voxel_vertex_emb\"]\n    point_id2embedid = map_states[\"voxel_id2embedding_id\"]\n\n    chunk_size = 10000\n    res = bits  # -1\n\n    @torch.no_grad()\n    def get_scores_once(feats, points, values, point_id2embedid):\n        torch.cuda.empty_cache()\n        # sample points inside voxels\n        start = -0.5\n        end = 0.5  # - 1./bits\n\n        x = y = z = torch.linspace(start, end, res)\n        # z = torch.linspace(1, 1, res)\n        xx, yy, zz = torch.meshgrid(x, y, z)\n        sampled_xyz = torch.stack([xx, yy, zz], dim=-1).float().cuda()\n\n        sampled_xyz *= voxel_size\n        sampled_xyz = sampled_xyz.reshape(1, -1, 3) + points.unsqueeze(1)\n\n        sampled_idx = torch.arange(points.size(0), device=points.device)\n        sampled_idx = sampled_idx[:, None].expand(*sampled_xyz.size()[:2])\n        sampled_idx = sampled_idx.reshape(-1)\n        sampled_xyz = sampled_xyz.reshape(-1, 3)\n\n        if sampled_xyz.shape[0] == 0:\n            return\n\n        field_inputs = get_features(\n            {\n                \"sampled_point_xyz\": sampled_xyz,\n                \"sampled_point_voxel_idx\": sampled_idx,\n                \"sampled_point_ray_direction\": None,\n                \"sampled_point_distance\": None,\n            },\n            {\n                \"voxel_vertex_idx\": feats,\n                \"voxel_center_xyz\": points,\n                \"voxel_vertex_emb\": values,\n                \"voxel_id2embedding_id\": point_id2embedid\n            },\n            voxel_size\n        )\n        field_inputs = field_inputs[\"emb\"]\n\n        # evaluation with density\n        sdf_values = sdf_network.get_values(field_inputs.float().cuda())\n        return sdf_values.reshape(-1, res ** 3, 1).detach().cpu()\n\n    return torch.cat([\n        get_scores_once(feats[i: i + chunk_size],\n                        points[i: i + chunk_size].cuda(), values, point_id2embedid)\n        for i in range(0, points.size(0), chunk_size)], 0).view(-1, res, res, res, 1)\n\n\n@torch.no_grad()\ndef eval_points(sdf_network, map_states, sampled_xyz, sampled_idx, voxel_size):\n    feats = map_states[\"voxel_vertex_idx\"]\n    points = map_states[\"voxel_center_xyz\"]\n    values = map_states[\"voxel_vertex_emb\"]\n\n    # sampled_xyz = sampled_xyz.reshape(1, 3) + points.unsqueeze(1)\n    # sampled_idx = sampled_idx[None, :].expand(*sampled_xyz.size()[:2])\n    sampled_idx = sampled_idx.reshape(-1)\n    sampled_xyz = sampled_xyz.reshape(-1, 3)\n\n    if sampled_xyz.shape[0] == 0:\n        return\n\n    field_inputs = get_features(\n        {\n            \"sampled_point_xyz\": sampled_xyz,\n            \"sampled_point_voxel_idx\": sampled_idx,\n            \"sampled_point_ray_direction\": None,\n            \"sampled_point_distance\": None,\n        },\n        {\n            \"voxel_vertex_idx\": feats,\n            \"voxel_center_xyz\": points,\n            \"voxel_vertex_emb\": values,\n        },\n        voxel_size\n    )\n\n    # evaluation with density\n    sdf_values = sdf_network.get_values(field_inputs['emb'].float().cuda())\n    return sdf_values.reshape(-1, 4)[:, :3].detach().cpu()\n\n\ndef render_rays(\n        rays_o,\n        rays_d,\n        map_states,\n        sdf_network,\n        step_size,\n        voxel_size,\n        truncation,\n        max_voxel_hit,\n        max_distance,\n        chunk_size=10000,\n        profiler=None,\n        return_raw=False\n):\n    torch.cuda.empty_cache()\n    centres = map_states[\"voxel_center_xyz\"].cuda()\n    childrens = map_states[\"voxel_structure\"].cuda()\n\n    if profiler is not None:\n        profiler.tick(\"ray_intersect\")\n    # print(\"Center\", rays_o[0][0])\n    intersections, hits = ray_intersect(\n        rays_o, rays_d, centres,\n        childrens, voxel_size, max_voxel_hit, max_distance)\n    if profiler is not None:\n        profiler.tok(\"ray_intersect\")\n    if hits.sum() <= 0:\n        return\n\n    ray_mask = hits.view(1, -1)\n\n    intersections = {\n        name: outs[ray_mask].reshape(-1, outs.size(-1))\n        for name, outs in intersections.items()\n    }\n\n    rays_o = rays_o[ray_mask].reshape(-1, 3)\n    rays_d = rays_d[ray_mask].reshape(-1, 3)\n\n    if profiler is not None:\n        profiler.tick(\"ray_sample\")\n    samples = ray_sample(intersections, step_size=step_size)\n    if samples == None:\n        return\n    if profiler is not None:\n        profiler.tok(\"ray_sample\")\n\n    sampled_depth = samples['sampled_point_depth']\n    sampled_idx = samples['sampled_point_voxel_idx'].long()\n\n    # only compute when the ray hits\n\n    sample_mask = sampled_idx.ne(-1)\n    if sample_mask.sum() == 0:  # miss everything skip\n        return None, 0\n\n    sampled_xyz = ray(rays_o.unsqueeze(\n        1), rays_d.unsqueeze(1), sampled_depth.unsqueeze(2))\n    sampled_dir = rays_d.unsqueeze(1).expand(\n        *sampled_depth.size(), rays_d.size()[-1])\n    sampled_dir = sampled_dir / \\\n        (torch.norm(sampled_dir, 2, -1, keepdim=True) + 1e-8)\n    samples['sampled_point_xyz'] = sampled_xyz\n    samples['sampled_point_ray_direction'] = sampled_dir\n    # apply mask\n    samples_valid = {name: s[sample_mask] for name, s in samples.items()}\n    # print(\"samples_valid_xyz\", samples[\"sampled_point_xyz\"].shape)\n    num_points = samples_valid['sampled_point_depth'].shape[0]\n    field_outputs = []\n    if chunk_size < 0:\n        chunk_size = num_points\n    final_xyz = []\n    xyz = 0\n    for i in range(0, num_points, chunk_size):\n        torch.cuda.empty_cache()\n        chunk_samples = {name: s[i:i+chunk_size]\n                         for name, s in samples_valid.items()}\n\n        # get encoder features as inputs\n        if profiler is not None:\n            profiler.tick(\"get_features\")\n\n        chunk_inputs = get_features(chunk_samples, map_states, voxel_size)\n        xyz = chunk_inputs[\"xyz\"]\n        if profiler is not None:\n            profiler.tok(\"get_features\")\n        # add coordinate information\n        chunk_inputs = chunk_inputs[\"emb\"]\n\n        # forward implicit fields\n        if profiler is not None:\n            profiler.tick(\"render_core\")\n\n        chunk_outputs = sdf_network(chunk_inputs)\n        if profiler is not None:\n            profiler.tok(\"render_core\")\n        final_xyz.append(xyz)\n        field_outputs.append(chunk_outputs)\n\n    field_outputs = {name: torch.cat(\n        [r[name] for r in field_outputs], dim=0) for name in field_outputs[0]}\n    final_xyz = torch.cat(final_xyz, 0)\n    outputs = field_outputs['sdf']\n    d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)\n    sdf_grad = grad(outputs=outputs,\n                    inputs=xyz,\n                    grad_outputs=d_points,\n                    retain_graph=True,)\n\n    outputs = {'sample_mask': sample_mask}\n\n    sdf = masked_scatter_ones(sample_mask, field_outputs['sdf']).squeeze(-1)\n    # depth = masked_scatter(sample_mask, field_outputs['depth'])\n\n    # colour = torch.sigmoid(colour)\n    sample_mask = outputs['sample_mask']\n\n    valid_mask = torch.where(\n        sample_mask, torch.ones_like(\n            sample_mask), torch.zeros_like(sample_mask)\n    )\n\n    return {\n        \"z_vals\": samples[\"sampled_point_depth\"],\n        \"sdf\": sdf,\n        \"ray_mask\": ray_mask,\n        \"valid_mask\": valid_mask,\n        \"sampled_xyz\": xyz,\n    }\n\n\ndef bundle_adjust_frames(\n    keyframe_graph,\n    embeddings,\n    map_states,\n    sdf_network,\n    loss_criteria,\n    voxel_size,\n    step_size,\n    N_rays=512,\n    num_iterations=10,\n    truncation=0.1,\n    max_voxel_hit=10,\n    max_distance=10,\n    learning_rate=[1e-2, 1e-2, 5e-3],\n    update_pose=True,\n    update_decoder=True,\n    profiler=None\n):\n    if profiler is not None:\n        profiler.tick(\"mapping_add_optim\")\n    optimize_params = [{'params': embeddings, 'lr': learning_rate[0]}]\n    if update_decoder:\n        optimize_params += [{'params': sdf_network.parameters(),\n                             'lr': learning_rate[1]}]\n\n    for keyframe in keyframe_graph:\n        if keyframe.index != 0 and update_pose:\n            keyframe.pose.requires_grad_(True)\n            optimize_params += [{\n                'params': keyframe.pose.parameters(), 'lr': learning_rate[2]\n            }]\n\n    optim = torch.optim.Adam(optimize_params)\n    if profiler is not None:\n        profiler.tok(\"mapping_add_optim\")\n    for iter in range(num_iterations):\n        torch.cuda.empty_cache()\n        rays_o = []\n        rays_d = []\n        rgb_samples = []\n        depth_samples = []\n        points_samples = []\n        pointsCos_samples = []\n        if iter == 0 and profiler is not None:\n            profiler.tick(\"mapping sample_rays\")\n        for frame in keyframe_graph:\n            torch.cuda.empty_cache()\n            pose = frame.get_pose().cuda()\n            frame.sample_rays(N_rays)\n\n            sample_mask = frame.sample_mask.cuda()\n            sampled_rays_d = frame.rays_d[sample_mask].cuda()\n            # print(sampled_rays_d)\n            R = pose[: 3, : 3].transpose(-1, -2)\n            sampled_rays_d = sampled_rays_d@R\n            sampled_rays_o = pose[: 3, 3].reshape(1, -1).expand_as(sampled_rays_d)\n\n            rays_d += [sampled_rays_d]\n            rays_o += [sampled_rays_o]\n            points_samples += [frame.points.unsqueeze(1).cuda()[sample_mask]]\n            pointsCos_samples += [frame.pointsCos.unsqueeze(1).cuda()[sample_mask]]\n            # rgb_samples += [frame.rgb.cuda()[sample_mask]]\n            # depth_samples += [frame.depth.cuda()[sample_mask]]\n\n        rays_d = torch.cat(rays_d, dim=0).unsqueeze(0)\n        rays_o = torch.cat(rays_o, dim=0).unsqueeze(0)\n        points_samples = torch.cat(points_samples, dim=0).unsqueeze(0)\n        pointsCos_samples = torch.cat(pointsCos_samples, dim=0).unsqueeze(0)\n\n        if iter == 0 and profiler is not None:\n            profiler.tok(\"mapping sample_rays\")\n        if iter == 0 and profiler is not None:\n            profiler.tick(\"mapping rendering\")\n        final_outputs = render_rays(\n            rays_o,\n            rays_d,\n            map_states,\n            sdf_network,\n            step_size,\n            voxel_size,\n            truncation,\n            max_voxel_hit,\n            max_distance,\n            chunk_size=-1,\n            profiler=profiler if iter == 0 else None\n        )\n        if final_outputs == None:\n            print(\"Encouter a bug while Mapping, currently not be fixed, Continue!!\")\n            hit_mask = None\n            continue\n        if iter == 0 and profiler is not None:\n            profiler.tok(\"mapping rendering\")\n        # if final_outputs == None:\n        #    continue\n        if iter == 0 and profiler is not None:\n            profiler.tick(\"mapping back proj\")\n        torch.cuda.empty_cache()\n        loss, _ = loss_criteria(\n            final_outputs, points_samples, pointsCos_samples)\n\n        optim.zero_grad()\n        loss.backward()\n        optim.step()\n        if iter == 0 and profiler is not None:\n            profiler.tok(\"mapping back proj\")\n\n\ndef track_frame(\n    frame_pose,\n    curr_frame,\n    map_states,\n    sdf_network,\n    loss_criteria,\n    voxel_size,\n    N_rays=512,\n    step_size=0.05,\n    num_iterations=10,\n    truncation=0.1,\n    learning_rate=1e-3,\n    max_voxel_hit=10,\n    max_distance=10,\n    profiler=None,\n    depth_variance=False\n):\n    torch.cuda.empty_cache()\n    init_pose = deepcopy(frame_pose).cuda()\n    init_pose.requires_grad_(True)\n    optim = torch.optim.Adam(init_pose.parameters(),\n                             lr=learning_rate*2 if curr_frame.index < 2\n                             else learning_rate/3)\n\n    for iter in range(num_iterations):\n        torch.cuda.empty_cache()\n        if iter == 0 and profiler is not None:\n            profiler.tick(\"track sample_rays\")\n        curr_frame.sample_rays(N_rays, track=True)\n        if iter == 0 and profiler is not None:\n            profiler.tok(\"track sample_rays\")\n\n        sample_mask = curr_frame.sample_mask\n        ray_dirs = curr_frame.rays_d[sample_mask].unsqueeze(0).cuda()\n\n        points_samples = curr_frame.points.unsqueeze(1).cuda()[sample_mask]\n        pointsCos_samples = curr_frame.pointsCos.unsqueeze(1).cuda()[sample_mask]\n\n\n        ray_dirs_iter = ray_dirs.squeeze(\n            0) @ init_pose.rotation().transpose(-1, -2)\n        ray_dirs_iter = ray_dirs_iter.unsqueeze(0)\n        ray_start_iter = init_pose.translation().reshape(\n            1, 1, -1).expand_as(ray_dirs_iter).cuda().contiguous()\n\n        if iter == 0 and profiler is not None:\n            profiler.tick(\"track render_rays\")\n        final_outputs = render_rays(\n            ray_start_iter,\n            ray_dirs_iter,\n            map_states,\n            sdf_network,\n            step_size,\n            voxel_size,\n            truncation,\n            max_voxel_hit,\n            max_distance,\n            chunk_size=-2,\n            profiler=profiler if iter == 0 else None\n        )\n        if final_outputs == None:\n            print(\"Encouter a bug while Tracking, currently not be fixed, Restarting!!\")\n            hit_mask = None\n            break\n\n        torch.cuda.empty_cache()\n        if iter == 0 and profiler is not None:\n            profiler.tok(\"track render_rays\")\n\n        hit_mask = final_outputs[\"ray_mask\"].view(N_rays)\n        final_outputs[\"ray_mask\"] = hit_mask\n\n        if iter == 0 and profiler is not None:\n            profiler.tick(\"track loss_criteria\")\n        loss, _ = loss_criteria(\n            final_outputs, points_samples, pointsCos_samples, weight_depth_loss=depth_variance)\n        if iter == 0 and profiler is not None:\n            profiler.tok(\"track loss_criteria\")\n        if iter == 0 and profiler is not None:\n            profiler.tick(\"track backward step\")\n        optim.zero_grad()\n        loss.backward()\n        optim.step()\n        if iter == 0 and profiler is not None:\n            profiler.tok(\"track backward step\")\n\n    return init_pose, hit_mask\n"
  },
  {
    "path": "src/variations/voxel_helpers.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this source tree.\n\n\"\"\" Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch \"\"\"\nfrom __future__ import (\n    division,\n    absolute_import,\n    with_statement,\n    print_function,\n    unicode_literals,\n)\nimport os\nimport sys\nimport torch\nimport torch.nn.functional as F\nfrom torch.autograd import Function\nimport torch.nn as nn\nimport sys\nimport numpy as np\nimport grid as _ext\n\nMAX_DEPTH = 80\n\n\nclass BallRayIntersect(Function):\n    @staticmethod\n    def forward(ctx, radius, n_max, points, ray_start, ray_dir):\n        inds, min_depth, max_depth = _ext.ball_intersect(\n            ray_start.float(), ray_dir.float(), points.float(), radius, n_max\n        )\n        min_depth = min_depth.type_as(ray_start)\n        max_depth = max_depth.type_as(ray_start)\n\n        ctx.mark_non_differentiable(inds)\n        ctx.mark_non_differentiable(min_depth)\n        ctx.mark_non_differentiable(max_depth)\n        return inds, min_depth, max_depth\n\n    @staticmethod\n    def backward(ctx, a, b, c):\n        return None, None, None, None, None\n\n\nball_ray_intersect = BallRayIntersect.apply\n\n\nclass AABBRayIntersect(Function):\n    @staticmethod\n    def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir):\n        # HACK: speed-up ray-voxel intersection by batching...\n        # HACK: avoid out-of-memory\n        G = min(2048, int(2 * 10 ** 9 / points.numel()))\n        S, N = ray_start.shape[:2]\n        K = int(np.ceil(N / G))\n        H = K * G\n        if H > N:\n            ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)\n            ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)\n        ray_start = ray_start.reshape(S * G, K, 3)\n        ray_dir = ray_dir.reshape(S * G, K, 3)\n        points = points.expand(S * G, *points.size()[1:]).contiguous()\n\n        inds, min_depth, max_depth = _ext.aabb_intersect(\n            ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max\n        )\n        min_depth = min_depth.type_as(ray_start)\n        max_depth = max_depth.type_as(ray_start)\n\n        inds = inds.reshape(S, H, -1)\n        min_depth = min_depth.reshape(S, H, -1)\n        max_depth = max_depth.reshape(S, H, -1)\n        if H > N:\n            inds = inds[:, :N]\n            min_depth = min_depth[:, :N]\n            max_depth = max_depth[:, :N]\n\n        ctx.mark_non_differentiable(inds)\n        ctx.mark_non_differentiable(min_depth)\n        ctx.mark_non_differentiable(max_depth)\n        return inds, min_depth, max_depth\n\n    @staticmethod\n    def backward(ctx, a, b, c):\n        return None, None, None, None, None\n\n\naabb_ray_intersect = AABBRayIntersect.apply\n\n\nclass SparseVoxelOctreeRayIntersect(Function):\n    @staticmethod\n    def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir):\n        # HACK: avoid out-of-memory\n        torch.cuda.empty_cache()\n        G = min(256, int(2 * 10 ** 9 / (points.numel() + children.numel())))\n        S, N = ray_start.shape[:2]\n        K = int(np.ceil(N / G))\n        H = K * G\n        if H > N:\n            ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)\n            ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)\n        ray_start = ray_start.reshape(S * G, K, 3)\n        ray_dir = ray_dir.reshape(S * G, K, 3)\n        points = points.expand(S * G, *points.size()).contiguous()\n        torch.cuda.empty_cache()\n        children = children.expand(S * G, *children.size()).contiguous()\n        torch.cuda.empty_cache()\n        inds, min_depth, max_depth = _ext.svo_intersect(\n            ray_start.float(),\n            ray_dir.float(),\n            points.float(),\n            children.int(),\n            voxelsize,\n            n_max,\n        )\n        torch.cuda.empty_cache()\n        min_depth = min_depth.type_as(ray_start)\n        max_depth = max_depth.type_as(ray_start)\n\n        inds = inds.reshape(S, H, -1)\n        min_depth = min_depth.reshape(S, H, -1)\n        max_depth = max_depth.reshape(S, H, -1)\n        if H > N:\n            inds = inds[:, :N]\n            min_depth = min_depth[:, :N]\n            max_depth = max_depth[:, :N]\n\n        ctx.mark_non_differentiable(inds)\n        ctx.mark_non_differentiable(min_depth)\n        ctx.mark_non_differentiable(max_depth)\n        return inds, min_depth, max_depth\n\n    @staticmethod\n    def backward(ctx, a, b, c):\n        return None, None, None, None, None\n\n\nsvo_ray_intersect = SparseVoxelOctreeRayIntersect.apply\n\n\nclass TriangleRayIntersect(Function):\n    @staticmethod\n    def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir):\n        # HACK: speed-up ray-voxel intersection by batching...\n        # HACK: avoid out-of-memory\n        G = min(2048, int(2 * 10 ** 9 / (3 * faces.numel())))\n        S, N = ray_start.shape[:2]\n        K = int(np.ceil(N / G))\n        H = K * G\n        if H > N:\n            ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)\n            ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)\n        ray_start = ray_start.reshape(S * G, K, 3)\n        ray_dir = ray_dir.reshape(S * G, K, 3)\n        face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3))\n        face_points = (\n            face_points.unsqueeze(0).expand(\n                S * G, *face_points.size()).contiguous()\n        )\n        inds, depth, uv = _ext.triangle_intersect(\n            ray_start.float(),\n            ray_dir.float(),\n            face_points.float(),\n            cagesize,\n            blur_ratio,\n            n_max,\n        )\n        depth = depth.type_as(ray_start)\n        uv = uv.type_as(ray_start)\n\n        inds = inds.reshape(S, H, -1)\n        depth = depth.reshape(S, H, -1, 3)\n        uv = uv.reshape(S, H, -1)\n        if H > N:\n            inds = inds[:, :N]\n            depth = depth[:, :N]\n            uv = uv[:, :N]\n\n        ctx.mark_non_differentiable(inds)\n        ctx.mark_non_differentiable(depth)\n        ctx.mark_non_differentiable(uv)\n        return inds, depth, uv\n\n    @staticmethod\n    def backward(ctx, a, b, c):\n        return None, None, None, None, None, None\n\n\ntriangle_ray_intersect = TriangleRayIntersect.apply\n\n\nclass UniformRaySampling(Function):\n    @staticmethod\n    def forward(\n        ctx,\n        pts_idx,\n        min_depth,\n        max_depth,\n        step_size,\n        max_ray_length,\n        deterministic=False,\n    ):\n        G, N, P = 256, pts_idx.size(0), pts_idx.size(1)\n        H = int(np.ceil(N / G)) * G\n        if H > N:\n            pts_idx = torch.cat([pts_idx, pts_idx[: H - N]], 0)\n            min_depth = torch.cat([min_depth, min_depth[: H - N]], 0)\n            max_depth = torch.cat([max_depth, max_depth[: H - N]], 0)\n        pts_idx = pts_idx.reshape(G, -1, P)\n        min_depth = min_depth.reshape(G, -1, P)\n        max_depth = max_depth.reshape(G, -1, P)\n\n        # pre-generate noise\n        max_steps = int(max_ray_length / step_size)\n        max_steps = max_steps + min_depth.size(-1) * 2\n        noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)\n        if deterministic:\n            noise += 0.5\n        else:\n            noise = noise.uniform_()\n\n        # call cuda function\n        sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling(\n            pts_idx,\n            min_depth.float(),\n            max_depth.float(),\n            noise.float(),\n            step_size,\n            max_steps,\n        )\n        sampled_depth = sampled_depth.type_as(min_depth)\n        sampled_dists = sampled_dists.type_as(min_depth)\n\n        sampled_idx = sampled_idx.reshape(H, -1)\n        sampled_depth = sampled_depth.reshape(H, -1)\n        sampled_dists = sampled_dists.reshape(H, -1)\n        if H > N:\n            sampled_idx = sampled_idx[:N]\n            sampled_depth = sampled_depth[:N]\n            sampled_dists = sampled_dists[:N]\n\n        max_len = sampled_idx.ne(-1).sum(-1).max()\n        sampled_idx = sampled_idx[:, :max_len]\n        sampled_depth = sampled_depth[:, :max_len]\n        sampled_dists = sampled_dists[:, :max_len]\n\n        ctx.mark_non_differentiable(sampled_idx)\n        ctx.mark_non_differentiable(sampled_depth)\n        ctx.mark_non_differentiable(sampled_dists)\n        return sampled_idx, sampled_depth, sampled_dists\n\n    @staticmethod\n    def backward(ctx, a, b, c):\n        return None, None, None, None, None, None\n\n\nuniform_ray_sampling = UniformRaySampling.apply\n\n\nclass InverseCDFRaySampling(Function):\n    @staticmethod\n    def forward(\n        ctx,\n        pts_idx,\n        min_depth,\n        max_depth,\n        probs,\n        steps,\n        fixed_step_size=-1,\n        deterministic=False,\n    ):\n        G, N, P = 200, pts_idx.size(0), pts_idx.size(1)\n        H = int(np.ceil(N / G)) * G\n\n        if H > N:\n            pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H - N, P)], 0)\n            min_depth = torch.cat(\n                [min_depth, min_depth[:1].expand(H - N, P)], 0)\n            max_depth = torch.cat(\n                [max_depth, max_depth[:1].expand(H - N, P)], 0)\n            probs = torch.cat([probs, probs[:1].expand(H - N, P)], 0)\n            steps = torch.cat([steps, steps[:1].expand(H - N)], 0)\n\n        # print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device)\n        pts_idx = pts_idx.reshape(G, -1, P)\n        min_depth = min_depth.reshape(G, -1, P)\n        max_depth = max_depth.reshape(G, -1, P)\n        probs = probs.reshape(G, -1, P)\n        steps = steps.reshape(G, -1)\n\n        # pre-generate noise\n        max_steps = steps.ceil().long().max() + P\n        # print(max_steps)\n        # print(*min_depth.size()[:-1],\" \", max_steps)\n        noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)\n        if deterministic:\n            noise += 0.5\n        else:\n            noise = noise.uniform_().clamp(min=0.001, max=0.999)  # in case\n\n        # call cuda function\n        chunk_size = 4 * G  # to avoid oom?\n        results = [\n            _ext.inverse_cdf_sampling(\n                pts_idx[:, i: i + chunk_size].contiguous(),\n                min_depth.float()[:, i: i + chunk_size].contiguous(),\n                max_depth.float()[:, i: i + chunk_size].contiguous(),\n                noise.float()[:, i: i + chunk_size].contiguous(),\n                probs.float()[:, i: i + chunk_size].contiguous(),\n                steps.float()[:, i: i + chunk_size].contiguous(),\n                fixed_step_size,\n            )\n            for i in range(0, min_depth.size(1), chunk_size)\n        ]\n\n        sampled_idx, sampled_depth, sampled_dists = [\n            torch.cat([r[i] for r in results], 1) for i in range(3)\n        ]\n        sampled_depth = sampled_depth.type_as(min_depth)\n        sampled_dists = sampled_dists.type_as(min_depth)\n\n        sampled_idx = sampled_idx.reshape(H, -1)\n        sampled_depth = sampled_depth.reshape(H, -1)\n        sampled_dists = sampled_dists.reshape(H, -1)\n        if H > N:\n            sampled_idx = sampled_idx[:N]\n            sampled_depth = sampled_depth[:N]\n            sampled_dists = sampled_dists[:N]\n\n        max_len = sampled_idx.ne(-1).sum(-1).max()\n        sampled_idx = sampled_idx[:, :max_len]\n        sampled_depth = sampled_depth[:, :max_len]\n        sampled_dists = sampled_dists[:, :max_len]\n\n        ctx.mark_non_differentiable(sampled_idx)\n        ctx.mark_non_differentiable(sampled_depth)\n        ctx.mark_non_differentiable(sampled_dists)\n        return sampled_idx, sampled_depth, sampled_dists\n\n    @staticmethod\n    def backward(ctx, a, b, c):\n        return None, None, None, None, None, None, None\n\n\ninverse_cdf_sampling = InverseCDFRaySampling.apply\n\n\n# back-up for ray point sampling\n@torch.no_grad()\ndef _parallel_ray_sampling(\n    MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False\n):\n    # uniform sampling\n    _min_depth = min_depth.min(1)[0]\n    _max_depth = max_depth.masked_fill(max_depth.eq(MAX_DEPTH), 0).max(1)[0]\n    max_ray_length = (_max_depth - _min_depth).max()\n\n    delta = torch.arange(\n        int(max_ray_length / MARCH_SIZE), device=min_depth.device, dtype=min_depth.dtype\n    )\n    delta = delta[None, :].expand(min_depth.size(0), delta.size(-1))\n    if deterministic:\n        delta = delta + 0.5\n    else:\n        delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99)\n    delta = delta * MARCH_SIZE\n    sampled_depth = min_depth[:, :1] + delta\n    sampled_idx = (sampled_depth[:, :, None] >=\n                   min_depth[:, None, :]).sum(-1) - 1\n    sampled_idx = pts_idx.gather(1, sampled_idx)\n\n    # include all boundary points\n    sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1)\n    sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1)\n\n    # reorder\n    sampled_depth, ordered_index = sampled_depth.sort(-1)\n    sampled_idx = sampled_idx.gather(1, ordered_index)\n    sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1]  # distances\n    sampled_depth = 0.5 * \\\n        (sampled_depth[:, 1:] + sampled_depth[:, :-1])  # mid-points\n\n    # remove all invalid depths\n    min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1\n    max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1)\n\n    sampled_depth.masked_fill_(\n        (max_ids.ne(min_ids))\n        | (sampled_depth > _max_depth[:, None])\n        | (sampled_dists == 0.0),\n        MAX_DEPTH,\n    )\n    sampled_depth, ordered_index = sampled_depth.sort(-1)  # sort again\n    sampled_masks = sampled_depth.eq(MAX_DEPTH)\n    num_max_steps = (~sampled_masks).sum(-1).max()\n\n    sampled_depth = sampled_depth[:, :num_max_steps]\n    sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_(\n        sampled_masks, 0.0\n    )[:, :num_max_steps]\n    sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_(sampled_masks, -1)[\n        :, :num_max_steps\n    ]\n\n    return sampled_idx, sampled_depth, sampled_dists\n\n\n@torch.no_grad()\ndef parallel_ray_sampling(\n    MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False\n):\n    chunk_size = 4096\n    full_size = min_depth.shape[0]\n    if full_size <= chunk_size:\n        return _parallel_ray_sampling(\n            MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic\n        )\n\n    outputs = zip(\n        *[\n            _parallel_ray_sampling(\n                MARCH_SIZE,\n                pts_idx[i: i + chunk_size],\n                min_depth[i: i + chunk_size],\n                max_depth[i: i + chunk_size],\n                deterministic=deterministic,\n            )\n            for i in range(0, full_size, chunk_size)\n        ]\n    )\n    sampled_idx, sampled_depth, sampled_dists = outputs\n\n    def padding_points(xs, pad):\n        if len(xs) == 1:\n            return xs[0]\n\n        maxlen = max([x.size(1) for x in xs])\n        full_size = sum([x.size(0) for x in xs])\n        xt = xs[0].new_ones(full_size, maxlen).fill_(pad)\n        st = 0\n        for i in range(len(xs)):\n            xt[st: st + xs[i].size(0), : xs[i].size(1)] = xs[i]\n            st += xs[i].size(0)\n        return xt\n\n    sampled_idx = padding_points(sampled_idx, -1)\n    sampled_depth = padding_points(sampled_depth, MAX_DEPTH)\n    sampled_dists = padding_points(sampled_dists, 0.0)\n    return sampled_idx, sampled_depth, sampled_dists\n\n\ndef discretize_points(voxel_points, voxel_size):\n    # this function turns voxel centers/corners into integer indeices\n    # we assume all points are alreay put as voxels (real numbers)\n    minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0]\n    voxel_indices = (\n        ((voxel_points - minimal_voxel_point) / voxel_size).round_().long()\n    )  # float\n    residual = (voxel_points - voxel_indices.type_as(voxel_points) * voxel_size).mean(\n        0, keepdim=True\n    )\n    return voxel_indices, residual\n\n\ndef build_easy_octree(points, half_voxel):\n    coords, residual = discretize_points(points, half_voxel)\n    ranges = coords.max(0)[0] - coords.min(0)[0]\n    depths = torch.log2(ranges.max().float()).ceil_().long() - 1\n    center = (coords.max(0)[0] + coords.min(0)[0]) / 2\n    centers, children = _ext.build_octree(center, coords, int(depths))\n    centers = centers.float() * half_voxel + residual  # transform back to float\n    return centers, children\n\n\n@torch.enable_grad()\ndef trilinear_interp(p, q, point_feats):\n    weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True)\n    if point_feats.dim() == 2:\n        point_feats = point_feats.view(point_feats.size(0), 8, -1)\n\n    point_feats = (weights * point_feats).sum(1)\n    return point_feats\n\n\ndef offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):\n    c = torch.arange(1, 2 * bits, 2, device=point_xyz.device)\n    ox, oy, oz = torch.meshgrid([c, c, c], indexing='ij')\n    offset = (torch.cat([ox.reshape(-1, 1),\n                         oy.reshape(-1, 1),\n                         oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1)\n    if not offset_only:\n        return (\n            point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel)\n    return offset.type_as(point_xyz) * quarter_voxel\n\n\ndef splitting_points(point_xyz, point_feats, values, half_voxel):\n    # generate new centers\n    quarter_voxel = half_voxel * 0.5\n    new_points = offset_points(point_xyz, quarter_voxel).reshape(-1, 3)\n    old_coords = discretize_points(point_xyz, quarter_voxel)[0]\n    new_coords = offset_points(old_coords).reshape(-1, 3)\n    new_keys0 = offset_points(new_coords).reshape(-1, 3)\n\n    # get unique keys and inverse indices (for original key0, where it maps to in keys)\n    new_keys, new_feats = torch.unique(\n        new_keys0, dim=0, sorted=True, return_inverse=True)\n    new_keys_idx = new_feats.new_zeros(new_keys.size(0)).scatter_(\n        0, new_feats, torch.arange(new_keys0.size(0), device=new_feats.device) // 64)\n\n    # recompute key vectors using trilinear interpolation\n    new_feats = new_feats.reshape(-1, 8)\n\n    if values is not None:\n        # (1/4 voxel size)\n        p = (new_keys - old_coords[new_keys_idx]\n             ).type_as(point_xyz).unsqueeze(1) * 0.25 + 0.5\n        q = offset_points(p, 0.5, offset_only=True).unsqueeze(0) + 0.5  # BUG?\n        point_feats = point_feats[new_keys_idx]\n        point_feats = F.embedding(point_feats, values).view(\n            point_feats.size(0), -1)\n        new_values = trilinear_interp(p, q, point_feats)\n    else:\n        new_values = None\n    return new_points, new_feats, new_values, new_keys\n\n\n@torch.no_grad()\ndef ray_intersect(ray_start, ray_dir, flatten_centers, flatten_children, voxel_size, max_hits, max_distance=MAX_DEPTH):\n    # ray-voxel intersection\n    max_hits_temp = 20\n    pts_idx, min_depth, max_depth = svo_ray_intersect(\n        voxel_size,\n        max_hits_temp,\n        flatten_centers,\n        flatten_children,\n        ray_start,\n        ray_dir)\n    torch.cuda.empty_cache()\n    # sort the depths\n    min_depth.masked_fill_(pts_idx.eq(-1), max_distance)\n    max_depth.masked_fill_(pts_idx.eq(-1), max_distance)\n\n    min_depth, sorted_idx = min_depth.sort(dim=-1)\n    max_depth = max_depth.gather(-1, sorted_idx)\n    pts_idx = pts_idx.gather(-1, sorted_idx)\n    # print(max_depth.max())\n    pts_idx[max_depth > 2*max_distance] = -1\n    pts_idx[min_depth > max_distance] = -1\n    min_depth.masked_fill_(pts_idx.eq(-1), max_distance)\n    max_depth.masked_fill_(pts_idx.eq(-1), max_distance)\n    # remove all points that completely miss the object\n    max_hits = torch.max(pts_idx.ne(-1).sum(-1))\n    min_depth = min_depth[..., :max_hits]\n    max_depth = max_depth[..., :max_hits]\n    pts_idx = pts_idx[..., :max_hits]\n\n    hits = pts_idx.ne(-1).any(-1)\n\n    intersection_outputs = {\n        \"min_depth\": min_depth,\n        \"max_depth\": max_depth,\n        \"intersected_voxel_idx\": pts_idx,\n    }\n    return intersection_outputs, hits\n\n\n@torch.no_grad()\ndef ray_sample(intersection_outputs, step_size=0.01, fixed=False):\n    dists = (\n        intersection_outputs[\"max_depth\"] -\n        intersection_outputs[\"min_depth\"]\n    ).masked_fill(intersection_outputs[\"intersected_voxel_idx\"].eq(-1), 0)\n    intersection_outputs[\"probs\"] = dists / dists.sum(dim=-1, keepdim=True)\n    intersection_outputs[\"steps\"] = dists.sum(-1) / step_size\n    # TODO:A serious BUG need to fix!\n    if dists.sum(-1).max() > 10 * MAX_DEPTH:\n        return\n    # sample points and use middle point approximation\n    sampled_idx, sampled_depth, sampled_dists = inverse_cdf_sampling(\n        intersection_outputs[\"intersected_voxel_idx\"],\n        intersection_outputs[\"min_depth\"],\n        intersection_outputs[\"max_depth\"],\n        intersection_outputs[\"probs\"],\n        intersection_outputs[\"steps\"], -1, fixed)\n\n    sampled_dists = sampled_dists.clamp(min=0.0)\n    sampled_depth.masked_fill_(sampled_idx.eq(-1), MAX_DEPTH)\n    sampled_dists.masked_fill_(sampled_idx.eq(-1), 0.0)\n\n    samples = {\n        \"sampled_point_depth\": sampled_depth,\n        \"sampled_point_distance\": sampled_dists,\n        \"sampled_point_voxel_idx\": sampled_idx,\n    }\n    return samples\n"
  },
  {
    "path": "third_party/marching_cubes/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nimport glob\n\n_ext_sources = glob.glob(\"src/*.cpp\") + glob.glob(\"src/*.cu\")\n\nsetup(\n    name='marching_cubes',\n    ext_modules=[\n        CUDAExtension(\n            name='marching_cubes',\n            sources=_ext_sources,\n            extra_compile_args={\n                \"cxx\": [\"-O2\", \"-I./include\"],\n                \"nvcc\": [\"-I./include\"]\n            },\n        )\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    }\n)"
  },
  {
    "path": "third_party/marching_cubes/src/mc.cpp",
    "content": "#include <torch/extension.h>\n\nstd::vector<torch::Tensor> marching_cubes_sparse(\n    torch::Tensor indexer,           // (nx, ny, nz) -> data_id\n    torch::Tensor valid_blocks,      // (K, )\n    torch::Tensor vec_batch_mapping, //\n    torch::Tensor cube_sdf,          // (M, rx, ry, rz)\n    torch::Tensor cube_std,          // (M, rx, ry, rz)\n    const std::vector<int> &n_xyz,   // [nx, ny, nz]\n    float max_std,                   // Prune all vertices\n    int max_n_triangles              // Maximum number of triangle buffer.\n);\n\nstd::vector<torch::Tensor> marching_cubes_sparse_colour(\n    torch::Tensor indexer,           // (nx, ny, nz) -> data_id\n    torch::Tensor valid_blocks,      // (K, )\n    torch::Tensor vec_batch_mapping, //\n    torch::Tensor cube_sdf,          // (M, rx, ry, rz, 4)\n    torch::Tensor cube_colour,       // (M, rx, ry, rz)\n    const std::vector<int> &n_xyz,   // [nx, ny, nz]\n    float max_std,                   // Prune all vertices\n    int max_n_triangles              // Maximum number of triangle buffer.\n);\n\nstd::vector<torch::Tensor> marching_cubes_sparse_interp_cuda(\n    torch::Tensor indexer,           // (nx, ny, nz) -> data_id\n    torch::Tensor valid_blocks,      // (K, )\n    torch::Tensor vec_batch_mapping, //\n    torch::Tensor cube_sdf,          // (M, rx, ry, rz)\n    torch::Tensor cube_std,          // (M, rx, ry, rz)\n    const std::vector<int> &n_xyz,   // [nx, ny, nz]\n    float max_std,                   // Prune all vertices\n    int max_n_triangles              // Maximum number of triangle buffer.\n);\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n    m.def(\"marching_cubes_sparse\", &marching_cubes_sparse, \"Marching Cubes without Interpolation (CUDA)\");\n    m.def(\"marching_cubes_sparse_colour\", &marching_cubes_sparse_colour, \"Marching Cubes without Interpolation (CUDA)\");\n    m.def(\"marching_cubes_sparse_interp\", &marching_cubes_sparse_interp_cuda, \"Marching Cubes with Interpolation (CUDA)\");\n}"
  },
  {
    "path": "third_party/marching_cubes/src/mc_data.cuh",
    "content": "#include <torch/extension.h>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <thrust/device_vector.h>\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\n#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)\n\nusing IndexerAccessor = torch::PackedTensorAccessor32<int64_t, 3, torch::RestrictPtrTraits>;\nusing ValidBlocksAccessor = torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits>;\nusing BackwardMappingAccessor = torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits>;\nusing CubeSDFAccessor = torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits>;\nusing CubeSDFRGBAccessor = torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits>;\nusing TrianglesAccessor = torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits>;\nusing TriangleStdAccessor = torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits>;\nusing TriangleVecIdAccessor = torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits>;\n\n\n__inline__ __device__ float4 make_float4(const float3& xyz, float w) {\n    return make_float4(xyz.x, xyz.y, xyz.z, w);\n}\n\ninline __host__ __device__ void operator+=(float2 &a, const float2& b) {\n    a.x += b.x; a.y += b.y;\n}\n\ninline __host__ __device__ float2 operator*(const float2& a, float b) {\n    return make_float2(a.x * b, a.y * b);\n}\n\ninline __host__ __device__ float2 operator/(const float2& a, float b) {\n    return make_float2(a.x / b, a.y / b);\n}\n\ninline __host__ __device__ float2 operator/(const float2& a, const float2& b) {\n    return make_float2(a.x / b.x, a.y / b.y);\n}\n\n__constant__ int edgeTable[256] = { 0x0, 0x109, 0x203, 0x30a, 0x406, 0x50f, 0x605, 0x70c, 0x80c, 0x905, 0xa0f, 0xb06, 0xc0a, 0xd03, 0xe09, 0xf00,\n\t0x190, 0x99, 0x393, 0x29a, 0x596, 0x49f, 0x795, 0x69c, 0x99c, 0x895, 0xb9f, 0xa96, 0xd9a, 0xc93, 0xf99, 0xe90, 0x230, 0x339, 0x33, 0x13a,\n\t0x636, 0x73f, 0x435, 0x53c, 0xa3c, 0xb35, 0x83f, 0x936, 0xe3a, 0xf33, 0xc39, 0xd30, 0x3a0, 0x2a9, 0x1a3, 0xaa, 0x7a6, 0x6af, 0x5a5, 0x4ac,\n\t0xbac, 0xaa5, 0x9af, 0x8a6, 0xfaa, 0xea3, 0xda9, 0xca0, 0x460, 0x569, 0x663, 0x76a, 0x66, 0x16f, 0x265, 0x36c, 0xc6c, 0xd65, 0xe6f, 0xf66,\n\t0x86a, 0x963, 0xa69, 0xb60, 0x5f0, 0x4f9, 0x7f3, 0x6fa, 0x1f6, 0xff, 0x3f5, 0x2fc, 0xdfc, 0xcf5, 0xfff, 0xef6, 0x9fa, 0x8f3, 0xbf9, 0xaf0,\n\t0x650, 0x759, 0x453, 0x55a, 0x256, 0x35f, 0x55, 0x15c, 0xe5c, 0xf55, 0xc5f, 0xd56, 0xa5a, 0xb53, 0x859, 0x950, 0x7c0, 0x6c9, 0x5c3, 0x4ca,\n\t0x3c6, 0x2cf, 0x1c5, 0xcc, 0xfcc, 0xec5, 0xdcf, 0xcc6, 0xbca, 0xac3, 0x9c9, 0x8c0, 0x8c0, 0x9c9, 0xac3, 0xbca, 0xcc6, 0xdcf, 0xec5, 0xfcc,\n\t0xcc, 0x1c5, 0x2cf, 0x3c6, 0x4ca, 0x5c3, 0x6c9, 0x7c0, 0x950, 0x859, 0xb53, 0xa5a, 0xd56, 0xc5f, 0xf55, 0xe5c, 0x15c, 0x55, 0x35f, 0x256,\n\t0x55a, 0x453, 0x759, 0x650, 0xaf0, 0xbf9, 0x8f3, 0x9fa, 0xef6, 0xfff, 0xcf5, 0xdfc, 0x2fc, 0x3f5, 0xff, 0x1f6, 0x6fa, 0x7f3, 0x4f9, 0x5f0,\n\t0xb60, 0xa69, 0x963, 0x86a, 0xf66, 0xe6f, 0xd65, 0xc6c, 0x36c, 0x265, 0x16f, 0x66, 0x76a, 0x663, 0x569, 0x460, 0xca0, 0xda9, 0xea3, 0xfaa,\n\t0x8a6, 0x9af, 0xaa5, 0xbac, 0x4ac, 0x5a5, 0x6af, 0x7a6, 0xaa, 0x1a3, 0x2a9, 0x3a0, 0xd30, 0xc39, 0xf33, 0xe3a, 0x936, 0x83f, 0xb35, 0xa3c,\n\t0x53c, 0x435, 0x73f, 0x636, 0x13a, 0x33, 0x339, 0x230, 0xe90, 0xf99, 0xc93, 0xd9a, 0xa96, 0xb9f, 0x895, 0x99c, 0x69c, 0x795, 0x49f, 0x596,\n\t0x29a, 0x393, 0x99, 0x190, 0xf00, 0xe09, 0xd03, 0xc0a, 0xb06, 0xa0f, 0x905, 0x80c, 0x70c, 0x605, 0x50f, 0x406, 0x30a, 0x203, 0x109, 0x0 };\n\n__constant__ int triangleTable[256][16] = { { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1 }, { 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1 }, { 3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1 }, { 3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1 },\n{ 9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1 }, { 9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1 },\n{ 2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1 }, { 8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1 },\n{ 4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1 }, { 3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1 }, { 4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1 },\n{ 4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1 }, { 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1 }, { 5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1 },\n{ 2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1 }, { 9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1 }, { 0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1 },\n{ 2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1 }, { 10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1 },\n{ 4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1 }, { 5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1 },\n{ 5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1 }, { 9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1 }, { 0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1 },\n{ 10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1 }, { 8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1 },\n{ 2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1 }, { 7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1 },\n{ 9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1 }, { 2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1 },\n{ 11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1 }, { 9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1 },\n{ 5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1 }, { 11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1 },\n{ 11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1 }, { 1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1 }, { 9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1 },\n{ 5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1 }, { 2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1 },\n{ 5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1 }, { 6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1 }, { 3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1 },\n{ 6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1 }, { 5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1 },\n{ 10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1 }, { 6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1 }, { 8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1 },\n{ 7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1 }, { 3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1 },\n{ 5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1 }, { 0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1 },\n{ 9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1 }, { 8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1 },\n{ 5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1 }, { 0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1 },\n{ 6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1 }, { 10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1 }, { 10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1 },\n{ 8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1 }, { 1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1 }, { 0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1 }, { 10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1 }, { 3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1 },\n{ 6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1 }, { 9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1 },\n{ 8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1 }, { 3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1 },\n{ 6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1 }, { 10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1 },\n{ 10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1 },\n{ 2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1 }, { 7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1 },\n{ 7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1 },\n{ 2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1 }, { 1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1 },\n{ 11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1 }, { 8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1 },\n{ 0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1 },\n{ 7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1 }, { 10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1 }, { 2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1 },\n{ 6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1 }, { 7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1 }, { 2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1 }, { 10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1 },\n{ 10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1 }, { 0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1 },\n{ 7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1 }, { 6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1 }, { 8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1 },\n{ 9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1 }, { 6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1 },\n{ 1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1 }, { 4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1 },\n{ 10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1 }, { 8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1 },\n{ 1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1 }, { 8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1 },\n{ 10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1 }, { 4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1 },\n{ 10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1 }, { 5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1 },\n{ 11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1 }, { 9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1 },\n{ 6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1 }, { 7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1 },\n{ 3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1 }, { 7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1 },\n{ 9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1 }, { 3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1 },\n{ 6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1 }, { 9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1 },\n{ 1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1 }, { 4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1 },\n{ 7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1 }, { 6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1 }, { 0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1 },\n{ 6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1 },\n{ 0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1 }, { 11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1 },\n{ 6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1 }, { 5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1 },\n{ 9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1 }, { 1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1 },\n{ 1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1 },\n{ 10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1 }, { 0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1 }, { 5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1 },\n{ 10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1 }, { 11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1 }, { 9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1 },\n{ 7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1 }, { 2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1 },\n{ 8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1 }, { 9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1 },\n{ 9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1 }, { 1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1 },\n{ 9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1 },\n{ 5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1 }, { 0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1 },\n{ 10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1 }, { 2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1 },\n{ 0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1 }, { 0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1 },\n{ 9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1 },\n{ 5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1 }, { 3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1 },\n{ 5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1 }, { 8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1 },\n{ 9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1 }, { 1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1 },\n{ 3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1 }, { 4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1 },\n{ 9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1 }, { 11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1 },\n{ 11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1 }, { 2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1 },\n{ 9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1 }, { 3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1 },\n{ 1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1 },\n{ 4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1 }, { 4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1 }, { 0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1 },\n{ 9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1 },\n{ 1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ 0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },\n{ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 } };\n"
  },
  {
    "path": "third_party/marching_cubes/src/mc_interp_kernel.cu",
    "content": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ static inline float2 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,\n                                              const uint max_vec_num,\n                                              const IndexerAccessor indexer,\n                                              const CubeSDFAccessor cube_sdf,\n                                              const CubeSDFAccessor cube_std,\n                                              const BackwardMappingAccessor vec_batch_mapping)\n{\n    if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))\n    {\n        return make_float2(NAN, NAN);\n    }\n    //    printf(\"B-Getting: %d %d %d --> %d, %d, %d\\n\", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));\n    long long vec_ind = indexer[bx][by][bz];\n    if (vec_ind == -1 || vec_ind >= max_vec_num)\n    {\n        return make_float2(NAN, NAN);\n    }\n    int batch_ind = vec_batch_mapping[vec_ind];\n    if (batch_ind == -1)\n    {\n        return make_float2(NAN, NAN);\n    }\n    //    printf(\"Getting: %d %d %d %d --> %d %d\\n\", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));\n    float sdf = cube_sdf[batch_ind][arx][ary][arz];\n    float std = cube_std[batch_ind][arx][ary][arz];\n    return make_float2(sdf, std);\n}\n\n// Use stddev to weight sdf value.\n// #define STD_W_SDF\n\n__device__ static inline float2 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,\n                                        const IndexerAccessor indexer,\n                                        const CubeSDFAccessor cube_sdf,\n                                        const CubeSDFAccessor cube_std,\n                                        const BackwardMappingAccessor vec_batch_mapping)\n{\n    if (bpos.x >= bsize.x)\n    {\n        bpos.x = bsize.x - 1;\n        rpos.x = r - 1;\n    }\n    if (bpos.y >= bsize.y)\n    {\n        bpos.y = bsize.y - 1;\n        rpos.y = r - 1;\n    }\n    if (bpos.z >= bsize.z)\n    {\n        bpos.z = bsize.z - 1;\n        rpos.z = r - 1;\n    }\n\n    uint rbound = (r - 1) / 2;\n    uint rstart = r / 2;\n    float rmid = r / 2.0f;\n\n    float w_xm, w_xp;\n    int bxm, rxm, bxp, rxp;\n    int zero_x;\n    if (rpos.x <= rbound)\n    {\n        bxm = -1;\n        rxm = r;\n        bxp = 0;\n        rxp = 0;\n        w_xp = (float)rpos.x + rmid;\n        w_xm = rmid - (float)rpos.x;\n        zero_x = 1;\n    }\n    else\n    {\n        bxm = 0;\n        rxm = 0;\n        bxp = 1;\n        rxp = -r;\n        w_xp = (float)rpos.x - rmid;\n        w_xm = rmid + r - (float)rpos.x;\n        zero_x = 0;\n    }\n    w_xm /= r;\n    w_xp /= r;\n\n    float w_ym, w_yp;\n    int bym, rym, byp, ryp;\n    int zero_y;\n    if (rpos.y <= rbound)\n    {\n        bym = -1;\n        rym = r;\n        byp = 0;\n        ryp = 0;\n        w_yp = (float)rpos.y + rmid;\n        w_ym = rmid - (float)rpos.y;\n        zero_y = 1;\n    }\n    else\n    {\n        bym = 0;\n        rym = 0;\n        byp = 1;\n        ryp = -r;\n        w_yp = (float)rpos.y - rmid;\n        w_ym = rmid + r - (float)rpos.y;\n        zero_y = 0;\n    }\n    w_ym /= r;\n    w_yp /= r;\n\n    float w_zm, w_zp;\n    int bzm, rzm, bzp, rzp;\n    int zero_z;\n    if (rpos.z <= rbound)\n    {\n        bzm = -1;\n        rzm = r;\n        bzp = 0;\n        rzp = 0;\n        w_zp = (float)rpos.z + rmid;\n        w_zm = rmid - (float)rpos.z;\n        zero_z = 1;\n    }\n    else\n    {\n        bzm = 0;\n        rzm = 0;\n        bzp = 1;\n        rzp = -r;\n        w_zp = (float)rpos.z - rmid;\n        w_zm = rmid + r - (float)rpos.z;\n        zero_z = 0;\n    }\n    w_zm /= r;\n    w_zp /= r;\n\n    rpos.x += rstart;\n    rpos.y += rstart;\n    rpos.z += rstart;\n\n    // printf(\"%u %u %u %d %d %d %d %d %d\\n\", rpos.x, rpos.y, rpos.z, rxm, rxp, rym, ryp, rzm, rzp);\n\n    // Tri-linear interpolation of SDF values.\n#ifndef STD_W_SDF\n    float total_weight = 0.0;\n#else\n    float2 total_weight{0.0, 0.0};\n#endif\n    float2 total_sdf{0.0, 0.0};\n\n    int zero_det = zero_x * 4 + zero_y * 2 + zero_z;\n\n    float2 sdfmmm = query_sdf_raw(bpos.x + bxm, bpos.y + bym, bpos.z + bzm, rpos.x + rxm, rpos.y + rym, rpos.z + rzm,\n                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    float wmmm = w_xm * w_ym * w_zm;\n#ifndef STD_W_SDF\n    if (!isnan(sdfmmm.x))\n    {\n        total_sdf += sdfmmm * wmmm;\n        total_weight += wmmm;\n    }\n#else\n    if (!isnan(sdfmmm.x))\n    {\n        total_sdf.x += sdfmmm.x * wmmm * sdfmmm.y;\n        total_weight.x += wmmm * sdfmmm.y;\n        total_sdf.y += wmmm * sdfmmm.y;\n        total_weight.y += wmmm;\n    }\n#endif\n    else if (zero_det == 0)\n    {\n        return make_float2(NAN, NAN);\n    }\n\n    float2 sdfmmp = query_sdf_raw(bpos.x + bxm, bpos.y + bym, bpos.z + bzp, rpos.x + rxm, rpos.y + rym, rpos.z + rzp,\n                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    float wmmp = w_xm * w_ym * w_zp;\n#ifndef STD_W_SDF\n    if (!isnan(sdfmmp.x))\n    {\n        total_sdf += sdfmmp * wmmp;\n        total_weight += wmmp;\n    }\n#else\n    if (!isnan(sdfmmp.x))\n    {\n        total_sdf.x += sdfmmp.x * wmmp * sdfmmp.y;\n        total_weight.x += wmmp * sdfmmp.y;\n        total_sdf.y += wmmp * sdfmmp.y;\n        total_weight.y += wmmp;\n    }\n#endif\n    else if (zero_det == 1)\n    {\n        return make_float2(NAN, NAN);\n    }\n\n    float2 sdfmpm = query_sdf_raw(bpos.x + bxm, bpos.y + byp, bpos.z + bzm, rpos.x + rxm, rpos.y + ryp, rpos.z + rzm,\n                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    float wmpm = w_xm * w_yp * w_zm;\n#ifndef STD_W_SDF\n    if (!isnan(sdfmpm.x))\n    {\n        total_sdf += sdfmpm * wmpm;\n        total_weight += wmpm;\n    }\n#else\n    if (!isnan(sdfmpm.x))\n    {\n        total_sdf.x += sdfmpm.x * wmpm * sdfmpm.y;\n        total_weight.x += wmpm * sdfmpm.y;\n        total_sdf.y += wmpm * sdfmpm.y;\n        total_weight.y += wmpm;\n    }\n#endif\n    else if (zero_det == 2)\n    {\n        return make_float2(NAN, NAN);\n    }\n\n    float2 sdfmpp = query_sdf_raw(bpos.x + bxm, bpos.y + byp, bpos.z + bzp, rpos.x + rxm, rpos.y + ryp, rpos.z + rzp,\n                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    float wmpp = w_xm * w_yp * w_zp;\n#ifndef STD_W_SDF\n    if (!isnan(sdfmpp.x))\n    {\n        total_sdf += sdfmpp * wmpp;\n        total_weight += wmpp;\n    }\n#else\n    if (!isnan(sdfmpp.x))\n    {\n        total_sdf.x += sdfmpp.x * wmpp * sdfmpp.y;\n        total_weight.x += wmpp * sdfmpp.y;\n        total_sdf.y += wmpp * sdfmpp.y;\n        total_weight.y += wmpp;\n    }\n#endif\n    else if (zero_det == 3)\n    {\n        return make_float2(NAN, NAN);\n    }\n\n    float2 sdfpmm = query_sdf_raw(bpos.x + bxp, bpos.y + bym, bpos.z + bzm, rpos.x + rxp, rpos.y + rym, rpos.z + rzm,\n                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    float wpmm = w_xp * w_ym * w_zm;\n#ifndef STD_W_SDF\n    if (!isnan(sdfpmm.x))\n    {\n        total_sdf += sdfpmm * wpmm;\n        total_weight += wpmm;\n    }\n#else\n    if (!isnan(sdfpmm.x))\n    {\n        total_sdf.x += sdfpmm.x * wpmm * sdfpmm.y;\n        total_weight.x += wpmm * sdfpmm.y;\n        total_sdf.y += wpmm * sdfpmm.y;\n        total_weight.y += wpmm;\n    }\n#endif\n    else if (zero_det == 4)\n    {\n        return make_float2(NAN, NAN);\n    }\n\n    float2 sdfpmp = query_sdf_raw(bpos.x + bxp, bpos.y + bym, bpos.z + bzp, rpos.x + rxp, rpos.y + rym, rpos.z + rzp,\n                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    float wpmp = w_xp * w_ym * w_zp;\n#ifndef STD_W_SDF\n    if (!isnan(sdfpmp.x))\n    {\n        total_sdf += sdfpmp * wpmp;\n        total_weight += wpmp;\n    }\n#else\n    if (!isnan(sdfpmp.x))\n    {\n        total_sdf.x += sdfpmp.x * wpmp * sdfpmp.y;\n        total_weight.x += wpmp * sdfpmp.y;\n        total_sdf.y += wpmp * sdfpmp.y;\n        total_weight.y += wpmp;\n    }\n#endif\n    else if (zero_det == 5)\n    {\n        return make_float2(NAN, NAN);\n    }\n\n    float2 sdfppm = query_sdf_raw(bpos.x + bxp, bpos.y + byp, bpos.z + bzm, rpos.x + rxp, rpos.y + ryp, rpos.z + rzm,\n                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    float wppm = w_xp * w_yp * w_zm;\n#ifndef STD_W_SDF\n    if (!isnan(sdfppm.x))\n    {\n        total_sdf += sdfppm * wppm;\n        total_weight += wppm;\n    }\n#else\n    if (!isnan(sdfppm.x))\n    {\n        total_sdf.x += sdfppm.x * wppm * sdfppm.y;\n        total_weight.x += wppm * sdfppm.y;\n        total_sdf.y += wppm * sdfppm.y;\n        total_weight.y += wppm;\n    }\n#endif\n    else if (zero_det == 6)\n    {\n        return make_float2(NAN, NAN);\n    }\n\n    float2 sdfppp = query_sdf_raw(bpos.x + bxp, bpos.y + byp, bpos.z + bzp, rpos.x + rxp, rpos.y + ryp, rpos.z + rzp,\n                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    float wppp = w_xp * w_yp * w_zp;\n#ifndef STD_W_SDF\n    if (!isnan(sdfppp.x))\n    {\n        total_sdf += sdfppp * wppp;\n        total_weight += wppp;\n    }\n#else\n    if (!isnan(sdfppp.x))\n    {\n        total_sdf.x += sdfppp.x * wppp * sdfppp.y;\n        total_weight.x += wppp * sdfppp.y;\n        total_sdf.y += wppp * sdfppp.y;\n        total_weight.y += wppp;\n    }\n#endif\n    else if (zero_det == 7)\n    {\n        return make_float2(NAN, NAN);\n    }\n\n    // If NAN, will also be handled.\n    return total_sdf / total_weight;\n}\n\n__device__ static inline float4 sdf_interp(const float3 p1, const float3 p2, const float stdp1, const float stdp2,\n                                           float valp1, float valp2)\n{\n    if (fabs(0.0f - valp1) < 1.0e-5f)\n        return make_float4(p1, stdp1);\n    if (fabs(0.0f - valp2) < 1.0e-5f)\n        return make_float4(p2, stdp2);\n    if (fabs(valp1 - valp2) < 1.0e-5f)\n        return make_float4(p1, stdp1);\n\n    float w2 = (0.0f - valp1) / (valp2 - valp1);\n    float w1 = 1 - w2;\n\n    return make_float4(p1.x * w1 + p2.x * w2,\n                       p1.y * w1 + p2.y * w2,\n                       p1.z * w1 + p2.z * w2,\n                       stdp1 * w1 + stdp2 * w2);\n}\n\n__global__ static void meshing_cube(const IndexerAccessor indexer,\n                                    const ValidBlocksAccessor valid_blocks,\n                                    const BackwardMappingAccessor vec_batch_mapping,\n                                    const CubeSDFAccessor cube_sdf,\n                                    const CubeSDFAccessor cube_std,\n                                    TrianglesAccessor triangles,\n                                    TriangleStdAccessor triangle_std,\n                                    TriangleVecIdAccessor triangle_flatten_id,\n                                    int *__restrict__ triangles_count,\n                                    int max_triangles_count,\n                                    const uint max_vec_num,\n                                    int nx, int ny, int nz,\n                                    float max_std)\n{\n    const uint r = cube_sdf.size(1) / 2;\n    const uint r3 = r * r * r;\n    const uint num_lif = valid_blocks.size(0);\n    const float sbs = 1.0f / r; // sub-block-size\n\n    const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;\n    const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;\n\n    if (lif_id >= num_lif || sub_id >= r3)\n    {\n        return;\n    }\n\n    const uint3 bpos = make_uint3(\n        (valid_blocks[lif_id] / (ny * nz)) % nx,\n        (valid_blocks[lif_id] / nz) % ny,\n        valid_blocks[lif_id] % nz);\n    const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));\n    const uint rx = sub_id / (r * r);\n    const uint ry = (sub_id / r) % r;\n    const uint rz = sub_id % r;\n\n    // Find all 8 neighbours\n    float3 points[8];\n    float2 sdf_vals[8];\n\n    sdf_vals[0] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[0].x))\n        return;\n    points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);\n\n    sdf_vals[1] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[1].x))\n        return;\n    points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);\n\n    sdf_vals[2] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[2].x))\n        return;\n    points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);\n\n    sdf_vals[3] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[3].x))\n        return;\n    points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);\n\n    sdf_vals[4] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[4].x))\n        return;\n    points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);\n\n    sdf_vals[5] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[5].x))\n        return;\n    points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);\n\n    sdf_vals[6] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[6].x))\n        return;\n    points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);\n\n    sdf_vals[7] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[7].x))\n        return;\n    points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);\n\n    // Find triangle config.\n    int cube_type = 0;\n    if (sdf_vals[0].x < 0)\n        cube_type |= 1;\n    if (sdf_vals[1].x < 0)\n        cube_type |= 2;\n    if (sdf_vals[2].x < 0)\n        cube_type |= 4;\n    if (sdf_vals[3].x < 0)\n        cube_type |= 8;\n    if (sdf_vals[4].x < 0)\n        cube_type |= 16;\n    if (sdf_vals[5].x < 0)\n        cube_type |= 32;\n    if (sdf_vals[6].x < 0)\n        cube_type |= 64;\n    if (sdf_vals[7].x < 0)\n        cube_type |= 128;\n\n    // Find vertex position on each edge (weighted by sdf value)\n    int edge_config = edgeTable[cube_type];\n    float4 vert_list[12];\n\n    if (edge_config == 0)\n        return;\n    if (edge_config & 1)\n        vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0].y, sdf_vals[1].y, sdf_vals[0].x, sdf_vals[1].x);\n    if (edge_config & 2)\n        vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1].y, sdf_vals[2].y, sdf_vals[1].x, sdf_vals[2].x);\n    if (edge_config & 4)\n        vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2].y, sdf_vals[3].y, sdf_vals[2].x, sdf_vals[3].x);\n    if (edge_config & 8)\n        vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3].y, sdf_vals[0].y, sdf_vals[3].x, sdf_vals[0].x);\n    if (edge_config & 16)\n        vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4].y, sdf_vals[5].y, sdf_vals[4].x, sdf_vals[5].x);\n    if (edge_config & 32)\n        vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5].y, sdf_vals[6].y, sdf_vals[5].x, sdf_vals[6].x);\n    if (edge_config & 64)\n        vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6].y, sdf_vals[7].y, sdf_vals[6].x, sdf_vals[7].x);\n    if (edge_config & 128)\n        vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7].y, sdf_vals[4].y, sdf_vals[7].x, sdf_vals[4].x);\n    if (edge_config & 256)\n        vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0].y, sdf_vals[4].y, sdf_vals[0].x, sdf_vals[4].x);\n    if (edge_config & 512)\n        vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1].y, sdf_vals[5].y, sdf_vals[1].x, sdf_vals[5].x);\n    if (edge_config & 1024)\n        vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2].y, sdf_vals[6].y, sdf_vals[2].x, sdf_vals[6].x);\n    if (edge_config & 2048)\n        vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3].y, sdf_vals[7].y, sdf_vals[3].x, sdf_vals[7].x);\n\n    // Write triangles to array.\n    float4 vp[3];\n    for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)\n    {\n#pragma unroll\n        for (int vi = 0; vi < 3; ++vi)\n        {\n            vp[vi] = vert_list[triangleTable[cube_type][i + vi]];\n        }\n        if (vp[0].w > max_std || vp[1].w > max_std || vp[2].w > max_std)\n        {\n            continue;\n        }\n        int triangle_id = atomicAdd(triangles_count, 1);\n        if (triangle_id < max_triangles_count)\n        {\n#pragma unroll\n            for (int vi = 0; vi < 3; ++vi)\n            {\n                triangles[triangle_id][vi][0] = vp[vi].x;\n                triangles[triangle_id][vi][1] = vp[vi].y;\n                triangles[triangle_id][vi][2] = vp[vi].z;\n                triangle_std[triangle_id][vi] = vp[vi].w;\n            }\n            triangle_flatten_id[triangle_id] = valid_blocks[lif_id];\n        }\n    }\n}\n\nstd::vector<torch::Tensor> marching_cubes_sparse_interp_cuda(\n    torch::Tensor indexer,           // (nx, ny, nz) -> data_id\n    torch::Tensor valid_blocks,      // (K, )\n    torch::Tensor vec_batch_mapping, //\n    torch::Tensor cube_sdf,          // (M, rx, ry, rz)\n    torch::Tensor cube_std,          // (M, rx, ry, rz)\n    const std::vector<int> &n_xyz,   // [nx, ny, nz]\n    float max_std,                   // Prune all vertices\n    int max_n_triangles              // Maximum number of triangle buffer\n)\n{\n    CHECK_INPUT(indexer);\n    CHECK_INPUT(valid_blocks);\n    CHECK_INPUT(cube_sdf);\n    CHECK_INPUT(cube_std);\n    CHECK_INPUT(vec_batch_mapping);\n    assert(max_n_triangles > 0);\n\n    const int r = cube_sdf.size(1) / 2;\n    const int r3 = r * r * r;\n    const int num_lif = valid_blocks.size(0);\n    const uint max_vec_num = vec_batch_mapping.size(0);\n\n    torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3},\n                                           torch::dtype(torch::kFloat32).device(torch::kCUDA));\n    torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));\n    torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));\n\n    dim3 dimBlock = dim3(16, 16);\n    uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;\n    uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;\n    dim3 dimGrid = dim3(xBlocks, yBlocks);\n\n    thrust::device_vector<int> n_output(1, 0);\n    meshing_cube<<<dimGrid, dimBlock, 0, at::cuda::getCurrentCUDAStream()>>>(\n        indexer.packed_accessor32<int64_t, 3, torch::RestrictPtrTraits>(),\n        valid_blocks.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),\n        vec_batch_mapping.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),\n        cube_sdf.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),\n        cube_std.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),\n        triangles.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),\n        triangle_std.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),\n        triangle_flatten_id.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),\n        n_output.data().get(), max_n_triangles, max_vec_num,\n        n_xyz[0], n_xyz[1], n_xyz[2], max_std);\n    cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());\n\n    int output_n_triangles = n_output[0];\n    if (output_n_triangles < max_n_triangles)\n    {\n        // Trim output tensor if it is not full.\n        triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n        triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n        triangle_std = triangle_std.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n    }\n    else\n    {\n        // Otherwise spawn a warning.\n        std::cerr << \"Warning from marching cube: the max triangle number is too small \" << output_n_triangles << \" vs \" << max_n_triangles << std::endl;\n    }\n\n    return {triangles, triangle_flatten_id, triangle_std};\n}\n"
  },
  {
    "path": "third_party/marching_cubes/src/mc_kernel.cu",
    "content": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ static inline float2 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,\n                                              const uint max_vec_num,\n                                              const IndexerAccessor indexer,\n                                              const CubeSDFAccessor cube_sdf,\n                                              const CubeSDFAccessor cube_std,\n                                              const BackwardMappingAccessor vec_batch_mapping)\n{\n    if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))\n    {\n        return make_float2(NAN, NAN);\n    }\n    //    printf(\"B-Getting: %d %d %d --> %d, %d, %d\\n\", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));\n    long long vec_ind = indexer[bx][by][bz];\n    if (vec_ind == -1 || vec_ind >= max_vec_num)\n    {\n        return make_float2(NAN, NAN);\n    }\n    int batch_ind = vec_batch_mapping[vec_ind];\n    if (batch_ind == -1)\n    {\n        return make_float2(NAN, NAN);\n    }\n    //    printf(\"Getting: %d %d %d %d --> %d %d\\n\", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));\n    float sdf = cube_sdf[batch_ind][arx][ary][arz];\n    float std = cube_std[batch_ind][arx][ary][arz];\n    return make_float2(sdf, std);\n}\n\n// Use stddev to weight sdf value.\n// #define STD_W_SDF\n\n__device__ static inline float2 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,\n                                        const IndexerAccessor indexer,\n                                        const CubeSDFAccessor cube_sdf,\n                                        const CubeSDFAccessor cube_std,\n                                        const BackwardMappingAccessor vec_batch_mapping)\n{\n    if (bpos.x >= bsize.x)\n    {\n        bpos.x = bsize.x - 1;\n        rpos.x = r - 1;\n    }\n    if (bpos.y >= bsize.y)\n    {\n        bpos.y = bsize.y - 1;\n        rpos.y = r - 1;\n    }\n    if (bpos.z >= bsize.z)\n    {\n        bpos.z = bsize.z - 1;\n        rpos.z = r - 1;\n    }\n\n    if (rpos.x == r)\n    {\n        bpos.x += 1;\n        rpos.x = 0;\n    }\n\n    if (rpos.y == r)\n    {\n        bpos.y += 1;\n        rpos.y = 0;\n    }\n\n    if (rpos.z == r)\n    {\n        bpos.z += 1;\n        rpos.z = 0;\n    }\n\n    float2 total_sdf = query_sdf_raw(bpos.x, bpos.y, bpos.z, rpos.x, rpos.y, rpos.z, max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n\n    // If NAN, will also be handled.\n    return total_sdf;\n}\n\n__device__ static inline float4 sdf_interp(const float3 p1, const float3 p2, const float stdp1, const float stdp2,\n                                           float valp1, float valp2)\n{\n    if (fabs(0.0f - valp1) < 1.0e-5f)\n        return make_float4(p1, stdp1);\n    if (fabs(0.0f - valp2) < 1.0e-5f)\n        return make_float4(p2, stdp2);\n    if (fabs(valp1 - valp2) < 1.0e-5f)\n        return make_float4(p1, stdp1);\n\n    float w2 = (0.0f - valp1) / (valp2 - valp1);\n    float w1 = 1 - w2;\n\n    return make_float4(p1.x * w1 + p2.x * w2,\n                       p1.y * w1 + p2.y * w2,\n                       p1.z * w1 + p2.z * w2,\n                       stdp1 * w1 + stdp2 * w2);\n}\n\n__global__ static void meshing_cube(const IndexerAccessor indexer,\n                                    const ValidBlocksAccessor valid_blocks,\n                                    const BackwardMappingAccessor vec_batch_mapping,\n                                    const CubeSDFAccessor cube_sdf,\n                                    const CubeSDFAccessor cube_std,\n                                    TrianglesAccessor triangles,\n                                    TriangleStdAccessor triangle_std,\n                                    TriangleVecIdAccessor triangle_flatten_id,\n                                    int *__restrict__ triangles_count,\n                                    int max_triangles_count,\n                                    const uint max_vec_num,\n                                    int nx, int ny, int nz,\n                                    float max_std)\n{\n    const uint r = cube_sdf.size(1);\n    const uint r3 = r * r * r;\n    const uint num_lif = valid_blocks.size(0);\n    const float sbs = 1.0f / r; // sub-block-size\n\n    const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;\n    const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;\n\n    if (lif_id >= num_lif || sub_id >= r3)\n    {\n        return;\n    }\n\n    const uint3 bpos = make_uint3(\n        (valid_blocks[lif_id] / (ny * nz)) % nx,\n        (valid_blocks[lif_id] / nz) % ny,\n        valid_blocks[lif_id] % nz);\n    const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));\n    const uint rx = sub_id / (r * r);\n    const uint ry = (sub_id / r) % r;\n    const uint rz = sub_id % r;\n\n    // Find all 8 neighbours\n    float3 points[8];\n    float2 sdf_vals[8];\n\n    sdf_vals[0] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[0].x))\n        return;\n    points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);\n\n    sdf_vals[1] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[1].x))\n        return;\n    points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);\n\n    sdf_vals[2] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[2].x))\n        return;\n    points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);\n\n    sdf_vals[3] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[3].x))\n        return;\n    points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);\n\n    sdf_vals[4] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[4].x))\n        return;\n    points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);\n\n    sdf_vals[5] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[5].x))\n        return;\n    points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);\n\n    sdf_vals[6] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[6].x))\n        return;\n    points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);\n\n    sdf_vals[7] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_vals[7].x))\n        return;\n    points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);\n\n    // Find triangle config.\n    int cube_type = 0;\n    if (sdf_vals[0].x < 0)\n        cube_type |= 1;\n    if (sdf_vals[1].x < 0)\n        cube_type |= 2;\n    if (sdf_vals[2].x < 0)\n        cube_type |= 4;\n    if (sdf_vals[3].x < 0)\n        cube_type |= 8;\n    if (sdf_vals[4].x < 0)\n        cube_type |= 16;\n    if (sdf_vals[5].x < 0)\n        cube_type |= 32;\n    if (sdf_vals[6].x < 0)\n        cube_type |= 64;\n    if (sdf_vals[7].x < 0)\n        cube_type |= 128;\n\n    // Find vertex position on each edge (weighted by sdf value)\n    int edge_config = edgeTable[cube_type];\n    float4 vert_list[12];\n\n    if (edge_config == 0)\n        return;\n    if (edge_config & 1)\n        vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0].y, sdf_vals[1].y, sdf_vals[0].x, sdf_vals[1].x);\n    if (edge_config & 2)\n        vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1].y, sdf_vals[2].y, sdf_vals[1].x, sdf_vals[2].x);\n    if (edge_config & 4)\n        vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2].y, sdf_vals[3].y, sdf_vals[2].x, sdf_vals[3].x);\n    if (edge_config & 8)\n        vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3].y, sdf_vals[0].y, sdf_vals[3].x, sdf_vals[0].x);\n    if (edge_config & 16)\n        vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4].y, sdf_vals[5].y, sdf_vals[4].x, sdf_vals[5].x);\n    if (edge_config & 32)\n        vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5].y, sdf_vals[6].y, sdf_vals[5].x, sdf_vals[6].x);\n    if (edge_config & 64)\n        vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6].y, sdf_vals[7].y, sdf_vals[6].x, sdf_vals[7].x);\n    if (edge_config & 128)\n        vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7].y, sdf_vals[4].y, sdf_vals[7].x, sdf_vals[4].x);\n    if (edge_config & 256)\n        vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0].y, sdf_vals[4].y, sdf_vals[0].x, sdf_vals[4].x);\n    if (edge_config & 512)\n        vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1].y, sdf_vals[5].y, sdf_vals[1].x, sdf_vals[5].x);\n    if (edge_config & 1024)\n        vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2].y, sdf_vals[6].y, sdf_vals[2].x, sdf_vals[6].x);\n    if (edge_config & 2048)\n        vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3].y, sdf_vals[7].y, sdf_vals[3].x, sdf_vals[7].x);\n\n    // Write triangles to array.\n    float4 vp[3];\n    for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)\n    {\n#pragma unroll\n        for (int vi = 0; vi < 3; ++vi)\n        {\n            vp[vi] = vert_list[triangleTable[cube_type][i + vi]];\n        }\n\n        int triangle_id = atomicAdd(triangles_count, 1);\n        if (triangle_id < max_triangles_count)\n        {\n#pragma unroll\n            for (int vi = 0; vi < 3; ++vi)\n            {\n                triangles[triangle_id][vi][0] = vp[vi].x;\n                triangles[triangle_id][vi][1] = vp[vi].y;\n                triangles[triangle_id][vi][2] = vp[vi].z;\n                triangle_std[triangle_id][vi] = vp[vi].w;\n            }\n            triangle_flatten_id[triangle_id] = valid_blocks[lif_id];\n        }\n    }\n}\n\nstd::vector<torch::Tensor> marching_cubes_sparse(\n    torch::Tensor indexer,           // (nx, ny, nz) -> data_id\n    torch::Tensor valid_blocks,      // (K, )\n    torch::Tensor vec_batch_mapping, //\n    torch::Tensor cube_sdf,          // (M, rx, ry, rz)\n    torch::Tensor cube_std,          // (M, rx, ry, rz)\n    const std::vector<int> &n_xyz,   // [nx, ny, nz]\n    float max_std,                   // Prune all vertices\n    int max_n_triangles              // Maximum number of triangle buffer\n)\n{\n    CHECK_INPUT(indexer);\n    CHECK_INPUT(valid_blocks);\n    CHECK_INPUT(cube_sdf);\n    CHECK_INPUT(cube_std);\n    CHECK_INPUT(vec_batch_mapping);\n    assert(max_n_triangles > 0);\n\n    const int r = cube_sdf.size(1);\n    const int r3 = r * r * r;\n    const int num_lif = valid_blocks.size(0);\n    const uint max_vec_num = vec_batch_mapping.size(0);\n\n    torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3},\n                                           torch::dtype(torch::kFloat32).device(torch::kCUDA));\n    torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));\n    torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));\n\n    dim3 dimBlock = dim3(16, 16);\n    uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;\n    uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;\n    dim3 dimGrid = dim3(xBlocks, yBlocks);\n\n    thrust::device_vector<int> n_output(1, 0);\n    meshing_cube<<<dimGrid, dimBlock, 0, at::cuda::getCurrentCUDAStream()>>>(\n        indexer.packed_accessor32<int64_t, 3, torch::RestrictPtrTraits>(),\n        valid_blocks.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),\n        vec_batch_mapping.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),\n        cube_sdf.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),\n        cube_std.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),\n        triangles.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),\n        triangle_std.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),\n        triangle_flatten_id.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),\n        n_output.data().get(), max_n_triangles, max_vec_num,\n        n_xyz[0], n_xyz[1], n_xyz[2], max_std);\n    cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());\n\n    int output_n_triangles = n_output[0];\n    if (output_n_triangles < max_n_triangles)\n    {\n        // Trim output tensor if it is not full.\n        triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n        triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n        triangle_std = triangle_std.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n    }\n    else\n    {\n        // Otherwise spawn a warning.\n        std::cerr << \"Warning from marching cube: the max triangle number is too small \" << output_n_triangles << \" vs \" << max_n_triangles << std::endl;\n    }\n\n    return {triangles, triangle_flatten_id, triangle_std};\n}\n"
  },
  {
    "path": "third_party/marching_cubes/src/mc_kernel_colour.cu",
    "content": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ static inline float4 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,\n                                              const uint max_vec_num,\n                                              const IndexerAccessor indexer,\n                                              const CubeSDFRGBAccessor cube_sdf,\n                                              const CubeSDFAccessor cube_std,\n                                              const BackwardMappingAccessor vec_batch_mapping)\n{\n    if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))\n    {\n        return make_float4(NAN, NAN, NAN, NAN);\n    }\n    //    printf(\"B-Getting: %d %d %d --> %d, %d, %d\\n\", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));\n    long long vec_ind = indexer[bx][by][bz];\n    if (vec_ind == -1 || vec_ind >= max_vec_num)\n    {\n        return make_float4(NAN, NAN, NAN, NAN);\n    }\n    int batch_ind = vec_batch_mapping[vec_ind];\n    if (batch_ind == -1)\n    {\n        return make_float4(NAN, NAN, NAN, NAN);\n    }\n    //    printf(\"Getting: %d %d %d %d --> %d %d\\n\", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));\n    return make_float4(cube_sdf[batch_ind][arx][ary][arz][3],\n                       cube_sdf[batch_ind][arx][ary][arz][0],\n                       cube_sdf[batch_ind][arx][ary][arz][1],\n                       cube_sdf[batch_ind][arx][ary][arz][2]);\n}\n\n// Use stddev to weight sdf value.\n// #define STD_W_SDF\n\n__device__ static inline float4 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,\n                                        const IndexerAccessor indexer,\n                                        const CubeSDFRGBAccessor cube_sdf,\n                                        const CubeSDFAccessor cube_std,\n                                        const BackwardMappingAccessor vec_batch_mapping)\n{\n    if (bpos.x >= bsize.x)\n    {\n        bpos.x = bsize.x - 1;\n        rpos.x = r - 1;\n    }\n    if (bpos.y >= bsize.y)\n    {\n        bpos.y = bsize.y - 1;\n        rpos.y = r - 1;\n    }\n    if (bpos.z >= bsize.z)\n    {\n        bpos.z = bsize.z - 1;\n        rpos.z = r - 1;\n    }\n\n    if (rpos.x == r)\n    {\n        bpos.x += 1;\n        rpos.x = 0;\n    }\n\n    if (rpos.y == r)\n    {\n        bpos.y += 1;\n        rpos.y = 0;\n    }\n\n    if (rpos.z == r)\n    {\n        bpos.z += 1;\n        rpos.z = 0;\n    }\n\n    float4 total_sdf = query_sdf_raw(bpos.x, bpos.y, bpos.z, rpos.x, rpos.y, rpos.z, max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n\n    // If NAN, will also be handled.\n    return total_sdf;\n}\n\n__device__ static inline float3 sdf_interp(const float3 p1, const float3 p2,\n                                           float valp1, float valp2)\n{\n    if (fabs(0.0f - valp1) < 1.0e-5f)\n        return p1;\n    if (fabs(0.0f - valp2) < 1.0e-5f)\n        return p2;\n    if (fabs(valp1 - valp2) < 1.0e-5f)\n        return p1;\n\n    float w2 = (0.0f - valp1) / (valp2 - valp1);\n    float w1 = 1 - w2;\n\n    return make_float3(p1.x * w1 + p2.x * w2,\n                       p1.y * w1 + p2.y * w2,\n                       p1.z * w1 + p2.z * w2);\n}\n\n__global__ static void meshing_cube_colour(const IndexerAccessor indexer,\n                                           const ValidBlocksAccessor valid_blocks,\n                                           const BackwardMappingAccessor vec_batch_mapping,\n                                           const CubeSDFRGBAccessor cube_sdf,\n                                           const CubeSDFAccessor cube_std,\n                                           TrianglesAccessor triangles,\n                                           TrianglesAccessor vertex_colours,\n                                           TriangleVecIdAccessor triangle_flatten_id,\n                                           int *__restrict__ triangles_count,\n                                           int max_triangles_count,\n                                           const uint max_vec_num,\n                                           int nx, int ny, int nz,\n                                           float max_std)\n{\n    const uint r = cube_sdf.size(1);\n    const uint r3 = r * r * r;\n    const uint num_lif = valid_blocks.size(0);\n    const float sbs = 1.0f / r; // sub-block-size\n\n    const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;\n    const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;\n\n    if (lif_id >= num_lif || sub_id >= r3)\n    {\n        return;\n    }\n\n    const uint3 bpos = make_uint3(\n        (valid_blocks[lif_id] / (ny * nz)) % nx,\n        (valid_blocks[lif_id] / nz) % ny,\n        valid_blocks[lif_id] % nz);\n    const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));\n    const uint rx = sub_id / (r * r);\n    const uint ry = (sub_id / r) % r;\n    const uint rz = sub_id % r;\n\n    // Find all 8 neighbours\n    float3 points[8];\n    float3 colours[8];\n    float sdf_vals[8];\n\n    float4 sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_val.x))\n        return;\n    sdf_vals[0] = sdf_val.x;\n    points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);\n    colours[0] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);\n\n    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_val.x))\n        return;\n    sdf_vals[1] = sdf_val.x;\n    points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);\n    colours[1] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);\n\n    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_val.x))\n        return;\n    sdf_vals[2] = sdf_val.x;\n    points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);\n    colours[2] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);\n\n    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_val.x))\n        return;\n    sdf_vals[3] = sdf_val.x;\n    points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);\n    colours[3] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);\n\n    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_val.x))\n        return;\n    sdf_vals[4] = sdf_val.x;\n    points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);\n    colours[4] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);\n\n    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_val.x))\n        return;\n    sdf_vals[5] = sdf_val.x;\n    points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);\n    colours[5] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);\n\n    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_val.x))\n        return;\n    sdf_vals[6] = sdf_val.x;\n    points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);\n    colours[6] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);\n\n    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);\n    if (isnan(sdf_val.x))\n        return;\n    sdf_vals[7] = sdf_val.x;\n    points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);\n    colours[7] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);\n\n    // Find triangle config.\n    int cube_type = 0;\n    if (sdf_vals[0] < 0)\n        cube_type |= 1;\n    if (sdf_vals[1] < 0)\n        cube_type |= 2;\n    if (sdf_vals[2] < 0)\n        cube_type |= 4;\n    if (sdf_vals[3] < 0)\n        cube_type |= 8;\n    if (sdf_vals[4] < 0)\n        cube_type |= 16;\n    if (sdf_vals[5] < 0)\n        cube_type |= 32;\n    if (sdf_vals[6] < 0)\n        cube_type |= 64;\n    if (sdf_vals[7] < 0)\n        cube_type |= 128;\n\n    // Find vertex position on each edge (weighted by sdf value)\n    int edge_config = edgeTable[cube_type];\n    float3 vert_list[12];\n    float3 rgb_list[12];\n\n    if (edge_config == 0)\n        return;\n    if (edge_config & 1)\n    {\n        vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0], sdf_vals[1]);\n        rgb_list[0] = sdf_interp(colours[0], colours[1], sdf_vals[0], sdf_vals[1]);\n    }\n    if (edge_config & 2)\n    {\n        vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1], sdf_vals[2]);\n        rgb_list[1] = sdf_interp(colours[1], colours[2], sdf_vals[1], sdf_vals[2]);\n    }\n    if (edge_config & 4)\n    {\n        vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2], sdf_vals[3]);\n        rgb_list[2] = sdf_interp(colours[2], colours[3], sdf_vals[2], sdf_vals[3]);\n    }\n    if (edge_config & 8)\n    {\n        vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3], sdf_vals[0]);\n        rgb_list[3] = sdf_interp(colours[3], colours[0], sdf_vals[3], sdf_vals[0]);\n    }\n    if (edge_config & 16)\n    {\n        vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4], sdf_vals[5]);\n        rgb_list[4] = sdf_interp(colours[4], colours[5], sdf_vals[4], sdf_vals[5]);\n    }\n    if (edge_config & 32)\n    {\n        vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5], sdf_vals[6]);\n        rgb_list[5] = sdf_interp(colours[5], colours[6], sdf_vals[5], sdf_vals[6]);\n    }\n    if (edge_config & 64)\n    {\n        vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6], sdf_vals[7]);\n        rgb_list[6] = sdf_interp(colours[6], colours[7], sdf_vals[6], sdf_vals[7]);\n    }\n    if (edge_config & 128)\n    {\n        vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7], sdf_vals[4]);\n        rgb_list[7] = sdf_interp(colours[7], colours[4], sdf_vals[7], sdf_vals[4]);\n    }\n    if (edge_config & 256)\n    {\n        vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0], sdf_vals[4]);\n        rgb_list[8] = sdf_interp(colours[0], colours[4], sdf_vals[0], sdf_vals[4]);\n    }\n    if (edge_config & 512)\n    {\n        vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1], sdf_vals[5]);\n        rgb_list[9] = sdf_interp(colours[1], colours[5], sdf_vals[1], sdf_vals[5]);\n    }\n    if (edge_config & 1024)\n    {\n        vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2], sdf_vals[6]);\n        rgb_list[10] = sdf_interp(colours[2], colours[6], sdf_vals[2], sdf_vals[6]);\n    }\n    if (edge_config & 2048)\n    {\n        vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3], sdf_vals[7]);\n        rgb_list[11] = sdf_interp(colours[3], colours[7], sdf_vals[3], sdf_vals[7]);\n    }\n\n    // Write triangles to array.\n    float3 vp[3];\n    float3 vc[3];\n    for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)\n    {\n#pragma unroll\n        for (int vi = 0; vi < 3; ++vi)\n        {\n            vp[vi] = vert_list[triangleTable[cube_type][i + vi]];\n            vc[vi] = rgb_list[triangleTable[cube_type][i + vi]];\n        }\n\n        int triangle_id = atomicAdd(triangles_count, 1);\n        if (triangle_id < max_triangles_count)\n        {\n#pragma unroll\n            for (int vi = 0; vi < 3; ++vi)\n            {\n                triangles[triangle_id][vi][0] = vp[vi].x;\n                triangles[triangle_id][vi][1] = vp[vi].y;\n                triangles[triangle_id][vi][2] = vp[vi].z;\n                vertex_colours[triangle_id][vi][0] = vc[vi].x;\n                vertex_colours[triangle_id][vi][1] = vc[vi].y;\n                vertex_colours[triangle_id][vi][2] = vc[vi].z;\n            }\n            triangle_flatten_id[triangle_id] = valid_blocks[lif_id];\n        }\n    }\n}\n\nstd::vector<torch::Tensor> marching_cubes_sparse_colour(\n    torch::Tensor indexer,           // (nx, ny, nz) -> data_id\n    torch::Tensor valid_blocks,      // (K, )\n    torch::Tensor vec_batch_mapping, //\n    torch::Tensor cube_rgb_sdf,      // (M, rx, ry, rz, 4)\n    torch::Tensor cube_std,          // (M, rx, ry, rz)\n    const std::vector<int> &n_xyz,   // [nx, ny, nz]\n    float max_std,                   // Prune all vertices\n    int max_n_triangles              // Maximum number of triangle buffer\n)\n{\n    CHECK_INPUT(indexer);\n    CHECK_INPUT(valid_blocks);\n    CHECK_INPUT(cube_rgb_sdf);\n    CHECK_INPUT(cube_std);\n    CHECK_INPUT(vec_batch_mapping);\n    assert(max_n_triangles > 0);\n\n    const int r = cube_rgb_sdf.size(1);\n    const int r3 = r * r * r;\n    const int num_lif = valid_blocks.size(0);\n    const uint max_vec_num = vec_batch_mapping.size(0);\n\n    torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));\n    torch::Tensor vertex_colours = torch::empty({max_n_triangles, 3, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));\n    torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));\n    torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));\n\n    dim3 dimBlock = dim3(16, 16);\n    uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;\n    uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;\n    dim3 dimGrid = dim3(xBlocks, yBlocks);\n\n    thrust::device_vector<int> n_output(1, 0);\n    meshing_cube_colour<<<dimGrid, dimBlock, 0, at::cuda::getCurrentCUDAStream()>>>(\n        indexer.packed_accessor32<int64_t, 3, torch::RestrictPtrTraits>(),\n        valid_blocks.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),\n        vec_batch_mapping.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),\n        cube_rgb_sdf.packed_accessor32<float, 5, torch::RestrictPtrTraits>(),\n        cube_std.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),\n        triangles.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),\n        vertex_colours.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),\n        triangle_flatten_id.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),\n        n_output.data().get(), max_n_triangles, max_vec_num,\n        n_xyz[0], n_xyz[1], n_xyz[2], max_std);\n    cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());\n\n    int output_n_triangles = n_output[0];\n    if (output_n_triangles < max_n_triangles)\n    {\n        // Trim output tensor if it is not full.\n        triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n        vertex_colours = vertex_colours.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n        triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});\n    }\n    else\n    {\n        // Otherwise spawn a warning.\n        std::cerr << \"Warning from marching cube: the max triangle number is too small \" << output_n_triangles << \" vs \" << max_n_triangles << std::endl;\n    }\n\n    return {triangles, vertex_colours, triangle_flatten_id};\n}\n"
  },
  {
    "path": "third_party/sparse_octree/include/octree.h",
    "content": "#include <memory>\n#include <torch/script.h>\n#include <torch/custom_class.h>\n\nenum OcType\n{\n    NONLEAF = -1,\n    SURFACE = 0,\n    FEATURE = 1\n};\n\nclass Octant : public torch::CustomClassHolder\n{\npublic:\n    inline Octant()\n    {\n        code_ = 0;\n        side_ = 0;\n        index_ = next_index_++;\n        depth_ = -1;\n        is_leaf_ = false;\n        children_mask_ = 0;\n        type_ = NONLEAF;\n        for (unsigned int i = 0; i < 8; i++)\n        {\n            child_ptr_[i] = nullptr;\n            // feature_index_[i] = -1;\n        }\n    }\n    ~Octant() {}\n\n    // std::shared_ptr<Octant> &child(const int x, const int y, const int z)\n    // {\n    //     return child_ptr_[x + y * 2 + z * 4];\n    // };\n\n    // std::shared_ptr<Octant> &child(const int offset)\n    // {\n    //     return child_ptr_[offset];\n    // }\n    Octant *&child(const int x, const int y, const int z)\n    {\n        return child_ptr_[x + y * 2 + z * 4];\n    };\n\n    Octant *&child(const int offset)\n    {\n        return child_ptr_[offset];\n    }\n\n    uint64_t code_;\n    bool is_leaf_;\n    unsigned int side_;\n    unsigned char children_mask_;\n    // std::shared_ptr<Octant> child_ptr_[8];\n    // int feature_index_[8];\n    int index_;\n    int depth_;\n    int type_;\n    // int feat_index_;\n    Octant *child_ptr_[8];\n    static int next_index_;\n};\n\nclass Octree : public torch::CustomClassHolder\n{\npublic:\n    Octree();\n    // temporal solution\n    Octree(int64_t grid_dim, int64_t feat_dim, double voxel_size, std::vector<torch::Tensor> all_pts);\n    ~Octree();\n    void init(int64_t grid_dim, int64_t feat_dim, double voxel_size);\n\n    // allocate voxels\n    void insert(torch::Tensor vox);\n    double try_insert(torch::Tensor pts);\n\n    // find a particular octant\n    Octant *find_octant(std::vector<float> coord);\n\n    // test intersections\n    bool has_voxel(torch::Tensor pose);\n\n    // query features\n    torch::Tensor get_features(torch::Tensor pts);\n\n    // get all voxels\n    torch::Tensor get_voxels();\n    std::vector<float> get_voxel_recursive(Octant *n);\n\n    // get leaf voxels\n    torch::Tensor get_leaf_voxels();\n    std::vector<float> get_leaf_voxel_recursive(Octant *n);\n\n    // count nodes\n    int64_t count_nodes();\n    int64_t count_recursive(Octant *n);\n\n    // count leaf nodes\n    int64_t count_leaf_nodes();\n    // int64_t leaves_count_recursive(std::shared_ptr<Octant> n);\n    int64_t leaves_count_recursive(Octant *n);\n\n    // get voxel centres and childrens\n    std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> get_centres_and_children();\n\npublic:\n    int size_;\n    int feat_dim_;\n    int max_level_;\n\n    // temporal solution\n    double voxel_size_;\n    std::vector<torch::Tensor> all_pts;\n\nprivate:\n    std::set<uint64_t> all_keys;\n\n\n    // std::shared_ptr<Octant> root_;\n    Octant *root_;\n    // static int feature_index;\n\n    // internal count function\n    std::pair<int64_t, int64_t> count_nodes_internal();\n    std::pair<int64_t, int64_t> count_recursive_internal(Octant *n);\n\n\n};"
  },
  {
    "path": "third_party/sparse_octree/include/test.h",
    "content": "#pragma once\n#include <iostream>\n\n#define MAX_BITS 21\n// #define SCALE_MASK ((uint64_t)0x1FF)\n#define SCALE_MASK ((uint64_t)0x1)\n\n/*\n * Mask generated with:\n   MASK[0] = 0x7000000000000000,\n   for(int i = 1; i < 21; ++i) {\n   MASK[i] = MASK[i-1] | (MASK[0] >> (i*3));\n   std::bitset<64> b(MASK[i]);\n   std::cout << std::hex << b.to_ullong() << std::endl;\n   }\n *\n*/\nconstexpr uint64_t MASK[] = {\n    0x7000000000000000,\n    0x7e00000000000000,\n    0x7fc0000000000000,\n    0x7ff8000000000000,\n    0x7fff000000000000,\n    0x7fffe00000000000,\n    0x7ffffc0000000000,\n    0x7fffff8000000000,\n    0x7ffffff000000000,\n    0x7ffffffe00000000,\n    0x7fffffffc0000000,\n    0x7ffffffff8000000,\n    0x7fffffffff000000,\n    0x7fffffffffe00000,\n    0x7ffffffffffc0000,\n    0x7fffffffffff8000,\n    0x7ffffffffffff000,\n    0x7ffffffffffffe00,\n    0x7fffffffffffffc0,\n    0x7ffffffffffffff8,\n    0x7fffffffffffffff};\n\ninline int64_t expand(int64_t value)\n{\n    int64_t x = value & 0x1fffff;\n    x = (x | x << 32) & 0x1f00000000ffff;\n    x = (x | x << 16) & 0x1f0000ff0000ff;\n    x = (x | x << 8) & 0x100f00f00f00f00f;\n    x = (x | x << 4) & 0x10c30c30c30c30c3;\n    x = (x | x << 2) & 0x1249249249249249;\n    return x;\n}\n\ninline uint64_t compact(uint64_t value)\n{\n    uint64_t x = value & 0x1249249249249249;\n    x = (x | x >> 2) & 0x10c30c30c30c30c3;\n    x = (x | x >> 4) & 0x100f00f00f00f00f;\n    x = (x | x >> 8) & 0x1f0000ff0000ff;\n    x = (x | x >> 16) & 0x1f00000000ffff;\n    x = (x | x >> 32) & 0x1fffff;\n    return x;\n}\n\ninline int64_t compute_morton(int64_t x, int64_t y, int64_t z)\n{\n    int64_t code = 0;\n\n    x = expand(x);\n    y = expand(y) << 1;\n    z = expand(z) << 2;\n\n    code = x | y | z;\n    return code;\n}\n\ninline torch::Tensor encode_torch(torch::Tensor coords)\n{\n    torch::Tensor outs = torch::zeros({coords.size(0), 1}, dtype(torch::kInt64));\n    for (int i = 0; i < coords.size(0); ++i)\n    {\n        int64_t x = coords.data_ptr<int64_t>()[i * 3];\n        int64_t y = coords.data_ptr<int64_t>()[i * 3 + 1];\n        int64_t z = coords.data_ptr<int64_t>()[i * 3];\n        outs.data_ptr<int64_t>()[i] = (compute_morton(x, y, z) & MASK[MAX_BITS - 1]);\n    }\n    return outs;\n}\n"
  },
  {
    "path": "third_party/sparse_octree/include/utils.h",
    "content": "#pragma once\n#include <iostream>\n#include <eigen3/Eigen/Dense>\n\n#define MAX_BITS 21\n// #define SCALE_MASK ((uint64_t)0x1FF)\n#define SCALE_MASK ((uint64_t)0x1)\n\ntemplate <class T>\nstruct Vector3\n{\n    Vector3() : x(0), y(0), z(0) {}\n    Vector3(T x_, T y_, T z_) : x(x_), y(y_), z(z_) {}\n\n    Vector3<T> operator+(const Vector3<T> &b)\n    {\n        return Vector3<T>(x + b.x, y + b.y, z + b.z);\n    }\n\n    Vector3<T> operator-(const Vector3<T> &b)\n    {\n        return Vector3<T>(x - b.x, y - b.y, z - b.z);\n    }\n\n    T x, y, z;\n};\n\ntypedef Vector3<int> Vector3i;\ntypedef Vector3<float> Vector3f;\n\n/*\n * Mask generated with:\n   MASK[0] = 0x7000000000000000,\n   for(int i = 1; i < 21; ++i) {\n   MASK[i] = MASK[i-1] | (MASK[0] >> (i*3));\n   std::bitset<64> b(MASK[i]);\n   std::cout << std::hex << b.to_ullong() << std::endl;\n   }\n *\n*/\nconstexpr uint64_t MASK[] = {\n    0x7000000000000000,\n    0x7e00000000000000,\n    0x7fc0000000000000,\n    0x7ff8000000000000,\n    0x7fff000000000000,\n    0x7fffe00000000000,\n    0x7ffffc0000000000,\n    0x7fffff8000000000,\n    0x7ffffff000000000,\n    0x7ffffffe00000000,\n    0x7fffffffc0000000,\n    0x7ffffffff8000000,\n    0x7fffffffff000000,\n    0x7fffffffffe00000,\n    0x7ffffffffffc0000,\n    0x7fffffffffff8000,\n    0x7ffffffffffff000,\n    0x7ffffffffffffe00,\n    0x7fffffffffffffc0,\n    0x7ffffffffffffff8,\n    0x7fffffffffffffff};\n\ninline uint64_t expand(unsigned long long value)\n{\n    uint64_t x = value & 0x1fffff;\n    x = (x | x << 32) & 0x1f00000000ffff;\n    x = (x | x << 16) & 0x1f0000ff0000ff;\n    x = (x | x << 8) & 0x100f00f00f00f00f;\n    x = (x | x << 4) & 0x10c30c30c30c30c3;\n    x = (x | x << 2) & 0x1249249249249249;\n    return x;\n}\n\ninline uint64_t compact(uint64_t value)\n{\n    uint64_t x = value & 0x1249249249249249;\n    x = (x | x >> 2) & 0x10c30c30c30c30c3;\n    x = (x | x >> 4) & 0x100f00f00f00f00f;\n    x = (x | x >> 8) & 0x1f0000ff0000ff;\n    x = (x | x >> 16) & 0x1f00000000ffff;\n    x = (x | x >> 32) & 0x1fffff;\n    return x;\n}\n\ninline uint64_t compute_morton(uint64_t x, uint64_t y, uint64_t z)\n{\n    uint64_t code = 0;\n\n    x = expand(x);\n    y = expand(y) << 1;\n    z = expand(z) << 2;\n\n    code = x | y | z;\n    return code;\n}\n\ninline Eigen::Vector3i decode(const uint64_t code)\n{\n    return {\n        compact(code >> 0ull),\n        compact(code >> 1ull),\n        compact(code >> 2ull)};\n}\n\ninline uint64_t encode(const int x, const int y, const int z)\n{\n    return (compute_morton(x, y, z) & MASK[MAX_BITS - 1]);\n}"
  },
  {
    "path": "third_party/sparse_octree/setup.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CppExtension\nimport glob\n\n_ext_sources = glob.glob(\"src/*.cpp\")\n\nsetup(\n    name='svo',\n    ext_modules=[\n        CppExtension(\n            name='svo',\n            sources=_ext_sources,\n            include_dirs=[\"./include\"],\n            extra_compile_args={\n                \"cxx\": [\"-O2\", \"-I./include\"]\n            },\n        )\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    }\n)\n"
  },
  {
    "path": "third_party/sparse_octree/src/bindings.cpp",
    "content": "#include \"../include/octree.h\"\n#include \"../include/test.h\"\n\nTORCH_LIBRARY(svo, m)\n{\n    m.def(\"encode\", &encode_torch);\n\n    m.class_<Octant>(\"Octant\")\n        .def(torch::init<>());\n\n    m.class_<Octree>(\"Octree\")\n        .def(torch::init<>())\n        .def(\"init\", &Octree::init)\n        .def(\"insert\", &Octree::insert)\n        .def(\"try_insert\", &Octree::try_insert)\n        .def(\"get_voxels\", &Octree::get_voxels)\n        .def(\"get_leaf_voxels\", &Octree::get_leaf_voxels)\n        .def(\"get_features\", &Octree::get_features)\n        .def(\"count_nodes\", &Octree::count_nodes)\n        .def(\"count_leaf_nodes\", &Octree::count_leaf_nodes)\n        .def(\"has_voxel\", &Octree::has_voxel)\n        .def(\"get_centres_and_children\", &Octree::get_centres_and_children)\n        .def_pickle(\n        // __getstate__\n        [](const c10::intrusive_ptr<Octree>& self) -> std::tuple<int64_t, int64_t, double, std::vector<torch::Tensor>> {\n            return std::make_tuple(self->size_, self->feat_dim_, self->voxel_size_, self->all_pts);\n        },\n        // __setstate__\n        [](std::tuple<int64_t, int64_t, double, std::vector<torch::Tensor>> state) { \n            return c10::make_intrusive<Octree>(std::get<0>(state), std::get<1>(state), std::get<2>(state), std::get<3>(state));\n        });\n}"
  },
  {
    "path": "third_party/sparse_octree/src/octree.cpp",
    "content": "#include \"../include/octree.h\"\n#include \"../include/utils.h\"\n#include <queue>\n#include <iostream>\n\n// #define MAX_HIT_VOXELS 10\n// #define MAX_NUM_VOXELS 10000\n\nint Octant::next_index_ = 0;\n// int Octree::feature_index = 0;\n\nint incr_x[8] = {0, 0, 0, 0, 1, 1, 1, 1};\nint incr_y[8] = {0, 0, 1, 1, 0, 0, 1, 1};\nint incr_z[8] = {0, 1, 0, 1, 0, 1, 0, 1};\n\nOctree::Octree()\n{\n}\n\nOctree::Octree(int64_t grid_dim, int64_t feat_dim, double voxel_size, std::vector<torch::Tensor> all_pts)\n{\n    Octant::next_index_ = 0;\n    init(grid_dim, feat_dim, voxel_size);\n    for (auto &pt : all_pts)\n    {\n        insert(pt);\n    }\n}\n\nOctree::~Octree()\n{\n}\n\nvoid Octree::init(int64_t grid_dim, int64_t feat_dim, double voxel_size)\n{\n    size_ = grid_dim;\n    feat_dim_ = feat_dim;\n    voxel_size_ = voxel_size;\n    max_level_ = log2(size_);\n    // root_ = std::make_shared<Octant>();\n    root_ = new Octant();\n    root_->side_ = size_;\n    // root_->depth_ = 0;\n    root_->is_leaf_ = false;\n\n    // feats_allocated_ = 0;\n    // auto options = torch::TensorOptions().requires_grad(true);\n    // feats_array_ = torch::randn({MAX_NUM_VOXELS, feat_dim}, options) * 0.01;\n}\n\nvoid Octree::insert(torch::Tensor pts)\n{\n    // temporal solution\n    all_pts.push_back(pts);\n\n    if (root_ == nullptr)\n    {\n        std::cout << \"Octree not initialized!\" << std::endl;\n    }\n\n    auto points = pts.accessor<int, 2>();\n    if (points.size(1) != 3)\n    {\n        std::cout << \"Point dimensions mismatch: inputs are \" << points.size(1) << \" expect 3\" << std::endl;\n        return;\n    }\n\n    for (int i = 0; i < points.size(0); ++i)\n    {\n        for (int j = 0; j < 8; ++j)\n        {\n            int x = points[i][0] + incr_x[j];\n            int y = points[i][1] + incr_y[j];\n            int z = points[i][2] + incr_z[j];\n            uint64_t key = encode(x, y, z);\n\n            all_keys.insert(key);\n\n            const unsigned int shift = MAX_BITS - max_level_ - 1;\n\n            auto n = root_;\n            unsigned edge = size_ / 2;\n            for (int d = 1; d <= max_level_; edge /= 2, ++d)\n            {\n                const int childid = ((x & edge) > 0) + 2 * ((y & edge) > 0) + 4 * ((z & edge) > 0);\n                // std::cout << \"Level: \" << d << \" ChildID: \" << childid << std::endl;\n                auto tmp = n->child(childid);\n                if (!tmp)\n                {\n                    const uint64_t code = key & MASK[d + shift];\n                    const bool is_leaf = (d == max_level_);\n                    // tmp = std::make_shared<Octant>();\n                    tmp = new Octant();\n                    tmp->code_ = code;\n                    tmp->side_ = edge;\n                    tmp->is_leaf_ = is_leaf;\n                    tmp->type_ = is_leaf ? (j == 0 ? SURFACE : FEATURE) : NONLEAF;\n\n                    n->children_mask_ = n->children_mask_ | (1 << childid);\n                    n->child(childid) = tmp;\n                }\n                else\n                {\n                    if (tmp->type_ == FEATURE && j == 0)\n                        tmp->type_ = SURFACE;\n                }\n                n = tmp;\n            }\n        }\n    }\n}\n\ndouble Octree::try_insert(torch::Tensor pts)\n{\n    if (root_ == nullptr)\n    {\n        std::cout << \"Octree not initialized!\" << std::endl;\n    }\n\n    auto points = pts.accessor<int, 2>();\n    if (points.size(1) != 3)\n    {\n        std::cout << \"Point dimensions mismatch: inputs are \" << points.size(1) << \" expect 3\" << std::endl;\n        return -1.0;\n    }\n\n    std::set<uint64_t> tmp_keys;\n\n    for (int i = 0; i < points.size(0); ++i)\n    {\n        for (int j = 0; j < 8; ++j)\n        {\n            int x = points[i][0] + incr_x[j];\n            int y = points[i][1] + incr_y[j];\n            int z = points[i][2] + incr_z[j];\n            uint64_t key = encode(x, y, z);\n\n            tmp_keys.insert(key);\n        }\n    }\n\n    std::set<int> result;\n    std::set_intersection(all_keys.begin(), all_keys.end(),\n                          tmp_keys.begin(), tmp_keys.end(),\n                          std::inserter(result, result.end()));\n\n    double overlap_ratio = 1.0 * result.size() / tmp_keys.size();\n    return overlap_ratio;\n}\n\nOctant *Octree::find_octant(std::vector<float> coord)\n{\n    int x = int(coord[0]);\n    int y = int(coord[1]);\n    int z = int(coord[2]);\n    // uint64_t key = encode(x, y, z);\n    // const unsigned int shift = MAX_BITS - max_level_ - 1;\n\n    auto n = root_;\n    unsigned edge = size_ / 2;\n    for (int d = 1; d <= max_level_; edge /= 2, ++d)\n    {\n        const int childid = ((x & edge) > 0) + 2 * ((y & edge) > 0) + 4 * ((z & edge) > 0);\n        auto tmp = n->child(childid);\n        if (!tmp)\n            return nullptr;\n\n        n = tmp;\n    }\n    return n;\n}\n\nbool Octree::has_voxel(torch::Tensor pts)\n{\n    if (root_ == nullptr)\n    {\n        std::cout << \"Octree not initialized!\" << std::endl;\n    }\n\n    auto points = pts.accessor<int, 1>();\n    if (points.size(0) != 3)\n    {\n        return false;\n    }\n\n    int x = int(points[0]);\n    int y = int(points[1]);\n    int z = int(points[2]);\n\n    auto n = root_;\n    unsigned edge = size_ / 2;\n    for (int d = 1; d <= max_level_; edge /= 2, ++d)\n    {\n        const int childid = ((x & edge) > 0) + 2 * ((y & edge) > 0) + 4 * ((z & edge) > 0);\n        auto tmp = n->child(childid);\n        if (!tmp)\n            return false;\n\n        n = tmp;\n    }\n\n    if (!n)\n        return false;\n    else\n        return true;\n}\n\ntorch::Tensor Octree::get_features(torch::Tensor pts)\n{\n}\n\ntorch::Tensor Octree::get_leaf_voxels()\n{\n    std::vector<float> voxel_coords = get_leaf_voxel_recursive(root_);\n\n    int N = voxel_coords.size() / 3;\n    torch::Tensor voxels = torch::from_blob(voxel_coords.data(), {N, 3});\n    return voxels.clone();\n}\n\nstd::vector<float> Octree::get_leaf_voxel_recursive(Octant *n)\n{\n    if (!n)\n        return std::vector<float>();\n\n    if (n->is_leaf_ && n->type_ == SURFACE)\n    {\n        auto xyz = decode(n->code_);\n        return {xyz[0], xyz[1], xyz[2]};\n    }\n\n    std::vector<float> coords;\n    for (int i = 0; i < 8; i++)\n    {\n        auto temp = get_leaf_voxel_recursive(n->child(i));\n        coords.insert(coords.end(), temp.begin(), temp.end());\n    }\n\n    return coords;\n}\n\ntorch::Tensor Octree::get_voxels()\n{\n    std::vector<float> voxel_coords = get_voxel_recursive(root_);\n    int N = voxel_coords.size() / 4;\n    auto options = torch::TensorOptions().dtype(torch::kFloat32);\n    torch::Tensor voxels = torch::from_blob(voxel_coords.data(), {N, 4}, options);\n    return voxels.clone();\n}\n\nstd::vector<float> Octree::get_voxel_recursive(Octant *n)\n{\n    if (!n)\n        return std::vector<float>();\n\n    auto xyz = decode(n->code_);\n    std::vector<float> coords = {xyz[0], xyz[1], xyz[2], float(n->side_)};\n    for (int i = 0; i < 8; i++)\n    {\n        auto temp = get_voxel_recursive(n->child(i));\n        coords.insert(coords.end(), temp.begin(), temp.end());\n    }\n\n    return coords;\n}\n\nstd::pair<int64_t, int64_t> Octree::count_nodes_internal()\n{\n    return count_recursive_internal(root_);\n}\n\n// int64_t Octree::leaves_count_recursive(std::shared_ptr<Octant> n)\nstd::pair<int64_t, int64_t> Octree::count_recursive_internal(Octant *n)\n{\n    if (!n)\n        return std::make_pair<int64_t, int64_t>(0, 0);\n\n    if (n->is_leaf_)\n        return std::make_pair<int64_t, int64_t>(1, 1);\n\n    auto sum = std::make_pair<int64_t, int64_t>(1, 0);\n\n    for (int i = 0; i < 8; i++)\n    {\n        auto temp = count_recursive_internal(n->child(i));\n        sum.first += temp.first;\n        sum.second += temp.second;\n    }\n\n    return sum;\n}\n\nstd::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Octree::get_centres_and_children()\n{\n    auto node_count = count_nodes_internal();\n    auto total_count = node_count.first;\n    auto leaf_count = node_count.second;\n\n    auto all_voxels = torch::zeros({total_count, 4}, dtype(torch::kFloat32));\n    auto all_children = -torch::ones({total_count, 8}, dtype(torch::kFloat32));\n    auto all_features = -torch::ones({total_count, 8}, dtype(torch::kInt32));\n\n    std::queue<Octant *> all_nodes;\n    all_nodes.push(root_);\n\n    while (!all_nodes.empty())\n    {\n        auto node_ptr = all_nodes.front();\n        all_nodes.pop();\n\n        auto xyz = decode(node_ptr->code_);\n        std::vector<float> coords = {xyz[0], xyz[1], xyz[2], float(node_ptr->side_)};\n        auto voxel = torch::from_blob(coords.data(), {4}, dtype(torch::kFloat32));\n        all_voxels[node_ptr->index_] = voxel;\n\n        if (node_ptr->type_ == SURFACE)\n        {\n            for (int i = 0; i < 8; ++i)\n            {\n                std::vector<float> vcoords = coords;\n                vcoords[0] += incr_x[i];\n                vcoords[1] += incr_y[i];\n                vcoords[2] += incr_z[i];\n                auto voxel = find_octant(vcoords);\n                if (voxel)\n                    all_features.data_ptr<int>()[node_ptr->index_ * 8 + i] = voxel->index_;\n            }\n        }\n\n        for (int i = 0; i < 8; i++)\n        {\n            auto child_ptr = node_ptr->child(i);\n            if (child_ptr && child_ptr->type_ != FEATURE)\n            {\n                all_nodes.push(child_ptr);\n                all_children[node_ptr->index_][i] = float(child_ptr->index_);\n            }\n        }\n    }\n\n    return std::make_tuple(all_voxels, all_children, all_features);\n}\n\nint64_t Octree::count_nodes()\n{\n    return count_recursive(root_);\n}\n\n// int64_t Octree::leaves_count_recursive(std::shared_ptr<Octant> n)\nint64_t Octree::count_recursive(Octant *n)\n{\n    if (!n)\n        return 0;\n\n    int64_t sum = 1;\n\n    for (int i = 0; i < 8; i++)\n    {\n        sum += count_recursive(n->child(i));\n    }\n\n    return sum;\n}\n\nint64_t Octree::count_leaf_nodes()\n{\n    return leaves_count_recursive(root_);\n}\n\n// int64_t Octree::leaves_count_recursive(std::shared_ptr<Octant> n)\nint64_t Octree::leaves_count_recursive(Octant *n)\n{\n    if (!n)\n        return 0;\n\n    if (n->type_ == SURFACE)\n    {\n        return 1;\n    }\n\n    int64_t sum = 0;\n\n    for (int i = 0; i < 8; i++)\n    {\n        sum += leaves_count_recursive(n->child(i));\n    }\n\n    return sum;\n}\n"
  },
  {
    "path": "third_party/sparse_voxels/include/cuda_utils.h",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#ifndef _CUDA_UTILS_H\n#define _CUDA_UTILS_H\n\n#include <ATen/ATen.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <cmath>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n\n#include <vector>\n\n#define TOTAL_THREADS 512\n\ninline int opt_n_threads(int work_size)\n{\n  const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);\n\n  return max(min(1 << pow_2, TOTAL_THREADS), 1);\n}\n\ninline dim3 opt_block_config(int x, int y)\n{\n  const int x_threads = opt_n_threads(x);\n  const int y_threads =\n      max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);\n  dim3 block_config(x_threads, y_threads, 1);\n\n  return block_config;\n}\n\n#define CUDA_CHECK_ERRORS()                                           \\\n  do                                                                  \\\n  {                                                                   \\\n    cudaError_t err = cudaGetLastError();                             \\\n    if (cudaSuccess != err)                                           \\\n    {                                                                 \\\n      fprintf(stderr, \"CUDA kernel failed : %s\\n%s at L:%d in %s\\n\",  \\\n              cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \\\n              __FILE__);                                              \\\n      exit(-1);                                                       \\\n    }                                                                 \\\n  } while (0)\n\n#endif\n"
  },
  {
    "path": "third_party/sparse_voxels/include/cutil_math.h",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n/*\n * Copyright 1993-2009 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all intellectual property and \n * proprietary rights in and to this software and related documentation and \n * any modifications thereto.  Any use, reproduction, disclosure, or distribution \n * of this software and related documentation without an express license \n * agreement from NVIDIA Corporation is strictly prohibited.\n * \n */\n\n/*\n    This file implements common mathematical operations on vector types\n    (float3, float4 etc.) since these are not provided as standard by CUDA.\n\n    The syntax is modelled on the Cg standard library.\n*/\n\n#ifndef CUTIL_MATH_H\n#define CUTIL_MATH_H\n\n#include \"cuda_runtime.h\"\n\n////////////////////////////////////////////////////////////////////////////////\ntypedef unsigned int uint;\ntypedef unsigned short ushort;\n\n#ifndef __CUDACC__\n#include <math.h>\n\ninline float fminf(float a, float b)\n{\n    return a < b ? a : b;\n}\n\ninline float fmaxf(float a, float b)\n{\n    return a > b ? a : b;\n}\n\ninline int max(int a, int b)\n{\n    return a > b ? a : b;\n}\n\ninline int min(int a, int b)\n{\n    return a < b ? a : b;\n}\n\ninline float rsqrtf(float x)\n{\n    return 1.0f / sqrtf(x);\n}\n\n#endif\n\n// float functions\n////////////////////////////////////////////////////////////////////////////////\n\n// lerp\ninline __device__ __host__ float lerp(float a, float b, float t)\n{\n    return a + t * (b - a);\n}\n\n// clamp\ninline __device__ __host__ float clamp(float f, float a, float b)\n{\n    return fmaxf(a, fminf(f, b));\n}\n\ninline __device__ __host__ void swap(float &a, float &b)\n{\n    float c = a;\n    a = b;\n    b = c;\n}\n\ninline __device__ __host__ void swap(int &a, int &b)\n{\n    float c = a;\n    a = b;\n    b = c;\n}\n\n// int2 functions\n////////////////////////////////////////////////////////////////////////////////\n\n// negate\ninline __host__ __device__ int2 operator-(int2 &a)\n{\n    return make_int2(-a.x, -a.y);\n}\n\n// addition\ninline __host__ __device__ int2 operator+(int2 a, int2 b)\n{\n    return make_int2(a.x + b.x, a.y + b.y);\n}\ninline __host__ __device__ void operator+=(int2 &a, int2 b)\n{\n    a.x += b.x;\n    a.y += b.y;\n}\n\n// subtract\ninline __host__ __device__ int2 operator-(int2 a, int2 b)\n{\n    return make_int2(a.x - b.x, a.y - b.y);\n}\ninline __host__ __device__ void operator-=(int2 &a, int2 b)\n{\n    a.x -= b.x;\n    a.y -= b.y;\n}\n\n// multiply\ninline __host__ __device__ int2 operator*(int2 a, int2 b)\n{\n    return make_int2(a.x * b.x, a.y * b.y);\n}\ninline __host__ __device__ int2 operator*(int2 a, int s)\n{\n    return make_int2(a.x * s, a.y * s);\n}\ninline __host__ __device__ int2 operator*(int s, int2 a)\n{\n    return make_int2(a.x * s, a.y * s);\n}\ninline __host__ __device__ void operator*=(int2 &a, int s)\n{\n    a.x *= s;\n    a.y *= s;\n}\n\n// float2 functions\n////////////////////////////////////////////////////////////////////////////////\n\n// additional constructors\ninline __host__ __device__ float2 make_float2(float s)\n{\n    return make_float2(s, s);\n}\ninline __host__ __device__ float2 make_float2(int2 a)\n{\n    return make_float2(float(a.x), float(a.y));\n}\n\n// negate\ninline __host__ __device__ float2 operator-(float2 &a)\n{\n    return make_float2(-a.x, -a.y);\n}\n\n// addition\ninline __host__ __device__ float2 operator+(float2 a, float2 b)\n{\n    return make_float2(a.x + b.x, a.y + b.y);\n}\ninline __host__ __device__ void operator+=(float2 &a, float2 b)\n{\n    a.x += b.x;\n    a.y += b.y;\n}\n\n// subtract\ninline __host__ __device__ float2 operator-(float2 a, float2 b)\n{\n    return make_float2(a.x - b.x, a.y - b.y);\n}\ninline __host__ __device__ void operator-=(float2 &a, float2 b)\n{\n    a.x -= b.x;\n    a.y -= b.y;\n}\n\n// multiply\ninline __host__ __device__ float2 operator*(float2 a, float2 b)\n{\n    return make_float2(a.x * b.x, a.y * b.y);\n}\ninline __host__ __device__ float2 operator*(float2 a, float s)\n{\n    return make_float2(a.x * s, a.y * s);\n}\ninline __host__ __device__ float2 operator*(float s, float2 a)\n{\n    return make_float2(a.x * s, a.y * s);\n}\ninline __host__ __device__ void operator*=(float2 &a, float s)\n{\n    a.x *= s;\n    a.y *= s;\n}\n\n// divide\ninline __host__ __device__ float2 operator/(float2 a, float2 b)\n{\n    return make_float2(a.x / b.x, a.y / b.y);\n}\ninline __host__ __device__ float2 operator/(float2 a, float s)\n{\n    float inv = 1.0f / s;\n    return a * inv;\n}\ninline __host__ __device__ float2 operator/(float s, float2 a)\n{\n    float inv = 1.0f / s;\n    return a * inv;\n}\ninline __host__ __device__ void operator/=(float2 &a, float s)\n{\n    float inv = 1.0f / s;\n    a *= inv;\n}\n\n// lerp\ninline __device__ __host__ float2 lerp(float2 a, float2 b, float t)\n{\n    return a + t * (b - a);\n}\n\n// clamp\ninline __device__ __host__ float2 clamp(float2 v, float a, float b)\n{\n    return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));\n}\n\ninline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)\n{\n    return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));\n}\n\n// dot product\ninline __host__ __device__ float dot(float2 a, float2 b)\n{\n    return a.x * b.x + a.y * b.y;\n}\n\n// length\ninline __host__ __device__ float length(float2 v)\n{\n    return sqrtf(dot(v, v));\n}\n\n// normalize\ninline __host__ __device__ float2 normalize(float2 v)\n{\n    float invLen = rsqrtf(dot(v, v));\n    return v * invLen;\n}\n\n// floor\ninline __host__ __device__ float2 floor(const float2 v)\n{\n    return make_float2(floor(v.x), floor(v.y));\n}\n\n// reflect\ninline __host__ __device__ float2 reflect(float2 i, float2 n)\n{\n    return i - 2.0f * n * dot(n, i);\n}\n\n// absolute value\ninline __host__ __device__ float2 fabs(float2 v)\n{\n    return make_float2(fabs(v.x), fabs(v.y));\n}\n\n// float3 functions\n////////////////////////////////////////////////////////////////////////////////\n\n// additional constructors\ninline __host__ __device__ float3 make_float3(float s)\n{\n    return make_float3(s, s, s);\n}\ninline __host__ __device__ float3 make_float3(float2 a)\n{\n    return make_float3(a.x, a.y, 0.0f);\n}\ninline __host__ __device__ float3 make_float3(float2 a, float s)\n{\n    return make_float3(a.x, a.y, s);\n}\ninline __host__ __device__ float3 make_float3(float4 a)\n{\n    return make_float3(a.x, a.y, a.z); // discards w\n}\ninline __host__ __device__ float3 make_float3(int3 a)\n{\n    return make_float3(float(a.x), float(a.y), float(a.z));\n}\n\n// negate\ninline __host__ __device__ float3 operator-(float3 &a)\n{\n    return make_float3(-a.x, -a.y, -a.z);\n}\n\n// min\nstatic __inline__ __host__ __device__ float3 fminf(float3 a, float3 b)\n{\n    return make_float3(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z));\n}\n\n// max\nstatic __inline__ __host__ __device__ float3 fmaxf(float3 a, float3 b)\n{\n    return make_float3(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z));\n}\n\n// addition\ninline __host__ __device__ float3 operator+(float3 a, float3 b)\n{\n    return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);\n}\ninline __host__ __device__ float3 operator+(float3 a, float b)\n{\n    return make_float3(a.x + b, a.y + b, a.z + b);\n}\ninline __host__ __device__ void operator+=(float3 &a, float3 b)\n{\n    a.x += b.x;\n    a.y += b.y;\n    a.z += b.z;\n}\n\n// subtract\ninline __host__ __device__ float3 operator-(float3 a, float3 b)\n{\n    return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);\n}\ninline __host__ __device__ float3 operator-(float3 a, float b)\n{\n    return make_float3(a.x - b, a.y - b, a.z - b);\n}\ninline __host__ __device__ void operator-=(float3 &a, float3 b)\n{\n    a.x -= b.x;\n    a.y -= b.y;\n    a.z -= b.z;\n}\n\n// multiply\ninline __host__ __device__ float3 operator*(float3 a, float3 b)\n{\n    return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);\n}\ninline __host__ __device__ float3 operator*(float3 a, float s)\n{\n    return make_float3(a.x * s, a.y * s, a.z * s);\n}\ninline __host__ __device__ float3 operator*(float s, float3 a)\n{\n    return make_float3(a.x * s, a.y * s, a.z * s);\n}\ninline __host__ __device__ void operator*=(float3 &a, float s)\n{\n    a.x *= s;\n    a.y *= s;\n    a.z *= s;\n}\ninline __host__ __device__ void operator*=(float3 &a, float3 b)\n{\n    a.x *= b.x;\n    a.y *= b.y;\n    a.z *= b.z;\n    ;\n}\n\n// divide\ninline __host__ __device__ float3 operator/(float3 a, float3 b)\n{\n    return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);\n}\ninline __host__ __device__ float3 operator/(float3 a, float s)\n{\n    float inv = 1.0f / s;\n    return a * inv;\n}\ninline __host__ __device__ float3 operator/(float s, float3 a)\n{\n    float inv = 1.0f / s;\n    return a * inv;\n}\ninline __host__ __device__ void operator/=(float3 &a, float s)\n{\n    float inv = 1.0f / s;\n    a *= inv;\n}\n\n// lerp\ninline __device__ __host__ float3 lerp(float3 a, float3 b, float t)\n{\n    return a + t * (b - a);\n}\n\n// clamp\ninline __device__ __host__ float3 clamp(float3 v, float a, float b)\n{\n    return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));\n}\n\ninline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)\n{\n    return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));\n}\n\n// dot product\ninline __host__ __device__ float dot(float3 a, float3 b)\n{\n    return a.x * b.x + a.y * b.y + a.z * b.z;\n}\n\n// cross product\ninline __host__ __device__ float3 cross(float3 a, float3 b)\n{\n    return make_float3(a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);\n}\n\n// length\ninline __host__ __device__ float length(float3 v)\n{\n    return sqrtf(dot(v, v));\n}\n\n// normalize\ninline __host__ __device__ float3 normalize(float3 v)\n{\n    float invLen = rsqrtf(dot(v, v));\n    return v * invLen;\n}\n\n// floor\ninline __host__ __device__ float3 floor(const float3 v)\n{\n    return make_float3(floor(v.x), floor(v.y), floor(v.z));\n}\n\n// reflect\ninline __host__ __device__ float3 reflect(float3 i, float3 n)\n{\n    return i - 2.0f * n * dot(n, i);\n}\n\n// absolute value\ninline __host__ __device__ float3 fabs(float3 v)\n{\n    return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));\n}\n\n// float4 functions\n////////////////////////////////////////////////////////////////////////////////\n\n// additional constructors\ninline __host__ __device__ float4 make_float4(float s)\n{\n    return make_float4(s, s, s, s);\n}\ninline __host__ __device__ float4 make_float4(float3 a)\n{\n    return make_float4(a.x, a.y, a.z, 0.0f);\n}\ninline __host__ __device__ float4 make_float4(float3 a, float w)\n{\n    return make_float4(a.x, a.y, a.z, w);\n}\ninline __host__ __device__ float4 make_float4(int4 a)\n{\n    return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));\n}\n\n// negate\ninline __host__ __device__ float4 operator-(float4 &a)\n{\n    return make_float4(-a.x, -a.y, -a.z, -a.w);\n}\n\n// min\nstatic __inline__ __host__ __device__ float4 fminf(float4 a, float4 b)\n{\n    return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w));\n}\n\n// max\nstatic __inline__ __host__ __device__ float4 fmaxf(float4 a, float4 b)\n{\n    return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w));\n}\n\n// addition\ninline __host__ __device__ float4 operator+(float4 a, float4 b)\n{\n    return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);\n}\ninline __host__ __device__ void operator+=(float4 &a, float4 b)\n{\n    a.x += b.x;\n    a.y += b.y;\n    a.z += b.z;\n    a.w += b.w;\n}\n\n// subtract\ninline __host__ __device__ float4 operator-(float4 a, float4 b)\n{\n    return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);\n}\ninline __host__ __device__ void operator-=(float4 &a, float4 b)\n{\n    a.x -= b.x;\n    a.y -= b.y;\n    a.z -= b.z;\n    a.w -= b.w;\n}\n\n// multiply\ninline __host__ __device__ float4 operator*(float4 a, float s)\n{\n    return make_float4(a.x * s, a.y * s, a.z * s, a.w * s);\n}\ninline __host__ __device__ float4 operator*(float s, float4 a)\n{\n    return make_float4(a.x * s, a.y * s, a.z * s, a.w * s);\n}\ninline __host__ __device__ void operator*=(float4 &a, float s)\n{\n    a.x *= s;\n    a.y *= s;\n    a.z *= s;\n    a.w *= s;\n}\n\n// divide\ninline __host__ __device__ float4 operator/(float4 a, float4 b)\n{\n    return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);\n}\ninline __host__ __device__ float4 operator/(float4 a, float s)\n{\n    float inv = 1.0f / s;\n    return a * inv;\n}\ninline __host__ __device__ float4 operator/(float s, float4 a)\n{\n    float inv = 1.0f / s;\n    return a * inv;\n}\ninline __host__ __device__ void operator/=(float4 &a, float s)\n{\n    float inv = 1.0f / s;\n    a *= inv;\n}\n\n// lerp\ninline __device__ __host__ float4 lerp(float4 a, float4 b, float t)\n{\n    return a + t * (b - a);\n}\n\n// clamp\ninline __device__ __host__ float4 clamp(float4 v, float a, float b)\n{\n    return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));\n}\n\ninline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)\n{\n    return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));\n}\n\n// dot product\ninline __host__ __device__ float dot(float4 a, float4 b)\n{\n    return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;\n}\n\n// length\ninline __host__ __device__ float length(float4 r)\n{\n    return sqrtf(dot(r, r));\n}\n\n// normalize\ninline __host__ __device__ float4 normalize(float4 v)\n{\n    float invLen = rsqrtf(dot(v, v));\n    return v * invLen;\n}\n\n// floor\ninline __host__ __device__ float4 floor(const float4 v)\n{\n    return make_float4(floor(v.x), floor(v.y), floor(v.z), floor(v.w));\n}\n\n// absolute value\ninline __host__ __device__ float4 fabs(float4 v)\n{\n    return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));\n}\n\n// int3 functions\n////////////////////////////////////////////////////////////////////////////////\n\n// additional constructors\ninline __host__ __device__ int3 make_int3(int s)\n{\n    return make_int3(s, s, s);\n}\ninline __host__ __device__ int3 make_int3(float3 a)\n{\n    return make_int3(int(a.x), int(a.y), int(a.z));\n}\n\n// negate\ninline __host__ __device__ int3 operator-(int3 &a)\n{\n    return make_int3(-a.x, -a.y, -a.z);\n}\n\n// min\ninline __host__ __device__ int3 min(int3 a, int3 b)\n{\n    return make_int3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z));\n}\n\n// max\ninline __host__ __device__ int3 max(int3 a, int3 b)\n{\n    return make_int3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z));\n}\n\n// addition\ninline __host__ __device__ int3 operator+(int3 a, int3 b)\n{\n    return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);\n}\ninline __host__ __device__ void operator+=(int3 &a, int3 b)\n{\n    a.x += b.x;\n    a.y += b.y;\n    a.z += b.z;\n}\n\n// subtract\ninline __host__ __device__ int3 operator-(int3 a, int3 b)\n{\n    return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);\n}\n\ninline __host__ __device__ void operator-=(int3 &a, int3 b)\n{\n    a.x -= b.x;\n    a.y -= b.y;\n    a.z -= b.z;\n}\n\n// multiply\ninline __host__ __device__ int3 operator*(int3 a, int3 b)\n{\n    return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);\n}\ninline __host__ __device__ int3 operator*(int3 a, int s)\n{\n    return make_int3(a.x * s, a.y * s, a.z * s);\n}\ninline __host__ __device__ int3 operator*(int s, int3 a)\n{\n    return make_int3(a.x * s, a.y * s, a.z * s);\n}\ninline __host__ __device__ void operator*=(int3 &a, int s)\n{\n    a.x *= s;\n    a.y *= s;\n    a.z *= s;\n}\n\n// divide\ninline __host__ __device__ int3 operator/(int3 a, int3 b)\n{\n    return make_int3(a.x / b.x, a.y / b.y, a.z / b.z);\n}\ninline __host__ __device__ int3 operator/(int3 a, int s)\n{\n    return make_int3(a.x / s, a.y / s, a.z / s);\n}\ninline __host__ __device__ int3 operator/(int s, int3 a)\n{\n    return make_int3(a.x / s, a.y / s, a.z / s);\n}\ninline __host__ __device__ void operator/=(int3 &a, int s)\n{\n    a.x /= s;\n    a.y /= s;\n    a.z /= s;\n}\n\n// clamp\ninline __device__ __host__ int clamp(int f, int a, int b)\n{\n    return max(a, min(f, b));\n}\n\ninline __device__ __host__ int3 clamp(int3 v, int a, int b)\n{\n    return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));\n}\n\ninline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)\n{\n    return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));\n}\n\n// uint3 functions\n////////////////////////////////////////////////////////////////////////////////\n\n// additional constructors\ninline __host__ __device__ uint3 make_uint3(uint s)\n{\n    return make_uint3(s, s, s);\n}\ninline __host__ __device__ uint3 make_uint3(float3 a)\n{\n    return make_uint3(uint(a.x), uint(a.y), uint(a.z));\n}\n\n// min\ninline __host__ __device__ uint3 min(uint3 a, uint3 b)\n{\n    return make_uint3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z));\n}\n\n// max\ninline __host__ __device__ uint3 max(uint3 a, uint3 b)\n{\n    return make_uint3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z));\n}\n\n// addition\ninline __host__ __device__ uint3 operator+(uint3 a, uint3 b)\n{\n    return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);\n}\ninline __host__ __device__ void operator+=(uint3 &a, uint3 b)\n{\n    a.x += b.x;\n    a.y += b.y;\n    a.z += b.z;\n}\n\n// subtract\ninline __host__ __device__ uint3 operator-(uint3 a, uint3 b)\n{\n    return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);\n}\n\ninline __host__ __device__ void operator-=(uint3 &a, uint3 b)\n{\n    a.x -= b.x;\n    a.y -= b.y;\n    a.z -= b.z;\n}\n\n// multiply\ninline __host__ __device__ uint3 operator*(uint3 a, uint3 b)\n{\n    return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);\n}\ninline __host__ __device__ uint3 operator*(uint3 a, uint s)\n{\n    return make_uint3(a.x * s, a.y * s, a.z * s);\n}\ninline __host__ __device__ uint3 operator*(uint s, uint3 a)\n{\n    return make_uint3(a.x * s, a.y * s, a.z * s);\n}\ninline __host__ __device__ void operator*=(uint3 &a, uint s)\n{\n    a.x *= s;\n    a.y *= s;\n    a.z *= s;\n}\n\n// divide\ninline __host__ __device__ uint3 operator/(uint3 a, uint3 b)\n{\n    return make_uint3(a.x / b.x, a.y / b.y, a.z / b.z);\n}\ninline __host__ __device__ uint3 operator/(uint3 a, uint s)\n{\n    return make_uint3(a.x / s, a.y / s, a.z / s);\n}\ninline __host__ __device__ uint3 operator/(uint s, uint3 a)\n{\n    return make_uint3(a.x / s, a.y / s, a.z / s);\n}\ninline __host__ __device__ void operator/=(uint3 &a, uint s)\n{\n    a.x /= s;\n    a.y /= s;\n    a.z /= s;\n}\n\n// clamp\ninline __device__ __host__ uint clamp(uint f, uint a, uint b)\n{\n    return max(a, min(f, b));\n}\n\ninline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)\n{\n    return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));\n}\n\ninline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)\n{\n    return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));\n}\n\n#endif"
  },
  {
    "path": "third_party/sparse_voxels/include/intersect.h",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#pragma once\n#include <torch/extension.h>\n#include <utility>\n\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,\n                                                              const float radius, const int n_max);\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,\n                                                              const float voxelsize, const int n_max);\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, at::Tensor children,\n                                                             const float voxelsize, const int n_max);\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points,\n                                                                  const float cagesize, const float blur, const int n_max);\n"
  },
  {
    "path": "third_party/sparse_voxels/include/octree.h",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#pragma once\n#include <torch/extension.h>\n#include <utility>\n\nstd::tuple<at::Tensor, at::Tensor> build_octree(at::Tensor center, at::Tensor points, int depth);"
  },
  {
    "path": "third_party/sparse_voxels/include/sample.h",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#pragma once\n#include <torch/extension.h>\n#include <utility>\n\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling(\n    at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,\n    const float step_size, const int max_steps);\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> inverse_cdf_sampling(\n    at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,\n    at::Tensor probs, at::Tensor steps, float fixed_step_size);"
  },
  {
    "path": "third_party/sparse_voxels/include/utils.h",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#pragma once\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/extension.h>\n\n#define CHECK_CUDA(x)                                             \\\n  do                                                              \\\n  {                                                               \\\n    TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\"); \\\n  } while (0)\n\n#define CHECK_CONTIGUOUS(x)                                            \\\n  do                                                                   \\\n  {                                                                    \\\n    TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\"); \\\n  } while (0)\n\n#define CHECK_IS_INT(x)                                 \\\n  do                                                    \\\n  {                                                     \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \\\n                #x \" must be an int tensor\");           \\\n  } while (0)\n\n#define CHECK_IS_FLOAT(x)                                 \\\n  do                                                      \\\n  {                                                       \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \\\n                #x \" must be a float tensor\");            \\\n  } while (0)\n"
  },
  {
    "path": "third_party/sparse_voxels/setup.py",
    "content": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n# LICENSE file in the root directory of this source tree.\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nimport glob\n\n_ext_sources = glob.glob(\"src/*.cpp\") + glob.glob(\"src/*.cu\")\n\nsetup(\n    name='grid',\n    ext_modules=[\n        CUDAExtension(\n            name='grid',\n            sources=_ext_sources,\n            include_dirs=[\"./include\"],\n            extra_compile_args={\n                \"cxx\": [\"-O2\", \"-I./include\"],\n                \"nvcc\": [\"-O2\", \"-I./include\"],\n            },\n        )\n    ],\n    cmdclass={\n        'build_ext': BuildExtension\n    }\n)\n"
  },
  {
    "path": "third_party/sparse_voxels/src/binding.cpp",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#include \"../include/intersect.h\"\n#include \"../include/octree.h\"\n#include \"../include/sample.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n  m.def(\"ball_intersect\", &ball_intersect);\n  m.def(\"aabb_intersect\", &aabb_intersect);\n  m.def(\"svo_intersect\", &svo_intersect);\n  m.def(\"triangle_intersect\", &triangle_intersect);\n\n  m.def(\"uniform_ray_sampling\", &uniform_ray_sampling);\n  m.def(\"inverse_cdf_sampling\", &inverse_cdf_sampling);\n\n  m.def(\"build_octree\", &build_octree);\n}"
  },
  {
    "path": "third_party/sparse_voxels/src/intersect.cpp",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#include \"../include/intersect.h\"\n#include \"../include/utils.h\"\n#include <utility>\n\nvoid ball_intersect_point_kernel_wrapper(\n    int b, int n, int m, float radius, int n_max,\n    const float *ray_start, const float *ray_dir, const float *points,\n    int *idx, float *min_depth, float *max_depth);\n\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,\n                                                              const float radius, const int n_max)\n{\n  CHECK_CONTIGUOUS(ray_start);\n  CHECK_CONTIGUOUS(ray_dir);\n  CHECK_CONTIGUOUS(points);\n  CHECK_IS_FLOAT(ray_start);\n  CHECK_IS_FLOAT(ray_dir);\n  CHECK_IS_FLOAT(points);\n  CHECK_CUDA(ray_start);\n  CHECK_CUDA(ray_dir);\n  CHECK_CUDA(points);\n\n  at::Tensor idx =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Int));\n  at::Tensor min_depth =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Float));\n  at::Tensor max_depth =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Float));\n  ball_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),\n                                      radius, n_max,\n                                      ray_start.data_ptr<float>(), ray_dir.data_ptr<float>(), points.data_ptr<float>(),\n                                      idx.data_ptr<int>(), min_depth.data_ptr<float>(), max_depth.data_ptr<float>());\n  return std::make_tuple(idx, min_depth, max_depth);\n}\n\nvoid aabb_intersect_point_kernel_wrapper(\n    int b, int n, int m, float voxelsize, int n_max,\n    const float *ray_start, const float *ray_dir, const float *points,\n    int *idx, float *min_depth, float *max_depth);\n\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,\n                                                              const float voxelsize, const int n_max)\n{\n  CHECK_CONTIGUOUS(ray_start);\n  CHECK_CONTIGUOUS(ray_dir);\n  CHECK_CONTIGUOUS(points);\n  CHECK_IS_FLOAT(ray_start);\n  CHECK_IS_FLOAT(ray_dir);\n  CHECK_IS_FLOAT(points);\n  CHECK_CUDA(ray_start);\n  CHECK_CUDA(ray_dir);\n  CHECK_CUDA(points);\n\n  at::Tensor idx =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Int));\n  at::Tensor min_depth =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Float));\n  at::Tensor max_depth =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Float));\n  aabb_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),\n                                      voxelsize, n_max,\n                                      ray_start.data_ptr<float>(), ray_dir.data_ptr<float>(), points.data_ptr<float>(),\n                                      idx.data_ptr<int>(), min_depth.data_ptr<float>(), max_depth.data_ptr<float>());\n  return std::make_tuple(idx, min_depth, max_depth);\n}\n\nvoid svo_intersect_point_kernel_wrapper(\n    int b, int n, int m, float voxelsize, int n_max,\n    const float *ray_start, const float *ray_dir, const float *points, const int *children,\n    int *idx, float *min_depth, float *max_depth);\n\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,\n                                                             at::Tensor children, const float voxelsize, const int n_max)\n{\n  CHECK_CONTIGUOUS(ray_start);\n  CHECK_CONTIGUOUS(ray_dir);\n  CHECK_CONTIGUOUS(points);\n  CHECK_CONTIGUOUS(children);\n  CHECK_IS_FLOAT(ray_start);\n  CHECK_IS_FLOAT(ray_dir);\n  CHECK_IS_FLOAT(points);\n  CHECK_CUDA(ray_start);\n  CHECK_CUDA(ray_dir);\n  CHECK_CUDA(points);\n  CHECK_CUDA(children);\n\n  at::Tensor idx =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Int));\n  at::Tensor min_depth =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Float));\n  at::Tensor max_depth =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Float));\n  svo_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),\n                                     voxelsize, n_max,\n                                     ray_start.data_ptr<float>(), ray_dir.data_ptr<float>(), points.data_ptr<float>(),\n                                     children.data_ptr<int>(), idx.data_ptr<int>(), min_depth.data_ptr<float>(), max_depth.data_ptr<float>());\n  return std::make_tuple(idx, min_depth, max_depth);\n}\n\nvoid triangle_intersect_point_kernel_wrapper(\n    int b, int n, int m, float cagesize, float blur, int n_max,\n    const float *ray_start, const float *ray_dir, const float *face_points,\n    int *idx, float *depth, float *uv);\n\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points,\n                                                                  const float cagesize, const float blur, const int n_max)\n{\n  CHECK_CONTIGUOUS(ray_start);\n  CHECK_CONTIGUOUS(ray_dir);\n  CHECK_CONTIGUOUS(face_points);\n  CHECK_IS_FLOAT(ray_start);\n  CHECK_IS_FLOAT(ray_dir);\n  CHECK_IS_FLOAT(face_points);\n  CHECK_CUDA(ray_start);\n  CHECK_CUDA(ray_dir);\n  CHECK_CUDA(face_points);\n\n  at::Tensor idx =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Int));\n  at::Tensor depth =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 3},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Float));\n  at::Tensor uv =\n      torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 2},\n                   at::device(ray_start.device()).dtype(at::ScalarType::Float));\n  triangle_intersect_point_kernel_wrapper(face_points.size(0), face_points.size(1), ray_start.size(1),\n                                          cagesize, blur, n_max,\n                                          ray_start.data_ptr<float>(), ray_dir.data_ptr<float>(), face_points.data_ptr<float>(),\n                                          idx.data_ptr<int>(), depth.data_ptr<float>(), uv.data_ptr<float>());\n  return std::make_tuple(idx, depth, uv);\n}\n"
  },
  {
    "path": "third_party/sparse_voxels/src/intersect_gpu.cu",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n\n#include <math.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n#include \"../include/cuda_utils.h\"\n#include \"../include/cutil_math.h\" // required for float3 vector math\n\n\n__global__ void ball_intersect_point_kernel(\n    int b, int n, int m, float radius,\n    int n_max,\n    const float *__restrict__ ray_start,\n    const float *__restrict__ ray_dir,\n    const float *__restrict__ points,\n    int *__restrict__ idx,\n    float *__restrict__ min_depth,\n    float *__restrict__ max_depth)\n{\n\n  int batch_index = blockIdx.x;\n  points += batch_index * n * 3;\n  ray_start += batch_index * m * 3;\n  ray_dir += batch_index * m * 3;\n  idx += batch_index * m * n_max;\n  min_depth += batch_index * m * n_max;\n  max_depth += batch_index * m * n_max;\n\n  int index = threadIdx.x;\n  int stride = blockDim.x;\n  float radius2 = radius * radius;\n\n  for (int j = index; j < m; j += stride)\n  {\n\n    float x0 = ray_start[j * 3 + 0];\n    float y0 = ray_start[j * 3 + 1];\n    float z0 = ray_start[j * 3 + 2];\n    float xw = ray_dir[j * 3 + 0];\n    float yw = ray_dir[j * 3 + 1];\n    float zw = ray_dir[j * 3 + 2];\n\n    for (int l = 0; l < n_max; ++l)\n    {\n      idx[j * n_max + l] = -1;\n    }\n\n    for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k)\n    {\n      float x = points[k * 3 + 0] - x0;\n      float y = points[k * 3 + 1] - y0;\n      float z = points[k * 3 + 2] - z0;\n      float d2 = x * x + y * y + z * z;\n      float d2_proj = pow(x * xw + y * yw + z * zw, 2);\n      float r2 = d2 - d2_proj;\n\n      if (r2 < radius2)\n      {\n        idx[j * n_max + cnt] = k;\n\n        float depth = sqrt(d2_proj);\n        float depth_blur = sqrt(radius2 - r2);\n\n        min_depth[j * n_max + cnt] = depth - depth_blur;\n        max_depth[j * n_max + cnt] = depth + depth_blur;\n        ++cnt;\n      }\n    }\n  }\n}\n\n__device__ float2 RayAABBIntersection(\n    const float3 &ori,\n    const float3 &dir,\n    const float3 &center,\n    float half_voxel)\n{\n\n  float f_low = 0;\n  float f_high = 100000.;\n  float f_dim_low, f_dim_high, temp, inv_ray_dir, start, aabb;\n\n  for (int d = 0; d < 3; ++d)\n  {\n    switch (d)\n    {\n    case 0:\n      inv_ray_dir = __fdividef(1.0f, dir.x);\n      start = ori.x;\n      aabb = center.x;\n      break;\n    case 1:\n      inv_ray_dir = __fdividef(1.0f, dir.y);\n      start = ori.y;\n      aabb = center.y;\n      break;\n    case 2:\n      inv_ray_dir = __fdividef(1.0f, dir.z);\n      start = ori.z;\n      aabb = center.z;\n      break;\n    }\n\n    f_dim_low = (aabb - half_voxel - start) * inv_ray_dir;\n    f_dim_high = (aabb + half_voxel - start) * inv_ray_dir;\n\n    // Make sure low is less than high\n    if (f_dim_high < f_dim_low)\n    {\n      temp = f_dim_low;\n      f_dim_low = f_dim_high;\n      f_dim_high = temp;\n    }\n\n    // If this dimension's high is less than the low we got then we definitely missed.\n    if (f_dim_high < f_low)\n    {\n      return make_float2(-1.0f, -1.0f);\n    }\n\n    // Likewise if the low is less than the high.\n    if (f_dim_low > f_high)\n    {\n      return make_float2(-1.0f, -1.0f);\n    }\n\n    // Add the clip from this dimension to the previous results\n    f_low = (f_dim_low > f_low) ? f_dim_low : f_low;\n    f_high = (f_dim_high < f_high) ? f_dim_high : f_high;\n\n    if (f_low > f_high)\n    {\n      return make_float2(-1.0f, -1.0f);\n    }\n  }\n  return make_float2(f_low, f_high);\n}\n\n__global__ void aabb_intersect_point_kernel(\n    int b, int n, int m, float voxelsize,\n    int n_max,\n    const float *__restrict__ ray_start,\n    const float *__restrict__ ray_dir,\n    const float *__restrict__ points,\n    int *__restrict__ idx,\n    float *__restrict__ min_depth,\n    float *__restrict__ max_depth)\n{\n\n  int batch_index = blockIdx.x;\n  points += batch_index * n * 3;\n  ray_start += batch_index * m * 3;\n  ray_dir += batch_index * m * 3;\n  idx += batch_index * m * n_max;\n  min_depth += batch_index * m * n_max;\n  max_depth += batch_index * m * n_max;\n\n  int index = threadIdx.x;\n  int stride = blockDim.x;\n  float half_voxel = voxelsize * 0.5;\n\n  for (int j = index; j < m; j += stride)\n  {\n    for (int l = 0; l < n_max; ++l)\n    {\n      idx[j * n_max + l] = -1;\n    }\n\n    for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k)\n    {\n      float2 depths = RayAABBIntersection(\n          make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),\n          make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),\n          make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]),\n          half_voxel);\n\n      if (depths.x > -1.0f)\n      {\n        idx[j * n_max + cnt] = k;\n        min_depth[j * n_max + cnt] = depths.x;\n        max_depth[j * n_max + cnt] = depths.y;\n        ++cnt;\n      }\n    }\n  }\n}\n\n__global__ void svo_intersect_point_kernel(\n    int b, int n, int m, float voxelsize,\n    int n_max,\n    const float *__restrict__ ray_start,\n    const float *__restrict__ ray_dir,\n    const float *__restrict__ points,\n    const int *__restrict__ children,\n    int *__restrict__ idx,\n    float *__restrict__ min_depth,\n    float *__restrict__ max_depth)\n{\n  /*\n  TODO: this is an inefficient implementation of the \n        navie Ray -- Sparse Voxel Octree Intersection. \n        It can be further improved using:\n        \n        Revelles, Jorge, Carlos Urena, and Miguel Lastra. \n        \"An efficient parametric algorithm for octree traversal.\" (2000).\n  */\n  int batch_index = blockIdx.x;\n  points += batch_index * n * 3;\n  children += batch_index * n * 9;\n  ray_start += batch_index * m * 3;\n  ray_dir += batch_index * m * 3;\n  idx += batch_index * m * n_max;\n  min_depth += batch_index * m * n_max;\n  max_depth += batch_index * m * n_max;\n\n  int index = threadIdx.x;\n  int stride = blockDim.x;\n  float half_voxel = voxelsize * 0.5;\n\n  for (int j = index; j < m; j += stride)\n  {\n    for (int l = 0; l < n_max; ++l)\n    {\n      idx[j * n_max + l] = -1;\n    }\n    int stack[256] = {-1}; // DFS, initialize the stack\n    int ptr = 0, cnt = 0, k = -1;\n    // stack[ptr] = n - 1; // ROOT node is always the last\n    stack[ptr] = 0;\n    while (ptr > -1 && cnt < n_max)\n    {\n      assert((ptr < 256));\n\n      // evaluate the current node\n      k = stack[ptr];\n      float2 depths = RayAABBIntersection(\n          make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),\n          make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),\n          make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]),\n          half_voxel * float(children[k * 9 + 8]));\n      stack[ptr] = -1;\n      ptr--;\n\n      if (depths.x > -1.0f)\n      { // ray did not miss the voxel\n        // TODO: here it should be able to know which children is ok, further optimize the code\n        if (children[k * 9 + 8] == 1)\n        { // this is a terminal node\n          idx[j * n_max + cnt] = k;\n          min_depth[j * n_max + cnt] = depths.x;\n          max_depth[j * n_max + cnt] = depths.y;\n          ++cnt;\n          continue;\n        }\n\n        for (int u = 0; u < 8; u++)\n        {\n          if (children[k * 9 + u] > -1)\n          {\n            ptr++;\n            stack[ptr] = children[k * 9 + u]; // push child to the stack\n          }\n        }\n      }\n    }\n  }\n}\n\n__device__ float3 RayTriangleIntersection(\n    const float3 &ori,\n    const float3 &dir,\n    const float3 &v0,\n    const float3 &v1,\n    const float3 &v2,\n    float blur)\n{\n\n  float3 v0v1 = v1 - v0;\n  float3 v0v2 = v2 - v0;\n  float3 v0O = ori - v0;\n  float3 dir_crs_v0v2 = cross(dir, v0v2);\n\n  float det = dot(v0v1, dir_crs_v0v2);\n  det = __fdividef(1.0f, det); // CUDA intrinsic function\n\n  float u = dot(v0O, dir_crs_v0v2) * det;\n  if ((u < 0.0f - blur) || (u > 1.0f + blur))\n    return make_float3(-1.0f, 0.0f, 0.0f);\n\n  float3 v0O_crs_v0v1 = cross(v0O, v0v1);\n  float v = dot(dir, v0O_crs_v0v1) * det;\n  if ((v < 0.0f - blur) || (v > 1.0f + blur))\n    return make_float3(-1.0f, 0.0f, 0.0f);\n\n  if (((u + v) < 0.0f - blur) || ((u + v) > 1.0f + blur))\n    return make_float3(-1.0f, 0.0f, 0.0f);\n\n  float t = dot(v0v2, v0O_crs_v0v1) * det;\n  return make_float3(t, u, v);\n}\n\n__global__ void triangle_intersect_point_kernel(\n    int b, int n, int m, float cagesize,\n    float blur, int n_max,\n    const float *__restrict__ ray_start,\n    const float *__restrict__ ray_dir,\n    const float *__restrict__ face_points,\n    int *__restrict__ idx,\n    float *__restrict__ depth,\n    float *__restrict__ uv)\n{\n\n  int batch_index = blockIdx.x;\n  face_points += batch_index * n * 9;\n  ray_start += batch_index * m * 3;\n  ray_dir += batch_index * m * 3;\n  idx += batch_index * m * n_max;\n  depth += batch_index * m * n_max * 3;\n  uv += batch_index * m * n_max * 2;\n\n  int index = threadIdx.x;\n  int stride = blockDim.x;\n  for (int j = index; j < m; j += stride)\n  {\n    // go over rays\n    for (int l = 0; l < n_max; ++l)\n    {\n      idx[j * n_max + l] = -1;\n    }\n\n    int cnt = 0;\n    for (int k = 0; k < n && cnt < n_max; ++k)\n    {\n      // go over triangles\n      float3 tuv = RayTriangleIntersection(\n          make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),\n          make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),\n          make_float3(face_points[k * 9 + 0], face_points[k * 9 + 1], face_points[k * 9 + 2]),\n          make_float3(face_points[k * 9 + 3], face_points[k * 9 + 4], face_points[k * 9 + 5]),\n          make_float3(face_points[k * 9 + 6], face_points[k * 9 + 7], face_points[k * 9 + 8]),\n          blur);\n\n      if (tuv.x > 0)\n      {\n        int ki = k;\n        float d = tuv.x, u = tuv.y, v = tuv.z;\n\n        // sort\n        for (int l = 0; l < cnt; l++)\n        {\n          if (d < depth[j * n_max * 3 + l * 3])\n          {\n            swap(ki, idx[j * n_max + l]);\n            swap(d, depth[j * n_max * 3 + l * 3]);\n            swap(u, uv[j * n_max * 2 + l * 2]);\n            swap(v, uv[j * n_max * 2 + l * 2 + 1]);\n          }\n        }\n        idx[j * n_max + cnt] = ki;\n        depth[j * n_max * 3 + cnt * 3] = d;\n        uv[j * n_max * 2 + cnt * 2] = u;\n        uv[j * n_max * 2 + cnt * 2 + 1] = v;\n        cnt++;\n      }\n    }\n\n    for (int l = 0; l < cnt; l++)\n    {\n      // compute min_depth\n      if (l == 0)\n        depth[j * n_max * 3 + l * 3 + 1] = -cagesize;\n      else\n        depth[j * n_max * 3 + l * 3 + 1] = -fminf(cagesize,\n                                                  .5 * (depth[j * n_max * 3 + l * 3] - depth[j * n_max * 3 + l * 3 - 3]));\n\n      // compute max_depth\n      if (l == cnt - 1)\n        depth[j * n_max * 3 + l * 3 + 2] = cagesize;\n      else\n        depth[j * n_max * 3 + l * 3 + 2] = fminf(cagesize,\n                                                 .5 * (depth[j * n_max * 3 + l * 3 + 3] - depth[j * n_max * 3 + l * 3]));\n    }\n  }\n}\n\nvoid ball_intersect_point_kernel_wrapper(\n    int b, int n, int m, float radius, int n_max,\n    const float *ray_start, const float *ray_dir, const float *points,\n    int *idx, float *min_depth, float *max_depth)\n{\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  ball_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(\n      b, n, m, radius, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth);\n\n  CUDA_CHECK_ERRORS();\n}\n\nvoid aabb_intersect_point_kernel_wrapper(\n    int b, int n, int m, float voxelsize, int n_max,\n    const float *ray_start, const float *ray_dir, const float *points,\n    int *idx, float *min_depth, float *max_depth)\n{\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  aabb_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(\n      b, n, m, voxelsize, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth);\n\n  CUDA_CHECK_ERRORS();\n}\n\nvoid svo_intersect_point_kernel_wrapper(\n    int b, int n, int m, float voxelsize, int n_max,\n    const float *ray_start, const float *ray_dir, const float *points, const int *children,\n    int *idx, float *min_depth, float *max_depth)\n{\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  svo_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(\n      b, n, m, voxelsize, n_max, ray_start, ray_dir, points, children, idx, min_depth, max_depth);\n\n  CUDA_CHECK_ERRORS();\n}\n\nvoid triangle_intersect_point_kernel_wrapper(\n    int b, int n, int m, float cagesize, float blur, int n_max,\n    const float *ray_start, const float *ray_dir, const float *face_points,\n    int *idx, float *depth, float *uv)\n{\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  triangle_intersect_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(\n      b, n, m, cagesize, blur, n_max, ray_start, ray_dir, face_points, idx, depth, uv);\n\n  CUDA_CHECK_ERRORS();\n}\n"
  },
  {
    "path": "third_party/sparse_voxels/src/octree.cpp",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#include \"../include/octree.h\"\n#include \"../include/utils.h\"\n#include <utility>\n#include <chrono>\nusing namespace std::chrono;\n\ntypedef struct OcTree\n{\n    int depth;\n    int index;\n    at::Tensor center;\n    struct OcTree *children[8];\n    void init(at::Tensor center, int d, int i)\n    {\n        this->center = center;\n        this->depth = d;\n        this->index = i;\n        for (int i = 0; i < 8; i++)\n            this->children[i] = nullptr;\n    }\n} OcTree;\n\nclass EasyOctree\n{\npublic:\n    OcTree *root;\n    int total;\n    int terminal;\n\n    at::Tensor all_centers;\n    at::Tensor all_children;\n\n    EasyOctree(at::Tensor center, int depth)\n    {\n        root = new OcTree;\n        root->init(center, depth, -1);\n        total = -1;\n        terminal = -1;\n    }\n    ~EasyOctree()\n    {\n        OcTree *p = root;\n        destory(p);\n    }\n    void destory(OcTree *&p);\n    void insert(OcTree *&p, at::Tensor point, int index);\n    void finalize();\n    std::pair<int, int> count(OcTree *&p);\n};\n\nvoid EasyOctree::destory(OcTree *&p)\n{\n    if (p != nullptr)\n    {\n        for (int i = 0; i < 8; i++)\n        {\n            if (p->children[i] != nullptr)\n                destory(p->children[i]);\n        }\n        delete p;\n        p = nullptr;\n    }\n}\n\nvoid EasyOctree::insert(OcTree *&p, at::Tensor point, int index)\n{\n    at::Tensor diff = (point > p->center).to(at::kInt);\n    int idx = diff[0].item<int>() + 2 * diff[1].item<int>() + 4 * diff[2].item<int>();\n    if (p->depth == 0)\n    {\n        p->children[idx] = new OcTree;\n        p->children[idx]->init(point, -1, index);\n    }\n    else\n    {\n        if (p->children[idx] == nullptr)\n        {\n            int length = 1 << (p->depth - 1);\n            at::Tensor new_center = p->center + (2 * diff - 1) * length;\n            p->children[idx] = new OcTree;\n            p->children[idx]->init(new_center, p->depth - 1, -1);\n        }\n        insert(p->children[idx], point, index);\n    }\n}\n\nstd::pair<int, int> EasyOctree::count(OcTree *&p)\n{\n    int total = 0, terminal = 0;\n    for (int i = 0; i < 8; i++)\n    {\n        if (p->children[i] != nullptr)\n        {\n            std::pair<int, int> sub = count(p->children[i]);\n            total += sub.first;\n            terminal += sub.second;\n        }\n    }\n    total += 1;\n    if (p->depth == -1)\n        terminal += 1;\n    return std::make_pair(total, terminal);\n}\n\nvoid EasyOctree::finalize()\n{\n    std::pair<int, int> outs = count(root);\n    total = outs.first;\n    terminal = outs.second;\n\n    all_centers =\n        torch::zeros({outs.first, 3}, at::device(root->center.device()).dtype(at::ScalarType::Int));\n    all_children =\n        -torch::ones({outs.first, 9}, at::device(root->center.device()).dtype(at::ScalarType::Int));\n\n    int node_idx = outs.first - 1;\n    root->index = node_idx;\n\n    std::queue<OcTree *> all_leaves;\n    all_leaves.push(root);\n    while (!all_leaves.empty())\n    {\n        OcTree *node_ptr = all_leaves.front();\n        all_leaves.pop();\n        for (int i = 0; i < 8; i++)\n        {\n            if (node_ptr->children[i] != nullptr)\n            {\n                if (node_ptr->children[i]->depth > -1)\n                {\n                    node_idx--;\n                    node_ptr->children[i]->index = node_idx;\n                }\n                all_leaves.push(node_ptr->children[i]);\n                all_children[node_ptr->index][i] = node_ptr->children[i]->index;\n            }\n        }\n        all_children[node_ptr->index][8] = 1 << (node_ptr->depth + 1);\n        all_centers[node_ptr->index] = node_ptr->center;\n    }\n    assert(node_idx == outs.second);\n};\n\nstd::tuple<at::Tensor, at::Tensor> build_octree(at::Tensor center, at::Tensor points, int depth)\n{\n    auto start = high_resolution_clock::now();\n    EasyOctree tree(center, depth);\n    for (int k = 0; k < points.size(0); k++)\n        tree.insert(tree.root, points[k], k);\n    tree.finalize();\n    auto stop = high_resolution_clock::now();\n    auto duration = duration_cast<microseconds>(stop - start);\n    printf(\"Building EasyOctree done. total #nodes = %d, terminal #nodes = %d (time taken %f s)\\n\",\n           tree.total, tree.terminal, float(duration.count()) / 1000000.);\n    return std::make_tuple(tree.all_centers, tree.all_children);\n}"
  },
  {
    "path": "third_party/sparse_voxels/src/sample.cpp",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#include \"../include/sample.h\"\n#include \"../include/utils.h\"\n#include <utility>\n\nvoid uniform_ray_sampling_kernel_wrapper(\n    int b, int num_rays, int max_hits, int max_steps, float step_size,\n    const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise,\n    int *sampled_idx, float *sampled_depth, float *sampled_dists);\n\nvoid inverse_cdf_sampling_kernel_wrapper(\n    int b, int num_rays, int max_hits, int max_steps, float fixed_step_size,\n    const int *pts_idx, const float *min_depth, const float *max_depth,\n    const float *uniform_noise, const float *probs, const float *steps,\n    int *sampled_idx, float *sampled_depth, float *sampled_dists);\n\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling(\n    at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,\n    const float step_size, const int max_steps)\n{\n\n  CHECK_CONTIGUOUS(pts_idx);\n  CHECK_CONTIGUOUS(min_depth);\n  CHECK_CONTIGUOUS(max_depth);\n  CHECK_CONTIGUOUS(uniform_noise);\n  CHECK_IS_FLOAT(min_depth);\n  CHECK_IS_FLOAT(max_depth);\n  CHECK_IS_FLOAT(uniform_noise);\n  CHECK_IS_INT(pts_idx);\n  CHECK_CUDA(pts_idx);\n  CHECK_CUDA(min_depth);\n  CHECK_CUDA(max_depth);\n  CHECK_CUDA(uniform_noise);\n\n  at::Tensor sampled_idx =\n      -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps},\n                   at::device(pts_idx.device()).dtype(at::ScalarType::Int));\n  at::Tensor sampled_depth =\n      torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},\n                   at::device(min_depth.device()).dtype(at::ScalarType::Float));\n  at::Tensor sampled_dists =\n      torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},\n                   at::device(min_depth.device()).dtype(at::ScalarType::Float));\n  uniform_ray_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2),\n                                      step_size,\n                                      pts_idx.data_ptr<int>(), min_depth.data_ptr<float>(), max_depth.data_ptr<float>(),\n                                      uniform_noise.data_ptr<float>(), sampled_idx.data_ptr<int>(),\n                                      sampled_depth.data_ptr<float>(), sampled_dists.data_ptr<float>());\n  return std::make_tuple(sampled_idx, sampled_depth, sampled_dists);\n}\n\nstd::tuple<at::Tensor, at::Tensor, at::Tensor> inverse_cdf_sampling(\n    at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,\n    at::Tensor probs, at::Tensor steps, float fixed_step_size)\n{\n\n  CHECK_CONTIGUOUS(pts_idx);\n  CHECK_CONTIGUOUS(min_depth);\n  CHECK_CONTIGUOUS(max_depth);\n  CHECK_CONTIGUOUS(probs);\n  CHECK_CONTIGUOUS(steps);\n  CHECK_CONTIGUOUS(uniform_noise);\n  CHECK_IS_FLOAT(min_depth);\n  CHECK_IS_FLOAT(max_depth);\n  CHECK_IS_FLOAT(uniform_noise);\n  CHECK_IS_FLOAT(probs);\n  CHECK_IS_FLOAT(steps);\n  CHECK_IS_INT(pts_idx);\n  CHECK_CUDA(pts_idx);\n  CHECK_CUDA(min_depth);\n  CHECK_CUDA(max_depth);\n  CHECK_CUDA(uniform_noise);\n  CHECK_CUDA(probs);\n  CHECK_CUDA(steps);\n\n  int max_steps = uniform_noise.size(-1);\n  at::Tensor sampled_idx =\n      -torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps},\n                   at::device(pts_idx.device()).dtype(at::ScalarType::Int));\n  at::Tensor sampled_depth =\n      torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},\n                   at::device(min_depth.device()).dtype(at::ScalarType::Float));\n  at::Tensor sampled_dists =\n      torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},\n                   at::device(min_depth.device()).dtype(at::ScalarType::Float));\n  inverse_cdf_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), fixed_step_size,\n                                      pts_idx.data_ptr<int>(), min_depth.data_ptr<float>(), max_depth.data_ptr<float>(),\n                                      uniform_noise.data_ptr<float>(), probs.data_ptr<float>(), steps.data_ptr<float>(),\n                                      sampled_idx.data_ptr<int>(), sampled_depth.data_ptr<float>(), sampled_dists.data_ptr<float>());\n  return std::make_tuple(sampled_idx, sampled_depth, sampled_dists);\n}"
  },
  {
    "path": "third_party/sparse_voxels/src/sample_gpu.cu",
    "content": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in the\n// LICENSE file in the root directory of this source tree.\n\n#include <math.h>\n#include <stdio.h>\n#include <stdlib.h>\n\n#include \"../include/cuda_utils.h\"\n#include \"../include/cutil_math.h\" // required for float3 vector math\n\n__global__ void uniform_ray_sampling_kernel(\n    int b, int num_rays,\n    int max_hits,\n    int max_steps,\n    float step_size,\n    const int *__restrict__ pts_idx,\n    const float *__restrict__ min_depth,\n    const float *__restrict__ max_depth,\n    const float *__restrict__ uniform_noise,\n    int *__restrict__ sampled_idx,\n    float *__restrict__ sampled_depth,\n    float *__restrict__ sampled_dists)\n{\n\n  int batch_index = blockIdx.x;\n  int index = threadIdx.x;\n  int stride = blockDim.x;\n\n  pts_idx += batch_index * num_rays * max_hits;\n  min_depth += batch_index * num_rays * max_hits;\n  max_depth += batch_index * num_rays * max_hits;\n\n  uniform_noise += batch_index * num_rays * max_steps;\n  sampled_idx += batch_index * num_rays * max_steps;\n  sampled_depth += batch_index * num_rays * max_steps;\n  sampled_dists += batch_index * num_rays * max_steps;\n\n  // loop over all rays\n  for (int j = index; j < num_rays; j += stride)\n  {\n    int H = j * max_hits, K = j * max_steps;\n    int s = 0, ucur = 0, umin = 0, umax = 0;\n    float last_min_depth, last_max_depth, curr_depth;\n\n    // sort all depths\n    while (true)\n    {\n      if ((umax == max_hits) || (ucur == max_steps) || (pts_idx[H + umax] == -1))\n      {\n        break; // reach the maximum\n      }\n      if (umin < max_hits)\n      {\n        last_min_depth = min_depth[H + umin];\n      }\n      else\n      {\n        last_min_depth = 10000.0;\n      }\n      if (umax < max_hits)\n      {\n        last_max_depth = max_depth[H + umax];\n      }\n      else\n      {\n        last_max_depth = 10000.0;\n      }\n      if (ucur < max_steps)\n      {\n        curr_depth = min_depth[H] + (float(ucur) + uniform_noise[K + ucur]) * step_size;\n      }\n\n      if ((last_max_depth <= curr_depth) && (last_max_depth <= last_min_depth))\n      {\n        sampled_depth[K + s] = last_max_depth;\n        sampled_idx[K + s] = pts_idx[H + umax];\n        umax++;\n        s++;\n        continue;\n      }\n      if ((curr_depth <= last_min_depth) && (curr_depth <= last_max_depth))\n      {\n        sampled_depth[K + s] = curr_depth;\n        sampled_idx[K + s] = pts_idx[H + umin - 1];\n        ucur++;\n        s++;\n        continue;\n      }\n      if ((last_min_depth <= curr_depth) && (last_min_depth <= last_max_depth))\n      {\n        sampled_depth[K + s] = last_min_depth;\n        sampled_idx[K + s] = pts_idx[H + umin];\n        umin++;\n        s++;\n        continue;\n      }\n    }\n\n    float l_depth, r_depth;\n    int step = 0;\n    for (ucur = 0, umin = 0, umax = 0; ucur < max_steps - 1; ucur++)\n    {\n      if (sampled_idx[K + ucur + 1] == -1)\n        break;\n      l_depth = sampled_depth[K + ucur];\n      r_depth = sampled_depth[K + ucur + 1];\n      sampled_depth[K + ucur] = (l_depth + r_depth) * .5;\n      sampled_dists[K + ucur] = (r_depth - l_depth);\n      if ((umin < max_hits) && (sampled_depth[K + ucur] >= min_depth[H + umin]) && (pts_idx[H + umin] > -1))\n        umin++;\n      if ((umax < max_hits) && (sampled_depth[K + ucur] >= max_depth[H + umax]) && (pts_idx[H + umax] > -1))\n        umax++;\n      if ((umax == max_hits) || (pts_idx[H + umax] == -1))\n        break;\n      if ((umin - 1 == umax) && (sampled_dists[K + ucur] > 0))\n      {\n        sampled_depth[K + step] = sampled_depth[K + ucur];\n        sampled_dists[K + step] = sampled_dists[K + ucur];\n        sampled_idx[K + step] = sampled_idx[K + ucur];\n        step++;\n      }\n    }\n\n    for (int s = step; s < max_steps; s++)\n    {\n      sampled_idx[K + s] = -1;\n    }\n  }\n}\n\n__global__ void inverse_cdf_sampling_kernel(\n    int b, int num_rays,\n    int max_hits,\n    int max_steps,\n    float fixed_step_size,\n    const int *__restrict__ pts_idx,\n    const float *__restrict__ min_depth,\n    const float *__restrict__ max_depth,\n    const float *__restrict__ uniform_noise,\n    const float *__restrict__ probs,\n    const float *__restrict__ steps,\n    int *__restrict__ sampled_idx,\n    float *__restrict__ sampled_depth,\n    float *__restrict__ sampled_dists)\n{\n\n  int batch_index = blockIdx.x;\n  int index = threadIdx.x;\n  int stride = blockDim.x;\n\n  pts_idx += batch_index * num_rays * max_hits;\n  min_depth += batch_index * num_rays * max_hits;\n  max_depth += batch_index * num_rays * max_hits;\n  probs += batch_index * num_rays * max_hits;\n  steps += batch_index * num_rays;\n\n  uniform_noise += batch_index * num_rays * max_steps;\n  sampled_idx += batch_index * num_rays * max_steps;\n  sampled_depth += batch_index * num_rays * max_steps;\n  sampled_dists += batch_index * num_rays * max_steps;\n\n  // loop over all rays\n  for (int j = index; j < num_rays; j += stride)\n  {\n    int H = j * max_hits, K = j * max_steps;\n    int curr_bin = 0, s = 0; // current index (bin)\n\n    float curr_min_depth = min_depth[H]; // lower depth\n    float curr_max_depth = max_depth[H]; // upper depth\n    float curr_min_cdf = 0;\n    float curr_max_cdf = probs[H];\n    float step_size = 1.0 / steps[j];\n    float z_low = curr_min_depth;\n    int total_steps = int(ceil(steps[j]));\n    bool done = false;\n\n    // optional use a fixed step size\n    if (fixed_step_size > 0.0)\n      step_size = fixed_step_size;\n\n    // sample points\n    for (int curr_step = 0; curr_step < total_steps; curr_step++)\n    {\n      float curr_cdf = (float(curr_step) + uniform_noise[K + curr_step]) * step_size;\n      // printf(\"curr_cdf: %f\\n\", curr_cdf);\n      while (curr_cdf > curr_max_cdf)\n      {\n        // first include max cdf\n        sampled_idx[K + s] = pts_idx[H + curr_bin];\n        sampled_dists[K + s] = (curr_max_depth - z_low);\n        sampled_depth[K + s] = (curr_max_depth + z_low) * .5;\n\n        // move to next cdf\n        curr_bin++;\n        s++;\n        if ((curr_bin >= max_hits) || (pts_idx[H + curr_bin] == -1))\n        {\n          done = true;\n          break;\n        }\n        curr_min_depth = min_depth[H + curr_bin];\n        curr_max_depth = max_depth[H + curr_bin];\n        curr_min_cdf = curr_max_cdf;\n        curr_max_cdf = curr_max_cdf + probs[H + curr_bin];\n        z_low = curr_min_depth;\n      }\n      if (done)\n        break;\n\n      // if the sampled cdf is inside bin\n      float u = (curr_cdf - curr_min_cdf) / (curr_max_cdf - curr_min_cdf);\n      float z = curr_min_depth + u * (curr_max_depth - curr_min_depth);\n      sampled_idx[K + s] = pts_idx[H + curr_bin];\n\n      sampled_dists[K + s] = (z - z_low);\n      sampled_depth[K + s] = (z + z_low) * .5;\n      z_low = z;\n      s++;\n    }\n\n    // if there are bins still remained\n    while ((z_low < curr_max_depth) && (!done) && (num_rays > (H + curr_bin)))\n    {\n      sampled_idx[K + s] = pts_idx[H + curr_bin];\n      sampled_dists[K + s] = (curr_max_depth - z_low);\n      sampled_depth[K + s] = (curr_max_depth + z_low) * .5;\n      curr_bin++;\n      s++;\n      if ((curr_bin >= max_hits) || (pts_idx[curr_bin] == -1))\n        break;\n\n      curr_min_depth = min_depth[H + curr_bin];\n      curr_max_depth = max_depth[H + curr_bin];\n      z_low = curr_min_depth;\n    }\n  }\n}\n\nvoid uniform_ray_sampling_kernel_wrapper(\n    int b, int num_rays, int max_hits, int max_steps, float step_size,\n    const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise,\n    int *sampled_idx, float *sampled_depth, float *sampled_dists)\n{\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  uniform_ray_sampling_kernel<<<b, opt_n_threads(num_rays), 0, stream>>>(\n      b, num_rays, max_hits, max_steps, step_size, pts_idx,\n      min_depth, max_depth, uniform_noise, sampled_idx, sampled_depth, sampled_dists);\n\n  CUDA_CHECK_ERRORS();\n}\n\nvoid inverse_cdf_sampling_kernel_wrapper(\n    int b, int num_rays, int max_hits, int max_steps, float fixed_step_size,\n    const int *pts_idx, const float *min_depth, const float *max_depth,\n    const float *uniform_noise, const float *probs, const float *steps,\n    int *sampled_idx, float *sampled_depth, float *sampled_dists)\n{\n\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\n  inverse_cdf_sampling_kernel<<<b, opt_n_threads(num_rays), 0, stream>>>(\n      b, num_rays, max_hits, max_steps, fixed_step_size,\n      pts_idx, min_depth, max_depth, uniform_noise, probs, steps,\n      sampled_idx, sampled_depth, sampled_dists);\n\n  CUDA_CHECK_ERRORS();\n}\n"
  }
]