Full Code of JunyuanDeng/NeRF-LOAM for AI

master 2fe4e8d8dd9a cached
70 files
271.3 KB
88.3k tokens
305 symbols
1 requests
Download .txt
Showing preview only (290K chars total). Download the full file or copy to clipboard to get everything.
Repository: JunyuanDeng/NeRF-LOAM
Branch: master
Commit: 2fe4e8d8dd9a
Files: 70
Total size: 271.3 KB

Directory structure:
gitextract_t7v5lyn0/

├── .gitignore
├── LICENSE
├── Readme.md
├── configs/
│   ├── kitti/
│   │   ├── kitti.yaml
│   │   ├── kitti_00.yaml
│   │   ├── kitti_01.yaml
│   │   ├── kitti_03.yaml
│   │   ├── kitti_04.yaml
│   │   ├── kitti_05.yaml
│   │   ├── kitti_06.yaml
│   │   ├── kitti_07.yaml
│   │   ├── kitti_08.yaml
│   │   ├── kitti_09.yaml
│   │   ├── kitti_10.yaml
│   │   ├── kitti_base06.yaml
│   │   └── kitti_base10.yaml
│   ├── maicity/
│   │   ├── maicity.yaml
│   │   ├── maicity_00.yaml
│   │   └── maicity_01.yaml
│   └── ncd/
│       ├── ncd.yaml
│       └── ncd_quad.yaml
├── demo/
│   ├── parser.py
│   └── run.py
├── install.sh
├── requirements.txt
├── src/
│   ├── criterion.py
│   ├── dataset/
│   │   ├── kitti.py
│   │   ├── maicity.py
│   │   └── ncd.py
│   ├── lidarFrame.py
│   ├── loggers.py
│   ├── mapping.py
│   ├── nerfloam.py
│   ├── se3pose.py
│   ├── share.py
│   ├── tracking.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── import_util.py
│   │   ├── mesh_util.py
│   │   ├── profile_util.py
│   │   └── sample_util.py
│   └── variations/
│       ├── decode_morton.py
│       ├── lidar.py
│       ├── render_helpers.py
│       └── voxel_helpers.py
└── third_party/
    ├── marching_cubes/
    │   ├── setup.py
    │   └── src/
    │       ├── mc.cpp
    │       ├── mc_data.cuh
    │       ├── mc_interp_kernel.cu
    │       ├── mc_kernel.cu
    │       └── mc_kernel_colour.cu
    ├── sparse_octree/
    │   ├── include/
    │   │   ├── octree.h
    │   │   ├── test.h
    │   │   └── utils.h
    │   ├── setup.py
    │   └── src/
    │       ├── bindings.cpp
    │       └── octree.cpp
    └── sparse_voxels/
        ├── include/
        │   ├── cuda_utils.h
        │   ├── cutil_math.h
        │   ├── intersect.h
        │   ├── octree.h
        │   ├── sample.h
        │   └── utils.h
        ├── setup.py
        └── src/
            ├── binding.cpp
            ├── intersect.cpp
            ├── intersect_gpu.cu
            ├── octree.cpp
            ├── sample.cpp
            └── sample_gpu.cu

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
__pycache__
*.egg-info
build/
dist/
logs/
.vscode/
results/
temp.sh



================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 JunyuanDeng

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: Readme.md
================================================
# NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping

