Showing preview only (290K chars total). Download the full file or copy to clipboard to get everything.
Repository: JunyuanDeng/NeRF-LOAM
Branch: master
Commit: 2fe4e8d8dd9a
Files: 70
Total size: 271.3 KB
Directory structure:
gitextract_t7v5lyn0/
├── .gitignore
├── LICENSE
├── Readme.md
├── configs/
│ ├── kitti/
│ │ ├── kitti.yaml
│ │ ├── kitti_00.yaml
│ │ ├── kitti_01.yaml
│ │ ├── kitti_03.yaml
│ │ ├── kitti_04.yaml
│ │ ├── kitti_05.yaml
│ │ ├── kitti_06.yaml
│ │ ├── kitti_07.yaml
│ │ ├── kitti_08.yaml
│ │ ├── kitti_09.yaml
│ │ ├── kitti_10.yaml
│ │ ├── kitti_base06.yaml
│ │ └── kitti_base10.yaml
│ ├── maicity/
│ │ ├── maicity.yaml
│ │ ├── maicity_00.yaml
│ │ └── maicity_01.yaml
│ └── ncd/
│ ├── ncd.yaml
│ └── ncd_quad.yaml
├── demo/
│ ├── parser.py
│ └── run.py
├── install.sh
├── requirements.txt
├── src/
│ ├── criterion.py
│ ├── dataset/
│ │ ├── kitti.py
│ │ ├── maicity.py
│ │ └── ncd.py
│ ├── lidarFrame.py
│ ├── loggers.py
│ ├── mapping.py
│ ├── nerfloam.py
│ ├── se3pose.py
│ ├── share.py
│ ├── tracking.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── import_util.py
│ │ ├── mesh_util.py
│ │ ├── profile_util.py
│ │ └── sample_util.py
│ └── variations/
│ ├── decode_morton.py
│ ├── lidar.py
│ ├── render_helpers.py
│ └── voxel_helpers.py
└── third_party/
├── marching_cubes/
│ ├── setup.py
│ └── src/
│ ├── mc.cpp
│ ├── mc_data.cuh
│ ├── mc_interp_kernel.cu
│ ├── mc_kernel.cu
│ └── mc_kernel_colour.cu
├── sparse_octree/
│ ├── include/
│ │ ├── octree.h
│ │ ├── test.h
│ │ └── utils.h
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ └── octree.cpp
└── sparse_voxels/
├── include/
│ ├── cuda_utils.h
│ ├── cutil_math.h
│ ├── intersect.h
│ ├── octree.h
│ ├── sample.h
│ └── utils.h
├── setup.py
└── src/
├── binding.cpp
├── intersect.cpp
├── intersect_gpu.cu
├── octree.cpp
├── sample.cpp
└── sample_gpu.cu
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__
*.egg-info
build/
dist/
logs/
.vscode/
results/
temp.sh
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 JunyuanDeng
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: Readme.md
================================================
# NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping
This repository contains the implementation of our paper:
> **NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping** ([PDF](https://arxiv.org/pdf/2303.10709))\
> [Junyuan Deng](https://github.com/JunyuanDeng), [Qi Wu](https://github.com/Gatsby23), [Xieyuanli Chen](https://github.com/Chen-Xieyuanli), Songpengcheng Xia, Zhen Sun, Guoqing Liu, Wenxian Yu and Ling Pei\
> If you use our code in your work, please star our repo and cite our paper.
```
@inproceedings{deng2023nerfloam,
title={NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping},
author={Junyuan Deng and Qi Wu and Xieyuanli Chen and Songpengcheng Xia and Zhen Sun and Guoqing Liu and Wenxian Yu and Ling Pei},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year={2023}
}
```
<div align=center>
<img src="./docs/NeRFLOAM.gif">
</div>
- *Our incrementally simultaneous odometry and mapping results on the Newer College dataset and the KITTI dataset sequence 00.*
- *The maps are dense with a form of mesh, the red line indicates the odometry results.*
- *We use the same network without training to prove the ability of generalization of our design.*
## Overview

**Overview of our method.** Our method is based on our neural SDF and composed of three main components:
- Neural odometry takes the pre-processed scan and optimizes the pose via back projecting the queried neural SDF;
- Neural mapping jointly optimizes the voxel embeddings map and pose while selecting the key-scans;
- Key-scans refined map returns SDF value and the final mesh is reconstructed by marching cube.
## Quatitative results
**The reconstructed maps**

*The qualitative result of our odometry mapping on the KITTI dataset. From left upper to right bottom, we list the results of sequences 00, 01, 03, 04, 05, 09, 10.*
**The odometry results**

*The qualitative results of our odometry on the KITTI dataset. From left to right, we list the results of sequences 00, 01, 03, 04, 05, 07, 09, 10. The dashed line corresponds to the ground truth and the blue line to our odometry method.*
## Data
1. Newer College real-world LiDAR dataset: [website](https://ori-drs.github.io/newer-college-dataset/download/).
2. MaiCity synthetic LiDAR dataset: [website](https://www.ipb.uni-bonn.de/data/mai-city-dataset/).
3. KITTI dataset: [website](https://www.cvlibs.net/datasets/kitti/).
## Environment Setup
To run the code, a GPU with large memory is preferred. We tested the code with RTX3090 and GTX TITAN.
We use Conda to create a virtual environment and install dependencies:
- python environment: We tested our code with Python 3.8.13
- [Pytorch](https://pytorch.org/get-started/locally/): The Version we tested is 1.10 with cuda10.2 (and cuda11.1)
- Other depedencies are specified in requirements.txt. You can then install all dependancies using `pip` or `conda`:
```
pip3 install -r requirements.txt
```
- After you have installed all third party libraries, run the following script to build extra Pytorch modules used in this project.
```bash
sh install.sh
```
- Replace the filename in mapping.py with the built library
```python
torch.classes.load_library("third_party/sparse_octree/build/lib.xxx/svo.xxx.so")
```
- [patchwork-plusplus](https://github.com/url-kaist/patchwork-plusplus) to separate gound from LiDAR points.
- Replace the filename in src/dataset/*.py with the built library
```python
patchwork_module_path ="/xxx/patchwork-plusplus/build/python_wrapper"
```
## Demo
- The full dataset can be downloaded as mentioned. You can also download the example part of dataset, use these [scripts](https://github.com/PRBonn/SHINE_mapping/tree/master/scripts) to download.
- Take maicity seq.01 dataset as example: Modify `configs/maicity/maicity_01.yaml` so the data_path section points to the real dataset path. Now you are all set to run the code:
```
python demo/run.py configs/maicity/maicity_01.yaml
```
## Note
- For kitti dataset, if you want to process it more fast, you can switch to branch `subscene`:
```
git checkout subscene
```
- Then run with `python demo/run.py configs/kitti/kitti_00.yaml`
- This branch cut the full scene into subscenes to speed up and concatenate them together. This will certainly add map inconsistency and decay tracking accuracy...
## Evaluation
- We follow the evaluation proposed [here](https://github.com/PRBonn/SHINE_mapping/tree/master/eval), but we did not use the `crop_intersection.py`
## Acknowledgement
Some of our codes are adapted from [Vox-Fusion](https://github.com/zju3dv/Vox-Fusion).
## Contact
Any questions or suggestions are welcome!
Junyuan Deng: d.juney@sjtu.edu.cn and Xieyuanli Chen: xieyuanli.chen@nudt.edu.cn
## License
This project is free software made available under the MIT License. For details see the LICENSE file.
================================================
FILE: configs/kitti/kitti.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: kitti
criteria:
sdf_weight: 10000.0
fs_weight: 1
eiko_weight: 0.1
sdf_truncation: 0.30
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.06
step_size: 0.2
max_voxel_hit: 20
num_iterations: 25
mapper_specs:
N_rays_each: 2048
use_local_coord: False
voxel_size: 0.3
step_size: 0.5
window_size: 4
num_iterations: 25
max_voxel_hit: 20
final_iter: True
mesh_res: 2
learning_rate_emb: 0.01
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
freeze_frame: 5
keyframe_gap: 8
remove_back: False
key_distance: 12
debug_args:
verbose: False
mesh_freq: 100
================================================
FILE: configs/kitti/kitti_00.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence00
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/00'
use_gt: False
max_depth: 40
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 1
================================================
FILE: configs/kitti/kitti_01.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence01
data_specs:
data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/01/'
use_gt: False
max_depth: 30
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: 1101
read_offset: 1
================================================
FILE: configs/kitti/kitti_03.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence03
data_specs:
data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/03/'
use_gt: False
max_depth: 30
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: 1101
read_offset: 1
================================================
FILE: configs/kitti/kitti_04.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence04
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/04'
use_gt: False
max_depth: 50
min_depth: 2.75
tracker_specs:
start_frame: 0
end_frame: 270
read_offset: 1
================================================
FILE: configs/kitti/kitti_05.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence05
data_specs:
data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/05/'
use_gt: False
max_depth: 50
min_depth: 5
tracker_specs:
start_frame: 2299
end_frame: 2760
read_offset: 1
================================================
FILE: configs/kitti/kitti_06.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence06
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/06'
use_gt: False
max_depth: 40
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 1
================================================
FILE: configs/kitti/kitti_07.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence07
data_specs:
data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/07'
use_gt: True
max_depth: 25
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: 1100
read_offset: 1
================================================
FILE: configs/kitti/kitti_08.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence08
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/08'
use_gt: False
max_depth: 40
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 1
================================================
FILE: configs/kitti/kitti_09.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence09
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/09'
use_gt: False
max_depth: 40
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 1
================================================
FILE: configs/kitti/kitti_10.yaml
================================================
base_config: configs/kitti/kitti_base10.yaml
exp_name: kitti/sqeuence10
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/10'
use_gt: False
max_depth: 70
min_depth: 2.75
tracker_specs:
start_frame: 0
end_frame: 1200
read_offset: 1
================================================
FILE: configs/kitti/kitti_base06.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: kitti
criteria:
sdf_weight: 10000.0
fs_weight: 1
eiko_weight: 0.1
sdf_truncation: 0.30
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.06
step_size: 0.2
max_voxel_hit: 20
num_iterations: 25
mapper_specs:
N_rays_each: 2048
use_local_coord: False
voxel_size: 0.3
step_size: 0.5
window_size: 4
num_iterations: 25
max_voxel_hit: 20
final_iter: True
mesh_res: 2
learning_rate_emb: 0.01
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
freeze_frame: 5
keyframe_gap: 8
remove_back: False
key_distance: 12
debug_args:
verbose: False
mesh_freq: 100
================================================
FILE: configs/kitti/kitti_base10.yaml
================================================
log_dir: '/home/evsjtu2/disk1/dengjunyuan/running_logs/'
decoder: lidar
dataset: kitti
criteria:
depth_weight: 0
sdf_weight: 12000.0
fs_weight: 1
eiko_weight: 0
sdf_truncation: 0.50
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.1
start_frame: 0
end_frame: -1
step_size: 0.2
show_imgs: False
max_voxel_hit: 20
keyframe_freq: 10
num_iterations: 40
mapper_specs:
N_rays_each: 2048
num_embeddings: 20000000
use_local_coord: False
voxel_size: 0.2
step_size: 0.2
window_size: 4
num_iterations: 20
max_voxel_hit: 20
final_iter: True
mesh_res: 2
overlap_th: 0.8
learning_rate_emb: 0.03
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
#max_depth_first: 20
freeze_frame: 10
keyframe_gap: 7
remove_back: True
key_distance: 7
debug_args:
verbose: False
mesh_freq: 100
================================================
FILE: configs/maicity/maicity.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: maicity
criteria:
sdf_weight: 10000.0
fs_weight: 1
eiko_weight: 0.1
sdf_truncation: 0.30
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.005
step_size: 0.2
max_voxel_hit: 20
num_iterations: 20
mapper_specs:
N_rays_each: 2048
use_local_coord: False
voxel_size: 0.2
step_size: 0.5
window_size: 4
num_iterations: 20
max_voxel_hit: 20
final_iter: True
mesh_res: 2
learning_rate_emb: 0.03
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
freeze_frame: 5
keyframe_gap: 8
remove_back: False
key_distance: 12
debug_args:
verbose: False
mesh_freq: 100
================================================
FILE: configs/maicity/maicity_00.yaml
================================================
base_config: configs/maicity/maicity.yaml
exp_name: maicity/sqeuence00
data_specs:
data_path: '/home/pl21n4/dataset/mai_city/bin/sequences/00'
use_gt: False
max_depth: 50.0
min_depth: 1.5
tracker_specs:
start_frame: 0
end_frame: 699
read_offset: 1
================================================
FILE: configs/maicity/maicity_01.yaml
================================================
base_config: configs/maicity/maicity.yaml
exp_name: maicity/sqeuence01
data_specs:
data_path: '/home/pl21n4/dataset/mai_city/bin/sequences/01'
use_gt: False
max_depth: 50.0
min_depth: 1.5
tracker_specs:
start_frame: 0
end_frame: 99
read_offset: 1
================================================
FILE: configs/ncd/ncd.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: ncd
criteria:
sdf_weight: 10000.0
fs_weight: 1
eiko_weight: 1.0
sdf_truncation: 0.30
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.04
step_size: 0.1
max_voxel_hit: 20
num_iterations: 30
mapper_specs:
N_rays_each: 2048
use_local_coord: False
voxel_size: 0.2
step_size: 0.2
window_size: 5
num_iterations: 15
max_voxel_hit: 20
final_iter: True
mesh_res: 2
learning_rate_emb: 0.002
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
freeze_frame: 20
keyframe_gap: 8
remove_back: False
key_distance: 20
debug_args:
verbose: False
mesh_freq: 500
================================================
FILE: configs/ncd/ncd_quad.yaml
================================================
base_config: configs/ncd/ncd.yaml
exp_name: ncd/quad
data_specs:
data_path: '/home/pl21n4/dataset/ncd_example/quad'
use_gt: False
max_depth: 50
min_depth: 1.5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 5
================================================
FILE: demo/parser.py
================================================
import yaml
import argparse
class ArgumentParserX(argparse.ArgumentParser):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_argument("config", type=str)
def parse_args(self, args=None, namespace=None):
_args = self.parse_known_args(args, namespace)[0]
file_args = argparse.Namespace()
file_args = self.parse_config_yaml(_args.config, file_args)
file_args = self.convert_to_namespace(file_args)
for ckey, cvalue in file_args.__dict__.items():
try:
self.add_argument('--' + ckey, type=type(cvalue),
default=cvalue, required=False)
except argparse.ArgumentError:
continue
_args = super().parse_args(args, namespace)
return _args
def parse_config_yaml(self, yaml_path, args=None):
with open(yaml_path, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
if configs is not None:
base_config = configs.get('base_config')
if base_config is not None:
base_config = self.parse_config_yaml(configs["base_config"])
if base_config is not None:
configs = self.update_recursive(base_config, configs)
else:
raise FileNotFoundError("base_config specified but not found!")
return configs
def convert_to_namespace(self, dict_in, args=None):
if args is None:
args = argparse.Namespace()
for ckey, cvalue in dict_in.items():
if ckey not in args.__dict__.keys():
args.__dict__[ckey] = cvalue
return args
def update_recursive(self, dict1, dict2):
for k, v in dict2.items():
if k not in dict1:
dict1[k] = dict()
if isinstance(v, dict):
self.update_recursive(dict1[k], v)
else:
dict1[k] = v
return dict1
def get_parser():
parser = ArgumentParserX()
parser.add_argument("--resume", default=None, type=str)
parser.add_argument("--debug", action='store_true')
return parser
if __name__ == '__main__':
args = ArgumentParserX()
print(args.parse_args())
================================================
FILE: demo/run.py
================================================
import os # noqa
import sys # noqa
sys.path.insert(0, os.path.abspath('src')) # noqa
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import random
from parser import get_parser
import numpy as np
import torch
from nerfloam import nerfloam
import os
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
if __name__ == '__main__':
args = get_parser().parse_args()
if hasattr(args, 'seeding'):
setup_seed(args.seeding)
else:
setup_seed(777)
slam = nerfloam(args)
slam.start()
slam.wait_child_processes()
================================================
FILE: install.sh
================================================
#!/bin/bash
cd third_party/marching_cubes
python setup.py install
cd ../sparse_octree
python setup.py install
cd ../sparse_voxels
python setup.py install
================================================
FILE: requirements.txt
================================================
matplotlib
open3d
opencv-python
PyYAML
scikit-image
tqdm
trimesh
pyyaml
================================================
FILE: src/criterion.py
================================================
import torch
import torch.nn as nn
from torch.autograd import grad
class Criterion(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.args = args
self.eiko_weight = args.criteria["eiko_weight"]
self.sdf_weight = args.criteria["sdf_weight"]
self.fs_weight = args.criteria["fs_weight"]
self.truncation = args.criteria["sdf_truncation"]
self.max_dpeth = args.data_specs["max_depth"]
def forward(self, outputs, obs, pointsCos, use_color_loss=True,
use_depth_loss=True, compute_sdf_loss=True,
weight_depth_loss=False, compute_eikonal_loss=False):
points = obs
loss = 0
loss_dict = {}
# pred_depth = outputs["depth"]
pred_sdf = outputs["sdf"]
z_vals = outputs["z_vals"]
ray_mask = outputs["ray_mask"]
valid_mask = outputs["valid_mask"]
sampled_xyz = outputs["sampled_xyz"]
gt_points = points[ray_mask]
pointsCos = pointsCos[ray_mask]
gt_distance = torch.norm(gt_points, 2, -1)
gt_distance = gt_distance * pointsCos.view(-1)
z_vals = z_vals * pointsCos.view(-1, 1)
if compute_sdf_loss:
fs_loss, sdf_loss, eikonal_loss = self.get_sdf_loss(
z_vals, gt_distance, pred_sdf,
truncation=self.truncation,
loss_type='l2',
valid_mask=valid_mask,
compute_eikonal_loss=compute_eikonal_loss,
points=sampled_xyz if compute_eikonal_loss else None
)
loss += self.fs_weight * fs_loss
loss += self.sdf_weight * sdf_loss
# loss += self.bs_weight * back_loss
loss_dict["fs_loss"] = fs_loss.item()
# loss_dict["bs_loss"] = back_loss.item()
loss_dict["sdf_loss"] = sdf_loss.item()
if compute_eikonal_loss:
loss += self.eiko_weight * eikonal_loss
loss_dict["eiko_loss"] = eikonal_loss.item()
loss_dict["loss"] = loss.item()
# print(loss_dict)
return loss, loss_dict
def compute_loss(self, x, y, mask=None, loss_type="l2"):
if mask is None:
mask = torch.ones_like(x).bool()
if loss_type == "l1":
return torch.mean(torch.abs(x - y)[mask])
elif loss_type == "l2":
return torch.mean(torch.square(x - y)[mask])
def get_masks(self, z_vals, depth, epsilon):
front_mask = torch.where(
z_vals < (depth - epsilon),
torch.ones_like(z_vals),
torch.zeros_like(z_vals),
)
back_mask = torch.where(
z_vals > (depth + epsilon),
torch.ones_like(z_vals),
torch.zeros_like(z_vals),
)
depth_mask = torch.where(
(depth > 0.0) & (depth < self.max_dpeth), torch.ones_like(
depth), torch.zeros_like(depth)
)
sdf_mask = (1.0 - front_mask) * (1.0 - back_mask) * depth_mask
num_fs_samples = torch.count_nonzero(front_mask).float()
num_sdf_samples = torch.count_nonzero(sdf_mask).float()
num_samples = num_sdf_samples + num_fs_samples
fs_weight = 1.0 - num_fs_samples / num_samples
sdf_weight = 1.0 - num_sdf_samples / num_samples
return front_mask, sdf_mask, fs_weight, sdf_weight
def get_sdf_loss(self, z_vals, depth, predicted_sdf, truncation, valid_mask, loss_type="l2", compute_eikonal_loss=False, points=None):
front_mask, sdf_mask, fs_weight, sdf_weight = self.get_masks(
z_vals, depth.unsqueeze(-1).expand(*z_vals.shape), truncation
)
fs_loss = (self.compute_loss(predicted_sdf * front_mask * valid_mask, torch.ones_like(
predicted_sdf) * front_mask, loss_type=loss_type,) * fs_weight)
sdf_loss = (self.compute_loss((z_vals + predicted_sdf * truncation) * sdf_mask * valid_mask,
depth.unsqueeze(-1).expand(*z_vals.shape) * sdf_mask, loss_type=loss_type,) * sdf_weight)
# back_loss = (self.compute_loss(predicted_sdf * back_mask, -torch.ones_like(
# predicted_sdf) * back_mask, loss_type=loss_type,) * back_weight)
eikonal_loss = None
if compute_eikonal_loss:
sdf = (predicted_sdf*sdf_mask*truncation)
sdf = sdf[valid_mask]
d_points = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
sdf_grad = grad(outputs=sdf,
inputs=points,
grad_outputs=d_points,
retain_graph=True,
only_inputs=True)[0]
eikonal_loss = self.compute_loss(sdf_grad[0].norm(2, -1), 1.0, loss_type=loss_type,)
return fs_loss, sdf_loss, eikonal_loss
================================================
FILE: src/dataset/kitti.py
================================================
import os.path as osp
import numpy as np
import torch
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree
patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp
params = pypatchworkpp.Parameters()
# params.verbose = True
PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)
class DataLoader(Dataset):
def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
self.data_path = data_path
self.num_bin = len(glob(osp.join(self.data_path, "velodyne/*.bin")))
self.use_gt = use_gt
self.max_depth = max_depth
self.min_depth = min_depth
self.gt_pose = self.load_gt_pose() if use_gt else None
def get_init_pose(self, frame):
if self.gt_pose is not None:
return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
).reshape(4, 4)
else:
return np.eye(4)
def load_gt_pose(self):
gt_file = osp.join(self.data_path, "poses_lidar.txt")
gt_pose = np.loadtxt(gt_file)
return gt_pose
def load_points(self, index):
remove_abnormal_z = True
path = osp.join(self.data_path, "velodyne/{:06d}.bin".format(index))
points = np.fromfile(path, dtype=np.float32, count=-1).reshape([-1, 4])
if remove_abnormal_z:
points = points[points[:, 2] > -3.0]
points_norm = np.linalg.norm(points[:, :3], axis=-1)
point_mask = True
if self.max_depth != -1:
point_mask = (points_norm < self.max_depth) & point_mask
if self.min_depth != -1:
point_mask = (points_norm > self.min_depth) & point_mask
if isinstance(point_mask, np.ndarray):
points = points[point_mask]
PatchworkPLUSPLUS.estimateGround(points)
ground = PatchworkPLUSPLUS.getGround()
nonground = PatchworkPLUSPLUS.getNonground()
Patchcenters = PatchworkPLUSPLUS.getCenters()
normals = PatchworkPLUSPLUS.getNormals()
T = cKDTree(Patchcenters)
_, index = T.query(ground)
if True:
groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
else:
groundcos = np.ones(ground.shape[0])
points = np.concatenate((ground, nonground), axis=0)
pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)
return points, pointcos
def __len__(self):
return self.num_bin
def __getitem__(self, index):
points, pointcos = self.load_points(index)
points = torch.from_numpy(points).float()
pointcos = torch.from_numpy(pointcos).float()
pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
).reshape(4, 4) if self.use_gt else None
return index, points, pointcos, pose
if __name__ == "__main__":
path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
loader = DataLoader(path)
for data in loader:
index, points, pose = data
print("current index ", index)
print("first 10th points:\n", points[:10])
if index > 10:
break
index += 1
================================================
FILE: src/dataset/maicity.py
================================================
import os.path as osp
import numpy as np
import torch
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree
patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp
params = pypatchworkpp.Parameters()
# params.verbose = True
PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)
class DataLoader(Dataset):
def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
self.data_path = data_path
self.num_bin = len(glob(osp.join(self.data_path, "velodyne/*.bin")))
self.use_gt = use_gt
self.max_depth = max_depth
self.min_depth = min_depth
self.gt_pose = self.load_gt_pose() if use_gt else None
def get_init_pose(self, frame):
if self.gt_pose is not None:
return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
).reshape(4, 4)
else:
return np.eye(4)
def load_gt_pose(self):
gt_file = osp.join(self.data_path, "poses.txt")
gt_pose = np.loadtxt(gt_file)
return gt_pose
def load_points(self, index):
path = osp.join(self.data_path, "velodyne/{:05d}.bin".format(index))
points = np.fromfile(path, dtype=np.float32, count=-1).reshape([-1, 4])
# points = points[:,:3]
# points = np.delete(points, -1, axis=1)
points_norm = np.linalg.norm(points[:, :3], axis=-1)
point_mask = True
if self.max_depth != -1:
point_mask = (points_norm < self.max_depth) & point_mask
if self.min_depth != -1:
point_mask = (points_norm > self.min_depth) & point_mask
if isinstance(point_mask, np.ndarray):
points = points[point_mask]
PatchworkPLUSPLUS.estimateGround(points)
ground = PatchworkPLUSPLUS.getGround()
nonground = PatchworkPLUSPLUS.getNonground()
Patchcenters = PatchworkPLUSPLUS.getCenters()
normals = PatchworkPLUSPLUS.getNormals()
T = cKDTree(Patchcenters)
_, index = T.query(ground)
if True:
groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
# groundnorm = np.linalg.norm(ground, axis=-1)
# groundcos = np.where(groundnorm > 10.0, np.ones(ground.shape[0]), groundcos)
else:
groundcos = np.ones(ground.shape[0])
points = np.concatenate((ground, nonground), axis=0)
pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)
return points, pointcos
def __len__(self):
return self.num_bin
def __getitem__(self, index):
points, pointcos = self.load_points(index)
points = torch.from_numpy(points).float()
pointcos = torch.from_numpy(pointcos).float()
pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
).reshape(4, 4) if self.use_gt else None
return index, points, pointcos, pose
if __name__ == "__main__":
path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
loader = DataLoader(path)
for data in loader:
index, points, pose = data
print("current index ", index)
print("first 10th points:\n", points[:10])
if index > 10:
break
index += 1
================================================
FILE: src/dataset/ncd.py
================================================
import os.path as osp
import numpy as np
import torch
import open3d as o3d
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree
patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp
params = pypatchworkpp.Parameters()
params.enable_RNR = False
# params.verbose = True
PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)
class DataLoader(Dataset):
def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
self.data_path = data_path
self.num_bin = len(glob(osp.join(self.data_path, "pcd/*.pcd")))
self.use_gt = use_gt
self.max_depth = max_depth
self.min_depth = min_depth
self.gt_pose = self.load_gt_pose() if use_gt else None
def get_init_pose(self, frame):
if self.gt_pose is not None:
return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
).reshape(4, 4)
else:
return np.array([[5.925493285036220747e-01, -8.038419275143061649e-01, 5.218676416200035417e-02, -2.422443415414985424e-01],
[8.017167514002809803e-01, 5.948020209102693467e-01, 5.882863457495644127e-02, 3.667865561670570873e+00],
[-7.832971094540422397e-02, 6.980134849334420320e-03, 9.969030746023688216e-01, 6.809443654823238434e-01]])
def load_gt_pose(self):
gt_file = osp.join(self.data_path, "poses.txt")
gt_pose = np.loadtxt(gt_file)
# with open(gt_file, mode='r', encoding="utf-8") as g:
# line = g.readline()
# while line:
# # TODO:write transfomation of kitti and pose matrix
# pose = np.zeros(16)
return gt_pose
def load_points(self, index):
path = osp.join(self.data_path, "pcd/{:05d}.pcd".format(index+500))
pc_load = o3d.io.read_point_cloud(path)
points = np.asarray(pc_load.points)
points_norm = np.linalg.norm(points, axis=-1)
point_mask = True
if self.max_depth != -1:
point_mask = (points_norm < self.max_depth) & point_mask
if self.min_depth != -1:
point_mask = (points_norm > self.min_depth) & point_mask
if isinstance(point_mask, np.ndarray):
points = points[point_mask]
PatchworkPLUSPLUS.estimateGround(points)
ground = PatchworkPLUSPLUS.getGround()
nonground = PatchworkPLUSPLUS.getNonground()
Patchcenters = PatchworkPLUSPLUS.getCenters()
normals = PatchworkPLUSPLUS.getNormals()
T = cKDTree(Patchcenters)
_, index = T.query(ground)
if True:
groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
else:
groundcos = np.ones(ground.shape[0])
points = np.concatenate((ground, nonground), axis=0)
pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)
return points, pointcos
def __len__(self):
return self.num_bin
def __getitem__(self, index):
points, pointcos = self.load_points(index)
points = torch.from_numpy(points).float()
pointcos = torch.from_numpy(pointcos).float()
pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
).reshape(4, 4) if self.use_gt else None
return index, points, pointcos, pose
if __name__ == "__main__":
path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
loader = DataLoader(path)
for data in loader:
index, points, pose = data
print("current index ", index)
print("first 10th points:\n", points[:10])
if index > 10:
break
index += 1
================================================
FILE: src/lidarFrame.py
================================================
import torch
import torch.nn as nn
import numpy as np
from se3pose import OptimizablePose
from utils.sample_util import *
import random
class LidarFrame(nn.Module):
def __init__(self, index, points, pointsCos, pose=None, new_keyframe=False) -> None:
super().__init__()
self.index = index
self.num_point = len(points)
self.points = points
self.pointsCos = pointsCos
if (not new_keyframe) and (pose is not None):
# TODO: fix this offset
pose[:3, 3] += 2000
pose = torch.tensor(pose, requires_grad=True, dtype=torch.float32)
self.pose = OptimizablePose.from_matrix(pose)
elif new_keyframe:
self.pose = pose
self.rays_d = self.get_rays()
self.rel_pose = None
def get_pose(self):
return self.pose.matrix()
def get_translation(self):
return self.pose.translation()
def get_rotation(self):
return self.pose.rotation()
def get_points(self):
return self.points
def get_pointsCos(self):
return self.pointsCos
def set_rel_pose(self, rel_pose):
self.rel_pose = rel_pose
def get_rel_pose(self):
return self.rel_pose
@torch.no_grad()
def get_rays(self):
self.rays_norm = (torch.norm(self.points, 2, -1, keepdim=True)+1e-8)
rays_d = self.points / self.rays_norm
# TODO: to keep cosistency, add one dim, but actually no need
return rays_d.unsqueeze(1).float()
@torch.no_grad()
def sample_rays(self, N_rays, track=False):
self.sample_mask = sample_rays(
torch.ones((self.num_point, 1))[None, ...], N_rays)[0, ...]
================================================
FILE: src/loggers.py
================================================
import os
import os.path as osp
import pickle
from datetime import datetime
import cv2
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import torch
import yaml
class BasicLogger:
def __init__(self, args) -> None:
self.args = args
self.log_dir = osp.join(
args.log_dir, args.exp_name, self.get_random_time_str())
self.img_dir = osp.join(self.log_dir, "imgs")
self.mesh_dir = osp.join(self.log_dir, "mesh")
self.ckpt_dir = osp.join(self.log_dir, "ckpt")
self.backup_dir = osp.join(self.log_dir, "bak")
self.misc_dir = osp.join(self.log_dir, "misc")
os.makedirs(self.img_dir)
os.makedirs(self.ckpt_dir)
os.makedirs(self.mesh_dir)
os.makedirs(self.misc_dir)
os.makedirs(self.backup_dir)
self.log_config(args)
def get_random_time_str(self):
return datetime.strftime(datetime.now(), "%Y-%m-%d-%H-%M-%S")
def log_ckpt(self, mapper):
print("******* saving *******")
decoder_state = {f: v.cpu()
for f, v in mapper.decoder.state_dict().items()}
map_state = {f: v.cpu() for f, v in mapper.map_states.items()}
embeddings = mapper.dynamic_embeddings.cpu()
svo = mapper.svo
torch.save({
"decoder_state": decoder_state,
# "map_state": map_state,
"embeddings": embeddings,
"svo": svo
},
os.path.join(self.ckpt_dir, "final_ckpt.pth"))
print("******* finish saving *******")
def log_config(self, config):
out_path = osp.join(self.backup_dir, "config.yaml")
yaml.dump(config, open(out_path, 'w'))
def log_mesh(self, mesh, name="final_mesh.ply"):
out_path = osp.join(self.mesh_dir, name)
o3d.io.write_triangle_mesh(out_path, mesh)
def log_point_cloud(self, pcd, name="final_points.ply"):
out_path = osp.join(self.mesh_dir, name)
o3d.io.write_point_cloud(out_path, pcd)
def log_numpy_data(self, data, name, ind=None):
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
if ind is not None:
np.save(osp.join(self.misc_dir, "{}-{:05d}.npy".format(name, ind)), data)
else:
np.save(osp.join(self.misc_dir, f"{name}.npy"), data)
self.npy2txt(osp.join(self.misc_dir, f"{name}.npy"), osp.join(self.misc_dir, f"{name}.txt"))
def log_debug_data(self, data, idx):
with open(os.path.join(self.misc_dir, f"scene_data_{idx}.pkl"), 'wb') as f:
pickle.dump(data, f)
def log_raw_image(self, ind, rgb, depth):
if isinstance(rgb, torch.Tensor):
rgb = rgb.detach().cpu().numpy()
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
rgb = cv2.cvtColor(rgb*255, cv2.COLOR_RGB2BGR)
cv2.imwrite(osp.join(self.img_dir, "{:05d}.jpg".format(
ind)), (rgb).astype(np.uint8))
cv2.imwrite(osp.join(self.img_dir, "{:05d}.png".format(
ind)), (depth*5000).astype(np.uint16))
def log_images(self, ind, gt_rgb, gt_depth, rgb, depth):
gt_depth_np = gt_depth.detach().cpu().numpy()
gt_color_np = gt_rgb.detach().cpu().numpy()
depth_np = depth.squeeze().detach().cpu().numpy()
color_np = rgb.detach().cpu().numpy()
h, w = depth_np.shape
gt_depth_np = cv2.resize(
gt_depth_np, (w, h), interpolation=cv2.INTER_NEAREST)
gt_color_np = cv2.resize(
gt_color_np, (w, h), interpolation=cv2.INTER_AREA)
depth_residual = np.abs(gt_depth_np - depth_np)
depth_residual[gt_depth_np == 0.0] = 0.0
color_residual = np.abs(gt_color_np - color_np)
color_residual[gt_depth_np == 0.0] = 0.0
fig, axs = plt.subplots(2, 3)
fig.tight_layout()
max_depth = np.max(gt_depth_np)
axs[0, 0].imshow(gt_depth_np, cmap="plasma",
vmin=0, vmax=max_depth)
axs[0, 0].set_title('Input Depth')
axs[0, 0].set_xticks([])
axs[0, 0].set_yticks([])
axs[0, 1].imshow(depth_np, cmap="plasma",
vmin=0, vmax=max_depth)
axs[0, 1].set_title('Generated Depth')
axs[0, 1].set_xticks([])
axs[0, 1].set_yticks([])
axs[0, 2].imshow(depth_residual, cmap="plasma",
vmin=0, vmax=max_depth)
axs[0, 2].set_title('Depth Residual')
axs[0, 2].set_xticks([])
axs[0, 2].set_yticks([])
gt_color_np = np.clip(gt_color_np, 0, 1)
color_np = np.clip(color_np, 0, 1)
color_residual = np.clip(color_residual, 0, 1)
axs[1, 0].imshow(gt_color_np, cmap="plasma")
axs[1, 0].set_title('Input RGB')
axs[1, 0].set_xticks([])
axs[1, 0].set_yticks([])
axs[1, 1].imshow(color_np, cmap="plasma")
axs[1, 1].set_title('Generated RGB')
axs[1, 1].set_xticks([])
axs[1, 1].set_yticks([])
axs[1, 2].imshow(color_residual, cmap="plasma")
axs[1, 2].set_title('RGB Residual')
axs[1, 2].set_xticks([])
axs[1, 2].set_yticks([])
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig(osp.join(self.img_dir, "{:05d}.jpg".format(
ind)), bbox_inches='tight', pad_inches=0.2)
plt.clf()
plt.close()
def npy2txt(self, input_path, output_path):
poses = np.load(input_path)
with open(output_path, mode='w') as w:
shape = poses.shape
print(shape)
for i in range(shape[0]):
one_pose = str()
for j in range(shape[1]):
if j == (shape[1]-1):
continue
for k in range(shape[2]):
if j == (shape[1]-2) and k == (shape[1]-1):
one_pose += (str(poses[i][j][k])+"\n")
else:
one_pose += (str(poses[i][j][k])+" ")
w.write(one_pose)
================================================
FILE: src/mapping.py
================================================
from copy import deepcopy
import random
from time import sleep
import numpy as np
from tqdm import tqdm
import torch
from criterion import Criterion
from loggers import BasicLogger
from utils.import_util import get_decoder, get_property
from variations.render_helpers import bundle_adjust_frames
from utils.mesh_util import MeshExtractor
from utils.profile_util import Profiler
from lidarFrame import LidarFrame
import torch.nn.functional as F
from pathlib import Path
import open3d as o3d
torch.classes.load_library(
"/home/pl21n4/Programmes/Vox-Fusion/third_party/sparse_octree/build/lib.linux-x86_64-cpython-38/svo.cpython-38-x86_64-linux-gnu.so")
def get_network_size(net):
size = 0
for param in net.parameters():
size += param.element_size() * param.numel()
return size / 1024 / 1024
class Mapping:
def __init__(self, args, logger: BasicLogger):
super().__init__()
self.args = args
self.logger = logger
self.decoder = get_decoder(args).cuda()
print(self.decoder)
self.loss_criteria = Criterion(args)
self.keyframe_graph = []
self.initialized = False
mapper_specs = args.mapper_specs
# optional args
self.ckpt_freq = get_property(args, "ckpt_freq", -1)
self.final_iter = get_property(mapper_specs, "final_iter", False)
self.mesh_res = get_property(mapper_specs, "mesh_res", 8)
self.save_data_freq = get_property(
args.debug_args, "save_data_freq", 0)
# required args
self.voxel_size = mapper_specs["voxel_size"]
self.window_size = mapper_specs["window_size"]
self.num_iterations = mapper_specs["num_iterations"]
self.n_rays = mapper_specs["N_rays_each"]
self.sdf_truncation = args.criteria["sdf_truncation"]
self.max_voxel_hit = mapper_specs["max_voxel_hit"]
self.step_size = mapper_specs["step_size"]
self.learning_rate_emb = mapper_specs["learning_rate_emb"]
self.learning_rate_decorder = mapper_specs["learning_rate_decorder"]
self.learning_rate_pose = mapper_specs["learning_rate_pose"]
self.step_size = self.step_size * self.voxel_size
self.max_distance = args.data_specs["max_depth"]
self.freeze_frame = mapper_specs["freeze_frame"]
self.keyframe_gap = mapper_specs["keyframe_gap"]
self.remove_back = mapper_specs["remove_back"]
self.key_distance = mapper_specs["key_distance"]
embed_dim = args.decoder_specs["in_dim"]
use_local_coord = mapper_specs["use_local_coord"]
self.embed_dim = embed_dim - 3 if use_local_coord else embed_dim
#num_embeddings = mapper_specs["num_embeddings"]
self.mesh_freq = args.debug_args["mesh_freq"]
self.mesher = MeshExtractor(args)
self.voxel_id2embedding_id = -torch.ones((int(2e9), 1), dtype=torch.int)
self.embeds_exist_search = dict()
self.current_num_embeds = 0
self.dynamic_embeddings = None
self.svo = torch.classes.svo.Octree()
self.svo.init(256*256*4, embed_dim, self.voxel_size)
self.frame_poses = []
self.depth_maps = []
self.last_tracked_frame_id = 0
self.final_poses=[]
verbose = get_property(args.debug_args, "verbose", False)
self.profiler = Profiler(verbose=verbose)
self.profiler.enable()
def spin(self, share_data, kf_buffer):
print("mapping process started!!!!!!!!!")
while True:
torch.cuda.empty_cache()
if not kf_buffer.empty():
tracked_frame = kf_buffer.get()
# self.create_voxels(tracked_frame)
if not self.initialized:
self.first_frame_id = tracked_frame.index
if self.mesher is not None:
self.mesher.rays_d = tracked_frame.get_rays()
self.create_voxels(tracked_frame)
self.insert_keyframe(tracked_frame)
while kf_buffer.empty():
self.do_mapping(share_data, tracked_frame, selection_method='current')
self.initialized = True
else:
if self.remove_back:
tracked_frame = self.remove_back_points(tracked_frame)
self.do_mapping(share_data, tracked_frame)
self.create_voxels(tracked_frame)
if (torch.norm(tracked_frame.pose.translation().cpu()
- self.current_keyframe.pose.translation().cpu())) > self.keyframe_gap:
self.insert_keyframe(tracked_frame)
print(
f"********** current num kfs: { len(self.keyframe_graph) } **********")
# self.create_voxels(tracked_frame)
tracked_pose = tracked_frame.get_pose().detach()
ref_pose = self.current_keyframe.get_pose().detach()
rel_pose = torch.linalg.inv(ref_pose) @ tracked_pose
self.frame_poses += [(len(self.keyframe_graph) -
1, rel_pose.cpu())]
if self.mesh_freq > 0 and (tracked_frame.index) % self.mesh_freq == 0:
if self.final_iter and len(self.keyframe_graph) > 20:
print(f"********** post-processing steps **********")
#self.num_iterations = 1
final_num_iter = len(self.keyframe_graph) + 1
progress_bar = tqdm(
range(0, final_num_iter), position=0)
progress_bar.set_description(" post-processing steps")
for iter in progress_bar:
#tracked_frame=self.keyframe_graph[iter//self.window_size]
self.do_mapping(share_data, tracked_frame=None,
update_pose=False, update_decoder=False, selection_method='random')
self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False),name=f"mesh_{tracked_frame.index:05d}.ply")
pose = self.get_updated_poses()
self.logger.log_numpy_data(np.asarray(pose), f"frame_poses_{tracked_frame.index:05d}")
if self.final_iter and len(self.keyframe_graph) > 20:
self.keyframe_graph = []
self.keyframe_graph += [self.current_keyframe]
if self.save_data_freq > 0 and (tracked_frame.stamp + 1) % self.save_data_freq == 0:
self.save_debug_data(tracked_frame)
elif share_data.stop_mapping:
break
print("******* extracting mesh without replay *******")
self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False), name="final_mesh_noreplay.ply")
if self.final_iter:
print(f"********** post-processing steps **********")
#self.num_iterations = 1
final_num_iter = len(self.keyframe_graph) + 1
progress_bar = tqdm(
range(0, final_num_iter), position=0)
progress_bar.set_description(" post-processing steps")
for iter in progress_bar:
tracked_frame=self.keyframe_graph[iter//self.window_size]
self.do_mapping(share_data, tracked_frame=None,
update_pose=False, update_decoder=False, selection_method='random')
print("******* extracting final mesh *******")
pose = self.get_updated_poses()
self.logger.log_numpy_data(np.asarray(pose), "frame_poses")
self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False))
print("******* mapping process died *******")
def do_mapping(self, share_data, tracked_frame=None,
update_pose=True, update_decoder=True, selection_method = 'current'):
self.profiler.tick("do_mapping")
self.decoder.train()
optimize_targets = self.select_optimize_targets(tracked_frame, selection_method=selection_method)
torch.cuda.empty_cache()
self.profiler.tick("bundle_adjust_frames")
bundle_adjust_frames(
optimize_targets,
self.dynamic_embeddings,
self.map_states,
self.decoder,
self.loss_criteria,
self.voxel_size,
self.step_size,
self.n_rays * 2 if selection_method=='random' else self.n_rays,
self.num_iterations,
self.sdf_truncation,
self.max_voxel_hit,
self.max_distance,
learning_rate=[self.learning_rate_emb,
self.learning_rate_decorder,
self.learning_rate_pose],
update_pose=update_pose,
update_decoder=update_decoder if tracked_frame == None or (tracked_frame.index -self.first_frame_id) < self.freeze_frame else False,
profiler=self.profiler
)
self.profiler.tok("bundle_adjust_frames")
# optimize_targets = [f.cpu() for f in optimize_targets]
self.update_share_data(share_data)
self.profiler.tok("do_mapping")
# sleep(0.01)
def select_optimize_targets(self, tracked_frame=None, selection_method='previous'):
# TODO: better ways
targets = []
if selection_method == 'current':
if tracked_frame == None:
raise ValueError('select one track frame')
else:
return [tracked_frame]
if len(self.keyframe_graph) <= self.window_size:
targets = self.keyframe_graph[:]
elif selection_method == 'random':
targets = random.sample(self.keyframe_graph, self.window_size)
elif selection_method == 'previous':
targets = self.keyframe_graph[-self.window_size:]
elif selection_method == 'overlap':
raise NotImplementedError(
f"seletion method {selection_method} unknown")
if tracked_frame is not None and tracked_frame != self.current_keyframe:
targets += [tracked_frame]
return targets
def update_share_data(self, share_data, frameid=None):
share_data.decoder = deepcopy(self.decoder)
tmp_states = {}
for k, v in self.map_states.items():
tmp_states[k] = v.detach().cpu()
share_data.states = tmp_states
# self.last_tracked_frame_id = frameid
def remove_back_points(self, frame):
rel_pose = frame.get_rel_pose()
points = frame.get_points()
points_norm = torch.norm(points, 2, -1)
points_xy = points[:, :2]
if rel_pose == None:
x = 1
y = 0
else:
x = rel_pose[0, 3]
y = rel_pose[1, 3]
rel_xy = torch.ones((1, 2))
rel_xy[0, 0] = x
rel_xy[0, 1] = y
point_cos = torch.sum(-points_xy * rel_xy, dim=-1)/(
torch.norm(points_xy, 2, -1)*(torch.norm(rel_xy, 2, -1)))
remove_index = ((point_cos >= 0.7) & (points_norm > self.key_distance))
new_points = frame.points[~remove_index]
new_cos = frame.get_pointsCos()[~remove_index]
return LidarFrame(frame.index, new_points, new_cos,
frame.pose, new_keyframe=True)
def frame_maxdistance_change(self, frame, distance):
# kf check
valid_distance = distance + 0.5
new_keyframe_rays_norm = frame.rays_norm.reshape(-1)
new_keyframe_points = frame.points[new_keyframe_rays_norm <= valid_distance]
new_keyframe_pointsCos = frame.get_pointsCos()[new_keyframe_rays_norm <= valid_distance]
return LidarFrame(frame.index, new_keyframe_points, new_keyframe_pointsCos,
frame.pose, new_keyframe=True)
def insert_keyframe(self, frame, valid_distance=-1):
# kf check
print("insert keyframe")
valid_distance = self.key_distance + 0.01
new_keyframe_rays_norm = frame.rays_norm.reshape(-1)
mask = (torch.abs(frame.points[:, 0]) < valid_distance) & (torch.abs(frame.points[:, 1])
< valid_distance) & (torch.abs(frame.points[:, 2]) < valid_distance)
new_keyframe_points = frame.points[mask]
new_keyframe_pointsCos = frame.get_pointsCos()[mask]
new_keyframe = LidarFrame(frame.index, new_keyframe_points, new_keyframe_pointsCos,
frame.pose, new_keyframe=True)
if new_keyframe_points.shape[0] < 2*self.n_rays:
raise ValueError('valid_distance too small')
self.current_keyframe = new_keyframe
self.keyframe_graph += [new_keyframe]
# self.update_grid_features()
def create_voxels(self, frame):
points = frame.get_points().cuda()
pose = frame.get_pose().cuda()
print("frame id", frame.index+1)
print("trans ", pose[:3, 3]-2000)
points = points@pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
voxels = torch.div(points, self.voxel_size, rounding_mode='floor')
self.svo.insert(voxels.cpu().int())
self.update_grid_features()
@torch.enable_grad()
def get_embeddings(self, points_idx):
flatten_idx = points_idx.reshape(-1).long()
valid_flatten_idx = flatten_idx[flatten_idx.ne(-1)]
existence = F.embedding(valid_flatten_idx, self.voxel_id2embedding_id)
torch_add_idx = existence.eq(-1).view(-1)
torch_add = valid_flatten_idx[torch_add_idx]
if torch_add.shape[0] == 0:
return
start_num = self.current_num_embeds
end_num = start_num + torch_add.shape[0]
embeddings_add = torch.zeros((end_num-start_num, self.embed_dim),
dtype=torch.bfloat16)
# torch.nn.init.normal_(embeddings_add, std=0.01)
if self.dynamic_embeddings == None:
embeddings = [embeddings_add]
else:
embeddings = [self.dynamic_embeddings.detach().cpu(), embeddings_add]
embeddings = torch.cat(embeddings, dim=0)
self.dynamic_embeddings = embeddings.cuda().requires_grad_()
self.current_num_embeds = end_num
self.voxel_id2embedding_id[torch_add] = torch.arange(start_num, end_num, dtype=torch.int).view(-1, 1)
@torch.enable_grad()
def update_grid_features(self):
voxels, children, features = self.svo.get_centres_and_children()
centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.voxel_size
children = torch.cat([children, voxels[:, -1:]], -1)
centres = centres.float()
children = children.int()
map_states = {}
map_states["voxel_vertex_idx"] = features
centres.requires_grad_()
map_states["voxel_center_xyz"] = centres
map_states["voxel_structure"] = children
self.profiler.tick("Creating embedding")
self.get_embeddings(map_states["voxel_vertex_idx"])
self.profiler.tok("Creating embedding")
map_states["voxel_vertex_emb"] = self.dynamic_embeddings
map_states["voxel_id2embedding_id"] = self.voxel_id2embedding_id
self.map_states = map_states
@torch.no_grad()
def get_updated_poses(self, offset=-2000):
for i in range(len(self.frame_poses)):
ref_frame_ind, rel_pose = self.frame_poses[i]
ref_frame = self.keyframe_graph[ref_frame_ind]
ref_pose = ref_frame.get_pose().detach().cpu()
pose = ref_pose @ rel_pose
pose[:3, 3] += offset
self.final_poses += [pose.detach().cpu().numpy()]
self.frame_poses = []
return self.final_poses
@torch.no_grad()
def extract_mesh(self, res=8, clean_mesh=False):
sdf_network = self.decoder
sdf_network.eval()
voxels, _, features = self.svo.get_centres_and_children()
index = features.eq(-1).any(-1)
voxels = voxels[~index, :]
features = features[~index, :]
centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.voxel_size
encoder_states = {}
encoder_states["voxel_vertex_idx"] = features
encoder_states["voxel_center_xyz"] = centres
self.profiler.tick("Creating embedding")
self.get_embeddings(encoder_states["voxel_vertex_idx"])
self.profiler.tok("Creating embedding")
encoder_states["voxel_vertex_emb"] = self.dynamic_embeddings
encoder_states["voxel_id2embedding_id"] = self.voxel_id2embedding_id
frame_poses = self.get_updated_poses()
mesh = self.mesher.create_mesh(
self.decoder, encoder_states, self.voxel_size, voxels,
frame_poses=None, depth_maps=None,
clean_mseh=clean_mesh, require_color=False, offset=-2000, res=res)
return mesh
@torch.no_grad()
def extract_voxels(self, offset=-10):
voxels, _, features = self.svo.get_centres_and_children()
index = features.eq(-1).any(-1)
voxels = voxels[~index, :]
features = features[~index, :]
voxels = (voxels[:, :3] + voxels[:, -1:] / 2) * \
self.voxel_size + offset
# print(torch.max(features)-torch.count_nonzero(index))
return voxels
@torch.no_grad()
def save_debug_data(self, tracked_frame, offset=-10):
"""
save per-frame voxel, mesh and pose
"""
pose = tracked_frame.get_pose().detach().cpu().numpy()
pose[:3, 3] += offset
frame_poses = self.get_updated_poses()
mesh = self.extract_mesh(res=8, clean_mesh=True)
voxels = self.extract_voxels().detach().cpu().numpy()
keyframe_poses = [p.get_pose().detach().cpu().numpy()
for p in self.keyframe_graph]
for f in frame_poses:
f[:3, 3] += offset
for kf in keyframe_poses:
kf[:3, 3] += offset
verts = np.asarray(mesh.vertices)
faces = np.asarray(mesh.triangles)
color = np.asarray(mesh.vertex_colors)
self.logger.log_debug_data({
"pose": pose,
"updated_poses": frame_poses,
"mesh": {"verts": verts, "faces": faces, "color": color},
"voxels": voxels,
"voxel_size": self.voxel_size,
"keyframes": keyframe_poses,
"is_keyframe": (tracked_frame == self.current_keyframe)
}, tracked_frame.stamp)
================================================
FILE: src/nerfloam.py
================================================
from multiprocessing.managers import BaseManager
from time import sleep
import torch
import torch.multiprocessing as mp
from loggers import BasicLogger
from mapping import Mapping
from share import ShareData, ShareDataProxy
from tracking import Tracking
from utils.import_util import get_dataset
class nerfloam:
def __init__(self, args):
self.args = args
# logger (optional)
self.logger = BasicLogger(args)
# shared data
mp.set_start_method('spawn', force=True)
BaseManager.register('ShareData', ShareData, ShareDataProxy)
manager = BaseManager()
manager.start()
self.share_data = manager.ShareData()
# keyframe buffer
self.kf_buffer = mp.Queue(maxsize=1)
# data stream
self.data_stream = get_dataset(args)
# tracker
self.tracker = Tracking(args, self.data_stream, self.logger)
# mapper
self.mapper = Mapping(args, self.logger)
# initialize map with first frame
self.tracker.process_first_frame(self.kf_buffer)
self.processes = []
def start(self):
mapping_process = mp.Process(
target=self.mapper.spin, args=(self.share_data, self.kf_buffer))
mapping_process.start()
print("initializing the first frame ...")
sleep(20)
# self.share_data.stop_mapping=True
tracking_process = mp.Process(
target=self.tracker.spin, args=(self.share_data, self.kf_buffer))
tracking_process.start()
self.processes = [mapping_process, tracking_process]
def wait_child_processes(self):
for p in self.processes:
p.join()
@torch.no_grad()
def get_raw_trajectory(self):
return self.share_data.tracking_trajectory
@torch.no_grad()
def get_keyframe_poses(self):
keyframe_graph = self.mapper.keyframe_graph
poses = []
for keyframe in keyframe_graph:
poses.append(keyframe.get_pose().detach().cpu().numpy())
return poses
================================================
FILE: src/se3pose.py
================================================
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
class OptimizablePose(nn.Module):
def __init__(self, init_pose):
super().__init__()
assert (isinstance(init_pose, torch.FloatTensor))
self.register_parameter('data', nn.Parameter(init_pose))
self.data.required_grad_ = True
def copy_from(self, pose):
self.data = deepcopy(pose.data)
def matrix(self):
Rt = torch.eye(4)
Rt[:3, :3] = self.rotation()
Rt[:3, 3] = self.translation()
return Rt
def rotation(self):
w = self.data[3:]
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[..., None, None]
I = torch.eye(3, device=w.device, dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
R = I+A*wx+B*wx@wx
return R
def translation(self,):
return self.data[:3]
@classmethod
def log(cls, R, eps=1e-7): # [...,3,3]
trace = R[..., 0, 0]+R[..., 1, 1]+R[..., 2, 2]
# ln(R) will explode if theta==pi
theta = ((trace-1)/2).clamp(-1+eps, 1-eps).acos_()[..., None, None] % np.pi
lnR = 1/(2*cls.taylor_A(theta)+1e-8) * (R-R.transpose(-2, -1)) # FIXME: wei-chiu finds it weird
w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]
w = torch.stack([w0, w1, w2], dim=-1)
return w
@classmethod
def from_matrix(cls, Rt, eps=1e-8): # [...,3,4]
R, u = Rt[:3, :3], Rt[:3, 3]
w = cls.log(R)
return OptimizablePose(torch.cat([u, w], dim=-1))
@classmethod
def skew_symmetric(cls, w):
w0, w1, w2 = w.unbind(dim=-1)
O = torch.zeros_like(w0)
wx = torch.stack([
torch.stack([O, -w2, w1], dim=-1),
torch.stack([w2, O, -w0], dim=-1),
torch.stack([-w1, w0, O], dim=-1)], dim=-2)
return wx
@classmethod
def taylor_A(cls, x, nth=10):
# Taylor expansion of sin(x)/x
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
if i > 0:
denom *= (2*i)*(2*i+1)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
@classmethod
def taylor_B(cls, x, nth=10):
# Taylor expansion of (1-cos(x))/x**2
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+1)*(2*i+2)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
@classmethod
def taylor_C(cls, x, nth=10):
# Taylor expansion of (x-sin(x))/x**3
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+2)*(2*i+3)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
if __name__ == '__main__':
before = torch.tensor([[-0.955421, 0.119616, - 0.269932, 2.655830],
[0.295248, 0.388339, - 0.872939, 2.981598],
[0.000408, - 0.913720, - 0.406343, 1.368648],
[0.000000, 0.000000, 0.000000, 1.000000]])
pose = OptimizablePose.from_matrix(before)
print(pose.rotation())
print(pose.translation())
after = pose.matrix()
print(after)
print(torch.abs((before-after)[:3, 3]))
================================================
FILE: src/share.py
================================================
from multiprocessing.managers import BaseManager, NamespaceProxy
from copy import deepcopy
import torch.multiprocessing as mp
from time import sleep
import sys
class ShareDataProxy(NamespaceProxy):
_exposed_ = ('__getattribute__', '__setattr__')
class ShareData:
global lock
lock = mp.RLock()
def __init__(self):
self.__stop_mapping = False
self.__stop_tracking = False
self.__decoder = None
self.__voxels = None
self.__octree = None
self.__states = None
self.__tracking_trajectory = []
@property
def decoder(self):
with lock:
return deepcopy(self.__decoder)
print("========== decoder get ==========")
sys.stdout.flush()
@decoder.setter
def decoder(self, decoder):
with lock:
self.__decoder = deepcopy(decoder)
# print("========== decoder set ==========")
sys.stdout.flush()
@property
def voxels(self):
with lock:
return deepcopy(self.__voxels)
print("========== voxels get ==========")
sys.stdout.flush()
@voxels.setter
def voxels(self, voxels):
with lock:
self.__voxels = deepcopy(voxels)
print("========== voxels set ==========")
sys.stdout.flush()
@property
def octree(self):
with lock:
return deepcopy(self.__octree)
print("========== octree get ==========")
sys.stdout.flush()
@octree.setter
def octree(self, octree):
with lock:
self.__octree = deepcopy(octree)
print("========== octree set ==========")
sys.stdout.flush()
@property
def states(self):
with lock:
return self.__states
print("========== states get ==========")
sys.stdout.flush()
@states.setter
def states(self, states):
with lock:
self.__states = states
# print("========== states set ==========")
sys.stdout.flush()
@property
def stop_mapping(self):
with lock:
return self.__stop_mapping
print("========== stop_mapping get ==========")
sys.stdout.flush()
@stop_mapping.setter
def stop_mapping(self, stop_mapping):
with lock:
self.__stop_mapping = stop_mapping
print("========== stop_mapping set ==========")
sys.stdout.flush()
@property
def stop_tracking(self):
with lock:
return self.__stop_tracking
print("========== stop_tracking get ==========")
sys.stdout.flush()
@stop_tracking.setter
def stop_tracking(self, stop_tracking):
with lock:
self.__stop_tracking = stop_tracking
print("========== stop_tracking set ==========")
sys.stdout.flush()
@property
def tracking_trajectory(self):
with lock:
return deepcopy(self.__tracking_trajectory)
print("========== tracking_trajectory get ==========")
sys.stdout.flush()
def push_pose(self, pose):
with lock:
self.__tracking_trajectory.append(deepcopy(pose))
# print("========== push_pose ==========")
sys.stdout.flush()
================================================
FILE: src/tracking.py
================================================
import torch
import numpy as np
from tqdm import tqdm
from criterion import Criterion
from lidarFrame import LidarFrame
from utils.import_util import get_property
from utils.profile_util import Profiler
from variations.render_helpers import fill_in, render_rays, track_frame
from se3pose import OptimizablePose
from time import sleep
from copy import deepcopy
class Tracking:
def __init__(self, args, data_stream, logger):
self.args = args
self.last_frame_id = 0
self.last_frame = None
self.data_stream = data_stream
self.logger = logger
self.loss_criteria = Criterion(args)
self.voxel_size = args.mapper_specs["voxel_size"]
self.N_rays = args.tracker_specs["N_rays"]
self.num_iterations = args.tracker_specs["num_iterations"]
self.sdf_truncation = args.criteria["sdf_truncation"]
self.learning_rate = args.tracker_specs["learning_rate"]
self.start_frame = args.tracker_specs["start_frame"]
self.end_frame = args.tracker_specs["end_frame"]
self.step_size = args.tracker_specs["step_size"]
# self.keyframe_freq = args.tracker_specs["keyframe_freq"]
self.max_voxel_hit = args.tracker_specs["max_voxel_hit"]
self.max_distance = args.data_specs["max_depth"]
self.step_size = self.step_size * self.voxel_size
self.read_offset = args.tracker_specs["read_offset"]
self.mesh_freq = args.debug_args["mesh_freq"]
if self.end_frame <= 0:
self.end_frame = len(self.data_stream)-1
# sanity check on the lower/upper bounds
self.start_frame = min(self.start_frame, len(self.data_stream))
self.end_frame = min(self.end_frame, len(self.data_stream))
self.rel_pose = None
# profiler
verbose = get_property(args.debug_args, "verbose", False)
self.profiler = Profiler(verbose=verbose)
self.profiler.enable()
def process_first_frame(self, kf_buffer):
init_pose = self.data_stream.get_init_pose(self.start_frame)
index, points, pointcos, _ = self.data_stream[self.start_frame]
first_frame = LidarFrame(index, points, pointcos, init_pose)
first_frame.pose.requires_grad_(False)
first_frame.points.requires_grad_(False)
print("******* initializing first_frame:", first_frame.index)
kf_buffer.put(first_frame, block=True)
self.last_frame = first_frame
self.start_frame += 1
def spin(self, share_data, kf_buffer):
print("******* tracking process started! *******")
progress_bar = tqdm(
range(self.start_frame, self.end_frame+1), position=0)
progress_bar.set_description("tracking frame")
for frame_id in progress_bar:
if frame_id % self.read_offset != 0:
continue
if share_data.stop_tracking:
break
data_in = self.data_stream[frame_id]
current_frame = LidarFrame(*data_in)
if isinstance(data_in[3], np.ndarray):
self.last_frame = current_frame
self.check_keyframe(current_frame, kf_buffer)
else:
self.do_tracking(share_data, current_frame, kf_buffer)
share_data.stop_mapping = True
print("******* tracking process died *******")
sleep(60)
while not kf_buffer.empty():
sleep(60)
def check_keyframe(self, check_frame, kf_buffer):
try:
kf_buffer.put(check_frame, block=True)
except:
pass
def do_tracking(self, share_data, current_frame, kf_buffer):
self.profiler.tick("before track1111")
decoder = share_data.decoder.cuda()
self.profiler.tok("before track1111")
self.profiler.tick("before track2222")
map_states = share_data.states
map_states["voxel_vertex_emb"] = map_states["voxel_vertex_emb"].cuda()
self.profiler.tok("before track2222")
constant_move_pose = self.last_frame.get_pose().detach()
input_pose = deepcopy(self.last_frame.pose)
input_pose.requires_grad_(False)
if self.rel_pose != None:
constant_move_pose[:3, 3] = (constant_move_pose @ (self.rel_pose))[:3, 3]
input_pose.data[:3] = constant_move_pose[:3, 3].T
torch.cuda.empty_cache()
self.profiler.tick("track frame")
frame_pose, hit_mask = track_frame(
input_pose,
current_frame,
map_states,
decoder,
self.loss_criteria,
self.voxel_size,
self.N_rays,
self.step_size,
self.num_iterations if self.rel_pose != None else self.num_iterations*5,
self.sdf_truncation,
self.learning_rate if self.rel_pose != None else self.learning_rate,
self.max_voxel_hit,
self.max_distance,
profiler=self.profiler,
depth_variance=True
)
self.profiler.tok("track frame")
if hit_mask == None:
current_frame.pose = OptimizablePose.from_matrix(constant_move_pose)
else:
current_frame.pose = frame_pose
current_frame.hit_ratio = hit_mask.sum() / self.N_rays
self.rel_pose = torch.linalg.inv(self.last_frame.get_pose().detach()) @ current_frame.get_pose().detach()
current_frame.set_rel_pose(self.rel_pose)
self.last_frame = current_frame
self.profiler.tick("transport frame")
self.check_keyframe(current_frame, kf_buffer)
self.profiler.tok("transport frame")
================================================
FILE: src/utils/__init__.py
================================================
================================================
FILE: src/utils/import_util.py
================================================
from importlib import import_module
import argparse
def get_dataset(args):
Dataset = import_module("dataset."+args.dataset)
return Dataset.DataLoader(**args.data_specs)
def get_decoder(args):
Decoder = import_module("variations."+args.decoder)
return Decoder.Decoder(**args.decoder_specs)
def get_property(args, name, default):
if isinstance(args, dict):
return args.get(name, default)
elif isinstance(args, argparse.Namespace):
if hasattr(args, name):
return vars(args)[name]
else:
return default
else:
raise ValueError(f"unkown dict/namespace type: {type(args)}")
================================================
FILE: src/utils/mesh_util.py
================================================
import math
import torch
import numpy as np
import open3d as o3d
from scipy.spatial import cKDTree
from skimage.measure import marching_cubes
from variations.render_helpers import get_scores, eval_points
class MeshExtractor:
def __init__(self, args):
self.voxel_size = args.mapper_specs["voxel_size"]
self.rays_d = None
self.depth_points = None
@ torch.no_grad()
def linearize_id(self, xyz, n_xyz):
return xyz[:, 2] + n_xyz[-1] * xyz[:, 1] + (n_xyz[-1] * n_xyz[-2]) * xyz[:, 0]
@torch.no_grad()
def downsample_points(self, points, voxel_size=0.01):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd = pcd.voxel_down_sample(voxel_size)
return np.asarray(pcd.points)
@torch.no_grad()
def get_rays(self, w=None, h=None, K=None):
w = self.w if w == None else w
h = self.h if h == None else h
if K is None:
K = np.eye(3)
K[0, 0] = self.K[0, 0] * w / self.w
K[1, 1] = self.K[1, 1] * h / self.h
K[0, 2] = self.K[0, 2] * w / self.w
K[1, 2] = self.K[1, 2] * h / self.h
ix, iy = torch.meshgrid(
torch.arange(w), torch.arange(h), indexing='xy')
rays_d = torch.stack(
[(ix-K[0, 2]) / K[0, 0],
(iy-K[1, 2]) / K[1, 1],
torch.ones_like(ix)], -1).float()
return rays_d
@torch.no_grad()
def get_valid_points(self, frame_poses, depth_maps):
if isinstance(frame_poses, list):
all_points = []
print("extracting all points")
for i in range(0, len(frame_poses), 5):
pose = frame_poses[i]
depth = depth_maps[i]
points = self.rays_d * depth.unsqueeze(-1)
points = points.reshape(-1, 3)
points = points @ pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
if len(all_points) == 0:
all_points = points.detach().cpu().numpy()
else:
all_points = np.concatenate(
[all_points, points.detach().cpu().numpy()], 0)
print("downsample all points")
all_points = self.downsample_points(all_points)
return all_points
else:
pose = frame_poses
depth = depth_maps
points = self.rays_d * depth.unsqueeze(-1)
points = points.reshape(-1, 3)
points = points @ pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
if self.depth_points is None:
self.depth_points = points.detach().cpu().numpy()
else:
self.depth_points = np.concatenate(
[self.depth_points, points], 0)
self.depth_points = self.downsample_points(self.depth_points)
return self.depth_points
@ torch.no_grad()
def create_mesh(self, decoder, map_states, voxel_size, voxels,
frame_poses=None, depth_maps=None, clean_mseh=False,
require_color=False, offset=-80, res=8):
sdf_grid = get_scores(decoder, map_states, voxel_size, bits=res)
sdf_grid = sdf_grid.reshape(-1, res, res, res, 1)
voxel_centres = map_states["voxel_center_xyz"]
verts, faces = self.marching_cubes(voxel_centres, sdf_grid)
if clean_mseh:
print("********** get points from frames **********")
all_points = self.get_valid_points(frame_poses, depth_maps)
print("********** construct kdtree **********")
kdtree = cKDTree(all_points)
print("********** query kdtree **********")
point_mask = kdtree.query_ball_point(
verts, voxel_size * 0.5, workers=12, return_length=True)
print("********** finished querying kdtree **********")
point_mask = point_mask > 0
face_mask = point_mask[faces.reshape(-1)].reshape(-1, 3).any(-1)
faces = faces[face_mask]
if require_color:
print("********** get color from network **********")
verts_torch = torch.from_numpy(verts).float().cuda()
batch_points = torch.split(verts_torch, 1000)
colors = []
for points in batch_points:
voxel_pos = points // self.voxel_size
batch_voxels = voxels[:, :3].cuda()
batch_voxels = batch_voxels.unsqueeze(
0).repeat(voxel_pos.shape[0], 1, 1)
# filter outliers
nonzeros = (batch_voxels == voxel_pos.unsqueeze(1)).all(-1)
nonzeros = torch.where(nonzeros, torch.ones_like(
nonzeros).int(), -torch.ones_like(nonzeros).int())
sorted, index = torch.sort(nonzeros, dim=-1, descending=True)
sorted = sorted[:, 0]
index = index[:, 0]
valid = (sorted != -1)
color_empty = torch.zeros_like(points)
points = points[valid, :]
index = index[valid]
# get color
if len(points) > 0:
color = eval_points(decoder, map_states,
points, index, voxel_size).cuda()
color_empty[valid] = color
colors += [color_empty]
colors = torch.cat(colors, 0)
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts+offset)
mesh.triangles = o3d.utility.Vector3iVector(faces)
if require_color:
mesh.vertex_colors = o3d.utility.Vector3dVector(
colors.detach().cpu().numpy())
mesh.compute_vertex_normals()
return mesh
@ torch.no_grad()
def marching_cubes(self, voxels, sdf):
voxels = voxels[:, :3]
sdf = sdf[..., 0]
res = 1.0 / (sdf.shape[1] - 1)
spacing = [res, res, res]
num_verts = 0
total_verts = []
total_faces = []
for i in range(len(voxels)):
sdf_volume = sdf[i].detach().cpu().numpy()
if np.min(sdf_volume) > 0 or np.max(sdf_volume) < 0:
continue
verts, faces, _, _ = marching_cubes(sdf_volume, 0, spacing=spacing)
verts -= 0.5
verts *= self.voxel_size
verts += voxels[i].detach().cpu().numpy()
faces += num_verts
num_verts += verts.shape[0]
total_verts += [verts]
total_faces += [faces]
total_verts = np.concatenate(total_verts)
total_faces = np.concatenate(total_faces)
return total_verts, total_faces
================================================
FILE: src/utils/profile_util.py
================================================
from time import time
import torch
class Profiler(object):
def __init__(self, verbose=False) -> None:
self.timer = dict()
self.time_log = dict()
self.enabled = False
self.verbose = verbose
def enable(self):
self.enabled = True
def disable(self):
self.enabled = False
def tick(self, name):
if not self.enabled:
return
self.timer[name] = time()
if name not in self.time_log:
self.time_log[name] = list()
def tok(self, name):
if not self.enabled:
return
if name not in self.timer:
return
torch.cuda.synchronize()
elapsed = time() - self.timer[name]
if self.verbose:
print(f"{name}: {elapsed*1000:.2f} ms")
else:
self.time_log[name].append(elapsed * 1000)
================================================
FILE: src/utils/sample_util.py
================================================
import torch
def sampling_without_replacement(logp, k):
def gumbel_like(u):
return -torch.log(-torch.log(torch.rand_like(u) + 1e-7) + 1e-7)
scores = logp + gumbel_like(logp)
return scores.topk(k, dim=-1)[1]
def sample_rays(mask, num_samples):
B, H, W = mask.shape
probs = mask / (mask.sum() + 1e-9)
flatten_probs = probs.reshape(B, -1)
sampled_index = sampling_without_replacement(
torch.log(flatten_probs + 1e-9), num_samples)
sampled_masks = (torch.zeros_like(
flatten_probs).scatter_(-1, sampled_index, 1).reshape(B, H, W) > 0)
return sampled_masks
================================================
FILE: src/variations/decode_morton.py
================================================
import numpy as np
def compact(value):
x = value & 0x1249249249249249
x = (x | x >> 2) & 0x10c30c30c30c30c3
x = (x | x >> 4) & 0x100f00f00f00f00f
x = (x | x >> 8) & 0x1f0000ff0000ff
x = (x | x >> 16) & 0x1f00000000ffff
x = (x | x >> 32) & 0x1fffff
return x
def decode(code):
return compact(code >> 0), compact(code >> 1), compact(code >> 2)
for i in range(10):
x, y, z = decode(samples_valid['sampled_point_voxel_idx'][i])
print(x, y, z)
print(torch.sqrt((x-80)**2+(y-80)**2+(z-80)**2))
print(samples_valid['sampled_point_depth'][i])
================================================
FILE: src/variations/lidar.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class GaussianFourierFeatureTransform(torch.nn.Module):
"""
Modified based on the implementation of Gaussian Fourier feature mapping.
"Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
https://arxiv.org/abs/2006.10739
https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
"""
def __init__(self, num_input_channels, mapping_size=93, scale=25, learnable=True):
super().__init__()
if learnable:
self._B = nn.Parameter(torch.randn(
(num_input_channels, mapping_size)) * scale)
else:
self._B = torch.randn((num_input_channels, mapping_size)) * scale
self.embedding_size = mapping_size
def forward(self, x):
# x = x.squeeze(0)
assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(x.dim())
x = x @ self._B.to(x.device)
return torch.sin(x)
class Nerf_positional_embedding(torch.nn.Module):
"""
Nerf positional embedding.
"""
def __init__(self, in_dim, multires, log_sampling=True):
super().__init__()
self.log_sampling = log_sampling
self.include_input = True
self.periodic_fns = [torch.sin, torch.cos]
self.max_freq_log2 = multires-1
self.num_freqs = multires
self.max_freq = self.max_freq_log2
self.N_freqs = self.num_freqs
self.embedding_size = multires*in_dim*2 + in_dim
def forward(self, x):
# x = x.squeeze(0)
assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(
x.dim())
if self.log_sampling:
freq_bands = 2.**torch.linspace(0.,
self.max_freq, steps=self.N_freqs)
else:
freq_bands = torch.linspace(
2.**0., 2.**self.max_freq, steps=self.N_freqs)
output = []
if self.include_input:
output.append(x)
for freq in freq_bands:
for p_fn in self.periodic_fns:
output.append(p_fn(x * freq))
ret = torch.cat(output, dim=1)
return ret
class Same(nn.Module):
def __init__(self, in_dim) -> None:
super().__init__()
self.embedding_size = in_dim
def forward(self, x):
return x
class Decoder(nn.Module):
def __init__(self,
depth=8,
width=258,
in_dim=3,
sdf_dim=128,
skips=[4],
multires=6,
embedder='none',
point_dim=3,
local_coord=False,
**kwargs) -> None:
super().__init__()
self.D = depth
self.W = width
self.skips = skips
self.point_dim = point_dim
if embedder == 'nerf':
self.pe = Nerf_positional_embedding(in_dim, multires)
elif embedder == 'none':
self.pe = Same(in_dim)
elif embedder == 'gaussian':
self.pe = GaussianFourierFeatureTransform(in_dim)
else:
raise NotImplementedError("unknown positional encoder")
self.pts_linears = nn.ModuleList(
[nn.Linear(self.pe.embedding_size, width)] + [nn.Linear(width, width) if i not in self.skips else nn.Linear(width + self.pe.embedding_size, width) for i in range(depth-1)])
self.sdf_out = nn.Linear(width, 1)
def get_values(self, input):
x = self.pe(input)
# point = input[:, -3:]
h = x
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([x, h], -1)
# outputs = self.output_linear(h)
# outputs[:, :3] = torch.sigmoid(outputs[:, :3])
sdf_out = self.sdf_out(h)
return sdf_out
def forward(self, inputs):
outputs = self.get_values(inputs)
return {
'sdf': outputs,
# 'depth': outputs[:, 1]
}
================================================
FILE: src/variations/render_helpers.py
================================================
from copy import deepcopy
import torch
import torch.nn.functional as F
from .voxel_helpers import ray_intersect, ray_sample
from torch.autograd import grad
def ray(ray_start, ray_dir, depths):
return ray_start + ray_dir * depths
def fill_in(shape, mask, input, initial=1.0):
if isinstance(initial, torch.Tensor):
output = initial.expand(*shape)
else:
output = input.new_ones(*shape) * initial
return output.masked_scatter(mask.unsqueeze(-1).expand(*shape), input)
def masked_scatter(mask, x):
B, K = mask.size()
if x.dim() == 1:
return x.new_zeros(B, K).masked_scatter(mask, x)
return x.new_zeros(B, K, x.size(-1)).masked_scatter(
mask.unsqueeze(-1).expand(B, K, x.size(-1)), x
)
def masked_scatter_ones(mask, x):
B, K = mask.size()
if x.dim() == 1:
return x.new_ones(B, K).masked_scatter(mask, x)
return x.new_ones(B, K, x.size(-1)).masked_scatter(
mask.unsqueeze(-1).expand(B, K, x.size(-1)), x
)
@torch.enable_grad()
def trilinear_interp(p, q, point_feats):
weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True)
if point_feats.dim() == 2:
point_feats = point_feats.view(point_feats.size(0), 8, -1)
point_feats = (weights * point_feats).sum(1)
return point_feats
def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
c = torch.arange(1, 2 * bits, 2, device=point_xyz.device)
ox, oy, oz = torch.meshgrid([c, c, c], indexing='ij')
offset = (torch.cat([
ox.reshape(-1, 1),
oy.reshape(-1, 1),
oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1)
if not offset_only:
return (
point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel)
return offset.type_as(point_xyz) * quarter_voxel
@torch.enable_grad()
def get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size):
# tri-linear interpolation
p = ((sampled_xyz - point_xyz) / voxel_size + 0.5).unsqueeze(1)
q = offset_points(p, 0.5, offset_only=True).unsqueeze(0) + 0.5
feats = trilinear_interp(p, q, point_feats).float()
# if self.args.local_coord:
# feats = torch.cat([(p-.5).squeeze(1).float(), feats], dim=-1)
return feats
@torch.enable_grad()
def get_features(samples, map_states, voxel_size):
# encoder states
point_idx = map_states["voxel_vertex_idx"].cuda()
point_xyz = map_states["voxel_center_xyz"].cuda()
values = map_states["voxel_vertex_emb"]
point_id2embedid = map_states["voxel_id2embedding_id"]
# ray point samples
sampled_idx = samples["sampled_point_voxel_idx"].long()
sampled_xyz = samples["sampled_point_xyz"]
sampled_dis = samples["sampled_point_distance"]
point_xyz = F.embedding(sampled_idx, point_xyz).requires_grad_()
selected_points_idx = F.embedding(sampled_idx, point_idx)
flatten_selected_points_idx = selected_points_idx.view(-1)
embed_idx = F.embedding(flatten_selected_points_idx.cpu(), point_id2embedid).squeeze(-1)
point_feats = F.embedding(embed_idx.cuda(), values).view(point_xyz.size(0), -1)
feats = get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size)
inputs = {"xyz": point_xyz, "dists": sampled_dis, "emb": feats.cuda()}
return inputs
@torch.no_grad()
def get_scores(sdf_network, map_states, voxel_size, bits=8):
feats = map_states["voxel_vertex_idx"]
points = map_states["voxel_center_xyz"]
values = map_states["voxel_vertex_emb"]
point_id2embedid = map_states["voxel_id2embedding_id"]
chunk_size = 10000
res = bits # -1
@torch.no_grad()
def get_scores_once(feats, points, values, point_id2embedid):
torch.cuda.empty_cache()
# sample points inside voxels
start = -0.5
end = 0.5 # - 1./bits
x = y = z = torch.linspace(start, end, res)
# z = torch.linspace(1, 1, res)
xx, yy, zz = torch.meshgrid(x, y, z)
sampled_xyz = torch.stack([xx, yy, zz], dim=-1).float().cuda()
sampled_xyz *= voxel_size
sampled_xyz = sampled_xyz.reshape(1, -1, 3) + points.unsqueeze(1)
sampled_idx = torch.arange(points.size(0), device=points.device)
sampled_idx = sampled_idx[:, None].expand(*sampled_xyz.size()[:2])
sampled_idx = sampled_idx.reshape(-1)
sampled_xyz = sampled_xyz.reshape(-1, 3)
if sampled_xyz.shape[0] == 0:
return
field_inputs = get_features(
{
"sampled_point_xyz": sampled_xyz,
"sampled_point_voxel_idx": sampled_idx,
"sampled_point_ray_direction": None,
"sampled_point_distance": None,
},
{
"voxel_vertex_idx": feats,
"voxel_center_xyz": points,
"voxel_vertex_emb": values,
"voxel_id2embedding_id": point_id2embedid
},
voxel_size
)
field_inputs = field_inputs["emb"]
# evaluation with density
sdf_values = sdf_network.get_values(field_inputs.float().cuda())
return sdf_values.reshape(-1, res ** 3, 1).detach().cpu()
return torch.cat([
get_scores_once(feats[i: i + chunk_size],
points[i: i + chunk_size].cuda(), values, point_id2embedid)
for i in range(0, points.size(0), chunk_size)], 0).view(-1, res, res, res, 1)
@torch.no_grad()
def eval_points(sdf_network, map_states, sampled_xyz, sampled_idx, voxel_size):
feats = map_states["voxel_vertex_idx"]
points = map_states["voxel_center_xyz"]
values = map_states["voxel_vertex_emb"]
# sampled_xyz = sampled_xyz.reshape(1, 3) + points.unsqueeze(1)
# sampled_idx = sampled_idx[None, :].expand(*sampled_xyz.size()[:2])
sampled_idx = sampled_idx.reshape(-1)
sampled_xyz = sampled_xyz.reshape(-1, 3)
if sampled_xyz.shape[0] == 0:
return
field_inputs = get_features(
{
"sampled_point_xyz": sampled_xyz,
"sampled_point_voxel_idx": sampled_idx,
"sampled_point_ray_direction": None,
"sampled_point_distance": None,
},
{
"voxel_vertex_idx": feats,
"voxel_center_xyz": points,
"voxel_vertex_emb": values,
},
voxel_size
)
# evaluation with density
sdf_values = sdf_network.get_values(field_inputs['emb'].float().cuda())
return sdf_values.reshape(-1, 4)[:, :3].detach().cpu()
def render_rays(
rays_o,
rays_d,
map_states,
sdf_network,
step_size,
voxel_size,
truncation,
max_voxel_hit,
max_distance,
chunk_size=10000,
profiler=None,
return_raw=False
):
torch.cuda.empty_cache()
centres = map_states["voxel_center_xyz"].cuda()
childrens = map_states["voxel_structure"].cuda()
if profiler is not None:
profiler.tick("ray_intersect")
# print("Center", rays_o[0][0])
intersections, hits = ray_intersect(
rays_o, rays_d, centres,
childrens, voxel_size, max_voxel_hit, max_distance)
if profiler is not None:
profiler.tok("ray_intersect")
if hits.sum() <= 0:
return
ray_mask = hits.view(1, -1)
intersections = {
name: outs[ray_mask].reshape(-1, outs.size(-1))
for name, outs in intersections.items()
}
rays_o = rays_o[ray_mask].reshape(-1, 3)
rays_d = rays_d[ray_mask].reshape(-1, 3)
if profiler is not None:
profiler.tick("ray_sample")
samples = ray_sample(intersections, step_size=step_size)
if samples == None:
return
if profiler is not None:
profiler.tok("ray_sample")
sampled_depth = samples['sampled_point_depth']
sampled_idx = samples['sampled_point_voxel_idx'].long()
# only compute when the ray hits
sample_mask = sampled_idx.ne(-1)
if sample_mask.sum() == 0: # miss everything skip
return None, 0
sampled_xyz = ray(rays_o.unsqueeze(
1), rays_d.unsqueeze(1), sampled_depth.unsqueeze(2))
sampled_dir = rays_d.unsqueeze(1).expand(
*sampled_depth.size(), rays_d.size()[-1])
sampled_dir = sampled_dir / \
(torch.norm(sampled_dir, 2, -1, keepdim=True) + 1e-8)
samples['sampled_point_xyz'] = sampled_xyz
samples['sampled_point_ray_direction'] = sampled_dir
# apply mask
samples_valid = {name: s[sample_mask] for name, s in samples.items()}
# print("samples_valid_xyz", samples["sampled_point_xyz"].shape)
num_points = samples_valid['sampled_point_depth'].shape[0]
field_outputs = []
if chunk_size < 0:
chunk_size = num_points
final_xyz = []
xyz = 0
for i in range(0, num_points, chunk_size):
torch.cuda.empty_cache()
chunk_samples = {name: s[i:i+chunk_size]
for name, s in samples_valid.items()}
# get encoder features as inputs
if profiler is not None:
profiler.tick("get_features")
chunk_inputs = get_features(chunk_samples, map_states, voxel_size)
xyz = chunk_inputs["xyz"]
if profiler is not None:
profiler.tok("get_features")
# add coordinate information
chunk_inputs = chunk_inputs["emb"]
# forward implicit fields
if profiler is not None:
profiler.tick("render_core")
chunk_outputs = sdf_network(chunk_inputs)
if profiler is not None:
profiler.tok("render_core")
final_xyz.append(xyz)
field_outputs.append(chunk_outputs)
field_outputs = {name: torch.cat(
[r[name] for r in field_outputs], dim=0) for name in field_outputs[0]}
final_xyz = torch.cat(final_xyz, 0)
outputs = field_outputs['sdf']
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
sdf_grad = grad(outputs=outputs,
inputs=xyz,
grad_outputs=d_points,
retain_graph=True,)
outputs = {'sample_mask': sample_mask}
sdf = masked_scatter_ones(sample_mask, field_outputs['sdf']).squeeze(-1)
# depth = masked_scatter(sample_mask, field_outputs['depth'])
# colour = torch.sigmoid(colour)
sample_mask = outputs['sample_mask']
valid_mask = torch.where(
sample_mask, torch.ones_like(
sample_mask), torch.zeros_like(sample_mask)
)
return {
"z_vals": samples["sampled_point_depth"],
"sdf": sdf,
"ray_mask": ray_mask,
"valid_mask": valid_mask,
"sampled_xyz": xyz,
}
def bundle_adjust_frames(
keyframe_graph,
embeddings,
map_states,
sdf_network,
loss_criteria,
voxel_size,
step_size,
N_rays=512,
num_iterations=10,
truncation=0.1,
max_voxel_hit=10,
max_distance=10,
learning_rate=[1e-2, 1e-2, 5e-3],
update_pose=True,
update_decoder=True,
profiler=None
):
if profiler is not None:
profiler.tick("mapping_add_optim")
optimize_params = [{'params': embeddings, 'lr': learning_rate[0]}]
if update_decoder:
optimize_params += [{'params': sdf_network.parameters(),
'lr': learning_rate[1]}]
for keyframe in keyframe_graph:
if keyframe.index != 0 and update_pose:
keyframe.pose.requires_grad_(True)
optimize_params += [{
'params': keyframe.pose.parameters(), 'lr': learning_rate[2]
}]
optim = torch.optim.Adam(optimize_params)
if profiler is not None:
profiler.tok("mapping_add_optim")
for iter in range(num_iterations):
torch.cuda.empty_cache()
rays_o = []
rays_d = []
rgb_samples = []
depth_samples = []
points_samples = []
pointsCos_samples = []
if iter == 0 and profiler is not None:
profiler.tick("mapping sample_rays")
for frame in keyframe_graph:
torch.cuda.empty_cache()
pose = frame.get_pose().cuda()
frame.sample_rays(N_rays)
sample_mask = frame.sample_mask.cuda()
sampled_rays_d = frame.rays_d[sample_mask].cuda()
# print(sampled_rays_d)
R = pose[: 3, : 3].transpose(-1, -2)
sampled_rays_d = sampled_rays_d@R
sampled_rays_o = pose[: 3, 3].reshape(1, -1).expand_as(sampled_rays_d)
rays_d += [sampled_rays_d]
rays_o += [sampled_rays_o]
points_samples += [frame.points.unsqueeze(1).cuda()[sample_mask]]
pointsCos_samples += [frame.pointsCos.unsqueeze(1).cuda()[sample_mask]]
# rgb_samples += [frame.rgb.cuda()[sample_mask]]
# depth_samples += [frame.depth.cuda()[sample_mask]]
rays_d = torch.cat(rays_d, dim=0).unsqueeze(0)
rays_o = torch.cat(rays_o, dim=0).unsqueeze(0)
points_samples = torch.cat(points_samples, dim=0).unsqueeze(0)
pointsCos_samples = torch.cat(pointsCos_samples, dim=0).unsqueeze(0)
if iter == 0 and profiler is not None:
profiler.tok("mapping sample_rays")
if iter == 0 and profiler is not None:
profiler.tick("mapping rendering")
final_outputs = render_rays(
rays_o,
rays_d,
map_states,
sdf_network,
step_size,
voxel_size,
truncation,
max_voxel_hit,
max_distance,
chunk_size=-1,
profiler=profiler if iter == 0 else None
)
if final_outputs == None:
print("Encouter a bug while Mapping, currently not be fixed, Continue!!")
hit_mask = None
continue
if iter == 0 and profiler is not None:
profiler.tok("mapping rendering")
# if final_outputs == None:
# continue
if iter == 0 and profiler is not None:
profiler.tick("mapping back proj")
torch.cuda.empty_cache()
loss, _ = loss_criteria(
final_outputs, points_samples, pointsCos_samples)
optim.zero_grad()
loss.backward()
optim.step()
if iter == 0 and profiler is not None:
profiler.tok("mapping back proj")
def track_frame(
frame_pose,
curr_frame,
map_states,
sdf_network,
loss_criteria,
voxel_size,
N_rays=512,
step_size=0.05,
num_iterations=10,
truncation=0.1,
learning_rate=1e-3,
max_voxel_hit=10,
max_distance=10,
profiler=None,
depth_variance=False
):
torch.cuda.empty_cache()
init_pose = deepcopy(frame_pose).cuda()
init_pose.requires_grad_(True)
optim = torch.optim.Adam(init_pose.parameters(),
lr=learning_rate*2 if curr_frame.index < 2
else learning_rate/3)
for iter in range(num_iterations):
torch.cuda.empty_cache()
if iter == 0 and profiler is not None:
profiler.tick("track sample_rays")
curr_frame.sample_rays(N_rays, track=True)
if iter == 0 and profiler is not None:
profiler.tok("track sample_rays")
sample_mask = curr_frame.sample_mask
ray_dirs = curr_frame.rays_d[sample_mask].unsqueeze(0).cuda()
points_samples = curr_frame.points.unsqueeze(1).cuda()[sample_mask]
pointsCos_samples = curr_frame.pointsCos.unsqueeze(1).cuda()[sample_mask]
ray_dirs_iter = ray_dirs.squeeze(
0) @ init_pose.rotation().transpose(-1, -2)
ray_dirs_iter = ray_dirs_iter.unsqueeze(0)
ray_start_iter = init_pose.translation().reshape(
1, 1, -1).expand_as(ray_dirs_iter).cuda().contiguous()
if iter == 0 and profiler is not None:
profiler.tick("track render_rays")
final_outputs = render_rays(
ray_start_iter,
ray_dirs_iter,
map_states,
sdf_network,
step_size,
voxel_size,
truncation,
max_voxel_hit,
max_distance,
chunk_size=-2,
profiler=profiler if iter == 0 else None
)
if final_outputs == None:
print("Encouter a bug while Tracking, currently not be fixed, Restarting!!")
hit_mask = None
break
torch.cuda.empty_cache()
if iter == 0 and profiler is not None:
profiler.tok("track render_rays")
hit_mask = final_outputs["ray_mask"].view(N_rays)
final_outputs["ray_mask"] = hit_mask
if iter == 0 and profiler is not None:
profiler.tick("track loss_criteria")
loss, _ = loss_criteria(
final_outputs, points_samples, pointsCos_samples, weight_depth_loss=depth_variance)
if iter == 0 and profiler is not None:
profiler.tok("track loss_criteria")
if iter == 0 and profiler is not None:
profiler.tick("track backward step")
optim.zero_grad()
loss.backward()
optim.step()
if iter == 0 and profiler is not None:
profiler.tok("track backward step")
return init_pose, hit_mask
================================================
FILE: src/variations/voxel_helpers.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
""" Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch """
from __future__ import (
division,
absolute_import,
with_statement,
print_function,
unicode_literals,
)
import os
import sys
import torch
import torch.nn.functional as F
from torch.autograd import Function
import torch.nn as nn
import sys
import numpy as np
import grid as _ext
MAX_DEPTH = 80
class BallRayIntersect(Function):
@staticmethod
def forward(ctx, radius, n_max, points, ray_start, ray_dir):
inds, min_depth, max_depth = _ext.ball_intersect(
ray_start.float(), ray_dir.float(), points.float(), radius, n_max
)
min_depth = min_depth.type_as(ray_start)
max_depth = max_depth.type_as(ray_start)
ctx.mark_non_differentiable(inds)
ctx.mark_non_differentiable(min_depth)
ctx.mark_non_differentiable(max_depth)
return inds, min_depth, max_depth
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None
ball_ray_intersect = BallRayIntersect.apply
class AABBRayIntersect(Function):
@staticmethod
def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir):
# HACK: speed-up ray-voxel intersection by batching...
# HACK: avoid out-of-memory
G = min(2048, int(2 * 10 ** 9 / points.numel()))
S, N = ray_start.shape[:2]
K = int(np.ceil(N / G))
H = K * G
if H > N:
ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
ray_start = ray_start.reshape(S * G, K, 3)
ray_dir = ray_dir.reshape(S * G, K, 3)
points = points.expand(S * G, *points.size()[1:]).contiguous()
inds, min_depth, max_depth = _ext.aabb_intersect(
ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max
)
min_depth = min_depth.type_as(ray_start)
max_depth = max_depth.type_as(ray_start)
inds = inds.reshape(S, H, -1)
min_depth = min_depth.reshape(S, H, -1)
max_depth = max_depth.reshape(S, H, -1)
if H > N:
inds = inds[:, :N]
min_depth = min_depth[:, :N]
max_depth = max_depth[:, :N]
ctx.mark_non_differentiable(inds)
ctx.mark_non_differentiable(min_depth)
ctx.mark_non_differentiable(max_depth)
return inds, min_depth, max_depth
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None
aabb_ray_intersect = AABBRayIntersect.apply
class SparseVoxelOctreeRayIntersect(Function):
@staticmethod
def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir):
# HACK: avoid out-of-memory
torch.cuda.empty_cache()
G = min(256, int(2 * 10 ** 9 / (points.numel() + children.numel())))
S, N = ray_start.shape[:2]
K = int(np.ceil(N / G))
H = K * G
if H > N:
ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
ray_start = ray_start.reshape(S * G, K, 3)
ray_dir = ray_dir.reshape(S * G, K, 3)
points = points.expand(S * G, *points.size()).contiguous()
torch.cuda.empty_cache()
children = children.expand(S * G, *children.size()).contiguous()
torch.cuda.empty_cache()
inds, min_depth, max_depth = _ext.svo_intersect(
ray_start.float(),
ray_dir.float(),
points.float(),
children.int(),
voxelsize,
n_max,
)
torch.cuda.empty_cache()
min_depth = min_depth.type_as(ray_start)
max_depth = max_depth.type_as(ray_start)
inds = inds.reshape(S, H, -1)
min_depth = min_depth.reshape(S, H, -1)
max_depth = max_depth.reshape(S, H, -1)
if H > N:
inds = inds[:, :N]
min_depth = min_depth[:, :N]
max_depth = max_depth[:, :N]
ctx.mark_non_differentiable(inds)
ctx.mark_non_differentiable(min_depth)
ctx.mark_non_differentiable(max_depth)
return inds, min_depth, max_depth
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None
svo_ray_intersect = SparseVoxelOctreeRayIntersect.apply
class TriangleRayIntersect(Function):
@staticmethod
def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir):
# HACK: speed-up ray-voxel intersection by batching...
# HACK: avoid out-of-memory
G = min(2048, int(2 * 10 ** 9 / (3 * faces.numel())))
S, N = ray_start.shape[:2]
K = int(np.ceil(N / G))
H = K * G
if H > N:
ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
ray_start = ray_start.reshape(S * G, K, 3)
ray_dir = ray_dir.reshape(S * G, K, 3)
face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3))
face_points = (
face_points.unsqueeze(0).expand(
S * G, *face_points.size()).contiguous()
)
inds, depth, uv = _ext.triangle_intersect(
ray_start.float(),
ray_dir.float(),
face_points.float(),
cagesize,
blur_ratio,
n_max,
)
depth = depth.type_as(ray_start)
uv = uv.type_as(ray_start)
inds = inds.reshape(S, H, -1)
depth = depth.reshape(S, H, -1, 3)
uv = uv.reshape(S, H, -1)
if H > N:
inds = inds[:, :N]
depth = depth[:, :N]
uv = uv[:, :N]
ctx.mark_non_differentiable(inds)
ctx.mark_non_differentiable(depth)
ctx.mark_non_differentiable(uv)
return inds, depth, uv
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None, None
triangle_ray_intersect = TriangleRayIntersect.apply
class UniformRaySampling(Function):
@staticmethod
def forward(
ctx,
pts_idx,
min_depth,
max_depth,
step_size,
max_ray_length,
deterministic=False,
):
G, N, P = 256, pts_idx.size(0), pts_idx.size(1)
H = int(np.ceil(N / G)) * G
if H > N:
pts_idx = torch.cat([pts_idx, pts_idx[: H - N]], 0)
min_depth = torch.cat([min_depth, min_depth[: H - N]], 0)
max_depth = torch.cat([max_depth, max_depth[: H - N]], 0)
pts_idx = pts_idx.reshape(G, -1, P)
min_depth = min_depth.reshape(G, -1, P)
max_depth = max_depth.reshape(G, -1, P)
# pre-generate noise
max_steps = int(max_ray_length / step_size)
max_steps = max_steps + min_depth.size(-1) * 2
noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
if deterministic:
noise += 0.5
else:
noise = noise.uniform_()
# call cuda function
sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling(
pts_idx,
min_depth.float(),
max_depth.float(),
noise.float(),
step_size,
max_steps,
)
sampled_depth = sampled_depth.type_as(min_depth)
sampled_dists = sampled_dists.type_as(min_depth)
sampled_idx = sampled_idx.reshape(H, -1)
sampled_depth = sampled_depth.reshape(H, -1)
sampled_dists = sampled_dists.reshape(H, -1)
if H > N:
sampled_idx = sampled_idx[:N]
sampled_depth = sampled_depth[:N]
sampled_dists = sampled_dists[:N]
max_len = sampled_idx.ne(-1).sum(-1).max()
sampled_idx = sampled_idx[:, :max_len]
sampled_depth = sampled_depth[:, :max_len]
sampled_dists = sampled_dists[:, :max_len]
ctx.mark_non_differentiable(sampled_idx)
ctx.mark_non_differentiable(sampled_depth)
ctx.mark_non_differentiable(sampled_dists)
return sampled_idx, sampled_depth, sampled_dists
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None, None
uniform_ray_sampling = UniformRaySampling.apply
class InverseCDFRaySampling(Function):
@staticmethod
def forward(
ctx,
pts_idx,
min_depth,
max_depth,
probs,
steps,
fixed_step_size=-1,
deterministic=False,
):
G, N, P = 200, pts_idx.size(0), pts_idx.size(1)
H = int(np.ceil(N / G)) * G
if H > N:
pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H - N, P)], 0)
min_depth = torch.cat(
[min_depth, min_depth[:1].expand(H - N, P)], 0)
max_depth = torch.cat(
[max_depth, max_depth[:1].expand(H - N, P)], 0)
probs = torch.cat([probs, probs[:1].expand(H - N, P)], 0)
steps = torch.cat([steps, steps[:1].expand(H - N)], 0)
# print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device)
pts_idx = pts_idx.reshape(G, -1, P)
min_depth = min_depth.reshape(G, -1, P)
max_depth = max_depth.reshape(G, -1, P)
probs = probs.reshape(G, -1, P)
steps = steps.reshape(G, -1)
# pre-generate noise
max_steps = steps.ceil().long().max() + P
# print(max_steps)
# print(*min_depth.size()[:-1]," ", max_steps)
noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
if deterministic:
noise += 0.5
else:
noise = noise.uniform_().clamp(min=0.001, max=0.999) # in case
# call cuda function
chunk_size = 4 * G # to avoid oom?
results = [
_ext.inverse_cdf_sampling(
pts_idx[:, i: i + chunk_size].contiguous(),
min_depth.float()[:, i: i + chunk_size].contiguous(),
max_depth.float()[:, i: i + chunk_size].contiguous(),
noise.float()[:, i: i + chunk_size].contiguous(),
probs.float()[:, i: i + chunk_size].contiguous(),
steps.float()[:, i: i + chunk_size].contiguous(),
fixed_step_size,
)
for i in range(0, min_depth.size(1), chunk_size)
]
sampled_idx, sampled_depth, sampled_dists = [
torch.cat([r[i] for r in results], 1) for i in range(3)
]
sampled_depth = sampled_depth.type_as(min_depth)
sampled_dists = sampled_dists.type_as(min_depth)
sampled_idx = sampled_idx.reshape(H, -1)
sampled_depth = sampled_depth.reshape(H, -1)
sampled_dists = sampled_dists.reshape(H, -1)
if H > N:
sampled_idx = sampled_idx[:N]
sampled_depth = sampled_depth[:N]
sampled_dists = sampled_dists[:N]
max_len = sampled_idx.ne(-1).sum(-1).max()
sampled_idx = sampled_idx[:, :max_len]
sampled_depth = sampled_depth[:, :max_len]
sampled_dists = sampled_dists[:, :max_len]
ctx.mark_non_differentiable(sampled_idx)
ctx.mark_non_differentiable(sampled_depth)
ctx.mark_non_differentiable(sampled_dists)
return sampled_idx, sampled_depth, sampled_dists
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None, None, None
inverse_cdf_sampling = InverseCDFRaySampling.apply
# back-up for ray point sampling
@torch.no_grad()
def _parallel_ray_sampling(
MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False
):
# uniform sampling
_min_depth = min_depth.min(1)[0]
_max_depth = max_depth.masked_fill(max_depth.eq(MAX_DEPTH), 0).max(1)[0]
max_ray_length = (_max_depth - _min_depth).max()
delta = torch.arange(
int(max_ray_length / MARCH_SIZE), device=min_depth.device, dtype=min_depth.dtype
)
delta = delta[None, :].expand(min_depth.size(0), delta.size(-1))
if deterministic:
delta = delta + 0.5
else:
delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99)
delta = delta * MARCH_SIZE
sampled_depth = min_depth[:, :1] + delta
sampled_idx = (sampled_depth[:, :, None] >=
min_depth[:, None, :]).sum(-1) - 1
sampled_idx = pts_idx.gather(1, sampled_idx)
# include all boundary points
sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1)
sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1)
# reorder
sampled_depth, ordered_index = sampled_depth.sort(-1)
sampled_idx = sampled_idx.gather(1, ordered_index)
sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1] # distances
sampled_depth = 0.5 * \
(sampled_depth[:, 1:] + sampled_depth[:, :-1]) # mid-points
# remove all invalid depths
min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1
max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1)
sampled_depth.masked_fill_(
(max_ids.ne(min_ids))
| (sampled_depth > _max_depth[:, None])
| (sampled_dists == 0.0),
MAX_DEPTH,
)
sampled_depth, ordered_index = sampled_depth.sort(-1) # sort again
sampled_masks = sampled_depth.eq(MAX_DEPTH)
num_max_steps = (~sampled_masks).sum(-1).max()
sampled_depth = sampled_depth[:, :num_max_steps]
sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_(
sampled_masks, 0.0
)[:, :num_max_steps]
sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_(sampled_masks, -1)[
:, :num_max_steps
]
return sampled_idx, sampled_depth, sampled_dists
@torch.no_grad()
def parallel_ray_sampling(
MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False
):
chunk_size = 4096
full_size = min_depth.shape[0]
if full_size <= chunk_size:
return _parallel_ray_sampling(
MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic
)
outputs = zip(
*[
_parallel_ray_sampling(
MARCH_SIZE,
pts_idx[i: i + chunk_size],
min_depth[i: i + chunk_size],
max_depth[i: i + chunk_size],
deterministic=deterministic,
)
for i in range(0, full_size, chunk_size)
]
)
sampled_idx, sampled_depth, sampled_dists = outputs
def padding_points(xs, pad):
if len(xs) == 1:
return xs[0]
maxlen = max([x.size(1) for x in xs])
full_size = sum([x.size(0) for x in xs])
xt = xs[0].new_ones(full_size, maxlen).fill_(pad)
st = 0
for i in range(len(xs)):
xt[st: st + xs[i].size(0), : xs[i].size(1)] = xs[i]
st += xs[i].size(0)
return xt
sampled_idx = padding_points(sampled_idx, -1)
sampled_depth = padding_points(sampled_depth, MAX_DEPTH)
sampled_dists = padding_points(sampled_dists, 0.0)
return sampled_idx, sampled_depth, sampled_dists
def discretize_points(voxel_points, voxel_size):
# this function turns voxel centers/corners into integer indeices
# we assume all points are alreay put as voxels (real numbers)
minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0]
voxel_indices = (
((voxel_points - minimal_voxel_point) / voxel_size).round_().long()
) # float
residual = (voxel_points - voxel_indices.type_as(voxel_points) * voxel_size).mean(
0, keepdim=True
)
return voxel_indices, residual
def build_easy_octree(points, half_voxel):
coords, residual = discretize_points(points, half_voxel)
ranges = coords.max(0)[0] - coords.min(0)[0]
depths = torch.log2(ranges.max().float()).ceil_().long() - 1
center = (coords.max(0)[0] + coords.min(0)[0]) / 2
centers, children = _ext.build_octree(center, coords, int(depths))
centers = centers.float() * half_voxel + residual # transform back to float
return centers, children
@torch.enable_grad()
def trilinear_interp(p, q, point_feats):
weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True)
if point_feats.dim() == 2:
point_feats = point_feats.view(point_feats.size(0), 8, -1)
point_feats = (weights * point_feats).sum(1)
return point_feats
def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
c = torch.arange(1, 2 * bits, 2, device=point_xyz.device)
ox, oy, oz = torch.meshgrid([c, c, c], indexing='ij')
offset = (torch.cat([ox.reshape(-1, 1),
oy.reshape(-1, 1),
oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1)
if not offset_only:
return (
point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel)
return offset.type_as(point_xyz) * quarter_voxel
def splitting_points(point_xyz, point_feats, values, half_voxel):
# generate new centers
quarter_voxel = half_voxel * 0.5
new_points = offset_points(point_xyz, quarter_voxel).reshape(-1, 3)
old_coords = discretize_points(point_xyz, quarter_voxel)[0]
new_coords = offset_points(old_coords).reshape(-1, 3)
new_keys0 = offset_points(new_coords).reshape(-1, 3)
# get unique keys and inverse indices (for original key0, where it maps to in keys)
new_keys, new_feats = torch.unique(
new_keys0, dim=0, sorted=True, return_inverse=True)
new_keys_idx = new_feats.new_zeros(new_keys.size(0)).scatter_(
0, new_feats, torch.arange(new_keys0.size(0), device=new_feats.device) // 64)
# recompute key vectors using trilinear interpolation
new_feats = new_feats.reshape(-1, 8)
if values is not None:
# (1/4 voxel size)
p = (new_keys - old_coords[new_keys_idx]
).type_as(point_xyz).unsqueeze(1) * 0.25 + 0.5
q = offset_points(p, 0.5, offset_only=True).unsqueeze(0) + 0.5 # BUG?
point_feats = point_feats[new_keys_idx]
point_feats = F.embedding(point_feats, values).view(
point_feats.size(0), -1)
new_values = trilinear_interp(p, q, point_feats)
else:
new_values = None
return new_points, new_feats, new_values, new_keys
@torch.no_grad()
def ray_intersect(ray_start, ray_dir, flatten_centers, flatten_children, voxel_size, max_hits, max_distance=MAX_DEPTH):
# ray-voxel intersection
max_hits_temp = 20
pts_idx, min_depth, max_depth = svo_ray_intersect(
voxel_size,
max_hits_temp,
flatten_centers,
flatten_children,
ray_start,
ray_dir)
torch.cuda.empty_cache()
# sort the depths
min_depth.masked_fill_(pts_idx.eq(-1), max_distance)
max_depth.masked_fill_(pts_idx.eq(-1), max_distance)
min_depth, sorted_idx = min_depth.sort(dim=-1)
max_depth = max_depth.gather(-1, sorted_idx)
pts_idx = pts_idx.gather(-1, sorted_idx)
# print(max_depth.max())
pts_idx[max_depth > 2*max_distance] = -1
pts_idx[min_depth > max_distance] = -1
min_depth.masked_fill_(pts_idx.eq(-1), max_distance)
max_depth.masked_fill_(pts_idx.eq(-1), max_distance)
# remove all points that completely miss the object
max_hits = torch.max(pts_idx.ne(-1).sum(-1))
min_depth = min_depth[..., :max_hits]
max_depth = max_depth[..., :max_hits]
pts_idx = pts_idx[..., :max_hits]
hits = pts_idx.ne(-1).any(-1)
intersection_outputs = {
"min_depth": min_depth,
"max_depth": max_depth,
"intersected_voxel_idx": pts_idx,
}
return intersection_outputs, hits
@torch.no_grad()
def ray_sample(intersection_outputs, step_size=0.01, fixed=False):
dists = (
intersection_outputs["max_depth"] -
intersection_outputs["min_depth"]
).masked_fill(intersection_outputs["intersected_voxel_idx"].eq(-1), 0)
intersection_outputs["probs"] = dists / dists.sum(dim=-1, keepdim=True)
intersection_outputs["steps"] = dists.sum(-1) / step_size
# TODO:A serious BUG need to fix!
if dists.sum(-1).max() > 10 * MAX_DEPTH:
return
# sample points and use middle point approximation
sampled_idx, sampled_depth, sampled_dists = inverse_cdf_sampling(
intersection_outputs["intersected_voxel_idx"],
intersection_outputs["min_depth"],
intersection_outputs["max_depth"],
intersection_outputs["probs"],
intersection_outputs["steps"], -1, fixed)
sampled_dists = sampled_dists.clamp(min=0.0)
sampled_depth.masked_fill_(sampled_idx.eq(-1), MAX_DEPTH)
sampled_dists.masked_fill_(sampled_idx.eq(-1), 0.0)
samples = {
"sampled_point_depth": sampled_depth,
"sampled_point_distance": sampled_dists,
"sampled_point_voxel_idx": sampled_idx,
}
return samples
================================================
FILE: third_party/marching_cubes/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import glob
_ext_sources = glob.glob("src/*.cpp") + glob.glob("src/*.cu")
setup(
name='marching_cubes',
ext_modules=[
CUDAExtension(
name='marching_cubes',
sources=_ext_sources,
extra_compile_args={
"cxx": ["-O2", "-I./include"],
"nvcc": ["-I./include"]
},
)
],
cmdclass={
'build_ext': BuildExtension
}
)
================================================
FILE: third_party/marching_cubes/src/mc.cpp
================================================
#include <torch/extension.h>
std::vector<torch::Tensor> marching_cubes_sparse(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector<int> &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer.
);
std::vector<torch::Tensor> marching_cubes_sparse_colour(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz, 4)
torch::Tensor cube_colour, // (M, rx, ry, rz)
const std::vector<int> &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer.
);
std::vector<torch::Tensor> marching_cubes_sparse_interp_cuda(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector<int> &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer.
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("marching_cubes_sparse", &marching_cubes_sparse, "Marching Cubes without Interpolation (CUDA)");
m.def("marching_cubes_sparse_colour", &marching_cubes_sparse_colour, "Marching Cubes without Interpolation (CUDA)");
m.def("marching_cubes_sparse_interp", &marching_cubes_sparse_interp_cuda, "Marching Cubes with Interpolation (CUDA)");
}
================================================
FILE: third_party/marching_cubes/src/mc_data.cuh
================================================
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <thrust/device_vector.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
using IndexerAccessor = torch::PackedTensorAccessor32<int64_t, 3, torch::RestrictPtrTraits>;
using ValidBlocksAccessor = torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits>;
using BackwardMappingAccessor = torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits>;
using CubeSDFAccessor = torch::PackedTensorAccessor32<float, 4, torch::RestrictPtrTraits>;
using CubeSDFRGBAccessor = torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits>;
using TrianglesAccessor = torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits>;
using TriangleStdAccessor = torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits>;
using TriangleVecIdAccessor = torch::PackedTensorAccessor32<int64_t, 1, torch::RestrictPtrTraits>;
__inline__ __device__ float4 make_float4(const float3& xyz, float w) {
return make_float4(xyz.x, xyz.y, xyz.z, w);
}
inline __host__ __device__ void operator+=(float2 &a, const float2& b) {
a.x += b.x; a.y += b.y;
}
inline __host__ __device__ float2 operator*(const float2& a, float b) {
return make_float2(a.x * b, a.y * b);
}
inline __host__ __device__ float2 operator/(const float2& a, float b) {
return make_float2(a.x / b, a.y / b);
}
inline __host__ __device__ float2 operator/(const float2& a, const float2& b) {
return make_float2(a.x / b.x, a.y / b.y);
}
__constant__ int edgeTable[256] = { 0x0, 0x109, 0x203, 0x30a, 0x406, 0x50f, 0x605, 0x70c, 0x80c, 0x905, 0xa0f, 0xb06, 0xc0a, 0xd03, 0xe09, 0xf00,
0x190, 0x99, 0x393, 0x29a, 0x596, 0x49f, 0x795, 0x69c, 0x99c, 0x895, 0xb9f, 0xa96, 0xd9a, 0xc93, 0xf99, 0xe90, 0x230, 0x339, 0x33, 0x13a,
0x636, 0x73f, 0x435, 0x53c, 0xa3c, 0xb35, 0x83f, 0x936, 0xe3a, 0xf33, 0xc39, 0xd30, 0x3a0, 0x2a9, 0x1a3, 0xaa, 0x7a6, 0x6af, 0x5a5, 0x4ac,
0xbac, 0xaa5, 0x9af, 0x8a6, 0xfaa, 0xea3, 0xda9, 0xca0, 0x460, 0x569, 0x663, 0x76a, 0x66, 0x16f, 0x265, 0x36c, 0xc6c, 0xd65, 0xe6f, 0xf66,
0x86a, 0x963, 0xa69, 0xb60, 0x5f0, 0x4f9, 0x7f3, 0x6fa, 0x1f6, 0xff, 0x3f5, 0x2fc, 0xdfc, 0xcf5, 0xfff, 0xef6, 0x9fa, 0x8f3, 0xbf9, 0xaf0,
0x650, 0x759, 0x453, 0x55a, 0x256, 0x35f, 0x55, 0x15c, 0xe5c, 0xf55, 0xc5f, 0xd56, 0xa5a, 0xb53, 0x859, 0x950, 0x7c0, 0x6c9, 0x5c3, 0x4ca,
0x3c6, 0x2cf, 0x1c5, 0xcc, 0xfcc, 0xec5, 0xdcf, 0xcc6, 0xbca, 0xac3, 0x9c9, 0x8c0, 0x8c0, 0x9c9, 0xac3, 0xbca, 0xcc6, 0xdcf, 0xec5, 0xfcc,
0xcc, 0x1c5, 0x2cf, 0x3c6, 0x4ca, 0x5c3, 0x6c9, 0x7c0, 0x950, 0x859, 0xb53, 0xa5a, 0xd56, 0xc5f, 0xf55, 0xe5c, 0x15c, 0x55, 0x35f, 0x256,
0x55a, 0x453, 0x759, 0x650, 0xaf0, 0xbf9, 0x8f3, 0x9fa, 0xef6, 0xfff, 0xcf5, 0xdfc, 0x2fc, 0x3f5, 0xff, 0x1f6, 0x6fa, 0x7f3, 0x4f9, 0x5f0,
0xb60, 0xa69, 0x963, 0x86a, 0xf66, 0xe6f, 0xd65, 0xc6c, 0x36c, 0x265, 0x16f, 0x66, 0x76a, 0x663, 0x569, 0x460, 0xca0, 0xda9, 0xea3, 0xfaa,
0x8a6, 0x9af, 0xaa5, 0xbac, 0x4ac, 0x5a5, 0x6af, 0x7a6, 0xaa, 0x1a3, 0x2a9, 0x3a0, 0xd30, 0xc39, 0xf33, 0xe3a, 0x936, 0x83f, 0xb35, 0xa3c,
0x53c, 0x435, 0x73f, 0x636, 0x13a, 0x33, 0x339, 0x230, 0xe90, 0xf99, 0xc93, 0xd9a, 0xa96, 0xb9f, 0x895, 0x99c, 0x69c, 0x795, 0x49f, 0x596,
0x29a, 0x393, 0x99, 0x190, 0xf00, 0xe09, 0xd03, 0xc0a, 0xb06, 0xa0f, 0x905, 0x80c, 0x70c, 0x605, 0x50f, 0x406, 0x30a, 0x203, 0x109, 0x0 };
__constant__ int triangleTable[256][16] = { { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1 }, { 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1 }, { 3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1 }, { 3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1 }, { 9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1 }, { 8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1 }, { 3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1 }, { 4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1 },
{ 4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1 }, { 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1 }, { 5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1 }, { 9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1 }, { 0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1 }, { 10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1 }, { 5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1 },
{ 5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1 }, { 9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1 }, { 0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1 }, { 8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1 },
{ 2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1 }, { 7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1 }, { 2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1 },
{ 11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1 }, { 9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1 },
{ 5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1 }, { 11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1 },
{ 11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1 }, { 1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1 }, { 9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1 }, { 2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1 }, { 6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1 }, { 3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1 },
{ 6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1 }, { 5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1 }, { 6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1 }, { 8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1 },
{ 7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1 }, { 3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1 }, { 0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1 },
{ 9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1 }, { 8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1 },
{ 5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1 }, { 0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1 },
{ 6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1 }, { 10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1 }, { 10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1 }, { 1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1 }, { 0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1 }, { 10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1 }, { 3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1 },
{ 6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1 }, { 9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1 },
{ 8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1 }, { 3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1 }, { 10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1 },
{ 10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1 },
{ 2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1 }, { 7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1 },
{ 2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1 }, { 1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1 },
{ 11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1 }, { 8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1 },
{ 0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1 },
{ 7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1 }, { 10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1 }, { 2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1 }, { 7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1 }, { 2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1 }, { 10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1 }, { 0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1 },
{ 7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1 }, { 6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1 }, { 8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1 }, { 6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1 }, { 4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1 },
{ 10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1 }, { 8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1 },
{ 1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1 }, { 8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1 },
{ 10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1 }, { 4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1 },
{ 10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1 }, { 5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1 }, { 9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1 }, { 7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1 },
{ 3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1 }, { 7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1 }, { 3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1 },
{ 6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1 }, { 9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1 },
{ 1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1 }, { 4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1 },
{ 7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1 }, { 6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1 }, { 0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1 },
{ 6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1 },
{ 0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1 }, { 11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1 },
{ 6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1 }, { 5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1 },
{ 9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1 }, { 1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1 },
{ 1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1 },
{ 10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1 }, { 0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1 }, { 5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1 }, { 11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1 }, { 9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1 },
{ 7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1 }, { 2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1 }, { 9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1 },
{ 9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1 }, { 1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1 }, { 0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1 },
{ 10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1 }, { 2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1 },
{ 0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1 }, { 0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1 },
{ 9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1 },
{ 5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1 }, { 3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1 },
{ 5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1 }, { 8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1 },
{ 9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1 }, { 1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1 },
{ 3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1 }, { 4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1 },
{ 9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1 }, { 11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1 }, { 2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1 },
{ 9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1 }, { 3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1 },
{ 1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1 }, { 4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1 }, { 0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1 },
{ 1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 } };
================================================
FILE: third_party/marching_cubes/src/mc_interp_kernel.cu
================================================
#include "mc_data.cuh"
#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
__device__ static inline float2 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
{
return make_float2(NAN, NAN);
}
// printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
long long vec_ind = indexer[bx][by][bz];
if (vec_ind == -1 || vec_ind >= max_vec_num)
{
return make_float2(NAN, NAN);
}
int batch_ind = vec_batch_mapping[vec_ind];
if (batch_ind == -1)
{
return make_float2(NAN, NAN);
}
// printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
float sdf = cube_sdf[batch_ind][arx][ary][arz];
float std = cube_std[batch_ind][arx][ary][arz];
return make_float2(sdf, std);
}
// Use stddev to weight sdf value.
// #define STD_W_SDF
__device__ static inline float2 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bpos.x >= bsize.x)
{
bpos.x = bsize.x - 1;
rpos.x = r - 1;
}
if (bpos.y >= bsize.y)
{
bpos.y = bsize.y - 1;
rpos.y = r - 1;
}
if (bpos.z >= bsize.z)
{
bpos.z = bsize.z - 1;
rpos.z = r - 1;
}
uint rbound = (r - 1) / 2;
uint rstart = r / 2;
float rmid = r / 2.0f;
float w_xm, w_xp;
int bxm, rxm, bxp, rxp;
int zero_x;
if (rpos.x <= rbound)
{
bxm = -1;
rxm = r;
bxp = 0;
rxp = 0;
w_xp = (float)rpos.x + rmid;
w_xm = rmid - (float)rpos.x;
zero_x = 1;
}
else
{
bxm = 0;
rxm = 0;
bxp = 1;
rxp = -r;
w_xp = (float)rpos.x - rmid;
w_xm = rmid + r - (float)rpos.x;
zero_x = 0;
}
w_xm /= r;
w_xp /= r;
float w_ym, w_yp;
int bym, rym, byp, ryp;
int zero_y;
if (rpos.y <= rbound)
{
bym = -1;
rym = r;
byp = 0;
ryp = 0;
w_yp = (float)rpos.y + rmid;
w_ym = rmid - (float)rpos.y;
zero_y = 1;
}
else
{
bym = 0;
rym = 0;
byp = 1;
ryp = -r;
w_yp = (float)rpos.y - rmid;
w_ym = rmid + r - (float)rpos.y;
zero_y = 0;
}
w_ym /= r;
w_yp /= r;
float w_zm, w_zp;
int bzm, rzm, bzp, rzp;
int zero_z;
if (rpos.z <= rbound)
{
bzm = -1;
rzm = r;
bzp = 0;
rzp = 0;
w_zp = (float)rpos.z + rmid;
w_zm = rmid - (float)rpos.z;
zero_z = 1;
}
else
{
bzm = 0;
rzm = 0;
bzp = 1;
rzp = -r;
w_zp = (float)rpos.z - rmid;
w_zm = rmid + r - (float)rpos.z;
zero_z = 0;
}
w_zm /= r;
w_zp /= r;
rpos.x += rstart;
rpos.y += rstart;
rpos.z += rstart;
// printf("%u %u %u %d %d %d %d %d %d\n", rpos.x, rpos.y, rpos.z, rxm, rxp, rym, ryp, rzm, rzp);
// Tri-linear interpolation of SDF values.
#ifndef STD_W_SDF
float total_weight = 0.0;
#else
float2 total_weight{0.0, 0.0};
#endif
float2 total_sdf{0.0, 0.0};
int zero_det = zero_x * 4 + zero_y * 2 + zero_z;
float2 sdfmmm = query_sdf_raw(bpos.x + bxm, bpos.y + bym, bpos.z + bzm, rpos.x + rxm, rpos.y + rym, rpos.z + rzm,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wmmm = w_xm * w_ym * w_zm;
#ifndef STD_W_SDF
if (!isnan(sdfmmm.x))
{
total_sdf += sdfmmm * wmmm;
total_weight += wmmm;
}
#else
if (!isnan(sdfmmm.x))
{
total_sdf.x += sdfmmm.x * wmmm * sdfmmm.y;
total_weight.x += wmmm * sdfmmm.y;
total_sdf.y += wmmm * sdfmmm.y;
total_weight.y += wmmm;
}
#endif
else if (zero_det == 0)
{
return make_float2(NAN, NAN);
}
float2 sdfmmp = query_sdf_raw(bpos.x + bxm, bpos.y + bym, bpos.z + bzp, rpos.x + rxm, rpos.y + rym, rpos.z + rzp,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wmmp = w_xm * w_ym * w_zp;
#ifndef STD_W_SDF
if (!isnan(sdfmmp.x))
{
total_sdf += sdfmmp * wmmp;
total_weight += wmmp;
}
#else
if (!isnan(sdfmmp.x))
{
total_sdf.x += sdfmmp.x * wmmp * sdfmmp.y;
total_weight.x += wmmp * sdfmmp.y;
total_sdf.y += wmmp * sdfmmp.y;
total_weight.y += wmmp;
}
#endif
else if (zero_det == 1)
{
return make_float2(NAN, NAN);
}
float2 sdfmpm = query_sdf_raw(bpos.x + bxm, bpos.y + byp, bpos.z + bzm, rpos.x + rxm, rpos.y + ryp, rpos.z + rzm,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wmpm = w_xm * w_yp * w_zm;
#ifndef STD_W_SDF
if (!isnan(sdfmpm.x))
{
total_sdf += sdfmpm * wmpm;
total_weight += wmpm;
}
#else
if (!isnan(sdfmpm.x))
{
total_sdf.x += sdfmpm.x * wmpm * sdfmpm.y;
total_weight.x += wmpm * sdfmpm.y;
total_sdf.y += wmpm * sdfmpm.y;
total_weight.y += wmpm;
}
#endif
else if (zero_det == 2)
{
return make_float2(NAN, NAN);
}
float2 sdfmpp = query_sdf_raw(bpos.x + bxm, bpos.y + byp, bpos.z + bzp, rpos.x + rxm, rpos.y + ryp, rpos.z + rzp,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wmpp = w_xm * w_yp * w_zp;
#ifndef STD_W_SDF
if (!isnan(sdfmpp.x))
{
total_sdf += sdfmpp * wmpp;
total_weight += wmpp;
}
#else
if (!isnan(sdfmpp.x))
{
total_sdf.x += sdfmpp.x * wmpp * sdfmpp.y;
total_weight.x += wmpp * sdfmpp.y;
total_sdf.y += wmpp * sdfmpp.y;
total_weight.y += wmpp;
}
#endif
else if (zero_det == 3)
{
return make_float2(NAN, NAN);
}
float2 sdfpmm = query_sdf_raw(bpos.x + bxp, bpos.y + bym, bpos.z + bzm, rpos.x + rxp, rpos.y + rym, rpos.z + rzm,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wpmm = w_xp * w_ym * w_zm;
#ifndef STD_W_SDF
if (!isnan(sdfpmm.x))
{
total_sdf += sdfpmm * wpmm;
total_weight += wpmm;
}
#else
if (!isnan(sdfpmm.x))
{
total_sdf.x += sdfpmm.x * wpmm * sdfpmm.y;
total_weight.x += wpmm * sdfpmm.y;
total_sdf.y += wpmm * sdfpmm.y;
total_weight.y += wpmm;
}
#endif
else if (zero_det == 4)
{
return make_float2(NAN, NAN);
}
float2 sdfpmp = query_sdf_raw(bpos.x + bxp, bpos.y + bym, bpos.z + bzp, rpos.x + rxp, rpos.y + rym, rpos.z + rzp,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wpmp = w_xp * w_ym * w_zp;
#ifndef STD_W_SDF
if (!isnan(sdfpmp.x))
{
total_sdf += sdfpmp * wpmp;
total_weight += wpmp;
}
#else
if (!isnan(sdfpmp.x))
{
total_sdf.x += sdfpmp.x * wpmp * sdfpmp.y;
total_weight.x += wpmp * sdfpmp.y;
total_sdf.y += wpmp * sdfpmp.y;
total_weight.y += wpmp;
}
#endif
else if (zero_det == 5)
{
return make_float2(NAN, NAN);
}
float2 sdfppm = query_sdf_raw(bpos.x + bxp, bpos.y + byp, bpos.z + bzm, rpos.x + rxp, rpos.y + ryp, rpos.z + rzm,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wppm = w_xp * w_yp * w_zm;
#ifndef STD_W_SDF
if (!isnan(sdfppm.x))
{
total_sdf += sdfppm * wppm;
total_weight += wppm;
}
#else
if (!isnan(sdfppm.x))
{
total_sdf.x += sdfppm.x * wppm * sdfppm.y;
total_weight.x += wppm * sdfppm.y;
total_sdf.y += wppm * sdfppm.y;
total_weight.y += wppm;
}
#endif
else if (zero_det == 6)
{
return make_float2(NAN, NAN);
}
float2 sdfppp = query_sdf_raw(bpos.x + bxp, bpos.y + byp, bpos.z + bzp, rpos.x + rxp, rpos.y + ryp, rpos.z + rzp,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wppp = w_xp * w_yp * w_zp;
#ifndef STD_W_SDF
if (!isnan(sdfppp.x))
{
total_sdf += sdfppp * wppp;
total_weight += wppp;
}
#else
if (!isnan(sdfppp.x))
{
total_sdf.x += sdfppp.x * wppp * sdfppp.y;
total_weight.x += wppp * sdfppp.y;
total_sdf.y += wppp * sdfppp.y;
total_weight.y += wppp;
}
#endif
else if (zero_det == 7)
{
return make_float2(NAN, NAN);
}
// If NAN, will also be handled.
return total_sdf / total_weight;
}
__device__ static inline float4 sdf_interp(const float3 p1, const float3 p2, const float stdp1, const float stdp2,
float valp1, float valp2)
{
if (fabs(0.0f - valp1) < 1.0e-5f)
return make_float4(p1, stdp1);
if (fabs(0.0f - valp2) < 1.0e-5f)
return make_float4(p2, stdp2);
if (fabs(valp1 - valp2) < 1.0e-5f)
return make_float4(p1, stdp1);
float w2 = (0.0f - valp1) / (valp2 - valp1);
float w1 = 1 - w2;
return make_float4(p1.x * w1 + p2.x * w2,
p1.y * w1 + p2.y * w2,
p1.z * w1 + p2.z * w2,
stdp1 * w1 + stdp2 * w2);
}
__global__ static void meshing_cube(const IndexerAccessor indexer,
const ValidBlocksAccessor valid_blocks,
const BackwardMappingAccessor vec_batch_mapping,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
TrianglesAccessor triangles,
TriangleStdAccessor triangle_std,
TriangleVecIdAccessor triangle_flatten_id,
int *__restrict__ triangles_count,
int max_triangles_count,
const uint max_vec_num,
int nx, int ny, int nz,
float max_std)
{
const uint r = cube_sdf.size(1) / 2;
const uint r3 = r * r * r;
const uint num_lif = valid_blocks.size(0);
const float sbs = 1.0f / r; // sub-block-size
const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;
if (lif_id >= num_lif || sub_id >= r3)
{
return;
}
const uint3 bpos = make_uint3(
(valid_blocks[lif_id] / (ny * nz)) % nx,
(valid_blocks[lif_id] / nz) % ny,
valid_blocks[lif_id] % nz);
const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
const uint rx = sub_id / (r * r);
const uint ry = (sub_id / r) % r;
const uint rz = sub_id % r;
// Find all 8 neighbours
float3 points[8];
float2 sdf_vals[8];
sdf_vals[0] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[0].x))
return;
points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
sdf_vals[1] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[1].x))
return;
points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
sdf_vals[2] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[2].x))
return;
points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
sdf_vals[3] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[3].x))
return;
points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
sdf_vals[4] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[4].x))
return;
points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[5] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[5].x))
return;
points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[6] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[6].x))
return;
points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[7] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[7].x))
return;
points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
// Find triangle config.
int cube_type = 0;
if (sdf_vals[0].x < 0)
cube_type |= 1;
if (sdf_vals[1].x < 0)
cube_type |= 2;
if (sdf_vals[2].x < 0)
cube_type |= 4;
if (sdf_vals[3].x < 0)
cube_type |= 8;
if (sdf_vals[4].x < 0)
cube_type |= 16;
if (sdf_vals[5].x < 0)
cube_type |= 32;
if (sdf_vals[6].x < 0)
cube_type |= 64;
if (sdf_vals[7].x < 0)
cube_type |= 128;
// Find vertex position on each edge (weighted by sdf value)
int edge_config = edgeTable[cube_type];
float4 vert_list[12];
if (edge_config == 0)
return;
if (edge_config & 1)
vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0].y, sdf_vals[1].y, sdf_vals[0].x, sdf_vals[1].x);
if (edge_config & 2)
vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1].y, sdf_vals[2].y, sdf_vals[1].x, sdf_vals[2].x);
if (edge_config & 4)
vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2].y, sdf_vals[3].y, sdf_vals[2].x, sdf_vals[3].x);
if (edge_config & 8)
vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3].y, sdf_vals[0].y, sdf_vals[3].x, sdf_vals[0].x);
if (edge_config & 16)
vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4].y, sdf_vals[5].y, sdf_vals[4].x, sdf_vals[5].x);
if (edge_config & 32)
vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5].y, sdf_vals[6].y, sdf_vals[5].x, sdf_vals[6].x);
if (edge_config & 64)
vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6].y, sdf_vals[7].y, sdf_vals[6].x, sdf_vals[7].x);
if (edge_config & 128)
vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7].y, sdf_vals[4].y, sdf_vals[7].x, sdf_vals[4].x);
if (edge_config & 256)
vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0].y, sdf_vals[4].y, sdf_vals[0].x, sdf_vals[4].x);
if (edge_config & 512)
vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1].y, sdf_vals[5].y, sdf_vals[1].x, sdf_vals[5].x);
if (edge_config & 1024)
vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2].y, sdf_vals[6].y, sdf_vals[2].x, sdf_vals[6].x);
if (edge_config & 2048)
vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3].y, sdf_vals[7].y, sdf_vals[3].x, sdf_vals[7].x);
// Write triangles to array.
float4 vp[3];
for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
}
if (vp[0].w > max_std || vp[1].w > max_std || vp[2].w > max_std)
{
continue;
}
int triangle_id = atomicAdd(triangles_count, 1);
if (triangle_id < max_triangles_count)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
triangles[triangle_id][vi][0] = vp[vi].x;
triangles[triangle_id][vi][1] = vp[vi].y;
triangles[triangle_id][vi][2] = vp[vi].z;
triangle_std[triangle_id][vi] = vp[vi].w;
}
triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
}
}
}
std::vector<torch::Tensor> marching_cubes_sparse_interp_cuda(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector<int> &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer
)
{
CHECK_INPUT(indexer);
CHECK_INPUT(valid_blocks);
CHECK_INPUT(cube_sdf);
CHECK_INPUT(cube_std);
CHECK_INPUT(vec_batch_mapping);
assert(max_n_triangles > 0);
const int r = cube_sdf.size(1) / 2;
const int r3 = r * r * r;
const int num_lif = valid_blocks.size(0);
const uint max_vec_num = vec_batch_mapping.size(0);
torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));
torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
dim3 dimBlock = dim3(16, 16);
uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;
uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;
dim3 dimGrid = dim3(xBlocks, yBlocks);
thrust::device_vector<int> n_output(1, 0);
meshing_cube<<<dimGrid, dimBlock, 0, at::cuda::getCurrentCUDAStream()>>>(
indexer.packed_accessor32<int64_t, 3, torch::RestrictPtrTraits>(),
valid_blocks.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
vec_batch_mapping.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
cube_sdf.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
cube_std.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
triangles.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
triangle_std.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
triangle_flatten_id.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
n_output.data().get(), max_n_triangles, max_vec_num,
n_xyz[0], n_xyz[1], n_xyz[2], max_std);
cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());
int output_n_triangles = n_output[0];
if (output_n_triangles < max_n_triangles)
{
// Trim output tensor if it is not full.
triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_std = triangle_std.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
}
else
{
// Otherwise spawn a warning.
std::cerr << "Warning from marching cube: the max triangle number is too small " << output_n_triangles << " vs " << max_n_triangles << std::endl;
}
return {triangles, triangle_flatten_id, triangle_std};
}
================================================
FILE: third_party/marching_cubes/src/mc_kernel.cu
================================================
#include "mc_data.cuh"
#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
__device__ static inline float2 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
{
return make_float2(NAN, NAN);
}
// printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
long long vec_ind = indexer[bx][by][bz];
if (vec_ind == -1 || vec_ind >= max_vec_num)
{
return make_float2(NAN, NAN);
}
int batch_ind = vec_batch_mapping[vec_ind];
if (batch_ind == -1)
{
return make_float2(NAN, NAN);
}
// printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
float sdf = cube_sdf[batch_ind][arx][ary][arz];
float std = cube_std[batch_ind][arx][ary][arz];
return make_float2(sdf, std);
}
// Use stddev to weight sdf value.
// #define STD_W_SDF
__device__ static inline float2 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bpos.x >= bsize.x)
{
bpos.x = bsize.x - 1;
rpos.x = r - 1;
}
if (bpos.y >= bsize.y)
{
bpos.y = bsize.y - 1;
rpos.y = r - 1;
}
if (bpos.z >= bsize.z)
{
bpos.z = bsize.z - 1;
rpos.z = r - 1;
}
if (rpos.x == r)
{
bpos.x += 1;
rpos.x = 0;
}
if (rpos.y == r)
{
bpos.y += 1;
rpos.y = 0;
}
if (rpos.z == r)
{
bpos.z += 1;
rpos.z = 0;
}
float2 total_sdf = query_sdf_raw(bpos.x, bpos.y, bpos.z, rpos.x, rpos.y, rpos.z, max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
// If NAN, will also be handled.
return total_sdf;
}
__device__ static inline float4 sdf_interp(const float3 p1, const float3 p2, const float stdp1, const float stdp2,
float valp1, float valp2)
{
if (fabs(0.0f - valp1) < 1.0e-5f)
return make_float4(p1, stdp1);
if (fabs(0.0f - valp2) < 1.0e-5f)
return make_float4(p2, stdp2);
if (fabs(valp1 - valp2) < 1.0e-5f)
return make_float4(p1, stdp1);
float w2 = (0.0f - valp1) / (valp2 - valp1);
float w1 = 1 - w2;
return make_float4(p1.x * w1 + p2.x * w2,
p1.y * w1 + p2.y * w2,
p1.z * w1 + p2.z * w2,
stdp1 * w1 + stdp2 * w2);
}
__global__ static void meshing_cube(const IndexerAccessor indexer,
const ValidBlocksAccessor valid_blocks,
const BackwardMappingAccessor vec_batch_mapping,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
TrianglesAccessor triangles,
TriangleStdAccessor triangle_std,
TriangleVecIdAccessor triangle_flatten_id,
int *__restrict__ triangles_count,
int max_triangles_count,
const uint max_vec_num,
int nx, int ny, int nz,
float max_std)
{
const uint r = cube_sdf.size(1);
const uint r3 = r * r * r;
const uint num_lif = valid_blocks.size(0);
const float sbs = 1.0f / r; // sub-block-size
const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;
if (lif_id >= num_lif || sub_id >= r3)
{
return;
}
const uint3 bpos = make_uint3(
(valid_blocks[lif_id] / (ny * nz)) % nx,
(valid_blocks[lif_id] / nz) % ny,
valid_blocks[lif_id] % nz);
const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
const uint rx = sub_id / (r * r);
const uint ry = (sub_id / r) % r;
const uint rz = sub_id % r;
// Find all 8 neighbours
float3 points[8];
float2 sdf_vals[8];
sdf_vals[0] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[0].x))
return;
points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
sdf_vals[1] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[1].x))
return;
points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
sdf_vals[2] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[2].x))
return;
points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
sdf_vals[3] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[3].x))
return;
points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
sdf_vals[4] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[4].x))
return;
points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[5] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[5].x))
return;
points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[6] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[6].x))
return;
points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[7] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[7].x))
return;
points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
// Find triangle config.
int cube_type = 0;
if (sdf_vals[0].x < 0)
cube_type |= 1;
if (sdf_vals[1].x < 0)
cube_type |= 2;
if (sdf_vals[2].x < 0)
cube_type |= 4;
if (sdf_vals[3].x < 0)
cube_type |= 8;
if (sdf_vals[4].x < 0)
cube_type |= 16;
if (sdf_vals[5].x < 0)
cube_type |= 32;
if (sdf_vals[6].x < 0)
cube_type |= 64;
if (sdf_vals[7].x < 0)
cube_type |= 128;
// Find vertex position on each edge (weighted by sdf value)
int edge_config = edgeTable[cube_type];
float4 vert_list[12];
if (edge_config == 0)
return;
if (edge_config & 1)
vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0].y, sdf_vals[1].y, sdf_vals[0].x, sdf_vals[1].x);
if (edge_config & 2)
vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1].y, sdf_vals[2].y, sdf_vals[1].x, sdf_vals[2].x);
if (edge_config & 4)
vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2].y, sdf_vals[3].y, sdf_vals[2].x, sdf_vals[3].x);
if (edge_config & 8)
vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3].y, sdf_vals[0].y, sdf_vals[3].x, sdf_vals[0].x);
if (edge_config & 16)
vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4].y, sdf_vals[5].y, sdf_vals[4].x, sdf_vals[5].x);
if (edge_config & 32)
vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5].y, sdf_vals[6].y, sdf_vals[5].x, sdf_vals[6].x);
if (edge_config & 64)
vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6].y, sdf_vals[7].y, sdf_vals[6].x, sdf_vals[7].x);
if (edge_config & 128)
vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7].y, sdf_vals[4].y, sdf_vals[7].x, sdf_vals[4].x);
if (edge_config & 256)
vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0].y, sdf_vals[4].y, sdf_vals[0].x, sdf_vals[4].x);
if (edge_config & 512)
vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1].y, sdf_vals[5].y, sdf_vals[1].x, sdf_vals[5].x);
if (edge_config & 1024)
vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2].y, sdf_vals[6].y, sdf_vals[2].x, sdf_vals[6].x);
if (edge_config & 2048)
vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3].y, sdf_vals[7].y, sdf_vals[3].x, sdf_vals[7].x);
// Write triangles to array.
float4 vp[3];
for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
}
int triangle_id = atomicAdd(triangles_count, 1);
if (triangle_id < max_triangles_count)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
triangles[triangle_id][vi][0] = vp[vi].x;
triangles[triangle_id][vi][1] = vp[vi].y;
triangles[triangle_id][vi][2] = vp[vi].z;
triangle_std[triangle_id][vi] = vp[vi].w;
}
triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
}
}
}
std::vector<torch::Tensor> marching_cubes_sparse(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector<int> &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer
)
{
CHECK_INPUT(indexer);
CHECK_INPUT(valid_blocks);
CHECK_INPUT(cube_sdf);
CHECK_INPUT(cube_std);
CHECK_INPUT(vec_batch_mapping);
assert(max_n_triangles > 0);
const int r = cube_sdf.size(1);
const int r3 = r * r * r;
const int num_lif = valid_blocks.size(0);
const uint max_vec_num = vec_batch_mapping.size(0);
torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));
torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
dim3 dimBlock = dim3(16, 16);
uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;
uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;
dim3 dimGrid = dim3(xBlocks, yBlocks);
thrust::device_vector<int> n_output(1, 0);
meshing_cube<<<dimGrid, dimBlock, 0, at::cuda::getCurrentCUDAStream()>>>(
indexer.packed_accessor32<int64_t, 3, torch::RestrictPtrTraits>(),
valid_blocks.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
vec_batch_mapping.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
cube_sdf.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
cube_std.packed_accessor32<float, 4, torch::RestrictPtrTraits>(),
triangles.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
triangle_std.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
triangle_flatten_id.packed_accessor32<int64_t, 1, torch::RestrictPtrTraits>(),
n_output.data().get(), max_n_triangles, max_vec_num,
n_xyz[0], n_xyz[1], n_xyz[2], max_std);
cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());
int output_n_triangles = n_output[0];
if (output_n_triangles < max_n_triangles)
{
// Trim output tensor if it is not full.
triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_std = triangle_std.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
}
else
{
// Otherwise spawn a warning.
std::cerr << "Warning from marching cube: the max triangle number is too small " << output_n_triangles << " vs " << max_n_triangles << std::endl;
}
return {triangles, triangle_flatten_id, triangle_std};
}
================================================
FILE: third_party/marching_cubes/src/mc_kernel_colour.cu
================================================
#include "mc_data.cuh"
#include <ATen/ATen.h>
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
__device__ static inline float4 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFRGBAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
{
return make_float4(NAN, NAN, NAN, NAN);
}
// printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
long long vec_ind = indexer[bx][by][bz];
if (vec_ind == -1 || vec_ind >= max_vec_num)
{
return make_float4(NAN, NAN, NAN, NAN);
}
int batch_ind = vec_batch_mapping[vec_ind];
if (batch_ind == -1)
{
return make_float4(NAN, NAN, NAN, NAN);
}
// printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
return make_float4(cube_sdf[batch_ind][arx][ary][arz][3],
cube_sdf[batch_ind][arx][ary][arz][0],
cube_sdf[batch_ind][arx][ary][arz][1],
cube_sdf[batch_ind][arx][ary][arz][2]);
}
// Use stddev to weight sdf value.
// #define STD_W_SDF
__device__ static inline float4 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFRGBAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bpos.x >= bsize.x)
{
bpos.x = bsize.x - 1;
rpos.x = r - 1;
}
if (bpos.y >= bsize.y)
{
bpos.y = bsize.y - 1;
rpos.y = r - 1;
}
if (bpos.z >= bsize.z)
{
bpos.z = bsize.z - 1;
rpos.z = r - 1;
}
if (rpos.x == r)
{
bpos.x += 1;
rpos.x = 0;
}
if (rpos.y == r)
{
bpos.y += 1;
rpos.y = 0;
}
if (rpos.z == r)
{
bpos.z += 1;
rpos.z = 0;
}
float4 total_sdf = query_sdf_raw(bpos.x, bpos.y, bpos.z, rpos.x, rpos.y, rpos.z, max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
// If NAN, will also be handled.
return total_sdf;
}
__device__ static inline float3 sdf_interp(const float3 p1, const float3 p2,
float valp1, float valp2)
{
if (fabs(0.0f - valp1) < 1.0e-5f)
return p1;
if (fabs(0.0f - valp2) < 1.0e-5f)
return p2;
if (fabs(valp1 - valp2) < 1.0e-5f)
return p1;
float w2 = (0.0f - valp1) / (valp2 - valp1);
float w1 = 1 - w2;
return make_float3(p1.x * w1 + p2.x * w2,
p1.y * w1 + p2.y * w2,
p1.z * w1 + p2.z * w2);
}
__global__ static void meshing_cube_colour(const IndexerAccessor indexer,
const ValidBlocksAccessor valid_blocks,
const BackwardMappingAccessor vec_batch_mapping,
const CubeSDFRGBAccessor cube_sdf,
const CubeSDFAccessor cube_std,
TrianglesAccessor triangles,
TrianglesAccessor vertex_colours,
TriangleVecIdAccessor triangle_flatten_id,
int *__restrict__ triangles_count,
int max_triangles_count,
const uint max_vec_num,
int nx, int ny, int nz,
float max_std)
{
const uint r = cube_sdf.size(1);
const uint r3 = r * r * r;
const uint num_lif = valid_blocks.size(0);
const float sbs = 1.0f / r; // sub-block-size
const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;
if (lif_id >= num_lif || sub_id >= r3)
{
return;
}
const uint3 bpos = make_uint3(
(valid_blocks[lif_id] / (ny * nz)) % nx,
(valid_blocks[lif_id] / nz) % ny,
valid_blocks[lif_id] % nz);
const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
const uint rx = sub_id / (r * r);
const uint ry = (sub_id / r) % r;
const uint rz = sub_id % r;
// Find all 8 neighbours
float3 points[8];
float3 colours[8];
float sdf_vals[8];
float4 sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[0] = sdf_val.x;
points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
colours[0] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[1] = sdf_val.x;
points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
colours[1] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[2] = sdf_val.x;
points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
colours[2] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[3] = sdf_val.x;
points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
colours[3] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[4] = sdf_val.x;
points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
colours[4] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[5] = sdf_val.x;
points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
colours[5] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[6] = sdf_val.x;
points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
colours[6] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[7] = sdf_val.x;
points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
colours[7] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
// Find triangle config.
int cube_type = 0;
if (sdf_vals[0] < 0)
cube_type |= 1;
if (sdf_vals[1] < 0)
cube_type |= 2;
if (sdf_vals[2] < 0)
cube_type |= 4;
if (sdf_vals[3] < 0)
cube_type |= 8;
if (sdf_vals[4] < 0)
cube_type |= 16;
if (sdf_vals[5] < 0)
cube_type |= 32;
if (sdf_vals[6] < 0)
cube_type |= 64;
if (sdf_vals[7] < 0)
cube_type |= 128;
// Find vertex position on each edge (weighted by sdf value)
int edge_config = edgeTable[cube_type];
float3 vert_list[12];
float3 rgb_list[12];
if (edge_config == 0)
return;
if (edge_config & 1)
{
vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0], sdf_vals[1]);
rgb_list[0] = sdf_interp(colours[0], colours[1], sdf_vals[0], sdf_vals[1]);
}
if (edge_config & 2)
{
vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1], sdf_vals[2]);
rgb_list[1] = sdf_interp(colours[1], colours[2], sdf_vals[1], sdf_vals[2]);
}
if (edge_config & 4)
{
vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2], sdf_vals[3]);
rgb_list[2] = sdf_interp(colours[2], colours[3], sdf_vals[2], sdf_vals[3]);
}
if (edge_config & 8)
{
vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3], sdf_vals[0]);
rgb_list[3] = sdf_interp(colours[3], colours[0], sdf_vals[3], sdf_vals[0]);
}
if (edge_config & 16)
{
vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4], sdf_vals[5]);
rgb_list[4] = sdf_interp(colours[4], colours[5], sdf_vals[4], sdf_vals[5]);
}
if (edge_config & 32)
{
vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5], sdf_vals[6]);
rgb_list[5] = sdf_interp(colours[5], colours[6], sdf_vals[5], sdf_vals[6]);
}
if (edge_config & 64)
{
vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6], sdf_vals[7]);
rgb_list[6] = sdf_interp(colours[6], colours[7], sdf_vals[6], sdf_vals[7]);
}
if (edge_config & 128)
{
vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7], sdf_vals[4]);
rgb_list[7] = sdf_interp(colours[7], colours[4], sdf_vals[7], sdf_vals[4]);
}
if (edge_config & 256)
{
vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0], sdf_vals[4]);
rgb_list[8] = sdf_interp(colours[0], colours[4], sdf_vals[0], sdf_vals[4]);
}
if (edge_config & 512)
{
vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1], sdf_vals[5]);
rgb_list[9] = sdf_interp(colours[1], colours[5], sdf_vals[1], sdf_vals[5]);
}
if (edge_config & 1024)
{
vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2], sdf_vals[6]);
rgb_list[10] = sdf_interp(colours[2], colours[6], sdf_vals[2], sdf_vals[6]);
}
if (edge_config & 2048)
{
vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3], sdf_vals[7]);
rgb_list[11] = sdf_interp(colours[3], colours[7], sdf_vals[3], sdf_vals[7]);
}
// Write triangles to array.
float3 vp[3];
float3 vc[3];
for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
vc[vi] = rgb_list[triangleTable[cube_type][i + vi]];
}
int triangle_id = atomicAdd(triangles_count, 1);
if (triangle_id < max_triangles_count)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
triangles[triangle_id][vi][0] = vp[vi].x;
triangles[triangle_id][vi][1] = vp[vi].y;
triangles[triangle_id][vi][2] = vp[vi].z;
vertex_colours[triangle_id][vi][0] = vc[vi].x;
vertex_colours[triangle_id][vi][1] = vc[vi].y;
vertex_colours[triangle_id][vi][2] = vc[vi].z;
}
triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
}
}
}
std::vector<torch::Tensor> marching_cubes_sparse_colour(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_rgb_sdf, // (M, rx, ry, rz, 4)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector<int> &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer
)
{
CHECK_INPUT(indexer);
CHECK_INPUT(valid_blocks);
CHECK_INPUT(cube_rgb_sdf);
CHECK_INPUT(cube_std);
CHECK_INPUT(vec_batch_mapping);
assert(max_n_triangles > 0);
const int r = cube_rgb_sdf.size(1);
const int r3 = r * r * r;
const int num_lif = valid_blocks.size(0);
const uint max_vec_num = vec_batch_mapping.size(0);
torch::Tensor triangles = torch::empty({max_n_tr
gitextract_t7v5lyn0/
├── .gitignore
├── LICENSE
├── Readme.md
├── configs/
│ ├── kitti/
│ │ ├── kitti.yaml
│ │ ├── kitti_00.yaml
│ │ ├── kitti_01.yaml
│ │ ├── kitti_03.yaml
│ │ ├── kitti_04.yaml
│ │ ├── kitti_05.yaml
│ │ ├── kitti_06.yaml
│ │ ├── kitti_07.yaml
│ │ ├── kitti_08.yaml
│ │ ├── kitti_09.yaml
│ │ ├── kitti_10.yaml
│ │ ├── kitti_base06.yaml
│ │ └── kitti_base10.yaml
│ ├── maicity/
│ │ ├── maicity.yaml
│ │ ├── maicity_00.yaml
│ │ └── maicity_01.yaml
│ └── ncd/
│ ├── ncd.yaml
│ └── ncd_quad.yaml
├── demo/
│ ├── parser.py
│ └── run.py
├── install.sh
├── requirements.txt
├── src/
│ ├── criterion.py
│ ├── dataset/
│ │ ├── kitti.py
│ │ ├── maicity.py
│ │ └── ncd.py
│ ├── lidarFrame.py
│ ├── loggers.py
│ ├── mapping.py
│ ├── nerfloam.py
│ ├── se3pose.py
│ ├── share.py
│ ├── tracking.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── import_util.py
│ │ ├── mesh_util.py
│ │ ├── profile_util.py
│ │ └── sample_util.py
│ └── variations/
│ ├── decode_morton.py
│ ├── lidar.py
│ ├── render_helpers.py
│ └── voxel_helpers.py
└── third_party/
├── marching_cubes/
│ ├── setup.py
│ └── src/
│ ├── mc.cpp
│ ├── mc_data.cuh
│ ├── mc_interp_kernel.cu
│ ├── mc_kernel.cu
│ └── mc_kernel_colour.cu
├── sparse_octree/
│ ├── include/
│ │ ├── octree.h
│ │ ├── test.h
│ │ └── utils.h
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ └── octree.cpp
└── sparse_voxels/
├── include/
│ ├── cuda_utils.h
│ ├── cutil_math.h
│ ├── intersect.h
│ ├── octree.h
│ ├── sample.h
│ └── utils.h
├── setup.py
└── src/
├── binding.cpp
├── intersect.cpp
├── intersect_gpu.cu
├── octree.cpp
├── sample.cpp
└── sample_gpu.cu
SYMBOL INDEX (305 symbols across 33 files)
FILE: demo/parser.py
class ArgumentParserX (line 4) | class ArgumentParserX(argparse.ArgumentParser):
method __init__ (line 5) | def __init__(self, **kwargs):
method parse_args (line 9) | def parse_args(self, args=None, namespace=None):
method parse_config_yaml (line 23) | def parse_config_yaml(self, yaml_path, args=None):
method convert_to_namespace (line 39) | def convert_to_namespace(self, dict_in, args=None):
method update_recursive (line 48) | def update_recursive(self, dict1, dict2):
function get_parser (line 58) | def get_parser():
FILE: demo/run.py
function setup_seed (line 12) | def setup_seed(seed):
FILE: src/criterion.py
class Criterion (line 6) | class Criterion(nn.Module):
method __init__ (line 7) | def __init__(self, args) -> None:
method forward (line 16) | def forward(self, outputs, obs, pointsCos, use_color_loss=True,
method compute_loss (line 59) | def compute_loss(self, x, y, mask=None, loss_type="l2"):
method get_masks (line 67) | def get_masks(self, z_vals, depth, epsilon):
method get_sdf_loss (line 92) | def get_sdf_loss(self, z_vals, depth, predicted_sdf, truncation, valid...
FILE: src/dataset/kitti.py
class DataLoader (line 19) | class DataLoader(Dataset):
method __init__ (line 20) | def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1...
method get_init_pose (line 28) | def get_init_pose(self, frame):
method load_gt_pose (line 35) | def load_gt_pose(self):
method load_points (line 40) | def load_points(self, index):
method __len__ (line 72) | def __len__(self):
method __getitem__ (line 75) | def __getitem__(self, index):
FILE: src/dataset/maicity.py
class DataLoader (line 20) | class DataLoader(Dataset):
method __init__ (line 21) | def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1...
method get_init_pose (line 29) | def get_init_pose(self, frame):
method load_gt_pose (line 36) | def load_gt_pose(self):
method load_points (line 41) | def load_points(self, index):
method __len__ (line 74) | def __len__(self):
method __getitem__ (line 77) | def __getitem__(self, index):
FILE: src/dataset/ncd.py
class DataLoader (line 21) | class DataLoader(Dataset):
method __init__ (line 22) | def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1...
method get_init_pose (line 30) | def get_init_pose(self, frame):
method load_gt_pose (line 39) | def load_gt_pose(self):
method load_points (line 49) | def load_points(self, index):
method __len__ (line 79) | def __len__(self):
method __getitem__ (line 82) | def __getitem__(self, index):
FILE: src/lidarFrame.py
class LidarFrame (line 9) | class LidarFrame(nn.Module):
method __init__ (line 10) | def __init__(self, index, points, pointsCos, pose=None, new_keyframe=F...
method get_pose (line 26) | def get_pose(self):
method get_translation (line 29) | def get_translation(self):
method get_rotation (line 32) | def get_rotation(self):
method get_points (line 35) | def get_points(self):
method get_pointsCos (line 38) | def get_pointsCos(self):
method set_rel_pose (line 41) | def set_rel_pose(self, rel_pose):
method get_rel_pose (line 44) | def get_rel_pose(self):
method get_rays (line 48) | def get_rays(self):
method sample_rays (line 55) | def sample_rays(self, N_rays, track=False):
FILE: src/loggers.py
class BasicLogger (line 14) | class BasicLogger:
method __init__ (line 15) | def __init__(self, args) -> None:
method get_random_time_str (line 33) | def get_random_time_str(self):
method log_ckpt (line 36) | def log_ckpt(self, mapper):
method log_config (line 52) | def log_config(self, config):
method log_mesh (line 56) | def log_mesh(self, mesh, name="final_mesh.ply"):
method log_point_cloud (line 60) | def log_point_cloud(self, pcd, name="final_points.ply"):
method log_numpy_data (line 64) | def log_numpy_data(self, data, name, ind=None):
method log_debug_data (line 73) | def log_debug_data(self, data, idx):
method log_raw_image (line 77) | def log_raw_image(self, ind, rgb, depth):
method log_images (line 88) | def log_images(self, ind, gt_rgb, gt_depth, rgb, depth):
method npy2txt (line 144) | def npy2txt(self, input_path, output_path):
FILE: src/mapping.py
function get_network_size (line 23) | def get_network_size(net):
class Mapping (line 30) | class Mapping:
method __init__ (line 31) | def __init__(self, args, logger: BasicLogger):
method spin (line 93) | def spin(self, share_data, kf_buffer):
method do_mapping (line 172) | def do_mapping(self, share_data, tracked_frame=None,
method select_optimize_targets (line 205) | def select_optimize_targets(self, tracked_frame=None, selection_method...
method update_share_data (line 227) | def update_share_data(self, share_data, frameid=None):
method remove_back_points (line 235) | def remove_back_points(self, frame):
method frame_maxdistance_change (line 257) | def frame_maxdistance_change(self, frame, distance):
method insert_keyframe (line 266) | def insert_keyframe(self, frame, valid_distance=-1):
method create_voxels (line 283) | def create_voxels(self, frame):
method get_embeddings (line 294) | def get_embeddings(self, points_idx):
method update_grid_features (line 320) | def update_grid_features(self):
method get_updated_poses (line 342) | def get_updated_poses(self, offset=-2000):
method extract_mesh (line 354) | def extract_mesh(self, res=8, clean_mesh=False):
method extract_voxels (line 381) | def extract_voxels(self, offset=-10):
method save_debug_data (line 392) | def save_debug_data(self, tracked_frame, offset=-10):
FILE: src/nerfloam.py
class nerfloam (line 15) | class nerfloam:
method __init__ (line 16) | def __init__(self, args):
method start (line 40) | def start(self):
method wait_child_processes (line 55) | def wait_child_processes(self):
method get_raw_trajectory (line 60) | def get_raw_trajectory(self):
method get_keyframe_poses (line 64) | def get_keyframe_poses(self):
FILE: src/se3pose.py
class OptimizablePose (line 8) | class OptimizablePose(nn.Module):
method __init__ (line 9) | def __init__(self, init_pose):
method copy_from (line 15) | def copy_from(self, pose):
method matrix (line 18) | def matrix(self):
method rotation (line 24) | def rotation(self):
method translation (line 34) | def translation(self,):
method log (line 38) | def log(cls, R, eps=1e-7): # [...,3,3]
method from_matrix (line 48) | def from_matrix(cls, Rt, eps=1e-8): # [...,3,4]
method skew_symmetric (line 54) | def skew_symmetric(cls, w):
method taylor_A (line 64) | def taylor_A(cls, x, nth=10):
method taylor_B (line 75) | def taylor_B(cls, x, nth=10):
method taylor_C (line 85) | def taylor_C(cls, x, nth=10):
FILE: src/share.py
class ShareDataProxy (line 8) | class ShareDataProxy(NamespaceProxy):
class ShareData (line 12) | class ShareData:
method __init__ (line 16) | def __init__(self):
method decoder (line 27) | def decoder(self):
method decoder (line 34) | def decoder(self, decoder):
method voxels (line 41) | def voxels(self):
method voxels (line 48) | def voxels(self, voxels):
method octree (line 55) | def octree(self):
method octree (line 62) | def octree(self, octree):
method states (line 69) | def states(self):
method states (line 76) | def states(self, states):
method stop_mapping (line 83) | def stop_mapping(self):
method stop_mapping (line 90) | def stop_mapping(self, stop_mapping):
method stop_tracking (line 97) | def stop_tracking(self):
method stop_tracking (line 104) | def stop_tracking(self, stop_tracking):
method tracking_trajectory (line 111) | def tracking_trajectory(self):
method push_pose (line 117) | def push_pose(self, pose):
FILE: src/tracking.py
class Tracking (line 15) | class Tracking:
method __init__ (line 16) | def __init__(self, args, data_stream, logger):
method process_first_frame (line 51) | def process_first_frame(self, kf_buffer):
method spin (line 64) | def spin(self, share_data, kf_buffer):
method check_keyframe (line 92) | def check_keyframe(self, check_frame, kf_buffer):
method do_tracking (line 98) | def do_tracking(self, share_data, current_frame, kf_buffer):
FILE: src/utils/import_util.py
function get_dataset (line 4) | def get_dataset(args):
function get_decoder (line 8) | def get_decoder(args):
function get_property (line 12) | def get_property(args, name, default):
FILE: src/utils/mesh_util.py
class MeshExtractor (line 11) | class MeshExtractor:
method __init__ (line 12) | def __init__(self, args):
method linearize_id (line 18) | def linearize_id(self, xyz, n_xyz):
method downsample_points (line 22) | def downsample_points(self, points, voxel_size=0.01):
method get_rays (line 29) | def get_rays(self, w=None, h=None, K=None):
method get_valid_points (line 47) | def get_valid_points(self, frame_poses, depth_maps):
method create_mesh (line 80) | def create_mesh(self, decoder, map_states, voxel_size, voxels,
method marching_cubes (line 145) | def marching_cubes(self, voxels, sdf):
FILE: src/utils/profile_util.py
class Profiler (line 5) | class Profiler(object):
method __init__ (line 6) | def __init__(self, verbose=False) -> None:
method enable (line 12) | def enable(self):
method disable (line 15) | def disable(self):
method tick (line 18) | def tick(self, name):
method tok (line 25) | def tok(self, name):
FILE: src/utils/sample_util.py
function sampling_without_replacement (line 4) | def sampling_without_replacement(logp, k):
function sample_rays (line 12) | def sample_rays(mask, num_samples):
FILE: src/variations/decode_morton.py
function compact (line 4) | def compact(value):
function decode (line 14) | def decode(code):
FILE: src/variations/lidar.py
class GaussianFourierFeatureTransform (line 6) | class GaussianFourierFeatureTransform(torch.nn.Module):
method __init__ (line 16) | def __init__(self, num_input_channels, mapping_size=93, scale=25, lear...
method forward (line 26) | def forward(self, x):
class Nerf_positional_embedding (line 33) | class Nerf_positional_embedding(torch.nn.Module):
method __init__ (line 39) | def __init__(self, in_dim, multires, log_sampling=True):
method forward (line 50) | def forward(self, x):
class Same (line 71) | class Same(nn.Module):
method __init__ (line 72) | def __init__(self, in_dim) -> None:
method forward (line 76) | def forward(self, x):
class Decoder (line 80) | class Decoder(nn.Module):
method __init__ (line 81) | def __init__(self,
method get_values (line 109) | def get_values(self, input):
method forward (line 125) | def forward(self, inputs):
FILE: src/variations/render_helpers.py
function ray (line 9) | def ray(ray_start, ray_dir, depths):
function fill_in (line 13) | def fill_in(shape, mask, input, initial=1.0):
function masked_scatter (line 21) | def masked_scatter(mask, x):
function masked_scatter_ones (line 30) | def masked_scatter_ones(mask, x):
function trilinear_interp (line 40) | def trilinear_interp(p, q, point_feats):
function offset_points (line 49) | def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
function get_embeddings (line 63) | def get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size):
function get_features (line 74) | def get_features(samples, map_states, voxel_size):
function get_scores (line 97) | def get_scores(sdf_network, map_states, voxel_size, bits=8):
function eval_points (line 157) | def eval_points(sdf_network, map_states, sampled_xyz, sampled_idx, voxel...
function render_rays (line 190) | def render_rays(
function bundle_adjust_frames (line 321) | def bundle_adjust_frames(
function track_frame (line 428) | def track_frame(
FILE: src/variations/voxel_helpers.py
class BallRayIntersect (line 27) | class BallRayIntersect(Function):
method forward (line 29) | def forward(ctx, radius, n_max, points, ray_start, ray_dir):
method backward (line 42) | def backward(ctx, a, b, c):
class AABBRayIntersect (line 49) | class AABBRayIntersect(Function):
method forward (line 51) | def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir):
method backward (line 85) | def backward(ctx, a, b, c):
class SparseVoxelOctreeRayIntersect (line 92) | class SparseVoxelOctreeRayIntersect(Function):
method forward (line 94) | def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir):
method backward (line 136) | def backward(ctx, a, b, c):
class TriangleRayIntersect (line 143) | class TriangleRayIntersect(Function):
method forward (line 145) | def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start...
method backward (line 187) | def backward(ctx, a, b, c):
class UniformRaySampling (line 194) | class UniformRaySampling(Function):
method forward (line 196) | def forward(
method backward (line 255) | def backward(ctx, a, b, c):
class InverseCDFRaySampling (line 262) | class InverseCDFRaySampling(Function):
method forward (line 264) | def forward(
method backward (line 343) | def backward(ctx, a, b, c):
function _parallel_ray_sampling (line 352) | def _parallel_ray_sampling(
function parallel_ray_sampling (line 411) | def parallel_ray_sampling(
function discretize_points (line 454) | def discretize_points(voxel_points, voxel_size):
function build_easy_octree (line 467) | def build_easy_octree(points, half_voxel):
function trilinear_interp (line 478) | def trilinear_interp(p, q, point_feats):
function offset_points (line 487) | def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
function splitting_points (line 499) | def splitting_points(point_xyz, point_feats, values, half_voxel):
function ray_intersect (line 531) | def ray_intersect(ray_start, ray_dir, flatten_centers, flatten_children,...
function ray_sample (line 571) | def ray_sample(intersection_outputs, step_size=0.01, fixed=False):
FILE: third_party/marching_cubes/src/mc.cpp
function PYBIND11_MODULE (line 36) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
FILE: third_party/sparse_octree/include/octree.h
type OcType (line 5) | enum OcType
function class (line 12) | class Octant : public torch::CustomClassHolder
function class (line 65) | class Octree : public torch::CustomClassHolder
FILE: third_party/sparse_octree/include/test.h
function expand (line 41) | inline int64_t expand(int64_t value)
function compact (line 52) | inline uint64_t compact(uint64_t value)
function compute_morton (line 63) | inline int64_t compute_morton(int64_t x, int64_t y, int64_t z)
FILE: third_party/sparse_octree/include/utils.h
type Vector3 (line 28) | typedef Vector3<int> Vector3i;
type Vector3 (line 29) | typedef Vector3<float> Vector3f;
function expand (line 64) | inline uint64_t expand(unsigned long long value)
function compact (line 75) | inline uint64_t compact(uint64_t value)
function compute_morton (line 86) | inline uint64_t compute_morton(uint64_t x, uint64_t y, uint64_t z)
function Eigen (line 98) | inline Eigen::Vector3i decode(const uint64_t code)
function encode (line 106) | inline uint64_t encode(const int x, const int y, const int z)
FILE: third_party/sparse_octree/src/bindings.cpp
function TORCH_LIBRARY (line 4) | TORCH_LIBRARY(svo, m)
FILE: third_party/sparse_octree/src/octree.cpp
function Octant (line 151) | Octant *Octree::find_octant(std::vector<float> coord)
FILE: third_party/sparse_voxels/include/cuda_utils.h
function opt_n_threads (line 20) | inline int opt_n_threads(int work_size)
function dim3 (line 27) | inline dim3 opt_block_config(int x, int y)
FILE: third_party/sparse_voxels/include/cutil_math.h
type uint (line 30) | typedef unsigned int uint;
type ushort (line 31) | typedef unsigned short ushort;
function fminf (line 36) | inline float fminf(float a, float b)
function fmaxf (line 41) | inline float fmaxf(float a, float b)
function max (line 46) | inline int max(int a, int b)
function min (line 51) | inline int min(int a, int b)
function rsqrtf (line 56) | inline float rsqrtf(float x)
function lerp (line 67) | float lerp(float a, float b, float t)
function clamp (line 73) | float clamp(float f, float a, float b)
function swap (line 78) | void swap(float &a, float &b)
function swap (line 85) | void swap(int &a, int &b)
function int2 (line 124) | int2 operator*(int2 a, int2 b)
function int2 (line 128) | int2 operator*(int2 a, int s)
function int2 (line 132) | int2 operator*(int s, int2 a)
function float2 (line 146) | float2 make_float2(float s)
function float2 (line 150) | float2 make_float2(int2 a)
function float2 (line 184) | float2 operator*(float2 a, float2 b)
function float2 (line 188) | float2 operator*(float2 a, float s)
function float2 (line 192) | float2 operator*(float s, float2 a)
function float2 (line 224) | float2 lerp(float2 a, float2 b, float t)
function float2 (line 230) | float2 clamp(float2 v, float a, float b)
function float2 (line 235) | float2 clamp(float2 v, float2 a, float2 b)
function dot (line 241) | float dot(float2 a, float2 b)
function length (line 247) | float length(float2 v)
function float2 (line 253) | float2 normalize(float2 v)
function float2 (line 260) | float2 floor(const float2 v)
function float2 (line 266) | float2 reflect(float2 i, float2 n)
function float2 (line 272) | float2 fabs(float2 v)
function float3 (line 281) | float3 make_float3(float s)
function float3 (line 285) | float3 make_float3(float2 a)
function float3 (line 289) | float3 make_float3(float2 a, float s)
function float3 (line 293) | float3 make_float3(float4 a)
function float3 (line 297) | float3 make_float3(int3 a)
function float3 (line 309) | float3 fminf(float3 a, float3 b)
function float3 (line 315) | float3 fmaxf(float3 a, float3 b)
function float3 (line 353) | float3 operator*(float3 a, float3 b)
function float3 (line 357) | float3 operator*(float3 a, float s)
function float3 (line 361) | float3 operator*(float s, float3 a)
function float3 (line 401) | float3 lerp(float3 a, float3 b, float t)
function float3 (line 407) | float3 clamp(float3 v, float a, float b)
function float3 (line 412) | float3 clamp(float3 v, float3 a, float3 b)
function dot (line 418) | float dot(float3 a, float3 b)
function float3 (line 424) | float3 cross(float3 a, float3 b)
function length (line 430) | float length(float3 v)
function float3 (line 436) | float3 normalize(float3 v)
function float3 (line 443) | float3 floor(const float3 v)
function float3 (line 449) | float3 reflect(float3 i, float3 n)
function float3 (line 455) | float3 fabs(float3 v)
function float4 (line 464) | float4 make_float4(float s)
function float4 (line 468) | float4 make_float4(float3 a)
function float4 (line 472) | float4 make_float4(float3 a, float w)
function float4 (line 476) | float4 make_float4(int4 a)
function float4 (line 488) | float4 fminf(float4 a, float4 b)
function float4 (line 494) | float4 fmaxf(float4 a, float4 b)
function float4 (line 526) | float4 operator*(float4 a, float s)
function float4 (line 530) | float4 operator*(float s, float4 a)
function float4 (line 564) | float4 lerp(float4 a, float4 b, float t)
function float4 (line 570) | float4 clamp(float4 v, float a, float b)
function float4 (line 575) | float4 clamp(float4 v, float4 a, float4 b)
function dot (line 581) | float dot(float4 a, float4 b)
function length (line 587) | float length(float4 r)
function float4 (line 593) | float4 normalize(float4 v)
function float4 (line 600) | float4 floor(const float4 v)
function float4 (line 606) | float4 fabs(float4 v)
function int3 (line 615) | int3 make_int3(int s)
function int3 (line 619) | int3 make_int3(float3 a)
function int3 (line 631) | int3 min(int3 a, int3 b)
function int3 (line 637) | int3 max(int3 a, int3 b)
function int3 (line 668) | int3 operator*(int3 a, int3 b)
function int3 (line 672) | int3 operator*(int3 a, int s)
function int3 (line 676) | int3 operator*(int s, int3 a)
function clamp (line 708) | int clamp(int f, int a, int b)
function int3 (line 713) | int3 clamp(int3 v, int a, int b)
function int3 (line 718) | int3 clamp(int3 v, int3 a, int3 b)
function uint3 (line 727) | uint3 make_uint3(uint s)
function uint3 (line 731) | uint3 make_uint3(float3 a)
function uint3 (line 737) | uint3 min(uint3 a, uint3 b)
function uint3 (line 743) | uint3 max(uint3 a, uint3 b)
function uint3 (line 774) | uint3 operator*(uint3 a, uint3 b)
function uint3 (line 778) | uint3 operator*(uint3 a, uint s)
function uint3 (line 782) | uint3 operator*(uint s, uint3 a)
function uint (line 814) | uint clamp(uint f, uint a, uint b)
function uint3 (line 819) | uint3 clamp(uint3 v, uint a, uint b)
function uint3 (line 824) | uint3 clamp(uint3 v, uint3 a, uint3 b)
FILE: third_party/sparse_voxels/src/binding.cpp
function PYBIND11_MODULE (line 10) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
FILE: third_party/sparse_voxels/src/intersect.cpp
function ball_intersect (line 15) | std::tuple<at::Tensor, at::Tensor, at::Tensor> ball_intersect(at::Tensor...
function aabb_intersect (line 49) | std::tuple<at::Tensor, at::Tensor, at::Tensor> aabb_intersect(at::Tensor...
function svo_intersect (line 83) | std::tuple<at::Tensor, at::Tensor, at::Tensor> svo_intersect(at::Tensor ...
function triangle_intersect (line 119) | std::tuple<at::Tensor, at::Tensor, at::Tensor> triangle_intersect(at::Te...
FILE: third_party/sparse_voxels/src/octree.cpp
type OcTree (line 12) | struct OcTree
type OcTree (line 17) | struct OcTree
method init (line 18) | void init(at::Tensor center, int d, int i)
class EasyOctree (line 28) | class EasyOctree
method EasyOctree (line 38) | EasyOctree(at::Tensor center, int depth)
function build_octree (line 149) | std::tuple<at::Tensor, at::Tensor> build_octree(at::Tensor center, at::T...
FILE: third_party/sparse_voxels/src/sample.cpp
function uniform_ray_sampling (line 21) | std::tuple<at::Tensor, at::Tensor, at::Tensor> uniform_ray_sampling(
function inverse_cdf_sampling (line 56) | std::tuple<at::Tensor, at::Tensor, at::Tensor> inverse_cdf_sampling(
Condensed preview — 70 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (291K chars).
[
{
"path": ".gitignore",
"chars": 69,
"preview": "__pycache__\n*.egg-info\nbuild/\ndist/\nlogs/\n.vscode/\nresults/\ntemp.sh\n\n"
},
{
"path": "LICENSE",
"chars": 1068,
"preview": "MIT License\n\nCopyright (c) 2023 JunyuanDeng\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "Readme.md",
"chars": 5131,
"preview": "# NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping\n\nThis repository cont"
},
{
"path": "configs/kitti/kitti.yaml",
"chars": 740,
"preview": "log_dir: './logs'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n sdf_weight: 10000.0\n fs_weight: 1\n eiko_weight: 0.1\n sdf"
},
{
"path": "configs/kitti/kitti_00.yaml",
"chars": 256,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence00\n\n\n\ndata_specs:\n data_path: '/home/pl21n4/dataset/kitt"
},
{
"path": "configs/kitti/kitti_01.yaml",
"chars": 270,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence01\n\n\n\ndata_specs:\n data_path: '/home/evsjtu2/disk1/dengj"
},
{
"path": "configs/kitti/kitti_03.yaml",
"chars": 270,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence03\n\n\n\ndata_specs:\n data_path: '/home/evsjtu2/disk1/dengj"
},
{
"path": "configs/kitti/kitti_04.yaml",
"chars": 260,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence04\n\n\n\ndata_specs:\n data_path: '/home/pl21n4/dataset/kitt"
},
{
"path": "configs/kitti/kitti_05.yaml",
"chars": 273,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence05\n\n\n\ndata_specs:\n data_path: '/home/evsjtu2/disk1/dengj"
},
{
"path": "configs/kitti/kitti_06.yaml",
"chars": 256,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence06\n\n\n\ndata_specs:\n data_path: '/home/pl21n4/dataset/kitt"
},
{
"path": "configs/kitti/kitti_07.yaml",
"chars": 268,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence07\n\n\n\ndata_specs:\n data_path: '/home/evsjtu2/disk1/dengj"
},
{
"path": "configs/kitti/kitti_08.yaml",
"chars": 256,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence08\n\n\n\ndata_specs:\n data_path: '/home/pl21n4/dataset/kitt"
},
{
"path": "configs/kitti/kitti_09.yaml",
"chars": 256,
"preview": "base_config: configs/kitti/kitti.yaml\n\nexp_name: kitti/sqeuence09\n\n\n\ndata_specs:\n data_path: '/home/pl21n4/dataset/kitt"
},
{
"path": "configs/kitti/kitti_10.yaml",
"chars": 268,
"preview": "base_config: configs/kitti/kitti_base10.yaml\n\nexp_name: kitti/sqeuence10\n\n\n\ndata_specs:\n data_path: '/home/pl21n4/datas"
},
{
"path": "configs/kitti/kitti_base06.yaml",
"chars": 739,
"preview": "log_dir: './logs'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n sdf_weight: 10000.0\n fs_weight: 1\n eiko_weight: 0.1\n sdf"
},
{
"path": "configs/kitti/kitti_base10.yaml",
"chars": 932,
"preview": "log_dir: '/home/evsjtu2/disk1/dengjunyuan/running_logs/'\ndecoder: lidar\ndataset: kitti\n\ncriteria:\n depth_weight: 0\n sd"
},
{
"path": "configs/maicity/maicity.yaml",
"chars": 743,
"preview": "log_dir: './logs'\ndecoder: lidar\ndataset: maicity\n\ncriteria:\n sdf_weight: 10000.0\n fs_weight: 1\n eiko_weight: 0.1\n s"
},
{
"path": "configs/maicity/maicity_00.yaml",
"chars": 265,
"preview": "base_config: configs/maicity/maicity.yaml\n\nexp_name: maicity/sqeuence00\n\n\ndata_specs:\n data_path: '/home/pl21n4/dataset"
},
{
"path": "configs/maicity/maicity_01.yaml",
"chars": 264,
"preview": "base_config: configs/maicity/maicity.yaml\n\nexp_name: maicity/sqeuence01\n\n\ndata_specs:\n data_path: '/home/pl21n4/dataset"
},
{
"path": "configs/ncd/ncd.yaml",
"chars": 740,
"preview": "log_dir: './logs'\ndecoder: lidar\ndataset: ncd\n\ncriteria:\n sdf_weight: 10000.0\n fs_weight: 1\n eiko_weight: 1.0\n sdf_t"
},
{
"path": "configs/ncd/ncd_quad.yaml",
"chars": 238,
"preview": "base_config: configs/ncd/ncd.yaml\n\nexp_name: ncd/quad\n\n\n\ndata_specs:\n data_path: '/home/pl21n4/dataset/ncd_example/quad"
},
{
"path": "demo/parser.py",
"chars": 2277,
"preview": "import yaml\nimport argparse\n\nclass ArgumentParserX(argparse.ArgumentParser):\n def __init__(self, **kwargs):\n s"
},
{
"path": "demo/run.py",
"chars": 683,
"preview": "import os # noqa\nimport sys # noqa\nsys.path.insert(0, os.path.abspath('src')) # noqa\nos.environ[\"CUDA_VISIBLE_DEVICES\""
},
{
"path": "install.sh",
"chars": 156,
"preview": "#!/bin/bash\n\ncd third_party/marching_cubes\npython setup.py install\n\ncd ../sparse_octree\npython setup.py install\n\ncd ../s"
},
{
"path": "requirements.txt",
"chars": 71,
"preview": "matplotlib\nopen3d\nopencv-python\nPyYAML\nscikit-image\ntqdm\ntrimesh\npyyaml"
},
{
"path": "src/criterion.py",
"chars": 4853,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.autograd import grad\n\n\nclass Criterion(nn.Module):\n def __init__(self, "
},
{
"path": "src/dataset/kitti.py",
"chars": 3326,
"preview": "import os.path as osp\n\nimport numpy as np\nimport torch\nfrom glob import glob\nfrom torch.utils.data import Dataset\nimport"
},
{
"path": "src/dataset/maicity.py",
"chars": 3440,
"preview": "import os.path as osp\n\nimport numpy as np\nimport torch\nfrom glob import glob\nfrom torch.utils.data import Dataset\nimport"
},
{
"path": "src/dataset/ncd.py",
"chars": 3864,
"preview": "import os.path as osp\n\nimport numpy as np\nimport torch\nimport open3d as o3d\nfrom glob import glob\nfrom torch.utils.data "
},
{
"path": "src/lidarFrame.py",
"chars": 1699,
"preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom se3pose import OptimizablePose\nfrom utils.sample_util import "
},
{
"path": "src/loggers.py",
"chars": 6118,
"preview": "import os\nimport os.path as osp\nimport pickle\nfrom datetime import datetime\n\nimport cv2\nimport matplotlib.pyplot as plt\n"
},
{
"path": "src/mapping.py",
"chars": 18659,
"preview": "from copy import deepcopy\nimport random\nfrom time import sleep\nimport numpy as np\nfrom tqdm import tqdm\nimport torch\n\nfr"
},
{
"path": "src/nerfloam.py",
"chars": 2052,
"preview": "from multiprocessing.managers import BaseManager\nfrom time import sleep\n\nimport torch\nimport torch.multiprocessing as mp"
},
{
"path": "src/se3pose.py",
"chars": 3279,
"preview": "\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom copy import deepcopy\n\n\nclass OptimizablePose(nn.Module):\n "
},
{
"path": "src/share.py",
"chars": 3355,
"preview": "from multiprocessing.managers import BaseManager, NamespaceProxy\nfrom copy import deepcopy\nimport torch.multiprocessing "
},
{
"path": "src/tracking.py",
"chars": 5641,
"preview": "import torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom criterion import Criterion\nfrom lidarFrame import LidarFrame"
},
{
"path": "src/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/utils/import_util.py",
"chars": 652,
"preview": "from importlib import import_module\nimport argparse\n\ndef get_dataset(args):\n Dataset = import_module(\"dataset.\"+args."
},
{
"path": "src/utils/mesh_util.py",
"chars": 6747,
"preview": "import math\nimport torch\n\nimport numpy as np\nimport open3d as o3d\nfrom scipy.spatial import cKDTree\nfrom skimage.measure"
},
{
"path": "src/utils/profile_util.py",
"chars": 870,
"preview": "from time import time\nimport torch\n\n\nclass Profiler(object):\n def __init__(self, verbose=False) -> None:\n self"
},
{
"path": "src/utils/sample_util.py",
"chars": 616,
"preview": "import torch\n\n\ndef sampling_without_replacement(logp, k):\n def gumbel_like(u):\n return -torch.log(-torch.log(t"
},
{
"path": "src/variations/decode_morton.py",
"chars": 588,
"preview": "import numpy as np\n\n\ndef compact(value):\n x = value & 0x1249249249249249\n x = (x | x >> 2) & 0x10c30c30c30c30c3\n "
},
{
"path": "src/variations/lidar.py",
"chars": 4114,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass GaussianFourierFeatureTransform(torch.nn.Modu"
},
{
"path": "src/variations/render_helpers.py",
"chars": 17255,
"preview": "from copy import deepcopy\nimport torch\nimport torch.nn.functional as F\n\nfrom .voxel_helpers import ray_intersect, ray_sa"
},
{
"path": "src/variations/voxel_helpers.py",
"chars": 21098,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
},
{
"path": "third_party/marching_cubes/setup.py",
"chars": 527,
"preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\nimport glob\n\n_ext_sourc"
},
{
"path": "third_party/marching_cubes/src/mc.cpp",
"chars": 1956,
"preview": "#include <torch/extension.h>\n\nstd::vector<torch::Tensor> marching_cubes_sparse(\n torch::Tensor indexer, // "
},
{
"path": "third_party/marching_cubes/src/mc_data.cuh",
"chars": 18896,
"preview": "#include <torch/extension.h>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <thrust/device_vector.h>\n\n#define CHE"
},
{
"path": "third_party/marching_cubes/src/mc_interp_kernel.cu",
"chars": 20516,
"preview": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ "
},
{
"path": "third_party/marching_cubes/src/mc_kernel.cu",
"chars": 13447,
"preview": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ "
},
{
"path": "third_party/marching_cubes/src/mc_kernel_colour.cu",
"chars": 15332,
"preview": "#include \"mc_data.cuh\"\n\n#include <ATen/ATen.h>\n#include <ATen/Context.h>\n#include <ATen/cuda/CUDAContext.h>\n\n__device__ "
},
{
"path": "third_party/sparse_octree/include/octree.h",
"chars": 3004,
"preview": "#include <memory>\n#include <torch/script.h>\n#include <torch/custom_class.h>\n\nenum OcType\n{\n NONLEAF = -1,\n SURFACE"
},
{
"path": "third_party/sparse_octree/include/test.h",
"chars": 2181,
"preview": "#pragma once\n#include <iostream>\n\n#define MAX_BITS 21\n// #define SCALE_MASK ((uint64_t)0x1FF)\n#define SCALE_MASK ((uint6"
},
{
"path": "third_party/sparse_octree/include/utils.h",
"chars": 2483,
"preview": "#pragma once\n#include <iostream>\n#include <eigen3/Eigen/Dense>\n\n#define MAX_BITS 21\n// #define SCALE_MASK ((uint64_t)0x1"
},
{
"path": "third_party/sparse_octree/setup.py",
"chars": 657,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
},
{
"path": "third_party/sparse_octree/src/bindings.cpp",
"chars": 1288,
"preview": "#include \"../include/octree.h\"\n#include \"../include/test.h\"\n\nTORCH_LIBRARY(svo, m)\n{\n m.def(\"encode\", &encode_torch);"
},
{
"path": "third_party/sparse_octree/src/octree.cpp",
"chars": 10131,
"preview": "#include \"../include/octree.h\"\n#include \"../include/utils.h\"\n#include <queue>\n#include <iostream>\n\n// #define MAX_HIT_VO"
},
{
"path": "third_party/sparse_voxels/include/cuda_utils.h",
"chars": 1629,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/include/cutil_math.h",
"chars": 18605,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/include/intersect.h",
"chars": 1194,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/include/octree.h",
"chars": 341,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/include/sample.h",
"chars": 682,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/include/utils.h",
"chars": 1473,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/setup.py",
"chars": 734,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n#\n# This source code is licensed under the MIT license found in the\n#"
},
{
"path": "third_party/sparse_voxels/src/binding.cpp",
"chars": 659,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/src/intersect.cpp",
"chars": 6968,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/src/intersect_gpu.cu",
"chars": 12773,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/src/octree.cpp",
"chars": 4439,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/src/sample.cpp",
"chars": 4402,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
},
{
"path": "third_party/sparse_voxels/src/sample_gpu.cu",
"chars": 8533,
"preview": "// Copyright (c) Facebook, Inc. and its affiliates.\n//\n// This source code is licensed under the MIT license found in th"
}
]
About this extraction
This page contains the full source code of the JunyuanDeng/NeRF-LOAM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 70 files (271.3 KB), approximately 88.3k tokens, and a symbol index with 305 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.