Repository: JunyuanDeng/NeRF-LOAM
Branch: master
Commit: 2fe4e8d8dd9a
Files: 70
Total size: 271.3 KB
Directory structure:
gitextract_t7v5lyn0/
├── .gitignore
├── LICENSE
├── Readme.md
├── configs/
│ ├── kitti/
│ │ ├── kitti.yaml
│ │ ├── kitti_00.yaml
│ │ ├── kitti_01.yaml
│ │ ├── kitti_03.yaml
│ │ ├── kitti_04.yaml
│ │ ├── kitti_05.yaml
│ │ ├── kitti_06.yaml
│ │ ├── kitti_07.yaml
│ │ ├── kitti_08.yaml
│ │ ├── kitti_09.yaml
│ │ ├── kitti_10.yaml
│ │ ├── kitti_base06.yaml
│ │ └── kitti_base10.yaml
│ ├── maicity/
│ │ ├── maicity.yaml
│ │ ├── maicity_00.yaml
│ │ └── maicity_01.yaml
│ └── ncd/
│ ├── ncd.yaml
│ └── ncd_quad.yaml
├── demo/
│ ├── parser.py
│ └── run.py
├── install.sh
├── requirements.txt
├── src/
│ ├── criterion.py
│ ├── dataset/
│ │ ├── kitti.py
│ │ ├── maicity.py
│ │ └── ncd.py
│ ├── lidarFrame.py
│ ├── loggers.py
│ ├── mapping.py
│ ├── nerfloam.py
│ ├── se3pose.py
│ ├── share.py
│ ├── tracking.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── import_util.py
│ │ ├── mesh_util.py
│ │ ├── profile_util.py
│ │ └── sample_util.py
│ └── variations/
│ ├── decode_morton.py
│ ├── lidar.py
│ ├── render_helpers.py
│ └── voxel_helpers.py
└── third_party/
├── marching_cubes/
│ ├── setup.py
│ └── src/
│ ├── mc.cpp
│ ├── mc_data.cuh
│ ├── mc_interp_kernel.cu
│ ├── mc_kernel.cu
│ └── mc_kernel_colour.cu
├── sparse_octree/
│ ├── include/
│ │ ├── octree.h
│ │ ├── test.h
│ │ └── utils.h
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ └── octree.cpp
└── sparse_voxels/
├── include/
│ ├── cuda_utils.h
│ ├── cutil_math.h
│ ├── intersect.h
│ ├── octree.h
│ ├── sample.h
│ └── utils.h
├── setup.py
└── src/
├── binding.cpp
├── intersect.cpp
├── intersect_gpu.cu
├── octree.cpp
├── sample.cpp
└── sample_gpu.cu
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__
*.egg-info
build/
dist/
logs/
.vscode/
results/
temp.sh
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 JunyuanDeng
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: Readme.md
================================================
# NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping
This repository contains the implementation of our paper:
> **NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping** ([PDF](https://arxiv.org/pdf/2303.10709))\
> [Junyuan Deng](https://github.com/JunyuanDeng), [Qi Wu](https://github.com/Gatsby23), [Xieyuanli Chen](https://github.com/Chen-Xieyuanli), Songpengcheng Xia, Zhen Sun, Guoqing Liu, Wenxian Yu and Ling Pei\
> If you use our code in your work, please star our repo and cite our paper.
```
@inproceedings{deng2023nerfloam,
title={NeRF-LOAM: Neural Implicit Representation for Large-Scale Incremental LiDAR Odometry and Mapping},
author={Junyuan Deng and Qi Wu and Xieyuanli Chen and Songpengcheng Xia and Zhen Sun and Guoqing Liu and Wenxian Yu and Ling Pei},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
year={2023}
}
```
- *Our incrementally simultaneous odometry and mapping results on the Newer College dataset and the KITTI dataset sequence 00.*
- *The maps are dense with a form of mesh, the red line indicates the odometry results.*
- *We use the same network without training to prove the ability of generalization of our design.*
## Overview

**Overview of our method.** Our method is based on our neural SDF and composed of three main components:
- Neural odometry takes the pre-processed scan and optimizes the pose via back projecting the queried neural SDF;
- Neural mapping jointly optimizes the voxel embeddings map and pose while selecting the key-scans;
- Key-scans refined map returns SDF value and the final mesh is reconstructed by marching cube.
## Quatitative results
**The reconstructed maps**

*The qualitative result of our odometry mapping on the KITTI dataset. From left upper to right bottom, we list the results of sequences 00, 01, 03, 04, 05, 09, 10.*
**The odometry results**

*The qualitative results of our odometry on the KITTI dataset. From left to right, we list the results of sequences 00, 01, 03, 04, 05, 07, 09, 10. The dashed line corresponds to the ground truth and the blue line to our odometry method.*
## Data
1. Newer College real-world LiDAR dataset: [website](https://ori-drs.github.io/newer-college-dataset/download/).
2. MaiCity synthetic LiDAR dataset: [website](https://www.ipb.uni-bonn.de/data/mai-city-dataset/).
3. KITTI dataset: [website](https://www.cvlibs.net/datasets/kitti/).
## Environment Setup
To run the code, a GPU with large memory is preferred. We tested the code with RTX3090 and GTX TITAN.
We use Conda to create a virtual environment and install dependencies:
- python environment: We tested our code with Python 3.8.13
- [Pytorch](https://pytorch.org/get-started/locally/): The Version we tested is 1.10 with cuda10.2 (and cuda11.1)
- Other depedencies are specified in requirements.txt. You can then install all dependancies using `pip` or `conda`:
```
pip3 install -r requirements.txt
```
- After you have installed all third party libraries, run the following script to build extra Pytorch modules used in this project.
```bash
sh install.sh
```
- Replace the filename in mapping.py with the built library
```python
torch.classes.load_library("third_party/sparse_octree/build/lib.xxx/svo.xxx.so")
```
- [patchwork-plusplus](https://github.com/url-kaist/patchwork-plusplus) to separate gound from LiDAR points.
- Replace the filename in src/dataset/*.py with the built library
```python
patchwork_module_path ="/xxx/patchwork-plusplus/build/python_wrapper"
```
## Demo
- The full dataset can be downloaded as mentioned. You can also download the example part of dataset, use these [scripts](https://github.com/PRBonn/SHINE_mapping/tree/master/scripts) to download.
- Take maicity seq.01 dataset as example: Modify `configs/maicity/maicity_01.yaml` so the data_path section points to the real dataset path. Now you are all set to run the code:
```
python demo/run.py configs/maicity/maicity_01.yaml
```
## Note
- For kitti dataset, if you want to process it more fast, you can switch to branch `subscene`:
```
git checkout subscene
```
- Then run with `python demo/run.py configs/kitti/kitti_00.yaml`
- This branch cut the full scene into subscenes to speed up and concatenate them together. This will certainly add map inconsistency and decay tracking accuracy...
## Evaluation
- We follow the evaluation proposed [here](https://github.com/PRBonn/SHINE_mapping/tree/master/eval), but we did not use the `crop_intersection.py`
## Acknowledgement
Some of our codes are adapted from [Vox-Fusion](https://github.com/zju3dv/Vox-Fusion).
## Contact
Any questions or suggestions are welcome!
Junyuan Deng: d.juney@sjtu.edu.cn and Xieyuanli Chen: xieyuanli.chen@nudt.edu.cn
## License
This project is free software made available under the MIT License. For details see the LICENSE file.
================================================
FILE: configs/kitti/kitti.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: kitti
criteria:
sdf_weight: 10000.0
fs_weight: 1
eiko_weight: 0.1
sdf_truncation: 0.30
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.06
step_size: 0.2
max_voxel_hit: 20
num_iterations: 25
mapper_specs:
N_rays_each: 2048
use_local_coord: False
voxel_size: 0.3
step_size: 0.5
window_size: 4
num_iterations: 25
max_voxel_hit: 20
final_iter: True
mesh_res: 2
learning_rate_emb: 0.01
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
freeze_frame: 5
keyframe_gap: 8
remove_back: False
key_distance: 12
debug_args:
verbose: False
mesh_freq: 100
================================================
FILE: configs/kitti/kitti_00.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence00
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/00'
use_gt: False
max_depth: 40
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 1
================================================
FILE: configs/kitti/kitti_01.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence01
data_specs:
data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/01/'
use_gt: False
max_depth: 30
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: 1101
read_offset: 1
================================================
FILE: configs/kitti/kitti_03.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence03
data_specs:
data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/03/'
use_gt: False
max_depth: 30
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: 1101
read_offset: 1
================================================
FILE: configs/kitti/kitti_04.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence04
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/04'
use_gt: False
max_depth: 50
min_depth: 2.75
tracker_specs:
start_frame: 0
end_frame: 270
read_offset: 1
================================================
FILE: configs/kitti/kitti_05.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence05
data_specs:
data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/05/'
use_gt: False
max_depth: 50
min_depth: 5
tracker_specs:
start_frame: 2299
end_frame: 2760
read_offset: 1
================================================
FILE: configs/kitti/kitti_06.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence06
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/06'
use_gt: False
max_depth: 40
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 1
================================================
FILE: configs/kitti/kitti_07.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence07
data_specs:
data_path: '/home/evsjtu2/disk1/dengjunyuan/kitti/dataset/sequences/07'
use_gt: True
max_depth: 25
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: 1100
read_offset: 1
================================================
FILE: configs/kitti/kitti_08.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence08
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/08'
use_gt: False
max_depth: 40
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 1
================================================
FILE: configs/kitti/kitti_09.yaml
================================================
base_config: configs/kitti/kitti.yaml
exp_name: kitti/sqeuence09
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/09'
use_gt: False
max_depth: 40
min_depth: 5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 1
================================================
FILE: configs/kitti/kitti_10.yaml
================================================
base_config: configs/kitti/kitti_base10.yaml
exp_name: kitti/sqeuence10
data_specs:
data_path: '/home/pl21n4/dataset/kitti/dataset/sequences/10'
use_gt: False
max_depth: 70
min_depth: 2.75
tracker_specs:
start_frame: 0
end_frame: 1200
read_offset: 1
================================================
FILE: configs/kitti/kitti_base06.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: kitti
criteria:
sdf_weight: 10000.0
fs_weight: 1
eiko_weight: 0.1
sdf_truncation: 0.30
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.06
step_size: 0.2
max_voxel_hit: 20
num_iterations: 25
mapper_specs:
N_rays_each: 2048
use_local_coord: False
voxel_size: 0.3
step_size: 0.5
window_size: 4
num_iterations: 25
max_voxel_hit: 20
final_iter: True
mesh_res: 2
learning_rate_emb: 0.01
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
freeze_frame: 5
keyframe_gap: 8
remove_back: False
key_distance: 12
debug_args:
verbose: False
mesh_freq: 100
================================================
FILE: configs/kitti/kitti_base10.yaml
================================================
log_dir: '/home/evsjtu2/disk1/dengjunyuan/running_logs/'
decoder: lidar
dataset: kitti
criteria:
depth_weight: 0
sdf_weight: 12000.0
fs_weight: 1
eiko_weight: 0
sdf_truncation: 0.50
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.1
start_frame: 0
end_frame: -1
step_size: 0.2
show_imgs: False
max_voxel_hit: 20
keyframe_freq: 10
num_iterations: 40
mapper_specs:
N_rays_each: 2048
num_embeddings: 20000000
use_local_coord: False
voxel_size: 0.2
step_size: 0.2
window_size: 4
num_iterations: 20
max_voxel_hit: 20
final_iter: True
mesh_res: 2
overlap_th: 0.8
learning_rate_emb: 0.03
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
#max_depth_first: 20
freeze_frame: 10
keyframe_gap: 7
remove_back: True
key_distance: 7
debug_args:
verbose: False
mesh_freq: 100
================================================
FILE: configs/maicity/maicity.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: maicity
criteria:
sdf_weight: 10000.0
fs_weight: 1
eiko_weight: 0.1
sdf_truncation: 0.30
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.005
step_size: 0.2
max_voxel_hit: 20
num_iterations: 20
mapper_specs:
N_rays_each: 2048
use_local_coord: False
voxel_size: 0.2
step_size: 0.5
window_size: 4
num_iterations: 20
max_voxel_hit: 20
final_iter: True
mesh_res: 2
learning_rate_emb: 0.03
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
freeze_frame: 5
keyframe_gap: 8
remove_back: False
key_distance: 12
debug_args:
verbose: False
mesh_freq: 100
================================================
FILE: configs/maicity/maicity_00.yaml
================================================
base_config: configs/maicity/maicity.yaml
exp_name: maicity/sqeuence00
data_specs:
data_path: '/home/pl21n4/dataset/mai_city/bin/sequences/00'
use_gt: False
max_depth: 50.0
min_depth: 1.5
tracker_specs:
start_frame: 0
end_frame: 699
read_offset: 1
================================================
FILE: configs/maicity/maicity_01.yaml
================================================
base_config: configs/maicity/maicity.yaml
exp_name: maicity/sqeuence01
data_specs:
data_path: '/home/pl21n4/dataset/mai_city/bin/sequences/01'
use_gt: False
max_depth: 50.0
min_depth: 1.5
tracker_specs:
start_frame: 0
end_frame: 99
read_offset: 1
================================================
FILE: configs/ncd/ncd.yaml
================================================
log_dir: './logs'
decoder: lidar
dataset: ncd
criteria:
sdf_weight: 10000.0
fs_weight: 1
eiko_weight: 1.0
sdf_truncation: 0.30
decoder_specs:
depth: 2
width: 256
in_dim: 16
skips: []
embedder: none
multires: 0
tracker_specs:
N_rays: 2048
learning_rate: 0.04
step_size: 0.1
max_voxel_hit: 20
num_iterations: 30
mapper_specs:
N_rays_each: 2048
use_local_coord: False
voxel_size: 0.2
step_size: 0.2
window_size: 5
num_iterations: 15
max_voxel_hit: 20
final_iter: True
mesh_res: 2
learning_rate_emb: 0.002
learning_rate_decorder: 0.005
learning_rate_pose: 0.001
freeze_frame: 20
keyframe_gap: 8
remove_back: False
key_distance: 20
debug_args:
verbose: False
mesh_freq: 500
================================================
FILE: configs/ncd/ncd_quad.yaml
================================================
base_config: configs/ncd/ncd.yaml
exp_name: ncd/quad
data_specs:
data_path: '/home/pl21n4/dataset/ncd_example/quad'
use_gt: False
max_depth: 50
min_depth: 1.5
tracker_specs:
start_frame: 0
end_frame: -1
read_offset: 5
================================================
FILE: demo/parser.py
================================================
import yaml
import argparse
class ArgumentParserX(argparse.ArgumentParser):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_argument("config", type=str)
def parse_args(self, args=None, namespace=None):
_args = self.parse_known_args(args, namespace)[0]
file_args = argparse.Namespace()
file_args = self.parse_config_yaml(_args.config, file_args)
file_args = self.convert_to_namespace(file_args)
for ckey, cvalue in file_args.__dict__.items():
try:
self.add_argument('--' + ckey, type=type(cvalue),
default=cvalue, required=False)
except argparse.ArgumentError:
continue
_args = super().parse_args(args, namespace)
return _args
def parse_config_yaml(self, yaml_path, args=None):
with open(yaml_path, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
if configs is not None:
base_config = configs.get('base_config')
if base_config is not None:
base_config = self.parse_config_yaml(configs["base_config"])
if base_config is not None:
configs = self.update_recursive(base_config, configs)
else:
raise FileNotFoundError("base_config specified but not found!")
return configs
def convert_to_namespace(self, dict_in, args=None):
if args is None:
args = argparse.Namespace()
for ckey, cvalue in dict_in.items():
if ckey not in args.__dict__.keys():
args.__dict__[ckey] = cvalue
return args
def update_recursive(self, dict1, dict2):
for k, v in dict2.items():
if k not in dict1:
dict1[k] = dict()
if isinstance(v, dict):
self.update_recursive(dict1[k], v)
else:
dict1[k] = v
return dict1
def get_parser():
parser = ArgumentParserX()
parser.add_argument("--resume", default=None, type=str)
parser.add_argument("--debug", action='store_true')
return parser
if __name__ == '__main__':
args = ArgumentParserX()
print(args.parse_args())
================================================
FILE: demo/run.py
================================================
import os # noqa
import sys # noqa
sys.path.insert(0, os.path.abspath('src')) # noqa
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import random
from parser import get_parser
import numpy as np
import torch
from nerfloam import nerfloam
import os
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
if __name__ == '__main__':
args = get_parser().parse_args()
if hasattr(args, 'seeding'):
setup_seed(args.seeding)
else:
setup_seed(777)
slam = nerfloam(args)
slam.start()
slam.wait_child_processes()
================================================
FILE: install.sh
================================================
#!/bin/bash
cd third_party/marching_cubes
python setup.py install
cd ../sparse_octree
python setup.py install
cd ../sparse_voxels
python setup.py install
================================================
FILE: requirements.txt
================================================
matplotlib
open3d
opencv-python
PyYAML
scikit-image
tqdm
trimesh
pyyaml
================================================
FILE: src/criterion.py
================================================
import torch
import torch.nn as nn
from torch.autograd import grad
class Criterion(nn.Module):
def __init__(self, args) -> None:
super().__init__()
self.args = args
self.eiko_weight = args.criteria["eiko_weight"]
self.sdf_weight = args.criteria["sdf_weight"]
self.fs_weight = args.criteria["fs_weight"]
self.truncation = args.criteria["sdf_truncation"]
self.max_dpeth = args.data_specs["max_depth"]
def forward(self, outputs, obs, pointsCos, use_color_loss=True,
use_depth_loss=True, compute_sdf_loss=True,
weight_depth_loss=False, compute_eikonal_loss=False):
points = obs
loss = 0
loss_dict = {}
# pred_depth = outputs["depth"]
pred_sdf = outputs["sdf"]
z_vals = outputs["z_vals"]
ray_mask = outputs["ray_mask"]
valid_mask = outputs["valid_mask"]
sampled_xyz = outputs["sampled_xyz"]
gt_points = points[ray_mask]
pointsCos = pointsCos[ray_mask]
gt_distance = torch.norm(gt_points, 2, -1)
gt_distance = gt_distance * pointsCos.view(-1)
z_vals = z_vals * pointsCos.view(-1, 1)
if compute_sdf_loss:
fs_loss, sdf_loss, eikonal_loss = self.get_sdf_loss(
z_vals, gt_distance, pred_sdf,
truncation=self.truncation,
loss_type='l2',
valid_mask=valid_mask,
compute_eikonal_loss=compute_eikonal_loss,
points=sampled_xyz if compute_eikonal_loss else None
)
loss += self.fs_weight * fs_loss
loss += self.sdf_weight * sdf_loss
# loss += self.bs_weight * back_loss
loss_dict["fs_loss"] = fs_loss.item()
# loss_dict["bs_loss"] = back_loss.item()
loss_dict["sdf_loss"] = sdf_loss.item()
if compute_eikonal_loss:
loss += self.eiko_weight * eikonal_loss
loss_dict["eiko_loss"] = eikonal_loss.item()
loss_dict["loss"] = loss.item()
# print(loss_dict)
return loss, loss_dict
def compute_loss(self, x, y, mask=None, loss_type="l2"):
if mask is None:
mask = torch.ones_like(x).bool()
if loss_type == "l1":
return torch.mean(torch.abs(x - y)[mask])
elif loss_type == "l2":
return torch.mean(torch.square(x - y)[mask])
def get_masks(self, z_vals, depth, epsilon):
front_mask = torch.where(
z_vals < (depth - epsilon),
torch.ones_like(z_vals),
torch.zeros_like(z_vals),
)
back_mask = torch.where(
z_vals > (depth + epsilon),
torch.ones_like(z_vals),
torch.zeros_like(z_vals),
)
depth_mask = torch.where(
(depth > 0.0) & (depth < self.max_dpeth), torch.ones_like(
depth), torch.zeros_like(depth)
)
sdf_mask = (1.0 - front_mask) * (1.0 - back_mask) * depth_mask
num_fs_samples = torch.count_nonzero(front_mask).float()
num_sdf_samples = torch.count_nonzero(sdf_mask).float()
num_samples = num_sdf_samples + num_fs_samples
fs_weight = 1.0 - num_fs_samples / num_samples
sdf_weight = 1.0 - num_sdf_samples / num_samples
return front_mask, sdf_mask, fs_weight, sdf_weight
def get_sdf_loss(self, z_vals, depth, predicted_sdf, truncation, valid_mask, loss_type="l2", compute_eikonal_loss=False, points=None):
front_mask, sdf_mask, fs_weight, sdf_weight = self.get_masks(
z_vals, depth.unsqueeze(-1).expand(*z_vals.shape), truncation
)
fs_loss = (self.compute_loss(predicted_sdf * front_mask * valid_mask, torch.ones_like(
predicted_sdf) * front_mask, loss_type=loss_type,) * fs_weight)
sdf_loss = (self.compute_loss((z_vals + predicted_sdf * truncation) * sdf_mask * valid_mask,
depth.unsqueeze(-1).expand(*z_vals.shape) * sdf_mask, loss_type=loss_type,) * sdf_weight)
# back_loss = (self.compute_loss(predicted_sdf * back_mask, -torch.ones_like(
# predicted_sdf) * back_mask, loss_type=loss_type,) * back_weight)
eikonal_loss = None
if compute_eikonal_loss:
sdf = (predicted_sdf*sdf_mask*truncation)
sdf = sdf[valid_mask]
d_points = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
sdf_grad = grad(outputs=sdf,
inputs=points,
grad_outputs=d_points,
retain_graph=True,
only_inputs=True)[0]
eikonal_loss = self.compute_loss(sdf_grad[0].norm(2, -1), 1.0, loss_type=loss_type,)
return fs_loss, sdf_loss, eikonal_loss
================================================
FILE: src/dataset/kitti.py
================================================
import os.path as osp
import numpy as np
import torch
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree
patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp
params = pypatchworkpp.Parameters()
# params.verbose = True
PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)
class DataLoader(Dataset):
def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
self.data_path = data_path
self.num_bin = len(glob(osp.join(self.data_path, "velodyne/*.bin")))
self.use_gt = use_gt
self.max_depth = max_depth
self.min_depth = min_depth
self.gt_pose = self.load_gt_pose() if use_gt else None
def get_init_pose(self, frame):
if self.gt_pose is not None:
return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
).reshape(4, 4)
else:
return np.eye(4)
def load_gt_pose(self):
gt_file = osp.join(self.data_path, "poses_lidar.txt")
gt_pose = np.loadtxt(gt_file)
return gt_pose
def load_points(self, index):
remove_abnormal_z = True
path = osp.join(self.data_path, "velodyne/{:06d}.bin".format(index))
points = np.fromfile(path, dtype=np.float32, count=-1).reshape([-1, 4])
if remove_abnormal_z:
points = points[points[:, 2] > -3.0]
points_norm = np.linalg.norm(points[:, :3], axis=-1)
point_mask = True
if self.max_depth != -1:
point_mask = (points_norm < self.max_depth) & point_mask
if self.min_depth != -1:
point_mask = (points_norm > self.min_depth) & point_mask
if isinstance(point_mask, np.ndarray):
points = points[point_mask]
PatchworkPLUSPLUS.estimateGround(points)
ground = PatchworkPLUSPLUS.getGround()
nonground = PatchworkPLUSPLUS.getNonground()
Patchcenters = PatchworkPLUSPLUS.getCenters()
normals = PatchworkPLUSPLUS.getNormals()
T = cKDTree(Patchcenters)
_, index = T.query(ground)
if True:
groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
else:
groundcos = np.ones(ground.shape[0])
points = np.concatenate((ground, nonground), axis=0)
pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)
return points, pointcos
def __len__(self):
return self.num_bin
def __getitem__(self, index):
points, pointcos = self.load_points(index)
points = torch.from_numpy(points).float()
pointcos = torch.from_numpy(pointcos).float()
pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
).reshape(4, 4) if self.use_gt else None
return index, points, pointcos, pose
if __name__ == "__main__":
path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
loader = DataLoader(path)
for data in loader:
index, points, pose = data
print("current index ", index)
print("first 10th points:\n", points[:10])
if index > 10:
break
index += 1
================================================
FILE: src/dataset/maicity.py
================================================
import os.path as osp
import numpy as np
import torch
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree
patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp
params = pypatchworkpp.Parameters()
# params.verbose = True
PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)
class DataLoader(Dataset):
def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
self.data_path = data_path
self.num_bin = len(glob(osp.join(self.data_path, "velodyne/*.bin")))
self.use_gt = use_gt
self.max_depth = max_depth
self.min_depth = min_depth
self.gt_pose = self.load_gt_pose() if use_gt else None
def get_init_pose(self, frame):
if self.gt_pose is not None:
return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
).reshape(4, 4)
else:
return np.eye(4)
def load_gt_pose(self):
gt_file = osp.join(self.data_path, "poses.txt")
gt_pose = np.loadtxt(gt_file)
return gt_pose
def load_points(self, index):
path = osp.join(self.data_path, "velodyne/{:05d}.bin".format(index))
points = np.fromfile(path, dtype=np.float32, count=-1).reshape([-1, 4])
# points = points[:,:3]
# points = np.delete(points, -1, axis=1)
points_norm = np.linalg.norm(points[:, :3], axis=-1)
point_mask = True
if self.max_depth != -1:
point_mask = (points_norm < self.max_depth) & point_mask
if self.min_depth != -1:
point_mask = (points_norm > self.min_depth) & point_mask
if isinstance(point_mask, np.ndarray):
points = points[point_mask]
PatchworkPLUSPLUS.estimateGround(points)
ground = PatchworkPLUSPLUS.getGround()
nonground = PatchworkPLUSPLUS.getNonground()
Patchcenters = PatchworkPLUSPLUS.getCenters()
normals = PatchworkPLUSPLUS.getNormals()
T = cKDTree(Patchcenters)
_, index = T.query(ground)
if True:
groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
# groundnorm = np.linalg.norm(ground, axis=-1)
# groundcos = np.where(groundnorm > 10.0, np.ones(ground.shape[0]), groundcos)
else:
groundcos = np.ones(ground.shape[0])
points = np.concatenate((ground, nonground), axis=0)
pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)
return points, pointcos
def __len__(self):
return self.num_bin
def __getitem__(self, index):
points, pointcos = self.load_points(index)
points = torch.from_numpy(points).float()
pointcos = torch.from_numpy(pointcos).float()
pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
).reshape(4, 4) if self.use_gt else None
return index, points, pointcos, pose
if __name__ == "__main__":
path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
loader = DataLoader(path)
for data in loader:
index, points, pose = data
print("current index ", index)
print("first 10th points:\n", points[:10])
if index > 10:
break
index += 1
================================================
FILE: src/dataset/ncd.py
================================================
import os.path as osp
import numpy as np
import torch
import open3d as o3d
from glob import glob
from torch.utils.data import Dataset
import sys
from scipy.spatial import cKDTree
patchwork_module_path ="/home/pl21n4/Programmes/patchwork-plusplus/build/python_wrapper"
sys.path.insert(0, patchwork_module_path)
import pypatchworkpp
params = pypatchworkpp.Parameters()
params.enable_RNR = False
# params.verbose = True
PatchworkPLUSPLUS = pypatchworkpp.patchworkpp(params)
class DataLoader(Dataset):
def __init__(self, data_path, use_gt=False, max_depth=-1, min_depth=-1) -> None:
self.data_path = data_path
self.num_bin = len(glob(osp.join(self.data_path, "pcd/*.pcd")))
self.use_gt = use_gt
self.max_depth = max_depth
self.min_depth = min_depth
self.gt_pose = self.load_gt_pose() if use_gt else None
def get_init_pose(self, frame):
if self.gt_pose is not None:
return np.concatenate((self.gt_pose[frame], [0, 0, 0, 1])
).reshape(4, 4)
else:
return np.array([[5.925493285036220747e-01, -8.038419275143061649e-01, 5.218676416200035417e-02, -2.422443415414985424e-01],
[8.017167514002809803e-01, 5.948020209102693467e-01, 5.882863457495644127e-02, 3.667865561670570873e+00],
[-7.832971094540422397e-02, 6.980134849334420320e-03, 9.969030746023688216e-01, 6.809443654823238434e-01]])
def load_gt_pose(self):
gt_file = osp.join(self.data_path, "poses.txt")
gt_pose = np.loadtxt(gt_file)
# with open(gt_file, mode='r', encoding="utf-8") as g:
# line = g.readline()
# while line:
# # TODO:write transfomation of kitti and pose matrix
# pose = np.zeros(16)
return gt_pose
def load_points(self, index):
path = osp.join(self.data_path, "pcd/{:05d}.pcd".format(index+500))
pc_load = o3d.io.read_point_cloud(path)
points = np.asarray(pc_load.points)
points_norm = np.linalg.norm(points, axis=-1)
point_mask = True
if self.max_depth != -1:
point_mask = (points_norm < self.max_depth) & point_mask
if self.min_depth != -1:
point_mask = (points_norm > self.min_depth) & point_mask
if isinstance(point_mask, np.ndarray):
points = points[point_mask]
PatchworkPLUSPLUS.estimateGround(points)
ground = PatchworkPLUSPLUS.getGround()
nonground = PatchworkPLUSPLUS.getNonground()
Patchcenters = PatchworkPLUSPLUS.getCenters()
normals = PatchworkPLUSPLUS.getNormals()
T = cKDTree(Patchcenters)
_, index = T.query(ground)
if True:
groundcos = np.abs(np.sum(normals[index] * ground, axis=-1)/np.linalg.norm(ground, axis=-1))
else:
groundcos = np.ones(ground.shape[0])
points = np.concatenate((ground, nonground), axis=0)
pointcos = np.concatenate((groundcos, np.ones(nonground.shape[0])), axis=0)
return points, pointcos
def __len__(self):
return self.num_bin
def __getitem__(self, index):
points, pointcos = self.load_points(index)
points = torch.from_numpy(points).float()
pointcos = torch.from_numpy(pointcos).float()
pose = np.concatenate((self.gt_pose[index], [0, 0, 0, 1])
).reshape(4, 4) if self.use_gt else None
return index, points, pointcos, pose
if __name__ == "__main__":
path = "/home/pl21n4/dataset/kitti/dataset/sequences/00/"
loader = DataLoader(path)
for data in loader:
index, points, pose = data
print("current index ", index)
print("first 10th points:\n", points[:10])
if index > 10:
break
index += 1
================================================
FILE: src/lidarFrame.py
================================================
import torch
import torch.nn as nn
import numpy as np
from se3pose import OptimizablePose
from utils.sample_util import *
import random
class LidarFrame(nn.Module):
def __init__(self, index, points, pointsCos, pose=None, new_keyframe=False) -> None:
super().__init__()
self.index = index
self.num_point = len(points)
self.points = points
self.pointsCos = pointsCos
if (not new_keyframe) and (pose is not None):
# TODO: fix this offset
pose[:3, 3] += 2000
pose = torch.tensor(pose, requires_grad=True, dtype=torch.float32)
self.pose = OptimizablePose.from_matrix(pose)
elif new_keyframe:
self.pose = pose
self.rays_d = self.get_rays()
self.rel_pose = None
def get_pose(self):
return self.pose.matrix()
def get_translation(self):
return self.pose.translation()
def get_rotation(self):
return self.pose.rotation()
def get_points(self):
return self.points
def get_pointsCos(self):
return self.pointsCos
def set_rel_pose(self, rel_pose):
self.rel_pose = rel_pose
def get_rel_pose(self):
return self.rel_pose
@torch.no_grad()
def get_rays(self):
self.rays_norm = (torch.norm(self.points, 2, -1, keepdim=True)+1e-8)
rays_d = self.points / self.rays_norm
# TODO: to keep cosistency, add one dim, but actually no need
return rays_d.unsqueeze(1).float()
@torch.no_grad()
def sample_rays(self, N_rays, track=False):
self.sample_mask = sample_rays(
torch.ones((self.num_point, 1))[None, ...], N_rays)[0, ...]
================================================
FILE: src/loggers.py
================================================
import os
import os.path as osp
import pickle
from datetime import datetime
import cv2
import matplotlib.pyplot as plt
import numpy as np
import open3d as o3d
import torch
import yaml
class BasicLogger:
def __init__(self, args) -> None:
self.args = args
self.log_dir = osp.join(
args.log_dir, args.exp_name, self.get_random_time_str())
self.img_dir = osp.join(self.log_dir, "imgs")
self.mesh_dir = osp.join(self.log_dir, "mesh")
self.ckpt_dir = osp.join(self.log_dir, "ckpt")
self.backup_dir = osp.join(self.log_dir, "bak")
self.misc_dir = osp.join(self.log_dir, "misc")
os.makedirs(self.img_dir)
os.makedirs(self.ckpt_dir)
os.makedirs(self.mesh_dir)
os.makedirs(self.misc_dir)
os.makedirs(self.backup_dir)
self.log_config(args)
def get_random_time_str(self):
return datetime.strftime(datetime.now(), "%Y-%m-%d-%H-%M-%S")
def log_ckpt(self, mapper):
print("******* saving *******")
decoder_state = {f: v.cpu()
for f, v in mapper.decoder.state_dict().items()}
map_state = {f: v.cpu() for f, v in mapper.map_states.items()}
embeddings = mapper.dynamic_embeddings.cpu()
svo = mapper.svo
torch.save({
"decoder_state": decoder_state,
# "map_state": map_state,
"embeddings": embeddings,
"svo": svo
},
os.path.join(self.ckpt_dir, "final_ckpt.pth"))
print("******* finish saving *******")
def log_config(self, config):
out_path = osp.join(self.backup_dir, "config.yaml")
yaml.dump(config, open(out_path, 'w'))
def log_mesh(self, mesh, name="final_mesh.ply"):
out_path = osp.join(self.mesh_dir, name)
o3d.io.write_triangle_mesh(out_path, mesh)
def log_point_cloud(self, pcd, name="final_points.ply"):
out_path = osp.join(self.mesh_dir, name)
o3d.io.write_point_cloud(out_path, pcd)
def log_numpy_data(self, data, name, ind=None):
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
if ind is not None:
np.save(osp.join(self.misc_dir, "{}-{:05d}.npy".format(name, ind)), data)
else:
np.save(osp.join(self.misc_dir, f"{name}.npy"), data)
self.npy2txt(osp.join(self.misc_dir, f"{name}.npy"), osp.join(self.misc_dir, f"{name}.txt"))
def log_debug_data(self, data, idx):
with open(os.path.join(self.misc_dir, f"scene_data_{idx}.pkl"), 'wb') as f:
pickle.dump(data, f)
def log_raw_image(self, ind, rgb, depth):
if isinstance(rgb, torch.Tensor):
rgb = rgb.detach().cpu().numpy()
if isinstance(depth, torch.Tensor):
depth = depth.detach().cpu().numpy()
rgb = cv2.cvtColor(rgb*255, cv2.COLOR_RGB2BGR)
cv2.imwrite(osp.join(self.img_dir, "{:05d}.jpg".format(
ind)), (rgb).astype(np.uint8))
cv2.imwrite(osp.join(self.img_dir, "{:05d}.png".format(
ind)), (depth*5000).astype(np.uint16))
def log_images(self, ind, gt_rgb, gt_depth, rgb, depth):
gt_depth_np = gt_depth.detach().cpu().numpy()
gt_color_np = gt_rgb.detach().cpu().numpy()
depth_np = depth.squeeze().detach().cpu().numpy()
color_np = rgb.detach().cpu().numpy()
h, w = depth_np.shape
gt_depth_np = cv2.resize(
gt_depth_np, (w, h), interpolation=cv2.INTER_NEAREST)
gt_color_np = cv2.resize(
gt_color_np, (w, h), interpolation=cv2.INTER_AREA)
depth_residual = np.abs(gt_depth_np - depth_np)
depth_residual[gt_depth_np == 0.0] = 0.0
color_residual = np.abs(gt_color_np - color_np)
color_residual[gt_depth_np == 0.0] = 0.0
fig, axs = plt.subplots(2, 3)
fig.tight_layout()
max_depth = np.max(gt_depth_np)
axs[0, 0].imshow(gt_depth_np, cmap="plasma",
vmin=0, vmax=max_depth)
axs[0, 0].set_title('Input Depth')
axs[0, 0].set_xticks([])
axs[0, 0].set_yticks([])
axs[0, 1].imshow(depth_np, cmap="plasma",
vmin=0, vmax=max_depth)
axs[0, 1].set_title('Generated Depth')
axs[0, 1].set_xticks([])
axs[0, 1].set_yticks([])
axs[0, 2].imshow(depth_residual, cmap="plasma",
vmin=0, vmax=max_depth)
axs[0, 2].set_title('Depth Residual')
axs[0, 2].set_xticks([])
axs[0, 2].set_yticks([])
gt_color_np = np.clip(gt_color_np, 0, 1)
color_np = np.clip(color_np, 0, 1)
color_residual = np.clip(color_residual, 0, 1)
axs[1, 0].imshow(gt_color_np, cmap="plasma")
axs[1, 0].set_title('Input RGB')
axs[1, 0].set_xticks([])
axs[1, 0].set_yticks([])
axs[1, 1].imshow(color_np, cmap="plasma")
axs[1, 1].set_title('Generated RGB')
axs[1, 1].set_xticks([])
axs[1, 1].set_yticks([])
axs[1, 2].imshow(color_residual, cmap="plasma")
axs[1, 2].set_title('RGB Residual')
axs[1, 2].set_xticks([])
axs[1, 2].set_yticks([])
plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig(osp.join(self.img_dir, "{:05d}.jpg".format(
ind)), bbox_inches='tight', pad_inches=0.2)
plt.clf()
plt.close()
def npy2txt(self, input_path, output_path):
poses = np.load(input_path)
with open(output_path, mode='w') as w:
shape = poses.shape
print(shape)
for i in range(shape[0]):
one_pose = str()
for j in range(shape[1]):
if j == (shape[1]-1):
continue
for k in range(shape[2]):
if j == (shape[1]-2) and k == (shape[1]-1):
one_pose += (str(poses[i][j][k])+"\n")
else:
one_pose += (str(poses[i][j][k])+" ")
w.write(one_pose)
================================================
FILE: src/mapping.py
================================================
from copy import deepcopy
import random
from time import sleep
import numpy as np
from tqdm import tqdm
import torch
from criterion import Criterion
from loggers import BasicLogger
from utils.import_util import get_decoder, get_property
from variations.render_helpers import bundle_adjust_frames
from utils.mesh_util import MeshExtractor
from utils.profile_util import Profiler
from lidarFrame import LidarFrame
import torch.nn.functional as F
from pathlib import Path
import open3d as o3d
torch.classes.load_library(
"/home/pl21n4/Programmes/Vox-Fusion/third_party/sparse_octree/build/lib.linux-x86_64-cpython-38/svo.cpython-38-x86_64-linux-gnu.so")
def get_network_size(net):
size = 0
for param in net.parameters():
size += param.element_size() * param.numel()
return size / 1024 / 1024
class Mapping:
def __init__(self, args, logger: BasicLogger):
super().__init__()
self.args = args
self.logger = logger
self.decoder = get_decoder(args).cuda()
print(self.decoder)
self.loss_criteria = Criterion(args)
self.keyframe_graph = []
self.initialized = False
mapper_specs = args.mapper_specs
# optional args
self.ckpt_freq = get_property(args, "ckpt_freq", -1)
self.final_iter = get_property(mapper_specs, "final_iter", False)
self.mesh_res = get_property(mapper_specs, "mesh_res", 8)
self.save_data_freq = get_property(
args.debug_args, "save_data_freq", 0)
# required args
self.voxel_size = mapper_specs["voxel_size"]
self.window_size = mapper_specs["window_size"]
self.num_iterations = mapper_specs["num_iterations"]
self.n_rays = mapper_specs["N_rays_each"]
self.sdf_truncation = args.criteria["sdf_truncation"]
self.max_voxel_hit = mapper_specs["max_voxel_hit"]
self.step_size = mapper_specs["step_size"]
self.learning_rate_emb = mapper_specs["learning_rate_emb"]
self.learning_rate_decorder = mapper_specs["learning_rate_decorder"]
self.learning_rate_pose = mapper_specs["learning_rate_pose"]
self.step_size = self.step_size * self.voxel_size
self.max_distance = args.data_specs["max_depth"]
self.freeze_frame = mapper_specs["freeze_frame"]
self.keyframe_gap = mapper_specs["keyframe_gap"]
self.remove_back = mapper_specs["remove_back"]
self.key_distance = mapper_specs["key_distance"]
embed_dim = args.decoder_specs["in_dim"]
use_local_coord = mapper_specs["use_local_coord"]
self.embed_dim = embed_dim - 3 if use_local_coord else embed_dim
#num_embeddings = mapper_specs["num_embeddings"]
self.mesh_freq = args.debug_args["mesh_freq"]
self.mesher = MeshExtractor(args)
self.voxel_id2embedding_id = -torch.ones((int(2e9), 1), dtype=torch.int)
self.embeds_exist_search = dict()
self.current_num_embeds = 0
self.dynamic_embeddings = None
self.svo = torch.classes.svo.Octree()
self.svo.init(256*256*4, embed_dim, self.voxel_size)
self.frame_poses = []
self.depth_maps = []
self.last_tracked_frame_id = 0
self.final_poses=[]
verbose = get_property(args.debug_args, "verbose", False)
self.profiler = Profiler(verbose=verbose)
self.profiler.enable()
def spin(self, share_data, kf_buffer):
print("mapping process started!!!!!!!!!")
while True:
torch.cuda.empty_cache()
if not kf_buffer.empty():
tracked_frame = kf_buffer.get()
# self.create_voxels(tracked_frame)
if not self.initialized:
self.first_frame_id = tracked_frame.index
if self.mesher is not None:
self.mesher.rays_d = tracked_frame.get_rays()
self.create_voxels(tracked_frame)
self.insert_keyframe(tracked_frame)
while kf_buffer.empty():
self.do_mapping(share_data, tracked_frame, selection_method='current')
self.initialized = True
else:
if self.remove_back:
tracked_frame = self.remove_back_points(tracked_frame)
self.do_mapping(share_data, tracked_frame)
self.create_voxels(tracked_frame)
if (torch.norm(tracked_frame.pose.translation().cpu()
- self.current_keyframe.pose.translation().cpu())) > self.keyframe_gap:
self.insert_keyframe(tracked_frame)
print(
f"********** current num kfs: { len(self.keyframe_graph) } **********")
# self.create_voxels(tracked_frame)
tracked_pose = tracked_frame.get_pose().detach()
ref_pose = self.current_keyframe.get_pose().detach()
rel_pose = torch.linalg.inv(ref_pose) @ tracked_pose
self.frame_poses += [(len(self.keyframe_graph) -
1, rel_pose.cpu())]
if self.mesh_freq > 0 and (tracked_frame.index) % self.mesh_freq == 0:
if self.final_iter and len(self.keyframe_graph) > 20:
print(f"********** post-processing steps **********")
#self.num_iterations = 1
final_num_iter = len(self.keyframe_graph) + 1
progress_bar = tqdm(
range(0, final_num_iter), position=0)
progress_bar.set_description(" post-processing steps")
for iter in progress_bar:
#tracked_frame=self.keyframe_graph[iter//self.window_size]
self.do_mapping(share_data, tracked_frame=None,
update_pose=False, update_decoder=False, selection_method='random')
self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False),name=f"mesh_{tracked_frame.index:05d}.ply")
pose = self.get_updated_poses()
self.logger.log_numpy_data(np.asarray(pose), f"frame_poses_{tracked_frame.index:05d}")
if self.final_iter and len(self.keyframe_graph) > 20:
self.keyframe_graph = []
self.keyframe_graph += [self.current_keyframe]
if self.save_data_freq > 0 and (tracked_frame.stamp + 1) % self.save_data_freq == 0:
self.save_debug_data(tracked_frame)
elif share_data.stop_mapping:
break
print("******* extracting mesh without replay *******")
self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False), name="final_mesh_noreplay.ply")
if self.final_iter:
print(f"********** post-processing steps **********")
#self.num_iterations = 1
final_num_iter = len(self.keyframe_graph) + 1
progress_bar = tqdm(
range(0, final_num_iter), position=0)
progress_bar.set_description(" post-processing steps")
for iter in progress_bar:
tracked_frame=self.keyframe_graph[iter//self.window_size]
self.do_mapping(share_data, tracked_frame=None,
update_pose=False, update_decoder=False, selection_method='random')
print("******* extracting final mesh *******")
pose = self.get_updated_poses()
self.logger.log_numpy_data(np.asarray(pose), "frame_poses")
self.logger.log_mesh(self.extract_mesh(res=self.mesh_res, clean_mesh=False))
print("******* mapping process died *******")
def do_mapping(self, share_data, tracked_frame=None,
update_pose=True, update_decoder=True, selection_method = 'current'):
self.profiler.tick("do_mapping")
self.decoder.train()
optimize_targets = self.select_optimize_targets(tracked_frame, selection_method=selection_method)
torch.cuda.empty_cache()
self.profiler.tick("bundle_adjust_frames")
bundle_adjust_frames(
optimize_targets,
self.dynamic_embeddings,
self.map_states,
self.decoder,
self.loss_criteria,
self.voxel_size,
self.step_size,
self.n_rays * 2 if selection_method=='random' else self.n_rays,
self.num_iterations,
self.sdf_truncation,
self.max_voxel_hit,
self.max_distance,
learning_rate=[self.learning_rate_emb,
self.learning_rate_decorder,
self.learning_rate_pose],
update_pose=update_pose,
update_decoder=update_decoder if tracked_frame == None or (tracked_frame.index -self.first_frame_id) < self.freeze_frame else False,
profiler=self.profiler
)
self.profiler.tok("bundle_adjust_frames")
# optimize_targets = [f.cpu() for f in optimize_targets]
self.update_share_data(share_data)
self.profiler.tok("do_mapping")
# sleep(0.01)
def select_optimize_targets(self, tracked_frame=None, selection_method='previous'):
# TODO: better ways
targets = []
if selection_method == 'current':
if tracked_frame == None:
raise ValueError('select one track frame')
else:
return [tracked_frame]
if len(self.keyframe_graph) <= self.window_size:
targets = self.keyframe_graph[:]
elif selection_method == 'random':
targets = random.sample(self.keyframe_graph, self.window_size)
elif selection_method == 'previous':
targets = self.keyframe_graph[-self.window_size:]
elif selection_method == 'overlap':
raise NotImplementedError(
f"seletion method {selection_method} unknown")
if tracked_frame is not None and tracked_frame != self.current_keyframe:
targets += [tracked_frame]
return targets
def update_share_data(self, share_data, frameid=None):
share_data.decoder = deepcopy(self.decoder)
tmp_states = {}
for k, v in self.map_states.items():
tmp_states[k] = v.detach().cpu()
share_data.states = tmp_states
# self.last_tracked_frame_id = frameid
def remove_back_points(self, frame):
rel_pose = frame.get_rel_pose()
points = frame.get_points()
points_norm = torch.norm(points, 2, -1)
points_xy = points[:, :2]
if rel_pose == None:
x = 1
y = 0
else:
x = rel_pose[0, 3]
y = rel_pose[1, 3]
rel_xy = torch.ones((1, 2))
rel_xy[0, 0] = x
rel_xy[0, 1] = y
point_cos = torch.sum(-points_xy * rel_xy, dim=-1)/(
torch.norm(points_xy, 2, -1)*(torch.norm(rel_xy, 2, -1)))
remove_index = ((point_cos >= 0.7) & (points_norm > self.key_distance))
new_points = frame.points[~remove_index]
new_cos = frame.get_pointsCos()[~remove_index]
return LidarFrame(frame.index, new_points, new_cos,
frame.pose, new_keyframe=True)
def frame_maxdistance_change(self, frame, distance):
# kf check
valid_distance = distance + 0.5
new_keyframe_rays_norm = frame.rays_norm.reshape(-1)
new_keyframe_points = frame.points[new_keyframe_rays_norm <= valid_distance]
new_keyframe_pointsCos = frame.get_pointsCos()[new_keyframe_rays_norm <= valid_distance]
return LidarFrame(frame.index, new_keyframe_points, new_keyframe_pointsCos,
frame.pose, new_keyframe=True)
def insert_keyframe(self, frame, valid_distance=-1):
# kf check
print("insert keyframe")
valid_distance = self.key_distance + 0.01
new_keyframe_rays_norm = frame.rays_norm.reshape(-1)
mask = (torch.abs(frame.points[:, 0]) < valid_distance) & (torch.abs(frame.points[:, 1])
< valid_distance) & (torch.abs(frame.points[:, 2]) < valid_distance)
new_keyframe_points = frame.points[mask]
new_keyframe_pointsCos = frame.get_pointsCos()[mask]
new_keyframe = LidarFrame(frame.index, new_keyframe_points, new_keyframe_pointsCos,
frame.pose, new_keyframe=True)
if new_keyframe_points.shape[0] < 2*self.n_rays:
raise ValueError('valid_distance too small')
self.current_keyframe = new_keyframe
self.keyframe_graph += [new_keyframe]
# self.update_grid_features()
def create_voxels(self, frame):
points = frame.get_points().cuda()
pose = frame.get_pose().cuda()
print("frame id", frame.index+1)
print("trans ", pose[:3, 3]-2000)
points = points@pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
voxels = torch.div(points, self.voxel_size, rounding_mode='floor')
self.svo.insert(voxels.cpu().int())
self.update_grid_features()
@torch.enable_grad()
def get_embeddings(self, points_idx):
flatten_idx = points_idx.reshape(-1).long()
valid_flatten_idx = flatten_idx[flatten_idx.ne(-1)]
existence = F.embedding(valid_flatten_idx, self.voxel_id2embedding_id)
torch_add_idx = existence.eq(-1).view(-1)
torch_add = valid_flatten_idx[torch_add_idx]
if torch_add.shape[0] == 0:
return
start_num = self.current_num_embeds
end_num = start_num + torch_add.shape[0]
embeddings_add = torch.zeros((end_num-start_num, self.embed_dim),
dtype=torch.bfloat16)
# torch.nn.init.normal_(embeddings_add, std=0.01)
if self.dynamic_embeddings == None:
embeddings = [embeddings_add]
else:
embeddings = [self.dynamic_embeddings.detach().cpu(), embeddings_add]
embeddings = torch.cat(embeddings, dim=0)
self.dynamic_embeddings = embeddings.cuda().requires_grad_()
self.current_num_embeds = end_num
self.voxel_id2embedding_id[torch_add] = torch.arange(start_num, end_num, dtype=torch.int).view(-1, 1)
@torch.enable_grad()
def update_grid_features(self):
voxels, children, features = self.svo.get_centres_and_children()
centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.voxel_size
children = torch.cat([children, voxels[:, -1:]], -1)
centres = centres.float()
children = children.int()
map_states = {}
map_states["voxel_vertex_idx"] = features
centres.requires_grad_()
map_states["voxel_center_xyz"] = centres
map_states["voxel_structure"] = children
self.profiler.tick("Creating embedding")
self.get_embeddings(map_states["voxel_vertex_idx"])
self.profiler.tok("Creating embedding")
map_states["voxel_vertex_emb"] = self.dynamic_embeddings
map_states["voxel_id2embedding_id"] = self.voxel_id2embedding_id
self.map_states = map_states
@torch.no_grad()
def get_updated_poses(self, offset=-2000):
for i in range(len(self.frame_poses)):
ref_frame_ind, rel_pose = self.frame_poses[i]
ref_frame = self.keyframe_graph[ref_frame_ind]
ref_pose = ref_frame.get_pose().detach().cpu()
pose = ref_pose @ rel_pose
pose[:3, 3] += offset
self.final_poses += [pose.detach().cpu().numpy()]
self.frame_poses = []
return self.final_poses
@torch.no_grad()
def extract_mesh(self, res=8, clean_mesh=False):
sdf_network = self.decoder
sdf_network.eval()
voxels, _, features = self.svo.get_centres_and_children()
index = features.eq(-1).any(-1)
voxels = voxels[~index, :]
features = features[~index, :]
centres = (voxels[:, :3] + voxels[:, -1:] / 2) * self.voxel_size
encoder_states = {}
encoder_states["voxel_vertex_idx"] = features
encoder_states["voxel_center_xyz"] = centres
self.profiler.tick("Creating embedding")
self.get_embeddings(encoder_states["voxel_vertex_idx"])
self.profiler.tok("Creating embedding")
encoder_states["voxel_vertex_emb"] = self.dynamic_embeddings
encoder_states["voxel_id2embedding_id"] = self.voxel_id2embedding_id
frame_poses = self.get_updated_poses()
mesh = self.mesher.create_mesh(
self.decoder, encoder_states, self.voxel_size, voxels,
frame_poses=None, depth_maps=None,
clean_mseh=clean_mesh, require_color=False, offset=-2000, res=res)
return mesh
@torch.no_grad()
def extract_voxels(self, offset=-10):
voxels, _, features = self.svo.get_centres_and_children()
index = features.eq(-1).any(-1)
voxels = voxels[~index, :]
features = features[~index, :]
voxels = (voxels[:, :3] + voxels[:, -1:] / 2) * \
self.voxel_size + offset
# print(torch.max(features)-torch.count_nonzero(index))
return voxels
@torch.no_grad()
def save_debug_data(self, tracked_frame, offset=-10):
"""
save per-frame voxel, mesh and pose
"""
pose = tracked_frame.get_pose().detach().cpu().numpy()
pose[:3, 3] += offset
frame_poses = self.get_updated_poses()
mesh = self.extract_mesh(res=8, clean_mesh=True)
voxels = self.extract_voxels().detach().cpu().numpy()
keyframe_poses = [p.get_pose().detach().cpu().numpy()
for p in self.keyframe_graph]
for f in frame_poses:
f[:3, 3] += offset
for kf in keyframe_poses:
kf[:3, 3] += offset
verts = np.asarray(mesh.vertices)
faces = np.asarray(mesh.triangles)
color = np.asarray(mesh.vertex_colors)
self.logger.log_debug_data({
"pose": pose,
"updated_poses": frame_poses,
"mesh": {"verts": verts, "faces": faces, "color": color},
"voxels": voxels,
"voxel_size": self.voxel_size,
"keyframes": keyframe_poses,
"is_keyframe": (tracked_frame == self.current_keyframe)
}, tracked_frame.stamp)
================================================
FILE: src/nerfloam.py
================================================
from multiprocessing.managers import BaseManager
from time import sleep
import torch
import torch.multiprocessing as mp
from loggers import BasicLogger
from mapping import Mapping
from share import ShareData, ShareDataProxy
from tracking import Tracking
from utils.import_util import get_dataset
class nerfloam:
def __init__(self, args):
self.args = args
# logger (optional)
self.logger = BasicLogger(args)
# shared data
mp.set_start_method('spawn', force=True)
BaseManager.register('ShareData', ShareData, ShareDataProxy)
manager = BaseManager()
manager.start()
self.share_data = manager.ShareData()
# keyframe buffer
self.kf_buffer = mp.Queue(maxsize=1)
# data stream
self.data_stream = get_dataset(args)
# tracker
self.tracker = Tracking(args, self.data_stream, self.logger)
# mapper
self.mapper = Mapping(args, self.logger)
# initialize map with first frame
self.tracker.process_first_frame(self.kf_buffer)
self.processes = []
def start(self):
mapping_process = mp.Process(
target=self.mapper.spin, args=(self.share_data, self.kf_buffer))
mapping_process.start()
print("initializing the first frame ...")
sleep(20)
# self.share_data.stop_mapping=True
tracking_process = mp.Process(
target=self.tracker.spin, args=(self.share_data, self.kf_buffer))
tracking_process.start()
self.processes = [mapping_process, tracking_process]
def wait_child_processes(self):
for p in self.processes:
p.join()
@torch.no_grad()
def get_raw_trajectory(self):
return self.share_data.tracking_trajectory
@torch.no_grad()
def get_keyframe_poses(self):
keyframe_graph = self.mapper.keyframe_graph
poses = []
for keyframe in keyframe_graph:
poses.append(keyframe.get_pose().detach().cpu().numpy())
return poses
================================================
FILE: src/se3pose.py
================================================
import torch
import torch.nn as nn
import numpy as np
from copy import deepcopy
class OptimizablePose(nn.Module):
def __init__(self, init_pose):
super().__init__()
assert (isinstance(init_pose, torch.FloatTensor))
self.register_parameter('data', nn.Parameter(init_pose))
self.data.required_grad_ = True
def copy_from(self, pose):
self.data = deepcopy(pose.data)
def matrix(self):
Rt = torch.eye(4)
Rt[:3, :3] = self.rotation()
Rt[:3, 3] = self.translation()
return Rt
def rotation(self):
w = self.data[3:]
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[..., None, None]
I = torch.eye(3, device=w.device, dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
R = I+A*wx+B*wx@wx
return R
def translation(self,):
return self.data[:3]
@classmethod
def log(cls, R, eps=1e-7): # [...,3,3]
trace = R[..., 0, 0]+R[..., 1, 1]+R[..., 2, 2]
# ln(R) will explode if theta==pi
theta = ((trace-1)/2).clamp(-1+eps, 1-eps).acos_()[..., None, None] % np.pi
lnR = 1/(2*cls.taylor_A(theta)+1e-8) * (R-R.transpose(-2, -1)) # FIXME: wei-chiu finds it weird
w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]
w = torch.stack([w0, w1, w2], dim=-1)
return w
@classmethod
def from_matrix(cls, Rt, eps=1e-8): # [...,3,4]
R, u = Rt[:3, :3], Rt[:3, 3]
w = cls.log(R)
return OptimizablePose(torch.cat([u, w], dim=-1))
@classmethod
def skew_symmetric(cls, w):
w0, w1, w2 = w.unbind(dim=-1)
O = torch.zeros_like(w0)
wx = torch.stack([
torch.stack([O, -w2, w1], dim=-1),
torch.stack([w2, O, -w0], dim=-1),
torch.stack([-w1, w0, O], dim=-1)], dim=-2)
return wx
@classmethod
def taylor_A(cls, x, nth=10):
# Taylor expansion of sin(x)/x
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
if i > 0:
denom *= (2*i)*(2*i+1)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
@classmethod
def taylor_B(cls, x, nth=10):
# Taylor expansion of (1-cos(x))/x**2
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+1)*(2*i+2)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
@classmethod
def taylor_C(cls, x, nth=10):
# Taylor expansion of (x-sin(x))/x**3
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+2)*(2*i+3)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
if __name__ == '__main__':
before = torch.tensor([[-0.955421, 0.119616, - 0.269932, 2.655830],
[0.295248, 0.388339, - 0.872939, 2.981598],
[0.000408, - 0.913720, - 0.406343, 1.368648],
[0.000000, 0.000000, 0.000000, 1.000000]])
pose = OptimizablePose.from_matrix(before)
print(pose.rotation())
print(pose.translation())
after = pose.matrix()
print(after)
print(torch.abs((before-after)[:3, 3]))
================================================
FILE: src/share.py
================================================
from multiprocessing.managers import BaseManager, NamespaceProxy
from copy import deepcopy
import torch.multiprocessing as mp
from time import sleep
import sys
class ShareDataProxy(NamespaceProxy):
_exposed_ = ('__getattribute__', '__setattr__')
class ShareData:
global lock
lock = mp.RLock()
def __init__(self):
self.__stop_mapping = False
self.__stop_tracking = False
self.__decoder = None
self.__voxels = None
self.__octree = None
self.__states = None
self.__tracking_trajectory = []
@property
def decoder(self):
with lock:
return deepcopy(self.__decoder)
print("========== decoder get ==========")
sys.stdout.flush()
@decoder.setter
def decoder(self, decoder):
with lock:
self.__decoder = deepcopy(decoder)
# print("========== decoder set ==========")
sys.stdout.flush()
@property
def voxels(self):
with lock:
return deepcopy(self.__voxels)
print("========== voxels get ==========")
sys.stdout.flush()
@voxels.setter
def voxels(self, voxels):
with lock:
self.__voxels = deepcopy(voxels)
print("========== voxels set ==========")
sys.stdout.flush()
@property
def octree(self):
with lock:
return deepcopy(self.__octree)
print("========== octree get ==========")
sys.stdout.flush()
@octree.setter
def octree(self, octree):
with lock:
self.__octree = deepcopy(octree)
print("========== octree set ==========")
sys.stdout.flush()
@property
def states(self):
with lock:
return self.__states
print("========== states get ==========")
sys.stdout.flush()
@states.setter
def states(self, states):
with lock:
self.__states = states
# print("========== states set ==========")
sys.stdout.flush()
@property
def stop_mapping(self):
with lock:
return self.__stop_mapping
print("========== stop_mapping get ==========")
sys.stdout.flush()
@stop_mapping.setter
def stop_mapping(self, stop_mapping):
with lock:
self.__stop_mapping = stop_mapping
print("========== stop_mapping set ==========")
sys.stdout.flush()
@property
def stop_tracking(self):
with lock:
return self.__stop_tracking
print("========== stop_tracking get ==========")
sys.stdout.flush()
@stop_tracking.setter
def stop_tracking(self, stop_tracking):
with lock:
self.__stop_tracking = stop_tracking
print("========== stop_tracking set ==========")
sys.stdout.flush()
@property
def tracking_trajectory(self):
with lock:
return deepcopy(self.__tracking_trajectory)
print("========== tracking_trajectory get ==========")
sys.stdout.flush()
def push_pose(self, pose):
with lock:
self.__tracking_trajectory.append(deepcopy(pose))
# print("========== push_pose ==========")
sys.stdout.flush()
================================================
FILE: src/tracking.py
================================================
import torch
import numpy as np
from tqdm import tqdm
from criterion import Criterion
from lidarFrame import LidarFrame
from utils.import_util import get_property
from utils.profile_util import Profiler
from variations.render_helpers import fill_in, render_rays, track_frame
from se3pose import OptimizablePose
from time import sleep
from copy import deepcopy
class Tracking:
def __init__(self, args, data_stream, logger):
self.args = args
self.last_frame_id = 0
self.last_frame = None
self.data_stream = data_stream
self.logger = logger
self.loss_criteria = Criterion(args)
self.voxel_size = args.mapper_specs["voxel_size"]
self.N_rays = args.tracker_specs["N_rays"]
self.num_iterations = args.tracker_specs["num_iterations"]
self.sdf_truncation = args.criteria["sdf_truncation"]
self.learning_rate = args.tracker_specs["learning_rate"]
self.start_frame = args.tracker_specs["start_frame"]
self.end_frame = args.tracker_specs["end_frame"]
self.step_size = args.tracker_specs["step_size"]
# self.keyframe_freq = args.tracker_specs["keyframe_freq"]
self.max_voxel_hit = args.tracker_specs["max_voxel_hit"]
self.max_distance = args.data_specs["max_depth"]
self.step_size = self.step_size * self.voxel_size
self.read_offset = args.tracker_specs["read_offset"]
self.mesh_freq = args.debug_args["mesh_freq"]
if self.end_frame <= 0:
self.end_frame = len(self.data_stream)-1
# sanity check on the lower/upper bounds
self.start_frame = min(self.start_frame, len(self.data_stream))
self.end_frame = min(self.end_frame, len(self.data_stream))
self.rel_pose = None
# profiler
verbose = get_property(args.debug_args, "verbose", False)
self.profiler = Profiler(verbose=verbose)
self.profiler.enable()
def process_first_frame(self, kf_buffer):
init_pose = self.data_stream.get_init_pose(self.start_frame)
index, points, pointcos, _ = self.data_stream[self.start_frame]
first_frame = LidarFrame(index, points, pointcos, init_pose)
first_frame.pose.requires_grad_(False)
first_frame.points.requires_grad_(False)
print("******* initializing first_frame:", first_frame.index)
kf_buffer.put(first_frame, block=True)
self.last_frame = first_frame
self.start_frame += 1
def spin(self, share_data, kf_buffer):
print("******* tracking process started! *******")
progress_bar = tqdm(
range(self.start_frame, self.end_frame+1), position=0)
progress_bar.set_description("tracking frame")
for frame_id in progress_bar:
if frame_id % self.read_offset != 0:
continue
if share_data.stop_tracking:
break
data_in = self.data_stream[frame_id]
current_frame = LidarFrame(*data_in)
if isinstance(data_in[3], np.ndarray):
self.last_frame = current_frame
self.check_keyframe(current_frame, kf_buffer)
else:
self.do_tracking(share_data, current_frame, kf_buffer)
share_data.stop_mapping = True
print("******* tracking process died *******")
sleep(60)
while not kf_buffer.empty():
sleep(60)
def check_keyframe(self, check_frame, kf_buffer):
try:
kf_buffer.put(check_frame, block=True)
except:
pass
def do_tracking(self, share_data, current_frame, kf_buffer):
self.profiler.tick("before track1111")
decoder = share_data.decoder.cuda()
self.profiler.tok("before track1111")
self.profiler.tick("before track2222")
map_states = share_data.states
map_states["voxel_vertex_emb"] = map_states["voxel_vertex_emb"].cuda()
self.profiler.tok("before track2222")
constant_move_pose = self.last_frame.get_pose().detach()
input_pose = deepcopy(self.last_frame.pose)
input_pose.requires_grad_(False)
if self.rel_pose != None:
constant_move_pose[:3, 3] = (constant_move_pose @ (self.rel_pose))[:3, 3]
input_pose.data[:3] = constant_move_pose[:3, 3].T
torch.cuda.empty_cache()
self.profiler.tick("track frame")
frame_pose, hit_mask = track_frame(
input_pose,
current_frame,
map_states,
decoder,
self.loss_criteria,
self.voxel_size,
self.N_rays,
self.step_size,
self.num_iterations if self.rel_pose != None else self.num_iterations*5,
self.sdf_truncation,
self.learning_rate if self.rel_pose != None else self.learning_rate,
self.max_voxel_hit,
self.max_distance,
profiler=self.profiler,
depth_variance=True
)
self.profiler.tok("track frame")
if hit_mask == None:
current_frame.pose = OptimizablePose.from_matrix(constant_move_pose)
else:
current_frame.pose = frame_pose
current_frame.hit_ratio = hit_mask.sum() / self.N_rays
self.rel_pose = torch.linalg.inv(self.last_frame.get_pose().detach()) @ current_frame.get_pose().detach()
current_frame.set_rel_pose(self.rel_pose)
self.last_frame = current_frame
self.profiler.tick("transport frame")
self.check_keyframe(current_frame, kf_buffer)
self.profiler.tok("transport frame")
================================================
FILE: src/utils/__init__.py
================================================
================================================
FILE: src/utils/import_util.py
================================================
from importlib import import_module
import argparse
def get_dataset(args):
Dataset = import_module("dataset."+args.dataset)
return Dataset.DataLoader(**args.data_specs)
def get_decoder(args):
Decoder = import_module("variations."+args.decoder)
return Decoder.Decoder(**args.decoder_specs)
def get_property(args, name, default):
if isinstance(args, dict):
return args.get(name, default)
elif isinstance(args, argparse.Namespace):
if hasattr(args, name):
return vars(args)[name]
else:
return default
else:
raise ValueError(f"unkown dict/namespace type: {type(args)}")
================================================
FILE: src/utils/mesh_util.py
================================================
import math
import torch
import numpy as np
import open3d as o3d
from scipy.spatial import cKDTree
from skimage.measure import marching_cubes
from variations.render_helpers import get_scores, eval_points
class MeshExtractor:
def __init__(self, args):
self.voxel_size = args.mapper_specs["voxel_size"]
self.rays_d = None
self.depth_points = None
@ torch.no_grad()
def linearize_id(self, xyz, n_xyz):
return xyz[:, 2] + n_xyz[-1] * xyz[:, 1] + (n_xyz[-1] * n_xyz[-2]) * xyz[:, 0]
@torch.no_grad()
def downsample_points(self, points, voxel_size=0.01):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd = pcd.voxel_down_sample(voxel_size)
return np.asarray(pcd.points)
@torch.no_grad()
def get_rays(self, w=None, h=None, K=None):
w = self.w if w == None else w
h = self.h if h == None else h
if K is None:
K = np.eye(3)
K[0, 0] = self.K[0, 0] * w / self.w
K[1, 1] = self.K[1, 1] * h / self.h
K[0, 2] = self.K[0, 2] * w / self.w
K[1, 2] = self.K[1, 2] * h / self.h
ix, iy = torch.meshgrid(
torch.arange(w), torch.arange(h), indexing='xy')
rays_d = torch.stack(
[(ix-K[0, 2]) / K[0, 0],
(iy-K[1, 2]) / K[1, 1],
torch.ones_like(ix)], -1).float()
return rays_d
@torch.no_grad()
def get_valid_points(self, frame_poses, depth_maps):
if isinstance(frame_poses, list):
all_points = []
print("extracting all points")
for i in range(0, len(frame_poses), 5):
pose = frame_poses[i]
depth = depth_maps[i]
points = self.rays_d * depth.unsqueeze(-1)
points = points.reshape(-1, 3)
points = points @ pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
if len(all_points) == 0:
all_points = points.detach().cpu().numpy()
else:
all_points = np.concatenate(
[all_points, points.detach().cpu().numpy()], 0)
print("downsample all points")
all_points = self.downsample_points(all_points)
return all_points
else:
pose = frame_poses
depth = depth_maps
points = self.rays_d * depth.unsqueeze(-1)
points = points.reshape(-1, 3)
points = points @ pose[:3, :3].transpose(-1, -2) + pose[:3, 3]
if self.depth_points is None:
self.depth_points = points.detach().cpu().numpy()
else:
self.depth_points = np.concatenate(
[self.depth_points, points], 0)
self.depth_points = self.downsample_points(self.depth_points)
return self.depth_points
@ torch.no_grad()
def create_mesh(self, decoder, map_states, voxel_size, voxels,
frame_poses=None, depth_maps=None, clean_mseh=False,
require_color=False, offset=-80, res=8):
sdf_grid = get_scores(decoder, map_states, voxel_size, bits=res)
sdf_grid = sdf_grid.reshape(-1, res, res, res, 1)
voxel_centres = map_states["voxel_center_xyz"]
verts, faces = self.marching_cubes(voxel_centres, sdf_grid)
if clean_mseh:
print("********** get points from frames **********")
all_points = self.get_valid_points(frame_poses, depth_maps)
print("********** construct kdtree **********")
kdtree = cKDTree(all_points)
print("********** query kdtree **********")
point_mask = kdtree.query_ball_point(
verts, voxel_size * 0.5, workers=12, return_length=True)
print("********** finished querying kdtree **********")
point_mask = point_mask > 0
face_mask = point_mask[faces.reshape(-1)].reshape(-1, 3).any(-1)
faces = faces[face_mask]
if require_color:
print("********** get color from network **********")
verts_torch = torch.from_numpy(verts).float().cuda()
batch_points = torch.split(verts_torch, 1000)
colors = []
for points in batch_points:
voxel_pos = points // self.voxel_size
batch_voxels = voxels[:, :3].cuda()
batch_voxels = batch_voxels.unsqueeze(
0).repeat(voxel_pos.shape[0], 1, 1)
# filter outliers
nonzeros = (batch_voxels == voxel_pos.unsqueeze(1)).all(-1)
nonzeros = torch.where(nonzeros, torch.ones_like(
nonzeros).int(), -torch.ones_like(nonzeros).int())
sorted, index = torch.sort(nonzeros, dim=-1, descending=True)
sorted = sorted[:, 0]
index = index[:, 0]
valid = (sorted != -1)
color_empty = torch.zeros_like(points)
points = points[valid, :]
index = index[valid]
# get color
if len(points) > 0:
color = eval_points(decoder, map_states,
points, index, voxel_size).cuda()
color_empty[valid] = color
colors += [color_empty]
colors = torch.cat(colors, 0)
mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(verts+offset)
mesh.triangles = o3d.utility.Vector3iVector(faces)
if require_color:
mesh.vertex_colors = o3d.utility.Vector3dVector(
colors.detach().cpu().numpy())
mesh.compute_vertex_normals()
return mesh
@ torch.no_grad()
def marching_cubes(self, voxels, sdf):
voxels = voxels[:, :3]
sdf = sdf[..., 0]
res = 1.0 / (sdf.shape[1] - 1)
spacing = [res, res, res]
num_verts = 0
total_verts = []
total_faces = []
for i in range(len(voxels)):
sdf_volume = sdf[i].detach().cpu().numpy()
if np.min(sdf_volume) > 0 or np.max(sdf_volume) < 0:
continue
verts, faces, _, _ = marching_cubes(sdf_volume, 0, spacing=spacing)
verts -= 0.5
verts *= self.voxel_size
verts += voxels[i].detach().cpu().numpy()
faces += num_verts
num_verts += verts.shape[0]
total_verts += [verts]
total_faces += [faces]
total_verts = np.concatenate(total_verts)
total_faces = np.concatenate(total_faces)
return total_verts, total_faces
================================================
FILE: src/utils/profile_util.py
================================================
from time import time
import torch
class Profiler(object):
def __init__(self, verbose=False) -> None:
self.timer = dict()
self.time_log = dict()
self.enabled = False
self.verbose = verbose
def enable(self):
self.enabled = True
def disable(self):
self.enabled = False
def tick(self, name):
if not self.enabled:
return
self.timer[name] = time()
if name not in self.time_log:
self.time_log[name] = list()
def tok(self, name):
if not self.enabled:
return
if name not in self.timer:
return
torch.cuda.synchronize()
elapsed = time() - self.timer[name]
if self.verbose:
print(f"{name}: {elapsed*1000:.2f} ms")
else:
self.time_log[name].append(elapsed * 1000)
================================================
FILE: src/utils/sample_util.py
================================================
import torch
def sampling_without_replacement(logp, k):
def gumbel_like(u):
return -torch.log(-torch.log(torch.rand_like(u) + 1e-7) + 1e-7)
scores = logp + gumbel_like(logp)
return scores.topk(k, dim=-1)[1]
def sample_rays(mask, num_samples):
B, H, W = mask.shape
probs = mask / (mask.sum() + 1e-9)
flatten_probs = probs.reshape(B, -1)
sampled_index = sampling_without_replacement(
torch.log(flatten_probs + 1e-9), num_samples)
sampled_masks = (torch.zeros_like(
flatten_probs).scatter_(-1, sampled_index, 1).reshape(B, H, W) > 0)
return sampled_masks
================================================
FILE: src/variations/decode_morton.py
================================================
import numpy as np
def compact(value):
x = value & 0x1249249249249249
x = (x | x >> 2) & 0x10c30c30c30c30c3
x = (x | x >> 4) & 0x100f00f00f00f00f
x = (x | x >> 8) & 0x1f0000ff0000ff
x = (x | x >> 16) & 0x1f00000000ffff
x = (x | x >> 32) & 0x1fffff
return x
def decode(code):
return compact(code >> 0), compact(code >> 1), compact(code >> 2)
for i in range(10):
x, y, z = decode(samples_valid['sampled_point_voxel_idx'][i])
print(x, y, z)
print(torch.sqrt((x-80)**2+(y-80)**2+(z-80)**2))
print(samples_valid['sampled_point_depth'][i])
================================================
FILE: src/variations/lidar.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class GaussianFourierFeatureTransform(torch.nn.Module):
"""
Modified based on the implementation of Gaussian Fourier feature mapping.
"Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
https://arxiv.org/abs/2006.10739
https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html
"""
def __init__(self, num_input_channels, mapping_size=93, scale=25, learnable=True):
super().__init__()
if learnable:
self._B = nn.Parameter(torch.randn(
(num_input_channels, mapping_size)) * scale)
else:
self._B = torch.randn((num_input_channels, mapping_size)) * scale
self.embedding_size = mapping_size
def forward(self, x):
# x = x.squeeze(0)
assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(x.dim())
x = x @ self._B.to(x.device)
return torch.sin(x)
class Nerf_positional_embedding(torch.nn.Module):
"""
Nerf positional embedding.
"""
def __init__(self, in_dim, multires, log_sampling=True):
super().__init__()
self.log_sampling = log_sampling
self.include_input = True
self.periodic_fns = [torch.sin, torch.cos]
self.max_freq_log2 = multires-1
self.num_freqs = multires
self.max_freq = self.max_freq_log2
self.N_freqs = self.num_freqs
self.embedding_size = multires*in_dim*2 + in_dim
def forward(self, x):
# x = x.squeeze(0)
assert x.dim() == 2, 'Expected 2D input (got {}D input)'.format(
x.dim())
if self.log_sampling:
freq_bands = 2.**torch.linspace(0.,
self.max_freq, steps=self.N_freqs)
else:
freq_bands = torch.linspace(
2.**0., 2.**self.max_freq, steps=self.N_freqs)
output = []
if self.include_input:
output.append(x)
for freq in freq_bands:
for p_fn in self.periodic_fns:
output.append(p_fn(x * freq))
ret = torch.cat(output, dim=1)
return ret
class Same(nn.Module):
def __init__(self, in_dim) -> None:
super().__init__()
self.embedding_size = in_dim
def forward(self, x):
return x
class Decoder(nn.Module):
def __init__(self,
depth=8,
width=258,
in_dim=3,
sdf_dim=128,
skips=[4],
multires=6,
embedder='none',
point_dim=3,
local_coord=False,
**kwargs) -> None:
super().__init__()
self.D = depth
self.W = width
self.skips = skips
self.point_dim = point_dim
if embedder == 'nerf':
self.pe = Nerf_positional_embedding(in_dim, multires)
elif embedder == 'none':
self.pe = Same(in_dim)
elif embedder == 'gaussian':
self.pe = GaussianFourierFeatureTransform(in_dim)
else:
raise NotImplementedError("unknown positional encoder")
self.pts_linears = nn.ModuleList(
[nn.Linear(self.pe.embedding_size, width)] + [nn.Linear(width, width) if i not in self.skips else nn.Linear(width + self.pe.embedding_size, width) for i in range(depth-1)])
self.sdf_out = nn.Linear(width, 1)
def get_values(self, input):
x = self.pe(input)
# point = input[:, -3:]
h = x
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([x, h], -1)
# outputs = self.output_linear(h)
# outputs[:, :3] = torch.sigmoid(outputs[:, :3])
sdf_out = self.sdf_out(h)
return sdf_out
def forward(self, inputs):
outputs = self.get_values(inputs)
return {
'sdf': outputs,
# 'depth': outputs[:, 1]
}
================================================
FILE: src/variations/render_helpers.py
================================================
from copy import deepcopy
import torch
import torch.nn.functional as F
from .voxel_helpers import ray_intersect, ray_sample
from torch.autograd import grad
def ray(ray_start, ray_dir, depths):
return ray_start + ray_dir * depths
def fill_in(shape, mask, input, initial=1.0):
if isinstance(initial, torch.Tensor):
output = initial.expand(*shape)
else:
output = input.new_ones(*shape) * initial
return output.masked_scatter(mask.unsqueeze(-1).expand(*shape), input)
def masked_scatter(mask, x):
B, K = mask.size()
if x.dim() == 1:
return x.new_zeros(B, K).masked_scatter(mask, x)
return x.new_zeros(B, K, x.size(-1)).masked_scatter(
mask.unsqueeze(-1).expand(B, K, x.size(-1)), x
)
def masked_scatter_ones(mask, x):
B, K = mask.size()
if x.dim() == 1:
return x.new_ones(B, K).masked_scatter(mask, x)
return x.new_ones(B, K, x.size(-1)).masked_scatter(
mask.unsqueeze(-1).expand(B, K, x.size(-1)), x
)
@torch.enable_grad()
def trilinear_interp(p, q, point_feats):
weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True)
if point_feats.dim() == 2:
point_feats = point_feats.view(point_feats.size(0), 8, -1)
point_feats = (weights * point_feats).sum(1)
return point_feats
def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
c = torch.arange(1, 2 * bits, 2, device=point_xyz.device)
ox, oy, oz = torch.meshgrid([c, c, c], indexing='ij')
offset = (torch.cat([
ox.reshape(-1, 1),
oy.reshape(-1, 1),
oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1)
if not offset_only:
return (
point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel)
return offset.type_as(point_xyz) * quarter_voxel
@torch.enable_grad()
def get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size):
# tri-linear interpolation
p = ((sampled_xyz - point_xyz) / voxel_size + 0.5).unsqueeze(1)
q = offset_points(p, 0.5, offset_only=True).unsqueeze(0) + 0.5
feats = trilinear_interp(p, q, point_feats).float()
# if self.args.local_coord:
# feats = torch.cat([(p-.5).squeeze(1).float(), feats], dim=-1)
return feats
@torch.enable_grad()
def get_features(samples, map_states, voxel_size):
# encoder states
point_idx = map_states["voxel_vertex_idx"].cuda()
point_xyz = map_states["voxel_center_xyz"].cuda()
values = map_states["voxel_vertex_emb"]
point_id2embedid = map_states["voxel_id2embedding_id"]
# ray point samples
sampled_idx = samples["sampled_point_voxel_idx"].long()
sampled_xyz = samples["sampled_point_xyz"]
sampled_dis = samples["sampled_point_distance"]
point_xyz = F.embedding(sampled_idx, point_xyz).requires_grad_()
selected_points_idx = F.embedding(sampled_idx, point_idx)
flatten_selected_points_idx = selected_points_idx.view(-1)
embed_idx = F.embedding(flatten_selected_points_idx.cpu(), point_id2embedid).squeeze(-1)
point_feats = F.embedding(embed_idx.cuda(), values).view(point_xyz.size(0), -1)
feats = get_embeddings(sampled_xyz, point_xyz, point_feats, voxel_size)
inputs = {"xyz": point_xyz, "dists": sampled_dis, "emb": feats.cuda()}
return inputs
@torch.no_grad()
def get_scores(sdf_network, map_states, voxel_size, bits=8):
feats = map_states["voxel_vertex_idx"]
points = map_states["voxel_center_xyz"]
values = map_states["voxel_vertex_emb"]
point_id2embedid = map_states["voxel_id2embedding_id"]
chunk_size = 10000
res = bits # -1
@torch.no_grad()
def get_scores_once(feats, points, values, point_id2embedid):
torch.cuda.empty_cache()
# sample points inside voxels
start = -0.5
end = 0.5 # - 1./bits
x = y = z = torch.linspace(start, end, res)
# z = torch.linspace(1, 1, res)
xx, yy, zz = torch.meshgrid(x, y, z)
sampled_xyz = torch.stack([xx, yy, zz], dim=-1).float().cuda()
sampled_xyz *= voxel_size
sampled_xyz = sampled_xyz.reshape(1, -1, 3) + points.unsqueeze(1)
sampled_idx = torch.arange(points.size(0), device=points.device)
sampled_idx = sampled_idx[:, None].expand(*sampled_xyz.size()[:2])
sampled_idx = sampled_idx.reshape(-1)
sampled_xyz = sampled_xyz.reshape(-1, 3)
if sampled_xyz.shape[0] == 0:
return
field_inputs = get_features(
{
"sampled_point_xyz": sampled_xyz,
"sampled_point_voxel_idx": sampled_idx,
"sampled_point_ray_direction": None,
"sampled_point_distance": None,
},
{
"voxel_vertex_idx": feats,
"voxel_center_xyz": points,
"voxel_vertex_emb": values,
"voxel_id2embedding_id": point_id2embedid
},
voxel_size
)
field_inputs = field_inputs["emb"]
# evaluation with density
sdf_values = sdf_network.get_values(field_inputs.float().cuda())
return sdf_values.reshape(-1, res ** 3, 1).detach().cpu()
return torch.cat([
get_scores_once(feats[i: i + chunk_size],
points[i: i + chunk_size].cuda(), values, point_id2embedid)
for i in range(0, points.size(0), chunk_size)], 0).view(-1, res, res, res, 1)
@torch.no_grad()
def eval_points(sdf_network, map_states, sampled_xyz, sampled_idx, voxel_size):
feats = map_states["voxel_vertex_idx"]
points = map_states["voxel_center_xyz"]
values = map_states["voxel_vertex_emb"]
# sampled_xyz = sampled_xyz.reshape(1, 3) + points.unsqueeze(1)
# sampled_idx = sampled_idx[None, :].expand(*sampled_xyz.size()[:2])
sampled_idx = sampled_idx.reshape(-1)
sampled_xyz = sampled_xyz.reshape(-1, 3)
if sampled_xyz.shape[0] == 0:
return
field_inputs = get_features(
{
"sampled_point_xyz": sampled_xyz,
"sampled_point_voxel_idx": sampled_idx,
"sampled_point_ray_direction": None,
"sampled_point_distance": None,
},
{
"voxel_vertex_idx": feats,
"voxel_center_xyz": points,
"voxel_vertex_emb": values,
},
voxel_size
)
# evaluation with density
sdf_values = sdf_network.get_values(field_inputs['emb'].float().cuda())
return sdf_values.reshape(-1, 4)[:, :3].detach().cpu()
def render_rays(
rays_o,
rays_d,
map_states,
sdf_network,
step_size,
voxel_size,
truncation,
max_voxel_hit,
max_distance,
chunk_size=10000,
profiler=None,
return_raw=False
):
torch.cuda.empty_cache()
centres = map_states["voxel_center_xyz"].cuda()
childrens = map_states["voxel_structure"].cuda()
if profiler is not None:
profiler.tick("ray_intersect")
# print("Center", rays_o[0][0])
intersections, hits = ray_intersect(
rays_o, rays_d, centres,
childrens, voxel_size, max_voxel_hit, max_distance)
if profiler is not None:
profiler.tok("ray_intersect")
if hits.sum() <= 0:
return
ray_mask = hits.view(1, -1)
intersections = {
name: outs[ray_mask].reshape(-1, outs.size(-1))
for name, outs in intersections.items()
}
rays_o = rays_o[ray_mask].reshape(-1, 3)
rays_d = rays_d[ray_mask].reshape(-1, 3)
if profiler is not None:
profiler.tick("ray_sample")
samples = ray_sample(intersections, step_size=step_size)
if samples == None:
return
if profiler is not None:
profiler.tok("ray_sample")
sampled_depth = samples['sampled_point_depth']
sampled_idx = samples['sampled_point_voxel_idx'].long()
# only compute when the ray hits
sample_mask = sampled_idx.ne(-1)
if sample_mask.sum() == 0: # miss everything skip
return None, 0
sampled_xyz = ray(rays_o.unsqueeze(
1), rays_d.unsqueeze(1), sampled_depth.unsqueeze(2))
sampled_dir = rays_d.unsqueeze(1).expand(
*sampled_depth.size(), rays_d.size()[-1])
sampled_dir = sampled_dir / \
(torch.norm(sampled_dir, 2, -1, keepdim=True) + 1e-8)
samples['sampled_point_xyz'] = sampled_xyz
samples['sampled_point_ray_direction'] = sampled_dir
# apply mask
samples_valid = {name: s[sample_mask] for name, s in samples.items()}
# print("samples_valid_xyz", samples["sampled_point_xyz"].shape)
num_points = samples_valid['sampled_point_depth'].shape[0]
field_outputs = []
if chunk_size < 0:
chunk_size = num_points
final_xyz = []
xyz = 0
for i in range(0, num_points, chunk_size):
torch.cuda.empty_cache()
chunk_samples = {name: s[i:i+chunk_size]
for name, s in samples_valid.items()}
# get encoder features as inputs
if profiler is not None:
profiler.tick("get_features")
chunk_inputs = get_features(chunk_samples, map_states, voxel_size)
xyz = chunk_inputs["xyz"]
if profiler is not None:
profiler.tok("get_features")
# add coordinate information
chunk_inputs = chunk_inputs["emb"]
# forward implicit fields
if profiler is not None:
profiler.tick("render_core")
chunk_outputs = sdf_network(chunk_inputs)
if profiler is not None:
profiler.tok("render_core")
final_xyz.append(xyz)
field_outputs.append(chunk_outputs)
field_outputs = {name: torch.cat(
[r[name] for r in field_outputs], dim=0) for name in field_outputs[0]}
final_xyz = torch.cat(final_xyz, 0)
outputs = field_outputs['sdf']
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
sdf_grad = grad(outputs=outputs,
inputs=xyz,
grad_outputs=d_points,
retain_graph=True,)
outputs = {'sample_mask': sample_mask}
sdf = masked_scatter_ones(sample_mask, field_outputs['sdf']).squeeze(-1)
# depth = masked_scatter(sample_mask, field_outputs['depth'])
# colour = torch.sigmoid(colour)
sample_mask = outputs['sample_mask']
valid_mask = torch.where(
sample_mask, torch.ones_like(
sample_mask), torch.zeros_like(sample_mask)
)
return {
"z_vals": samples["sampled_point_depth"],
"sdf": sdf,
"ray_mask": ray_mask,
"valid_mask": valid_mask,
"sampled_xyz": xyz,
}
def bundle_adjust_frames(
keyframe_graph,
embeddings,
map_states,
sdf_network,
loss_criteria,
voxel_size,
step_size,
N_rays=512,
num_iterations=10,
truncation=0.1,
max_voxel_hit=10,
max_distance=10,
learning_rate=[1e-2, 1e-2, 5e-3],
update_pose=True,
update_decoder=True,
profiler=None
):
if profiler is not None:
profiler.tick("mapping_add_optim")
optimize_params = [{'params': embeddings, 'lr': learning_rate[0]}]
if update_decoder:
optimize_params += [{'params': sdf_network.parameters(),
'lr': learning_rate[1]}]
for keyframe in keyframe_graph:
if keyframe.index != 0 and update_pose:
keyframe.pose.requires_grad_(True)
optimize_params += [{
'params': keyframe.pose.parameters(), 'lr': learning_rate[2]
}]
optim = torch.optim.Adam(optimize_params)
if profiler is not None:
profiler.tok("mapping_add_optim")
for iter in range(num_iterations):
torch.cuda.empty_cache()
rays_o = []
rays_d = []
rgb_samples = []
depth_samples = []
points_samples = []
pointsCos_samples = []
if iter == 0 and profiler is not None:
profiler.tick("mapping sample_rays")
for frame in keyframe_graph:
torch.cuda.empty_cache()
pose = frame.get_pose().cuda()
frame.sample_rays(N_rays)
sample_mask = frame.sample_mask.cuda()
sampled_rays_d = frame.rays_d[sample_mask].cuda()
# print(sampled_rays_d)
R = pose[: 3, : 3].transpose(-1, -2)
sampled_rays_d = sampled_rays_d@R
sampled_rays_o = pose[: 3, 3].reshape(1, -1).expand_as(sampled_rays_d)
rays_d += [sampled_rays_d]
rays_o += [sampled_rays_o]
points_samples += [frame.points.unsqueeze(1).cuda()[sample_mask]]
pointsCos_samples += [frame.pointsCos.unsqueeze(1).cuda()[sample_mask]]
# rgb_samples += [frame.rgb.cuda()[sample_mask]]
# depth_samples += [frame.depth.cuda()[sample_mask]]
rays_d = torch.cat(rays_d, dim=0).unsqueeze(0)
rays_o = torch.cat(rays_o, dim=0).unsqueeze(0)
points_samples = torch.cat(points_samples, dim=0).unsqueeze(0)
pointsCos_samples = torch.cat(pointsCos_samples, dim=0).unsqueeze(0)
if iter == 0 and profiler is not None:
profiler.tok("mapping sample_rays")
if iter == 0 and profiler is not None:
profiler.tick("mapping rendering")
final_outputs = render_rays(
rays_o,
rays_d,
map_states,
sdf_network,
step_size,
voxel_size,
truncation,
max_voxel_hit,
max_distance,
chunk_size=-1,
profiler=profiler if iter == 0 else None
)
if final_outputs == None:
print("Encouter a bug while Mapping, currently not be fixed, Continue!!")
hit_mask = None
continue
if iter == 0 and profiler is not None:
profiler.tok("mapping rendering")
# if final_outputs == None:
# continue
if iter == 0 and profiler is not None:
profiler.tick("mapping back proj")
torch.cuda.empty_cache()
loss, _ = loss_criteria(
final_outputs, points_samples, pointsCos_samples)
optim.zero_grad()
loss.backward()
optim.step()
if iter == 0 and profiler is not None:
profiler.tok("mapping back proj")
def track_frame(
frame_pose,
curr_frame,
map_states,
sdf_network,
loss_criteria,
voxel_size,
N_rays=512,
step_size=0.05,
num_iterations=10,
truncation=0.1,
learning_rate=1e-3,
max_voxel_hit=10,
max_distance=10,
profiler=None,
depth_variance=False
):
torch.cuda.empty_cache()
init_pose = deepcopy(frame_pose).cuda()
init_pose.requires_grad_(True)
optim = torch.optim.Adam(init_pose.parameters(),
lr=learning_rate*2 if curr_frame.index < 2
else learning_rate/3)
for iter in range(num_iterations):
torch.cuda.empty_cache()
if iter == 0 and profiler is not None:
profiler.tick("track sample_rays")
curr_frame.sample_rays(N_rays, track=True)
if iter == 0 and profiler is not None:
profiler.tok("track sample_rays")
sample_mask = curr_frame.sample_mask
ray_dirs = curr_frame.rays_d[sample_mask].unsqueeze(0).cuda()
points_samples = curr_frame.points.unsqueeze(1).cuda()[sample_mask]
pointsCos_samples = curr_frame.pointsCos.unsqueeze(1).cuda()[sample_mask]
ray_dirs_iter = ray_dirs.squeeze(
0) @ init_pose.rotation().transpose(-1, -2)
ray_dirs_iter = ray_dirs_iter.unsqueeze(0)
ray_start_iter = init_pose.translation().reshape(
1, 1, -1).expand_as(ray_dirs_iter).cuda().contiguous()
if iter == 0 and profiler is not None:
profiler.tick("track render_rays")
final_outputs = render_rays(
ray_start_iter,
ray_dirs_iter,
map_states,
sdf_network,
step_size,
voxel_size,
truncation,
max_voxel_hit,
max_distance,
chunk_size=-2,
profiler=profiler if iter == 0 else None
)
if final_outputs == None:
print("Encouter a bug while Tracking, currently not be fixed, Restarting!!")
hit_mask = None
break
torch.cuda.empty_cache()
if iter == 0 and profiler is not None:
profiler.tok("track render_rays")
hit_mask = final_outputs["ray_mask"].view(N_rays)
final_outputs["ray_mask"] = hit_mask
if iter == 0 and profiler is not None:
profiler.tick("track loss_criteria")
loss, _ = loss_criteria(
final_outputs, points_samples, pointsCos_samples, weight_depth_loss=depth_variance)
if iter == 0 and profiler is not None:
profiler.tok("track loss_criteria")
if iter == 0 and profiler is not None:
profiler.tick("track backward step")
optim.zero_grad()
loss.backward()
optim.step()
if iter == 0 and profiler is not None:
profiler.tok("track backward step")
return init_pose, hit_mask
================================================
FILE: src/variations/voxel_helpers.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
""" Modified based on: https://github.com/erikwijmans/Pointnet2_PyTorch """
from __future__ import (
division,
absolute_import,
with_statement,
print_function,
unicode_literals,
)
import os
import sys
import torch
import torch.nn.functional as F
from torch.autograd import Function
import torch.nn as nn
import sys
import numpy as np
import grid as _ext
MAX_DEPTH = 80
class BallRayIntersect(Function):
@staticmethod
def forward(ctx, radius, n_max, points, ray_start, ray_dir):
inds, min_depth, max_depth = _ext.ball_intersect(
ray_start.float(), ray_dir.float(), points.float(), radius, n_max
)
min_depth = min_depth.type_as(ray_start)
max_depth = max_depth.type_as(ray_start)
ctx.mark_non_differentiable(inds)
ctx.mark_non_differentiable(min_depth)
ctx.mark_non_differentiable(max_depth)
return inds, min_depth, max_depth
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None
ball_ray_intersect = BallRayIntersect.apply
class AABBRayIntersect(Function):
@staticmethod
def forward(ctx, voxelsize, n_max, points, ray_start, ray_dir):
# HACK: speed-up ray-voxel intersection by batching...
# HACK: avoid out-of-memory
G = min(2048, int(2 * 10 ** 9 / points.numel()))
S, N = ray_start.shape[:2]
K = int(np.ceil(N / G))
H = K * G
if H > N:
ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
ray_start = ray_start.reshape(S * G, K, 3)
ray_dir = ray_dir.reshape(S * G, K, 3)
points = points.expand(S * G, *points.size()[1:]).contiguous()
inds, min_depth, max_depth = _ext.aabb_intersect(
ray_start.float(), ray_dir.float(), points.float(), voxelsize, n_max
)
min_depth = min_depth.type_as(ray_start)
max_depth = max_depth.type_as(ray_start)
inds = inds.reshape(S, H, -1)
min_depth = min_depth.reshape(S, H, -1)
max_depth = max_depth.reshape(S, H, -1)
if H > N:
inds = inds[:, :N]
min_depth = min_depth[:, :N]
max_depth = max_depth[:, :N]
ctx.mark_non_differentiable(inds)
ctx.mark_non_differentiable(min_depth)
ctx.mark_non_differentiable(max_depth)
return inds, min_depth, max_depth
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None
aabb_ray_intersect = AABBRayIntersect.apply
class SparseVoxelOctreeRayIntersect(Function):
@staticmethod
def forward(ctx, voxelsize, n_max, points, children, ray_start, ray_dir):
# HACK: avoid out-of-memory
torch.cuda.empty_cache()
G = min(256, int(2 * 10 ** 9 / (points.numel() + children.numel())))
S, N = ray_start.shape[:2]
K = int(np.ceil(N / G))
H = K * G
if H > N:
ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
ray_start = ray_start.reshape(S * G, K, 3)
ray_dir = ray_dir.reshape(S * G, K, 3)
points = points.expand(S * G, *points.size()).contiguous()
torch.cuda.empty_cache()
children = children.expand(S * G, *children.size()).contiguous()
torch.cuda.empty_cache()
inds, min_depth, max_depth = _ext.svo_intersect(
ray_start.float(),
ray_dir.float(),
points.float(),
children.int(),
voxelsize,
n_max,
)
torch.cuda.empty_cache()
min_depth = min_depth.type_as(ray_start)
max_depth = max_depth.type_as(ray_start)
inds = inds.reshape(S, H, -1)
min_depth = min_depth.reshape(S, H, -1)
max_depth = max_depth.reshape(S, H, -1)
if H > N:
inds = inds[:, :N]
min_depth = min_depth[:, :N]
max_depth = max_depth[:, :N]
ctx.mark_non_differentiable(inds)
ctx.mark_non_differentiable(min_depth)
ctx.mark_non_differentiable(max_depth)
return inds, min_depth, max_depth
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None
svo_ray_intersect = SparseVoxelOctreeRayIntersect.apply
class TriangleRayIntersect(Function):
@staticmethod
def forward(ctx, cagesize, blur_ratio, n_max, points, faces, ray_start, ray_dir):
# HACK: speed-up ray-voxel intersection by batching...
# HACK: avoid out-of-memory
G = min(2048, int(2 * 10 ** 9 / (3 * faces.numel())))
S, N = ray_start.shape[:2]
K = int(np.ceil(N / G))
H = K * G
if H > N:
ray_start = torch.cat([ray_start, ray_start[:, : H - N]], 1)
ray_dir = torch.cat([ray_dir, ray_dir[:, : H - N]], 1)
ray_start = ray_start.reshape(S * G, K, 3)
ray_dir = ray_dir.reshape(S * G, K, 3)
face_points = F.embedding(faces.reshape(-1, 3), points.reshape(-1, 3))
face_points = (
face_points.unsqueeze(0).expand(
S * G, *face_points.size()).contiguous()
)
inds, depth, uv = _ext.triangle_intersect(
ray_start.float(),
ray_dir.float(),
face_points.float(),
cagesize,
blur_ratio,
n_max,
)
depth = depth.type_as(ray_start)
uv = uv.type_as(ray_start)
inds = inds.reshape(S, H, -1)
depth = depth.reshape(S, H, -1, 3)
uv = uv.reshape(S, H, -1)
if H > N:
inds = inds[:, :N]
depth = depth[:, :N]
uv = uv[:, :N]
ctx.mark_non_differentiable(inds)
ctx.mark_non_differentiable(depth)
ctx.mark_non_differentiable(uv)
return inds, depth, uv
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None, None
triangle_ray_intersect = TriangleRayIntersect.apply
class UniformRaySampling(Function):
@staticmethod
def forward(
ctx,
pts_idx,
min_depth,
max_depth,
step_size,
max_ray_length,
deterministic=False,
):
G, N, P = 256, pts_idx.size(0), pts_idx.size(1)
H = int(np.ceil(N / G)) * G
if H > N:
pts_idx = torch.cat([pts_idx, pts_idx[: H - N]], 0)
min_depth = torch.cat([min_depth, min_depth[: H - N]], 0)
max_depth = torch.cat([max_depth, max_depth[: H - N]], 0)
pts_idx = pts_idx.reshape(G, -1, P)
min_depth = min_depth.reshape(G, -1, P)
max_depth = max_depth.reshape(G, -1, P)
# pre-generate noise
max_steps = int(max_ray_length / step_size)
max_steps = max_steps + min_depth.size(-1) * 2
noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
if deterministic:
noise += 0.5
else:
noise = noise.uniform_()
# call cuda function
sampled_idx, sampled_depth, sampled_dists = _ext.uniform_ray_sampling(
pts_idx,
min_depth.float(),
max_depth.float(),
noise.float(),
step_size,
max_steps,
)
sampled_depth = sampled_depth.type_as(min_depth)
sampled_dists = sampled_dists.type_as(min_depth)
sampled_idx = sampled_idx.reshape(H, -1)
sampled_depth = sampled_depth.reshape(H, -1)
sampled_dists = sampled_dists.reshape(H, -1)
if H > N:
sampled_idx = sampled_idx[:N]
sampled_depth = sampled_depth[:N]
sampled_dists = sampled_dists[:N]
max_len = sampled_idx.ne(-1).sum(-1).max()
sampled_idx = sampled_idx[:, :max_len]
sampled_depth = sampled_depth[:, :max_len]
sampled_dists = sampled_dists[:, :max_len]
ctx.mark_non_differentiable(sampled_idx)
ctx.mark_non_differentiable(sampled_depth)
ctx.mark_non_differentiable(sampled_dists)
return sampled_idx, sampled_depth, sampled_dists
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None, None
uniform_ray_sampling = UniformRaySampling.apply
class InverseCDFRaySampling(Function):
@staticmethod
def forward(
ctx,
pts_idx,
min_depth,
max_depth,
probs,
steps,
fixed_step_size=-1,
deterministic=False,
):
G, N, P = 200, pts_idx.size(0), pts_idx.size(1)
H = int(np.ceil(N / G)) * G
if H > N:
pts_idx = torch.cat([pts_idx, pts_idx[:1].expand(H - N, P)], 0)
min_depth = torch.cat(
[min_depth, min_depth[:1].expand(H - N, P)], 0)
max_depth = torch.cat(
[max_depth, max_depth[:1].expand(H - N, P)], 0)
probs = torch.cat([probs, probs[:1].expand(H - N, P)], 0)
steps = torch.cat([steps, steps[:1].expand(H - N)], 0)
# print(G, P, np.ceil(N / G), N, H, pts_idx.shape, min_depth.device)
pts_idx = pts_idx.reshape(G, -1, P)
min_depth = min_depth.reshape(G, -1, P)
max_depth = max_depth.reshape(G, -1, P)
probs = probs.reshape(G, -1, P)
steps = steps.reshape(G, -1)
# pre-generate noise
max_steps = steps.ceil().long().max() + P
# print(max_steps)
# print(*min_depth.size()[:-1]," ", max_steps)
noise = min_depth.new_zeros(*min_depth.size()[:-1], max_steps)
if deterministic:
noise += 0.5
else:
noise = noise.uniform_().clamp(min=0.001, max=0.999) # in case
# call cuda function
chunk_size = 4 * G # to avoid oom?
results = [
_ext.inverse_cdf_sampling(
pts_idx[:, i: i + chunk_size].contiguous(),
min_depth.float()[:, i: i + chunk_size].contiguous(),
max_depth.float()[:, i: i + chunk_size].contiguous(),
noise.float()[:, i: i + chunk_size].contiguous(),
probs.float()[:, i: i + chunk_size].contiguous(),
steps.float()[:, i: i + chunk_size].contiguous(),
fixed_step_size,
)
for i in range(0, min_depth.size(1), chunk_size)
]
sampled_idx, sampled_depth, sampled_dists = [
torch.cat([r[i] for r in results], 1) for i in range(3)
]
sampled_depth = sampled_depth.type_as(min_depth)
sampled_dists = sampled_dists.type_as(min_depth)
sampled_idx = sampled_idx.reshape(H, -1)
sampled_depth = sampled_depth.reshape(H, -1)
sampled_dists = sampled_dists.reshape(H, -1)
if H > N:
sampled_idx = sampled_idx[:N]
sampled_depth = sampled_depth[:N]
sampled_dists = sampled_dists[:N]
max_len = sampled_idx.ne(-1).sum(-1).max()
sampled_idx = sampled_idx[:, :max_len]
sampled_depth = sampled_depth[:, :max_len]
sampled_dists = sampled_dists[:, :max_len]
ctx.mark_non_differentiable(sampled_idx)
ctx.mark_non_differentiable(sampled_depth)
ctx.mark_non_differentiable(sampled_dists)
return sampled_idx, sampled_depth, sampled_dists
@staticmethod
def backward(ctx, a, b, c):
return None, None, None, None, None, None, None
inverse_cdf_sampling = InverseCDFRaySampling.apply
# back-up for ray point sampling
@torch.no_grad()
def _parallel_ray_sampling(
MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False
):
# uniform sampling
_min_depth = min_depth.min(1)[0]
_max_depth = max_depth.masked_fill(max_depth.eq(MAX_DEPTH), 0).max(1)[0]
max_ray_length = (_max_depth - _min_depth).max()
delta = torch.arange(
int(max_ray_length / MARCH_SIZE), device=min_depth.device, dtype=min_depth.dtype
)
delta = delta[None, :].expand(min_depth.size(0), delta.size(-1))
if deterministic:
delta = delta + 0.5
else:
delta = delta + delta.clone().uniform_().clamp(min=0.01, max=0.99)
delta = delta * MARCH_SIZE
sampled_depth = min_depth[:, :1] + delta
sampled_idx = (sampled_depth[:, :, None] >=
min_depth[:, None, :]).sum(-1) - 1
sampled_idx = pts_idx.gather(1, sampled_idx)
# include all boundary points
sampled_depth = torch.cat([min_depth, max_depth, sampled_depth], -1)
sampled_idx = torch.cat([pts_idx, pts_idx, sampled_idx], -1)
# reorder
sampled_depth, ordered_index = sampled_depth.sort(-1)
sampled_idx = sampled_idx.gather(1, ordered_index)
sampled_dists = sampled_depth[:, 1:] - sampled_depth[:, :-1] # distances
sampled_depth = 0.5 * \
(sampled_depth[:, 1:] + sampled_depth[:, :-1]) # mid-points
# remove all invalid depths
min_ids = (sampled_depth[:, :, None] >= min_depth[:, None, :]).sum(-1) - 1
max_ids = (sampled_depth[:, :, None] >= max_depth[:, None, :]).sum(-1)
sampled_depth.masked_fill_(
(max_ids.ne(min_ids))
| (sampled_depth > _max_depth[:, None])
| (sampled_dists == 0.0),
MAX_DEPTH,
)
sampled_depth, ordered_index = sampled_depth.sort(-1) # sort again
sampled_masks = sampled_depth.eq(MAX_DEPTH)
num_max_steps = (~sampled_masks).sum(-1).max()
sampled_depth = sampled_depth[:, :num_max_steps]
sampled_dists = sampled_dists.gather(1, ordered_index).masked_fill_(
sampled_masks, 0.0
)[:, :num_max_steps]
sampled_idx = sampled_idx.gather(1, ordered_index).masked_fill_(sampled_masks, -1)[
:, :num_max_steps
]
return sampled_idx, sampled_depth, sampled_dists
@torch.no_grad()
def parallel_ray_sampling(
MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=False
):
chunk_size = 4096
full_size = min_depth.shape[0]
if full_size <= chunk_size:
return _parallel_ray_sampling(
MARCH_SIZE, pts_idx, min_depth, max_depth, deterministic=deterministic
)
outputs = zip(
*[
_parallel_ray_sampling(
MARCH_SIZE,
pts_idx[i: i + chunk_size],
min_depth[i: i + chunk_size],
max_depth[i: i + chunk_size],
deterministic=deterministic,
)
for i in range(0, full_size, chunk_size)
]
)
sampled_idx, sampled_depth, sampled_dists = outputs
def padding_points(xs, pad):
if len(xs) == 1:
return xs[0]
maxlen = max([x.size(1) for x in xs])
full_size = sum([x.size(0) for x in xs])
xt = xs[0].new_ones(full_size, maxlen).fill_(pad)
st = 0
for i in range(len(xs)):
xt[st: st + xs[i].size(0), : xs[i].size(1)] = xs[i]
st += xs[i].size(0)
return xt
sampled_idx = padding_points(sampled_idx, -1)
sampled_depth = padding_points(sampled_depth, MAX_DEPTH)
sampled_dists = padding_points(sampled_dists, 0.0)
return sampled_idx, sampled_depth, sampled_dists
def discretize_points(voxel_points, voxel_size):
# this function turns voxel centers/corners into integer indeices
# we assume all points are alreay put as voxels (real numbers)
minimal_voxel_point = voxel_points.min(dim=0, keepdim=True)[0]
voxel_indices = (
((voxel_points - minimal_voxel_point) / voxel_size).round_().long()
) # float
residual = (voxel_points - voxel_indices.type_as(voxel_points) * voxel_size).mean(
0, keepdim=True
)
return voxel_indices, residual
def build_easy_octree(points, half_voxel):
coords, residual = discretize_points(points, half_voxel)
ranges = coords.max(0)[0] - coords.min(0)[0]
depths = torch.log2(ranges.max().float()).ceil_().long() - 1
center = (coords.max(0)[0] + coords.min(0)[0]) / 2
centers, children = _ext.build_octree(center, coords, int(depths))
centers = centers.float() * half_voxel + residual # transform back to float
return centers, children
@torch.enable_grad()
def trilinear_interp(p, q, point_feats):
weights = (p * q + (1 - p) * (1 - q)).prod(dim=-1, keepdim=True)
if point_feats.dim() == 2:
point_feats = point_feats.view(point_feats.size(0), 8, -1)
point_feats = (weights * point_feats).sum(1)
return point_feats
def offset_points(point_xyz, quarter_voxel=1, offset_only=False, bits=2):
c = torch.arange(1, 2 * bits, 2, device=point_xyz.device)
ox, oy, oz = torch.meshgrid([c, c, c], indexing='ij')
offset = (torch.cat([ox.reshape(-1, 1),
oy.reshape(-1, 1),
oz.reshape(-1, 1)], 1).type_as(point_xyz) - bits) / float(bits - 1)
if not offset_only:
return (
point_xyz.unsqueeze(1) + offset.unsqueeze(0).type_as(point_xyz) * quarter_voxel)
return offset.type_as(point_xyz) * quarter_voxel
def splitting_points(point_xyz, point_feats, values, half_voxel):
# generate new centers
quarter_voxel = half_voxel * 0.5
new_points = offset_points(point_xyz, quarter_voxel).reshape(-1, 3)
old_coords = discretize_points(point_xyz, quarter_voxel)[0]
new_coords = offset_points(old_coords).reshape(-1, 3)
new_keys0 = offset_points(new_coords).reshape(-1, 3)
# get unique keys and inverse indices (for original key0, where it maps to in keys)
new_keys, new_feats = torch.unique(
new_keys0, dim=0, sorted=True, return_inverse=True)
new_keys_idx = new_feats.new_zeros(new_keys.size(0)).scatter_(
0, new_feats, torch.arange(new_keys0.size(0), device=new_feats.device) // 64)
# recompute key vectors using trilinear interpolation
new_feats = new_feats.reshape(-1, 8)
if values is not None:
# (1/4 voxel size)
p = (new_keys - old_coords[new_keys_idx]
).type_as(point_xyz).unsqueeze(1) * 0.25 + 0.5
q = offset_points(p, 0.5, offset_only=True).unsqueeze(0) + 0.5 # BUG?
point_feats = point_feats[new_keys_idx]
point_feats = F.embedding(point_feats, values).view(
point_feats.size(0), -1)
new_values = trilinear_interp(p, q, point_feats)
else:
new_values = None
return new_points, new_feats, new_values, new_keys
@torch.no_grad()
def ray_intersect(ray_start, ray_dir, flatten_centers, flatten_children, voxel_size, max_hits, max_distance=MAX_DEPTH):
# ray-voxel intersection
max_hits_temp = 20
pts_idx, min_depth, max_depth = svo_ray_intersect(
voxel_size,
max_hits_temp,
flatten_centers,
flatten_children,
ray_start,
ray_dir)
torch.cuda.empty_cache()
# sort the depths
min_depth.masked_fill_(pts_idx.eq(-1), max_distance)
max_depth.masked_fill_(pts_idx.eq(-1), max_distance)
min_depth, sorted_idx = min_depth.sort(dim=-1)
max_depth = max_depth.gather(-1, sorted_idx)
pts_idx = pts_idx.gather(-1, sorted_idx)
# print(max_depth.max())
pts_idx[max_depth > 2*max_distance] = -1
pts_idx[min_depth > max_distance] = -1
min_depth.masked_fill_(pts_idx.eq(-1), max_distance)
max_depth.masked_fill_(pts_idx.eq(-1), max_distance)
# remove all points that completely miss the object
max_hits = torch.max(pts_idx.ne(-1).sum(-1))
min_depth = min_depth[..., :max_hits]
max_depth = max_depth[..., :max_hits]
pts_idx = pts_idx[..., :max_hits]
hits = pts_idx.ne(-1).any(-1)
intersection_outputs = {
"min_depth": min_depth,
"max_depth": max_depth,
"intersected_voxel_idx": pts_idx,
}
return intersection_outputs, hits
@torch.no_grad()
def ray_sample(intersection_outputs, step_size=0.01, fixed=False):
dists = (
intersection_outputs["max_depth"] -
intersection_outputs["min_depth"]
).masked_fill(intersection_outputs["intersected_voxel_idx"].eq(-1), 0)
intersection_outputs["probs"] = dists / dists.sum(dim=-1, keepdim=True)
intersection_outputs["steps"] = dists.sum(-1) / step_size
# TODO:A serious BUG need to fix!
if dists.sum(-1).max() > 10 * MAX_DEPTH:
return
# sample points and use middle point approximation
sampled_idx, sampled_depth, sampled_dists = inverse_cdf_sampling(
intersection_outputs["intersected_voxel_idx"],
intersection_outputs["min_depth"],
intersection_outputs["max_depth"],
intersection_outputs["probs"],
intersection_outputs["steps"], -1, fixed)
sampled_dists = sampled_dists.clamp(min=0.0)
sampled_depth.masked_fill_(sampled_idx.eq(-1), MAX_DEPTH)
sampled_dists.masked_fill_(sampled_idx.eq(-1), 0.0)
samples = {
"sampled_point_depth": sampled_depth,
"sampled_point_distance": sampled_dists,
"sampled_point_voxel_idx": sampled_idx,
}
return samples
================================================
FILE: third_party/marching_cubes/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import glob
_ext_sources = glob.glob("src/*.cpp") + glob.glob("src/*.cu")
setup(
name='marching_cubes',
ext_modules=[
CUDAExtension(
name='marching_cubes',
sources=_ext_sources,
extra_compile_args={
"cxx": ["-O2", "-I./include"],
"nvcc": ["-I./include"]
},
)
],
cmdclass={
'build_ext': BuildExtension
}
)
================================================
FILE: third_party/marching_cubes/src/mc.cpp
================================================
#include
std::vector marching_cubes_sparse(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer.
);
std::vector marching_cubes_sparse_colour(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz, 4)
torch::Tensor cube_colour, // (M, rx, ry, rz)
const std::vector &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer.
);
std::vector marching_cubes_sparse_interp_cuda(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer.
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("marching_cubes_sparse", &marching_cubes_sparse, "Marching Cubes without Interpolation (CUDA)");
m.def("marching_cubes_sparse_colour", &marching_cubes_sparse_colour, "Marching Cubes without Interpolation (CUDA)");
m.def("marching_cubes_sparse_interp", &marching_cubes_sparse_interp_cuda, "Marching Cubes with Interpolation (CUDA)");
}
================================================
FILE: third_party/marching_cubes/src/mc_data.cuh
================================================
#include
#include
#include
#include
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
using IndexerAccessor = torch::PackedTensorAccessor32;
using ValidBlocksAccessor = torch::PackedTensorAccessor32;
using BackwardMappingAccessor = torch::PackedTensorAccessor32;
using CubeSDFAccessor = torch::PackedTensorAccessor32;
using CubeSDFRGBAccessor = torch::PackedTensorAccessor32;
using TrianglesAccessor = torch::PackedTensorAccessor32;
using TriangleStdAccessor = torch::PackedTensorAccessor32;
using TriangleVecIdAccessor = torch::PackedTensorAccessor32;
__inline__ __device__ float4 make_float4(const float3& xyz, float w) {
return make_float4(xyz.x, xyz.y, xyz.z, w);
}
inline __host__ __device__ void operator+=(float2 &a, const float2& b) {
a.x += b.x; a.y += b.y;
}
inline __host__ __device__ float2 operator*(const float2& a, float b) {
return make_float2(a.x * b, a.y * b);
}
inline __host__ __device__ float2 operator/(const float2& a, float b) {
return make_float2(a.x / b, a.y / b);
}
inline __host__ __device__ float2 operator/(const float2& a, const float2& b) {
return make_float2(a.x / b.x, a.y / b.y);
}
__constant__ int edgeTable[256] = { 0x0, 0x109, 0x203, 0x30a, 0x406, 0x50f, 0x605, 0x70c, 0x80c, 0x905, 0xa0f, 0xb06, 0xc0a, 0xd03, 0xe09, 0xf00,
0x190, 0x99, 0x393, 0x29a, 0x596, 0x49f, 0x795, 0x69c, 0x99c, 0x895, 0xb9f, 0xa96, 0xd9a, 0xc93, 0xf99, 0xe90, 0x230, 0x339, 0x33, 0x13a,
0x636, 0x73f, 0x435, 0x53c, 0xa3c, 0xb35, 0x83f, 0x936, 0xe3a, 0xf33, 0xc39, 0xd30, 0x3a0, 0x2a9, 0x1a3, 0xaa, 0x7a6, 0x6af, 0x5a5, 0x4ac,
0xbac, 0xaa5, 0x9af, 0x8a6, 0xfaa, 0xea3, 0xda9, 0xca0, 0x460, 0x569, 0x663, 0x76a, 0x66, 0x16f, 0x265, 0x36c, 0xc6c, 0xd65, 0xe6f, 0xf66,
0x86a, 0x963, 0xa69, 0xb60, 0x5f0, 0x4f9, 0x7f3, 0x6fa, 0x1f6, 0xff, 0x3f5, 0x2fc, 0xdfc, 0xcf5, 0xfff, 0xef6, 0x9fa, 0x8f3, 0xbf9, 0xaf0,
0x650, 0x759, 0x453, 0x55a, 0x256, 0x35f, 0x55, 0x15c, 0xe5c, 0xf55, 0xc5f, 0xd56, 0xa5a, 0xb53, 0x859, 0x950, 0x7c0, 0x6c9, 0x5c3, 0x4ca,
0x3c6, 0x2cf, 0x1c5, 0xcc, 0xfcc, 0xec5, 0xdcf, 0xcc6, 0xbca, 0xac3, 0x9c9, 0x8c0, 0x8c0, 0x9c9, 0xac3, 0xbca, 0xcc6, 0xdcf, 0xec5, 0xfcc,
0xcc, 0x1c5, 0x2cf, 0x3c6, 0x4ca, 0x5c3, 0x6c9, 0x7c0, 0x950, 0x859, 0xb53, 0xa5a, 0xd56, 0xc5f, 0xf55, 0xe5c, 0x15c, 0x55, 0x35f, 0x256,
0x55a, 0x453, 0x759, 0x650, 0xaf0, 0xbf9, 0x8f3, 0x9fa, 0xef6, 0xfff, 0xcf5, 0xdfc, 0x2fc, 0x3f5, 0xff, 0x1f6, 0x6fa, 0x7f3, 0x4f9, 0x5f0,
0xb60, 0xa69, 0x963, 0x86a, 0xf66, 0xe6f, 0xd65, 0xc6c, 0x36c, 0x265, 0x16f, 0x66, 0x76a, 0x663, 0x569, 0x460, 0xca0, 0xda9, 0xea3, 0xfaa,
0x8a6, 0x9af, 0xaa5, 0xbac, 0x4ac, 0x5a5, 0x6af, 0x7a6, 0xaa, 0x1a3, 0x2a9, 0x3a0, 0xd30, 0xc39, 0xf33, 0xe3a, 0x936, 0x83f, 0xb35, 0xa3c,
0x53c, 0x435, 0x73f, 0x636, 0x13a, 0x33, 0x339, 0x230, 0xe90, 0xf99, 0xc93, 0xd9a, 0xa96, 0xb9f, 0x895, 0x99c, 0x69c, 0x795, 0x49f, 0x596,
0x29a, 0x393, 0x99, 0x190, 0xf00, 0xe09, 0xd03, 0xc0a, 0xb06, 0xa0f, 0x905, 0x80c, 0x70c, 0x605, 0x50f, 0x406, 0x30a, 0x203, 0x109, 0x0 };
__constant__ int triangleTable[256][16] = { { -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 8, 3, 9, 8, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 2, 10, 0, 2, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 8, 3, 2, 10, 8, 10, 9, 8, -1, -1, -1, -1, -1, -1, -1 }, { 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 11, 2, 8, 11, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 11, 2, 1, 9, 11, 9, 8, 11, -1, -1, -1, -1, -1, -1, -1 }, { 3, 10, 1, 11, 10, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 10, 1, 0, 8, 10, 8, 11, 10, -1, -1, -1, -1, -1, -1, -1 }, { 3, 9, 0, 3, 11, 9, 11, 10, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 3, 0, 7, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 1, 9, 4, 7, 1, 7, 3, 1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 4, 7, 3, 0, 4, 1, 2, 10, -1, -1, -1, -1, -1, -1, -1 }, { 9, 2, 10, 9, 0, 2, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 10, 9, 2, 9, 7, 2, 7, 3, 7, 9, 4, -1, -1, -1, -1 }, { 8, 4, 7, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 4, 7, 11, 2, 4, 2, 0, 4, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 1, 8, 4, 7, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 7, 11, 9, 4, 11, 9, 11, 2, 9, 2, 1, -1, -1, -1, -1 }, { 3, 10, 1, 3, 11, 10, 7, 8, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 11, 10, 1, 4, 11, 1, 0, 4, 7, 11, 4, -1, -1, -1, -1 }, { 4, 7, 8, 9, 0, 11, 9, 11, 10, 11, 0, 3, -1, -1, -1, -1 },
{ 4, 7, 11, 4, 11, 9, 9, 11, 10, -1, -1, -1, -1, -1, -1, -1 }, { 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 4, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 5, 4, 1, 5, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 5, 4, 8, 3, 5, 3, 1, 5, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 1, 2, 10, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1 }, { 5, 2, 10, 5, 4, 2, 4, 0, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 10, 5, 3, 2, 5, 3, 5, 4, 3, 4, 8, -1, -1, -1, -1 }, { 9, 5, 4, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 11, 2, 0, 8, 11, 4, 9, 5, -1, -1, -1, -1, -1, -1, -1 }, { 0, 5, 4, 0, 1, 5, 2, 3, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 2, 1, 5, 2, 5, 8, 2, 8, 11, 4, 8, 5, -1, -1, -1, -1 }, { 10, 3, 11, 10, 1, 3, 9, 5, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 9, 5, 0, 8, 1, 8, 10, 1, 8, 11, 10, -1, -1, -1, -1 }, { 5, 4, 0, 5, 0, 11, 5, 11, 10, 11, 0, 3, -1, -1, -1, -1 },
{ 5, 4, 8, 5, 8, 10, 10, 8, 11, -1, -1, -1, -1, -1, -1, -1 }, { 9, 7, 8, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 3, 0, 9, 5, 3, 5, 7, 3, -1, -1, -1, -1, -1, -1, -1 }, { 0, 7, 8, 0, 1, 7, 1, 5, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 7, 8, 9, 5, 7, 10, 1, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 1, 2, 9, 5, 0, 5, 3, 0, 5, 7, 3, -1, -1, -1, -1 }, { 8, 0, 2, 8, 2, 5, 8, 5, 7, 10, 5, 2, -1, -1, -1, -1 },
{ 2, 10, 5, 2, 5, 3, 3, 5, 7, -1, -1, -1, -1, -1, -1, -1 }, { 7, 9, 5, 7, 8, 9, 3, 11, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 7, 9, 7, 2, 9, 2, 0, 2, 7, 11, -1, -1, -1, -1 }, { 2, 3, 11, 0, 1, 8, 1, 7, 8, 1, 5, 7, -1, -1, -1, -1 },
{ 11, 2, 1, 11, 1, 7, 7, 1, 5, -1, -1, -1, -1, -1, -1, -1 }, { 9, 5, 8, 8, 5, 7, 10, 1, 3, 10, 3, 11, -1, -1, -1, -1 },
{ 5, 7, 0, 5, 0, 9, 7, 11, 0, 1, 0, 10, 11, 10, 0, -1 }, { 11, 10, 0, 11, 0, 3, 10, 5, 0, 8, 0, 7, 5, 7, 0, -1 },
{ 11, 10, 5, 7, 11, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 1, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 8, 3, 1, 9, 8, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1 }, { 1, 6, 5, 2, 6, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 6, 5, 1, 2, 6, 3, 0, 8, -1, -1, -1, -1, -1, -1, -1 }, { 9, 6, 5, 9, 0, 6, 0, 2, 6, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 9, 8, 5, 8, 2, 5, 2, 6, 3, 2, 8, -1, -1, -1, -1 }, { 2, 3, 11, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 0, 8, 11, 2, 0, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 10, 6, 1, 9, 2, 9, 11, 2, 9, 8, 11, -1, -1, -1, -1 }, { 6, 3, 11, 6, 5, 3, 5, 1, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 11, 0, 11, 5, 0, 5, 1, 5, 11, 6, -1, -1, -1, -1 }, { 3, 11, 6, 0, 3, 6, 0, 6, 5, 0, 5, 9, -1, -1, -1, -1 },
{ 6, 5, 9, 6, 9, 11, 11, 9, 8, -1, -1, -1, -1, -1, -1, -1 }, { 5, 10, 6, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 3, 0, 4, 7, 3, 6, 5, 10, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 5, 10, 6, 8, 4, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 6, 5, 1, 9, 7, 1, 7, 3, 7, 9, 4, -1, -1, -1, -1 }, { 6, 1, 2, 6, 5, 1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 5, 5, 2, 6, 3, 0, 4, 3, 4, 7, -1, -1, -1, -1 }, { 8, 4, 7, 9, 0, 5, 0, 6, 5, 0, 2, 6, -1, -1, -1, -1 },
{ 7, 3, 9, 7, 9, 4, 3, 2, 9, 5, 9, 6, 2, 6, 9, -1 }, { 3, 11, 2, 7, 8, 4, 10, 6, 5, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 10, 6, 4, 7, 2, 4, 2, 0, 2, 7, 11, -1, -1, -1, -1 }, { 0, 1, 9, 4, 7, 8, 2, 3, 11, 5, 10, 6, -1, -1, -1, -1 },
{ 9, 2, 1, 9, 11, 2, 9, 4, 11, 7, 11, 4, 5, 10, 6, -1 }, { 8, 4, 7, 3, 11, 5, 3, 5, 1, 5, 11, 6, -1, -1, -1, -1 },
{ 5, 1, 11, 5, 11, 6, 1, 0, 11, 7, 11, 4, 0, 4, 11, -1 }, { 0, 5, 9, 0, 6, 5, 0, 3, 6, 11, 6, 3, 8, 4, 7, -1 },
{ 6, 5, 9, 6, 9, 11, 4, 7, 9, 7, 11, 9, -1, -1, -1, -1 }, { 10, 4, 9, 6, 4, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 10, 6, 4, 9, 10, 0, 8, 3, -1, -1, -1, -1, -1, -1, -1 }, { 10, 0, 1, 10, 6, 0, 6, 4, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 3, 1, 8, 1, 6, 8, 6, 4, 6, 1, 10, -1, -1, -1, -1 }, { 1, 4, 9, 1, 2, 4, 2, 6, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 1, 2, 9, 2, 4, 9, 2, 6, 4, -1, -1, -1, -1 }, { 0, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 3, 2, 8, 2, 4, 4, 2, 6, -1, -1, -1, -1, -1, -1, -1 }, { 10, 4, 9, 10, 6, 4, 11, 2, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 2, 2, 8, 11, 4, 9, 10, 4, 10, 6, -1, -1, -1, -1 }, { 3, 11, 2, 0, 1, 6, 0, 6, 4, 6, 1, 10, -1, -1, -1, -1 },
{ 6, 4, 1, 6, 1, 10, 4, 8, 1, 2, 1, 11, 8, 11, 1, -1 }, { 9, 6, 4, 9, 3, 6, 9, 1, 3, 11, 6, 3, -1, -1, -1, -1 },
{ 8, 11, 1, 8, 1, 0, 11, 6, 1, 9, 1, 4, 6, 4, 1, -1 }, { 3, 11, 6, 3, 6, 0, 0, 6, 4, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 4, 8, 11, 6, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 10, 6, 7, 8, 10, 8, 9, 10, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 7, 3, 0, 10, 7, 0, 9, 10, 6, 7, 10, -1, -1, -1, -1 }, { 10, 6, 7, 1, 10, 7, 1, 7, 8, 1, 8, 0, -1, -1, -1, -1 },
{ 10, 6, 7, 10, 7, 1, 1, 7, 3, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 6, 1, 6, 8, 1, 8, 9, 8, 6, 7, -1, -1, -1, -1 },
{ 2, 6, 9, 2, 9, 1, 6, 7, 9, 0, 9, 3, 7, 3, 9, -1 }, { 7, 8, 0, 7, 0, 6, 6, 0, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 7, 3, 2, 6, 7, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 11, 10, 6, 8, 10, 8, 9, 8, 6, 7, -1, -1, -1, -1 },
{ 2, 0, 7, 2, 7, 11, 0, 9, 7, 6, 7, 10, 9, 10, 7, -1 }, { 1, 8, 0, 1, 7, 8, 1, 10, 7, 6, 7, 10, 2, 3, 11, -1 },
{ 11, 2, 1, 11, 1, 7, 10, 6, 1, 6, 7, 1, -1, -1, -1, -1 }, { 8, 9, 6, 8, 6, 7, 9, 1, 6, 11, 6, 3, 1, 3, 6, -1 },
{ 0, 9, 1, 11, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 8, 0, 7, 0, 6, 3, 11, 0, 11, 6, 0, -1, -1, -1, -1 },
{ 7, 11, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 8, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 9, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 1, 9, 8, 3, 1, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1 }, { 10, 1, 2, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 10, 3, 0, 8, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1 }, { 2, 9, 0, 2, 10, 9, 6, 11, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 11, 7, 2, 10, 3, 10, 8, 3, 10, 9, 8, -1, -1, -1, -1 }, { 7, 2, 3, 6, 2, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 7, 0, 8, 7, 6, 0, 6, 2, 0, -1, -1, -1, -1, -1, -1, -1 }, { 2, 7, 6, 2, 3, 7, 0, 1, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 6, 2, 1, 8, 6, 1, 9, 8, 8, 7, 6, -1, -1, -1, -1 }, { 10, 7, 6, 10, 1, 7, 1, 3, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 7, 6, 1, 7, 10, 1, 8, 7, 1, 0, 8, -1, -1, -1, -1 }, { 0, 3, 7, 0, 7, 10, 0, 10, 9, 6, 10, 7, -1, -1, -1, -1 },
{ 7, 6, 10, 7, 10, 8, 8, 10, 9, -1, -1, -1, -1, -1, -1, -1 }, { 6, 8, 4, 11, 8, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 6, 11, 3, 0, 6, 0, 4, 6, -1, -1, -1, -1, -1, -1, -1 }, { 8, 6, 11, 8, 4, 6, 9, 0, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 4, 6, 9, 6, 3, 9, 3, 1, 11, 3, 6, -1, -1, -1, -1 }, { 6, 8, 4, 6, 11, 8, 2, 10, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 1, 2, 10, 3, 0, 11, 0, 6, 11, 0, 4, 6, -1, -1, -1, -1 }, { 4, 11, 8, 4, 6, 11, 0, 2, 9, 2, 10, 9, -1, -1, -1, -1 },
{ 10, 9, 3, 10, 3, 2, 9, 4, 3, 11, 3, 6, 4, 6, 3, -1 }, { 8, 2, 3, 8, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 4, 2, 4, 6, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 9, 0, 2, 3, 4, 2, 4, 6, 4, 3, 8, -1, -1, -1, -1 },
{ 1, 9, 4, 1, 4, 2, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1 }, { 8, 1, 3, 8, 6, 1, 8, 4, 6, 6, 10, 1, -1, -1, -1, -1 },
{ 10, 1, 0, 10, 0, 6, 6, 0, 4, -1, -1, -1, -1, -1, -1, -1 }, { 4, 6, 3, 4, 3, 8, 6, 10, 3, 0, 3, 9, 10, 9, 3, -1 },
{ 10, 9, 4, 6, 10, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 9, 5, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 4, 9, 5, 11, 7, 6, -1, -1, -1, -1, -1, -1, -1 }, { 5, 0, 1, 5, 4, 0, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 7, 6, 8, 3, 4, 3, 5, 4, 3, 1, 5, -1, -1, -1, -1 }, { 9, 5, 4, 10, 1, 2, 7, 6, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 6, 11, 7, 1, 2, 10, 0, 8, 3, 4, 9, 5, -1, -1, -1, -1 }, { 7, 6, 11, 5, 4, 10, 4, 2, 10, 4, 0, 2, -1, -1, -1, -1 },
{ 3, 4, 8, 3, 5, 4, 3, 2, 5, 10, 5, 2, 11, 7, 6, -1 }, { 7, 2, 3, 7, 6, 2, 5, 4, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 5, 4, 0, 8, 6, 0, 6, 2, 6, 8, 7, -1, -1, -1, -1 }, { 3, 6, 2, 3, 7, 6, 1, 5, 0, 5, 4, 0, -1, -1, -1, -1 },
{ 6, 2, 8, 6, 8, 7, 2, 1, 8, 4, 8, 5, 1, 5, 8, -1 }, { 9, 5, 4, 10, 1, 6, 1, 7, 6, 1, 3, 7, -1, -1, -1, -1 },
{ 1, 6, 10, 1, 7, 6, 1, 0, 7, 8, 7, 0, 9, 5, 4, -1 }, { 4, 0, 10, 4, 10, 5, 0, 3, 10, 6, 10, 7, 3, 7, 10, -1 },
{ 7, 6, 10, 7, 10, 8, 5, 4, 10, 4, 8, 10, -1, -1, -1, -1 }, { 6, 9, 5, 6, 11, 9, 11, 8, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 6, 11, 0, 6, 3, 0, 5, 6, 0, 9, 5, -1, -1, -1, -1 }, { 0, 11, 8, 0, 5, 11, 0, 1, 5, 5, 6, 11, -1, -1, -1, -1 },
{ 6, 11, 3, 6, 3, 5, 5, 3, 1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 10, 9, 5, 11, 9, 11, 8, 11, 5, 6, -1, -1, -1, -1 },
{ 0, 11, 3, 0, 6, 11, 0, 9, 6, 5, 6, 9, 1, 2, 10, -1 }, { 11, 8, 5, 11, 5, 6, 8, 0, 5, 10, 5, 2, 0, 2, 5, -1 },
{ 6, 11, 3, 6, 3, 5, 2, 10, 3, 10, 5, 3, -1, -1, -1, -1 }, { 5, 8, 9, 5, 2, 8, 5, 6, 2, 3, 8, 2, -1, -1, -1, -1 },
{ 9, 5, 6, 9, 6, 0, 0, 6, 2, -1, -1, -1, -1, -1, -1, -1 }, { 1, 5, 8, 1, 8, 0, 5, 6, 8, 3, 8, 2, 6, 2, 8, -1 },
{ 1, 5, 6, 2, 1, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 3, 6, 1, 6, 10, 3, 8, 6, 5, 6, 9, 8, 9, 6, -1 },
{ 10, 1, 0, 10, 0, 6, 9, 5, 0, 5, 6, 0, -1, -1, -1, -1 }, { 0, 3, 8, 5, 6, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 11, 5, 10, 7, 5, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 5, 10, 11, 7, 5, 8, 3, 0, -1, -1, -1, -1, -1, -1, -1 }, { 5, 11, 7, 5, 10, 11, 1, 9, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 10, 7, 5, 10, 11, 7, 9, 8, 1, 8, 3, 1, -1, -1, -1, -1 }, { 11, 1, 2, 11, 7, 1, 7, 5, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 1, 2, 7, 1, 7, 5, 7, 2, 11, -1, -1, -1, -1 }, { 9, 7, 5, 9, 2, 7, 9, 0, 2, 2, 11, 7, -1, -1, -1, -1 },
{ 7, 5, 2, 7, 2, 11, 5, 9, 2, 3, 2, 8, 9, 8, 2, -1 }, { 2, 5, 10, 2, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1 },
{ 8, 2, 0, 8, 5, 2, 8, 7, 5, 10, 2, 5, -1, -1, -1, -1 }, { 9, 0, 1, 5, 10, 3, 5, 3, 7, 3, 10, 2, -1, -1, -1, -1 },
{ 9, 8, 2, 9, 2, 1, 8, 7, 2, 10, 2, 5, 7, 5, 2, -1 }, { 1, 3, 5, 3, 7, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 7, 0, 7, 1, 1, 7, 5, -1, -1, -1, -1, -1, -1, -1 }, { 9, 0, 3, 9, 3, 5, 5, 3, 7, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 8, 7, 5, 9, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 5, 8, 4, 5, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 5, 0, 4, 5, 11, 0, 5, 10, 11, 11, 3, 0, -1, -1, -1, -1 }, { 0, 1, 9, 8, 4, 10, 8, 10, 11, 10, 4, 5, -1, -1, -1, -1 },
{ 10, 11, 4, 10, 4, 5, 11, 3, 4, 9, 4, 1, 3, 1, 4, -1 }, { 2, 5, 1, 2, 8, 5, 2, 11, 8, 4, 5, 8, -1, -1, -1, -1 },
{ 0, 4, 11, 0, 11, 3, 4, 5, 11, 2, 11, 1, 5, 1, 11, -1 }, { 0, 2, 5, 0, 5, 9, 2, 11, 5, 4, 5, 8, 11, 8, 5, -1 },
{ 9, 4, 5, 2, 11, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 5, 10, 3, 5, 2, 3, 4, 5, 3, 8, 4, -1, -1, -1, -1 },
{ 5, 10, 2, 5, 2, 4, 4, 2, 0, -1, -1, -1, -1, -1, -1, -1 }, { 3, 10, 2, 3, 5, 10, 3, 8, 5, 4, 5, 8, 0, 1, 9, -1 },
{ 5, 10, 2, 5, 2, 4, 1, 9, 2, 9, 4, 2, -1, -1, -1, -1 }, { 8, 4, 5, 8, 5, 3, 3, 5, 1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 4, 5, 1, 0, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 8, 4, 5, 8, 5, 3, 9, 0, 5, 0, 3, 5, -1, -1, -1, -1 },
{ 9, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 11, 7, 4, 9, 11, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 8, 3, 4, 9, 7, 9, 11, 7, 9, 10, 11, -1, -1, -1, -1 }, { 1, 10, 11, 1, 11, 4, 1, 4, 0, 7, 4, 11, -1, -1, -1, -1 },
{ 3, 1, 4, 3, 4, 8, 1, 10, 4, 7, 4, 11, 10, 11, 4, -1 }, { 4, 11, 7, 9, 11, 4, 9, 2, 11, 9, 1, 2, -1, -1, -1, -1 },
{ 9, 7, 4, 9, 11, 7, 9, 1, 11, 2, 11, 1, 0, 8, 3, -1 }, { 11, 7, 4, 11, 4, 2, 2, 4, 0, -1, -1, -1, -1, -1, -1, -1 },
{ 11, 7, 4, 11, 4, 2, 8, 3, 4, 3, 2, 4, -1, -1, -1, -1 }, { 2, 9, 10, 2, 7, 9, 2, 3, 7, 7, 4, 9, -1, -1, -1, -1 },
{ 9, 10, 7, 9, 7, 4, 10, 2, 7, 8, 7, 0, 2, 0, 7, -1 }, { 3, 7, 10, 3, 10, 2, 7, 4, 10, 1, 10, 0, 4, 0, 10, -1 },
{ 1, 10, 2, 8, 7, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 4, 9, 1, 4, 1, 7, 7, 1, 3, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 9, 1, 4, 1, 7, 0, 8, 1, 8, 7, 1, -1, -1, -1, -1 }, { 4, 0, 3, 7, 4, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 4, 8, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 9, 10, 8, 10, 11, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 9, 3, 9, 11, 11, 9, 10, -1, -1, -1, -1, -1, -1, -1 }, { 0, 1, 10, 0, 10, 8, 8, 10, 11, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 1, 10, 11, 3, 10, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 2, 11, 1, 11, 9, 9, 11, 8, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 0, 9, 3, 9, 11, 1, 2, 9, 2, 11, 9, -1, -1, -1, -1 }, { 0, 2, 11, 8, 0, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 3, 2, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 8, 2, 8, 10, 10, 8, 9, -1, -1, -1, -1, -1, -1, -1 },
{ 9, 10, 2, 0, 9, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 2, 3, 8, 2, 8, 10, 0, 1, 8, 1, 10, 8, -1, -1, -1, -1 },
{ 1, 10, 2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 1, 3, 8, 9, 1, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ 0, 9, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 }, { 0, 3, 8, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 },
{ -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 } };
================================================
FILE: third_party/marching_cubes/src/mc_interp_kernel.cu
================================================
#include "mc_data.cuh"
#include
#include
#include
__device__ static inline float2 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
{
return make_float2(NAN, NAN);
}
// printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
long long vec_ind = indexer[bx][by][bz];
if (vec_ind == -1 || vec_ind >= max_vec_num)
{
return make_float2(NAN, NAN);
}
int batch_ind = vec_batch_mapping[vec_ind];
if (batch_ind == -1)
{
return make_float2(NAN, NAN);
}
// printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
float sdf = cube_sdf[batch_ind][arx][ary][arz];
float std = cube_std[batch_ind][arx][ary][arz];
return make_float2(sdf, std);
}
// Use stddev to weight sdf value.
// #define STD_W_SDF
__device__ static inline float2 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bpos.x >= bsize.x)
{
bpos.x = bsize.x - 1;
rpos.x = r - 1;
}
if (bpos.y >= bsize.y)
{
bpos.y = bsize.y - 1;
rpos.y = r - 1;
}
if (bpos.z >= bsize.z)
{
bpos.z = bsize.z - 1;
rpos.z = r - 1;
}
uint rbound = (r - 1) / 2;
uint rstart = r / 2;
float rmid = r / 2.0f;
float w_xm, w_xp;
int bxm, rxm, bxp, rxp;
int zero_x;
if (rpos.x <= rbound)
{
bxm = -1;
rxm = r;
bxp = 0;
rxp = 0;
w_xp = (float)rpos.x + rmid;
w_xm = rmid - (float)rpos.x;
zero_x = 1;
}
else
{
bxm = 0;
rxm = 0;
bxp = 1;
rxp = -r;
w_xp = (float)rpos.x - rmid;
w_xm = rmid + r - (float)rpos.x;
zero_x = 0;
}
w_xm /= r;
w_xp /= r;
float w_ym, w_yp;
int bym, rym, byp, ryp;
int zero_y;
if (rpos.y <= rbound)
{
bym = -1;
rym = r;
byp = 0;
ryp = 0;
w_yp = (float)rpos.y + rmid;
w_ym = rmid - (float)rpos.y;
zero_y = 1;
}
else
{
bym = 0;
rym = 0;
byp = 1;
ryp = -r;
w_yp = (float)rpos.y - rmid;
w_ym = rmid + r - (float)rpos.y;
zero_y = 0;
}
w_ym /= r;
w_yp /= r;
float w_zm, w_zp;
int bzm, rzm, bzp, rzp;
int zero_z;
if (rpos.z <= rbound)
{
bzm = -1;
rzm = r;
bzp = 0;
rzp = 0;
w_zp = (float)rpos.z + rmid;
w_zm = rmid - (float)rpos.z;
zero_z = 1;
}
else
{
bzm = 0;
rzm = 0;
bzp = 1;
rzp = -r;
w_zp = (float)rpos.z - rmid;
w_zm = rmid + r - (float)rpos.z;
zero_z = 0;
}
w_zm /= r;
w_zp /= r;
rpos.x += rstart;
rpos.y += rstart;
rpos.z += rstart;
// printf("%u %u %u %d %d %d %d %d %d\n", rpos.x, rpos.y, rpos.z, rxm, rxp, rym, ryp, rzm, rzp);
// Tri-linear interpolation of SDF values.
#ifndef STD_W_SDF
float total_weight = 0.0;
#else
float2 total_weight{0.0, 0.0};
#endif
float2 total_sdf{0.0, 0.0};
int zero_det = zero_x * 4 + zero_y * 2 + zero_z;
float2 sdfmmm = query_sdf_raw(bpos.x + bxm, bpos.y + bym, bpos.z + bzm, rpos.x + rxm, rpos.y + rym, rpos.z + rzm,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wmmm = w_xm * w_ym * w_zm;
#ifndef STD_W_SDF
if (!isnan(sdfmmm.x))
{
total_sdf += sdfmmm * wmmm;
total_weight += wmmm;
}
#else
if (!isnan(sdfmmm.x))
{
total_sdf.x += sdfmmm.x * wmmm * sdfmmm.y;
total_weight.x += wmmm * sdfmmm.y;
total_sdf.y += wmmm * sdfmmm.y;
total_weight.y += wmmm;
}
#endif
else if (zero_det == 0)
{
return make_float2(NAN, NAN);
}
float2 sdfmmp = query_sdf_raw(bpos.x + bxm, bpos.y + bym, bpos.z + bzp, rpos.x + rxm, rpos.y + rym, rpos.z + rzp,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wmmp = w_xm * w_ym * w_zp;
#ifndef STD_W_SDF
if (!isnan(sdfmmp.x))
{
total_sdf += sdfmmp * wmmp;
total_weight += wmmp;
}
#else
if (!isnan(sdfmmp.x))
{
total_sdf.x += sdfmmp.x * wmmp * sdfmmp.y;
total_weight.x += wmmp * sdfmmp.y;
total_sdf.y += wmmp * sdfmmp.y;
total_weight.y += wmmp;
}
#endif
else if (zero_det == 1)
{
return make_float2(NAN, NAN);
}
float2 sdfmpm = query_sdf_raw(bpos.x + bxm, bpos.y + byp, bpos.z + bzm, rpos.x + rxm, rpos.y + ryp, rpos.z + rzm,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wmpm = w_xm * w_yp * w_zm;
#ifndef STD_W_SDF
if (!isnan(sdfmpm.x))
{
total_sdf += sdfmpm * wmpm;
total_weight += wmpm;
}
#else
if (!isnan(sdfmpm.x))
{
total_sdf.x += sdfmpm.x * wmpm * sdfmpm.y;
total_weight.x += wmpm * sdfmpm.y;
total_sdf.y += wmpm * sdfmpm.y;
total_weight.y += wmpm;
}
#endif
else if (zero_det == 2)
{
return make_float2(NAN, NAN);
}
float2 sdfmpp = query_sdf_raw(bpos.x + bxm, bpos.y + byp, bpos.z + bzp, rpos.x + rxm, rpos.y + ryp, rpos.z + rzp,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wmpp = w_xm * w_yp * w_zp;
#ifndef STD_W_SDF
if (!isnan(sdfmpp.x))
{
total_sdf += sdfmpp * wmpp;
total_weight += wmpp;
}
#else
if (!isnan(sdfmpp.x))
{
total_sdf.x += sdfmpp.x * wmpp * sdfmpp.y;
total_weight.x += wmpp * sdfmpp.y;
total_sdf.y += wmpp * sdfmpp.y;
total_weight.y += wmpp;
}
#endif
else if (zero_det == 3)
{
return make_float2(NAN, NAN);
}
float2 sdfpmm = query_sdf_raw(bpos.x + bxp, bpos.y + bym, bpos.z + bzm, rpos.x + rxp, rpos.y + rym, rpos.z + rzm,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wpmm = w_xp * w_ym * w_zm;
#ifndef STD_W_SDF
if (!isnan(sdfpmm.x))
{
total_sdf += sdfpmm * wpmm;
total_weight += wpmm;
}
#else
if (!isnan(sdfpmm.x))
{
total_sdf.x += sdfpmm.x * wpmm * sdfpmm.y;
total_weight.x += wpmm * sdfpmm.y;
total_sdf.y += wpmm * sdfpmm.y;
total_weight.y += wpmm;
}
#endif
else if (zero_det == 4)
{
return make_float2(NAN, NAN);
}
float2 sdfpmp = query_sdf_raw(bpos.x + bxp, bpos.y + bym, bpos.z + bzp, rpos.x + rxp, rpos.y + rym, rpos.z + rzp,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wpmp = w_xp * w_ym * w_zp;
#ifndef STD_W_SDF
if (!isnan(sdfpmp.x))
{
total_sdf += sdfpmp * wpmp;
total_weight += wpmp;
}
#else
if (!isnan(sdfpmp.x))
{
total_sdf.x += sdfpmp.x * wpmp * sdfpmp.y;
total_weight.x += wpmp * sdfpmp.y;
total_sdf.y += wpmp * sdfpmp.y;
total_weight.y += wpmp;
}
#endif
else if (zero_det == 5)
{
return make_float2(NAN, NAN);
}
float2 sdfppm = query_sdf_raw(bpos.x + bxp, bpos.y + byp, bpos.z + bzm, rpos.x + rxp, rpos.y + ryp, rpos.z + rzm,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wppm = w_xp * w_yp * w_zm;
#ifndef STD_W_SDF
if (!isnan(sdfppm.x))
{
total_sdf += sdfppm * wppm;
total_weight += wppm;
}
#else
if (!isnan(sdfppm.x))
{
total_sdf.x += sdfppm.x * wppm * sdfppm.y;
total_weight.x += wppm * sdfppm.y;
total_sdf.y += wppm * sdfppm.y;
total_weight.y += wppm;
}
#endif
else if (zero_det == 6)
{
return make_float2(NAN, NAN);
}
float2 sdfppp = query_sdf_raw(bpos.x + bxp, bpos.y + byp, bpos.z + bzp, rpos.x + rxp, rpos.y + ryp, rpos.z + rzp,
max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
float wppp = w_xp * w_yp * w_zp;
#ifndef STD_W_SDF
if (!isnan(sdfppp.x))
{
total_sdf += sdfppp * wppp;
total_weight += wppp;
}
#else
if (!isnan(sdfppp.x))
{
total_sdf.x += sdfppp.x * wppp * sdfppp.y;
total_weight.x += wppp * sdfppp.y;
total_sdf.y += wppp * sdfppp.y;
total_weight.y += wppp;
}
#endif
else if (zero_det == 7)
{
return make_float2(NAN, NAN);
}
// If NAN, will also be handled.
return total_sdf / total_weight;
}
__device__ static inline float4 sdf_interp(const float3 p1, const float3 p2, const float stdp1, const float stdp2,
float valp1, float valp2)
{
if (fabs(0.0f - valp1) < 1.0e-5f)
return make_float4(p1, stdp1);
if (fabs(0.0f - valp2) < 1.0e-5f)
return make_float4(p2, stdp2);
if (fabs(valp1 - valp2) < 1.0e-5f)
return make_float4(p1, stdp1);
float w2 = (0.0f - valp1) / (valp2 - valp1);
float w1 = 1 - w2;
return make_float4(p1.x * w1 + p2.x * w2,
p1.y * w1 + p2.y * w2,
p1.z * w1 + p2.z * w2,
stdp1 * w1 + stdp2 * w2);
}
__global__ static void meshing_cube(const IndexerAccessor indexer,
const ValidBlocksAccessor valid_blocks,
const BackwardMappingAccessor vec_batch_mapping,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
TrianglesAccessor triangles,
TriangleStdAccessor triangle_std,
TriangleVecIdAccessor triangle_flatten_id,
int *__restrict__ triangles_count,
int max_triangles_count,
const uint max_vec_num,
int nx, int ny, int nz,
float max_std)
{
const uint r = cube_sdf.size(1) / 2;
const uint r3 = r * r * r;
const uint num_lif = valid_blocks.size(0);
const float sbs = 1.0f / r; // sub-block-size
const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;
if (lif_id >= num_lif || sub_id >= r3)
{
return;
}
const uint3 bpos = make_uint3(
(valid_blocks[lif_id] / (ny * nz)) % nx,
(valid_blocks[lif_id] / nz) % ny,
valid_blocks[lif_id] % nz);
const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
const uint rx = sub_id / (r * r);
const uint ry = (sub_id / r) % r;
const uint rz = sub_id % r;
// Find all 8 neighbours
float3 points[8];
float2 sdf_vals[8];
sdf_vals[0] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[0].x))
return;
points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
sdf_vals[1] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[1].x))
return;
points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
sdf_vals[2] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[2].x))
return;
points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
sdf_vals[3] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[3].x))
return;
points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
sdf_vals[4] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[4].x))
return;
points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[5] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[5].x))
return;
points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[6] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[6].x))
return;
points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[7] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[7].x))
return;
points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
// Find triangle config.
int cube_type = 0;
if (sdf_vals[0].x < 0)
cube_type |= 1;
if (sdf_vals[1].x < 0)
cube_type |= 2;
if (sdf_vals[2].x < 0)
cube_type |= 4;
if (sdf_vals[3].x < 0)
cube_type |= 8;
if (sdf_vals[4].x < 0)
cube_type |= 16;
if (sdf_vals[5].x < 0)
cube_type |= 32;
if (sdf_vals[6].x < 0)
cube_type |= 64;
if (sdf_vals[7].x < 0)
cube_type |= 128;
// Find vertex position on each edge (weighted by sdf value)
int edge_config = edgeTable[cube_type];
float4 vert_list[12];
if (edge_config == 0)
return;
if (edge_config & 1)
vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0].y, sdf_vals[1].y, sdf_vals[0].x, sdf_vals[1].x);
if (edge_config & 2)
vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1].y, sdf_vals[2].y, sdf_vals[1].x, sdf_vals[2].x);
if (edge_config & 4)
vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2].y, sdf_vals[3].y, sdf_vals[2].x, sdf_vals[3].x);
if (edge_config & 8)
vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3].y, sdf_vals[0].y, sdf_vals[3].x, sdf_vals[0].x);
if (edge_config & 16)
vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4].y, sdf_vals[5].y, sdf_vals[4].x, sdf_vals[5].x);
if (edge_config & 32)
vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5].y, sdf_vals[6].y, sdf_vals[5].x, sdf_vals[6].x);
if (edge_config & 64)
vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6].y, sdf_vals[7].y, sdf_vals[6].x, sdf_vals[7].x);
if (edge_config & 128)
vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7].y, sdf_vals[4].y, sdf_vals[7].x, sdf_vals[4].x);
if (edge_config & 256)
vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0].y, sdf_vals[4].y, sdf_vals[0].x, sdf_vals[4].x);
if (edge_config & 512)
vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1].y, sdf_vals[5].y, sdf_vals[1].x, sdf_vals[5].x);
if (edge_config & 1024)
vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2].y, sdf_vals[6].y, sdf_vals[2].x, sdf_vals[6].x);
if (edge_config & 2048)
vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3].y, sdf_vals[7].y, sdf_vals[3].x, sdf_vals[7].x);
// Write triangles to array.
float4 vp[3];
for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
}
if (vp[0].w > max_std || vp[1].w > max_std || vp[2].w > max_std)
{
continue;
}
int triangle_id = atomicAdd(triangles_count, 1);
if (triangle_id < max_triangles_count)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
triangles[triangle_id][vi][0] = vp[vi].x;
triangles[triangle_id][vi][1] = vp[vi].y;
triangles[triangle_id][vi][2] = vp[vi].z;
triangle_std[triangle_id][vi] = vp[vi].w;
}
triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
}
}
}
std::vector marching_cubes_sparse_interp_cuda(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer
)
{
CHECK_INPUT(indexer);
CHECK_INPUT(valid_blocks);
CHECK_INPUT(cube_sdf);
CHECK_INPUT(cube_std);
CHECK_INPUT(vec_batch_mapping);
assert(max_n_triangles > 0);
const int r = cube_sdf.size(1) / 2;
const int r3 = r * r * r;
const int num_lif = valid_blocks.size(0);
const uint max_vec_num = vec_batch_mapping.size(0);
torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));
torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
dim3 dimBlock = dim3(16, 16);
uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;
uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;
dim3 dimGrid = dim3(xBlocks, yBlocks);
thrust::device_vector n_output(1, 0);
meshing_cube<<>>(
indexer.packed_accessor32(),
valid_blocks.packed_accessor32(),
vec_batch_mapping.packed_accessor32(),
cube_sdf.packed_accessor32(),
cube_std.packed_accessor32(),
triangles.packed_accessor32(),
triangle_std.packed_accessor32(),
triangle_flatten_id.packed_accessor32(),
n_output.data().get(), max_n_triangles, max_vec_num,
n_xyz[0], n_xyz[1], n_xyz[2], max_std);
cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());
int output_n_triangles = n_output[0];
if (output_n_triangles < max_n_triangles)
{
// Trim output tensor if it is not full.
triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_std = triangle_std.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
}
else
{
// Otherwise spawn a warning.
std::cerr << "Warning from marching cube: the max triangle number is too small " << output_n_triangles << " vs " << max_n_triangles << std::endl;
}
return {triangles, triangle_flatten_id, triangle_std};
}
================================================
FILE: third_party/marching_cubes/src/mc_kernel.cu
================================================
#include "mc_data.cuh"
#include
#include
#include
__device__ static inline float2 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
{
return make_float2(NAN, NAN);
}
// printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
long long vec_ind = indexer[bx][by][bz];
if (vec_ind == -1 || vec_ind >= max_vec_num)
{
return make_float2(NAN, NAN);
}
int batch_ind = vec_batch_mapping[vec_ind];
if (batch_ind == -1)
{
return make_float2(NAN, NAN);
}
// printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
float sdf = cube_sdf[batch_ind][arx][ary][arz];
float std = cube_std[batch_ind][arx][ary][arz];
return make_float2(sdf, std);
}
// Use stddev to weight sdf value.
// #define STD_W_SDF
__device__ static inline float2 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bpos.x >= bsize.x)
{
bpos.x = bsize.x - 1;
rpos.x = r - 1;
}
if (bpos.y >= bsize.y)
{
bpos.y = bsize.y - 1;
rpos.y = r - 1;
}
if (bpos.z >= bsize.z)
{
bpos.z = bsize.z - 1;
rpos.z = r - 1;
}
if (rpos.x == r)
{
bpos.x += 1;
rpos.x = 0;
}
if (rpos.y == r)
{
bpos.y += 1;
rpos.y = 0;
}
if (rpos.z == r)
{
bpos.z += 1;
rpos.z = 0;
}
float2 total_sdf = query_sdf_raw(bpos.x, bpos.y, bpos.z, rpos.x, rpos.y, rpos.z, max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
// If NAN, will also be handled.
return total_sdf;
}
__device__ static inline float4 sdf_interp(const float3 p1, const float3 p2, const float stdp1, const float stdp2,
float valp1, float valp2)
{
if (fabs(0.0f - valp1) < 1.0e-5f)
return make_float4(p1, stdp1);
if (fabs(0.0f - valp2) < 1.0e-5f)
return make_float4(p2, stdp2);
if (fabs(valp1 - valp2) < 1.0e-5f)
return make_float4(p1, stdp1);
float w2 = (0.0f - valp1) / (valp2 - valp1);
float w1 = 1 - w2;
return make_float4(p1.x * w1 + p2.x * w2,
p1.y * w1 + p2.y * w2,
p1.z * w1 + p2.z * w2,
stdp1 * w1 + stdp2 * w2);
}
__global__ static void meshing_cube(const IndexerAccessor indexer,
const ValidBlocksAccessor valid_blocks,
const BackwardMappingAccessor vec_batch_mapping,
const CubeSDFAccessor cube_sdf,
const CubeSDFAccessor cube_std,
TrianglesAccessor triangles,
TriangleStdAccessor triangle_std,
TriangleVecIdAccessor triangle_flatten_id,
int *__restrict__ triangles_count,
int max_triangles_count,
const uint max_vec_num,
int nx, int ny, int nz,
float max_std)
{
const uint r = cube_sdf.size(1);
const uint r3 = r * r * r;
const uint num_lif = valid_blocks.size(0);
const float sbs = 1.0f / r; // sub-block-size
const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;
if (lif_id >= num_lif || sub_id >= r3)
{
return;
}
const uint3 bpos = make_uint3(
(valid_blocks[lif_id] / (ny * nz)) % nx,
(valid_blocks[lif_id] / nz) % ny,
valid_blocks[lif_id] % nz);
const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
const uint rx = sub_id / (r * r);
const uint ry = (sub_id / r) % r;
const uint rz = sub_id % r;
// Find all 8 neighbours
float3 points[8];
float2 sdf_vals[8];
sdf_vals[0] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[0].x))
return;
points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
sdf_vals[1] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[1].x))
return;
points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
sdf_vals[2] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[2].x))
return;
points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
sdf_vals[3] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[3].x))
return;
points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
sdf_vals[4] = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[4].x))
return;
points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[5] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[5].x))
return;
points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[6] = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[6].x))
return;
points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
sdf_vals[7] = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_vals[7].x))
return;
points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
// Find triangle config.
int cube_type = 0;
if (sdf_vals[0].x < 0)
cube_type |= 1;
if (sdf_vals[1].x < 0)
cube_type |= 2;
if (sdf_vals[2].x < 0)
cube_type |= 4;
if (sdf_vals[3].x < 0)
cube_type |= 8;
if (sdf_vals[4].x < 0)
cube_type |= 16;
if (sdf_vals[5].x < 0)
cube_type |= 32;
if (sdf_vals[6].x < 0)
cube_type |= 64;
if (sdf_vals[7].x < 0)
cube_type |= 128;
// Find vertex position on each edge (weighted by sdf value)
int edge_config = edgeTable[cube_type];
float4 vert_list[12];
if (edge_config == 0)
return;
if (edge_config & 1)
vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0].y, sdf_vals[1].y, sdf_vals[0].x, sdf_vals[1].x);
if (edge_config & 2)
vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1].y, sdf_vals[2].y, sdf_vals[1].x, sdf_vals[2].x);
if (edge_config & 4)
vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2].y, sdf_vals[3].y, sdf_vals[2].x, sdf_vals[3].x);
if (edge_config & 8)
vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3].y, sdf_vals[0].y, sdf_vals[3].x, sdf_vals[0].x);
if (edge_config & 16)
vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4].y, sdf_vals[5].y, sdf_vals[4].x, sdf_vals[5].x);
if (edge_config & 32)
vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5].y, sdf_vals[6].y, sdf_vals[5].x, sdf_vals[6].x);
if (edge_config & 64)
vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6].y, sdf_vals[7].y, sdf_vals[6].x, sdf_vals[7].x);
if (edge_config & 128)
vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7].y, sdf_vals[4].y, sdf_vals[7].x, sdf_vals[4].x);
if (edge_config & 256)
vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0].y, sdf_vals[4].y, sdf_vals[0].x, sdf_vals[4].x);
if (edge_config & 512)
vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1].y, sdf_vals[5].y, sdf_vals[1].x, sdf_vals[5].x);
if (edge_config & 1024)
vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2].y, sdf_vals[6].y, sdf_vals[2].x, sdf_vals[6].x);
if (edge_config & 2048)
vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3].y, sdf_vals[7].y, sdf_vals[3].x, sdf_vals[7].x);
// Write triangles to array.
float4 vp[3];
for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
}
int triangle_id = atomicAdd(triangles_count, 1);
if (triangle_id < max_triangles_count)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
triangles[triangle_id][vi][0] = vp[vi].x;
triangles[triangle_id][vi][1] = vp[vi].y;
triangles[triangle_id][vi][2] = vp[vi].z;
triangle_std[triangle_id][vi] = vp[vi].w;
}
triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
}
}
}
std::vector marching_cubes_sparse(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_sdf, // (M, rx, ry, rz)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer
)
{
CHECK_INPUT(indexer);
CHECK_INPUT(valid_blocks);
CHECK_INPUT(cube_sdf);
CHECK_INPUT(cube_std);
CHECK_INPUT(vec_batch_mapping);
assert(max_n_triangles > 0);
const int r = cube_sdf.size(1);
const int r3 = r * r * r;
const int num_lif = valid_blocks.size(0);
const uint max_vec_num = vec_batch_mapping.size(0);
torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3},
torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));
torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
dim3 dimBlock = dim3(16, 16);
uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;
uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;
dim3 dimGrid = dim3(xBlocks, yBlocks);
thrust::device_vector n_output(1, 0);
meshing_cube<<>>(
indexer.packed_accessor32(),
valid_blocks.packed_accessor32(),
vec_batch_mapping.packed_accessor32(),
cube_sdf.packed_accessor32(),
cube_std.packed_accessor32(),
triangles.packed_accessor32(),
triangle_std.packed_accessor32(),
triangle_flatten_id.packed_accessor32(),
n_output.data().get(), max_n_triangles, max_vec_num,
n_xyz[0], n_xyz[1], n_xyz[2], max_std);
cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());
int output_n_triangles = n_output[0];
if (output_n_triangles < max_n_triangles)
{
// Trim output tensor if it is not full.
triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_std = triangle_std.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
}
else
{
// Otherwise spawn a warning.
std::cerr << "Warning from marching cube: the max triangle number is too small " << output_n_triangles << " vs " << max_n_triangles << std::endl;
}
return {triangles, triangle_flatten_id, triangle_std};
}
================================================
FILE: third_party/marching_cubes/src/mc_kernel_colour.cu
================================================
#include "mc_data.cuh"
#include
#include
#include
__device__ static inline float4 query_sdf_raw(uint bx, uint by, uint bz, uint arx, uint ary, uint arz,
const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFRGBAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bx >= indexer.size(0) || by >= indexer.size(1) || bz >= indexer.size(2))
{
return make_float4(NAN, NAN, NAN, NAN);
}
// printf("B-Getting: %d %d %d --> %d, %d, %d\n", bx, by, bz, indexer.size(0), indexer.size(1), indexer.size(2));
long long vec_ind = indexer[bx][by][bz];
if (vec_ind == -1 || vec_ind >= max_vec_num)
{
return make_float4(NAN, NAN, NAN, NAN);
}
int batch_ind = vec_batch_mapping[vec_ind];
if (batch_ind == -1)
{
return make_float4(NAN, NAN, NAN, NAN);
}
// printf("Getting: %d %d %d %d --> %d %d\n", batch_ind, arx, ary, arz, cube_sdf.size(0), cube_sdf.size(1));
return make_float4(cube_sdf[batch_ind][arx][ary][arz][3],
cube_sdf[batch_ind][arx][ary][arz][0],
cube_sdf[batch_ind][arx][ary][arz][1],
cube_sdf[batch_ind][arx][ary][arz][2]);
}
// Use stddev to weight sdf value.
// #define STD_W_SDF
__device__ static inline float4 get_sdf(const uint3 bsize, const uint r, uint3 bpos, uint3 rpos, const uint max_vec_num,
const IndexerAccessor indexer,
const CubeSDFRGBAccessor cube_sdf,
const CubeSDFAccessor cube_std,
const BackwardMappingAccessor vec_batch_mapping)
{
if (bpos.x >= bsize.x)
{
bpos.x = bsize.x - 1;
rpos.x = r - 1;
}
if (bpos.y >= bsize.y)
{
bpos.y = bsize.y - 1;
rpos.y = r - 1;
}
if (bpos.z >= bsize.z)
{
bpos.z = bsize.z - 1;
rpos.z = r - 1;
}
if (rpos.x == r)
{
bpos.x += 1;
rpos.x = 0;
}
if (rpos.y == r)
{
bpos.y += 1;
rpos.y = 0;
}
if (rpos.z == r)
{
bpos.z += 1;
rpos.z = 0;
}
float4 total_sdf = query_sdf_raw(bpos.x, bpos.y, bpos.z, rpos.x, rpos.y, rpos.z, max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
// If NAN, will also be handled.
return total_sdf;
}
__device__ static inline float3 sdf_interp(const float3 p1, const float3 p2,
float valp1, float valp2)
{
if (fabs(0.0f - valp1) < 1.0e-5f)
return p1;
if (fabs(0.0f - valp2) < 1.0e-5f)
return p2;
if (fabs(valp1 - valp2) < 1.0e-5f)
return p1;
float w2 = (0.0f - valp1) / (valp2 - valp1);
float w1 = 1 - w2;
return make_float3(p1.x * w1 + p2.x * w2,
p1.y * w1 + p2.y * w2,
p1.z * w1 + p2.z * w2);
}
__global__ static void meshing_cube_colour(const IndexerAccessor indexer,
const ValidBlocksAccessor valid_blocks,
const BackwardMappingAccessor vec_batch_mapping,
const CubeSDFRGBAccessor cube_sdf,
const CubeSDFAccessor cube_std,
TrianglesAccessor triangles,
TrianglesAccessor vertex_colours,
TriangleVecIdAccessor triangle_flatten_id,
int *__restrict__ triangles_count,
int max_triangles_count,
const uint max_vec_num,
int nx, int ny, int nz,
float max_std)
{
const uint r = cube_sdf.size(1);
const uint r3 = r * r * r;
const uint num_lif = valid_blocks.size(0);
const float sbs = 1.0f / r; // sub-block-size
const uint lif_id = blockIdx.x * blockDim.x + threadIdx.x;
const uint sub_id = blockIdx.y * blockDim.y + threadIdx.y;
if (lif_id >= num_lif || sub_id >= r3)
{
return;
}
const uint3 bpos = make_uint3(
(valid_blocks[lif_id] / (ny * nz)) % nx,
(valid_blocks[lif_id] / nz) % ny,
valid_blocks[lif_id] % nz);
const uint3 bsize = make_uint3(indexer.size(0), indexer.size(1), indexer.size(2));
const uint rx = sub_id / (r * r);
const uint ry = (sub_id / r) % r;
const uint rz = sub_id % r;
// Find all 8 neighbours
float3 points[8];
float3 colours[8];
float sdf_vals[8];
float4 sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[0] = sdf_val.x;
points[0] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
colours[0] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[1] = sdf_val.x;
points[1] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + rz * sbs);
colours[1] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[2] = sdf_val.x;
points[2] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
colours[2] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[3] = sdf_val.x;
points[3] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + rz * sbs);
colours[3] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[4] = sdf_val.x;
points[4] = make_float3(bpos.x + rx * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
colours[4] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[5] = sdf_val.x;
points[5] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + ry * sbs, bpos.z + (rz + 1) * sbs);
colours[5] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx + 1, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[6] = sdf_val.x;
points[6] = make_float3(bpos.x + (rx + 1) * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
colours[6] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
sdf_val = get_sdf(bsize, r, bpos, make_uint3(rx, ry + 1, rz + 1), max_vec_num, indexer, cube_sdf, cube_std, vec_batch_mapping);
if (isnan(sdf_val.x))
return;
sdf_vals[7] = sdf_val.x;
points[7] = make_float3(bpos.x + rx * sbs, bpos.y + (ry + 1) * sbs, bpos.z + (rz + 1) * sbs);
colours[7] = make_float3(sdf_val.y, sdf_val.z, sdf_val.w);
// Find triangle config.
int cube_type = 0;
if (sdf_vals[0] < 0)
cube_type |= 1;
if (sdf_vals[1] < 0)
cube_type |= 2;
if (sdf_vals[2] < 0)
cube_type |= 4;
if (sdf_vals[3] < 0)
cube_type |= 8;
if (sdf_vals[4] < 0)
cube_type |= 16;
if (sdf_vals[5] < 0)
cube_type |= 32;
if (sdf_vals[6] < 0)
cube_type |= 64;
if (sdf_vals[7] < 0)
cube_type |= 128;
// Find vertex position on each edge (weighted by sdf value)
int edge_config = edgeTable[cube_type];
float3 vert_list[12];
float3 rgb_list[12];
if (edge_config == 0)
return;
if (edge_config & 1)
{
vert_list[0] = sdf_interp(points[0], points[1], sdf_vals[0], sdf_vals[1]);
rgb_list[0] = sdf_interp(colours[0], colours[1], sdf_vals[0], sdf_vals[1]);
}
if (edge_config & 2)
{
vert_list[1] = sdf_interp(points[1], points[2], sdf_vals[1], sdf_vals[2]);
rgb_list[1] = sdf_interp(colours[1], colours[2], sdf_vals[1], sdf_vals[2]);
}
if (edge_config & 4)
{
vert_list[2] = sdf_interp(points[2], points[3], sdf_vals[2], sdf_vals[3]);
rgb_list[2] = sdf_interp(colours[2], colours[3], sdf_vals[2], sdf_vals[3]);
}
if (edge_config & 8)
{
vert_list[3] = sdf_interp(points[3], points[0], sdf_vals[3], sdf_vals[0]);
rgb_list[3] = sdf_interp(colours[3], colours[0], sdf_vals[3], sdf_vals[0]);
}
if (edge_config & 16)
{
vert_list[4] = sdf_interp(points[4], points[5], sdf_vals[4], sdf_vals[5]);
rgb_list[4] = sdf_interp(colours[4], colours[5], sdf_vals[4], sdf_vals[5]);
}
if (edge_config & 32)
{
vert_list[5] = sdf_interp(points[5], points[6], sdf_vals[5], sdf_vals[6]);
rgb_list[5] = sdf_interp(colours[5], colours[6], sdf_vals[5], sdf_vals[6]);
}
if (edge_config & 64)
{
vert_list[6] = sdf_interp(points[6], points[7], sdf_vals[6], sdf_vals[7]);
rgb_list[6] = sdf_interp(colours[6], colours[7], sdf_vals[6], sdf_vals[7]);
}
if (edge_config & 128)
{
vert_list[7] = sdf_interp(points[7], points[4], sdf_vals[7], sdf_vals[4]);
rgb_list[7] = sdf_interp(colours[7], colours[4], sdf_vals[7], sdf_vals[4]);
}
if (edge_config & 256)
{
vert_list[8] = sdf_interp(points[0], points[4], sdf_vals[0], sdf_vals[4]);
rgb_list[8] = sdf_interp(colours[0], colours[4], sdf_vals[0], sdf_vals[4]);
}
if (edge_config & 512)
{
vert_list[9] = sdf_interp(points[1], points[5], sdf_vals[1], sdf_vals[5]);
rgb_list[9] = sdf_interp(colours[1], colours[5], sdf_vals[1], sdf_vals[5]);
}
if (edge_config & 1024)
{
vert_list[10] = sdf_interp(points[2], points[6], sdf_vals[2], sdf_vals[6]);
rgb_list[10] = sdf_interp(colours[2], colours[6], sdf_vals[2], sdf_vals[6]);
}
if (edge_config & 2048)
{
vert_list[11] = sdf_interp(points[3], points[7], sdf_vals[3], sdf_vals[7]);
rgb_list[11] = sdf_interp(colours[3], colours[7], sdf_vals[3], sdf_vals[7]);
}
// Write triangles to array.
float3 vp[3];
float3 vc[3];
for (int i = 0; triangleTable[cube_type][i] != -1; i += 3)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
vp[vi] = vert_list[triangleTable[cube_type][i + vi]];
vc[vi] = rgb_list[triangleTable[cube_type][i + vi]];
}
int triangle_id = atomicAdd(triangles_count, 1);
if (triangle_id < max_triangles_count)
{
#pragma unroll
for (int vi = 0; vi < 3; ++vi)
{
triangles[triangle_id][vi][0] = vp[vi].x;
triangles[triangle_id][vi][1] = vp[vi].y;
triangles[triangle_id][vi][2] = vp[vi].z;
vertex_colours[triangle_id][vi][0] = vc[vi].x;
vertex_colours[triangle_id][vi][1] = vc[vi].y;
vertex_colours[triangle_id][vi][2] = vc[vi].z;
}
triangle_flatten_id[triangle_id] = valid_blocks[lif_id];
}
}
}
std::vector marching_cubes_sparse_colour(
torch::Tensor indexer, // (nx, ny, nz) -> data_id
torch::Tensor valid_blocks, // (K, )
torch::Tensor vec_batch_mapping, //
torch::Tensor cube_rgb_sdf, // (M, rx, ry, rz, 4)
torch::Tensor cube_std, // (M, rx, ry, rz)
const std::vector &n_xyz, // [nx, ny, nz]
float max_std, // Prune all vertices
int max_n_triangles // Maximum number of triangle buffer
)
{
CHECK_INPUT(indexer);
CHECK_INPUT(valid_blocks);
CHECK_INPUT(cube_rgb_sdf);
CHECK_INPUT(cube_std);
CHECK_INPUT(vec_batch_mapping);
assert(max_n_triangles > 0);
const int r = cube_rgb_sdf.size(1);
const int r3 = r * r * r;
const int num_lif = valid_blocks.size(0);
const uint max_vec_num = vec_batch_mapping.size(0);
torch::Tensor triangles = torch::empty({max_n_triangles, 3, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor vertex_colours = torch::empty({max_n_triangles, 3, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
torch::Tensor triangle_flatten_id = torch::empty({max_n_triangles}, torch::dtype(torch::kLong).device(torch::kCUDA));
torch::Tensor triangle_std = torch::empty({max_n_triangles, 3}, torch::dtype(torch::kFloat32).device(torch::kCUDA));
dim3 dimBlock = dim3(16, 16);
uint xBlocks = (num_lif + dimBlock.x - 1) / dimBlock.x;
uint yBlocks = (r3 + dimBlock.y - 1) / dimBlock.y;
dim3 dimGrid = dim3(xBlocks, yBlocks);
thrust::device_vector n_output(1, 0);
meshing_cube_colour<<>>(
indexer.packed_accessor32(),
valid_blocks.packed_accessor32(),
vec_batch_mapping.packed_accessor32(),
cube_rgb_sdf.packed_accessor32(),
cube_std.packed_accessor32(),
triangles.packed_accessor32(),
vertex_colours.packed_accessor32(),
triangle_flatten_id.packed_accessor32(),
n_output.data().get(), max_n_triangles, max_vec_num,
n_xyz[0], n_xyz[1], n_xyz[2], max_std);
cudaStreamSynchronize(at::cuda::getCurrentCUDAStream());
int output_n_triangles = n_output[0];
if (output_n_triangles < max_n_triangles)
{
// Trim output tensor if it is not full.
triangles = triangles.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
vertex_colours = vertex_colours.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
triangle_flatten_id = triangle_flatten_id.index({torch::indexing::Slice(torch::indexing::None, output_n_triangles)});
}
else
{
// Otherwise spawn a warning.
std::cerr << "Warning from marching cube: the max triangle number is too small " << output_n_triangles << " vs " << max_n_triangles << std::endl;
}
return {triangles, vertex_colours, triangle_flatten_id};
}
================================================
FILE: third_party/sparse_octree/include/octree.h
================================================
#include
#include
#include
enum OcType
{
NONLEAF = -1,
SURFACE = 0,
FEATURE = 1
};
class Octant : public torch::CustomClassHolder
{
public:
inline Octant()
{
code_ = 0;
side_ = 0;
index_ = next_index_++;
depth_ = -1;
is_leaf_ = false;
children_mask_ = 0;
type_ = NONLEAF;
for (unsigned int i = 0; i < 8; i++)
{
child_ptr_[i] = nullptr;
// feature_index_[i] = -1;
}
}
~Octant() {}
// std::shared_ptr &child(const int x, const int y, const int z)
// {
// return child_ptr_[x + y * 2 + z * 4];
// };
// std::shared_ptr &child(const int offset)
// {
// return child_ptr_[offset];
// }
Octant *&child(const int x, const int y, const int z)
{
return child_ptr_[x + y * 2 + z * 4];
};
Octant *&child(const int offset)
{
return child_ptr_[offset];
}
uint64_t code_;
bool is_leaf_;
unsigned int side_;
unsigned char children_mask_;
// std::shared_ptr child_ptr_[8];
// int feature_index_[8];
int index_;
int depth_;
int type_;
// int feat_index_;
Octant *child_ptr_[8];
static int next_index_;
};
class Octree : public torch::CustomClassHolder
{
public:
Octree();
// temporal solution
Octree(int64_t grid_dim, int64_t feat_dim, double voxel_size, std::vector all_pts);
~Octree();
void init(int64_t grid_dim, int64_t feat_dim, double voxel_size);
// allocate voxels
void insert(torch::Tensor vox);
double try_insert(torch::Tensor pts);
// find a particular octant
Octant *find_octant(std::vector coord);
// test intersections
bool has_voxel(torch::Tensor pose);
// query features
torch::Tensor get_features(torch::Tensor pts);
// get all voxels
torch::Tensor get_voxels();
std::vector get_voxel_recursive(Octant *n);
// get leaf voxels
torch::Tensor get_leaf_voxels();
std::vector get_leaf_voxel_recursive(Octant *n);
// count nodes
int64_t count_nodes();
int64_t count_recursive(Octant *n);
// count leaf nodes
int64_t count_leaf_nodes();
// int64_t leaves_count_recursive(std::shared_ptr n);
int64_t leaves_count_recursive(Octant *n);
// get voxel centres and childrens
std::tuple get_centres_and_children();
public:
int size_;
int feat_dim_;
int max_level_;
// temporal solution
double voxel_size_;
std::vector all_pts;
private:
std::set all_keys;
// std::shared_ptr root_;
Octant *root_;
// static int feature_index;
// internal count function
std::pair count_nodes_internal();
std::pair count_recursive_internal(Octant *n);
};
================================================
FILE: third_party/sparse_octree/include/test.h
================================================
#pragma once
#include
#define MAX_BITS 21
// #define SCALE_MASK ((uint64_t)0x1FF)
#define SCALE_MASK ((uint64_t)0x1)
/*
* Mask generated with:
MASK[0] = 0x7000000000000000,
for(int i = 1; i < 21; ++i) {
MASK[i] = MASK[i-1] | (MASK[0] >> (i*3));
std::bitset<64> b(MASK[i]);
std::cout << std::hex << b.to_ullong() << std::endl;
}
*
*/
constexpr uint64_t MASK[] = {
0x7000000000000000,
0x7e00000000000000,
0x7fc0000000000000,
0x7ff8000000000000,
0x7fff000000000000,
0x7fffe00000000000,
0x7ffffc0000000000,
0x7fffff8000000000,
0x7ffffff000000000,
0x7ffffffe00000000,
0x7fffffffc0000000,
0x7ffffffff8000000,
0x7fffffffff000000,
0x7fffffffffe00000,
0x7ffffffffffc0000,
0x7fffffffffff8000,
0x7ffffffffffff000,
0x7ffffffffffffe00,
0x7fffffffffffffc0,
0x7ffffffffffffff8,
0x7fffffffffffffff};
inline int64_t expand(int64_t value)
{
int64_t x = value & 0x1fffff;
x = (x | x << 32) & 0x1f00000000ffff;
x = (x | x << 16) & 0x1f0000ff0000ff;
x = (x | x << 8) & 0x100f00f00f00f00f;
x = (x | x << 4) & 0x10c30c30c30c30c3;
x = (x | x << 2) & 0x1249249249249249;
return x;
}
inline uint64_t compact(uint64_t value)
{
uint64_t x = value & 0x1249249249249249;
x = (x | x >> 2) & 0x10c30c30c30c30c3;
x = (x | x >> 4) & 0x100f00f00f00f00f;
x = (x | x >> 8) & 0x1f0000ff0000ff;
x = (x | x >> 16) & 0x1f00000000ffff;
x = (x | x >> 32) & 0x1fffff;
return x;
}
inline int64_t compute_morton(int64_t x, int64_t y, int64_t z)
{
int64_t code = 0;
x = expand(x);
y = expand(y) << 1;
z = expand(z) << 2;
code = x | y | z;
return code;
}
inline torch::Tensor encode_torch(torch::Tensor coords)
{
torch::Tensor outs = torch::zeros({coords.size(0), 1}, dtype(torch::kInt64));
for (int i = 0; i < coords.size(0); ++i)
{
int64_t x = coords.data_ptr()[i * 3];
int64_t y = coords.data_ptr()[i * 3 + 1];
int64_t z = coords.data_ptr()[i * 3];
outs.data_ptr()[i] = (compute_morton(x, y, z) & MASK[MAX_BITS - 1]);
}
return outs;
}
================================================
FILE: third_party/sparse_octree/include/utils.h
================================================
#pragma once
#include
#include
#define MAX_BITS 21
// #define SCALE_MASK ((uint64_t)0x1FF)
#define SCALE_MASK ((uint64_t)0x1)
template
struct Vector3
{
Vector3() : x(0), y(0), z(0) {}
Vector3(T x_, T y_, T z_) : x(x_), y(y_), z(z_) {}
Vector3 operator+(const Vector3 &b)
{
return Vector3(x + b.x, y + b.y, z + b.z);
}
Vector3 operator-(const Vector3 &b)
{
return Vector3(x - b.x, y - b.y, z - b.z);
}
T x, y, z;
};
typedef Vector3 Vector3i;
typedef Vector3 Vector3f;
/*
* Mask generated with:
MASK[0] = 0x7000000000000000,
for(int i = 1; i < 21; ++i) {
MASK[i] = MASK[i-1] | (MASK[0] >> (i*3));
std::bitset<64> b(MASK[i]);
std::cout << std::hex << b.to_ullong() << std::endl;
}
*
*/
constexpr uint64_t MASK[] = {
0x7000000000000000,
0x7e00000000000000,
0x7fc0000000000000,
0x7ff8000000000000,
0x7fff000000000000,
0x7fffe00000000000,
0x7ffffc0000000000,
0x7fffff8000000000,
0x7ffffff000000000,
0x7ffffffe00000000,
0x7fffffffc0000000,
0x7ffffffff8000000,
0x7fffffffff000000,
0x7fffffffffe00000,
0x7ffffffffffc0000,
0x7fffffffffff8000,
0x7ffffffffffff000,
0x7ffffffffffffe00,
0x7fffffffffffffc0,
0x7ffffffffffffff8,
0x7fffffffffffffff};
inline uint64_t expand(unsigned long long value)
{
uint64_t x = value & 0x1fffff;
x = (x | x << 32) & 0x1f00000000ffff;
x = (x | x << 16) & 0x1f0000ff0000ff;
x = (x | x << 8) & 0x100f00f00f00f00f;
x = (x | x << 4) & 0x10c30c30c30c30c3;
x = (x | x << 2) & 0x1249249249249249;
return x;
}
inline uint64_t compact(uint64_t value)
{
uint64_t x = value & 0x1249249249249249;
x = (x | x >> 2) & 0x10c30c30c30c30c3;
x = (x | x >> 4) & 0x100f00f00f00f00f;
x = (x | x >> 8) & 0x1f0000ff0000ff;
x = (x | x >> 16) & 0x1f00000000ffff;
x = (x | x >> 32) & 0x1fffff;
return x;
}
inline uint64_t compute_morton(uint64_t x, uint64_t y, uint64_t z)
{
uint64_t code = 0;
x = expand(x);
y = expand(y) << 1;
z = expand(z) << 2;
code = x | y | z;
return code;
}
inline Eigen::Vector3i decode(const uint64_t code)
{
return {
compact(code >> 0ull),
compact(code >> 1ull),
compact(code >> 2ull)};
}
inline uint64_t encode(const int x, const int y, const int z)
{
return (compute_morton(x, y, z) & MASK[MAX_BITS - 1]);
}
================================================
FILE: third_party/sparse_octree/setup.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
import glob
_ext_sources = glob.glob("src/*.cpp")
setup(
name='svo',
ext_modules=[
CppExtension(
name='svo',
sources=_ext_sources,
include_dirs=["./include"],
extra_compile_args={
"cxx": ["-O2", "-I./include"]
},
)
],
cmdclass={
'build_ext': BuildExtension
}
)
================================================
FILE: third_party/sparse_octree/src/bindings.cpp
================================================
#include "../include/octree.h"
#include "../include/test.h"
TORCH_LIBRARY(svo, m)
{
m.def("encode", &encode_torch);
m.class_("Octant")
.def(torch::init<>());
m.class_("Octree")
.def(torch::init<>())
.def("init", &Octree::init)
.def("insert", &Octree::insert)
.def("try_insert", &Octree::try_insert)
.def("get_voxels", &Octree::get_voxels)
.def("get_leaf_voxels", &Octree::get_leaf_voxels)
.def("get_features", &Octree::get_features)
.def("count_nodes", &Octree::count_nodes)
.def("count_leaf_nodes", &Octree::count_leaf_nodes)
.def("has_voxel", &Octree::has_voxel)
.def("get_centres_and_children", &Octree::get_centres_and_children)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr& self) -> std::tuple> {
return std::make_tuple(self->size_, self->feat_dim_, self->voxel_size_, self->all_pts);
},
// __setstate__
[](std::tuple> state) {
return c10::make_intrusive(std::get<0>(state), std::get<1>(state), std::get<2>(state), std::get<3>(state));
});
}
================================================
FILE: third_party/sparse_octree/src/octree.cpp
================================================
#include "../include/octree.h"
#include "../include/utils.h"
#include
#include
// #define MAX_HIT_VOXELS 10
// #define MAX_NUM_VOXELS 10000
int Octant::next_index_ = 0;
// int Octree::feature_index = 0;
int incr_x[8] = {0, 0, 0, 0, 1, 1, 1, 1};
int incr_y[8] = {0, 0, 1, 1, 0, 0, 1, 1};
int incr_z[8] = {0, 1, 0, 1, 0, 1, 0, 1};
Octree::Octree()
{
}
Octree::Octree(int64_t grid_dim, int64_t feat_dim, double voxel_size, std::vector all_pts)
{
Octant::next_index_ = 0;
init(grid_dim, feat_dim, voxel_size);
for (auto &pt : all_pts)
{
insert(pt);
}
}
Octree::~Octree()
{
}
void Octree::init(int64_t grid_dim, int64_t feat_dim, double voxel_size)
{
size_ = grid_dim;
feat_dim_ = feat_dim;
voxel_size_ = voxel_size;
max_level_ = log2(size_);
// root_ = std::make_shared();
root_ = new Octant();
root_->side_ = size_;
// root_->depth_ = 0;
root_->is_leaf_ = false;
// feats_allocated_ = 0;
// auto options = torch::TensorOptions().requires_grad(true);
// feats_array_ = torch::randn({MAX_NUM_VOXELS, feat_dim}, options) * 0.01;
}
void Octree::insert(torch::Tensor pts)
{
// temporal solution
all_pts.push_back(pts);
if (root_ == nullptr)
{
std::cout << "Octree not initialized!" << std::endl;
}
auto points = pts.accessor();
if (points.size(1) != 3)
{
std::cout << "Point dimensions mismatch: inputs are " << points.size(1) << " expect 3" << std::endl;
return;
}
for (int i = 0; i < points.size(0); ++i)
{
for (int j = 0; j < 8; ++j)
{
int x = points[i][0] + incr_x[j];
int y = points[i][1] + incr_y[j];
int z = points[i][2] + incr_z[j];
uint64_t key = encode(x, y, z);
all_keys.insert(key);
const unsigned int shift = MAX_BITS - max_level_ - 1;
auto n = root_;
unsigned edge = size_ / 2;
for (int d = 1; d <= max_level_; edge /= 2, ++d)
{
const int childid = ((x & edge) > 0) + 2 * ((y & edge) > 0) + 4 * ((z & edge) > 0);
// std::cout << "Level: " << d << " ChildID: " << childid << std::endl;
auto tmp = n->child(childid);
if (!tmp)
{
const uint64_t code = key & MASK[d + shift];
const bool is_leaf = (d == max_level_);
// tmp = std::make_shared();
tmp = new Octant();
tmp->code_ = code;
tmp->side_ = edge;
tmp->is_leaf_ = is_leaf;
tmp->type_ = is_leaf ? (j == 0 ? SURFACE : FEATURE) : NONLEAF;
n->children_mask_ = n->children_mask_ | (1 << childid);
n->child(childid) = tmp;
}
else
{
if (tmp->type_ == FEATURE && j == 0)
tmp->type_ = SURFACE;
}
n = tmp;
}
}
}
}
double Octree::try_insert(torch::Tensor pts)
{
if (root_ == nullptr)
{
std::cout << "Octree not initialized!" << std::endl;
}
auto points = pts.accessor();
if (points.size(1) != 3)
{
std::cout << "Point dimensions mismatch: inputs are " << points.size(1) << " expect 3" << std::endl;
return -1.0;
}
std::set tmp_keys;
for (int i = 0; i < points.size(0); ++i)
{
for (int j = 0; j < 8; ++j)
{
int x = points[i][0] + incr_x[j];
int y = points[i][1] + incr_y[j];
int z = points[i][2] + incr_z[j];
uint64_t key = encode(x, y, z);
tmp_keys.insert(key);
}
}
std::set result;
std::set_intersection(all_keys.begin(), all_keys.end(),
tmp_keys.begin(), tmp_keys.end(),
std::inserter(result, result.end()));
double overlap_ratio = 1.0 * result.size() / tmp_keys.size();
return overlap_ratio;
}
Octant *Octree::find_octant(std::vector coord)
{
int x = int(coord[0]);
int y = int(coord[1]);
int z = int(coord[2]);
// uint64_t key = encode(x, y, z);
// const unsigned int shift = MAX_BITS - max_level_ - 1;
auto n = root_;
unsigned edge = size_ / 2;
for (int d = 1; d <= max_level_; edge /= 2, ++d)
{
const int childid = ((x & edge) > 0) + 2 * ((y & edge) > 0) + 4 * ((z & edge) > 0);
auto tmp = n->child(childid);
if (!tmp)
return nullptr;
n = tmp;
}
return n;
}
bool Octree::has_voxel(torch::Tensor pts)
{
if (root_ == nullptr)
{
std::cout << "Octree not initialized!" << std::endl;
}
auto points = pts.accessor();
if (points.size(0) != 3)
{
return false;
}
int x = int(points[0]);
int y = int(points[1]);
int z = int(points[2]);
auto n = root_;
unsigned edge = size_ / 2;
for (int d = 1; d <= max_level_; edge /= 2, ++d)
{
const int childid = ((x & edge) > 0) + 2 * ((y & edge) > 0) + 4 * ((z & edge) > 0);
auto tmp = n->child(childid);
if (!tmp)
return false;
n = tmp;
}
if (!n)
return false;
else
return true;
}
torch::Tensor Octree::get_features(torch::Tensor pts)
{
}
torch::Tensor Octree::get_leaf_voxels()
{
std::vector voxel_coords = get_leaf_voxel_recursive(root_);
int N = voxel_coords.size() / 3;
torch::Tensor voxels = torch::from_blob(voxel_coords.data(), {N, 3});
return voxels.clone();
}
std::vector Octree::get_leaf_voxel_recursive(Octant *n)
{
if (!n)
return std::vector();
if (n->is_leaf_ && n->type_ == SURFACE)
{
auto xyz = decode(n->code_);
return {xyz[0], xyz[1], xyz[2]};
}
std::vector coords;
for (int i = 0; i < 8; i++)
{
auto temp = get_leaf_voxel_recursive(n->child(i));
coords.insert(coords.end(), temp.begin(), temp.end());
}
return coords;
}
torch::Tensor Octree::get_voxels()
{
std::vector voxel_coords = get_voxel_recursive(root_);
int N = voxel_coords.size() / 4;
auto options = torch::TensorOptions().dtype(torch::kFloat32);
torch::Tensor voxels = torch::from_blob(voxel_coords.data(), {N, 4}, options);
return voxels.clone();
}
std::vector Octree::get_voxel_recursive(Octant *n)
{
if (!n)
return std::vector();
auto xyz = decode(n->code_);
std::vector coords = {xyz[0], xyz[1], xyz[2], float(n->side_)};
for (int i = 0; i < 8; i++)
{
auto temp = get_voxel_recursive(n->child(i));
coords.insert(coords.end(), temp.begin(), temp.end());
}
return coords;
}
std::pair Octree::count_nodes_internal()
{
return count_recursive_internal(root_);
}
// int64_t Octree::leaves_count_recursive(std::shared_ptr n)
std::pair Octree::count_recursive_internal(Octant *n)
{
if (!n)
return std::make_pair(0, 0);
if (n->is_leaf_)
return std::make_pair(1, 1);
auto sum = std::make_pair(1, 0);
for (int i = 0; i < 8; i++)
{
auto temp = count_recursive_internal(n->child(i));
sum.first += temp.first;
sum.second += temp.second;
}
return sum;
}
std::tuple Octree::get_centres_and_children()
{
auto node_count = count_nodes_internal();
auto total_count = node_count.first;
auto leaf_count = node_count.second;
auto all_voxels = torch::zeros({total_count, 4}, dtype(torch::kFloat32));
auto all_children = -torch::ones({total_count, 8}, dtype(torch::kFloat32));
auto all_features = -torch::ones({total_count, 8}, dtype(torch::kInt32));
std::queue all_nodes;
all_nodes.push(root_);
while (!all_nodes.empty())
{
auto node_ptr = all_nodes.front();
all_nodes.pop();
auto xyz = decode(node_ptr->code_);
std::vector coords = {xyz[0], xyz[1], xyz[2], float(node_ptr->side_)};
auto voxel = torch::from_blob(coords.data(), {4}, dtype(torch::kFloat32));
all_voxels[node_ptr->index_] = voxel;
if (node_ptr->type_ == SURFACE)
{
for (int i = 0; i < 8; ++i)
{
std::vector vcoords = coords;
vcoords[0] += incr_x[i];
vcoords[1] += incr_y[i];
vcoords[2] += incr_z[i];
auto voxel = find_octant(vcoords);
if (voxel)
all_features.data_ptr()[node_ptr->index_ * 8 + i] = voxel->index_;
}
}
for (int i = 0; i < 8; i++)
{
auto child_ptr = node_ptr->child(i);
if (child_ptr && child_ptr->type_ != FEATURE)
{
all_nodes.push(child_ptr);
all_children[node_ptr->index_][i] = float(child_ptr->index_);
}
}
}
return std::make_tuple(all_voxels, all_children, all_features);
}
int64_t Octree::count_nodes()
{
return count_recursive(root_);
}
// int64_t Octree::leaves_count_recursive(std::shared_ptr n)
int64_t Octree::count_recursive(Octant *n)
{
if (!n)
return 0;
int64_t sum = 1;
for (int i = 0; i < 8; i++)
{
sum += count_recursive(n->child(i));
}
return sum;
}
int64_t Octree::count_leaf_nodes()
{
return leaves_count_recursive(root_);
}
// int64_t Octree::leaves_count_recursive(std::shared_ptr n)
int64_t Octree::leaves_count_recursive(Octant *n)
{
if (!n)
return 0;
if (n->type_ == SURFACE)
{
return 1;
}
int64_t sum = 0;
for (int i = 0; i < 8; i++)
{
sum += leaves_count_recursive(n->child(i));
}
return sum;
}
================================================
FILE: third_party/sparse_voxels/include/cuda_utils.h
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H
#include
#include
#include
#include
#include
#include
#define TOTAL_THREADS 512
inline int opt_n_threads(int work_size)
{
const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
return max(min(1 << pow_2, TOTAL_THREADS), 1);
}
inline dim3 opt_block_config(int x, int y)
{
const int x_threads = opt_n_threads(x);
const int y_threads =
max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
dim3 block_config(x_threads, y_threads, 1);
return block_config;
}
#define CUDA_CHECK_ERRORS() \
do \
{ \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess != err) \
{ \
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
__FILE__); \
exit(-1); \
} \
} while (0)
#endif
================================================
FILE: third_party/sparse_voxels/include/cutil_math.h
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
/*
* Copyright 1993-2009 NVIDIA Corporation. All rights reserved.
*
* NVIDIA Corporation and its licensors retain all intellectual property and
* proprietary rights in and to this software and related documentation and
* any modifications thereto. Any use, reproduction, disclosure, or distribution
* of this software and related documentation without an express license
* agreement from NVIDIA Corporation is strictly prohibited.
*
*/
/*
This file implements common mathematical operations on vector types
(float3, float4 etc.) since these are not provided as standard by CUDA.
The syntax is modelled on the Cg standard library.
*/
#ifndef CUTIL_MATH_H
#define CUTIL_MATH_H
#include "cuda_runtime.h"
////////////////////////////////////////////////////////////////////////////////
typedef unsigned int uint;
typedef unsigned short ushort;
#ifndef __CUDACC__
#include
inline float fminf(float a, float b)
{
return a < b ? a : b;
}
inline float fmaxf(float a, float b)
{
return a > b ? a : b;
}
inline int max(int a, int b)
{
return a > b ? a : b;
}
inline int min(int a, int b)
{
return a < b ? a : b;
}
inline float rsqrtf(float x)
{
return 1.0f / sqrtf(x);
}
#endif
// float functions
////////////////////////////////////////////////////////////////////////////////
// lerp
inline __device__ __host__ float lerp(float a, float b, float t)
{
return a + t * (b - a);
}
// clamp
inline __device__ __host__ float clamp(float f, float a, float b)
{
return fmaxf(a, fminf(f, b));
}
inline __device__ __host__ void swap(float &a, float &b)
{
float c = a;
a = b;
b = c;
}
inline __device__ __host__ void swap(int &a, int &b)
{
float c = a;
a = b;
b = c;
}
// int2 functions
////////////////////////////////////////////////////////////////////////////////
// negate
inline __host__ __device__ int2 operator-(int2 &a)
{
return make_int2(-a.x, -a.y);
}
// addition
inline __host__ __device__ int2 operator+(int2 a, int2 b)
{
return make_int2(a.x + b.x, a.y + b.y);
}
inline __host__ __device__ void operator+=(int2 &a, int2 b)
{
a.x += b.x;
a.y += b.y;
}
// subtract
inline __host__ __device__ int2 operator-(int2 a, int2 b)
{
return make_int2(a.x - b.x, a.y - b.y);
}
inline __host__ __device__ void operator-=(int2 &a, int2 b)
{
a.x -= b.x;
a.y -= b.y;
}
// multiply
inline __host__ __device__ int2 operator*(int2 a, int2 b)
{
return make_int2(a.x * b.x, a.y * b.y);
}
inline __host__ __device__ int2 operator*(int2 a, int s)
{
return make_int2(a.x * s, a.y * s);
}
inline __host__ __device__ int2 operator*(int s, int2 a)
{
return make_int2(a.x * s, a.y * s);
}
inline __host__ __device__ void operator*=(int2 &a, int s)
{
a.x *= s;
a.y *= s;
}
// float2 functions
////////////////////////////////////////////////////////////////////////////////
// additional constructors
inline __host__ __device__ float2 make_float2(float s)
{
return make_float2(s, s);
}
inline __host__ __device__ float2 make_float2(int2 a)
{
return make_float2(float(a.x), float(a.y));
}
// negate
inline __host__ __device__ float2 operator-(float2 &a)
{
return make_float2(-a.x, -a.y);
}
// addition
inline __host__ __device__ float2 operator+(float2 a, float2 b)
{
return make_float2(a.x + b.x, a.y + b.y);
}
inline __host__ __device__ void operator+=(float2 &a, float2 b)
{
a.x += b.x;
a.y += b.y;
}
// subtract
inline __host__ __device__ float2 operator-(float2 a, float2 b)
{
return make_float2(a.x - b.x, a.y - b.y);
}
inline __host__ __device__ void operator-=(float2 &a, float2 b)
{
a.x -= b.x;
a.y -= b.y;
}
// multiply
inline __host__ __device__ float2 operator*(float2 a, float2 b)
{
return make_float2(a.x * b.x, a.y * b.y);
}
inline __host__ __device__ float2 operator*(float2 a, float s)
{
return make_float2(a.x * s, a.y * s);
}
inline __host__ __device__ float2 operator*(float s, float2 a)
{
return make_float2(a.x * s, a.y * s);
}
inline __host__ __device__ void operator*=(float2 &a, float s)
{
a.x *= s;
a.y *= s;
}
// divide
inline __host__ __device__ float2 operator/(float2 a, float2 b)
{
return make_float2(a.x / b.x, a.y / b.y);
}
inline __host__ __device__ float2 operator/(float2 a, float s)
{
float inv = 1.0f / s;
return a * inv;
}
inline __host__ __device__ float2 operator/(float s, float2 a)
{
float inv = 1.0f / s;
return a * inv;
}
inline __host__ __device__ void operator/=(float2 &a, float s)
{
float inv = 1.0f / s;
a *= inv;
}
// lerp
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
{
return a + t * (b - a);
}
// clamp
inline __device__ __host__ float2 clamp(float2 v, float a, float b)
{
return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
}
inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
{
return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
}
// dot product
inline __host__ __device__ float dot(float2 a, float2 b)
{
return a.x * b.x + a.y * b.y;
}
// length
inline __host__ __device__ float length(float2 v)
{
return sqrtf(dot(v, v));
}
// normalize
inline __host__ __device__ float2 normalize(float2 v)
{
float invLen = rsqrtf(dot(v, v));
return v * invLen;
}
// floor
inline __host__ __device__ float2 floor(const float2 v)
{
return make_float2(floor(v.x), floor(v.y));
}
// reflect
inline __host__ __device__ float2 reflect(float2 i, float2 n)
{
return i - 2.0f * n * dot(n, i);
}
// absolute value
inline __host__ __device__ float2 fabs(float2 v)
{
return make_float2(fabs(v.x), fabs(v.y));
}
// float3 functions
////////////////////////////////////////////////////////////////////////////////
// additional constructors
inline __host__ __device__ float3 make_float3(float s)
{
return make_float3(s, s, s);
}
inline __host__ __device__ float3 make_float3(float2 a)
{
return make_float3(a.x, a.y, 0.0f);
}
inline __host__ __device__ float3 make_float3(float2 a, float s)
{
return make_float3(a.x, a.y, s);
}
inline __host__ __device__ float3 make_float3(float4 a)
{
return make_float3(a.x, a.y, a.z); // discards w
}
inline __host__ __device__ float3 make_float3(int3 a)
{
return make_float3(float(a.x), float(a.y), float(a.z));
}
// negate
inline __host__ __device__ float3 operator-(float3 &a)
{
return make_float3(-a.x, -a.y, -a.z);
}
// min
static __inline__ __host__ __device__ float3 fminf(float3 a, float3 b)
{
return make_float3(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z));
}
// max
static __inline__ __host__ __device__ float3 fmaxf(float3 a, float3 b)
{
return make_float3(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z));
}
// addition
inline __host__ __device__ float3 operator+(float3 a, float3 b)
{
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
inline __host__ __device__ float3 operator+(float3 a, float b)
{
return make_float3(a.x + b, a.y + b, a.z + b);
}
inline __host__ __device__ void operator+=(float3 &a, float3 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
}
// subtract
inline __host__ __device__ float3 operator-(float3 a, float3 b)
{
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
}
inline __host__ __device__ float3 operator-(float3 a, float b)
{
return make_float3(a.x - b, a.y - b, a.z - b);
}
inline __host__ __device__ void operator-=(float3 &a, float3 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
}
// multiply
inline __host__ __device__ float3 operator*(float3 a, float3 b)
{
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
}
inline __host__ __device__ float3 operator*(float3 a, float s)
{
return make_float3(a.x * s, a.y * s, a.z * s);
}
inline __host__ __device__ float3 operator*(float s, float3 a)
{
return make_float3(a.x * s, a.y * s, a.z * s);
}
inline __host__ __device__ void operator*=(float3 &a, float s)
{
a.x *= s;
a.y *= s;
a.z *= s;
}
inline __host__ __device__ void operator*=(float3 &a, float3 b)
{
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
;
}
// divide
inline __host__ __device__ float3 operator/(float3 a, float3 b)
{
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
}
inline __host__ __device__ float3 operator/(float3 a, float s)
{
float inv = 1.0f / s;
return a * inv;
}
inline __host__ __device__ float3 operator/(float s, float3 a)
{
float inv = 1.0f / s;
return a * inv;
}
inline __host__ __device__ void operator/=(float3 &a, float s)
{
float inv = 1.0f / s;
a *= inv;
}
// lerp
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
{
return a + t * (b - a);
}
// clamp
inline __device__ __host__ float3 clamp(float3 v, float a, float b)
{
return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
}
inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
{
return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
}
// dot product
inline __host__ __device__ float dot(float3 a, float3 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z;
}
// cross product
inline __host__ __device__ float3 cross(float3 a, float3 b)
{
return make_float3(a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
}
// length
inline __host__ __device__ float length(float3 v)
{
return sqrtf(dot(v, v));
}
// normalize
inline __host__ __device__ float3 normalize(float3 v)
{
float invLen = rsqrtf(dot(v, v));
return v * invLen;
}
// floor
inline __host__ __device__ float3 floor(const float3 v)
{
return make_float3(floor(v.x), floor(v.y), floor(v.z));
}
// reflect
inline __host__ __device__ float3 reflect(float3 i, float3 n)
{
return i - 2.0f * n * dot(n, i);
}
// absolute value
inline __host__ __device__ float3 fabs(float3 v)
{
return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
}
// float4 functions
////////////////////////////////////////////////////////////////////////////////
// additional constructors
inline __host__ __device__ float4 make_float4(float s)
{
return make_float4(s, s, s, s);
}
inline __host__ __device__ float4 make_float4(float3 a)
{
return make_float4(a.x, a.y, a.z, 0.0f);
}
inline __host__ __device__ float4 make_float4(float3 a, float w)
{
return make_float4(a.x, a.y, a.z, w);
}
inline __host__ __device__ float4 make_float4(int4 a)
{
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
}
// negate
inline __host__ __device__ float4 operator-(float4 &a)
{
return make_float4(-a.x, -a.y, -a.z, -a.w);
}
// min
static __inline__ __host__ __device__ float4 fminf(float4 a, float4 b)
{
return make_float4(fminf(a.x, b.x), fminf(a.y, b.y), fminf(a.z, b.z), fminf(a.w, b.w));
}
// max
static __inline__ __host__ __device__ float4 fmaxf(float4 a, float4 b)
{
return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z), fmaxf(a.w, b.w));
}
// addition
inline __host__ __device__ float4 operator+(float4 a, float4 b)
{
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
inline __host__ __device__ void operator+=(float4 &a, float4 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
a.w += b.w;
}
// subtract
inline __host__ __device__ float4 operator-(float4 a, float4 b)
{
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
inline __host__ __device__ void operator-=(float4 &a, float4 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
a.w -= b.w;
}
// multiply
inline __host__ __device__ float4 operator*(float4 a, float s)
{
return make_float4(a.x * s, a.y * s, a.z * s, a.w * s);
}
inline __host__ __device__ float4 operator*(float s, float4 a)
{
return make_float4(a.x * s, a.y * s, a.z * s, a.w * s);
}
inline __host__ __device__ void operator*=(float4 &a, float s)
{
a.x *= s;
a.y *= s;
a.z *= s;
a.w *= s;
}
// divide
inline __host__ __device__ float4 operator/(float4 a, float4 b)
{
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
}
inline __host__ __device__ float4 operator/(float4 a, float s)
{
float inv = 1.0f / s;
return a * inv;
}
inline __host__ __device__ float4 operator/(float s, float4 a)
{
float inv = 1.0f / s;
return a * inv;
}
inline __host__ __device__ void operator/=(float4 &a, float s)
{
float inv = 1.0f / s;
a *= inv;
}
// lerp
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
{
return a + t * (b - a);
}
// clamp
inline __device__ __host__ float4 clamp(float4 v, float a, float b)
{
return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
}
inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
{
return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
}
// dot product
inline __host__ __device__ float dot(float4 a, float4 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
}
// length
inline __host__ __device__ float length(float4 r)
{
return sqrtf(dot(r, r));
}
// normalize
inline __host__ __device__ float4 normalize(float4 v)
{
float invLen = rsqrtf(dot(v, v));
return v * invLen;
}
// floor
inline __host__ __device__ float4 floor(const float4 v)
{
return make_float4(floor(v.x), floor(v.y), floor(v.z), floor(v.w));
}
// absolute value
inline __host__ __device__ float4 fabs(float4 v)
{
return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
}
// int3 functions
////////////////////////////////////////////////////////////////////////////////
// additional constructors
inline __host__ __device__ int3 make_int3(int s)
{
return make_int3(s, s, s);
}
inline __host__ __device__ int3 make_int3(float3 a)
{
return make_int3(int(a.x), int(a.y), int(a.z));
}
// negate
inline __host__ __device__ int3 operator-(int3 &a)
{
return make_int3(-a.x, -a.y, -a.z);
}
// min
inline __host__ __device__ int3 min(int3 a, int3 b)
{
return make_int3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z));
}
// max
inline __host__ __device__ int3 max(int3 a, int3 b)
{
return make_int3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z));
}
// addition
inline __host__ __device__ int3 operator+(int3 a, int3 b)
{
return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
}
inline __host__ __device__ void operator+=(int3 &a, int3 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
}
// subtract
inline __host__ __device__ int3 operator-(int3 a, int3 b)
{
return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
}
inline __host__ __device__ void operator-=(int3 &a, int3 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
}
// multiply
inline __host__ __device__ int3 operator*(int3 a, int3 b)
{
return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
}
inline __host__ __device__ int3 operator*(int3 a, int s)
{
return make_int3(a.x * s, a.y * s, a.z * s);
}
inline __host__ __device__ int3 operator*(int s, int3 a)
{
return make_int3(a.x * s, a.y * s, a.z * s);
}
inline __host__ __device__ void operator*=(int3 &a, int s)
{
a.x *= s;
a.y *= s;
a.z *= s;
}
// divide
inline __host__ __device__ int3 operator/(int3 a, int3 b)
{
return make_int3(a.x / b.x, a.y / b.y, a.z / b.z);
}
inline __host__ __device__ int3 operator/(int3 a, int s)
{
return make_int3(a.x / s, a.y / s, a.z / s);
}
inline __host__ __device__ int3 operator/(int s, int3 a)
{
return make_int3(a.x / s, a.y / s, a.z / s);
}
inline __host__ __device__ void operator/=(int3 &a, int s)
{
a.x /= s;
a.y /= s;
a.z /= s;
}
// clamp
inline __device__ __host__ int clamp(int f, int a, int b)
{
return max(a, min(f, b));
}
inline __device__ __host__ int3 clamp(int3 v, int a, int b)
{
return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
}
inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
{
return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
}
// uint3 functions
////////////////////////////////////////////////////////////////////////////////
// additional constructors
inline __host__ __device__ uint3 make_uint3(uint s)
{
return make_uint3(s, s, s);
}
inline __host__ __device__ uint3 make_uint3(float3 a)
{
return make_uint3(uint(a.x), uint(a.y), uint(a.z));
}
// min
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
{
return make_uint3(min(a.x, b.x), min(a.y, b.y), min(a.z, b.z));
}
// max
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
{
return make_uint3(max(a.x, b.x), max(a.y, b.y), max(a.z, b.z));
}
// addition
inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
{
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
}
inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
}
// subtract
inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
{
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
}
inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
}
// multiply
inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
{
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
}
inline __host__ __device__ uint3 operator*(uint3 a, uint s)
{
return make_uint3(a.x * s, a.y * s, a.z * s);
}
inline __host__ __device__ uint3 operator*(uint s, uint3 a)
{
return make_uint3(a.x * s, a.y * s, a.z * s);
}
inline __host__ __device__ void operator*=(uint3 &a, uint s)
{
a.x *= s;
a.y *= s;
a.z *= s;
}
// divide
inline __host__ __device__ uint3 operator/(uint3 a, uint3 b)
{
return make_uint3(a.x / b.x, a.y / b.y, a.z / b.z);
}
inline __host__ __device__ uint3 operator/(uint3 a, uint s)
{
return make_uint3(a.x / s, a.y / s, a.z / s);
}
inline __host__ __device__ uint3 operator/(uint s, uint3 a)
{
return make_uint3(a.x / s, a.y / s, a.z / s);
}
inline __host__ __device__ void operator/=(uint3 &a, uint s)
{
a.x /= s;
a.y /= s;
a.z /= s;
}
// clamp
inline __device__ __host__ uint clamp(uint f, uint a, uint b)
{
return max(a, min(f, b));
}
inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
{
return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
}
inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
{
return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
}
#endif
================================================
FILE: third_party/sparse_voxels/include/intersect.h
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include
#include
std::tuple ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,
const float radius, const int n_max);
std::tuple aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,
const float voxelsize, const int n_max);
std::tuple svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points, at::Tensor children,
const float voxelsize, const int n_max);
std::tuple triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points,
const float cagesize, const float blur, const int n_max);
================================================
FILE: third_party/sparse_voxels/include/octree.h
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include
#include
std::tuple build_octree(at::Tensor center, at::Tensor points, int depth);
================================================
FILE: third_party/sparse_voxels/include/sample.h
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include
#include
std::tuple uniform_ray_sampling(
at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,
const float step_size, const int max_steps);
std::tuple inverse_cdf_sampling(
at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,
at::Tensor probs, at::Tensor steps, float fixed_step_size);
================================================
FILE: third_party/sparse_voxels/include/utils.h
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include
#include
#define CHECK_CUDA(x) \
do \
{ \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
} while (0)
#define CHECK_CONTIGUOUS(x) \
do \
{ \
TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
} while (0)
#define CHECK_IS_INT(x) \
do \
{ \
TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
#x " must be an int tensor"); \
} while (0)
#define CHECK_IS_FLOAT(x) \
do \
{ \
TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
#x " must be a float tensor"); \
} while (0)
================================================
FILE: third_party/sparse_voxels/setup.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import glob
_ext_sources = glob.glob("src/*.cpp") + glob.glob("src/*.cu")
setup(
name='grid',
ext_modules=[
CUDAExtension(
name='grid',
sources=_ext_sources,
include_dirs=["./include"],
extra_compile_args={
"cxx": ["-O2", "-I./include"],
"nvcc": ["-O2", "-I./include"],
},
)
],
cmdclass={
'build_ext': BuildExtension
}
)
================================================
FILE: third_party/sparse_voxels/src/binding.cpp
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "../include/intersect.h"
#include "../include/octree.h"
#include "../include/sample.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ball_intersect", &ball_intersect);
m.def("aabb_intersect", &aabb_intersect);
m.def("svo_intersect", &svo_intersect);
m.def("triangle_intersect", &triangle_intersect);
m.def("uniform_ray_sampling", &uniform_ray_sampling);
m.def("inverse_cdf_sampling", &inverse_cdf_sampling);
m.def("build_octree", &build_octree);
}
================================================
FILE: third_party/sparse_voxels/src/intersect.cpp
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "../include/intersect.h"
#include "../include/utils.h"
#include
void ball_intersect_point_kernel_wrapper(
int b, int n, int m, float radius, int n_max,
const float *ray_start, const float *ray_dir, const float *points,
int *idx, float *min_depth, float *max_depth);
std::tuple ball_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,
const float radius, const int n_max)
{
CHECK_CONTIGUOUS(ray_start);
CHECK_CONTIGUOUS(ray_dir);
CHECK_CONTIGUOUS(points);
CHECK_IS_FLOAT(ray_start);
CHECK_IS_FLOAT(ray_dir);
CHECK_IS_FLOAT(points);
CHECK_CUDA(ray_start);
CHECK_CUDA(ray_dir);
CHECK_CUDA(points);
at::Tensor idx =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Int));
at::Tensor min_depth =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Float));
at::Tensor max_depth =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Float));
ball_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),
radius, n_max,
ray_start.data_ptr(), ray_dir.data_ptr(), points.data_ptr(),
idx.data_ptr(), min_depth.data_ptr(), max_depth.data_ptr());
return std::make_tuple(idx, min_depth, max_depth);
}
void aabb_intersect_point_kernel_wrapper(
int b, int n, int m, float voxelsize, int n_max,
const float *ray_start, const float *ray_dir, const float *points,
int *idx, float *min_depth, float *max_depth);
std::tuple aabb_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,
const float voxelsize, const int n_max)
{
CHECK_CONTIGUOUS(ray_start);
CHECK_CONTIGUOUS(ray_dir);
CHECK_CONTIGUOUS(points);
CHECK_IS_FLOAT(ray_start);
CHECK_IS_FLOAT(ray_dir);
CHECK_IS_FLOAT(points);
CHECK_CUDA(ray_start);
CHECK_CUDA(ray_dir);
CHECK_CUDA(points);
at::Tensor idx =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Int));
at::Tensor min_depth =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Float));
at::Tensor max_depth =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Float));
aabb_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),
voxelsize, n_max,
ray_start.data_ptr(), ray_dir.data_ptr(), points.data_ptr(),
idx.data_ptr(), min_depth.data_ptr(), max_depth.data_ptr());
return std::make_tuple(idx, min_depth, max_depth);
}
void svo_intersect_point_kernel_wrapper(
int b, int n, int m, float voxelsize, int n_max,
const float *ray_start, const float *ray_dir, const float *points, const int *children,
int *idx, float *min_depth, float *max_depth);
std::tuple svo_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor points,
at::Tensor children, const float voxelsize, const int n_max)
{
CHECK_CONTIGUOUS(ray_start);
CHECK_CONTIGUOUS(ray_dir);
CHECK_CONTIGUOUS(points);
CHECK_CONTIGUOUS(children);
CHECK_IS_FLOAT(ray_start);
CHECK_IS_FLOAT(ray_dir);
CHECK_IS_FLOAT(points);
CHECK_CUDA(ray_start);
CHECK_CUDA(ray_dir);
CHECK_CUDA(points);
CHECK_CUDA(children);
at::Tensor idx =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Int));
at::Tensor min_depth =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Float));
at::Tensor max_depth =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Float));
svo_intersect_point_kernel_wrapper(points.size(0), points.size(1), ray_start.size(1),
voxelsize, n_max,
ray_start.data_ptr(), ray_dir.data_ptr(), points.data_ptr(),
children.data_ptr(), idx.data_ptr(), min_depth.data_ptr(), max_depth.data_ptr());
return std::make_tuple(idx, min_depth, max_depth);
}
void triangle_intersect_point_kernel_wrapper(
int b, int n, int m, float cagesize, float blur, int n_max,
const float *ray_start, const float *ray_dir, const float *face_points,
int *idx, float *depth, float *uv);
std::tuple triangle_intersect(at::Tensor ray_start, at::Tensor ray_dir, at::Tensor face_points,
const float cagesize, const float blur, const int n_max)
{
CHECK_CONTIGUOUS(ray_start);
CHECK_CONTIGUOUS(ray_dir);
CHECK_CONTIGUOUS(face_points);
CHECK_IS_FLOAT(ray_start);
CHECK_IS_FLOAT(ray_dir);
CHECK_IS_FLOAT(face_points);
CHECK_CUDA(ray_start);
CHECK_CUDA(ray_dir);
CHECK_CUDA(face_points);
at::Tensor idx =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max},
at::device(ray_start.device()).dtype(at::ScalarType::Int));
at::Tensor depth =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 3},
at::device(ray_start.device()).dtype(at::ScalarType::Float));
at::Tensor uv =
torch::zeros({ray_start.size(0), ray_start.size(1), n_max * 2},
at::device(ray_start.device()).dtype(at::ScalarType::Float));
triangle_intersect_point_kernel_wrapper(face_points.size(0), face_points.size(1), ray_start.size(1),
cagesize, blur, n_max,
ray_start.data_ptr(), ray_dir.data_ptr(), face_points.data_ptr(),
idx.data_ptr(), depth.data_ptr(), uv.data_ptr());
return std::make_tuple(idx, depth, uv);
}
================================================
FILE: third_party/sparse_voxels/src/intersect_gpu.cu
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include
#include
#include
#include "../include/cuda_utils.h"
#include "../include/cutil_math.h" // required for float3 vector math
__global__ void ball_intersect_point_kernel(
int b, int n, int m, float radius,
int n_max,
const float *__restrict__ ray_start,
const float *__restrict__ ray_dir,
const float *__restrict__ points,
int *__restrict__ idx,
float *__restrict__ min_depth,
float *__restrict__ max_depth)
{
int batch_index = blockIdx.x;
points += batch_index * n * 3;
ray_start += batch_index * m * 3;
ray_dir += batch_index * m * 3;
idx += batch_index * m * n_max;
min_depth += batch_index * m * n_max;
max_depth += batch_index * m * n_max;
int index = threadIdx.x;
int stride = blockDim.x;
float radius2 = radius * radius;
for (int j = index; j < m; j += stride)
{
float x0 = ray_start[j * 3 + 0];
float y0 = ray_start[j * 3 + 1];
float z0 = ray_start[j * 3 + 2];
float xw = ray_dir[j * 3 + 0];
float yw = ray_dir[j * 3 + 1];
float zw = ray_dir[j * 3 + 2];
for (int l = 0; l < n_max; ++l)
{
idx[j * n_max + l] = -1;
}
for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k)
{
float x = points[k * 3 + 0] - x0;
float y = points[k * 3 + 1] - y0;
float z = points[k * 3 + 2] - z0;
float d2 = x * x + y * y + z * z;
float d2_proj = pow(x * xw + y * yw + z * zw, 2);
float r2 = d2 - d2_proj;
if (r2 < radius2)
{
idx[j * n_max + cnt] = k;
float depth = sqrt(d2_proj);
float depth_blur = sqrt(radius2 - r2);
min_depth[j * n_max + cnt] = depth - depth_blur;
max_depth[j * n_max + cnt] = depth + depth_blur;
++cnt;
}
}
}
}
__device__ float2 RayAABBIntersection(
const float3 &ori,
const float3 &dir,
const float3 ¢er,
float half_voxel)
{
float f_low = 0;
float f_high = 100000.;
float f_dim_low, f_dim_high, temp, inv_ray_dir, start, aabb;
for (int d = 0; d < 3; ++d)
{
switch (d)
{
case 0:
inv_ray_dir = __fdividef(1.0f, dir.x);
start = ori.x;
aabb = center.x;
break;
case 1:
inv_ray_dir = __fdividef(1.0f, dir.y);
start = ori.y;
aabb = center.y;
break;
case 2:
inv_ray_dir = __fdividef(1.0f, dir.z);
start = ori.z;
aabb = center.z;
break;
}
f_dim_low = (aabb - half_voxel - start) * inv_ray_dir;
f_dim_high = (aabb + half_voxel - start) * inv_ray_dir;
// Make sure low is less than high
if (f_dim_high < f_dim_low)
{
temp = f_dim_low;
f_dim_low = f_dim_high;
f_dim_high = temp;
}
// If this dimension's high is less than the low we got then we definitely missed.
if (f_dim_high < f_low)
{
return make_float2(-1.0f, -1.0f);
}
// Likewise if the low is less than the high.
if (f_dim_low > f_high)
{
return make_float2(-1.0f, -1.0f);
}
// Add the clip from this dimension to the previous results
f_low = (f_dim_low > f_low) ? f_dim_low : f_low;
f_high = (f_dim_high < f_high) ? f_dim_high : f_high;
if (f_low > f_high)
{
return make_float2(-1.0f, -1.0f);
}
}
return make_float2(f_low, f_high);
}
__global__ void aabb_intersect_point_kernel(
int b, int n, int m, float voxelsize,
int n_max,
const float *__restrict__ ray_start,
const float *__restrict__ ray_dir,
const float *__restrict__ points,
int *__restrict__ idx,
float *__restrict__ min_depth,
float *__restrict__ max_depth)
{
int batch_index = blockIdx.x;
points += batch_index * n * 3;
ray_start += batch_index * m * 3;
ray_dir += batch_index * m * 3;
idx += batch_index * m * n_max;
min_depth += batch_index * m * n_max;
max_depth += batch_index * m * n_max;
int index = threadIdx.x;
int stride = blockDim.x;
float half_voxel = voxelsize * 0.5;
for (int j = index; j < m; j += stride)
{
for (int l = 0; l < n_max; ++l)
{
idx[j * n_max + l] = -1;
}
for (int k = 0, cnt = 0; k < n && cnt < n_max; ++k)
{
float2 depths = RayAABBIntersection(
make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),
make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),
make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]),
half_voxel);
if (depths.x > -1.0f)
{
idx[j * n_max + cnt] = k;
min_depth[j * n_max + cnt] = depths.x;
max_depth[j * n_max + cnt] = depths.y;
++cnt;
}
}
}
}
__global__ void svo_intersect_point_kernel(
int b, int n, int m, float voxelsize,
int n_max,
const float *__restrict__ ray_start,
const float *__restrict__ ray_dir,
const float *__restrict__ points,
const int *__restrict__ children,
int *__restrict__ idx,
float *__restrict__ min_depth,
float *__restrict__ max_depth)
{
/*
TODO: this is an inefficient implementation of the
navie Ray -- Sparse Voxel Octree Intersection.
It can be further improved using:
Revelles, Jorge, Carlos Urena, and Miguel Lastra.
"An efficient parametric algorithm for octree traversal." (2000).
*/
int batch_index = blockIdx.x;
points += batch_index * n * 3;
children += batch_index * n * 9;
ray_start += batch_index * m * 3;
ray_dir += batch_index * m * 3;
idx += batch_index * m * n_max;
min_depth += batch_index * m * n_max;
max_depth += batch_index * m * n_max;
int index = threadIdx.x;
int stride = blockDim.x;
float half_voxel = voxelsize * 0.5;
for (int j = index; j < m; j += stride)
{
for (int l = 0; l < n_max; ++l)
{
idx[j * n_max + l] = -1;
}
int stack[256] = {-1}; // DFS, initialize the stack
int ptr = 0, cnt = 0, k = -1;
// stack[ptr] = n - 1; // ROOT node is always the last
stack[ptr] = 0;
while (ptr > -1 && cnt < n_max)
{
assert((ptr < 256));
// evaluate the current node
k = stack[ptr];
float2 depths = RayAABBIntersection(
make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),
make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),
make_float3(points[k * 3 + 0], points[k * 3 + 1], points[k * 3 + 2]),
half_voxel * float(children[k * 9 + 8]));
stack[ptr] = -1;
ptr--;
if (depths.x > -1.0f)
{ // ray did not miss the voxel
// TODO: here it should be able to know which children is ok, further optimize the code
if (children[k * 9 + 8] == 1)
{ // this is a terminal node
idx[j * n_max + cnt] = k;
min_depth[j * n_max + cnt] = depths.x;
max_depth[j * n_max + cnt] = depths.y;
++cnt;
continue;
}
for (int u = 0; u < 8; u++)
{
if (children[k * 9 + u] > -1)
{
ptr++;
stack[ptr] = children[k * 9 + u]; // push child to the stack
}
}
}
}
}
}
__device__ float3 RayTriangleIntersection(
const float3 &ori,
const float3 &dir,
const float3 &v0,
const float3 &v1,
const float3 &v2,
float blur)
{
float3 v0v1 = v1 - v0;
float3 v0v2 = v2 - v0;
float3 v0O = ori - v0;
float3 dir_crs_v0v2 = cross(dir, v0v2);
float det = dot(v0v1, dir_crs_v0v2);
det = __fdividef(1.0f, det); // CUDA intrinsic function
float u = dot(v0O, dir_crs_v0v2) * det;
if ((u < 0.0f - blur) || (u > 1.0f + blur))
return make_float3(-1.0f, 0.0f, 0.0f);
float3 v0O_crs_v0v1 = cross(v0O, v0v1);
float v = dot(dir, v0O_crs_v0v1) * det;
if ((v < 0.0f - blur) || (v > 1.0f + blur))
return make_float3(-1.0f, 0.0f, 0.0f);
if (((u + v) < 0.0f - blur) || ((u + v) > 1.0f + blur))
return make_float3(-1.0f, 0.0f, 0.0f);
float t = dot(v0v2, v0O_crs_v0v1) * det;
return make_float3(t, u, v);
}
__global__ void triangle_intersect_point_kernel(
int b, int n, int m, float cagesize,
float blur, int n_max,
const float *__restrict__ ray_start,
const float *__restrict__ ray_dir,
const float *__restrict__ face_points,
int *__restrict__ idx,
float *__restrict__ depth,
float *__restrict__ uv)
{
int batch_index = blockIdx.x;
face_points += batch_index * n * 9;
ray_start += batch_index * m * 3;
ray_dir += batch_index * m * 3;
idx += batch_index * m * n_max;
depth += batch_index * m * n_max * 3;
uv += batch_index * m * n_max * 2;
int index = threadIdx.x;
int stride = blockDim.x;
for (int j = index; j < m; j += stride)
{
// go over rays
for (int l = 0; l < n_max; ++l)
{
idx[j * n_max + l] = -1;
}
int cnt = 0;
for (int k = 0; k < n && cnt < n_max; ++k)
{
// go over triangles
float3 tuv = RayTriangleIntersection(
make_float3(ray_start[j * 3 + 0], ray_start[j * 3 + 1], ray_start[j * 3 + 2]),
make_float3(ray_dir[j * 3 + 0], ray_dir[j * 3 + 1], ray_dir[j * 3 + 2]),
make_float3(face_points[k * 9 + 0], face_points[k * 9 + 1], face_points[k * 9 + 2]),
make_float3(face_points[k * 9 + 3], face_points[k * 9 + 4], face_points[k * 9 + 5]),
make_float3(face_points[k * 9 + 6], face_points[k * 9 + 7], face_points[k * 9 + 8]),
blur);
if (tuv.x > 0)
{
int ki = k;
float d = tuv.x, u = tuv.y, v = tuv.z;
// sort
for (int l = 0; l < cnt; l++)
{
if (d < depth[j * n_max * 3 + l * 3])
{
swap(ki, idx[j * n_max + l]);
swap(d, depth[j * n_max * 3 + l * 3]);
swap(u, uv[j * n_max * 2 + l * 2]);
swap(v, uv[j * n_max * 2 + l * 2 + 1]);
}
}
idx[j * n_max + cnt] = ki;
depth[j * n_max * 3 + cnt * 3] = d;
uv[j * n_max * 2 + cnt * 2] = u;
uv[j * n_max * 2 + cnt * 2 + 1] = v;
cnt++;
}
}
for (int l = 0; l < cnt; l++)
{
// compute min_depth
if (l == 0)
depth[j * n_max * 3 + l * 3 + 1] = -cagesize;
else
depth[j * n_max * 3 + l * 3 + 1] = -fminf(cagesize,
.5 * (depth[j * n_max * 3 + l * 3] - depth[j * n_max * 3 + l * 3 - 3]));
// compute max_depth
if (l == cnt - 1)
depth[j * n_max * 3 + l * 3 + 2] = cagesize;
else
depth[j * n_max * 3 + l * 3 + 2] = fminf(cagesize,
.5 * (depth[j * n_max * 3 + l * 3 + 3] - depth[j * n_max * 3 + l * 3]));
}
}
}
void ball_intersect_point_kernel_wrapper(
int b, int n, int m, float radius, int n_max,
const float *ray_start, const float *ray_dir, const float *points,
int *idx, float *min_depth, float *max_depth)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
ball_intersect_point_kernel<<>>(
b, n, m, radius, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth);
CUDA_CHECK_ERRORS();
}
void aabb_intersect_point_kernel_wrapper(
int b, int n, int m, float voxelsize, int n_max,
const float *ray_start, const float *ray_dir, const float *points,
int *idx, float *min_depth, float *max_depth)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
aabb_intersect_point_kernel<<>>(
b, n, m, voxelsize, n_max, ray_start, ray_dir, points, idx, min_depth, max_depth);
CUDA_CHECK_ERRORS();
}
void svo_intersect_point_kernel_wrapper(
int b, int n, int m, float voxelsize, int n_max,
const float *ray_start, const float *ray_dir, const float *points, const int *children,
int *idx, float *min_depth, float *max_depth)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
svo_intersect_point_kernel<<>>(
b, n, m, voxelsize, n_max, ray_start, ray_dir, points, children, idx, min_depth, max_depth);
CUDA_CHECK_ERRORS();
}
void triangle_intersect_point_kernel_wrapper(
int b, int n, int m, float cagesize, float blur, int n_max,
const float *ray_start, const float *ray_dir, const float *face_points,
int *idx, float *depth, float *uv)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
triangle_intersect_point_kernel<<>>(
b, n, m, cagesize, blur, n_max, ray_start, ray_dir, face_points, idx, depth, uv);
CUDA_CHECK_ERRORS();
}
================================================
FILE: third_party/sparse_voxels/src/octree.cpp
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "../include/octree.h"
#include "../include/utils.h"
#include
#include
using namespace std::chrono;
typedef struct OcTree
{
int depth;
int index;
at::Tensor center;
struct OcTree *children[8];
void init(at::Tensor center, int d, int i)
{
this->center = center;
this->depth = d;
this->index = i;
for (int i = 0; i < 8; i++)
this->children[i] = nullptr;
}
} OcTree;
class EasyOctree
{
public:
OcTree *root;
int total;
int terminal;
at::Tensor all_centers;
at::Tensor all_children;
EasyOctree(at::Tensor center, int depth)
{
root = new OcTree;
root->init(center, depth, -1);
total = -1;
terminal = -1;
}
~EasyOctree()
{
OcTree *p = root;
destory(p);
}
void destory(OcTree *&p);
void insert(OcTree *&p, at::Tensor point, int index);
void finalize();
std::pair count(OcTree *&p);
};
void EasyOctree::destory(OcTree *&p)
{
if (p != nullptr)
{
for (int i = 0; i < 8; i++)
{
if (p->children[i] != nullptr)
destory(p->children[i]);
}
delete p;
p = nullptr;
}
}
void EasyOctree::insert(OcTree *&p, at::Tensor point, int index)
{
at::Tensor diff = (point > p->center).to(at::kInt);
int idx = diff[0].item() + 2 * diff[1].item() + 4 * diff[2].item();
if (p->depth == 0)
{
p->children[idx] = new OcTree;
p->children[idx]->init(point, -1, index);
}
else
{
if (p->children[idx] == nullptr)
{
int length = 1 << (p->depth - 1);
at::Tensor new_center = p->center + (2 * diff - 1) * length;
p->children[idx] = new OcTree;
p->children[idx]->init(new_center, p->depth - 1, -1);
}
insert(p->children[idx], point, index);
}
}
std::pair EasyOctree::count(OcTree *&p)
{
int total = 0, terminal = 0;
for (int i = 0; i < 8; i++)
{
if (p->children[i] != nullptr)
{
std::pair sub = count(p->children[i]);
total += sub.first;
terminal += sub.second;
}
}
total += 1;
if (p->depth == -1)
terminal += 1;
return std::make_pair(total, terminal);
}
void EasyOctree::finalize()
{
std::pair outs = count(root);
total = outs.first;
terminal = outs.second;
all_centers =
torch::zeros({outs.first, 3}, at::device(root->center.device()).dtype(at::ScalarType::Int));
all_children =
-torch::ones({outs.first, 9}, at::device(root->center.device()).dtype(at::ScalarType::Int));
int node_idx = outs.first - 1;
root->index = node_idx;
std::queue all_leaves;
all_leaves.push(root);
while (!all_leaves.empty())
{
OcTree *node_ptr = all_leaves.front();
all_leaves.pop();
for (int i = 0; i < 8; i++)
{
if (node_ptr->children[i] != nullptr)
{
if (node_ptr->children[i]->depth > -1)
{
node_idx--;
node_ptr->children[i]->index = node_idx;
}
all_leaves.push(node_ptr->children[i]);
all_children[node_ptr->index][i] = node_ptr->children[i]->index;
}
}
all_children[node_ptr->index][8] = 1 << (node_ptr->depth + 1);
all_centers[node_ptr->index] = node_ptr->center;
}
assert(node_idx == outs.second);
};
std::tuple build_octree(at::Tensor center, at::Tensor points, int depth)
{
auto start = high_resolution_clock::now();
EasyOctree tree(center, depth);
for (int k = 0; k < points.size(0); k++)
tree.insert(tree.root, points[k], k);
tree.finalize();
auto stop = high_resolution_clock::now();
auto duration = duration_cast(stop - start);
printf("Building EasyOctree done. total #nodes = %d, terminal #nodes = %d (time taken %f s)\n",
tree.total, tree.terminal, float(duration.count()) / 1000000.);
return std::make_tuple(tree.all_centers, tree.all_children);
}
================================================
FILE: third_party/sparse_voxels/src/sample.cpp
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "../include/sample.h"
#include "../include/utils.h"
#include
void uniform_ray_sampling_kernel_wrapper(
int b, int num_rays, int max_hits, int max_steps, float step_size,
const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise,
int *sampled_idx, float *sampled_depth, float *sampled_dists);
void inverse_cdf_sampling_kernel_wrapper(
int b, int num_rays, int max_hits, int max_steps, float fixed_step_size,
const int *pts_idx, const float *min_depth, const float *max_depth,
const float *uniform_noise, const float *probs, const float *steps,
int *sampled_idx, float *sampled_depth, float *sampled_dists);
std::tuple uniform_ray_sampling(
at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,
const float step_size, const int max_steps)
{
CHECK_CONTIGUOUS(pts_idx);
CHECK_CONTIGUOUS(min_depth);
CHECK_CONTIGUOUS(max_depth);
CHECK_CONTIGUOUS(uniform_noise);
CHECK_IS_FLOAT(min_depth);
CHECK_IS_FLOAT(max_depth);
CHECK_IS_FLOAT(uniform_noise);
CHECK_IS_INT(pts_idx);
CHECK_CUDA(pts_idx);
CHECK_CUDA(min_depth);
CHECK_CUDA(max_depth);
CHECK_CUDA(uniform_noise);
at::Tensor sampled_idx =
-torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps},
at::device(pts_idx.device()).dtype(at::ScalarType::Int));
at::Tensor sampled_depth =
torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},
at::device(min_depth.device()).dtype(at::ScalarType::Float));
at::Tensor sampled_dists =
torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},
at::device(min_depth.device()).dtype(at::ScalarType::Float));
uniform_ray_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2),
step_size,
pts_idx.data_ptr(), min_depth.data_ptr(), max_depth.data_ptr(),
uniform_noise.data_ptr(), sampled_idx.data_ptr(),
sampled_depth.data_ptr(), sampled_dists.data_ptr());
return std::make_tuple(sampled_idx, sampled_depth, sampled_dists);
}
std::tuple inverse_cdf_sampling(
at::Tensor pts_idx, at::Tensor min_depth, at::Tensor max_depth, at::Tensor uniform_noise,
at::Tensor probs, at::Tensor steps, float fixed_step_size)
{
CHECK_CONTIGUOUS(pts_idx);
CHECK_CONTIGUOUS(min_depth);
CHECK_CONTIGUOUS(max_depth);
CHECK_CONTIGUOUS(probs);
CHECK_CONTIGUOUS(steps);
CHECK_CONTIGUOUS(uniform_noise);
CHECK_IS_FLOAT(min_depth);
CHECK_IS_FLOAT(max_depth);
CHECK_IS_FLOAT(uniform_noise);
CHECK_IS_FLOAT(probs);
CHECK_IS_FLOAT(steps);
CHECK_IS_INT(pts_idx);
CHECK_CUDA(pts_idx);
CHECK_CUDA(min_depth);
CHECK_CUDA(max_depth);
CHECK_CUDA(uniform_noise);
CHECK_CUDA(probs);
CHECK_CUDA(steps);
int max_steps = uniform_noise.size(-1);
at::Tensor sampled_idx =
-torch::ones({pts_idx.size(0), pts_idx.size(1), max_steps},
at::device(pts_idx.device()).dtype(at::ScalarType::Int));
at::Tensor sampled_depth =
torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},
at::device(min_depth.device()).dtype(at::ScalarType::Float));
at::Tensor sampled_dists =
torch::zeros({min_depth.size(0), min_depth.size(1), max_steps},
at::device(min_depth.device()).dtype(at::ScalarType::Float));
inverse_cdf_sampling_kernel_wrapper(min_depth.size(0), min_depth.size(1), min_depth.size(2), sampled_depth.size(2), fixed_step_size,
pts_idx.data_ptr(), min_depth.data_ptr(), max_depth.data_ptr(),
uniform_noise.data_ptr(), probs.data_ptr(), steps.data_ptr(),
sampled_idx.data_ptr(), sampled_depth.data_ptr(), sampled_dists.data_ptr());
return std::make_tuple(sampled_idx, sampled_depth, sampled_dists);
}
================================================
FILE: third_party/sparse_voxels/src/sample_gpu.cu
================================================
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include
#include
#include
#include "../include/cuda_utils.h"
#include "../include/cutil_math.h" // required for float3 vector math
__global__ void uniform_ray_sampling_kernel(
int b, int num_rays,
int max_hits,
int max_steps,
float step_size,
const int *__restrict__ pts_idx,
const float *__restrict__ min_depth,
const float *__restrict__ max_depth,
const float *__restrict__ uniform_noise,
int *__restrict__ sampled_idx,
float *__restrict__ sampled_depth,
float *__restrict__ sampled_dists)
{
int batch_index = blockIdx.x;
int index = threadIdx.x;
int stride = blockDim.x;
pts_idx += batch_index * num_rays * max_hits;
min_depth += batch_index * num_rays * max_hits;
max_depth += batch_index * num_rays * max_hits;
uniform_noise += batch_index * num_rays * max_steps;
sampled_idx += batch_index * num_rays * max_steps;
sampled_depth += batch_index * num_rays * max_steps;
sampled_dists += batch_index * num_rays * max_steps;
// loop over all rays
for (int j = index; j < num_rays; j += stride)
{
int H = j * max_hits, K = j * max_steps;
int s = 0, ucur = 0, umin = 0, umax = 0;
float last_min_depth, last_max_depth, curr_depth;
// sort all depths
while (true)
{
if ((umax == max_hits) || (ucur == max_steps) || (pts_idx[H + umax] == -1))
{
break; // reach the maximum
}
if (umin < max_hits)
{
last_min_depth = min_depth[H + umin];
}
else
{
last_min_depth = 10000.0;
}
if (umax < max_hits)
{
last_max_depth = max_depth[H + umax];
}
else
{
last_max_depth = 10000.0;
}
if (ucur < max_steps)
{
curr_depth = min_depth[H] + (float(ucur) + uniform_noise[K + ucur]) * step_size;
}
if ((last_max_depth <= curr_depth) && (last_max_depth <= last_min_depth))
{
sampled_depth[K + s] = last_max_depth;
sampled_idx[K + s] = pts_idx[H + umax];
umax++;
s++;
continue;
}
if ((curr_depth <= last_min_depth) && (curr_depth <= last_max_depth))
{
sampled_depth[K + s] = curr_depth;
sampled_idx[K + s] = pts_idx[H + umin - 1];
ucur++;
s++;
continue;
}
if ((last_min_depth <= curr_depth) && (last_min_depth <= last_max_depth))
{
sampled_depth[K + s] = last_min_depth;
sampled_idx[K + s] = pts_idx[H + umin];
umin++;
s++;
continue;
}
}
float l_depth, r_depth;
int step = 0;
for (ucur = 0, umin = 0, umax = 0; ucur < max_steps - 1; ucur++)
{
if (sampled_idx[K + ucur + 1] == -1)
break;
l_depth = sampled_depth[K + ucur];
r_depth = sampled_depth[K + ucur + 1];
sampled_depth[K + ucur] = (l_depth + r_depth) * .5;
sampled_dists[K + ucur] = (r_depth - l_depth);
if ((umin < max_hits) && (sampled_depth[K + ucur] >= min_depth[H + umin]) && (pts_idx[H + umin] > -1))
umin++;
if ((umax < max_hits) && (sampled_depth[K + ucur] >= max_depth[H + umax]) && (pts_idx[H + umax] > -1))
umax++;
if ((umax == max_hits) || (pts_idx[H + umax] == -1))
break;
if ((umin - 1 == umax) && (sampled_dists[K + ucur] > 0))
{
sampled_depth[K + step] = sampled_depth[K + ucur];
sampled_dists[K + step] = sampled_dists[K + ucur];
sampled_idx[K + step] = sampled_idx[K + ucur];
step++;
}
}
for (int s = step; s < max_steps; s++)
{
sampled_idx[K + s] = -1;
}
}
}
__global__ void inverse_cdf_sampling_kernel(
int b, int num_rays,
int max_hits,
int max_steps,
float fixed_step_size,
const int *__restrict__ pts_idx,
const float *__restrict__ min_depth,
const float *__restrict__ max_depth,
const float *__restrict__ uniform_noise,
const float *__restrict__ probs,
const float *__restrict__ steps,
int *__restrict__ sampled_idx,
float *__restrict__ sampled_depth,
float *__restrict__ sampled_dists)
{
int batch_index = blockIdx.x;
int index = threadIdx.x;
int stride = blockDim.x;
pts_idx += batch_index * num_rays * max_hits;
min_depth += batch_index * num_rays * max_hits;
max_depth += batch_index * num_rays * max_hits;
probs += batch_index * num_rays * max_hits;
steps += batch_index * num_rays;
uniform_noise += batch_index * num_rays * max_steps;
sampled_idx += batch_index * num_rays * max_steps;
sampled_depth += batch_index * num_rays * max_steps;
sampled_dists += batch_index * num_rays * max_steps;
// loop over all rays
for (int j = index; j < num_rays; j += stride)
{
int H = j * max_hits, K = j * max_steps;
int curr_bin = 0, s = 0; // current index (bin)
float curr_min_depth = min_depth[H]; // lower depth
float curr_max_depth = max_depth[H]; // upper depth
float curr_min_cdf = 0;
float curr_max_cdf = probs[H];
float step_size = 1.0 / steps[j];
float z_low = curr_min_depth;
int total_steps = int(ceil(steps[j]));
bool done = false;
// optional use a fixed step size
if (fixed_step_size > 0.0)
step_size = fixed_step_size;
// sample points
for (int curr_step = 0; curr_step < total_steps; curr_step++)
{
float curr_cdf = (float(curr_step) + uniform_noise[K + curr_step]) * step_size;
// printf("curr_cdf: %f\n", curr_cdf);
while (curr_cdf > curr_max_cdf)
{
// first include max cdf
sampled_idx[K + s] = pts_idx[H + curr_bin];
sampled_dists[K + s] = (curr_max_depth - z_low);
sampled_depth[K + s] = (curr_max_depth + z_low) * .5;
// move to next cdf
curr_bin++;
s++;
if ((curr_bin >= max_hits) || (pts_idx[H + curr_bin] == -1))
{
done = true;
break;
}
curr_min_depth = min_depth[H + curr_bin];
curr_max_depth = max_depth[H + curr_bin];
curr_min_cdf = curr_max_cdf;
curr_max_cdf = curr_max_cdf + probs[H + curr_bin];
z_low = curr_min_depth;
}
if (done)
break;
// if the sampled cdf is inside bin
float u = (curr_cdf - curr_min_cdf) / (curr_max_cdf - curr_min_cdf);
float z = curr_min_depth + u * (curr_max_depth - curr_min_depth);
sampled_idx[K + s] = pts_idx[H + curr_bin];
sampled_dists[K + s] = (z - z_low);
sampled_depth[K + s] = (z + z_low) * .5;
z_low = z;
s++;
}
// if there are bins still remained
while ((z_low < curr_max_depth) && (!done) && (num_rays > (H + curr_bin)))
{
sampled_idx[K + s] = pts_idx[H + curr_bin];
sampled_dists[K + s] = (curr_max_depth - z_low);
sampled_depth[K + s] = (curr_max_depth + z_low) * .5;
curr_bin++;
s++;
if ((curr_bin >= max_hits) || (pts_idx[curr_bin] == -1))
break;
curr_min_depth = min_depth[H + curr_bin];
curr_max_depth = max_depth[H + curr_bin];
z_low = curr_min_depth;
}
}
}
void uniform_ray_sampling_kernel_wrapper(
int b, int num_rays, int max_hits, int max_steps, float step_size,
const int *pts_idx, const float *min_depth, const float *max_depth, const float *uniform_noise,
int *sampled_idx, float *sampled_depth, float *sampled_dists)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
uniform_ray_sampling_kernel<<>>(
b, num_rays, max_hits, max_steps, step_size, pts_idx,
min_depth, max_depth, uniform_noise, sampled_idx, sampled_depth, sampled_dists);
CUDA_CHECK_ERRORS();
}
void inverse_cdf_sampling_kernel_wrapper(
int b, int num_rays, int max_hits, int max_steps, float fixed_step_size,
const int *pts_idx, const float *min_depth, const float *max_depth,
const float *uniform_noise, const float *probs, const float *steps,
int *sampled_idx, float *sampled_depth, float *sampled_dists)
{
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
inverse_cdf_sampling_kernel<<>>(
b, num_rays, max_hits, max_steps, fixed_step_size,
pts_idx, min_depth, max_depth, uniform_noise, probs, steps,
sampled_idx, sampled_depth, sampled_dists);
CUDA_CHECK_ERRORS();
}