This repository contains the implementation of our paper:
> **NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping** ([PDF](https://arxiv.org/pdf/2303.10709))\
> [Junyuan Deng](https://github.com/JunyuanDeng), [Qi Wu](https://github.com/Gatsby23), [Xieyuanli Chen](https://github.com/Chen-Xieyuanli), Songpengcheng Xia, Zhen Sun, Guoqing Liu, Wenxian Yu and Ling Pei\
> If you use our code in your work, please star our repo and cite our paper.

```
@inproceedings{deng2023nerfloam,
      title={NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping}, 
      author={Junyuan Deng and Qi Wu and Xieyuanli Chen and Songpengcheng Xia and Zhen Sun and Guoqing Liu and Wenxian Yu and Ling Pei},
      booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
      year={2023}

}
```

<div align=center>
<img src="./docs/NeRFLOAM.gif"> 
</div>

- *Our incrementally simultaneous odometry and mapping results on the Newer College dataset and the KITTI dataset sequence 00.*
- *The maps are dense with a form of mesh, the red line indicates the odometry results.*
- *We use the same network without training to prove the ability of generalization of our design.*


## Overview

![pipeline](./docs/pipeline.png)

**Overview of our method.** Our method is based on our neural SDF and composed of three main components:
- Neural odometry takes the pre-processed scan and optimizes the pose via back projecting the queried neural SDF; 
- Neural mapping jointly optimizes the voxel embeddings map and pose while selecting the key-scans; 
- Key-scans refined map returns SDF value and the final mesh is reconstructed by marching cube.

## Quatitative results

**The reconstructed maps**
![odomap_kitti](./docs/odomap_kitti.png)
*The qualitative result of our odometry mapping on the KITTI dataset. From left upper to right bottom, we list the results of sequences 00, 01, 03, 04, 05, 09, 10.*

**The odometry results**
![odo_qual](./docs/odo_qual.png)
*The qualitative results of our odometry on the KITTI dataset. From left to right, we list the results of sequences 00, 01, 03, 04, 05, 07, 09, 10. The dashed line corresponds to the ground truth and the blue line to our odometry method.*


## Data

1. Newer College real-world LiDAR dataset: [website](https://ori-drs.github.io/newer-college-dataset/download/). 

2. MaiCity synthetic LiDAR dataset: [website](https://www.ipb.uni-bonn.de/data/mai-city-dataset/).

3. KITTI dataset: [website](https://www.cvlibs.net/datasets/kitti/).

## Environment Setup

To run the code, a GPU with large memory is preferred. We tested the code with RTX3090 and GTX TITAN.

We use Conda to create a virtual environment and install dependencies:

- python environment: We tested our code with Python 3.8.13

- [Pytorch](https://pytorch.org/get-started/locally/): The Version we tested is 1.10 with cuda10.2 (and cuda11.1)

- Other depedencies are specified in requirements.txt. You can then install all dependancies using `pip` or `conda`: 
```
pip3 install -r requirements.txt
```

- After you have installed all third party libraries, run the following script to build extra Pytorch modules used in this project.

```bash
sh install.sh
```


- Replace the filename in mapping.py with the built library
```python
torch.classes.load_library("third_party/sparse_octree/build/lib.xxx/svo.xxx.so")
```

- [patchwork-plusplus](https://github.com/url-kaist/patchwork-plusplus) to separate gound from LiDAR points.

- Replace the filename in src/dataset/*.py with the built library
```python
patchwork_module_path ="/xxx/patchwork-plusplus/build/python_wrapper"
```

## Demo

- The full dataset can be downloaded as mentioned. You can also download the example part of dataset, use these [scripts](https://github.com/PRBonn/SHINE_mapping/tree/master/scripts) to download.
- Take maicity seq.01 dataset as example: Modify `configs/maicity/maicity_01.yaml` so the data_path section points to the real dataset path. Now you are all set to run the code: 
```
python demo/run.py configs/maicity/maicity_01.yaml
```

## Note

- For kitti dataset, if you want to process it more fast, you can switch to branch `subscene`:
```
git checkout subscene
```
- Then run with `python demo/run.py configs/kitti/kitti_00.yaml`
- This branch cut the full scene into subscenes to speed up and concatenate them together. This will certainly add map inconsistency and decay tracking accuracy...

## Evaluation

- We follow the evaluation proposed [here](https://github.com/PRBonn/SHINE_mapping/tree/master/eval), but we did not use the `crop_intersection.py`


## Acknowledgement

Some of our codes are adapted from [Vox-Fusion](https://github.com/zju3dv/Vox-Fusion).

## Contact

Any questions or suggestions are welcome!

Junyuan Deng: d.juney@sjtu.edu.cn and Xieyuanli Chen: xieyuanli.chen@nudt.edu.cn

## License

This project is free software made available under the MIT License. For details see the LICENSE file.


================================================
FILE: configs/kitti/kitti.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: kitti

criteria:
  sdf_weight: 10000.0
  fs_weight: 1
  eiko_weight: 0.1
  sdf_truncation: 0.30

decoder_specs:
  depth: 2
  width: 256
  in_dim: 16
  skips: []
  embedder: none
  multires: 0

tracker_specs:
  N_rays: 2048
  learning_rate: 0.06
  step_size: 0.2
  max_voxel_hit: 20
  num_iterations: 25

mapper_specs:
  N_rays_each: 2048
  use_local_coord: False
  voxel_size: 0.3
  step_size: 0.5
  window_size: 4
  num_iterations: 25
  max_voxel_hit: 20
  final_iter: True
  mesh_res: 2
  learning_rate_emb: 0.01
  learning_rate_decorder: 0.005
  learning_rate_pose: 0.001
  freeze_frame: 5
  keyframe_gap: 8
  remove_back: False
  key_distance: 12

debug_args:
  verbose: False
  mesh_freq: 100


================================================
FILE: configs/kitti/kitti_00.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence00



data_specs:
  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/00'
  use_gt: False
  max_depth: 40
  min_depth: 5

tracker_specs:
  start_frame: 0
  end_frame: -1
  read_offset: 1

================================================
FILE: configs/kitti/kitti_01.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence01



data_specs:
  data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/01/'
  use_gt: False
  max_depth: 30
  min_depth: 5

tracker_specs:
  start_frame: 0
  end_frame: 1101
  read_offset: 1

================================================
FILE: configs/kitti/kitti_03.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence03



data_specs:
  data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/03/'
  use_gt: False
  max_depth: 30
  min_depth: 5

tracker_specs:
  start_frame: 0
  end_frame: 1101
  read_offset: 1

================================================
FILE: configs/kitti/kitti_04.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence04



data_specs:
  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/04'
  use_gt: False
  max_depth: 50
  min_depth: 2.75

tracker_specs:
  start_frame: 0
  end_frame: 270
  read_offset: 1

================================================
FILE: configs/kitti/kitti_05.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence05



data_specs:
  data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/05/'
  use_gt: False
  max_depth: 50
  min_depth: 5

tracker_specs:
  start_frame: 2299
  end_frame: 2760
  read_offset: 1

================================================
FILE: configs/kitti/kitti_06.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence06



data_specs:
  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/06'
  use_gt: False
  max_depth: 40
  min_depth: 5

tracker_specs:
  start_frame: 0
  end_frame: -1
  read_offset: 1

================================================
FILE: configs/kitti/kitti_07.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence07



data_specs:
  data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/07'
  use_gt: True
  max_depth: 25
  min_depth: 5

tracker_specs:
  start_frame: 0
  end_frame: 1100
  read_offset: 1

================================================
FILE: configs/kitti/kitti_08.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence08



data_specs:
  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/08'
  use_gt: False
  max_depth: 40
  min_depth: 5

tracker_specs:
  start_frame: 0
  end_frame: -1
  read_offset: 1

================================================
FILE: configs/kitti/kitti_09.yaml
================================================
base_config: configs/kitti/kitti.yaml

exp_name: kitti/sqeuence09



data_specs:
  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/09'
  use_gt: False
  max_depth: 40
  min_depth: 5

tracker_specs:
  start_frame: 0
  end_frame: -1
  read_offset: 1

================================================
FILE: configs/kitti/kitti_10.yaml
================================================
base_config: configs/kitti/kitti_base10.yaml

exp_name: kitti/sqeuence10



data_specs:
  data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/10'
  use_gt: False
  max_depth: 70
  min_depth: 2.75

tracker_specs:
  start_frame: 0
  end_frame: 1200
  read_offset: 1

================================================
FILE: configs/kitti/kitti_base06.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: kitti

criteria:
  sdf_weight: 10000.0
  fs_weight: 1
  eiko_weight: 0.1
  sdf_truncation: 0.30

decoder_specs:
  depth: 2
  width: 256
  in_dim: 16
  skips: []
  embedder: none
  multires: 0

tracker_specs:
  N_rays: 2048
  learning_rate: 0.06
  step_size: 0.2
  max_voxel_hit: 20
  num_iterations: 25

mapper_specs:
  N_rays_each: 2048
  use_local_coord: False
  voxel_size: 0.3
  step_size: 0.5
  window_size: 4
  num_iterations: 25
  max_voxel_hit: 20
  final_iter: True
  mesh_res: 2
  learning_rate_emb: 0.01
  learning_rate_decorder: 0.005
  learning_rate_pose: 0.001
  freeze_frame: 5
  keyframe_gap: 8
  remove_back: False
  key_distance: 12

debug_args:
  verbose: False
  mesh_freq: 100

================================================
FILE: configs/kitti/kitti_base10.yaml
================================================
log_dir: '/home/evsjtu2/disk1/dengjunyuan/running_logs/'
decoder: lidar
dataset: kitti

criteria:
  depth_weight: 0
  sdf_weight: 12000.0
  fs_weight: 1
  eiko_weight: 0
  sdf_truncation: 0.50

decoder_specs:
  depth: 2
  width: 256
  in_dim: 16
  skips: []
  embedder: none
  multires: 0

tracker_specs:
  N_rays: 2048
  learning_rate: 0.1
  start_frame: 0
  end_frame: -1
  step_size: 0.2
  show_imgs: False
  max_voxel_hit: 20
  keyframe_freq: 10
  num_iterations: 40

mapper_specs:
  N_rays_each: 2048
  num_embeddings: 20000000
  use_local_coord: False
  voxel_size: 0.2
  step_size: 0.2
  window_size: 4
  num_iterations: 20
  max_voxel_hit: 20
  final_iter: True
  mesh_res: 2
  overlap_th: 0.8
  learning_rate_emb: 0.03
  learning_rate_decorder: 0.005
  learning_rate_pose: 0.001
  #max_depth_first: 20
  freeze_frame: 10
  keyframe_gap: 7
  remove_back: True
  key_distance: 7

debug_args:
  verbose: False
  mesh_freq: 100

================================================
FILE: configs/maicity/maicity.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: maicity

criteria:
  sdf_weight: 10000.0
  fs_weight: 1
  eiko_weight: 0.1
  sdf_truncation: 0.30

decoder_specs:
  depth: 2
  width: 256
  in_dim: 16
  skips: []
  embedder: none
  multires: 0

tracker_specs:
  N_rays: 2048
  learning_rate: 0.005
  step_size: 0.2
  max_voxel_hit: 20
  num_iterations: 20

mapper_specs:
  N_rays_each: 2048
  use_local_coord: False
  voxel_size: 0.2
  step_size: 0.5
  window_size: 4
  num_iterations: 20
  max_voxel_hit: 20
  final_iter: True
  mesh_res: 2
  learning_rate_emb: 0.03
  learning_rate_decorder: 0.005
  learning_rate_pose: 0.001
  freeze_frame: 5
  keyframe_gap: 8
  remove_back: False
  key_distance: 12

debug_args:
  verbose: False
  mesh_freq: 100


================================================
FILE: configs/maicity/maicity_00.yaml
================================================
base_config: configs/maicity/maicity.yaml

exp_name: maicity/sqeuence00


data_specs:
  data_path: '/home/pl21n4/dataset/mai_city/bin/sequences/00'
  use_gt: False
  max_depth: 50.0
  min_depth: 1.5

tracker_specs:
  start_frame: 0
  end_frame: 699
  read_offset: 1

================================================
FILE: configs/maicity/maicity_01.yaml
================================================
base_config: configs/maicity/maicity.yaml

exp_name: maicity/sqeuence01


data_specs:
  data_path: '/home/pl21n4/dataset/mai_city/bin/sequences/01'
  use_gt: False
  max_depth: 50.0
  min_depth: 1.5

tracker_specs:
  start_frame: 0
  end_frame: 99
  read_offset: 1

================================================
FILE: configs/ncd/ncd.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: ncd

criteria:
  sdf_weight: 10000.0
  fs_weight: 1
  eiko_weight: 1.0
  sdf_truncation: 0.30

decoder_specs:
  depth: 2
  width: 256
  in_dim: 16
  skips: []
  embedder: none
  multires: 0

tracker_specs:
  N_rays: 2048
  learning_rate: 0.04
  step_size: 0.1
  max_voxel_hit: 20
  num_iterations: 30

mapper_specs:
  N_rays_each: 2048
  use_local_coord: False
  voxel_size: 0.2
  step_size: 0.2
  window_size: 5
  num_iterations: 15
  max_voxel_hit: 20
  final_iter: True
  mesh_res: 2
  learning_rate_emb: 0.002
  learning_rate_decorder: 0.005
  learning_rate_pose: 0.001
  freeze_frame: 20
  keyframe_gap: 8
  remove_back: False
  key_distance: 20

debug_args:
  verbose: False
  mesh_freq: 500


================================================
FILE: configs/ncd/ncd_quad.yaml
================================================
base_config: configs/ncd/ncd.yaml

exp_name: ncd/quad



data_specs:
  data_path: '/home/pl21n4/dataset/ncd_example/quad'
  use_gt: False
  max_depth: 50
  min_depth: 1.5

tracker_specs:
  start_frame: 0
  end_frame: -1
  read_offset: 5



================================================
FILE: demo/parser.py
================================================
import yaml
import argparse

class ArgumentParserX(argparse.ArgumentParser):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_argument("config", type=str)

    def parse_args(self, args=None, namespace=None):
        _args = self.parse_known_args(args, namespace)[0]
        file_args = argparse.Namespace()
        file_args = self.parse_config_yaml(_args.config, file_args)
        file_args = self.convert_to_namespace(file_args)
        for ckey, cvalue in file_args.__dict__.items():
            try:
                self.add_argument('--' + ckey, type=type(cvalue),
                                  default=cvalue, required=False)
            except argparse.ArgumentError:
                continue
        _args = super().parse_args(args, namespace)
        return _args

    def parse_config_yaml(self, yaml_path, args=None):

        with open(yaml_path, 'r') as f:
            configs = yaml.load(f, Loader=yaml.FullLoader)

        if configs is not None:
            base_config = configs.get('base_config')
            if base_config is not None:
                base_config = self.parse_config_yaml(configs["base_config"])
                if base_config is not None:
                    configs = self.update_recursive(base_config, configs)
                else:
                    raise FileNotFoundError("base_config specified but not found!")

        return configs

    def convert_to_namespace(self, dict_in, args=None):
        if args is None:
            args = argparse.Namespace()
        for ckey, cvalue in dict_in.items():
            if ckey not in args.__dict__.keys():
                args.__dict__[ckey] = cvalue

        return args

    def update_recursive(self, dict1, dict2):
        for k, v in dict2.items():
            if k not in dict1:
                dict1[k] = dict()
            if isinstance(v, dict):
                self.update_recursive(dict1[k], v)
            else:
                dict1[k] = v
        return dict1

def get_parser():
    parser = ArgumentParserX()
    parser.add_argument("--resume", default=None, type=str)
    parser.add_argument("--debug", action='store_true')
    return parser

if __name__ == '__main__':
    args = ArgumentParserX()
    print(args.parse_args())


================================================
FILE: demo/run.py
================================================
import os  # noqa
import sys  # noqa
sys.path.insert(0, os.path.abspath('src')) # noqa
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import random
from parser import get_parser
import numpy as np
import torch
from nerfloam import nerfloam
import os
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

if __name__ == '__main__':
    args = get_parser().parse_args()
    if hasattr(args, 'seeding'):
        setup_seed(args.seeding)
    else:
        setup_seed(777)

    slam = nerfloam(args)
    slam.start()
    slam.wait_child_processes()

================================================
FILE: install.sh
================================================
#!/bin/bash

cd third_party/marching_cubes
python setup.py install

cd ../sparse_octree
python setup.py install

cd ../sparse_voxels
python setup.py install

================================================
FILE: requirements.txt
================================================
matplotlib
open3d
opencv-python
PyYAML
scikit-image
tqdm
trimesh
pyyaml

================================================
FILE: src/criterion.py
================================================
import torch
import torch.nn as nn
from torch.autograd import grad


class Criterion(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        self.args = args
        self.eiko_weight = args.criteria["eiko_weight"]
        self.sdf_weight = args.criteria["sdf_weight"]
        self.fs_weight = args.criteria["fs_weight"]
        self.truncation = args.criteria["sdf_truncation"]
        self.max_dpeth = args.data_specs["max_depth"]

    def forward(self, outputs, obs, pointsCos, use_color_loss=True,
                use_depth_loss=True, compute_sdf_loss=True,
                weight_depth_loss=False, compute_eikonal_loss=False):

        points = obs
        loss = 0
        loss_dict = {}

        # pred_depth = outputs["depth"]
        pred_sdf = outputs["sdf"]
        z_vals = outputs["z_vals"]
        ray_mask = outputs["ray_mask"]
        valid_mask = outputs["valid_mask"]
        sampled_xyz = outputs["sampled_xyz"]
        gt_points = points[ray_mask]
        pointsCos = pointsCos[ray_mask]
        gt_distance = torch.norm(gt_points, 2, -1)

        gt_distance = gt_distance * pointsCos.view(-1)
        z_vals = z_vals * pointsCos.view(-1, 1)

        if compute_sdf_loss:
            fs_loss, sdf_loss, eikonal_loss = self.get_sdf_loss(
                z_vals, gt_distance, pred_sdf,
                truncation=self.truncation,
                loss_type='l2',
                valid_mask=valid_mask,
                compute_eikonal_loss=compute_eikonal_loss,
                points=sampled_xyz if compute_eikonal_loss else None
            )
            loss += self.fs_weight * fs_loss
            loss += self.sdf_weight * sdf_loss
            # loss += self.bs_weight * back_loss
            loss_dict["fs_loss"] = fs_loss.item()
            # loss_dict["bs_loss"] = back_loss.item()
            loss_dict["sdf_loss"] = sdf_loss.item()
            if compute_eikonal_loss:
                loss += self.eiko_weight * eikonal_loss
                loss_dict["eiko_loss"] = eikonal_loss.item()
        loss_dict["loss"] = loss.item()
        # print(loss_dict)
        return loss, loss_dict

    def compute_loss(self, x, y, mask=None, loss_type="l2"):
        if mask is None:
            mask = torch.ones_like(x).bool()
        if loss_type == "l1":
            return torch.mean(torch.abs(x - y)[mask])
        elif loss_type == "l2":
            return torch.mean(torch.square(x - y)[mask])

    def get_masks(self, z_vals, depth, epsilon):
        front_mask = torch.where(
            z_vals < (depth - epsilon),
            torch.ones_like(z_vals),
            torch.zeros_like(z_vals),
        )
        back_mask = torch.where(
            z_vals > (depth + epsilon),
            torch.ones_like(z_vals),
            torch.zeros_like(z_vals),
        )
        depth_mask = torch.where(
            (depth > 0.0) & (depth < self.max_dpeth), torch.ones_like(
                depth), torch.zeros_like(depth)
        )
        sdf_mask = (1.0 - front_mask) * (1.0 - back_mask) * depth_mask

        num_fs_samples = torch.count_nonzero(front_mask).float()
        num_sdf_samples = torch.count_nonzero(sdf_mask).float()
        num_samples = num_sdf_samples + num_fs_samples
        fs_weight = 1.0 - num_fs_samples / num_samples
        sdf_weight = 1.0 - num_sdf_samples / num_samples

        return front_mask, sdf_mask, fs_weight, sdf_weight

    def get_sdf_loss(self, z_vals, depth, predicted_sdf, truncation, valid_mask, loss_type="l2", compute_eikonal_loss=False, points=None):

        front_mask, sdf_mask, fs_weight, sdf_weight = self.get_masks(
            z_vals, depth.unsqueeze(-1).expand(*z_vals.shape), truncation
        )
        fs_loss = (self.compute_loss(predicted_sdf * front_mask * valid_mask, torch.ones_like(
            predicted_sdf) * front_mask, loss_type=loss_type,) * fs_weight)
        sdf_loss = (self.compute_loss((z_vals + predicted_sdf * truncation) * sdf_mask * valid_mask,
                    depth.unsqueeze(-1).expand(*z_vals.shape) * sdf_mask, loss_type=loss_type,) * sdf_weight)
        # back_loss = (self.compute_loss(predicted_sdf * back_mask, -torch.ones_like(
        #     predicted_sdf) * back_mask, loss_type=loss_type,) * back_weight)
        eikonal_loss = None
        if compute_eikonal_loss:
            sdf = (predicted_sdf*sdf_mask*truncation)
            sdf = sdf[valid_mask]
            d_points = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
            sdf_grad = grad(outputs=sdf,
                            inputs=points,
                            grad_outputs=d_points,
                            retain_graph=True,
                            only_inputs=True)[0]
            eikonal_loss = self.compute_loss(sdf_grad[0].norm(2, -1), 1.0, loss_type=loss_type,)

        return fs_loss, sdf_loss, eikonal_loss


================================================
FILE: src/dataset/kitti.py
================================================
import os.path as osp

import numpy as np
import torch
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree

patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp
params = pypatchworkpp.Parameters()
# params.verbose = True

PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)


class DataLoader(Dataset):
    def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
        self.data_path = data_path
        self.num_bin = len(glob(osp.join(self.data_path, "velodyne/*.bin")))
        self.use_gt = use_gt
        self.max_depth = max_depth
        self.min_depth = min_depth
        self.gt_pose = self.load_gt_pose() if use_gt else None

    def get_init_pose(self, frame):
        if self.gt_pose is not None:
            return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
                                  ).reshape(4, 4)
        else:
            return np.eye(4)

    def load_gt_pose(self):
        gt_file = osp.join(self.data_path, "poses_lidar.txt")
        gt_pose = np.loadtxt(gt_file)
        return gt_pose

    def load_points(self, index):
        remove_abnormal_z = True
        path = osp.join(self.data_path, "velodyne/{:06d}.bin".format(index))
        points = np.fromfile(path, dtype=np.float32, count=-1).reshape([-1, 4])
        if remove_abnormal_z:
            points = points[points[:, 2] > -3.0]
        points_norm = np.linalg.norm(points[:, :3], axis=-1)
        point_mask = True
        if self.max_depth != -1:
            point_mask = (points_norm < self.max_depth) & point_mask
        if self.min_depth != -1:
            point_mask = (points_norm > self.min_depth) & point_mask

        if isinstance(point_mask, np.ndarray):
            points = points[point_mask]

        PatchworkPLUSPLUS.estimateGround(points)
        ground = PatchworkPLUSPLUS.getGround()
        nonground = PatchworkPLUSPLUS.getNonground()
        Patchcenters = PatchworkPLUSPLUS.getCenters()
        normals = PatchworkPLUSPLUS.getNormals()
        T = cKDTree(Patchcenters)
        _, index = T.query(ground)
        if True:
            groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
        else:
            groundcos = np.ones(ground.shape[0])
        points = np.concatenate((ground, nonground), axis=0)
        pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)

        return points, pointcos

    def __len__(self):
        return self.num_bin

    def __getitem__(self, index):
        points, pointcos = self.load_points(index)
        points = torch.from_numpy(points).float()
        pointcos = torch.from_numpy(pointcos).float()
        pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
                              ).reshape(4, 4) if self.use_gt else None
        return index, points, pointcos, pose


if __name__ == "__main__":
    path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
    loader = DataLoader(path)
    for data in loader:
        index, points, pose = data
        print("current index ", index)
        print("first 10th points:\n", points[:10])
        if index > 10:
            break
        index += 1


================================================
FILE: src/dataset/maicity.py
================================================
import os.path as osp

import numpy as np
import torch
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree

patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp

params = pypatchworkpp.Parameters()
# params.verbose = True

PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)


class DataLoader(Dataset):
    def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
        self.data_path = data_path
        self.num_bin = len(glob(osp.join(self.data_path, "velodyne/*.bin")))
        self.use_gt = use_gt
        self.max_depth = max_depth
        self.min_depth = min_depth
        self.gt_pose = self.load_gt_pose() if use_gt else None

    def get_init_pose(self, frame):
        if self.gt_pose is not None:
            return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
                                  ).reshape(4, 4)
        else:
            return np.eye(4)

    def load_gt_pose(self):
        gt_file = osp.join(self.data_path, "poses.txt")
        gt_pose = np.loadtxt(gt_file)
        return gt_pose

    def load_points(self, index):
        path = osp.join(self.data_path, "velodyne/{:05d}.bin".format(index))
        points = np.fromfile(path, dtype=np.float32, count=-1).reshape([-1, 4])
        # points = points[:,:3]
        # points = np.delete(points, -1, axis=1)
        points_norm = np.linalg.norm(points[:, :3], axis=-1)
        point_mask = True
        if self.max_depth != -1:
            point_mask = (points_norm < self.max_depth) & point_mask
        if self.min_depth != -1:
            point_mask = (points_norm > self.min_depth) & point_mask
        if isinstance(point_mask, np.ndarray):
            points = points[point_mask]

        PatchworkPLUSPLUS.estimateGround(points)
        ground = PatchworkPLUSPLUS.getGround()
        nonground = PatchworkPLUSPLUS.getNonground()
        Patchcenters = PatchworkPLUSPLUS.getCenters()
        normals = PatchworkPLUSPLUS.getNormals()
        T = cKDTree(Patchcenters)
        _, index = T.query(ground)
        if True:
            groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
            # groundnorm = np.linalg.norm(ground, axis=-1)
            # groundcos = np.where(groundnorm > 10.0, np.ones(ground.shape[0]), groundcos)

        else:
            groundcos = np.ones(ground.shape[0])
        points = np.concatenate((ground, nonground), axis=0)
        pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)

        return points, pointcos

    def __len__(self):
        return self.num_bin

    def __getitem__(self, index):
        points, pointcos = self.load_points(index)
        points = torch.from_numpy(points).float()
        pointcos = torch.from_numpy(pointcos).float()
        pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
                              ).reshape(4, 4) if self.use_gt else None
        return index, points, pointcos, pose


if __name__ == "__main__":
    path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
    loader = DataLoader(path)
    for data in loader:
        index, points, pose = data
        print("current index ", index)
        print("first 10th points:\n", points[:10])
        if index > 10:
            break
        index += 1


================================================
FILE: src/dataset/ncd.py
================================================
import os.path as osp

import numpy as np
import torch
import open3d as o3d
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree

patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp
params = pypatchworkpp.Parameters()
params.enable_RNR = False
# params.verbose = True

PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)


class DataLoader(Dataset):
    def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
        self.data_path = data_path
        self.num_bin = len(glob(osp.join(self.data_path, "pcd/*.pcd")))
        self.use_gt = use_gt
        self.max_depth = max_depth
        self.min_depth = min_depth
        self.gt_pose = self.load_gt_pose() if use_gt else None

    def get_init_pose(self, frame):
        if self.gt_pose is not None:
            return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
                                  ).reshape(4, 4)
        else:
            return np.array([[5.925493285036220747e-01, -8.038419275143061649e-01, 5.218676416200035417e-02, -2.422443415414985424e-01],
                            [8.017167514002809803e-01, 5.948020209102693467e-01, 5.882863457495644127e-02,  3.667865561670570873e+00],
                            [-7.832971094540422397e-02, 6.980134849334420320e-03, 9.969030746023688216e-01, 6.809443654823238434e-01]])

    def load_gt_pose(self):
        gt_file = osp.join(self.data_path, "poses.txt")
        gt_pose = np.loadtxt(gt_file)
        # with open(gt_file, mode='r', encoding="utf-8") as g:
        #     line = g.readline()
        #     while line:
        #         # TODO:write transfomation of kitti and pose matrix
        #         pose = np.zeros(16)
        return gt_pose

    def load_points(self, index):
        path = osp.join(self.data_path, "pcd/{:05d}.pcd".format(index+500))
        pc_load = o3d.io.read_point_cloud(path)
        points = np.asarray(pc_load.points)

        points_norm = np.linalg.norm(points, axis=-1)
        point_mask = True
        if self.max_depth != -1:
            point_mask = (points_norm < self.max_depth) & point_mask
        if self.min_depth != -1:
            point_mask = (points_norm > self.min_depth) & point_mask
        if isinstance(point_mask, np.ndarray):
            points = points[point_mask]

        PatchworkPLUSPLUS.estimateGround(points)
        ground = PatchworkPLUSPLUS.getGround()
        nonground = PatchworkPLUSPLUS.getNonground()
        Patchcenters = PatchworkPLUSPLUS.getCenters()
        normals = PatchworkPLUSPLUS.getNormals()
        T = cKDTree(Patchcenters)
        _, index = T.query(ground)
        if True:
            groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
        else:
            groundcos = np.ones(ground.shape[0])
        points = np.concatenate((ground, nonground), axis=0)
        pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)

        return points, pointcos

    def __len__(self):
        return self.num_bin

    def __getitem__(self, index):
        points, pointcos = self.load_points(index)
        points = torch.from_numpy(points).float()
        pointcos = torch.from_numpy(pointcos).float()
        pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
                              ).reshape(4, 4) if self.use_gt else None
        return index, points, pointcos, pose


if __name__ == "__main__":
    path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
    loader = DataLoader(path)
    for data in loader:
        index, points, pose = data
        print("current index ", index)
        print("first 10th points:\n", points[:10])
        if index > 10:
            break
        index += 1


================================================
FILE: src/lidarFrame.py
================================================
import torch
import torch.nn as nn
import numpy as np
from se3pose import OptimizablePose
from utils.sample_util import *
import random


class LidarFrame(nn.Module):
    def __init__(self, index, points, pointsCos, pose=None, new_keyframe=False) -> None:
        super().__init__()
        self.index = index
        self.num_point = len(points)
        self.points = points
        self.pointsCos = pointsCos
        if (not new_keyframe) and (pose is not None):
            # TODO: fix this offset
            pose[:3, 3] += 2000
            pose = torch.tensor(pose, requires_grad=True, dtype=torch.float32)
            self.pose = OptimizablePose.from_matrix(pose)
        elif new_keyframe:
            self.pose = pose
        self.rays_d = self.get_rays()
        self.rel_pose = None

    def get_pose(self):
        return self.pose.matrix()

    def get_translation(self):
        return self.pose.translation()

    def get_rotation(self):
        return self.pose.rotation()

    def get_points(self):
        return self.points

    def get_pointsCos(self):
        return self.pointsCos

    def set_rel_pose(self, rel_pose):
        self.rel_pose = rel_pose

    def get_rel_pose(self):
        return self.rel_pose

    @torch.no_grad()
    def get_rays(self):
        self.rays_norm = (torch.norm(self.points, 2, -1, keepdim=True)+1e-8)
        rays_d = self.points / self.rays_norm
        # TODO: to keep cosistency, add one dim, but actually no need
        return rays_d.unsqueeze(1).float()

    @torch.no_grad()
    def sample_rays(self, N_rays, track=False):
        self.sample_mask = sample_rays(
            torch.ones((self.num_point, 1))[None, ...], N_rays)[0, ...]





================================================
FILE: src/loggers.py
================================================
import os
import os.path as osp
import pickle
from datetime import datetime

import cv2
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import torch
import yaml


class BasicLogger:
    def __init__(self, args) -> None:
        self.args = args
        self.log_dir = osp.join(
            args.log_dir, args.exp_name, self.get_random_time_str())
        self.img_dir = osp.join(self.log_dir, "imgs")
        self.mesh_dir = osp.join(self.log_dir, "mesh")
        self.ckpt_dir = osp.join(self.log_dir, "ckpt")
        self.backup_dir = osp.join(self.log_dir, "bak")
        self.misc_dir = osp.join(self.log_dir, "misc")

        os.makedirs(self.img_dir)
        os.makedirs(self.ckpt_dir)
        os.makedirs(self.mesh_dir)
        os.makedirs(self.misc_dir)
        os.makedirs(self.backup_dir)

        self.log_config(args)

    def get_random_time_str(self):
        return datetime.strftime(datetime.now(), "%Y-%m-%d-%H-%M-%S")

    def log_ckpt(self, mapper):
        print("******* saving *******")
        decoder_state = {f: v.cpu()
                         for f, v in mapper.decoder.state_dict().items()}
        map_state = {f: v.cpu() for f, v in mapper.map_states.items()}
        embeddings = mapper.dynamic_embeddings.cpu()
        svo = mapper.svo
        torch.save({
            "decoder_state": decoder_state,
            # "map_state": map_state,
            "embeddings": embeddings,
            "svo": svo
        },
            os.path.join(self.ckpt_dir, "final_ckpt.pth"))
        print("******* finish saving *******")

    def log_config(self, config):
        out_path = osp.join(self.backup_dir, "config.yaml")
        yaml.dump(config, open(out_path, 'w'))

    def log_mesh(self, mesh, name="final_mesh.ply"):
        out_path = osp.join(self.mesh_dir, name)
        o3d.io.write_triangle_mesh(out_path, mesh)

    def log_point_cloud(self, pcd, name="final_points.ply"):
        out_path = osp.join(self.mesh_dir, name)
        o3d.io.write_point_cloud(out_path, pcd)

    def log_numpy_data(self, data, name, ind=None):
        if isinstance(data, torch.Tensor):
            data = data.detach().cpu().numpy()
        if ind is not None:
            np.save(osp.join(self.misc_dir, "{}-{:05d}.npy".format(name, ind)), data)
        else:
            np.save(osp.join(self.misc_dir, f"{name}.npy"), data)
            self.npy2txt(osp.join(self.misc_dir, f"{name}.npy"), osp.join(self.misc_dir, f"{name}.txt"))

    def log_debug_data(self, data, idx):
        with open(os.path.join(self.misc_dir, f"scene_data_{idx}.pkl"), 'wb') as f:
            pickle.dump(data, f)

    def log_raw_image(self, ind, rgb, depth):
        if isinstance(rgb, torch.Tensor):
            rgb = rgb.detach().cpu().numpy()
        if isinstance(depth, torch.Tensor):
            depth = depth.detach().cpu().numpy()
        rgb = cv2.cvtColor(rgb*255, cv2.COLOR_RGB2BGR)
        cv2.imwrite(osp.join(self.img_dir, "{:05d}.jpg".format(
            ind)), (rgb).astype(np.uint8))
        cv2.imwrite(osp.join(self.img_dir, "{:05d}.png".format(
            ind)), (depth*5000).astype(np.uint16))

    def log_images(self, ind, gt_rgb, gt_depth, rgb, depth):
        gt_depth_np = gt_depth.detach().cpu().numpy()
        gt_color_np = gt_rgb.detach().cpu().numpy()
        depth_np = depth.squeeze().detach().cpu().numpy()
        color_np = rgb.detach().cpu().numpy()

        h, w = depth_np.shape
        gt_depth_np = cv2.resize(
            gt_depth_np, (w, h), interpolation=cv2.INTER_NEAREST)
        gt_color_np = cv2.resize(
            gt_color_np, (w, h), interpolation=cv2.INTER_AREA)

        depth_residual = np.abs(gt_depth_np - depth_np)
        depth_residual[gt_depth_np == 0.0] = 0.0
        color_residual = np.abs(gt_color_np - color_np)
        color_residual[gt_depth_np == 0.0] = 0.0

        fig, axs = plt.subplots(2, 3)
        fig.tight_layout()
        max_depth = np.max(gt_depth_np)
        axs[0, 0].imshow(gt_depth_np, cmap="plasma",
                         vmin=0, vmax=max_depth)
        axs[0, 0].set_title('Input Depth')
        axs[0, 0].set_xticks([])
        axs[0, 0].set_yticks([])
        axs[0, 1].imshow(depth_np, cmap="plasma",
                         vmin=0, vmax=max_depth)
        axs[0, 1].set_title('Generated Depth')
        axs[0, 1].set_xticks([])
        axs[0, 1].set_yticks([])
        axs[0, 2].imshow(depth_residual, cmap="plasma",
                         vmin=0, vmax=max_depth)
        axs[0, 2].set_title('Depth Residual')
        axs[0, 2].set_xticks([])
        axs[0, 2].set_yticks([])
        gt_color_np = np.clip(gt_color_np, 0, 1)
        color_np = np.clip(color_np, 0, 1)
        color_residual = np.clip(color_residual, 0, 1)
        axs[1, 0].imshow(gt_color_np, cmap="plasma")
        axs[1, 0].set_title('Input RGB')
        axs[1, 0].set_xticks([])
        axs[1, 0].set_yticks([])
        axs[1, 1].imshow(color_np, cmap="plasma")
        axs[1, 1].set_title('Generated RGB')
        axs[1, 1].set_xticks([])
        axs[1, 1].set_yticks([])
        axs[1, 2].imshow(color_residual, cmap="plasma")
        axs[1, 2].set_title('RGB Residual')
        axs[1, 2].set_xticks([])
        axs[1, 2].set_yticks([])
        plt.subplots_adjust(wspace=0, hspace=0)
        plt.savefig(osp.join(self.img_dir, "{:05d}.jpg".format(
            ind)), bbox_inches='tight', pad_inches=0.2)
        plt.clf()
        plt.close()

    def npy2txt(self, input_path, output_path):
        poses = np.load(input_path)
        with open(output_path, mode='w') as w:
            shape = poses.shape
            print(shape)
            for i in range(shape[0]):
                one_pose = str()
                for j in range(shape[1]):
                    if j == (shape[1]-1):
                        continue
                    for k in range(shape[2]):
                        if j == (shape[1]-2) and k == (shape[1]-1):
                            one_pose += (str(poses[i][j][k])+"\n")
                        else:
                            one_pose += (str(poses[i][j][k])+" ")
                w.write(one_pose)


================================================
FILE: src/mapping.py
================================================
from copy import deepcopy
import random
from time import sleep
import numpy as np
from tqdm import tqdm
import torch

from criterion import Criterion
from loggers import BasicLogger
from utils.import_util import get_decoder, get_property
from variations.render_helpers import bundle_adjust_frames
from utils.mesh_util import MeshExtractor
from utils.profile_util import Profiler
from lidarFrame import LidarFrame
import torch.nn.functional as F
from pathlib import Path
import open3d as o3d

torch.classes.load_library(
    "/home/pl21n4/Programmes/Vox-Fusion/third_party/sparse_octree/build/lib.linux-x86_64-cpython-38/svo.cpython-38-x86_64-linux-gnu.so")


def get_network_size(net):
    size = 0
    for param in net.parameters():
        size += param.element_size() * param.numel()
    return size / 1024 / 1024


class Mapping:
    def __init__(self, args, logger: BasicLogger):
        super().__init__()
        self.args = args
        self.logger = logger
        self.decoder = get_decoder(args).cuda()
        print(self.decoder)
        self.loss_criteria = Criterion(args)
        self.keyframe_graph = []
        self.initialized = False

        mapper_specs = args.mapper_specs

        # optional args
        self.ckpt_freq = get_property(args, "ckpt_freq", -1)
        self.final_iter = get_property(mapper_specs, "final_iter", False)
        self.mesh_res = get_property(mapper_specs, "mesh_res", 8)
        self.save_data_freq = get_property(
            args.debug_args, "save_data_freq", 0)

        # required args
        self.voxel_size = mapper_specs["voxel_size"]
        self.window_size = mapper_specs["window_size"]
        self.num_iterations = mapper_specs["num_iterations"]
        self.n_rays = mapper_specs["N_rays_each"]
        self.sdf_truncation = args.criteria["sdf_truncation"]
        self.max_voxel_hit = mapper_specs["max_voxel_hit"]
        self.step_size = mapper_specs["step_size"]
        self.learning_rate_emb = mapper_specs["learning_rate_emb"]
        self.learning_rate_decorder = mapper_specs["learning_rate_decorder"]
        self.learning_rate_pose = mapper_specs["learning_rate_pose"]
        self.step_size = self.step_size * self.voxel_size
        self.max_distance = args.data_specs["max_depth"]
        self.freeze_frame = mapper_specs["freeze_frame"]
        self.keyframe_gap = mapper_specs["keyframe_gap"]
        self.remove_back = mapper_specs["remove_back"]
        self.key_distance = mapper_specs["key_distance"]
        
        embed_dim = args.decoder_specs["in_dim"]
        use_local_coord = mapper_specs["use_local_coord"]
        self.embed_dim = embed_dim - 3 if use_local_coord else embed_dim
        #num_embeddings = mapper_specs["num_embeddings"]
        self.mesh_freq = args.debug_args["mesh_freq"]
        self.mesher = MeshExtractor(args)


        self.voxel_id2embedding_id = -torch.ones((int(2e9), 1), dtype=torch.int)
        self.embeds_exist_search = dict()
        self.current_num_embeds = 0
        self.dynamic_embeddings = None

        self.svo = torch.classes.svo.Octree()
        self.svo.init(256*256*4, embed_dim, self.voxel_size)

        self.frame_poses = []
        self.depth_maps = []
        self.last_tracked_frame_id = 0
        self.final_poses=[]
        
        verbose = get_property(args.debug_args, "verbose", False)
        self.profiler = Profiler(verbose=verbose)
        self.profiler.enable()
        
    def spin(self, share_data, kf_buffer):
        print("mapping process started!!!!!!!!!")
        while True:
            torch.cuda.empty_cache()
            if not kf_buffer.empty():
                tracked_frame = kf_buffer.get()
                # self.create_voxels(tracked_frame)
                if not self.initialized:
                    self.first_frame_id = tracked_frame.index
                    if self.mesher is not None:
                        self.mesher.rays_d = tracked_frame.get_rays()
                    self.create_voxels(tracked_frame)
                    self.insert_keyframe(tracked_frame)
                    while kf_buffer.empty():
                        self.do_mapping(share_data, tracked_frame, selection_method='current')
                    self.initialized = True
                else:
                    if self.remove_back:
                        tracked_frame = self.remove_back_points(tracked_frame)
                    self.do_mapping(share_data, tracked_frame)
                    self.create_voxels(tracked_frame)
                    if (torch.norm(tracked_frame.pose.translation().cpu()
                        - self.current_keyframe.pose.translation().cpu())) > self.keyframe_gap:
                        self.insert_keyframe(tracked_frame)
                        print(
                            f"********** current num kfs: { len(self.keyframe_graph) } **********")

                # self.create_voxels(tracked_frame)
                tracked_pose = tracked_frame.get_pose().detach()
                ref_pose = self.current_keyframe.get_pose().detach()
                rel_pose = torch.linalg.inv(ref_pose) @ tracked_pose
                self.frame_poses += [(len(self.keyframe_graph) -
                                      1, rel_pose.cpu())]

                if self.mesh_freq > 0 and (tracked_frame.index) % self.mesh_freq == 0:
                    if self.final_iter and len(self.keyframe_graph) > 20:
                        print(f"********** post-processing steps **********")
                        #self.num_iterations = 1
                        final_num_iter = len(self.keyframe_graph) + 1
                        progress_bar = tqdm(
                            range(0, final_num_iter), position=0)
                        progress_bar.set_description(" post-processing steps")
                        for iter in progress_bar:
                            #tracked_frame=self.keyframe_graph[iter//self.window_size]
                            self.do_mapping(share_data, tracked_frame=None,
                                            update_pose=False, update_decoder=False, selection_method='random')


                    self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False),name=f"mesh_{tracked_frame.index:05d}.ply")
                    pose = self.get_updated_poses()
                    self.logger.log_numpy_data(np.asarray(pose), f"frame_poses_{tracked_frame.index:05d}")

                    if self.final_iter and len(self.keyframe_graph) > 20:
                        self.keyframe_graph = []
                        self.keyframe_graph += [self.current_keyframe]
                if self.save_data_freq > 0 and (tracked_frame.stamp + 1) % self.save_data_freq == 0:
                    self.save_debug_data(tracked_frame)
            elif share_data.stop_mapping:
                break
        print("******* extracting mesh without replay *******")
        self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False), name="final_mesh_noreplay.ply")
        if self.final_iter:
            print(f"********** post-processing steps **********")
            #self.num_iterations = 1
            final_num_iter = len(self.keyframe_graph) + 1
            progress_bar = tqdm(
                range(0, final_num_iter), position=0)
            progress_bar.set_description(" post-processing steps")
            for iter in progress_bar:
                tracked_frame=self.keyframe_graph[iter//self.window_size]
                self.do_mapping(share_data, tracked_frame=None,
                                update_pose=False, update_decoder=False, selection_method='random')

        print("******* extracting final mesh *******")
        pose = self.get_updated_poses()
        self.logger.log_numpy_data(np.asarray(pose), "frame_poses")
        self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False))
        print("******* mapping process died *******")

    def do_mapping(self, share_data, tracked_frame=None,
                   update_pose=True, update_decoder=True, selection_method = 'current'):
        self.profiler.tick("do_mapping")
        self.decoder.train()
        optimize_targets = self.select_optimize_targets(tracked_frame, selection_method=selection_method)
        torch.cuda.empty_cache()
        self.profiler.tick("bundle_adjust_frames")
        bundle_adjust_frames(
            optimize_targets,
            self.dynamic_embeddings,
            self.map_states,
            self.decoder,
            self.loss_criteria,
            self.voxel_size,
            self.step_size,
            self.n_rays * 2  if selection_method=='random' else self.n_rays,
            self.num_iterations,
            self.sdf_truncation,
            self.max_voxel_hit,
            self.max_distance,
            learning_rate=[self.learning_rate_emb, 
                           self.learning_rate_decorder, 
                           self.learning_rate_pose],
            update_pose=update_pose,
            update_decoder=update_decoder if tracked_frame == None or (tracked_frame.index -self.first_frame_id) < self.freeze_frame else False,
            profiler=self.profiler
        )
        self.profiler.tok("bundle_adjust_frames")
        # optimize_targets = [f.cpu() for f in optimize_targets]
        self.update_share_data(share_data)
        self.profiler.tok("do_mapping")
        # sleep(0.01)

    def select_optimize_targets(self, tracked_frame=None, selection_method='previous'):
        # TODO: better ways
        targets = []
        if selection_method == 'current':
            if tracked_frame == None:
                raise ValueError('select one track frame')
            else:
                return [tracked_frame]
        if len(self.keyframe_graph) <= self.window_size:
            targets = self.keyframe_graph[:]
        elif selection_method == 'random':
            targets = random.sample(self.keyframe_graph, self.window_size)
        elif selection_method == 'previous':
            targets = self.keyframe_graph[-self.window_size:]
        elif selection_method == 'overlap':
            raise NotImplementedError(
                f"seletion method {selection_method} unknown")

        if tracked_frame is not None and tracked_frame != self.current_keyframe:
            targets += [tracked_frame]
        return targets

    def update_share_data(self, share_data, frameid=None):
        share_data.decoder = deepcopy(self.decoder)
        tmp_states = {}
        for k, v in self.map_states.items():
            tmp_states[k] = v.detach().cpu()
        share_data.states = tmp_states
        # self.last_tracked_frame_id = frameid

    def remove_back_points(self, frame):
        rel_pose = frame.get_rel_pose()
        points = frame.get_points()
        points_norm = torch.norm(points, 2, -1)
        points_xy = points[:, :2]
        if rel_pose == None:
            x = 1
            y = 0
        else:
            x = rel_pose[0, 3]
            y = rel_pose[1, 3]
        rel_xy = torch.ones((1, 2))
        rel_xy[0, 0] = x
        rel_xy[0, 1] = y
        point_cos = torch.sum(-points_xy * rel_xy, dim=-1)/(
            torch.norm(points_xy, 2, -1)*(torch.norm(rel_xy, 2, -1)))
        remove_index = ((point_cos >= 0.7) & (points_norm > self.key_distance))
        new_points = frame.points[~remove_index]
        new_cos = frame.get_pointsCos()[~remove_index]
        return LidarFrame(frame.index, new_points, new_cos,
                          frame.pose, new_keyframe=True)

    def frame_maxdistance_change(self, frame, distance):
        # kf check
        valid_distance = distance + 0.5
        new_keyframe_rays_norm = frame.rays_norm.reshape(-1)
        new_keyframe_points = frame.points[new_keyframe_rays_norm <= valid_distance]
        new_keyframe_pointsCos = frame.get_pointsCos()[new_keyframe_rays_norm <= valid_distance]
        return LidarFrame(frame.index, new_keyframe_points, new_keyframe_pointsCos,
                          frame.pose, new_keyframe=True)

    def insert_keyframe(self, frame, valid_distance=-1):
        # kf check
        print("insert keyframe")
        valid_distance = self.key_distance + 0.01
        new_keyframe_rays_norm = frame.rays_norm.reshape(-1)
        mask = (torch.abs(frame.points[:, 0]) < valid_distance) & (torch.abs(frame.points[:, 1])
                                                                   < valid_distance) & (torch.abs(frame.points[:, 2]) < valid_distance)
        new_keyframe_points = frame.points[mask]
        new_keyframe_pointsCos = frame.get_pointsCos()[mask]
        new_keyframe = LidarFrame(frame.index, new_keyframe_points, new_keyframe_pointsCos,
                                  frame.pose, new_keyframe=True)
        if new_keyframe_points.shape[0] < 2*self.n_rays:
            raise ValueError('valid_distance too small')
        self.current_keyframe = new_keyframe
        self.keyframe_graph += [new_keyframe]
        # self.update_grid_features()

    def create_voxels(self, frame):
        points = frame.get_points().cuda()
        pose = frame.get_pose().cuda()
        print("frame id", frame.index+1)
        print("trans ", pose[:3, 3]-2000)
        points = points@pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
        voxels = torch.div(points, self.voxel_size, rounding_mode='floor')
        self.svo.insert(voxels.cpu().int())
        self.update_grid_features()

    @torch.enable_grad()
    def get_embeddings(self, points_idx):

        flatten_idx = points_idx.reshape(-1).long()
        valid_flatten_idx = flatten_idx[flatten_idx.ne(-1)]
        existence = F.embedding(valid_flatten_idx, self.voxel_id2embedding_id)
        torch_add_idx = existence.eq(-1).view(-1)
        torch_add = valid_flatten_idx[torch_add_idx]
        if torch_add.shape[0] == 0:
            return
        start_num = self.current_num_embeds
        end_num = start_num + torch_add.shape[0]
        embeddings_add = torch.zeros((end_num-start_num, self.embed_dim),
                                     dtype=torch.bfloat16)
        # torch.nn.init.normal_(embeddings_add, std=0.01)

        if self.dynamic_embeddings == None:
            embeddings = [embeddings_add]
        else:
            embeddings = [self.dynamic_embeddings.detach().cpu(), embeddings_add]
        embeddings = torch.cat(embeddings, dim=0)
        self.dynamic_embeddings = embeddings.cuda().requires_grad_()

        self.current_num_embeds = end_num
        self.voxel_id2embedding_id[torch_add] = torch.arange(start_num, end_num, dtype=torch.int).view(-1, 1)

    @torch.enable_grad()
    def update_grid_features(self):
        voxels, children, features = self.svo.get_centres_and_children()
        centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.voxel_size
        children = torch.cat([children, voxels[:, -1:]], -1)

        centres = centres.float()
        children = children.int()

        map_states = {}
        map_states["voxel_vertex_idx"] = features
        centres.requires_grad_()
        map_states["voxel_center_xyz"] = centres
        map_states["voxel_structure"] = children
        self.profiler.tick("Creating embedding")
        self.get_embeddings(map_states["voxel_vertex_idx"])
        self.profiler.tok("Creating embedding")
        map_states["voxel_vertex_emb"] = self.dynamic_embeddings
        map_states["voxel_id2embedding_id"] = self.voxel_id2embedding_id

        self.map_states = map_states

    @torch.no_grad()
    def get_updated_poses(self, offset=-2000):
        for i in range(len(self.frame_poses)):
            ref_frame_ind, rel_pose = self.frame_poses[i]
            ref_frame = self.keyframe_graph[ref_frame_ind]
            ref_pose = ref_frame.get_pose().detach().cpu()
            pose = ref_pose @ rel_pose
            pose[:3, 3] += offset
            self.final_poses += [pose.detach().cpu().numpy()]
        self.frame_poses = []
        return self.final_poses

    @torch.no_grad()
    def extract_mesh(self, res=8, clean_mesh=False):
        sdf_network = self.decoder
        sdf_network.eval()

        voxels, _, features = self.svo.get_centres_and_children()
        index = features.eq(-1).any(-1)
        voxels = voxels[~index, :]
        features = features[~index, :]
        centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.voxel_size

        encoder_states = {}
        encoder_states["voxel_vertex_idx"] = features
        encoder_states["voxel_center_xyz"] = centres
        self.profiler.tick("Creating embedding")
        self.get_embeddings(encoder_states["voxel_vertex_idx"])
        self.profiler.tok("Creating embedding")
        encoder_states["voxel_vertex_emb"] = self.dynamic_embeddings
        encoder_states["voxel_id2embedding_id"] = self.voxel_id2embedding_id

        frame_poses = self.get_updated_poses()
        mesh = self.mesher.create_mesh(
            self.decoder, encoder_states, self.voxel_size, voxels,
            frame_poses=None, depth_maps=None,
            clean_mseh=clean_mesh, require_color=False, offset=-2000, res=res)
        return mesh

    @torch.no_grad()
    def extract_voxels(self, offset=-10):
        voxels, _, features = self.svo.get_centres_and_children()
        index = features.eq(-1).any(-1)
        voxels = voxels[~index, :]
        features = features[~index, :]
        voxels = (voxels[:, :3] + voxels[:, -1:] / 2) * \
            self.voxel_size + offset
        # print(torch.max(features)-torch.count_nonzero(index))
        return voxels

    @torch.no_grad()
    def save_debug_data(self, tracked_frame, offset=-10):
        """
        save per-frame voxel, mesh and pose 
        """
        pose = tracked_frame.get_pose().detach().cpu().numpy()
        pose[:3, 3] += offset
        frame_poses = self.get_updated_poses()
        mesh = self.extract_mesh(res=8, clean_mesh=True)
        voxels = self.extract_voxels().detach().cpu().numpy()
        keyframe_poses = [p.get_pose().detach().cpu().numpy()
                          for p in self.keyframe_graph]

        for f in frame_poses:
            f[:3, 3] += offset
        for kf in keyframe_poses:
            kf[:3, 3] += offset

        verts = np.asarray(mesh.vertices)
        faces = np.asarray(mesh.triangles)
        color = np.asarray(mesh.vertex_colors)

        self.logger.log_debug_data({
            "pose": pose,
            "updated_poses": frame_poses,
            "mesh": {"verts": verts, "faces": faces, "color": color},
            "voxels": voxels,
            "voxel_size": self.voxel_size,
            "keyframes": keyframe_poses,
            "is_keyframe": (tracked_frame == self.current_keyframe)
        }, tracked_frame.stamp)


================================================
FILE: src/nerfloam.py
================================================
from multiprocessing.managers import BaseManager
from time import sleep

import torch
import torch.multiprocessing as mp

from loggers import BasicLogger
from mapping import Mapping
from share import ShareData, ShareDataProxy
from tracking import Tracking
from utils.import_util import get_dataset



class nerfloam:
    def __init__(self, args):
        self.args = args

        # logger (optional)
        self.logger = BasicLogger(args)

        # shared data
        mp.set_start_method('spawn', force=True)
        BaseManager.register('ShareData', ShareData, ShareDataProxy)
        manager = BaseManager()
        manager.start()
        self.share_data = manager.ShareData()
        # keyframe buffer
        self.kf_buffer = mp.Queue(maxsize=1)
        # data stream
        self.data_stream = get_dataset(args)
        # tracker
        self.tracker = Tracking(args, self.data_stream, self.logger)
        # mapper
        self.mapper = Mapping(args, self.logger)
        # initialize map with first frame
        self.tracker.process_first_frame(self.kf_buffer)
        self.processes = []

    def start(self):
        mapping_process = mp.Process(
            target=self.mapper.spin, args=(self.share_data, self.kf_buffer))
        mapping_process.start()
        print("initializing the first frame ...")
        sleep(20)
        # self.share_data.stop_mapping=True
        tracking_process = mp.Process(
            target=self.tracker.spin, args=(self.share_data, self.kf_buffer))
        tracking_process.start()

        self.processes = [mapping_process, tracking_process]



    def wait_child_processes(self):
        for p in self.processes:
            p.join()

    @torch.no_grad()
    def get_raw_trajectory(self):
        return self.share_data.tracking_trajectory

    @torch.no_grad()
    def get_keyframe_poses(self):
        keyframe_graph = self.mapper.keyframe_graph
        poses = []
        for keyframe in keyframe_graph:
            poses.append(keyframe.get_pose().detach().cpu().numpy())
        return poses


================================================
FILE: src/se3pose.py
================================================

import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy


class OptimizablePose(nn.Module):
    def __init__(self, init_pose):
        super().__init__()
        assert (isinstance(init_pose, torch.FloatTensor))
        self.register_parameter('data', nn.Parameter(init_pose))
        self.data.required_grad_ = True

    def copy_from(self, pose):
        self.data = deepcopy(pose.data)

    def matrix(self):
        Rt = torch.eye(4)
        Rt[:3, :3] = self.rotation()
        Rt[:3, 3] = self.translation()
        return Rt

    def rotation(self):
        w = self.data[3:]
        wx = self.skew_symmetric(w)
        theta = w.norm(dim=-1)[..., None, None]
        I = torch.eye(3, device=w.device, dtype=torch.float32)
        A = self.taylor_A(theta)
        B = self.taylor_B(theta)
        R = I+A*wx+B*wx@wx
        return R

    def translation(self,):
        return self.data[:3]

    @classmethod
    def log(cls, R, eps=1e-7):  # [...,3,3]
        trace = R[..., 0, 0]+R[..., 1, 1]+R[..., 2, 2]
        # ln(R) will explode if theta==pi
        theta = ((trace-1)/2).clamp(-1+eps, 1-eps).acos_()[..., None, None] % np.pi
        lnR = 1/(2*cls.taylor_A(theta)+1e-8) * (R-R.transpose(-2, -1))  # FIXME: wei-chiu finds it weird
        w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]
        w = torch.stack([w0, w1, w2], dim=-1)
        return w

    @classmethod
    def from_matrix(cls, Rt, eps=1e-8):  # [...,3,4]
        R, u = Rt[:3, :3], Rt[:3, 3]
        w = cls.log(R)
        return OptimizablePose(torch.cat([u, w], dim=-1))

    @classmethod
    def skew_symmetric(cls, w):
        w0, w1, w2 = w.unbind(dim=-1)
        O = torch.zeros_like(w0)
        wx = torch.stack([
            torch.stack([O, -w2, w1], dim=-1),
            torch.stack([w2, O, -w0], dim=-1),
            torch.stack([-w1, w0, O], dim=-1)], dim=-2)
        return wx

    @classmethod
    def taylor_A(cls, x, nth=10):
        # Taylor expansion of sin(x)/x
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            if i > 0:
                denom *= (2*i)*(2*i+1)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans

    @classmethod
    def taylor_B(cls, x, nth=10):
        # Taylor expansion of (1-cos(x))/x**2
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            denom *= (2*i+1)*(2*i+2)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans

    @classmethod
    def taylor_C(cls, x, nth=10):
        # Taylor expansion of (x-sin(x))/x**3
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            denom *= (2*i+2)*(2*i+3)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans


if __name__ == '__main__':
    before = torch.tensor([[-0.955421, 0.119616, - 0.269932, 2.655830],
                           [0.295248, 0.388339, - 0.872939, 2.981598],
                           [0.000408, - 0.913720, - 0.406343, 1.368648],
                           [0.000000, 0.000000, 0.000000, 1.000000]])
    pose = OptimizablePose.from_matrix(before)
    print(pose.rotation())
    print(pose.translation())
    after = pose.matrix()
    print(after)
    print(torch.abs((before-after)[:3, 3]))


================================================
FILE: src/share.py
================================================
from multiprocessing.managers import BaseManager, NamespaceProxy
from copy import deepcopy
import torch.multiprocessing as mp
from time import sleep
import sys


class ShareDataProxy(NamespaceProxy):
    _exposed_ = ('__getattribute__', '__setattr__')


class ShareData:
    global lock
    lock = mp.RLock()

    def __init__(self):
        self.__stop_mapping = False
        self.__stop_tracking = False

        self.__decoder = None
        self.__voxels = None
        self.__octree = None
        self.__states = None
        self.__tracking_trajectory = []

    @property
    def decoder(self):
        with lock:
            return deepcopy(self.__decoder)
            print("========== decoder get ==========")
            sys.stdout.flush()

    @decoder.setter
    def decoder(self, decoder):
        with lock:
            self.__decoder = deepcopy(decoder)
            # print("========== decoder set ==========")
            sys.stdout.flush()

    @property
    def voxels(self):
        with lock:
            return deepcopy(self.__voxels)
            print("========== voxels get ==========")
            sys.stdout.flush()

    @voxels.setter
    def voxels(self, voxels):
        with lock:
            self.__voxels = deepcopy(voxels)
            print("========== voxels set ==========")
            sys.stdout.flush()

    @property
    def octree(self):
        with lock:
            return deepcopy(self.__octree)
            print("========== octree get ==========")
            sys.stdout.flush()

    @octree.setter
    def octree(self, octree):
        with lock:
            self.__octree = deepcopy(octree)
            print("========== octree set ==========")
            sys.stdout.flush()

    @property
    def states(self):
        with lock:
            return self.__states
            print("========== states get ==========")
            sys.stdout.flush()

    @states.setter
    def states(self, states):
        with lock:
            self.__states = states
            # print("========== states set ==========")
            sys.stdout.flush()

    @property
    def stop_mapping(self):
        with lock:
            return self.__stop_mapping
            print("========== stop_mapping get ==========")
            sys.stdout.flush()

    @stop_mapping.setter
    def stop_mapping(self, stop_mapping):
        with lock:
            self.__stop_mapping = stop_mapping
            print("========== stop_mapping set ==========")
            sys.stdout.flush()

    @property
    def stop_tracking(self):
        with lock:
            return self.__stop_tracking
            print("========== stop_tracking get ==========")
            sys.stdout.flush()

    @stop_tracking.setter
    def stop_tracking(self, stop_tracking):
        with lock:
            self.__stop_tracking = stop_tracking
            print("========== stop_tracking set ==========")
            sys.stdout.flush()

    @property
    def tracking_trajectory(self):
        with lock:
            return deepcopy(self.__tracking_trajectory)
            print("========== tracking_trajectory get ==========")
            sys.stdout.flush()

    def push_pose(self, pose):
        with lock:
            self.__tracking_trajectory.append(deepcopy(pose))
            # print("========== push_pose ==========")
            sys.stdout.flush()


================================================
FILE: src/tracking.py
================================================
import torch
import numpy as np
from tqdm import tqdm

from criterion import Criterion
from lidarFrame import LidarFrame
from utils.import_util import get_property
from utils.profile_util import Profiler
from variations.render_helpers import fill_in, render_rays, track_frame
from se3pose import OptimizablePose
from time import sleep
from copy import deepcopy


class Tracking:
    def __init__(self, args, data_stream, logger):
        self.args = args
        self.last_frame_id = 0
        self.last_frame = None

        self.data_stream = data_stream
        self.logger = logger
        self.loss_criteria = Criterion(args)

        self.voxel_size = args.mapper_specs["voxel_size"]
        self.N_rays = args.tracker_specs["N_rays"]
        self.num_iterations = args.tracker_specs["num_iterations"]
        self.sdf_truncation = args.criteria["sdf_truncation"]
        self.learning_rate = args.tracker_specs["learning_rate"]
        self.start_frame = args.tracker_specs["start_frame"]
        self.end_frame = args.tracker_specs["end_frame"]
        self.step_size = args.tracker_specs["step_size"]
        # self.keyframe_freq = args.tracker_specs["keyframe_freq"]
        self.max_voxel_hit = args.tracker_specs["max_voxel_hit"]
        self.max_distance = args.data_specs["max_depth"]
        self.step_size = self.step_size * self.voxel_size
        self.read_offset = args.tracker_specs["read_offset"]
        self.mesh_freq = args.debug_args["mesh_freq"]
        if self.end_frame <= 0:
            self.end_frame = len(self.data_stream)-1

        # sanity check on the lower/upper bounds
        self.start_frame = min(self.start_frame, len(self.data_stream))
        self.end_frame = min(self.end_frame, len(self.data_stream))
        self.rel_pose = None
        # profiler
        verbose = get_property(args.debug_args, "verbose", False)
        self.profiler = Profiler(verbose=verbose)
        self.profiler.enable()

    def process_first_frame(self, kf_buffer):
        init_pose = self.data_stream.get_init_pose(self.start_frame)
        index, points, pointcos, _ = self.data_stream[self.start_frame]
        first_frame = LidarFrame(index, points, pointcos, init_pose)
        first_frame.pose.requires_grad_(False)
        first_frame.points.requires_grad_(False)

        print("******* initializing first_frame:", first_frame.index)
        kf_buffer.put(first_frame, block=True)
        self.last_frame = first_frame
        self.start_frame += 1


    def spin(self, share_data, kf_buffer):
        print("******* tracking process started! *******")
        progress_bar = tqdm(
            range(self.start_frame, self.end_frame+1), position=0)
        progress_bar.set_description("tracking frame")
        for frame_id in progress_bar:
            if frame_id % self.read_offset != 0:
                continue
            if share_data.stop_tracking:
                break

            data_in = self.data_stream[frame_id]

            current_frame = LidarFrame(*data_in)
            if isinstance(data_in[3], np.ndarray):
                self.last_frame = current_frame
                self.check_keyframe(current_frame, kf_buffer)

            else:
                self.do_tracking(share_data, current_frame, kf_buffer)

        share_data.stop_mapping = True
        print("******* tracking process died *******")

        sleep(60)
        while not kf_buffer.empty():
            sleep(60)

    def check_keyframe(self, check_frame, kf_buffer):
        try:
            kf_buffer.put(check_frame, block=True)
        except:
            pass

    def do_tracking(self, share_data, current_frame, kf_buffer):

        self.profiler.tick("before track1111")
        decoder = share_data.decoder.cuda()
        self.profiler.tok("before track1111")

        self.profiler.tick("before track2222")
        map_states = share_data.states
        map_states["voxel_vertex_emb"] = map_states["voxel_vertex_emb"].cuda()
        self.profiler.tok("before track2222")

        constant_move_pose = self.last_frame.get_pose().detach()
        input_pose = deepcopy(self.last_frame.pose)
        input_pose.requires_grad_(False)
        if self.rel_pose != None:
            constant_move_pose[:3, 3] = (constant_move_pose @ (self.rel_pose))[:3, 3]
            input_pose.data[:3] = constant_move_pose[:3, 3].T
        torch.cuda.empty_cache()
        self.profiler.tick("track frame")

        frame_pose, hit_mask = track_frame(
            input_pose,
            current_frame,
            map_states,
            decoder,
            self.loss_criteria,
            self.voxel_size,
            self.N_rays,
            self.step_size,
            self.num_iterations if self.rel_pose != None else self.num_iterations*5,
            self.sdf_truncation,
            self.learning_rate if self.rel_pose != None else self.learning_rate,
            self.max_voxel_hit,
            self.max_distance,
            profiler=self.profiler,
            depth_variance=True
        )
        self.profiler.tok("track frame")
        if hit_mask == None:
            current_frame.pose = OptimizablePose.from_matrix(constant_move_pose)
        else:
            current_frame.pose = frame_pose
            current_frame.hit_ratio = hit_mask.sum() / self.N_rays

        self.rel_pose = torch.linalg.inv(self.last_frame.get_pose().detach()) @ current_frame.get_pose().detach()
        current_frame.set_rel_pose(self.rel_pose)
        self.last_frame = current_frame

        self.profiler.tick("transport frame")
        self.check_keyframe(current_frame, kf_buffer)
        self.profiler.tok("transport frame")


================================================
FILE: src/utils/__init__.py
================================================


================================================
FILE: src/utils/import_util.py
================================================
from importlib import import_module
import argparse

def get_dataset(args):
    Dataset = import_module("dataset."+args.dataset)
    return Dataset.DataLoader(**args.data_specs)

def get_decoder(args):
    Decoder = import_module("variations."+args.decoder)
    return Decoder.Decoder(**args.decoder_specs)

def get_property(args, name, default):
    if isinstance(args, dict):
        return args.get(name, default)
    elif isinstance(args, argparse.Namespace):
        if hasattr(args, name):
            return vars(args)[name]
        else:
            return default
    else:
        raise ValueError(f"unkown dict/namespace type: {type(args)}")

================================================
FILE: src/utils/mesh_util.py
================================================
import math
import torch

import numpy as np
import open3d as o3d
from scipy.spatial import cKDTree
from skimage.measure import marching_cubes
from variations.render_helpers import get_scores, eval_points


class MeshExtractor:
    def __init__(self, args):
        self.voxel_size = args.mapper_specs["voxel_size"]
        self.rays_d = None
        self.depth_points = None

    @ torch.no_grad()
    def linearize_id(self, xyz, n_xyz):
        return xyz[:, 2] + n_xyz[-1] * xyz[:, 1] + (n_xyz[-1] * n_xyz[-2]) * xyz[:, 0]

    @torch.no_grad()
    def downsample_points(self, points, voxel_size=0.01):
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd = pcd.voxel_down_sample(voxel_size)
        return np.asarray(pcd.points)

    @torch.no_grad()
    def get_rays(self, w=None, h=None, K=None):
        w = self.w if w == None else w
        h = self.h if h == None else h
        if K is None:
            K = np.eye(3)
            K[0, 0] = self.K[0, 0] * w / self.w
            K[1, 1] = self.K[1, 1] * h / self.h
            K[0, 2] = self.K[0, 2] * w / self.w
            K[1, 2] = self.K[1, 2] * h / self.h
        ix, iy = torch.meshgrid(
            torch.arange(w), torch.arange(h), indexing='xy')
        rays_d = torch.stack(
            [(ix-K[0, 2]) / K[0, 0],
             (iy-K[1, 2]) / K[1, 1],
             torch.ones_like(ix)], -1).float()
        return rays_d

    @torch.no_grad()
    def get_valid_points(self, frame_poses, depth_maps):
        if isinstance(frame_poses, list):
            all_points = []
            print("extracting all points")
            for i in range(0, len(frame_poses), 5):
                pose = frame_poses[i]
                depth = depth_maps[i]
                points = self.rays_d * depth.unsqueeze(-1)
                points = points.reshape(-1, 3)
                points = points @ pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
                if len(all_points) == 0:
                    all_points = points.detach().cpu().numpy()
                else:
                    all_points = np.concatenate(
                        [all_points, points.detach().cpu().numpy()], 0)
            print("downsample all points")
            all_points = self.downsample_points(all_points)
            return all_points
        else:
            pose = frame_poses
            depth = depth_maps
            points = self.rays_d * depth.unsqueeze(-1)
            points = points.reshape(-1, 3)
            points = points @ pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
            if self.depth_points is None:
                self.depth_points = points.detach().cpu().numpy()
            else:
                self.depth_points = np.concatenate(
                    [self.depth_points, points], 0)
            self.depth_points = self.downsample_points(self.depth_points)
        return self.depth_points

    @ torch.no_grad()
    def create_mesh(self, decoder, map_states, voxel_size, voxels,
                    frame_poses=None, depth_maps=None, clean_mseh=False,
                    require_color=False, offset=-80, res=8):

        sdf_grid = get_scores(decoder, map_states, voxel_size, bits=res)
        sdf_grid = sdf_grid.reshape(-1, res, res, res, 1)

        voxel_centres = map_states["voxel_center_xyz"]
        verts, faces = self.marching_cubes(voxel_centres, sdf_grid)

        if clean_mseh:
            print("********** get points from frames **********")
            all_points = self.get_valid_points(frame_poses, depth_maps)
            print("********** construct kdtree **********")
            kdtree = cKDTree(all_points)
            print("********** query kdtree **********")
            point_mask = kdtree.query_ball_point(
                verts, voxel_size * 0.5, workers=12, return_length=True)
            print("********** finished querying kdtree **********")
            point_mask = point_mask > 0
            face_mask = point_mask[faces.reshape(-1)].reshape(-1, 3).any(-1)

            faces = faces[face_mask]

        if require_color:
            print("********** get color from network **********")
            verts_torch = torch.from_numpy(verts).float().cuda()
            batch_points = torch.split(verts_torch, 1000)
            colors = []
            for points in batch_points:
                voxel_pos = points // self.voxel_size
                batch_voxels = voxels[:, :3].cuda()
                batch_voxels = batch_voxels.unsqueeze(
                    0).repeat(voxel_pos.shape[0], 1, 1)

                # filter outliers
                nonzeros = (batch_voxels == voxel_pos.unsqueeze(1)).all(-1)
                nonzeros = torch.where(nonzeros, torch.ones_like(
                    nonzeros).int(), -torch.ones_like(nonzeros).int())
                sorted, index = torch.sort(nonzeros, dim=-1, descending=True)
                sorted = sorted[:, 0]
                index = index[:, 0]
                valid = (sorted != -1)
                color_empty = torch.zeros_like(points)
                points = points[valid, :]
                index = index[valid]

                # get color
                if len(points) > 0:
                    color = eval_points(decoder, map_states,
                                        points, index, voxel_size).cuda()
                    color_empty[valid] = color
                colors += [color_empty]
            colors = torch.cat(colors, 0)

        mesh = o3d.geometry.TriangleMesh()
        mesh.vertices = o3d.utility.Vector3dVector(verts+offset)
        mesh.triangles = o3d.utility.Vector3iVector(faces)
        if require_color:
            mesh.vertex_colors = o3d.utility.Vector3dVector(
                colors.detach().cpu().numpy())
        mesh.compute_vertex_normals()
        return mesh

    @ torch.no_grad()
    def marching_cubes(self, voxels, sdf):
        voxels = voxels[:, :3]
        sdf = sdf[..., 0]
        res = 1.0 / (sdf.shape[1] - 1)
        spacing = [res, res, res]

        num_verts = 0
        total_verts = []
        total_faces = []
        for i in range(len(voxels)):
            sdf_volume = sdf[i].detach().cpu().numpy()
            if np.min(sdf_volume) > 0 or np.max(sdf_volume) < 0:
                continue
            verts, faces, _, _ = marching_cubes(sdf_volume, 0, spacing=spacing)
            verts -= 0.5
            verts *= self.voxel_size
            verts += voxels[i].detach().cpu().numpy()
            faces += num_verts
            num_verts += verts.shape[0]

            total_verts += [verts]
            total_faces += [faces]
        total_verts = np.concatenate(total_verts)
        total_faces = np.concatenate(total_faces)
        return total_verts, total_faces


================================================
FILE: src/utils/profile_util.py
================================================
from time import time
import torch


class Profiler(object):
    def __init__(self, verbose=False) -> None:
        self.timer = dict()
        self.time_log = dict()
        self.enabled = False
        self.verbose = verbose

    def enable(self):
        self.enabled = True

    def disable(self):
        self.enabled = False

    def tick(self, name):
        if not self.enabled:
            return
        self.timer[name] = time()
        if name not in self.time_log:
            self.time_log[name] = list()

    def tok(self, name):
        if not self.enabled:
            return
        if name not in self.timer:
            return
        torch.cuda.synchronize()
        elapsed = time() - self.timer[name]
        if self.verbose:
            print(f"{name}: {elapsed*1000:.2f} ms")
        else:
            self.time_log[name].append(elapsed * 1000)


================================================
FILE: src/utils/sample_util.py
================================================
import torch


def sampling_without_replacement(logp, k):
    def gumbel_like(u):
        return -torch.log(-torch.log(torch.rand_like(u) + 1e-7) + 1e-7)

    scores = logp + gumbel_like(logp)
    return scores.topk(k, dim=-1)[1]


def sample_rays(mask, num_samples):
    B, H, W = mask.shape
    probs = mask / (mask.sum() + 1e-9)
    flatten_probs = probs.reshape(B, -1)
    sampled_index = sampling_without_replacement(
        torch.log(flatten_probs + 1e-9), num_samples)
    sampled_masks = (torch.zeros_like(
        flatten_probs).scatter_(-1, sampled_index, 1).reshape(B, H, W) > 0)
    return sampled_masks

================================================
FILE: src/variations/decode_morton.py
================================================
import numpy as np


def compact(value):
    x = value & 0x1249249249249249
    x = (x | x >> 2) & 0x10c30c30c30c30c3
    x = (x | x >> 4) & 0x100f00f00f00f00f
    x = (x | x >> 8) & 0x1f0000ff0000ff
    x = (x | x >> 16) & 0x1f00000000ffff
    x = (x | x >> 32) & 0x1fffff
    return x


def decode(code):
    return compact(code >> 0), compact(code >> 1), compact(code >> 2)


for i in range(10):
    x, y, z = decode(samples_valid['sampled_point_voxel_idx'][i])
    print(x, y, z)
    print(torch.sqrt((x-80)**2+(y-80)**2+(z-80)**2))
    print(samples_valid['sampled_point_depth'][i])


================================================
FILE: src/variations/lidar.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class GaussianFourierFeatureTransform(torch.nn.Module):
    """
    Modified based on the implementation of Gaussian Fourier feature mapping.

    "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
       https://arxiv.org/abs/2006.10739
       https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html

    """

    def __init__(self, num_input_channels, mapping_size=93, scale=25, learnable=True):
        super().__init__()

        if learnable:
            self._B = nn.Parameter(torch.randn(
                (num_input_channels, mapping_size)) * scale)
        else:
            self._B = torch.randn((num_input_channels, mapping_size)) * scale
        self.embedding_size = mapping_size

    def forward(self, x):
        # x = x.squeeze(0)
        assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(x.dim())
        x = x @ self._B.to(x.device)
        return torch.sin(x)


class Nerf_positional_embedding(torch.nn.Module):
    """
    Nerf positional embedding.

    """

    def __init__(self, in_dim, multires, log_sampling=True):
        super().__init__()
        self.log_sampling = log_sampling
        self.include_input = True
        self.periodic_fns = [torch.sin, torch.cos]
        self.max_freq_log2 = multires-1
        self.num_freqs = multires
        self.max_freq = self.max_freq_log2
        self.N_freqs = self.num_freqs
        self.embedding_size = multires*in_dim*2 + in_dim

    def forward(self, x):
        # x = x.squeeze(0)
        assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(
            x.dim())

        if self.log_sampling:
            freq_bands = 2.**torch.linspace(0.,
                                            self.max_freq, steps=self.N_freqs)
        else:
            freq_bands = torch.linspace(
                2.**0., 2.**self.max_freq, steps=self.N_freqs)
        output = []
        if self.include_input:
            output.append(x)
        for freq in freq_bands:
            for p_fn in self.periodic_fns:
                output.append(p_fn(x * freq))
        ret = torch.cat(output, dim=1)
        return ret


class Same(nn.Module):
    def __init__(self, in_dim) -> None:
        super().__init__()
        self.embedding_size = in_dim

    def forward(self, x):
        return x


class Decoder(nn.Module):
    def __init__(self,
                 depth=8,
                 width=258,
                 in_dim=3,
                 sdf_dim=128,
                 skips=[4],
                 multires=6,
                 embedder='none',
                 point_dim=3,
                 local_coord=False,
                 **kwargs) -> None:
        super().__init__()
        self.D = depth
        self.W = width
        self.skips = skips
        self.point_dim = point_dim
        if embedder == 'nerf':
            self.pe = Nerf_positional_embedding(in_dim, multires)
        elif embedder == 'none':
            self.pe = Same(in_dim)
        elif embedder == 'gaussian':
            self.pe = GaussianFourierFeatureTransform(in_dim)
        else:
            raise NotImplementedError("unknown positional encoder")
        self.pts_linears = nn.ModuleList(
            [nn.Linear(self.pe.embedding_size, width)] + [nn.Linear(width, width) if i not in self.skips else nn.Linear(width + self.pe.embedding_size, width) for i in range(depth-1)])
        self.sdf_out = nn.Linear(width, 1)

    def get_values(self, input):
        x = self.pe(input)
        # point = input[:, -3:]
        h = x
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = F.relu(h)
            if i in self.skips:
                h = torch.cat([x, h], -1)

        # outputs = self.output_linear(h)
        # outputs[:, :3] = torch.sigmoid(outputs[:, :3])
        sdf_out = self.sdf_out(h)

        return sdf_out

    def forward(self, inputs):
        outputs = self.get_values(inputs)

        return {
            'sdf': outputs,
            # 'depth': outputs[:, 1]
        }


================================================
FILE: src/variations/render_helpers.py
================================================
from copy import deepcopy
import torch
import torch.nn.functional as F

from .voxel_helpers import ray_intersect, ray_sample
from torch.autograd import grad


def ray(ray_start, ray_dir, depths):
    return ray_start + ray_dir * depths


def fill_in(shape, mask, input, initial=1.0):
    if isinstance(initial, torch.Tensor):
        output = initial.expand(*shape)
    else:
        output = input.new_ones(*shape) * initial
    return output.masked_scatter(mask.unsqueeze(-1).expand(*shape), input)


def masked_scatter(mask, x):
    B, K = mask.size()
    if x.dim() == 1:
        return x.new_zeros(B, K).masked_scatter(mask, x)
    return x.new_zeros(B, K, x.size(-1)).masked_scatter(
        mask.unsqueeze(-1).expand(B, K, x.size(-1)), x
    )


def masked_scatter_ones(mask, x):
    B, K = mask.size()
    if x.dim() == 1:
        return x.new_ones(B, K).masked_scatter(mask, x)
    return x.new_ones(B, K, x.size(-1)).masked_scatter(
        mask.unsqueeze(-1).expand(B, K, x.size(-1)), x
    )


@torch.enable_grad()
def trilinear_interp(p, q, point_feats):
    weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True)
    if point_feats.dim() == 2:
        point_feats = point_feats.view(point_feats.size(0), 8, -1)

    point_feats = (weights * point_feats).sum(1)
    return point_feats


def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
    c = torch.arange(1, 2 * bits, 2, device=point_xyz.device)
    ox, oy, oz = torch.meshgrid([c, c, c], indexing='ij')
    offset = (torch.cat([
        ox.reshape(-1, 1),
        oy.reshape(-1, 1),
        oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1)
    if not offset_only:
        return (
            point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel)
    return offset.type_as(point_xyz) * quarter_voxel


@torch.enable_grad()
def get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size):
    # tri-linear interpolation
    p = ((sampled_xyz - point_xyz) / voxel_size + 0.5).unsqueeze(1)
    q = offset_points(p, 0.5, offset_only=True).unsqueeze(0) + 0.5
    feats = trilinear_interp(p, q, point_feats).float()
    # if self.args.local_coord:
    # feats = torch.cat([(p-.5).squeeze(1).float(), feats], dim=-1)
    return feats


@torch.enable_grad()
def get_features(samples, map_states, voxel_size):
    # encoder states
    point_idx = map_states["voxel_vertex_idx"].cuda()
    point_xyz = map_states["voxel_center_xyz"].cuda()
    values = map_states["voxel_vertex_emb"]
    point_id2embedid = map_states["voxel_id2embedding_id"]
    # ray point samples
    sampled_idx = samples["sampled_point_voxel_idx"].long()
    sampled_xyz = samples["sampled_point_xyz"]
    sampled_dis = samples["sampled_point_distance"]

    point_xyz = F.embedding(sampled_idx, point_xyz).requires_grad_()
    selected_points_idx = F.embedding(sampled_idx, point_idx)
    flatten_selected_points_idx = selected_points_idx.view(-1)
    embed_idx = F.embedding(flatten_selected_points_idx.cpu(), point_id2embedid).squeeze(-1)
    point_feats = F.embedding(embed_idx.cuda(), values).view(point_xyz.size(0), -1)

    feats = get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size)
    inputs = {"xyz": point_xyz, "dists": sampled_dis, "emb": feats.cuda()}
    return inputs


@torch.no_grad()
def get_scores(sdf_network, map_states, voxel_size, bits=8):
    feats = map_states["voxel_vertex_idx"]
    points = map_states["voxel_center_xyz"]
    values = map_states["voxel_vertex_emb"]
    point_id2embedid = map_states["voxel_id2embedding_id"]

    chunk_size = 10000
    res = bits  # -1

    @torch.no_grad()
    def get_scores_once(feats, points, values, point_id2embedid):
        torch.cuda.empty_cache()
        # sample points inside voxels
        start = -0.5
        end = 0.5  # - 1./bits

        x = y = z = torch.linspace(start, end, res)
        # z = torch.linspace(1, 1, res)
        xx, yy, zz = torch.meshgrid(x, y, z)
        sampled_xyz = torch.stack([xx, yy, zz], dim=-1).float().cuda()

        sampled_xyz *= voxel_size
        sampled_xyz = sampled_xyz.reshape(1, -1, 3) + points.unsqueeze(1)

        sampled_idx = torch.arange(points.size(0), device=points.device)
        sampled_idx = sampled_idx[:, None].expand(*sampled_xyz.size()[:2])
        sampled_idx = sampled_idx.reshape(-1)
        sampled_xyz = sampled_xyz.reshape(-1, 3)

        if sampled_xyz.shape[0] == 0:
            return

        field_inputs = get_features(
            {
                "sampled_point_xyz": sampled_xyz,
                "sampled_point_voxel_idx": sampled_idx,
                "sampled_point_ray_direction": None,
                "sampled_point_distance": None,
            },
            {
                "voxel_vertex_idx": feats,
                "voxel_center_xyz": points,
                "voxel_vertex_emb": values,
                "voxel_id2embedding_id": point_id2embedid
            },
            voxel_size
        )
        field_inputs = field_inputs["emb"]

        # evaluation with density
        sdf_values = sdf_network.get_values(field_inputs.float().cuda())
        return sdf_values.reshape(-1, res ** 3, 1).detach().cpu()

    return torch.cat([
        get_scores_once(feats[i: i + chunk_size],
                        points[i: i + chunk_size].cuda(), values, point_id2embedid)
        for i in range(0, points.size(0), chunk_size)], 0).view(-1, res, res, res, 1)


@torch.no_grad()
def eval_points(sdf_network, map_states, sampled_xyz, sampled_idx, voxel_size):
    feats = map_states["voxel_vertex_idx"]
    points = map_states["voxel_center_xyz"]
    values = map_states["voxel_vertex_emb"]

    # sampled_xyz = sampled_xyz.reshape(1, 3) + points.unsqueeze(1)
    # sampled_idx = sampled_idx[None, :].expand(*sampled_xyz.size()[:2])
    sampled_idx = sampled_idx.reshape(-1)
    sampled_xyz = sampled_xyz.reshape(-1, 3)

    if sampled_xyz.shape[0] == 0:
        return

    field_inputs = get_features(
        {
            "sampled_point_xyz": sampled_xyz,
            "sampled_point_voxel_idx": sampled_idx,
            "sampled_point_ray_direction": None,
            "sampled_point_distance": None,
        },
        {
            "voxel_vertex_idx": feats,
            "voxel_center_xyz": points,
            "voxel_vertex_emb": values,
        },
        voxel_size
    )

    # evaluation with density
    sdf_values = sdf_network.get_values(field_inputs['emb'].float().cuda())
    return sdf_values.reshape(-1, 4)[:, :3].detach().cpu()


def render_rays(
        rays_o,
        rays_d,
        map_states,
        sdf_network,
        step_size,
        voxel_size,
        truncation,
        max_voxel_hit,
        max_distance,
        chunk_size=10000,
        profiler=None,
        return_raw=False
):
    torch.cuda.empty_cache()
    centres = map_states["voxel_center_xyz"].cuda()
    childrens = map_states["voxel_structure"].cuda()

    if profiler is not None:
        profiler.tick("ray_intersect")
    # print("Center", rays_o[0][0])
    intersections, hits = ray_intersect(
        rays_o, rays_d, centres,
        childrens, voxel_size, max_voxel_hit, max_distance)
    if profiler is not None:
        profiler.tok("ray_intersect")
    if hits.sum() <= 0:
        return

    ray_mask = hits.view(1, -1)

    intersections = {
        name: outs[ray_mask].reshape(-1, outs.size(-1))
        for name, outs in intersections.items()
    }

    rays_o = rays_o[ray_mask].reshape(-1, 3)
    rays_d = rays_d[ray_mask].reshape(-1, 3)

    if profiler is not None:
        profiler.tick("ray_sample")
    samples = ray_sample(intersections, step_size=step_size)
    if samples == None:
        return
    if profiler is not None:
        profiler.tok("ray_sample")

    sampled_depth = samples['sampled_point_depth']
    sampled_idx = samples['sampled_point_voxel_idx'].long()

    # only compute when the ray hits

    sample_mask = sampled_idx.ne(-1)
    if sample_mask.sum() == 0:  # miss everything skip
        return None, 0

    sampled_xyz = ray(rays_o.unsqueeze(
        1), rays_d.unsqueeze(1), sampled_depth.unsqueeze(2))
    sampled_dir = rays_d.unsqueeze(1).expand(
        *sampled_depth.size(), rays_d.size()[-1])
    sampled_dir = sampled_dir / \
        (torch.norm(sampled_dir, 2, -1, keepdim=True) + 1e-8)
    samples['sampled_point_xyz'] = sampled_xyz
    samples['sampled_point_ray_direction'] = sampled_dir
    # apply mask
    samples_valid = {name: s[sample_mask] for name, s in samples.items()}
    # print("samples_valid_xyz", samples["sampled_point_xyz"].shape)
    num_points = samples_valid['sampled_point_depth'].shape[0]
    field_outputs = []
    if chunk_size < 0:
        chunk_size = num_points
    final_xyz = []
    xyz = 0
    for i in range(0, num_points, chunk_size):
        torch.cuda.empty_cache()
        chunk_samples = {name: s[i:i+chunk_size]
                         for name, s in samples_valid.items()}

        # get encoder features as inputs
        if profiler is not None:
            profiler.tick("get_features")

        chunk_inputs = get_features(chunk_samples, map_states, voxel_size)
        xyz = chunk_inputs["xyz"]
        if profiler is not None:
            profiler.tok("get_features")
        # add coordinate information
        chunk_inputs = chunk_inputs["emb"]

        # forward implicit fields
        if profiler is not None:
            profiler.tick("render_core")

        chunk_outputs = sdf_network(chunk_inputs)
        if profiler is not None:
            profiler.tok("render_core")
        final_xyz.append(xyz)
        field_outputs.append(chunk_outputs)

    field_outputs = {name: torch.cat(
        [r[name] for r in field_outputs], dim=0) for name in field_outputs[0]}
    final_xyz = torch.cat(final_xyz, 0)
    outputs = field_outputs['sdf']
    d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
    sdf_grad = grad(outputs=outputs,
                    inputs=xyz,
                    grad_outputs=d_points,
                    retain_graph=True,)

    outputs = {'sample_mask': sample_mask}

    sdf = masked_scatter_ones(sample_mask, field_outputs['sdf']).squeeze(-1)
    # depth = masked_scatter(sample_mask, field_outputs['depth'])

    # colour = torch.sigmoid(colour)
    sample_mask = outputs['sample_mask']

    valid_mask = torch.where(
        sample_mask, torch.ones_like(
            sample_mask), torch.zeros_like(sample_mask)
    )

    return {
        "z_vals": samples["sampled_point_depth"],
        "sdf": sdf,
        "ray_mask": ray_mask,
        "valid_mask": valid_mask,
        "sampled_xyz": xyz,
    }


def bundle_adjust_frames(
    keyframe_graph,
    embeddings,
    map_states,
    sdf_network,
    loss_criteria,
    voxel_size,
    step_size,
    N_rays=512,
    num_iterations=10,
    truncation=0.1,
    max_voxel_hit=10,
    max_distance=10,
    learning_rate=[1e-2, 1e-2, 5e-3],
    update_pose=True,
    update_decoder=True,
    profiler=None
):
    if profiler is not None:
        profiler.tick("mapping_add_optim")
    optimize_params = [{'params': embeddings, 'lr': learning_rate[0]}]
    if update_decoder:
        optimize_params += [{'params': sdf_network.parameters(),
                             'lr': learning_rate[1]}]

    for keyframe in keyframe_graph:
        if keyframe.index != 0 and update_pose:
            keyframe.pose.requires_grad_(True)
            optimize_params += [{
                'params': keyframe.pose.parameters(), 'lr': learning_rate[2]
            }]

    optim = torch.optim.Adam(optimize_params)
    if profiler is not None:
        profiler.tok("mapping_add_optim")
    for iter in range(num_iterations):
        torch.cuda.empty_cache()
        rays_o = []
        rays_d = []
        rgb_samples = []
        depth_samples = []
        points_samples = []
        pointsCos_samples = []
        if iter == 0 and profiler is not None:
            profiler.tick("mapping sample_rays")
        for frame in keyframe_graph:
            torch.cuda.empty_cache()
            pose = frame.get_pose().cuda()
            frame.sample_rays(N_rays)

            sample_mask = frame.sample_mask.cuda()
            sampled_rays_d = frame.rays_d[sample_mask].cuda()
            # print(sampled_rays_d)
            R = pose[: 3, : 3].transpose(-1, -2)
            sampled_rays_d = sampled_rays_d@R
            sampled_rays_o = pose[: 3, 3].reshape(1, -1).expand_as(sampled_rays_d)

            rays_d += [sampled_rays_d]
            rays_o += [sampled_rays_o]
            points_samples += [frame.points.unsqueeze(1).cuda()[sample_mask]]
            pointsCos_samples += [frame.pointsCos.unsqueeze(1).cuda()[sample_mask]]
            # rgb_samples += [frame.rgb.cuda()[sample_mask]]
            # depth_samples += [frame.depth.cuda()[sample_mask]]

        rays_d = torch.cat(rays_d, dim=0).unsqueeze(0)
        rays_o = torch.cat(rays_o, dim=0).unsqueeze(0)
        points_samples = torch.cat(points_samples, dim=0).unsqueeze(0)
        pointsCos_samples = torch.cat(pointsCos_samples, dim=0).unsqueeze(0)

        if iter == 0 and profiler is not None:
            profiler.tok("mapping sample_rays")
        if iter == 0 and profiler is not None:
            profiler.tick("mapping rendering")
        final_outputs = render_rays(
            rays_o,
            rays_d,
            map_states,
            sdf_network,
            step_size,
            voxel_size,
            truncation,
            max_voxel_hit,
            max_distance,
            chunk_size=-1,
            profiler=profiler if iter == 0 else None
        )
        if final_outputs == None:
            print("Encouter a bug while Mapping, currently not be fixed, Continue!!")
            hit_mask = None
            continue
        if iter == 0 and profiler is not None:
            profiler.tok("mapping rendering")
        # if final_outputs == None:
        #    continue
        if iter == 0 and profiler is not None:
            profiler.tick("mapping back proj")
        torch.cuda.empty_cache()
        loss, _ = loss_criteria(
            final_outputs, points_samples, pointsCos_samples)

        optim.zero_grad()
        loss.backward()
        optim.step()
        if iter == 0 and profiler is not None:
            profiler.tok("mapping back proj")


def track_frame(
    frame_pose,
    curr_frame,
    map_states,
    sdf_network,
    loss_criteria,
    voxel_size,
    N_rays=512,
    step_size=0.05,
    num_iterations=10,
    truncation=0.1,
    learning_rate=1e-3,
    max_voxel_hit=10,
    max_distance=10,
    profiler=None,
    depth_variance=False
):
    torch.cuda.empty_cache()
    init_pose = deepcopy(frame_pose).cuda()
    init_pose.requires_grad_(True)
    optim = torch.optim.Adam(init_pose.parameters(),
                             lr=learning_rate*2 if curr_frame.index < 2
                             else learning_rate/3)

    for iter in range(num_iterations):
        torch.cuda.empty_cache()
        if iter == 0 and profiler is not None:
            profiler.tick("track sample_rays")
        curr_frame.sample_rays(N_rays, track=True)
        if iter == 0 and profiler is not None:
            profiler.tok("track sample_rays")

        sample_mask = curr_frame.sample_mask
        ray_dirs = curr_frame.rays_d[sample_mask].unsqueeze(0).cuda()

        points_samples = curr_frame.points.unsqueeze(1).cuda()[sample_mask]
        pointsCos_samples = curr_frame.pointsCos.unsqueeze(1).cuda()[sample_mask]


        ray_dirs_iter = ray_dirs.squeeze(
            0) @ init_pose.rotation().transpose(-1, -2)
        ray_dirs_iter = ray_dirs_iter.unsqueeze(0)
        ray_start_iter = init_pose.translation().reshape(
            1, 1, -1).expand_as(ray_dirs_iter).cuda().contiguous()

        if iter == 0 and profiler is not None:
            profiler.tick("track render_rays")
        final_outputs = render_rays(
            ray_start_iter,
            ray_dirs_iter,
            map_states,
            sdf_network,
            step_size,
            voxel_size,
            truncation,
            max_voxel_hit,
            max_distance,
            chunk_size=-2,
            profiler=profiler if iter == 0 else None
        )
        if final_outputs == None:
            print("Encouter a bug while Tracking, currently not be fixed, Restarting!!")
            hit_mask = None
            break

        torch.cuda.empty_cache()
        if iter == 0 and profiler is not None:
            profiler.tok("track render_rays")

        hit_mask = final_outputs["ray_mask"].view(N_rays)
        final_outputs["ray_mask"] = hit_mask

        if iter == 0 and profiler is not None:
            profiler.tick("track loss_criteria")
        loss, _ = loss_criteria(
            final_outputs, points_samples, pointsCos_samples, weight_depth_loss=depth_variance)
        if iter == 0 and profiler is not None:
            profiler.tok("track loss_criteria")
        if iter == 0 and profiler is not None:
            profiler.tick("track backward step")
        optim.zero_grad()
        loss.backward()
        optim.step()
        if iter == 0 and profiler is not None:
            profiler.tok("track backward step")

    return init_pose, hit_mask


================================================
FILE: src/variations/voxel_helpers.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

""" Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch """
from __future__ import (
    division,
    absolute_import,
    with_statement,
    print_function,
    unicode_literals,
)
import os
import sys
import torch
import torch.nn.functional as F
from torch.autograd import Function
import torch.nn as nn
import sys
import numpy as np
import grid as _ext

MAX_DEPTH = 80


class BallRayIntersect(Function):
    @staticmethod
    def forward(ctx, radius, n_max, points, ray_start, ray_dir):
        inds, min_depth, max_depth = _ext.ball_intersect(
            ray_start.float(), ray_dir.float(), points.float(), radius, n_max
        )
        min_depth = min_depth.type_as(ray_start)
        max_depth = max_depth.type_as(ray_start)

        ctx.mark_non_differentiable(inds)
        ctx.mark_non_differentiable(min_depth)
        ctx.mark_non_differentiable(max_depth)
        return inds, min_depth, max_depth

    @staticmethod
    def backward(ctx, a, b, c):
        return None, None, None, None, None


ball_ray_intersect = BallRayIntersect.apply


class AABBRayIntersect(Function):
    @staticmethod
    def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir):
        # HACK: speed-up ray-voxel intersection by batching...
        # HACK: avoid out-of-memory
        G = min(2048, int(2 * 10 ** 9 / points.numel()))
        S, N = ray_start.shape[:2]
        K = int(np.ceil(N / G))
        H = K * G
        if H > N:
            ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
            ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
        ray_start = ray_start.reshape(S * G, K, 3)
        ray_dir = ray_dir.reshape(S * G, K, 3)
        points = points.expand(S * G, *points.size()[1:]).contiguous()

        inds, min_depth, max_depth = _ext.aabb_intersect(
            ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max
        )
        min_depth = min_depth.type_as(ray_start)
        max_depth = max_depth.type_as(ray_start)

        inds = inds.reshape(S, H, -1)
        min_depth = min_depth.reshape(S, H, -1)
        max_depth = max_depth.reshape(S, H, -1)
        if H > N:
            inds = inds[:, :N]
            min_depth = min_depth[:, :N]
            max_depth = max_depth[:, :N]

        ctx.mark_non_differentiable(inds)
        ctx.mark_non_differentiable(min_depth)
        ctx.mark_non_differentiable(max_depth)
        return inds, min_depth, max_depth

    @staticmethod
    def backward(ctx, a, b, c):
        return None, None, None, None, None


aabb_ray_intersect = AABBRayIntersect.apply


class SparseVoxelOctreeRayIntersect(Function):
    @staticmethod
    def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir):
        # HACK: avoid out-of-memory
        torch.cuda.empty_cache()
        G = min(256, int(2 * 10 ** 9 / (points.numel() + children.numel())))
        S, N = ray_start.shape[:2]
        K = int(np.ceil(N / G))
        H = K * G
        if H > N:
            ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
            ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
        ray_start = ray_start.reshape(S * G, K, 3)
        ray_dir = ray_dir.reshape(S * G, K, 3)
        points = points.expand(S * G, *points.size()).contiguous()
        torch.cuda.empty_cache()
        children = children.expand(S * G, *children.size()).contiguous()
        torch.cuda.empty_cache()
        inds, min_depth, max_depth = _ext.svo_intersect(
            ray_start.float(),
            ray_dir.float(),
            points.float(),
            children.int(),
            voxelsize,
            n_max,
        )
        torch.cuda.empty_cache()
        min_depth = min_depth.type_as(ray_start)
        max_depth = max_depth.type_as(ray_start)

        inds = inds.reshape(S, H, -1)
        min_depth = min_depth.reshape(S, H, -1)
        max_depth = max_depth.reshape(S, H, -1)
        if H > N:
            inds = inds[:, :N]
            min_depth = min_depth[:, :N]
            max_depth = max_depth[:, :N]

        ctx.mark_non_differentiable(inds)
        ctx.mark_non_differentiable(min_depth)
        ctx.mark_non_differentiable(max_depth)
        return inds, min_depth, max_depth

    @staticmethod
    def backward(ctx, a, b, c):
        return None, None, None, None, None


svo_ray_intersect = SparseVoxelOctreeRayIntersect.apply


class TriangleRayIntersect(Function):
    @staticmethod
    def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir):
        # HACK: speed-up ray-voxel intersection by batching...
        # HACK: avoid out-of-memory
        G = min(2048, int(2 * 10 ** 9 / (3 * faces.numel())))
        S, N = ray_start.shape[:2]
        K = int(np.ceil(N / G))
        H = K * G
        if H > N:
            ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
            ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
        ray_start = ray_start.reshape(S * G, K, 3)
        ray_dir = ray_dir.reshape(S * G, K, 3)
        face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3))
        face_points = (
            face_points.unsqueeze(0).expand(
                S * G, *face_points.size()).contiguous()
        )
        inds, depth, uv = _ext.triangle_intersect(
            ray_start.float(),
            ray_dir.float(),
            face_points.float(),
            cagesize,
            blur_ratio,
            n_max,
        )
        depth = depth.type_as(ray_start)
        uv = uv.type_as(ray_start)

        inds = inds.reshape(S, H, -1)
        depth = depth.reshape(S, H, -1, 3)
        uv = uv.reshape(S, H, -1)
        if H > N:
            inds = inds[:, :N]
            depth = depth[:, :N]
            uv = uv[:, :N]

        ctx.mark_non_differentiable(inds)
        ctx.mark_non_differentiable(depth)
        ctx.mark_non_differentiable(uv)
        return inds, depth, uv

    @staticmethod
    def backward(ctx, a, b, c):
        return None, None, None, None, None, None


triangle_ray_intersect = TriangleRayIntersect.apply


class UniformRaySampling(Function):
    @staticmethod
    def forward(
        ctx,
        pts_idx,
        min_depth,
        max_depth,
        step_size,
        max_ray_length,
        deterministic=False,
    ):
        G, N, P = 256, pts_idx.size(0), pts_idx.size(1)
        H = int(np.ceil(N / G)) * G
        if H > N:
            pts_idx = torch.cat([pts_idx, pts_idx[: H - N]], 0)
            min_depth = torch.cat([min_depth, min_depth[: H - N]], 0)
            max_depth = torch.cat([max_depth, max_depth[: H - N]], 0)
        pts_idx = pts_idx.reshape(G, -1, P)
        min_depth = min_depth.reshape(G, -1, P)
        max_depth = max_depth.reshape(G, -1, P)

        # pre-generate noise
        max_steps = int(max_ray_length / step_size)
        max_steps = max_steps + min_depth.size(-1) * 2
        noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
        if deterministic:
            noise += 0.5
        else:
            noise = noise.uniform_()

        # call cuda function
        sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling(
            pts_idx,
            min_depth.float(),
            max_depth.float(),
            noise.float(),
            step_size,
            max_steps,
        )
        sampled_depth = sampled_depth.type_as(min_depth)
        sampled_dists = sampled_dists.type_as(min_depth)

        sampled_idx = sampled_idx.reshape(H, -1)
        sampled_depth = sampled_depth.reshape(H, -1)
        sampled_dists = sampled_dists.reshape(H, -1)
        if H > N:
            sampled_idx = sampled_idx[:N]
            sampled_depth = sampled_depth[:N]
            sampled_dists = sampled_dists[:N]

        max_len = sampled_idx.ne(-1).sum(-1).max()
        sampled_idx = sampled_idx[:, :max_len]
        sampled_depth = sampled_depth[:, :max_len]
        sampled_dists = sampled_dists[:, :max_len]

        ctx.mark_non_differentiable(sampled_idx)
        ctx.mark_non_differentiable(sampled_depth)
        ctx.mark_non_differentiable(sampled_dists)
        return sampled_idx, sampled_depth, sampled_dists

    @staticmethod
    def backward(ctx, a, b, c):
        return None, None, None, None, None, None


uniform_ray_sampling = UniformRaySampling.apply


class InverseCDFRaySampling(Function):
    @staticmethod
    def forward(
        ctx,
        pts_idx,
        min_depth,
        max_depth,
        probs,
        steps,
        fixed_step_size=-1,
        deterministic=False,
    ):
        G, N, P = 200, pts_idx.size(0), pts_idx.size(1)
        H = int(np.ceil(N / G)) * G

        if H > N:
            pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H - N, P)], 0)
            min_depth = torch.cat(
                [min_depth, min_depth[:1].expand(H - N, P)], 0)
            max_depth = torch.cat(
                [max_depth, max_depth[:1].expand(H - N, P)], 0)
            probs = torch.cat([probs, probs[:1].expand(H - N, P)], 0)
            steps = torch.cat([steps, steps[:1].expand(H - N)], 0)

        # print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device)
        pts_idx = pts_idx.reshape(G, -1, P)
        min_depth = min_depth.reshape(G, -1, P)
        max_depth = max_depth.reshape(G, -1, P)
        probs = probs.reshape(G, -1, P)
        steps = steps.reshape(G, -1)

        # pre-generate noise
        max_steps = steps.ceil().long().max() + P
        # print(max_steps)
        # print(*min_depth.size()[:-1]," ", max_steps)
        noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
        if deterministic:
            noise += 0.5
        else:
            noise = noise.uniform_().clamp(min=0.001, max=0.999)  # in case

        # call cuda function
        chunk_size = 4 * G  # to avoid oom?
        results = [
            _ext.inverse_cdf_sampling(
                pts_idx[:, i: i + chunk_size].contiguous(),
                min_depth.float()[:, i: i + chunk_size].contiguous(),
                max_depth.float()[:, i: i + chunk_size].contiguous(),
                noise.float()[:, i: i + chunk_size].contiguous(),
                probs.float()[:, i: i + chunk_size].contiguous(),
                steps.float()[:, i: i + chunk_size].contiguous(),
                fixed_step_size,
            )
            for i in range(0, min_depth.size(1), chunk_size)
        ]

        sampled_idx, sampled_depth, sampled_dists = [
            torch.cat([r[i] for r in results], 1) for i in range(3)
        ]
        sampled_depth = sampled_depth.type_as(min_depth)
        sampled_dists = sampled_dists.type_as(min_depth)

        sampled_idx = sampled_idx.reshape(H, -1)
        sampled_depth = sampled_depth.reshape(H, -1)
        sampled_dists = sampled_dists.reshape(H, -1)
        if H > N:
            sampled_idx = sampled_idx[:N]
            sampled_depth = sampled_depth[:N]
            sampled_dists = sampled_dists[:N]

        max_len = sampled_idx.ne(-1).sum(-1).max()
        sampled_idx = sampled_idx[:, :max_len]
        sampled_depth = sampled_depth[:, :max_len]
        sampled_dists = sampled_dists[:, :max_len]

        ctx.mark_non_differentiable(sampled_idx)
        ctx.mark_non_differentiable(sampled_depth)
        ctx.mark_non_differentiable(sampled_dists)
        return sampled_idx, sampled_depth, sampled_dists

    @staticmethod
    def backward(ctx, a, b, c):
        return None, None, None, None, None, None, None


inverse_cdf_sampling = InverseCDFRaySampling.apply


# back-up for ray point sampling
@torch.no_grad()
def _parallel_ray_sampling(
    MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False
):
    # uniform sampling
    _min_depth = min_depth.min(1)[0]
    _max_depth = max_depth.masked_fill(max_depth.eq(MAX_DEPTH), 0).max(1)[0]
    max_ray_length = (_max_depth - _min_depth).max()

    delta = torch.arange(
        int(max_ray_length / MARCH_SIZE), device=min_depth.device, dtype=min_depth.dtype
    )
    delta = delta[None, :].expand(min_depth.size(0), delta.size(-1))
    if deterministic:
        delta = delta + 0.5
    else:
        delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99)
    delta = delta * MARCH_SIZE
    sampled_depth = min_depth[:, :1] + delta
    sampled_idx = (sampled_depth[:, :, None] >=
                   min_depth[:, None, :]).sum(-1) - 1
    sampled_idx = pts_idx.gather(1, sampled_idx)

    # include all boundary points
    sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1)
    sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1)

    # reorder
    sampled_depth, ordered_index = sampled_depth.sort(-1)
    sampled_idx = sampled_idx.gather(1, ordered_index)
    sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1]  # distances
    sampled_depth = 0.5 * \
        (sampled_depth[:, 1:] + sampled_depth[:, :-1])  # mid-points

    # remove all invalid depths
    min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1
    max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1)

    sampled_depth.masked_fill_(
        (max_ids.ne(min_ids))
        | (sampled_depth > _max_depth[:, None])
        | (sampled_dists == 0.0),
        MAX_DEPTH,
    )
    sampled_depth, ordered_index = sampled_depth.sort(-1)  # sort again
    sampled_masks = sampled_depth.eq(MAX_DEPTH)
    num_max_steps = (~sampled_masks).sum(-1).max()

    sampled_depth = sampled_depth[:, :num_max_steps]
    sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_(
        sampled_masks, 0.0
    )[:, :num_max_steps]
    sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_(sampled_masks, -1)[
        :, :num_max_steps
    ]

    return sampled_idx, sampled_depth, sampled_dists


@torch.no_grad()
def parallel_ray_sampling(
    MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False
):
    chunk_size = 4096
    full_size = min_depth.shape[0]
    if full_size <= chunk_size:
        return _parallel_ray_sampling(
            MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic
        )

    outputs = zip(
        *[
            _parallel_ray_sampling(
                MARCH_SIZE,
                pts_idx[i: i + chunk_size],
                min_depth[i: i + chunk_size],
                max_depth[i: i + chunk_size],
                deterministic=deterministic,
            )
            for i in range(0, full_size, chunk_size)
        ]
    )
    sampled_idx, sampled_depth, sampled_dists = outputs

    def padding_points(xs, pad):
        if len(xs) == 1:
            return xs[0]

        maxlen = max([x.size(1) for x in xs])
        full_size = sum([x.size(0) for x in xs])
        xt = xs[0].new_ones(full_size, maxlen).fill_(pad)
        st = 0
        for i in range(len(xs)):
            xt[st: st + xs[i].size(0), : xs[i].size(1)] = xs[i]
            st += xs[i].size(0)
        return xt

    sampled_idx = padding_points(sampled_idx, -1)
    sampled_depth = padding_points(sampled_depth, MAX_DEPTH)
    sampled_dists = padding_points(sampled_dists, 0.0)
    return sampled_idx, sampled_depth, sampled_dists


def discretize_points(voxel_points, voxel_size):
    # this function turns voxel centers/corners into integer indeices
    # we assume all points are alreay put as voxels (real numbers)
    minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0]
    voxel_indices = (
        ((voxel_points - minimal_voxel_point) / voxel_size).round_().long()
    )  # float
    residual = (voxel_points - voxel_indices.type_as(voxel_points) * voxel_size).mean(
        0, keepdim=True
    )
    return voxel_indices, residual


def build_easy_octree(points, half_voxel):
    coords, residual = discretize_points(points, half_voxel)
    ranges = coords.max(0)[0] - coords.min(0)[0]
    depths = torch.log2(ranges.max().float()).ceil_().long() - 1
    center = (coords.max(0)[0] + coords.min(0)[0]) / 2
    centers, children = _ext.build_octree(center, coords, int(depths))
    centers = centers.float() * half_voxel + residual  # transform back to float
    return centers, children


@torch.enable_grad()
def trilinear_interp(p, q, point_feats):
    weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True)
    if point_feats.dim() == 2:
        point_feats = point_feats.view(point_feats.size(0), 8, -1)

    point_feats = (weights * point_feats).sum(1)
    return point_feats


def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
    c = torch.arange(1, 2 * bits, 2, device=point_xyz.device)
    ox, oy, oz = torch.meshgrid([c, c, c], indexing='ij')
    offset = (torch.cat([ox.reshape(-1, 1),
                         oy.reshape(-1, 1),
                         oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1)
    if not offset_only:
        return (
            point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel)
    return offset.type_as(point_xyz) * quarter_voxel


def splitting_points(point_xyz, point_feats, values, half_voxel):
    # generate new centers
    quarter_voxel = half_voxel * 0.5
    new_points = offset_points(point_xyz, quarter_voxel).reshape(-1, 3)
    old_coords = discretize_points(point_xyz, quarter_voxel)[0]
    new_coords = offset_points(old_coords).reshape(-1, 3)
    new_keys0 = offset_points(new_coords).reshape(-1, 3)

    # get unique keys and inverse indices (for original key0, where it maps to in keys)
    new_keys, new_feats = torch.unique(
        new_keys0, dim=0, sorted=True, return_inverse=True)
    new_keys_idx = new_feats.new_zeros(new_keys.size(0)).scatter_(
        0, new_feats, torch.arange(new_keys0.size(0), device=new_feats.device) // 64)

    # recompute key vectors using trilinear interpolation
    new_feats = new_feats.reshape(-1, 8)

    if values is not None:
        # (1/4 voxel size)
        p = (new_keys - old_coords[new_keys_idx]
             ).type_as(point_xyz).unsqueeze(1) * 0.25 + 0.5
        q = offset_points(p, 0.5, offset_only=True).unsqueeze(0) + 0.5  # BUG?
        point_feats = point_feats[new_keys_idx]
        point_feats = F.embedding(point_feats, values).view(
            point_feats.size(0), -1)
        new_values = trilinear_interp(p, q, point_feats)
    else:
        new_values = None
    return new_points, new_feats, new_values, new_keys


@torch.no_grad()
def ray_intersect(ray_start, ray_dir, flatten_centers, flatten_children, voxel_size, max_hits, max_distance=MAX_DEPTH):
    # ray-voxel intersection
    max_hits_temp = 20
    pts_idx, min_depth, max_depth = svo_ray_intersect(
        voxel_size,
        max_hits_temp,
        flatten_centers,
        flatten_children,
        ray_start,
        ray_dir)
    torch.cuda.empty_cache()
    # sort the depths
    min_depth.masked_fill_(pts_idx.eq(-1), max_distance)
    max_depth.masked_fill_(pts_idx.eq(-1), max_distance)

    min_depth, sorted_idx = min_depth.sort(dim=-1)
    max_depth = max_depth.gather(-1, sorted_idx)
    pts_idx = pts_idx.gather(-1, sorted_idx)
    # print(max_depth.max())
    pts_idx[max_depth > 2*max_distance] = -1
    pts_idx[min_depth > max_distance] = -1
    min_depth.masked_fill_(pts_idx.eq(-1), max_distance)
    max_depth.masked_fill_(pts_idx.eq(-1), max_distance)
    # remove all points that completely miss the object
    max_hits = torch.max(pts_idx.ne(-1).sum(-1))
    min_depth = min_depth[..., :max_hits]
    max_depth = max_depth[..., :max_hits]
    pts_idx = pts_idx[..., :max_hits]

    hits = pts_idx.ne(-1).any(-1)

    intersection_outputs = {
        "min_depth": min_depth,
        "max_depth": max_depth,
        "intersected_voxel_idx": pts_idx,
    }
    return intersection_outputs, hits


@torch.no_grad()
def ray_sample(intersection_outputs, step_size=0.01, fixed=False):
    dists = (
        intersection_outputs["max_depth"] -
        intersection_outputs["min_depth"]
    ).masked_fill(intersection_outputs["intersected_voxel_idx"].eq(-1), 0)
    intersection_outputs["probs"] = dists / dists.sum(dim=-1, keepdim=True)
    intersection_outputs["steps"] = dists.sum(-1) / step_size
    # TODO:A serious BUG need to fix!
    if dists.sum(-1).max() > 10 * MAX_DEPTH:
        return
    # sample points and use middle point approximation
    sampled_idx, sampled_depth, sampled_dists = inverse_cdf_sampling(
        intersection_outputs["intersected_voxel_idx"],
        intersection_outputs["min_depth"],
        intersection_outputs["max_depth"],
        intersection_outputs["probs"],
        intersection_outputs["steps"], -1, fixed)

    sampled_dists = sampled_dists.clamp(min=0.0)
    sampled_depth.masked_fill_(sampled_idx.eq(-1), MAX_DEPTH)
    sampled_dists.masked_fill_(sampled_idx.eq(-1), 0.0)

    samples = {
        "sampled_point_depth": sampled_depth,
        "sampled_point_distance": sampled_dists,
        "sampled_point_voxel_idx": sampled_idx,
    }
    return samples


================================================
FILE: third_party/marching_cubes/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import glob

_ext_sources = glob.glob("src/*.cpp") + glob.glob("src/*.cu")

setup(
    name='marching_cubes',
    ext_modules=[
        CUDAExtension(
            name='marching_cubes',
            sources=_ext_sources,
            extra_compile_args={
                "cxx": ["-O2", "-I./include"],
                "nvcc": ["-I./include"]
            },
        )
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

================================================
FILE: third_party/marching_cubes/src/mc.cpp
================================================
#include <torch/extension.h>

std::vector<torch::Tensor> marching_cubes_sparse(
    torch::Tensor indexer,           // (nx, ny, nz) -> data_id
    torch::Tensor valid_blocks,      // (K, )
    torch::Tensor vec_batch_mapping, //
    torch::Tensor cube_sdf,          // (M, rx, ry, rz)
    torch::Tensor cube_std,          // (M, rx, ry, rz)
    const std::vector<int> &n_xyz,   // [nx, ny, nz]
    float max_std,                   // Prune all vertices
    int max_n_triangles              // Maximum number of triangle buffer.
);

std::vector<torch::Tensor> marching_cubes_sparse_colour(
    torch::Tensor indexer,           // (nx, ny, nz) -> data_id
    torch::Tensor valid_blocks,      // (K, )
    torch::Tensor vec_batch_mapping, //
    torch::Tensor cube_sdf,          // (M, rx, ry, rz, 4)
    torch::Tensor cube_colour,       // (M, rx, ry, rz)
    const std::vector<int> &n_xyz,   // [nx, ny, nz]
    float max_std,                   // Prune all vertices
    int max_n_triangles              // Maximum number of triangle buffer.
);

std::vector<torch::Tensor> marching_cubes_sparse_interp_cuda(
    torch::Tensor indexer,           // (nx, ny, nz) -> data_id
    torch::Tensor valid_blocks,      // (K, )
    torch::Tensor vec_batch_mapping, //
    torch::Tensor cube_sdf,          // (M, rx, ry, rz)
    torch::Tensor cube_std,          // (M, rx, ry, rz)
    const std::vector<int> &n_xyz,   // [nx, ny, nz]
    float max_std,                   // Prune all vertices
    int max_n_triangles              // Maximum number of triangle buffer.
);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("marching_cubes_sparse", &marching_cubes_sparse, "Marching Cubes without Interpolation (CUDA)");
    m.def("marching_cubes_sparse_colour", &marching_cubes_sparse_colour, "Marching Cubes without Interpolation (CUDA)");
    m.def("marching_cubes_sparse_interp", &marching_cubes_sparse_interp_cuda, "Marching Cubes with Interpolation (CUDA)");
}

================================================
FILE: third_party/marching_cubes/src/mc_data.cuh
================================================
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <thrust/device_vector.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

using IndexerAccessor = torch::PackedTensorAccessor32<int64_t, 3, torch::RestrictPtrTraits>;
using ValidBlocksAccessor = torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits>;
using BackwardMappingAccessor = torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits>;
using CubeSDFAccessor = torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits>;
using CubeSDFRGBAccessor = torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits>;
using TrianglesAccessor = torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits>;
using TriangleStdAccessor = torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits>;
using TriangleVecIdAccessor = torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits>;


__inline__ __device__ float4 make_float4(const float3& xyz, float w) {
    return make_float4(xyz.x, xyz.y, xyz.z, w);
}

inline __host__ __device__ void operator+=(float2 &a, const float2& b) {
    a.x += b.x; a.y += b.y;
}

inline __host__ __device__ float2 operator*(const float2& a, float b) {
    return make_float2(a.x * b, a.y * b);
}

inline __host__ __device__ float2 operator/(const float2& a, float b) {
    return make_float2(a.x / b, a.y / b);
}

inline __host__ __device__ float2 operator/(const float2& a, const float2& b) {
    return make_float2(a.x / b.x, a.y / b.y);
}

__constant__ int edgeTable[256] = { 0x0, 0x109, 0x203, 0x30a, 0x406, 0x50f, 0x605, 0x70c, 0x80c, 0x905, 0xa0f, 0xb06, 0xc0a, 0xd03, 0xe09, 0xf00,
	0x190, 0x99, 0x393, 0x29a, 0x596, 0x49f, 0x795, 0x69c, 0x99c, 0x895, 0xb9f, 0xa96, 0xd9a, 0xc93, 0xf99, 0xe90, 0x230, 0x339, 0x33, 0x13a,
	0x636, 0x73f, 0x435, 0x53c, 0xa3c, 0xb35, 0x83f, 0x936, 0xe3a, 0xf33, 0xc39, 0xd30, 0x3a0, 0x2a9, 0x1a3, 0xaa, 0x7a6, 0x6af, 0x5a5, 0x4ac,
	0xbac, 0xaa5, 0x9af, 0x8a6, 0xfaa, 0xea3, 0xda9, 0xca0, 0x460, 0x569, 0x663, 0x76a, 0x66, 0x16f, 0x265, 0x36c, 0xc6c, 0xd65, 0xe6f, 0xf66,
	0x86a, 0x963, 0xa69, 0xb60, 0x5f0, 0x4f9, 0x7f3, 0x6fa, 0x1f6, 0xff, 0x3f5, 0x2fc, 0xdfc, 0xcf5, 0xfff, 0xef6, 0x9fa, 0x8f3, 0xbf9, 0xaf0,
	0x650, 0x759, 0x453, 0x55a, 0x256, 0x35f, 0x55, 0x15c, 0xe5c, 0xf55, 0xc5f, 0xd56, 0xa5a, 0xb53, 0x859, 0x950, 0x7c0, 0x6c9, 0x5c3, 0x4ca,
	0x3c6, 0x2cf, 0x1c5, 0xcc, 0xfcc, 0xec5, 0xdcf, 0xcc6, 0xbca, 0xac3, 0x9c9, 0x8c0, 0x8c0, 0x9c9, 0xac3, 0xbca, 0xcc6, 0xdcf, 0xec5, 0xfcc,
	0xcc, 0x1c5, 0x2cf, 0x3c6, 0x4ca, 0x5c3, 0x6c9, 0x7c0, 0x950, 0x859, 0xb53, 0xa5a, 0xd56, 0xc5f, 0xf55, 0xe5c, 0x15c, 0x55, 0x35f, 0x256,
	0x55a, 0x453, 0x759, 0x650, 0xaf0, 0xbf9, 0x8f3, 0x9fa, 0xef6, 0xfff, 0xcf5, 0xdfc, 0x2fc, 0x3f5, 0xff, 0x1f6, 0x6fa, 0x7f3, 0x4f9, 0x5f0,
	0xb60, 0xa69, 0x963, 0x86a, 0xf66, 0xe6f, 0xd65, 0xc6c, 0x36c, 0x265, 0x16f, 0x66, 0x76a, 0x663, 0x569, 0x460, 0xca0, 0xda9, 0xea3, 0xfaa,
	0x8a6, 0x9af, 0xaa5, 0xbac, 0x4ac, 0x5a5, 0x6af, 0x7a6, 0xaa, 0x1a3, 0x2a9, 0x3a0, 0xd30, 0xc39, 0xf33, 0xe3a, 0x936, 0x83f, 0xb35, 0xa3c,
	0x53c, 0x435, 0x73f, 0x636, 0x13a, 0x33, 0x339, 0x230, 0xe90, 0xf99, 0xc93, 0xd9a, 0xa96, 0xb9f, 0x895, 0x99c, 0x69c, 0x795, 0x49f, 0x596,
	0x29a, 0x393, 0x99, 0x190, 0xf00, 0xe09, 0xd03, 0xc0a, 0xb06, 0xa0f, 0x905, 0x80c, 0x70c, 0x605, 0x50f, 0x406, 0x30a, 0x203, 0x109, 0x0 };

__constant__ int triangleTable[256][16] = { { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1 }, { 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1 }, { 3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1 }, { 3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1 }, { 9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1 }, { 8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1 }, { 3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1 }, { 4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1 },
{ 4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1 }, { 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1 }, { 5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1 }, { 9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1 }, { 0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1 }, { 10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1 }, { 5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1 },
{ 5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1 }, { 9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1 }, { 0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1 }, { 8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1 },
{ 2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1 }, { 7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1 }, { 2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1 },
{ 11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1 }, { 9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1 },
{ 5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1 }, { 11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1 },
{ 11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1 }, { 1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1 }, { 9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1 }, { 2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1 }, { 6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1 }, { 3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1 },
{ 6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1 }, { 5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1 }, { 6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1 }, { 8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1 },
{ 7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1 }, { 3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1 }, { 0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1 },
{ 9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1 }, { 8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1 },
{ 5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1 }, { 0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1 },
{ 6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1 }, { 10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1 }, { 10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1 }, { 1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1 }, { 0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1 }, { 10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1 }, { 3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1 },
{ 6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1 }, { 9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1 },
{ 8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1 }, { 3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1 }, { 10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1 },
{ 10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1 },
{ 2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1 }, { 7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1 },
{ 2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1 }, { 1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1 },
{ 11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1 }, { 8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1 },
{ 0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1 },
{ 7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1 }, { 10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1 }, { 2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1 }, { 7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1 }, { 2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1 }, { 10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1 }, { 0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1 },
{ 7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1 }, { 6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1 }, { 8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1 }, { 6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1 }, { 4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1 },
{ 10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1 }, { 8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1 },
{ 1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1 }, { 8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1 },
{ 10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1 }, { 4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1 },
{ 10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1 }, { 5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1 }, { 9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1 }, { 7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1 },
{ 3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1 }, { 7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1 }, { 3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1 },
{ 6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1 }, { 9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1 },
{ 1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1 }, { 4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1 },
{ 7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1 }, { 6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1 }, { 0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1 },
{ 6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1 },
{ 0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1 }, { 11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1 },
{ 6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1 }, { 5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1 },
{ 9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1 }, { 1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1 },
{ 1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1 },
{ 10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1 }, { 0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1 }, { 5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1 }, { 11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1 }, { 9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1 },
{ 7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1 }, { 2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1 }, { 9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1 },
{ 9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1 }, { 1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1 }, { 0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1 },
{ 10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1 }, { 2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1 },
{ 0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1 }, { 0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1 },
{ 9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1 },
{ 5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1 }, { 3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1 },
{ 5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1 }, { 8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1 },
{ 9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1 }, { 1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1 },
{ 3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1 }, { 4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1 },
{ 9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1 }, { 11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1 }, { 2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1 },
{ 9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1 }, { 3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1 },
{ 1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1 }, { 4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1 }, { 0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1 },
{ 1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 } };


================================================
FILE: third_party/marching_cubes/src/mc_interp_kernel.cu
================================================
#include "mc_data.cuh"

#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>

__device__ static inline float2 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
                                              const uint max_vec_num,
                                              const IndexerAccessor indexer,
                                              const CubeSDFAccessor cube_sdf,
                                              const CubeSDFAccessor cube_std,
                                              const BackwardMappingAccessor vec_batch_mapping)
{
    if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
    {
        return make_float2(NAN, NAN);
    }
    //    printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
    long long vec_ind = indexer[bx][by][bz];
    if (vec_ind == -1 || vec_ind >= max_vec_num)
    {
        return make_float2(NAN, NAN);
    }
    int batch_ind = vec_batch_mapping[vec_ind];
    if (batch_ind == -1)
    {
        return make_float2(NAN, NAN);
    }
    //    printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
    float sdf = cube_sdf[batch_ind][arx][ary][arz];
    float std = cube_std[batch_ind][arx][ary][arz];
    return make_float2(sdf, std);
}

// Use stddev to weight sdf value.
// #define STD_W_SDF

__device__ static inline float2 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
                                        const IndexerAccessor indexer,
                                        const CubeSDFAccessor cube_sdf,
                                        const CubeSDFAccessor cube_std,
                                        const BackwardMappingAccessor vec_batch_mapping)
{
    if (bpos.x >= bsize.x)
    {
        bpos.x = bsize.x - 1;
        rpos.x = r - 1;
    }
    if (bpos.y >= bsize.y)
    {
        bpos.y = bsize.y - 1;
        rpos.y = r - 1;
    }
    if (bpos.z >= bsize.z)
    {
        bpos.z = bsize.z - 1;
        rpos.z = r - 1;
    }

    uint rbound = (r - 1) / 2;
    uint rstart = r / 2;
    float rmid = r / 2.0f;

    float w_xm, w_xp;
    int bxm, rxm, bxp, rxp;
    int zero_x;
    if (rpos.x <= rbound)
    {
        bxm = -1;
        rxm = r;
        bxp = 0;
        rxp = 0;
        w_xp = (float)rpos.x + rmid;
        w_xm = rmid - (float)rpos.x;
        zero_x = 1;
    }
    else
    {
        bxm = 0;
        rxm = 0;
        bxp = 1;
        rxp = -r;
        w_xp = (float)rpos.x - rmid;
        w_xm = rmid + r - (float)rpos.x;
        zero_x = 0;
    }
    w_xm /= r;
    w_xp /= r;

    float w_ym, w_yp;
    int bym, rym, byp, ryp;
    int zero_y;
    if (rpos.y <= rbound)
    {
        bym = -1;
        rym = r;
        byp = 0;
        ryp = 0;
        w_yp = (float)rpos.y + rmid;
        w_ym = rmid - (float)rpos.y;
        zero_y = 1;
    }
    else
    {
        bym = 0;
        rym = 0;
        byp = 1;
        ryp = -r;
        w_yp = (float)rpos.y - rmid;
        w_ym = rmid + r - (float)rpos.y;
        zero_y = 0;
    }
    w_ym /= r;
    w_yp /= r;

    float w_zm, w_zp;
    int bzm, rzm, bzp, rzp;
    int zero_z;
    if (rpos.z <= rbound)
    {
        bzm = -1;
        rzm = r;
        bzp = 0;
        rzp = 0;
        w_zp = (float)rpos.z + rmid;
        w_zm = rmid - (float)rpos.z;
        zero_z = 1;
    }
    else
    {
        bzm = 0;
        rzm = 0;
        bzp = 1;
        rzp = -r;
        w_zp = (float)rpos.z - rmid;
        w_zm = rmid + r - (float)rpos.z;
        zero_z = 0;
    }
    w_zm /= r;
    w_zp /= r;

    rpos.x += rstart;
    rpos.y += rstart;
    rpos.z += rstart;

    // printf("%u %u %u %d %d %d %d %d %d\n", rpos.x, rpos.y, rpos.z, rxm, rxp, rym, ryp, rzm, rzp);

    // Tri-linear interpolation of SDF values.
#ifndef STD_W_SDF
    float total_weight = 0.0;
#else
    float2 total_weight{0.0, 0.0};
#endif
    float2 total_sdf{0.0, 0.0};

    int zero_det = zero_x * 4 + zero_y * 2 + zero_z;

    float2 sdfmmm = query_sdf_raw(bpos.x + bxm, bpos.y + bym, bpos.z + bzm, rpos.x + rxm, rpos.y + rym, rpos.z + rzm,
                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    float wmmm = w_xm * w_ym * w_zm;
#ifndef STD_W_SDF
    if (!isnan(sdfmmm.x))
    {
        total_sdf += sdfmmm * wmmm;
        total_weight += wmmm;
    }
#else
    if (!isnan(sdfmmm.x))
    {
        total_sdf.x += sdfmmm.x * wmmm * sdfmmm.y;
        total_weight.x += wmmm * sdfmmm.y;
        total_sdf.y += wmmm * sdfmmm.y;
        total_weight.y += wmmm;
    }
#endif
    else if (zero_det == 0)
    {
        return make_float2(NAN, NAN);
    }

    float2 sdfmmp = query_sdf_raw(bpos.x + bxm, bpos.y + bym, bpos.z + bzp, rpos.x + rxm, rpos.y + rym, rpos.z + rzp,
                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    float wmmp = w_xm * w_ym * w_zp;
#ifndef STD_W_SDF
    if (!isnan(sdfmmp.x))
    {
        total_sdf += sdfmmp * wmmp;
        total_weight += wmmp;
    }
#else
    if (!isnan(sdfmmp.x))
    {
        total_sdf.x += sdfmmp.x * wmmp * sdfmmp.y;
        total_weight.x += wmmp * sdfmmp.y;
        total_sdf.y += wmmp * sdfmmp.y;
        total_weight.y += wmmp;
    }
#endif
    else if (zero_det == 1)
    {
        return make_float2(NAN, NAN);
    }

    float2 sdfmpm = query_sdf_raw(bpos.x + bxm, bpos.y + byp, bpos.z + bzm, rpos.x + rxm, rpos.y + ryp, rpos.z + rzm,
                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    float wmpm = w_xm * w_yp * w_zm;
#ifndef STD_W_SDF
    if (!isnan(sdfmpm.x))
    {
        total_sdf += sdfmpm * wmpm;
        total_weight += wmpm;
    }
#else
    if (!isnan(sdfmpm.x))
    {
        total_sdf.x += sdfmpm.x * wmpm * sdfmpm.y;
        total_weight.x += wmpm * sdfmpm.y;
        total_sdf.y += wmpm * sdfmpm.y;
        total_weight.y += wmpm;
    }
#endif
    else if (zero_det == 2)
    {
        return make_float2(NAN, NAN);
    }

    float2 sdfmpp = query_sdf_raw(bpos.x + bxm, bpos.y + byp, bpos.z + bzp, rpos.x + rxm, rpos.y + ryp, rpos.z + rzp,
                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    float wmpp = w_xm * w_yp * w_zp;
#ifndef STD_W_SDF
    if (!isnan(sdfmpp.x))
    {
        total_sdf += sdfmpp * wmpp;
        total_weight += wmpp;
    }
#else
    if (!isnan(sdfmpp.x))
    {
        total_sdf.x += sdfmpp.x * wmpp * sdfmpp.y;
        total_weight.x += wmpp * sdfmpp.y;
        total_sdf.y += wmpp * sdfmpp.y;
        total_weight.y += wmpp;
    }
#endif
    else if (zero_det == 3)
    {
        return make_float2(NAN, NAN);
    }

    float2 sdfpmm = query_sdf_raw(bpos.x + bxp, bpos.y + bym, bpos.z + bzm, rpos.x + rxp, rpos.y + rym, rpos.z + rzm,
                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    float wpmm = w_xp * w_ym * w_zm;
#ifndef STD_W_SDF
    if (!isnan(sdfpmm.x))
    {
        total_sdf += sdfpmm * wpmm;
        total_weight += wpmm;
    }
#else
    if (!isnan(sdfpmm.x))
    {
        total_sdf.x += sdfpmm.x * wpmm * sdfpmm.y;
        total_weight.x += wpmm * sdfpmm.y;
        total_sdf.y += wpmm * sdfpmm.y;
        total_weight.y += wpmm;
    }
#endif
    else if (zero_det == 4)
    {
        return make_float2(NAN, NAN);
    }

    float2 sdfpmp = query_sdf_raw(bpos.x + bxp, bpos.y + bym, bpos.z + bzp, rpos.x + rxp, rpos.y + rym, rpos.z + rzp,
                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    float wpmp = w_xp * w_ym * w_zp;
#ifndef STD_W_SDF
    if (!isnan(sdfpmp.x))
    {
        total_sdf += sdfpmp * wpmp;
        total_weight += wpmp;
    }
#else
    if (!isnan(sdfpmp.x))
    {
        total_sdf.x += sdfpmp.x * wpmp * sdfpmp.y;
        total_weight.x += wpmp * sdfpmp.y;
        total_sdf.y += wpmp * sdfpmp.y;
        total_weight.y += wpmp;
    }
#endif
    else if (zero_det == 5)
    {
        return make_float2(NAN, NAN);
    }

    float2 sdfppm = query_sdf_raw(bpos.x + bxp, bpos.y + byp, bpos.z + bzm, rpos.x + rxp, rpos.y + ryp, rpos.z + rzm,
                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    float wppm = w_xp * w_yp * w_zm;
#ifndef STD_W_SDF
    if (!isnan(sdfppm.x))
    {
        total_sdf += sdfppm * wppm;
        total_weight += wppm;
    }
#else
    if (!isnan(sdfppm.x))
    {
        total_sdf.x += sdfppm.x * wppm * sdfppm.y;
        total_weight.x += wppm * sdfppm.y;
        total_sdf.y += wppm * sdfppm.y;
        total_weight.y += wppm;
    }
#endif
    else if (zero_det == 6)
    {
        return make_float2(NAN, NAN);
    }

    float2 sdfppp = query_sdf_raw(bpos.x + bxp, bpos.y + byp, bpos.z + bzp, rpos.x + rxp, rpos.y + ryp, rpos.z + rzp,
                                  max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    float wppp = w_xp * w_yp * w_zp;
#ifndef STD_W_SDF
    if (!isnan(sdfppp.x))
    {
        total_sdf += sdfppp * wppp;
        total_weight += wppp;
    }
#else
    if (!isnan(sdfppp.x))
    {
        total_sdf.x += sdfppp.x * wppp * sdfppp.y;
        total_weight.x += wppp * sdfppp.y;
        total_sdf.y += wppp * sdfppp.y;
        total_weight.y += wppp;
    }
#endif
    else if (zero_det == 7)
    {
        return make_float2(NAN, NAN);
    }

    // If NAN, will also be handled.
    return total_sdf / total_weight;
}

__device__ static inline float4 sdf_interp(const float3 p1, const float3 p2, const float stdp1, const float stdp2,
                                           float valp1, float valp2)
{
    if (fabs(0.0f - valp1) < 1.0e-5f)
        return make_float4(p1, stdp1);
    if (fabs(0.0f - valp2) < 1.0e-5f)
        return make_float4(p2, stdp2);
    if (fabs(valp1 - valp2) < 1.0e-5f)
        return make_float4(p1, stdp1);

    float w2 = (0.0f - valp1) / (valp2 - valp1);
    float w1 = 1 - w2;

    return make_float4(p1.x * w1 + p2.x * w2,
                       p1.y * w1 + p2.y * w2,
                       p1.z * w1 + p2.z * w2,
                       stdp1 * w1 + stdp2 * w2);
}

__global__ static void meshing_cube(const IndexerAccessor indexer,
                                    const ValidBlocksAccessor valid_blocks,
                                    const BackwardMappingAccessor vec_batch_mapping,
                                    const CubeSDFAccessor cube_sdf,
                                    const CubeSDFAccessor cube_std,
                                    TrianglesAccessor triangles,
                                    TriangleStdAccessor triangle_std,
                                    TriangleVecIdAccessor triangle_flatten_id,
                                    int *__restrict__ triangles_count,
                                    int max_triangles_count,
                                    const uint max_vec_num,
                                    int nx, int ny, int nz,
                                    float max_std)
{
    const uint r = cube_sdf.size(1) / 2;
    const uint r3 = r * r * r;
    const uint num_lif = valid_blocks.size(0);
    const float sbs = 1.0f / r; // sub-block-size

    const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
    const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;

    if (lif_id >= num_lif || sub_id >= r3)
    {
        return;
    }

    const uint3 bpos = make_uint3(
        (valid_blocks[lif_id] / (ny * nz)) % nx,
        (valid_blocks[lif_id] / nz) % ny,
        valid_blocks[lif_id] % nz);
    const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
    const uint rx = sub_id / (r * r);
    const uint ry = (sub_id / r) % r;
    const uint rz = sub_id % r;

    // Find all 8 neighbours
    float3 points[8];
    float2 sdf_vals[8];

    sdf_vals[0] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[0].x))
        return;
    points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);

    sdf_vals[1] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[1].x))
        return;
    points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);

    sdf_vals[2] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[2].x))
        return;
    points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);

    sdf_vals[3] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[3].x))
        return;
    points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);

    sdf_vals[4] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[4].x))
        return;
    points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);

    sdf_vals[5] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[5].x))
        return;
    points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);

    sdf_vals[6] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[6].x))
        return;
    points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);

    sdf_vals[7] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[7].x))
        return;
    points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);

    // Find triangle config.
    int cube_type = 0;
    if (sdf_vals[0].x < 0)
        cube_type |= 1;
    if (sdf_vals[1].x < 0)
        cube_type |= 2;
    if (sdf_vals[2].x < 0)
        cube_type |= 4;
    if (sdf_vals[3].x < 0)
        cube_type |= 8;
    if (sdf_vals[4].x < 0)
        cube_type |= 16;
    if (sdf_vals[5].x < 0)
        cube_type |= 32;
    if (sdf_vals[6].x < 0)
        cube_type |= 64;
    if (sdf_vals[7].x < 0)
        cube_type |= 128;

    // Find vertex position on each edge (weighted by sdf value)
    int edge_config = edgeTable[cube_type];
    float4 vert_list[12];

    if (edge_config == 0)
        return;
    if (edge_config & 1)
        vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0].y, sdf_vals[1].y, sdf_vals[0].x, sdf_vals[1].x);
    if (edge_config & 2)
        vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1].y, sdf_vals[2].y, sdf_vals[1].x, sdf_vals[2].x);
    if (edge_config & 4)
        vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2].y, sdf_vals[3].y, sdf_vals[2].x, sdf_vals[3].x);
    if (edge_config & 8)
        vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3].y, sdf_vals[0].y, sdf_vals[3].x, sdf_vals[0].x);
    if (edge_config & 16)
        vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4].y, sdf_vals[5].y, sdf_vals[4].x, sdf_vals[5].x);
    if (edge_config & 32)
        vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5].y, sdf_vals[6].y, sdf_vals[5].x, sdf_vals[6].x);
    if (edge_config & 64)
        vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6].y, sdf_vals[7].y, sdf_vals[6].x, sdf_vals[7].x);
    if (edge_config & 128)
        vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7].y, sdf_vals[4].y, sdf_vals[7].x, sdf_vals[4].x);
    if (edge_config & 256)
        vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0].y, sdf_vals[4].y, sdf_vals[0].x, sdf_vals[4].x);
    if (edge_config & 512)
        vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1].y, sdf_vals[5].y, sdf_vals[1].x, sdf_vals[5].x);
    if (edge_config & 1024)
        vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2].y, sdf_vals[6].y, sdf_vals[2].x, sdf_vals[6].x);
    if (edge_config & 2048)
        vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3].y, sdf_vals[7].y, sdf_vals[3].x, sdf_vals[7].x);

    // Write triangles to array.
    float4 vp[3];
    for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
    {
#pragma unroll
        for (int vi = 0; vi < 3; ++vi)
        {
            vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
        }
        if (vp[0].w > max_std || vp[1].w > max_std || vp[2].w > max_std)
        {
            continue;
        }
        int triangle_id = atomicAdd(triangles_count, 1);
        if (triangle_id < max_triangles_count)
        {
#pragma unroll
            for (int vi = 0; vi < 3; ++vi)
            {
                triangles[triangle_id][vi][0] = vp[vi].x;
                triangles[triangle_id][vi][1] = vp[vi].y;
                triangles[triangle_id][vi][2] = vp[vi].z;
                triangle_std[triangle_id][vi] = vp[vi].w;
            }
            triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
        }
    }
}

std::vector<torch::Tensor> marching_cubes_sparse_interp_cuda(
    torch::Tensor indexer,           // (nx, ny, nz) -> data_id
    torch::Tensor valid_blocks,      // (K, )
    torch::Tensor vec_batch_mapping, //
    torch::Tensor cube_sdf,          // (M, rx, ry, rz)
    torch::Tensor cube_std,          // (M, rx, ry, rz)
    const std::vector<int> &n_xyz,   // [nx, ny, nz]
    float max_std,                   // Prune all vertices
    int max_n_triangles              // Maximum number of triangle buffer
)
{
    CHECK_INPUT(indexer);
    CHECK_INPUT(valid_blocks);
    CHECK_INPUT(cube_sdf);
    CHECK_INPUT(cube_std);
    CHECK_INPUT(vec_batch_mapping);
    assert(max_n_triangles > 0);

    const int r = cube_sdf.size(1) / 2;
    const int r3 = r * r * r;
    const int num_lif = valid_blocks.size(0);
    const uint max_vec_num = vec_batch_mapping.size(0);

    torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3},
                                           torch::dtype(torch::kFloat32).device(torch::kCUDA));
    torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));
    torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));

    dim3 dimBlock = dim3(16, 16);
    uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;
    uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;
    dim3 dimGrid = dim3(xBlocks, yBlocks);

    thrust::device_vector<int> n_output(1, 0);
    meshing_cube<<<dimGrid, dimBlock, 0, at::cuda::getCurrentCUDAStream()>>>(
        indexer.packed_accessor32<int64_t, 3, torch::RestrictPtrTraits>(),
        valid_blocks.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
        vec_batch_mapping.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
        cube_sdf.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
        cube_std.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
        triangles.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
        triangle_std.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
        triangle_flatten_id.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
        n_output.data().get(), max_n_triangles, max_vec_num,
        n_xyz[0], n_xyz[1], n_xyz[2], max_std);
    cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());

    int output_n_triangles = n_output[0];
    if (output_n_triangles < max_n_triangles)
    {
        // Trim output tensor if it is not full.
        triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
        triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
        triangle_std = triangle_std.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
    }
    else
    {
        // Otherwise spawn a warning.
        std::cerr << "Warning from marching cube: the max triangle number is too small " << output_n_triangles << " vs " << max_n_triangles << std::endl;
    }

    return {triangles, triangle_flatten_id, triangle_std};
}


================================================
FILE: third_party/marching_cubes/src/mc_kernel.cu
================================================
#include "mc_data.cuh"

#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>

__device__ static inline float2 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
                                              const uint max_vec_num,
                                              const IndexerAccessor indexer,
                                              const CubeSDFAccessor cube_sdf,
                                              const CubeSDFAccessor cube_std,
                                              const BackwardMappingAccessor vec_batch_mapping)
{
    if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
    {
        return make_float2(NAN, NAN);
    }
    //    printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
    long long vec_ind = indexer[bx][by][bz];
    if (vec_ind == -1 || vec_ind >= max_vec_num)
    {
        return make_float2(NAN, NAN);
    }
    int batch_ind = vec_batch_mapping[vec_ind];
    if (batch_ind == -1)
    {
        return make_float2(NAN, NAN);
    }
    //    printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
    float sdf = cube_sdf[batch_ind][arx][ary][arz];
    float std = cube_std[batch_ind][arx][ary][arz];
    return make_float2(sdf, std);
}

// Use stddev to weight sdf value.
// #define STD_W_SDF

__device__ static inline float2 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
                                        const IndexerAccessor indexer,
                                        const CubeSDFAccessor cube_sdf,
                                        const CubeSDFAccessor cube_std,
                                        const BackwardMappingAccessor vec_batch_mapping)
{
    if (bpos.x >= bsize.x)
    {
        bpos.x = bsize.x - 1;
        rpos.x = r - 1;
    }
    if (bpos.y >= bsize.y)
    {
        bpos.y = bsize.y - 1;
        rpos.y = r - 1;
    }
    if (bpos.z >= bsize.z)
    {
        bpos.z = bsize.z - 1;
        rpos.z = r - 1;
    }

    if (rpos.x == r)
    {
        bpos.x += 1;
        rpos.x = 0;
    }

    if (rpos.y == r)
    {
        bpos.y += 1;
        rpos.y = 0;
    }

    if (rpos.z == r)
    {
        bpos.z += 1;
        rpos.z = 0;
    }

    float2 total_sdf = query_sdf_raw(bpos.x, bpos.y, bpos.z, rpos.x, rpos.y, rpos.z, max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);

    // If NAN, will also be handled.
    return total_sdf;
}

__device__ static inline float4 sdf_interp(const float3 p1, const float3 p2, const float stdp1, const float stdp2,
                                           float valp1, float valp2)
{
    if (fabs(0.0f - valp1) < 1.0e-5f)
        return make_float4(p1, stdp1);
    if (fabs(0.0f - valp2) < 1.0e-5f)
        return make_float4(p2, stdp2);
    if (fabs(valp1 - valp2) < 1.0e-5f)
        return make_float4(p1, stdp1);

    float w2 = (0.0f - valp1) / (valp2 - valp1);
    float w1 = 1 - w2;

    return make_float4(p1.x * w1 + p2.x * w2,
                       p1.y * w1 + p2.y * w2,
                       p1.z * w1 + p2.z * w2,
                       stdp1 * w1 + stdp2 * w2);
}

__global__ static void meshing_cube(const IndexerAccessor indexer,
                                    const ValidBlocksAccessor valid_blocks,
                                    const BackwardMappingAccessor vec_batch_mapping,
                                    const CubeSDFAccessor cube_sdf,
                                    const CubeSDFAccessor cube_std,
                                    TrianglesAccessor triangles,
                                    TriangleStdAccessor triangle_std,
                                    TriangleVecIdAccessor triangle_flatten_id,
                                    int *__restrict__ triangles_count,
                                    int max_triangles_count,
                                    const uint max_vec_num,
                                    int nx, int ny, int nz,
                                    float max_std)
{
    const uint r = cube_sdf.size(1);
    const uint r3 = r * r * r;
    const uint num_lif = valid_blocks.size(0);
    const float sbs = 1.0f / r; // sub-block-size

    const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
    const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;

    if (lif_id >= num_lif || sub_id >= r3)
    {
        return;
    }

    const uint3 bpos = make_uint3(
        (valid_blocks[lif_id] / (ny * nz)) % nx,
        (valid_blocks[lif_id] / nz) % ny,
        valid_blocks[lif_id] % nz);
    const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
    const uint rx = sub_id / (r * r);
    const uint ry = (sub_id / r) % r;
    const uint rz = sub_id % r;

    // Find all 8 neighbours
    float3 points[8];
    float2 sdf_vals[8];

    sdf_vals[0] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[0].x))
        return;
    points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);

    sdf_vals[1] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[1].x))
        return;
    points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);

    sdf_vals[2] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[2].x))
        return;
    points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);

    sdf_vals[3] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[3].x))
        return;
    points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);

    sdf_vals[4] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[4].x))
        return;
    points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);

    sdf_vals[5] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[5].x))
        return;
    points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);

    sdf_vals[6] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[6].x))
        return;
    points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);

    sdf_vals[7] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_vals[7].x))
        return;
    points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);

    // Find triangle config.
    int cube_type = 0;
    if (sdf_vals[0].x < 0)
        cube_type |= 1;
    if (sdf_vals[1].x < 0)
        cube_type |= 2;
    if (sdf_vals[2].x < 0)
        cube_type |= 4;
    if (sdf_vals[3].x < 0)
        cube_type |= 8;
    if (sdf_vals[4].x < 0)
        cube_type |= 16;
    if (sdf_vals[5].x < 0)
        cube_type |= 32;
    if (sdf_vals[6].x < 0)
        cube_type |= 64;
    if (sdf_vals[7].x < 0)
        cube_type |= 128;

    // Find vertex position on each edge (weighted by sdf value)
    int edge_config = edgeTable[cube_type];
    float4 vert_list[12];

    if (edge_config == 0)
        return;
    if (edge_config & 1)
        vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0].y, sdf_vals[1].y, sdf_vals[0].x, sdf_vals[1].x);
    if (edge_config & 2)
        vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1].y, sdf_vals[2].y, sdf_vals[1].x, sdf_vals[2].x);
    if (edge_config & 4)
        vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2].y, sdf_vals[3].y, sdf_vals[2].x, sdf_vals[3].x);
    if (edge_config & 8)
        vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3].y, sdf_vals[0].y, sdf_vals[3].x, sdf_vals[0].x);
    if (edge_config & 16)
        vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4].y, sdf_vals[5].y, sdf_vals[4].x, sdf_vals[5].x);
    if (edge_config & 32)
        vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5].y, sdf_vals[6].y, sdf_vals[5].x, sdf_vals[6].x);
    if (edge_config & 64)
        vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6].y, sdf_vals[7].y, sdf_vals[6].x, sdf_vals[7].x);
    if (edge_config & 128)
        vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7].y, sdf_vals[4].y, sdf_vals[7].x, sdf_vals[4].x);
    if (edge_config & 256)
        vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0].y, sdf_vals[4].y, sdf_vals[0].x, sdf_vals[4].x);
    if (edge_config & 512)
        vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1].y, sdf_vals[5].y, sdf_vals[1].x, sdf_vals[5].x);
    if (edge_config & 1024)
        vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2].y, sdf_vals[6].y, sdf_vals[2].x, sdf_vals[6].x);
    if (edge_config & 2048)
        vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3].y, sdf_vals[7].y, sdf_vals[3].x, sdf_vals[7].x);

    // Write triangles to array.
    float4 vp[3];
    for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
    {
#pragma unroll
        for (int vi = 0; vi < 3; ++vi)
        {
            vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
        }

        int triangle_id = atomicAdd(triangles_count, 1);
        if (triangle_id < max_triangles_count)
        {
#pragma unroll
            for (int vi = 0; vi < 3; ++vi)
            {
                triangles[triangle_id][vi][0] = vp[vi].x;
                triangles[triangle_id][vi][1] = vp[vi].y;
                triangles[triangle_id][vi][2] = vp[vi].z;
                triangle_std[triangle_id][vi] = vp[vi].w;
            }
            triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
        }
    }
}

std::vector<torch::Tensor> marching_cubes_sparse(
    torch::Tensor indexer,           // (nx, ny, nz) -> data_id
    torch::Tensor valid_blocks,      // (K, )
    torch::Tensor vec_batch_mapping, //
    torch::Tensor cube_sdf,          // (M, rx, ry, rz)
    torch::Tensor cube_std,          // (M, rx, ry, rz)
    const std::vector<int> &n_xyz,   // [nx, ny, nz]
    float max_std,                   // Prune all vertices
    int max_n_triangles              // Maximum number of triangle buffer
)
{
    CHECK_INPUT(indexer);
    CHECK_INPUT(valid_blocks);
    CHECK_INPUT(cube_sdf);
    CHECK_INPUT(cube_std);
    CHECK_INPUT(vec_batch_mapping);
    assert(max_n_triangles > 0);

    const int r = cube_sdf.size(1);
    const int r3 = r * r * r;
    const int num_lif = valid_blocks.size(0);
    const uint max_vec_num = vec_batch_mapping.size(0);

    torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3},
                                           torch::dtype(torch::kFloat32).device(torch::kCUDA));
    torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));
    torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));

    dim3 dimBlock = dim3(16, 16);
    uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;
    uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;
    dim3 dimGrid = dim3(xBlocks, yBlocks);

    thrust::device_vector<int> n_output(1, 0);
    meshing_cube<<<dimGrid, dimBlock, 0, at::cuda::getCurrentCUDAStream()>>>(
        indexer.packed_accessor32<int64_t, 3, torch::RestrictPtrTraits>(),
        valid_blocks.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
        vec_batch_mapping.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
        cube_sdf.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
        cube_std.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
        triangles.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
        triangle_std.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
        triangle_flatten_id.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
        n_output.data().get(), max_n_triangles, max_vec_num,
        n_xyz[0], n_xyz[1], n_xyz[2], max_std);
    cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());

    int output_n_triangles = n_output[0];
    if (output_n_triangles < max_n_triangles)
    {
        // Trim output tensor if it is not full.
        triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
        triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
        triangle_std = triangle_std.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
    }
    else
    {
        // Otherwise spawn a warning.
        std::cerr << "Warning from marching cube: the max triangle number is too small " << output_n_triangles << " vs " << max_n_triangles << std::endl;
    }

    return {triangles, triangle_flatten_id, triangle_std};
}


================================================
FILE: third_party/marching_cubes/src/mc_kernel_colour.cu
================================================
#include "mc_data.cuh"

#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>

__device__ static inline float4 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
                                              const uint max_vec_num,
                                              const IndexerAccessor indexer,
                                              const CubeSDFRGBAccessor cube_sdf,
                                              const CubeSDFAccessor cube_std,
                                              const BackwardMappingAccessor vec_batch_mapping)
{
    if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
    {
        return make_float4(NAN, NAN, NAN, NAN);
    }
    //    printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
    long long vec_ind = indexer[bx][by][bz];
    if (vec_ind == -1 || vec_ind >= max_vec_num)
    {
        return make_float4(NAN, NAN, NAN, NAN);
    }
    int batch_ind = vec_batch_mapping[vec_ind];
    if (batch_ind == -1)
    {
        return make_float4(NAN, NAN, NAN, NAN);
    }
    //    printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
    return make_float4(cube_sdf[batch_ind][arx][ary][arz][3],
                       cube_sdf[batch_ind][arx][ary][arz][0],
                       cube_sdf[batch_ind][arx][ary][arz][1],
                       cube_sdf[batch_ind][arx][ary][arz][2]);
}

// Use stddev to weight sdf value.
// #define STD_W_SDF

__device__ static inline float4 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
                                        const IndexerAccessor indexer,
                                        const CubeSDFRGBAccessor cube_sdf,
                                        const CubeSDFAccessor cube_std,
                                        const BackwardMappingAccessor vec_batch_mapping)
{
    if (bpos.x >= bsize.x)
    {
        bpos.x = bsize.x - 1;
        rpos.x = r - 1;
    }
    if (bpos.y >= bsize.y)
    {
        bpos.y = bsize.y - 1;
        rpos.y = r - 1;
    }
    if (bpos.z >= bsize.z)
    {
        bpos.z = bsize.z - 1;
        rpos.z = r - 1;
    }

    if (rpos.x == r)
    {
        bpos.x += 1;
        rpos.x = 0;
    }

    if (rpos.y == r)
    {
        bpos.y += 1;
        rpos.y = 0;
    }

    if (rpos.z == r)
    {
        bpos.z += 1;
        rpos.z = 0;
    }

    float4 total_sdf = query_sdf_raw(bpos.x, bpos.y, bpos.z, rpos.x, rpos.y, rpos.z, max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);

    // If NAN, will also be handled.
    return total_sdf;
}

__device__ static inline float3 sdf_interp(const float3 p1, const float3 p2,
                                           float valp1, float valp2)
{
    if (fabs(0.0f - valp1) < 1.0e-5f)
        return p1;
    if (fabs(0.0f - valp2) < 1.0e-5f)
        return p2;
    if (fabs(valp1 - valp2) < 1.0e-5f)
        return p1;

    float w2 = (0.0f - valp1) / (valp2 - valp1);
    float w1 = 1 - w2;

    return make_float3(p1.x * w1 + p2.x * w2,
                       p1.y * w1 + p2.y * w2,
                       p1.z * w1 + p2.z * w2);
}

__global__ static void meshing_cube_colour(const IndexerAccessor indexer,
                                           const ValidBlocksAccessor valid_blocks,
                                           const BackwardMappingAccessor vec_batch_mapping,
                                           const CubeSDFRGBAccessor cube_sdf,
                                           const CubeSDFAccessor cube_std,
                                           TrianglesAccessor triangles,
                                           TrianglesAccessor vertex_colours,
                                           TriangleVecIdAccessor triangle_flatten_id,
                                           int *__restrict__ triangles_count,
                                           int max_triangles_count,
                                           const uint max_vec_num,
                                           int nx, int ny, int nz,
                                           float max_std)
{
    const uint r = cube_sdf.size(1);
    const uint r3 = r * r * r;
    const uint num_lif = valid_blocks.size(0);
    const float sbs = 1.0f / r; // sub-block-size

    const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
    const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;

    if (lif_id >= num_lif || sub_id >= r3)
    {
        return;
    }

    const uint3 bpos = make_uint3(
        (valid_blocks[lif_id] / (ny * nz)) % nx,
        (valid_blocks[lif_id] / nz) % ny,
        valid_blocks[lif_id] % nz);
    const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
    const uint rx = sub_id / (r * r);
    const uint ry = (sub_id / r) % r;
    const uint rz = sub_id % r;

    // Find all 8 neighbours
    float3 points[8];
    float3 colours[8];
    float sdf_vals[8];

    float4 sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_val.x))
        return;
    sdf_vals[0] = sdf_val.x;
    points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
    colours[0] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);

    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_val.x))
        return;
    sdf_vals[1] = sdf_val.x;
    points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
    colours[1] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);

    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_val.x))
        return;
    sdf_vals[2] = sdf_val.x;
    points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
    colours[2] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);

    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_val.x))
        return;
    sdf_vals[3] = sdf_val.x;
    points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
    colours[3] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);

    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_val.x))
        return;
    sdf_vals[4] = sdf_val.x;
    points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
    colours[4] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);

    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_val.x))
        return;
    sdf_vals[5] = sdf_val.x;
    points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
    colours[5] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);

    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_val.x))
        return;
    sdf_vals[6] = sdf_val.x;
    points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
    colours[6] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);

    sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
    if (isnan(sdf_val.x))
        return;
    sdf_vals[7] = sdf_val.x;
    points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
    colours[7] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);

    // Find triangle config.
    int cube_type = 0;
    if (sdf_vals[0] < 0)
        cube_type |= 1;
    if (sdf_vals[1] < 0)
        cube_type |= 2;
    if (sdf_vals[2] < 0)
        cube_type |= 4;
    if (sdf_vals[3] < 0)
        cube_type |= 8;
    if (sdf_vals[4] < 0)
        cube_type |= 16;
    if (sdf_vals[5] < 0)
        cube_type |= 32;
    if (sdf_vals[6] < 0)
        cube_type |= 64;
    if (sdf_vals[7] < 0)
        cube_type |= 128;

    // Find vertex position on each edge (weighted by sdf value)
    int edge_config = edgeTable[cube_type];
    float3 vert_list[12];
    float3 rgb_list[12];

    if (edge_config == 0)
        return;
    if (edge_config & 1)
    {
        vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0], sdf_vals[1]);
        rgb_list[0] = sdf_interp(colours[0], colours[1], sdf_vals[0], sdf_vals[1]);
    }
    if (edge_config & 2)
    {
        vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1], sdf_vals[2]);
        rgb_list[1] = sdf_interp(colours[1], colours[2], sdf_vals[1], sdf_vals[2]);
    }
    if (edge_config & 4)
    {
        vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2], sdf_vals[3]);
        rgb_list[2] = sdf_interp(colours[2], colours[3], sdf_vals[2], sdf_vals[3]);
    }
    if (edge_config & 8)
    {
        vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3], sdf_vals[0]);
        rgb_list[3] = sdf_interp(colours[3], colours[0], sdf_vals[3], sdf_vals[0]);
    }
    if (edge_config & 16)
    {
        vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4], sdf_vals[5]);
        rgb_list[4] = sdf_interp(colours[4], colours[5], sdf_vals[4], sdf_vals[5]);
    }
    if (edge_config & 32)
    {
        vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5], sdf_vals[6]);
        rgb_list[5] = sdf_interp(colours[5], colours[6], sdf_vals[5], sdf_vals[6]);
    }
    if (edge_config & 64)
    {
        vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6], sdf_vals[7]);
        rgb_list[6] = sdf_interp(colours[6], colours[7], sdf_vals[6], sdf_vals[7]);
    }
    if (edge_config & 128)
    {
        vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7], sdf_vals[4]);
        rgb_list[7] = sdf_interp(colours[7], colours[4], sdf_vals[7], sdf_vals[4]);
    }
    if (edge_config & 256)
    {
        vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0], sdf_vals[4]);
        rgb_list[8] = sdf_interp(colours[0], colours[4], sdf_vals[0], sdf_vals[4]);
    }
    if (edge_config & 512)
    {
        vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1], sdf_vals[5]);
        rgb_list[9] = sdf_interp(colours[1], colours[5], sdf_vals[1], sdf_vals[5]);
    }
    if (edge_config & 1024)
    {
        vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2], sdf_vals[6]);
        rgb_list[10] = sdf_interp(colours[2], colours[6], sdf_vals[2], sdf_vals[6]);
    }
    if (edge_config & 2048)
    {
        vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3], sdf_vals[7]);
        rgb_list[11] = sdf_interp(colours[3], colours[7], sdf_vals[3], sdf_vals[7]);
    }

    // Write triangles to array.
    float3 vp[3];
    float3 vc[3];
    for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
    {
#pragma unroll
        for (int vi = 0; vi < 3; ++vi)
        {
            vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
            vc[vi] = rgb_list[triangleTable[cube_type][i + vi]];
        }

        int triangle_id = atomicAdd(triangles_count, 1);
        if (triangle_id < max_triangles_count)
        {
#pragma unroll
            for (int vi = 0; vi < 3; ++vi)
            {
                triangles[triangle_id][vi][0] = vp[vi].x;
                triangles[triangle_id][vi][1] = vp[vi].y;
                triangles[triangle_id][vi][2] = vp[vi].z;
                vertex_colours[triangle_id][vi][0] = vc[vi].x;
                vertex_colours[triangle_id][vi][1] = vc[vi].y;
                vertex_colours[triangle_id][vi][2] = vc[vi].z;
            }
            triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
        }
    }
}

std::vector<torch::Tensor> marching_cubes_sparse_colour(
    torch::Tensor indexer,           // (nx, ny, nz) -> data_id
    torch::Tensor valid_blocks,      // (K, )
    torch::Tensor vec_batch_mapping, //
    torch::Tensor cube_rgb_sdf,      // (M, rx, ry, rz, 4)
    torch::Tensor cube_std,          // (M, rx, ry, rz)
    const std::vector<int> &n_xyz,   // [nx, ny, nz]
    float max_std,                   // Prune all vertices
    int max_n_triangles              // Maximum number of triangle buffer
)
{
    CHECK_INPUT(indexer);
    CHECK_INPUT(valid_blocks);
    CHECK_INPUT(cube_rgb_sdf);
    CHECK_INPUT(cube_std);
    CHECK_INPUT(vec_batch_mapping);
    assert(max_n_triangles > 0);

    const int r = cube_rgb_sdf.size(1);
    const int r3 = r * r * r;
    const int num_lif = valid_blocks.size(0);
    const uint max_vec_num = vec_batch_mapping.size(0);

    torch::Tensor triangles = torch::empty({max_n_tr
Download .txt
gitextract_t7v5lyn0/

├── .gitignore
├── LICENSE
├── Readme.md
├── configs/
│   ├── kitti/
│   │   ├── kitti.yaml
│   │   ├── kitti_00.yaml
│   │   ├── kitti_01.yaml
│   │   ├── kitti_03.yaml
│   │   ├── kitti_04.yaml
│   │   ├── kitti_05.yaml
│   │   ├── kitti_06.yaml
│   │   ├── kitti_07.yaml
│   │   ├── kitti_08.yaml
│   │   ├── kitti_09.yaml
│   │   ├── kitti_10.yaml
│   │   ├── kitti_base06.yaml
│   │   └── kitti_base10.yaml
│   ├── maicity/
│   │   ├── maicity.yaml
│   │   ├── maicity_00.yaml
│   │   └── maicity_01.yaml
│   └── ncd/
│       ├── ncd.yaml
│       └── ncd_quad.yaml
├── demo/
│   ├── parser.py
│   └── run.py
├── install.sh
├── requirements.txt
├── src/
│   ├── criterion.py
│   ├── dataset/
│   │   ├── kitti.py
│   │   ├── maicity.py
│   │   └── ncd.py
│   ├── lidarFrame.py
│   ├── loggers.py
│   ├── mapping.py
│   ├── nerfloam.py
│   ├── se3pose.py
│   ├── share.py
│   ├── tracking.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── import_util.py
│   │   ├── mesh_util.py
│   │   ├── profile_util.py
│   │   └── sample_util.py
│   └── variations/
│       ├── decode_morton.py
│       ├── lidar.py
│       ├── render_helpers.py
│       └── voxel_helpers.py
└── third_party/
    ├── marching_cubes/
    │   ├── setup.py
    │   └── src/
    │       ├── mc.cpp
    │       ├── mc_data.cuh
    │       ├── mc_interp_kernel.cu
    │       ├── mc_kernel.cu
    │       └── mc_kernel_colour.cu
    ├── sparse_octree/
    │   ├── include/
    │   │   ├── octree.h
    │   │   ├── test.h
    │   │   └── utils.h
    │   ├── setup.py
    │   └── src/
    │       ├── bindings.cpp
    │       └── octree.cpp
    └── sparse_voxels/
        ├── include/
        │   ├── cuda_utils.h
        │   ├── cutil_math.h
        │   ├── intersect.h
        │   ├── octree.h
        │   ├── sample.h
        │   └── utils.h
        ├── setup.py
        └── src/
            ├── binding.cpp
            ├── intersect.cpp
            ├── intersect_gpu.cu
            ├── octree.cpp
            ├── sample.cpp
            └── sample_gpu.cu
Download .txt
SYMBOL INDEX (305 symbols across 33 files)

FILE: demo/parser.py
  class ArgumentParserX (line 4) | class ArgumentParserX(argparse.ArgumentParser):
    method __init__ (line 5) | def __init__(self, **kwargs):
    method parse_args (line 9) | def parse_args(self, args=None, namespace=None):
    method parse_config_yaml (line 23) | def parse_config_yaml(self, yaml_path, args=None):
    method convert_to_namespace (line 39) | def convert_to_namespace(self, dict_in, args=None):
    method update_recursive (line 48) | def update_recursive(self, dict1, dict2):
  function get_parser (line 58) | def get_parser():

FILE: demo/run.py
  function setup_seed (line 12) | def setup_seed(seed):

FILE: src/criterion.py
  class Criterion (line 6) | class Criterion(nn.Module):
    method __init__ (line 7) | def __init__(self, args) -> None:
    method forward (line 16) | def forward(self, outputs, obs, pointsCos, use_color_loss=True,
    method compute_loss (line 59) | def compute_loss(self, x, y, mask=None, loss_type="l2"):
    method get_masks (line 67) | def get_masks(self, z_vals, depth, epsilon):
    method get_sdf_loss (line 92) | def get_sdf_loss(self, z_vals, depth, predicted_sdf, truncation, valid...

FILE: src/dataset/kitti.py
  class DataLoader (line 19) | class DataLoader(Dataset):
    method __init__ (line 20) | def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1...
    method get_init_pose (line 28) | def get_init_pose(self, frame):
    method load_gt_pose (line 35) | def load_gt_pose(self):
    method load_points (line 40) | def load_points(self, index):
    method __len__ (line 72) | def __len__(self):
    method __getitem__ (line 75) | def __getitem__(self, index):

FILE: src/dataset/maicity.py
  class DataLoader (line 20) | class DataLoader(Dataset):
    method __init__ (line 21) | def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1...
    method get_init_pose (line 29) | def get_init_pose(self, frame):
    method load_gt_pose (line 36) | def load_gt_pose(self):
    method load_points (line 41) | def load_points(self, index):
    method __len__ (line 74) | def __len__(self):
    method __getitem__ (line 77) | def __getitem__(self, index):

FILE: src/dataset/ncd.py
  class DataLoader (line 21) | class DataLoader(Dataset):
    method __init__ (line 22) | def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1...
    method get_init_pose (line 30) | def get_init_pose(self, frame):
    method load_gt_pose (line 39) | def load_gt_pose(self):
    method load_points (line 49) | def load_points(self, index):
    method __len__ (line 79) | def __len__(self):
    method __getitem__ (line 82) | def __getitem__(self, index):

FILE: src/lidarFrame.py
  class LidarFrame (line 9) | class LidarFrame(nn.Module):
    method __init__ (line 10) | def __init__(self, index, points, pointsCos, pose=None, new_keyframe=F...
    method get_pose (line 26) | def get_pose(self):
    method get_translation (line 29) | def get_translation(self):
    method get_rotation (line 32) | def get_rotation(self):
    method get_points (line 35) | def get_points(self):
    method get_pointsCos (line 38) | def get_pointsCos(self):
    method set_rel_pose (line 41) | def set_rel_pose(self, rel_pose):
    method get_rel_pose (line 44) | def get_rel_pose(self):
    method get_rays (line 48) | def get_rays(self):
    method sample_rays (line 55) | def sample_rays(self, N_rays, track=False):

FILE: src/loggers.py
  class BasicLogger (line 14) | class BasicLogger:
    method __init__ (line 15) | def __init__(self, args) -> None:
    method get_random_time_str (line 33) | def get_random_time_str(self):
    method log_ckpt (line 36) | def log_ckpt(self, mapper):
    method log_config (line 52) | def log_config(self, config):
    method log_mesh (line 56) | def log_mesh(self, mesh, name="final_mesh.ply"):
    method log_point_cloud (line 60) | def log_point_cloud(self, pcd, name="final_points.ply"):
    method log_numpy_data (line 64) | def log_numpy_data(self, data, name, ind=None):
    method log_debug_data (line 73) | def log_debug_data(self, data, idx):
    method log_raw_image (line 77) | def log_raw_image(self, ind, rgb, depth):
    method log_images (line 88) | def log_images(self, ind, gt_rgb, gt_depth, rgb, depth):
    method npy2txt (line 144) | def npy2txt(self, input_path, output_path):

FILE: src/mapping.py
  function get_network_size (line 23) | def get_network_size(net):
  class Mapping (line 30) | class Mapping:
    method __init__ (line 31) | def __init__(self, args, logger: BasicLogger):
    method spin (line 93) | def spin(self, share_data, kf_buffer):
    method do_mapping (line 172) | def do_mapping(self, share_data, tracked_frame=None,
    method select_optimize_targets (line 205) | def select_optimize_targets(self, tracked_frame=None, selection_method...
    method update_share_data (line 227) | def update_share_data(self, share_data, frameid=None):
    method remove_back_points (line 235) | def remove_back_points(self, frame):
    method frame_maxdistance_change (line 257) | def frame_maxdistance_change(self, frame, distance):
    method insert_keyframe (line 266) | def insert_keyframe(self, frame, valid_distance=-1):
    method create_voxels (line 283) | def create_voxels(self, frame):
    method get_embeddings (line 294) | def get_embeddings(self, points_idx):
    method update_grid_features (line 320) | def update_grid_features(self):
    method get_updated_poses (line 342) | def get_updated_poses(self, offset=-2000):
    method extract_mesh (line 354) | def extract_mesh(self, res=8, clean_mesh=False):
    method extract_voxels (line 381) | def extract_voxels(self, offset=-10):
    method save_debug_data (line 392) | def save_debug_data(self, tracked_frame, offset=-10):

FILE: src/nerfloam.py
  class nerfloam (line 15) | class nerfloam:
    method __init__ (line 16) | def __init__(self, args):
    method start (line 40) | def start(self):
    method wait_child_processes (line 55) | def wait_child_processes(self):
    method get_raw_trajectory (line 60) | def get_raw_trajectory(self):
    method get_keyframe_poses (line 64) | def get_keyframe_poses(self):

FILE: src/se3pose.py
  class OptimizablePose (line 8) | class OptimizablePose(nn.Module):
    method __init__ (line 9) | def __init__(self, init_pose):
    method copy_from (line 15) | def copy_from(self, pose):
    method matrix (line 18) | def matrix(self):
    method rotation (line 24) | def rotation(self):
    method translation (line 34) | def translation(self,):
    method log (line 38) | def log(cls, R, eps=1e-7):  # [...,3,3]
    method from_matrix (line 48) | def from_matrix(cls, Rt, eps=1e-8):  # [...,3,4]
    method skew_symmetric (line 54) | def skew_symmetric(cls, w):
    method taylor_A (line 64) | def taylor_A(cls, x, nth=10):
    method taylor_B (line 75) | def taylor_B(cls, x, nth=10):
    method taylor_C (line 85) | def taylor_C(cls, x, nth=10):

FILE: src/share.py
  class ShareDataProxy (line 8) | class ShareDataProxy(NamespaceProxy):
  class ShareData (line 12) | class ShareData:
    method __init__ (line 16) | def __init__(self):
    method decoder (line 27) | def decoder(self):
    method decoder (line 34) | def decoder(self, decoder):
    method voxels (line 41) | def voxels(self):
    method voxels (line 48) | def voxels(self, voxels):
    method octree (line 55) | def octree(self):
    method octree (line 62) | def octree(self, octree):
    method states (line 69) | def states(self):
    method states (line 76) | def states(self, states):
    method stop_mapping (line 83) | def stop_mapping(self):
    method stop_mapping (line 90) | def stop_mapping(self, stop_mapping):
    method stop_tracking (line 97) | def stop_tracking(self):
    method stop_tracking (line 104) | def stop_tracking(self, stop_tracking):
    method tracking_trajectory (line 111) | def tracking_trajectory(self):
    method push_pose (line 117) | def push_pose(self, pose):

FILE: src/tracking.py
  class Tracking (line 15) | class Tracking:
    method __init__ (line 16) | def __init__(self, args, data_stream, logger):
    method process_first_frame (line 51) | def process_first_frame(self, kf_buffer):
    method spin (line 64) | def spin(self, share_data, kf_buffer):
    method check_keyframe (line 92) | def check_keyframe(self, check_frame, kf_buffer):
    method do_tracking (line 98) | def do_tracking(self, share_data, current_frame, kf_buffer):

FILE: src/utils/import_util.py
  function get_dataset (line 4) | def get_dataset(args):
  function get_decoder (line 8) | def get_decoder(args):
  function get_property (line 12) | def get_property(args, name, default):

FILE: src/utils/mesh_util.py
  class MeshExtractor (line 11) | class MeshExtractor:
    method __init__ (line 12) | def __init__(self, args):
    method linearize_id (line 18) | def linearize_id(self, xyz, n_xyz):
    method downsample_points (line 22) | def downsample_points(self, points, voxel_size=0.01):
    method get_rays (line 29) | def get_rays(self, w=None, h=None, K=None):
    method get_valid_points (line 47) | def get_valid_points(self, frame_poses, depth_maps):
    method create_mesh (line 80) | def create_mesh(self, decoder, map_states, voxel_size, voxels,
    method marching_cubes (line 145) | def marching_cubes(self, voxels, sdf):

FILE: src/utils/profile_util.py
  class Profiler (line 5) | class Profiler(object):
    method __init__ (line 6) | def __init__(self, verbose=False) -> None:
    method enable (line 12) | def enable(self):
    method disable (line 15) | def disable(self):
    method tick (line 18) | def tick(self, name):
    method tok (line 25) | def tok(self, name):

FILE: src/utils/sample_util.py
  function sampling_without_replacement (line 4) | def sampling_without_replacement(logp, k):
  function sample_rays (line 12) | def sample_rays(mask, num_samples):

FILE: src/variations/decode_morton.py
  function compact (line 4) | def compact(value):
  function decode (line 14) | def decode(code):

FILE: src/variations/lidar.py
  class GaussianFourierFeatureTransform (line 6) | class GaussianFourierFeatureTransform(torch.nn.Module):
    method __init__ (line 16) | def __init__(self, num_input_channels, mapping_size=93, scale=25, lear...
    method forward (line 26) | def forward(self, x):
  class Nerf_positional_embedding (line 33) | class Nerf_positional_embedding(torch.nn.Module):
    method __init__ (line 39) | def __init__(self, in_dim, multires, log_sampling=True):
    method forward (line 50) | def forward(self, x):
  class Same (line 71) | class Same(nn.Module):
    method __init__ (line 72) | def __init__(self, in_dim) -> None:
    method forward (line 76) | def forward(self, x):
  class Decoder (line 80) | class Decoder(nn.Module):
    method __init__ (line 81) | def __init__(self,
    method get_values (line 109) | def get_values(self, input):
    method forward (line 125) | def forward(self, inputs):

FILE: src/variations/render_helpers.py
  function ray (line 9) | def ray(ray_start, ray_dir, depths):
  function fill_in (line 13) | def fill_in(shape, mask, input, initial=1.0):
  function masked_scatter (line 21) | def masked_scatter(mask, x):
  function masked_scatter_ones (line 30) | def masked_scatter_ones(mask, x):
  function trilinear_interp (line 40) | def trilinear_interp(p, q, point_feats):
  function offset_points (line 49) | def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
  function get_embeddings (line 63) | def get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size):
  function get_features (line 74) | def get_features(samples, map_states, voxel_size):
  function get_scores (line 97) | def get_scores(sdf_network, map_states, voxel_size, bits=8):
  function eval_points (line 157) | def eval_points(sdf_network, map_states, sampled_xyz, sampled_idx, voxel...
  function render_rays (line 190) | def render_rays(
  function bundle_adjust_frames (line 321) | def bundle_adjust_frames(
  function track_frame (line 428) | def track_frame(

FILE: src/variations/voxel_helpers.py
  class BallRayIntersect (line 27) | class BallRayIntersect(Function):
    method forward (line 29) | def forward(ctx, radius, n_max, points, ray_start, ray_dir):
    method backward (line 42) | def backward(ctx, a, b, c):
  class AABBRayIntersect (line 49) | class AABBRayIntersect(Function):
    method forward (line 51) | def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir):
    method backward (line 85) | def backward(ctx, a, b, c):
  class SparseVoxelOctreeRayIntersect (line 92) | class SparseVoxelOctreeRayIntersect(Function):
    method forward (line 94) | def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir):
    method backward (line 136) | def backward(ctx, a, b, c):
  class TriangleRayIntersect (line 143) | class TriangleRayIntersect(Function):
    method forward (line 145) | def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start...
    method backward (line 187) | def backward(ctx, a, b, c):
  class UniformRaySampling (line 194) | class UniformRaySampling(Function):
    method forward (line 196) | def forward(
    method backward (line 255) | def backward(ctx, a, b, c):
  class InverseCDFRaySampling (line 262) | class InverseCDFRaySampling(Function):
    method forward (line 264) | def forward(
    method backward (line 343) | def backward(ctx, a, b, c):
  function _parallel_ray_sampling (line 352) | def _parallel_ray_sampling(
  function parallel_ray_sampling (line 411) | def parallel_ray_sampling(
  function discretize_points (line 454) | def discretize_points(voxel_points, voxel_size):
  function build_easy_octree (line 467) | def build_easy_octree(points, half_voxel):
  function trilinear_interp (line 478) | def trilinear_interp(p, q, point_feats):
  function offset_points (line 487) | def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
  function splitting_points (line 499) | def splitting_points(point_xyz, point_feats, values, half_voxel):
  function ray_intersect (line 531) | def ray_intersect(ray_start, ray_dir, flatten_centers, flatten_children,...
  function ray_sample (line 571) | def ray_sample(intersection_outputs, step_size=0.01, fixed=False):

FILE: third_party/marching_cubes/src/mc.cpp
  function PYBIND11_MODULE (line 36) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)

FILE: third_party/sparse_octree/include/octree.h
  type OcType (line 5) | enum OcType
  function class (line 12) | class Octant : public torch::CustomClassHolder
  function class (line 65) | class Octree : public torch::CustomClassHolder

FILE: third_party/sparse_octree/include/test.h
  function expand (line 41) | inline int64_t expand(int64_t value)
  function compact (line 52) | inline uint64_t compact(uint64_t value)
  function compute_morton (line 63) | inline int64_t compute_morton(int64_t x, int64_t y, int64_t z)

FILE: third_party/sparse_octree/include/utils.h
  type Vector3 (line 28) | typedef Vector3<int> Vector3i;
  type Vector3 (line 29) | typedef Vector3<float> Vector3f;
  function expand (line 64) | inline uint64_t expand(unsigned long long value)
  function compact (line 75) | inline uint64_t compact(uint64_t value)
  function compute_morton (line 86) | inline uint64_t compute_morton(uint64_t x, uint64_t y, uint64_t z)
  function Eigen (line 98) | inline Eigen::Vector3i decode(const uint64_t code)
  function encode (line 106) | inline uint64_t encode(const int x, const int y, const int z)

FILE: third_party/sparse_octree/src/bindings.cpp
  function TORCH_LIBRARY (line 4) | TORCH_LIBRARY(svo, m)

FILE: third_party/sparse_octree/src/octree.cpp
  function Octant (line 151) | Octant *Octree::find_octant(std::vector<float> coord)

FILE: third_party/sparse_voxels/include/cuda_utils.h
  function opt_n_threads (line 20) | inline int opt_n_threads(int work_size)
  function dim3 (line 27) | inline dim3 opt_block_config(int x, int y)

FILE: third_party/sparse_voxels/include/cutil_math.h
  type uint (line 30) | typedef unsigned int uint;
  type ushort (line 31) | typedef unsigned short ushort;
  function fminf (line 36) | inline float fminf(float a, float b)
  function fmaxf (line 41) | inline float fmaxf(float a, float b)
  function max (line 46) | inline int max(int a, int b)
  function min (line 51) | inline int min(int a, int b)
  function rsqrtf (line 56) | inline float rsqrtf(float x)
  function lerp (line 67) | float lerp(float a, float b, float t)
  function clamp (line 73) | float clamp(float f, float a, float b)
  function swap (line 78) | void swap(float &a, float &b)
  function swap (line 85) | void swap(int &a, int &b)
  function int2 (line 124) | int2 operator*(int2 a, int2 b)
  function int2 (line 128) | int2 operator*(int2 a, int s)
  function int2 (line 132) | int2 operator*(int s, int2 a)
  function float2 (line 146) | float2 make_float2(float s)
  function float2 (line 150) | float2 make_float2(int2 a)
  function float2 (line 184) | float2 operator*(float2 a, float2 b)
  function float2 (line 188) | float2 operator*(float2 a, float s)
  function float2 (line 192) | float2 operator*(float s, float2 a)
  function float2 (line 224) | float2 lerp(float2 a, float2 b, float t)
  function float2 (line 230) | float2 clamp(float2 v, float a, float b)
  function float2 (line 235) | float2 clamp(float2 v, float2 a, float2 b)
  function dot (line 241) | float dot(float2 a, float2 b)
  function length (line 247) | float length(float2 v)
  function float2 (line 253) | float2 normalize(float2 v)
  function float2 (line 260) | float2 floor(const float2 v)
  function float2 (line 266) | float2 reflect(float2 i, float2 n)
  function float2 (line 272) | float2 fabs(float2 v)
  function float3 (line 281) | float3 make_float3(float s)
  function float3 (line 285) | float3 make_float3(float2 a)
  function float3 (line 289) | float3 make_float3(float2 a, float s)
  function float3 (line 293) | float3 make_float3(float4 a)
  function float3 (line 297) | float3 make_float3(int3 a)
  function float3 (line 309) | float3 fminf(float3 a, float3 b)
  function float3 (line 315) | float3 fmaxf(float3 a, float3 b)
  function float3 (line 353) | float3 operator*(float3 a, float3 b)
  function float3 (line 357) | float3 operator*(float3 a, float s)
  function float3 (line 361) | float3 operator*(float s, float3 a)
  function float3 (line 401) | float3 lerp(float3 a, float3 b, float t)
  function float3 (line 407) | float3 clamp(float3 v, float a, float b)
  function float3 (line 412) | float3 clamp(float3 v, float3 a, float3 b)
  function dot (line 418) | float dot(float3 a, float3 b)
  function float3 (line 424) | float3 cross(float3 a, float3 b)
  function length (line 430) | float length(float3 v)
  function float3 (line 436) | float3 normalize(float3 v)
  function float3 (line 443) | float3 floor(const float3 v)
  function float3 (line 449) | float3 reflect(float3 i, float3 n)
  function float3 (line 455) | float3 fabs(float3 v)
  function float4 (line 464) | float4 make_float4(float s)
  function float4 (line 468) | float4 make_float4(float3 a)
  function float4 (line 472) | float4 make_float4(float3 a, float w)
  function float4 (line 476) | float4 make_float4(int4 a)
  function float4 (line 488) | float4 fminf(float4 a, float4 b)
  function float4 (line 494) | float4 fmaxf(float4 a, float4 b)
  function float4 (line 526) | float4 operator*(float4 a, float s)
  function float4 (line 530) | float4 operator*(float s, float4 a)
  function float4 (line 564) | float4 lerp(float4 a, float4 b, float t)
  function float4 (line 570) | float4 clamp(float4 v, float a, float b)
  function float4 (line 575) | float4 clamp(float4 v, float4 a, float4 b)
  function dot (line 581) | float dot(float4 a, float4 b)
  function length (line 587) | float length(float4 r)
  function float4 (line 593) | float4 normalize(float4 v)
  function float4 (line 600) | float4 floor(const float4 v)
  function float4 (line 606) | float4 fabs(float4 v)
  function int3 (line 615) | int3 make_int3(int s)
  function int3 (line 619) | int3 make_int3(float3 a)
  function int3 (line 631) | int3 min(int3 a, int3 b)
  function int3 (line 637) | int3 max(int3 a, int3 b)
  function int3 (line 668) | int3 operator*(int3 a, int3 b)
  function int3 (line 672) | int3 operator*(int3 a, int s)
  function int3 (line 676) | int3 operator*(int s, int3 a)
  function clamp (line 708) | int clamp(int f, int a, int b)
  function int3 (line 713) | int3 clamp(int3 v, int a, int b)
  function int3 (line 718) | int3 clamp(int3 v, int3 a, int3 b)
  function uint3 (line 727) | uint3 make_uint3(uint s)
  function uint3 (line 731) | uint3 make_uint3(float3 a)
  function uint3 (line 737) | uint3 min(uint3 a, uint3 b)
  function uint3 (line 743) | uint3 max(uint3 a, uint3 b)
  function uint3 (line 774) | uint3 operator*(uint3 a, uint3 b)
  function uint3 (line 778) | uint3 operator*(uint3 a, uint s)
  function uint3 (line 782) | uint3 operator*(uint s, uint3 a)
  function uint (line 814) | uint clamp(uint f, uint a, uint b)
  function uint3 (line 819) | uint3 clamp(uint3 v, uint a, uint b)
  function uint3 (line 824) | uint3 clamp(uint3 v, uint3 a, uint3 b)

FILE: third_party/sparse_voxels/src/binding.cpp
  function PYBIND11_MODULE (line 10) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)

FILE: third_party/sparse_voxels/src/intersect.cpp
  function ball_intersect (line 15) | std::tuple<at::Tensor, at::Tensor, at::Tensor> ball_intersect(at::Tensor...
  function aabb_intersect (line 49) | std::tuple<at::Tensor, at::Tensor, at::Tensor> aabb_intersect(at::Tensor...
  function svo_intersect (line 83) | std::tuple<at::Tensor, at::Tensor, at::Tensor> svo_intersect(at::Tensor ...
  function triangle_intersect (line 119) | std::tuple<at::Tensor, at::Tensor, at::Tensor> triangle_intersect(at::Te...

FILE: third_party/sparse_voxels/src/octree.cpp
  type OcTree (line 12) | struct OcTree
    type OcTree (line 17) | struct OcTree
    method init (line 18) | void init(at::Tensor center, int d, int i)
  class EasyOctree (line 28) | class EasyOctree
    method EasyOctree (line 38) | EasyOctree(at::Tensor center, int depth)
  function build_octree (line 149) | std::tuple<at::Tensor, at::Tensor> build_octree(at::Tensor center, at::T...

FILE: third_party/sparse_voxels/src/sample.cpp
  function uniform_ray_sampling (line 21) | std::tuple<at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling(
  function inverse_cdf_sampling (line 56) | std::tuple<at::Tensor, at::Tensor, at::Tensor> inverse_cdf_sampling(
Condensed preview — 70 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (291K chars).
[
  {
    "path": ".gitignore",
    "chars": 69,
    "preview": "__pycache__\n*.egg-info\nbuild/\ndist/\nlogs/\n.vscode/\nresults/\ntemp.sh\n\n"
  },
  {
    "path": "LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2023 JunyuanDeng\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "Readme.md",
    "chars": 5131,
    "preview": "# NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping\n\nThis repository cont"
  },
  {
    "path": "configs/kitti/kitti.yaml",
    "chars": 740,
    "preview": "log_dir: './logs'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n  sdf_weight: 10000.0\n  fs_weight: 1\n  eiko_weight: 0.1\n  sdf"
  },
  {
    "path": "configs/kitti/kitti_00.yaml",
    "chars": 256,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence00\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitt"
  },
  {
    "path": "configs/kitti/kitti_01.yaml",
    "chars": 270,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence01\n\n\n\ndata_specs:\n  data_path: '/home/evsjtu2/disk1/dengj"
  },
  {
    "path": "configs/kitti/kitti_03.yaml",
    "chars": 270,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence03\n\n\n\ndata_specs:\n  data_path: '/home/evsjtu2/disk1/dengj"
  },
  {
    "path": "configs/kitti/kitti_04.yaml",
    "chars": 260,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence04\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitt"
  },
  {
    "path": "configs/kitti/kitti_05.yaml",
    "chars": 273,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence05\n\n\n\ndata_specs:\n  data_path: '/home/evsjtu2/disk1/dengj"
  },
  {
    "path": "configs/kitti/kitti_06.yaml",
    "chars": 256,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence06\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitt"
  },
  {
    "path": "configs/kitti/kitti_07.yaml",
    "chars": 268,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence07\n\n\n\ndata_specs:\n  data_path: '/home/evsjtu2/disk1/dengj"
  },
  {
    "path": "configs/kitti/kitti_08.yaml",
    "chars": 256,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence08\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitt"
  },
  {
    "path": "configs/kitti/kitti_09.yaml",
    "chars": 256,
    "preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence09\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/kitt"
  },
  {
    "path": "configs/kitti/kitti_10.yaml",
    "chars": 268,
    "preview": "base_config: configs/kitti/kitti_base10.yaml\n\nexp_name: kitti/sqeuence10\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/datas"
  },
  {
    "path": "configs/kitti/kitti_base06.yaml",
    "chars": 739,
    "preview": "log_dir: './logs'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n  sdf_weight: 10000.0\n  fs_weight: 1\n  eiko_weight: 0.1\n  sdf"
  },
  {
    "path": "configs/kitti/kitti_base10.yaml",
    "chars": 932,
    "preview": "log_dir: '/home/evsjtu2/disk1/dengjunyuan/running_logs/'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n  depth_weight: 0\n  sd"
  },
  {
    "path": "configs/maicity/maicity.yaml",
    "chars": 743,
    "preview": "log_dir: './logs'\ndecoder: lidar\ndataset: maicity\n\ncriteria:\n  sdf_weight: 10000.0\n  fs_weight: 1\n  eiko_weight: 0.1\n  s"
  },
  {
    "path": "configs/maicity/maicity_00.yaml",
    "chars": 265,
    "preview": "base_config: configs/maicity/maicity.yaml\n\nexp_name: maicity/sqeuence00\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset"
  },
  {
    "path": "configs/maicity/maicity_01.yaml",
    "chars": 264,
    "preview": "base_config: configs/maicity/maicity.yaml\n\nexp_name: maicity/sqeuence01\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset"
  },
  {
    "path": "configs/ncd/ncd.yaml",
    "chars": 740,
    "preview": "log_dir: './logs'\ndecoder: lidar\ndataset: ncd\n\ncriteria:\n  sdf_weight: 10000.0\n  fs_weight: 1\n  eiko_weight: 1.0\n  sdf_t"
  },
  {
    "path": "configs/ncd/ncd_quad.yaml",
    "chars": 238,
    "preview": "base_config: configs/ncd/ncd.yaml\n\nexp_name: ncd/quad\n\n\n\ndata_specs:\n  data_path: '/home/pl21n4/dataset/ncd_example/quad"
  },
  {
    "path": "demo/parser.py",
    "chars": 2277,
    "preview": "import yaml\nimport argparse\n\nclass ArgumentParserX(argparse.ArgumentParser):\n    def __init__(self, **kwargs):\n        s"
  },
  {
    "path": "demo/run.py",
    "chars": 683,
    "preview": "import os  # noqa\nimport sys  # noqa\nsys.path.insert(0, os.path.abspath('src')) # noqa\nos.environ[\"CUDA_VISIBLE_DEVICES\""
  },
  {
    "path": "install.sh",
    "chars": 156,
    "preview": "#!/bin/bash\n\ncd third_party/marching_cubes\npython setup.py install\n\ncd ../sparse_octree\npython setup.py install\n\ncd ../s"
  },
  {
    "path": "requirements.txt",
    "chars": 71,
    "preview": "matplotlib\nopen3d\nopencv-python\nPyYAML\nscikit-image\ntqdm\ntrimesh\npyyaml"
  },
  {
    "path": "src/criterion.py",
    "chars": 4853,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.autograd import grad\n\n\nclass Criterion(nn.Module):\n    def __init__(self, "
  },
  {
    "path": "src/dataset/kitti.py",
    "chars": 3326,
    "preview": "import os.path as osp\n\nimport numpy as np\nimport torch\nfrom glob import glob\nfrom torch.utils.data import Dataset\nimport"
  },
  {
    "path": "src/dataset/maicity.py",
    "chars": 3440,
    "preview": "import os.path as osp\n\nimport numpy as np\nimport torch\nfrom glob import glob\nfrom torch.utils.data import Dataset\nimport"
  },
  {
    "path": "src/dataset/ncd.py",
    "chars": 3864,
    "preview": "import os.path as osp\n\nimport numpy as np\nimport torch\nimport open3d as o3d\nfrom glob import glob\nfrom torch.utils.data "
  },
  {
    "path": "src/lidarFrame.py",
    "chars": 1699,
    "preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom se3pose import OptimizablePose\nfrom utils.sample_util import "
  },
  {
    "path": "src/loggers.py",
    "chars": 6118,
    "preview": "import os\nimport os.path as osp\nimport pickle\nfrom datetime import datetime\n\nimport cv2\nimport matplotlib.pyplot as plt\n"
  },
  {
    "path": "src/mapping.py",
    "chars": 18659,
    "preview": "from copy import deepcopy\nimport random\nfrom time import sleep\nimport numpy as np\nfrom tqdm import tqdm\nimport torch\n\nfr"
  },
  {
    "path": "src/nerfloam.py",
    "chars": 2052,
    "preview": "from multiprocessing.managers import BaseManager\nfrom time import sleep\n\nimport torch\nimport torch.multiprocessing as mp"
  },
  {
    "path": "src/se3pose.py",
    "chars": 3279,
    "preview": "\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom copy import deepcopy\n\n\nclass OptimizablePose(nn.Module):\n   "
  },
  {
    "path": "src/share.py",
    "chars": 3355,
    "preview": "from multiprocessing.managers import BaseManager, NamespaceProxy\nfrom copy import deepcopy\nimport torch.multiprocessing "
  },
  {
    "path": "src/tracking.py",
    "chars": 5641,
    "preview": "import torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom criterion import Criterion\nfrom lidarFrame import LidarFrame"
  },
  {
    "path": "src/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/utils/import_util.py",
    "chars": 652,
    "preview": "from importlib import import_module\nimport argparse\n\ndef get_dataset(args):\n    Dataset = import_module(\"dataset.\"+args."
  },
  {
    "path": "src/utils/mesh_util.py",
    "chars": 6747,
    "preview": "import math\nimport torch\n\nimport numpy as np\nimport open3d as o3d\nfrom scipy.spatial import cKDTree\nfrom skimage.measure"
  },
  {
    "path": "src/utils/profile_util.py",
    "chars": 870,
    "preview": "from time import time\nimport torch\n\n\nclass Profiler(object):\n    def __init__(self, verbose=False) -> None:\n        self"
  },
  {
    "path": "src/utils/sample_util.py",
    "chars": 616,
    "preview": "import torch\n\n\ndef sampling_without_replacement(logp, k):\n    def gumbel_like(u):\n        return -torch.log(-torch.log(t"
  },
  {
    "path": "src/variations/decode_morton.py",
    "chars": 588,
    "preview": "import numpy as np\n\n\ndef compact(value):\n    x = value & 0x1249249249249249\n    x = (x | x >> 2) & 0x10c30c30c30c30c3\n  "
  },
  {
    "path": "src/variations/lidar.py",
    "chars": 4114,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass GaussianFourierFeatureTransform(torch.nn.Modu"
  },
  {
    "path": "src/variations/render_helpers.py",
    "chars": 17255,
    "preview": "from copy import deepcopy\nimport torch\nimport torch.nn.functional as F\n\nfrom .voxel_helpers import ray_intersect, ray_sa"
  },
  {
    "path": "src/variations/voxel_helpers.py",
    "chars": 21098,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
  },
  {
    "path": "third_party/marching_cubes/setup.py",
    "chars": 527,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nimport glob\n\n_ext_sourc"
  },
  {
    "path": "third_party/marching_cubes/src/mc.cpp",
    "chars": 1956,
    "preview": "#include <torch/extension.h>\n\nstd::vector<torch::Tensor> marching_cubes_sparse(\n    torch::Tensor indexer,           // "
  },
  {
    "path": "third_party/marching_cubes/src/mc_data.cuh",
    "chars": 18896,
    "preview": "#include <torch/extension.h>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <thrust/device_vector.h>\n\n#define CHE"
  },
  {
    "path": "third_party/marching_cubes/src/mc_interp_kernel.cu",
    "chars": 20516,
    "preview": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ "
  },
  {
    "path": "third_party/marching_cubes/src/mc_kernel.cu",
    "chars": 13447,
    "preview": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ "
  },
  {
    "path": "third_party/marching_cubes/src/mc_kernel_colour.cu",
    "chars": 15332,
    "preview": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ "
  },
  {
    "path": "third_party/sparse_octree/include/octree.h",
    "chars": 3004,
    "preview": "#include <memory>\n#include <torch/script.h>\n#include <torch/custom_class.h>\n\nenum OcType\n{\n    NONLEAF = -1,\n    SURFACE"
  },
  {
    "path": "third_party/sparse_octree/include/test.h",
    "chars": 2181,
    "preview": "#pragma once\n#include <iostream>\n\n#define MAX_BITS 21\n// #define SCALE_MASK ((uint64_t)0x1FF)\n#define SCALE_MASK ((uint6"
  },
  {
    "path": "third_party/sparse_octree/include/utils.h",
    "chars": 2483,
    "preview": "#pragma once\n#include <iostream>\n#include <eigen3/Eigen/Dense>\n\n#define MAX_BITS 21\n// #define SCALE_MASK ((uint64_t)0x1"
  },
  {
    "path": "third_party/sparse_octree/setup.py",
    "chars": 657,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
  },
  {
    "path": "third_party/sparse_octree/src/bindings.cpp",
    "chars": 1288,
    "preview": "#include \"../include/octree.h\"\n#include \"../include/test.h\"\n\nTORCH_LIBRARY(svo, m)\n{\n    m.def(\"encode\", &encode_torch);"
  },
  {
    "path": "third_party/sparse_octree/src/octree.cpp",
    "chars": 10131,
    "preview": "#include \"../include/octree.h\"\n#include \"../include/utils.h\"\n#include <queue>\n#include <iostream>\n\n// #define MAX_HIT_VO"
  },
  {
    "path": "third_party/sparse_voxels/include/cuda_utils.h",
    "chars": 1629,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/include/cutil_math.h",
    "chars": 18605,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/include/intersect.h",
    "chars": 1194,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/include/octree.h",
    "chars": 341,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/include/sample.h",
    "chars": 682,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/include/utils.h",
    "chars": 1473,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/setup.py",
    "chars": 734,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
  },
  {
    "path": "third_party/sparse_voxels/src/binding.cpp",
    "chars": 659,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/src/intersect.cpp",
    "chars": 6968,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/src/intersect_gpu.cu",
    "chars": 12773,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/src/octree.cpp",
    "chars": 4439,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/src/sample.cpp",
    "chars": 4402,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  },
  {
    "path": "third_party/sparse_voxels/src/sample_gpu.cu",
    "chars": 8533,
    "preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
  }
]

About this extraction

This page contains the full source code of the JunyuanDeng/NeRF-LOAM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 70 files (271.3 KB), approximately 88.3k tokens, and a symbol index with 305 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!