Repository: tangtaogo/lidar-nerf Branch: main Commit: 8083a1d74eef Files: 90 Total size: 579.5 KB Directory structure: gitextract_gza2uyyg/ ├── .github/ │ └── workflows/ │ └── formatter.yml ├── LICENSE ├── configs/ │ ├── kitti360_1538.txt │ ├── kitti360_1728.txt │ ├── kitti360_1908.txt │ ├── kitti360_3353.txt │ └── nerf_mvl.txt ├── extern/ │ ├── chamfer3D/ │ │ ├── chamfer3D.cu │ │ ├── chamfer_cuda.cpp │ │ ├── dist_chamfer_3D.py │ │ └── setup.py │ └── fscore.py ├── lidarmvl/ │ └── readme.md ├── lidarnerf/ │ ├── __init__.py │ ├── activation.py │ ├── convert.py │ ├── dataset/ │ │ ├── base_dataset.py │ │ ├── kitti360_dataset.py │ │ └── nerfmvl_dataset.py │ ├── encoding.py │ ├── ffmlp/ │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── ffmlp.py │ │ ├── setup.py │ │ └── src/ │ │ ├── bindings.cpp │ │ ├── cutlass_matmul.h │ │ ├── ffmlp.cu │ │ ├── ffmlp.h │ │ └── utils.h │ ├── freqencoder/ │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── freq.py │ │ ├── setup.py │ │ └── src/ │ │ ├── bindings.cpp │ │ ├── freqencoder.cu │ │ └── freqencoder.h │ ├── gridencoder/ │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── grid.py │ │ ├── setup.py │ │ └── src/ │ │ ├── bindings.cpp │ │ ├── gridencoder.cu │ │ └── gridencoder.h │ ├── loss.py │ ├── nerf/ │ │ ├── network.py │ │ ├── network_tcnn.py │ │ ├── renderer.py │ │ └── utils.py │ ├── raymarching/ │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── raymarching.py │ │ ├── setup.py │ │ └── src/ │ │ ├── bindings.cpp │ │ ├── raymarching.cu │ │ └── raymarching.h │ └── shencoder/ │ ├── __init__.py │ ├── backend.py │ ├── setup.py │ ├── sphere_harmonics.py │ └── src/ │ ├── bindings.cpp │ ├── shencoder.cu │ └── shencoder.h ├── lidarnvs/ │ ├── __init__.py │ ├── configs/ │ │ ├── pcgen_kitti360_raydrop.txt │ │ └── pcgen_nerfmvl_raydrop.txt │ ├── eval.py │ ├── lidarnvs_base.py │ ├── lidarnvs_meshing.py │ ├── lidarnvs_nksr.py │ ├── lidarnvs_pcgen.py │ ├── lidarnvs_poisson.py │ ├── loader.py │ ├── plot_possion_grid_search.py │ ├── raydrop_dataset_poisson.py │ ├── raydrop_train_pcgen.py │ ├── raydrop_train_poisson.py │ ├── readme.md │ ├── run.py │ └── unet.py ├── main_lidarnerf.py ├── preprocess/ │ ├── cal_centerpose_bound.py │ ├── generate_train_rangeview.py │ ├── kitti360_loader.py │ ├── kitti360_to_nerf.py │ ├── nerfmvl_loader.py │ └── nerfmvl_to_nerf.py ├── readme.md ├── requirements.txt ├── requirements_torch.txt └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/formatter.yml ================================================ name: Formatter on: push: branches: - main pull_request: types: [opened, reopened, synchronize] jobs: formatter: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: psf/black@stable ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Tao Tang, Yixing Lao Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: configs/kitti360_1538.txt ================================================ sequence_id = 1538 alpha_d = 1000.0 alpha_r = 1 alpha_i = 1e1 alpha_grad = 100.0 grad_loss = True desired_resolution = 32768 change_patch_size_lidar = [2, 8] num_steps = 768 upsample_steps = 64 bound = 1 scale = 0.01150158050828236 offset = [1150.2651429096413, 3997.2217130085182, 109.3943550832148] ================================================ FILE: configs/kitti360_1728.txt ================================================ sequence_id = 1728 alpha_d = 1000.0 alpha_r = 1 alpha_i = 1e1 alpha_grad = 100.0 grad_loss = True desired_resolution = 32768 change_patch_size_lidar = [2, 8] num_steps = 768 upsample_steps = 64 bound = 1 scale = 0.01235117157331213 offset = [1036.6078389848537, 3863.5989919125104, 111.73904860790459] ================================================ FILE: configs/kitti360_1908.txt ================================================ sequence_id = 1908 alpha_d = 1000.0 alpha_r = 1 alpha_i = 1e1 alpha_grad = 100.0 grad_loss = True desired_resolution = 32768 change_patch_size_lidar = [2, 8] num_steps = 768 upsample_steps = 64 bound = 1 scale = 0.010784853507573345 offset = [1069.988979297527, 3765.8807850056446, 113.0212841477088] ================================================ FILE: configs/kitti360_3353.txt ================================================ sequence_id = 3353 alpha_d = 1000.0 alpha_r = 1 alpha_i = 1e1 alpha_grad = 100.0 grad_loss = True desired_resolution = 32768 change_patch_size_lidar = [2, 8] num_steps = 768 upsample_steps = 64 bound = 1 scale = 0.00951045294058913 offset = [1364.3592435499154, 3818.620913210761, 108.69906656243805] ================================================ FILE: configs/nerf_mvl.txt ================================================ path = data/nerf_mvl dataloader = nerf_mvl sequence_id = car alpha_d = 1000.0 alpha_r = 1 alpha_i = 1 alpha_grad = 100.0 intensity_inv_scale=255.0 grad_loss = False desired_resolution = 32768 eval_interval=5 num_steps = 768 upsample_steps = 64 bound = 1 scale = 0.005 offset = [973.0483450856506, 648.3910430331337, -8.442160936778045] ================================================ FILE: extern/chamfer3D/chamfer3D.cu ================================================ #include #include #include #include #include __global__ void NmDistanceKernel(int b, int n, const float *xyz, int m, const float *xyz2, float *result, int *result_i) { const int batch = 512; __shared__ float buf[batch * 3]; for (int i = blockIdx.x; i < b; i += gridDim.x) { for (int k2 = 0; k2 < m; k2 += batch) { int end_k = min(m, k2 + batch) - k2; for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) { buf[j] = xyz2[(i * m + k2) * 3 + j]; } __syncthreads(); for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) { float x1 = xyz[(i * n + j) * 3 + 0]; float y1 = xyz[(i * n + j) * 3 + 1]; float z1 = xyz[(i * n + j) * 3 + 2]; int best_i = 0; float best = 0; int end_ka = end_k - (end_k & 3); if (end_ka == batch) { for (int k = 0; k < batch; k += 4) { { float x2 = buf[k * 3 + 0] - x1; float y2 = buf[k * 3 + 1] - y1; float z2 = buf[k * 3 + 2] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (k == 0 || d < best) { best = d; best_i = k + k2; } } { float x2 = buf[k * 3 + 3] - x1; float y2 = buf[k * 3 + 4] - y1; float z2 = buf[k * 3 + 5] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (d < best) { best = d; best_i = k + k2 + 1; } } { float x2 = buf[k * 3 + 6] - x1; float y2 = buf[k * 3 + 7] - y1; float z2 = buf[k * 3 + 8] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (d < best) { best = d; best_i = k + k2 + 2; } } { float x2 = buf[k * 3 + 9] - x1; float y2 = buf[k * 3 + 10] - y1; float z2 = buf[k * 3 + 11] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (d < best) { best = d; best_i = k + k2 + 3; } } } } else { for (int k = 0; k < end_ka; k += 4) { { float x2 = buf[k * 3 + 0] - x1; float y2 = buf[k * 3 + 1] - y1; float z2 = buf[k * 3 + 2] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (k == 0 || d < best) { best = d; best_i = k + k2; } } { float x2 = buf[k * 3 + 3] - x1; float y2 = buf[k * 3 + 4] - y1; float z2 = buf[k * 3 + 5] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (d < best) { best = d; best_i = k + k2 + 1; } } { float x2 = buf[k * 3 + 6] - x1; float y2 = buf[k * 3 + 7] - y1; float z2 = buf[k * 3 + 8] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (d < best) { best = d; best_i = k + k2 + 2; } } { float x2 = buf[k * 3 + 9] - x1; float y2 = buf[k * 3 + 10] - y1; float z2 = buf[k * 3 + 11] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (d < best) { best = d; best_i = k + k2 + 3; } } } } for (int k = end_ka; k < end_k; k++) { float x2 = buf[k * 3 + 0] - x1; float y2 = buf[k * 3 + 1] - y1; float z2 = buf[k * 3 + 2] - z1; float d = x2 * x2 + y2 * y2 + z2 * z2; if (k == 0 || d < best) { best = d; best_i = k + k2; } } if (k2 == 0 || result[(i * n + j)] > best) { result[(i * n + j)] = best; result_i[(i * n + j)] = best_i; } } __syncthreads(); } } } // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * // xyz2,float * result,int * result_i,float * result2,int * result2_i, // cudaStream_t stream){ int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { const auto batch_size = xyz1.size(0); const auto n = xyz1.size(1); // num_points point cloud A const auto m = xyz2.size(1); // num_points point cloud B NmDistanceKernel<<>>( batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); NmDistanceKernel<<>>( batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); // THError("aborting"); return 0; } return 1; } __global__ void NmDistanceGradKernel(int b, int n, const float *xyz1, int m, const float *xyz2, const float *grad_dist1, const int *idx1, float *grad_xyz1, float *grad_xyz2) { for (int i = blockIdx.x; i < b; i += gridDim.x) { for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) { float x1 = xyz1[(i * n + j) * 3 + 0]; float y1 = xyz1[(i * n + j) * 3 + 1]; float z1 = xyz1[(i * n + j) * 3 + 2]; int j2 = idx1[i * n + j]; float x2 = xyz2[(i * m + j2) * 3 + 0]; float y2 = xyz2[(i * m + j2) * 3 + 1]; float z2 = xyz2[(i * m + j2) * 3 + 2]; float g = grad_dist1[i * n + j] * 2; atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2)); atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2)); atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2)); atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2))); atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2))); atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2))); } } } // int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float * // xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const // int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){ int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { // cudaMemset(grad_xyz1,0,b*n*3*4); // cudaMemset(grad_xyz2,0,b*m*3*4); const auto batch_size = xyz1.size(0); const auto n = xyz1.size(1); // num_points point cloud A const auto m = xyz2.size(1); // num_points point cloud B NmDistanceGradKernel<<>>( batch_size, n, xyz1.data(), m, xyz2.data(), graddist1.data(), idx1.data(), gradxyz1.data(), gradxyz2.data()); NmDistanceGradKernel<<>>( batch_size, m, xyz2.data(), n, xyz1.data(), graddist2.data(), idx2.data(), gradxyz2.data(), gradxyz1.data()); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); // THError("aborting"); return 0; } return 1; } ================================================ FILE: extern/chamfer3D/chamfer_cuda.cpp ================================================ #include #include /// TMP // #include "common.h" /// NOT TMP int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); } int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); } ================================================ FILE: extern/chamfer3D/dist_chamfer_3D.py ================================================ from torch import nn from torch.autograd import Function import torch import importlib import os import sys from pathlib import Path script_dir = Path(__file__).parent.absolute() object_dir = script_dir.parent / "tmp" sys.path.append(str(object_dir)) chamfer_found = importlib.find_loader("chamfer_3D") is not None if not chamfer_found: ## Cool trick from https://github.com/chrdiller cur_path = os.path.dirname(os.path.abspath(__file__)) build_path = cur_path.replace("chamfer3D", "tmp") os.makedirs(build_path, exist_ok=True) print(f"Jitting Chamfer 3D to {build_path}") from torch.utils.cpp_extension import load chamfer_3D = load( name="chamfer_3D", sources=[ "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer_cuda.cpp"]), "/".join(os.path.abspath(__file__).split("/")[:-1] + ["chamfer3D.cu"]), ], build_directory=build_path, ) print(f"Loaded jitted library {chamfer_3D.__file__}") else: import chamfer_3D print(f"Loaded pre-compiled library {chamfer_3D.__file__}") # Chamfer's distance module @thibaultgroueix # GPU tensors only class chamfer_3DFunction(Function): @staticmethod def forward(ctx, xyz1, xyz2): batchsize, n, dim = xyz1.size() assert ( dim == 3 ), "Wrong last dimension for the chamfer distance 's input! Check with .size()" _, m, dim = xyz2.size() assert ( dim == 3 ), "Wrong last dimension for the chamfer distance 's input! Check with .size()" device = xyz1.device device = xyz1.device dist1 = torch.zeros(batchsize, n) dist2 = torch.zeros(batchsize, m) idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) dist1 = dist1.to(device) dist2 = dist2.to(device) idx1 = idx1.to(device) idx2 = idx2.to(device) torch.cuda.set_device(device) chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) ctx.save_for_backward(xyz1, xyz2, idx1, idx2) return dist1, dist2, idx1, idx2 @staticmethod def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): xyz1, xyz2, idx1, idx2 = ctx.saved_tensors graddist1 = graddist1.contiguous() graddist2 = graddist2.contiguous() device = graddist1.device gradxyz1 = torch.zeros(xyz1.size()) gradxyz2 = torch.zeros(xyz2.size()) gradxyz1 = gradxyz1.to(device) gradxyz2 = gradxyz2.to(device) chamfer_3D.backward( xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 ) return gradxyz1, gradxyz2 class chamfer_3DDist(nn.Module): def __init__(self): super(chamfer_3DDist, self).__init__() def forward(self, input1, input2): input1 = input1.contiguous() input2 = input2.contiguous() return chamfer_3DFunction.apply(input1, input2) ================================================ FILE: extern/chamfer3D/setup.py ================================================ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension setup( name="chamfer_3D", ext_modules=[ CUDAExtension( "chamfer_3D", [ "/".join(__file__.split("/")[:-1] + ["chamfer_cuda.cpp"]), "/".join(__file__.split("/")[:-1] + ["chamfer3D.cu"]), ], ), ], cmdclass={"build_ext": BuildExtension}, ) ================================================ FILE: extern/fscore.py ================================================ import torch def fscore(dist1, dist2, threshold=0.001): """ Calculates the F-score between two point clouds with the corresponding threshold value. :param dist1: Batch, N-Points :param dist2: Batch, N-Points :param th: float :return: fscore, precision, recall """ # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean # distances, so you should adapt the threshold accordingly. precision_1 = torch.mean((dist1 < threshold).float(), dim=1) precision_2 = torch.mean((dist2 < threshold).float(), dim=1) fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) fscore[torch.isnan(fscore)] = 0 return fscore, precision_1, precision_2 ================================================ FILE: lidarmvl/readme.md ================================================ # LiDAR-MVL ![dataset_vis.png](../assets/dataset_vis.png) | Sensor | Details (Sensor location: F: front. T: top.) | | ----------------------------- | ------------------------------------------------------------ | | LiDAR LiDAR-F | Spinning, 64 beams, 10Hz capture frequency, 360° horizontal FOV, 0.6° horizontal resolution, -52.1° to +52.1° vertical FOV, ≤60m range, ±3cm accuracy. | | LiDAR-T | Spinning, 64 beams, 20Hz capture frequency, 360° horizontal FOV, 0.4° horizontal resolution, -25° to +15° vertical FOV, ≤200m range, ±2cm accuracy. | We establish an object-centric **m**ulti-**v**iew **L**iDAR dataset, which we dub the **NeRF-MVL** dataset, containing carefully calibrated sensor poses, acquired from multi-LiDAR sensor data from real autonomous vehicles. It contains more than **76k frames** covering two types of collecting vehicles, three LiDAR settings, two collecting paths, and nine object categories. ## Dataset Format ```bash nerf_mvl ├── nerf_mvl_76k │   ├── vehicle_type1 │   │   ├── LiDAR │   │   │   └── {class_name} │   │   │ ├── l │   │   │ └── s │   │   │ ├── {frame_id}.npy │   │   │ └── lidar2world.txt │   │   ├── LiDAR_F │   │   └── LiDAR_T │   └── vehicle_type2 │   ├── LiDAR │   ├── LiDAR_F │   └── LiDAR_T │ └── nerf_mvl_7k    └── {class_name} ├── {frame_id}.npy └── lidar2world.txt Note: {class_name}: {bollard, pedestrian, plant, traffic_cone, water_safety_barrier, car, pier, tire, warning_sign} {frame_id}.npy: local point clouds, (N,4) lidar2world.txt: the lidar to world matrix, (M, 16) l/s: large/small collecting paths ``` For fast validation, we extract a pocket version of the dataset with only 7.3k frames covering the nine categories, called **nerf_mvl_7k**. For all point clound frames, we crop out the region of interest, i.e., the object. The raw data will also be released to the community soon. ## Citation If you find our dataset helps, please consider citing: ``` @article{tao2023lidar, title={LiDAR-NeRF: Novel LiDAR View Synthesis via Neural Radiance Fields}, author={Tao, Tang and Gao, Longfei and Wang, Guangrun and Lao, Yixing and Chen, Peng and Zhao hengshuang and Hao, Dayang and Liang, Xiaodan and Salzmann, Mathieu and Yu, Kaicheng}, journal={arXiv preprint arXiv:2304.10406}, year={2023} } ``` ================================================ FILE: lidarnerf/__init__.py ================================================ __version__ = "0.1.0" ================================================ FILE: lidarnerf/activation.py ================================================ import torch from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd class _trunc_exp(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # cast to float32 def forward(ctx, x): ctx.save_for_backward(x) return torch.exp(x) @staticmethod @custom_bwd def backward(ctx, g): x = ctx.saved_tensors[0] return g * torch.exp(x.clamp(-15, 15)) trunc_exp = _trunc_exp.apply ================================================ FILE: lidarnerf/convert.py ================================================ import numpy as np def lidar_to_pano_with_intensities_with_bbox_mask( local_points_with_intensities: np.ndarray, lidar_H: int, lidar_W: int, lidar_K: int, bbox_local: np.ndarray, max_depth=80, max_intensity=255.0, ): """ Convert lidar frame to pano frame with intensities with bbox_mask. Lidar points are in local coordinates. Args: local_points: (N, 4), float32, in lidar frame, with intensities. lidar_H: pano height. lidar_W: pano width. lidar_K: lidar intrinsics. bbox_local: (8x4), world bbox in local. max_depth: max depth in meters. max_intensity: max intensity. Return: pano: (H, W), float32. intensities: (H, W), float32. """ # Un pack. local_points = local_points_with_intensities[:, :3] local_point_intensities = local_points_with_intensities[:, 3] fov_up, fov = lidar_K fov_down = fov - fov_up # Compute dists to lidar center. dists = np.linalg.norm(local_points, axis=1) # Fill pano and intensities. pano = np.zeros((lidar_H, lidar_W)) intensities = np.zeros((lidar_H, lidar_W)) # bbox mask pano[:, :] = -1 r_min, r_max, c_min, c_max = 1e5, -1, 1e5, -1 for bbox_local_point in bbox_local: x, y, z, _ = bbox_local_point beta = np.pi - np.arctan2(y, x) alpha = np.arctan2(z, np.sqrt(x**2 + y**2)) + fov_down / 180 * np.pi c = int(round(beta / (2 * np.pi / lidar_W))) r = int(round(lidar_H - alpha / (fov / 180 * np.pi / lidar_H))) # Check out-of-bounds. if r >= lidar_H or r < 0 or c >= lidar_W or c < 0: continue else: r_min, r_max, c_min, c_max = ( min(r_min, r), max(r_max, r), min(c_min, c), max(c_max, c), ) pano[r_min:r_max, c_min:c_max] = 0 # Fill pano and intensities. for local_points, dist, local_point_intensity in zip( local_points, dists, local_point_intensities, ): # Check max depth. if dist >= max_depth: continue x, y, z = local_points beta = np.pi - np.arctan2(y, x) alpha = np.arctan2(z, np.sqrt(x**2 + y**2)) + fov_down / 180 * np.pi c = int(round(beta / (2 * np.pi / lidar_W))) r = int(round(lidar_H - alpha / (fov / 180 * np.pi / lidar_H))) # Check out-of-bounds. if r >= lidar_H or r < 0 or c >= lidar_W or c < 0: continue # Set to min dist if not set. if pano[r, c] == 0.0: pano[r, c] = dist intensities[r, c] = local_point_intensity / max_intensity elif pano[r, c] > dist: pano[r, c] = dist intensities[r, c] = local_point_intensity / max_intensity return pano, intensities def lidar_to_pano_with_intensities( local_points_with_intensities: np.ndarray, lidar_H: int, lidar_W: int, lidar_K: int, max_depth=80, ): """ Convert lidar frame to pano frame with intensities. Lidar points are in local coordinates. Args: local_points: (N, 4), float32, in lidar frame, with intensities. lidar_H: pano height. lidar_W: pano width. lidar_K: lidar intrinsics. max_depth: max depth in meters. Return: pano: (H, W), float32. intensities: (H, W), float32. """ # Un pack. local_points = local_points_with_intensities[:, :3] local_point_intensities = local_points_with_intensities[:, 3] fov_up, fov = lidar_K fov_down = fov - fov_up # Compute dists to lidar center. dists = np.linalg.norm(local_points, axis=1) # Fill pano and intensities. pano = np.zeros((lidar_H, lidar_W)) intensities = np.zeros((lidar_H, lidar_W)) for local_points, dist, local_point_intensity in zip( local_points, dists, local_point_intensities, ): # Check max depth. if dist >= max_depth: continue x, y, z = local_points beta = np.pi - np.arctan2(y, x) alpha = np.arctan2(z, np.sqrt(x**2 + y**2)) + fov_down / 180 * np.pi c = int(round(beta / (2 * np.pi / lidar_W))) r = int(round(lidar_H - alpha / (fov / 180 * np.pi / lidar_H))) # Check out-of-bounds. if r >= lidar_H or r < 0 or c >= lidar_W or c < 0: continue # Set to min dist if not set. if pano[r, c] == 0.0: pano[r, c] = dist intensities[r, c] = local_point_intensity elif pano[r, c] > dist: pano[r, c] = dist intensities[r, c] = local_point_intensity return pano, intensities def lidar_to_pano( local_points: np.ndarray, lidar_H: int, lidar_W: int, lidar_K: int, max_dpeth=80 ): """ Convert lidar frame to pano frame. Lidar points are in local coordinates. Args: local_points: (N, 3), float32, in lidar frame. lidar_H: pano height. lidar_W: pano width. lidar_K: lidar intrinsics. max_depth: max depth in meters. Return: pano: (H, W), float32. """ # (N, 3) -> (N, 4), filled with zeros. local_points_with_intensities = np.concatenate( [local_points, np.zeros((local_points.shape[0], 1))], axis=1 ) pano, _ = lidar_to_pano_with_intensities( local_points_with_intensities=local_points_with_intensities, lidar_H=lidar_H, lidar_W=lidar_W, lidar_K=lidar_K, max_dpeth=max_dpeth, ) return pano def pano_to_lidar_with_intensities(pano: np.ndarray, intensities, lidar_K): """ Args: pano: (H, W), float32. intensities: (H, W), float32. lidar_K: lidar intrinsics (fov_up, fov) Return: local_points_with_intensities: (N, 4), float32, in lidar frame. """ fov_up, fov = lidar_K H, W = pano.shape i, j = np.meshgrid( np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing="xy" ) beta = -(i - W / 2) / W * 2 * np.pi alpha = (fov_up - j / H * fov) / 180 * np.pi dirs = np.stack( [ np.cos(alpha) * np.cos(beta), np.cos(alpha) * np.sin(beta), np.sin(alpha), ], -1, ) local_points = dirs * pano.reshape(H, W, 1) # local_points: (H, W, 3) # intensities : (H, W) # local_points_with_intensities: (H, W, 4) local_points_with_intensities = np.concatenate( [local_points, intensities.reshape(H, W, 1)], axis=2 ) # Filter empty points. idx = np.where(pano != 0.0) local_points_with_intensities = local_points_with_intensities[idx] return local_points_with_intensities def pano_to_lidar(pano, lidar_K): """ Args: pano: (H, W), float32. lidar_K: lidar intrinsics (fov_up, fov) Return: local_points: (N, 3), float32, in lidar frame. """ local_points_with_intensities = pano_to_lidar_with_intensities( pano=pano, intensities=np.zeros_like(pano), lidar_K=lidar_K, ) return local_points_with_intensities[:, :3] def lidar_to_pano_with_intensities_fpa( local_points_with_intensities: np.ndarray, lidar_H: int, lidar_W: int, lidar_K: int, max_depth=80, z_buffer_len=10, ): """ Convert lidar frame to pano frame with intensities with bbox_mask. Lidar points are in local coordinates. Args: local_points: (N, 4), float32, in lidar frame, with intensities. lidar_H: pano height. lidar_W: pano width. lidar_K: lidar intrinsics. max_depth: max depth in meters. z_buffer_len: length of the z_buffer. Return: rangeview image: (H, W, 3), float32. """ # Un pack. local_points = local_points_with_intensities[:, :3] local_point_intensities = local_points_with_intensities[:, 3] fov_up, fov = lidar_K fov_down = fov - fov_up # Compute dists to lidar center. dists = np.linalg.norm(local_points, axis=1) # Fill pano and intensities. range_view = np.zeros((lidar_H, lidar_W, 3, z_buffer_len + 1)) for local_point, dist, local_point_intensity in zip( local_points, dists, local_point_intensities, ): # Check max depth. if dist >= max_depth: continue x, y, z = local_point beta = np.pi - np.arctan2(y, x) alpha = np.arctan2(z, np.sqrt(x**2 + y**2)) + fov_down / 180 * np.pi c = int(round(beta / (2 * np.pi / lidar_W))) r = int(round(lidar_H - alpha / (fov / 180 * np.pi / lidar_H))) if r >= lidar_H or r < 0 or c >= lidar_W or c < 0: continue position = range_view[r, c, 2, 0] + 1 if position > z_buffer_len: depth_z_buffer = list(range_view[r, c, 2][1:]) + [dist] intensity_z_buffer = list(range_view[r, c, 1][1:]) + [local_point_intensity] position = position - 1 sort_index = np.argsort(depth_z_buffer) depth_z_buffer = np.insert( np.array(depth_z_buffer)[sort_index][:z_buffer_len], 0, position ) intensity_z_buffer = np.insert( np.array(intensity_z_buffer)[sort_index][:z_buffer_len], 0, position ) range_view[r, c, 2] = depth_z_buffer range_view[r, c, 1] = intensity_z_buffer else: range_view[r, c, 2, int(position)] = dist range_view[r, c, 1, int(position)] = local_point_intensity range_view[r, c, 2, 0] = position range_view = parse_z_buffer(range_view, lidar_H, lidar_W) return range_view[:, :, 2], range_view[:, :, 1] def parse_z_buffer(novel_pano, lidar_H, lidar_W, threshold=0.2): range_view = np.zeros((lidar_H, lidar_W, 3)) for i in range(lidar_H): for j in range(lidar_W): range_pixel = novel_pano[i, j, 2] intensity_pixel = novel_pano[i, j, 1] z_buffer_num = int(range_pixel[0]) if z_buffer_num == 0: continue if z_buffer_num == 1: range_view[i][j][2] = range_pixel[1] range_view[i][j][1] = intensity_pixel[1] continue depth_z_buffer = range_pixel[1:z_buffer_num] cloest_points = min(depth_z_buffer) index = depth_z_buffer <= (cloest_points + threshold) final_depth_z_buffer = np.array(depth_z_buffer)[index] final_dis = np.average( final_depth_z_buffer, weights=1 / final_depth_z_buffer ) range_view[i][j][2] = final_dis intensity_z_buffer = intensity_pixel[1:z_buffer_num] final_intensity_z_buffer = np.array(intensity_z_buffer)[index] final_intensity = np.average( final_intensity_z_buffer, weights=1 / final_depth_z_buffer ) range_view[i][j][1] = final_intensity return range_view ================================================ FILE: lidarnerf/dataset/base_dataset.py ================================================ import numpy as np import torch import trimesh from packaging import version as pver from dataclasses import dataclass def custom_meshgrid(*args): if pver.parse(torch.__version__) < pver.parse("1.10"): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing="ij") @torch.cuda.amp.autocast(enabled=False) def get_lidar_rays(poses, intrinsics, H, W, N=-1, patch_size=1): """ Get lidar rays. Args: poses: [B, 4, 4], cam2world intrinsics: [2] H, W, N: int Returns: rays_o, rays_d: [B, N, 3] inds: [B, N] """ device = poses.device B = poses.shape[0] i, j = custom_meshgrid( torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device), ) # float # i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 # j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 i = i.t().reshape([1, H * W]).expand([B, H * W]) j = j.t().reshape([1, H * W]).expand([B, H * W]) results = {} if N > 0: N = min(N, H * W) if isinstance(patch_size, int): patch_size_x, patch_size_y = patch_size, patch_size elif len(patch_size) == 1: patch_size_x, patch_size_y = patch_size[0], patch_size[0] else: patch_size_x, patch_size_y = patch_size if patch_size_x > 0: # random sample left-top cores. # NOTE: this impl will lead to less sampling on the image corner # pixels... but I don't have other ideas. num_patch = N // (patch_size_x * patch_size_y) inds_x = torch.randint(0, H - patch_size_x, size=[num_patch], device=device) inds_y = torch.randint(0, W - patch_size_y, size=[num_patch], device=device) inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2] # create meshgrid for each patch pi, pj = custom_meshgrid( torch.arange(patch_size_x, device=device), torch.arange(patch_size_y, device=device), ) offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2] inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2] inds = inds.view(-1, 2) # [N, 2] inds = inds[:, 0] * W + inds[:, 1] # [N], flatten inds = inds.expand([B, N]) else: inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate inds = inds.expand([B, N]) i = torch.gather(i, -1, inds) j = torch.gather(j, -1, inds) results["inds"] = inds else: inds = torch.arange(H * W, device=device).expand([B, H * W]) results["inds"] = inds fov_up, fov = intrinsics beta = -(i - W / 2) / W * 2 * np.pi alpha = (fov_up - j / H * fov) / 180 * np.pi directions = torch.stack( [ torch.cos(alpha) * torch.cos(beta), torch.cos(alpha) * torch.sin(beta), torch.sin(alpha), ], -1, ) # directions = directions / torch.norm(directions, dim=-1, keepdim=True) rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) rays_o = poses[..., :3, 3] # [B, 3] rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] results["rays_o"] = rays_o results["rays_d"] = rays_d return results @torch.cuda.amp.autocast(enabled=False) def get_rays(poses, intrinsics, H, W, N=-1, patch_size=1): """get rays Args: poses: [B, 4, 4], cam2world intrinsics: [4] H, W, N: int Returns: rays_o, rays_d: [B, N, 3] inds: [B, N] """ device = poses.device B = poses.shape[0] fx, fy, cx, cy = intrinsics i, j = custom_meshgrid( torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device), ) # float i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 results = {} if N > 0: N = min(N, H * W) if patch_size > 1: # random sample left-top cores. # NOTE: this impl will lead to less sampling on the image corner # pixels... but I don't have other ideas. num_patch = N // (patch_size**2) inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device) inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device) inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2] # create meshgrid for each patch pi, pj = custom_meshgrid( torch.arange(patch_size, device=device), torch.arange(patch_size, device=device), ) offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2] inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2] inds = inds.view(-1, 2) # [N, 2] inds = inds[:, 0] * W + inds[:, 1] # [N], flatten inds = inds.expand([B, N]) else: inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate inds = inds.expand([B, N]) i = torch.gather(i, -1, inds) j = torch.gather(j, -1, inds) results["inds"] = inds else: inds = torch.arange(H * W, device=device).expand([B, H * W]) zs = torch.ones_like(i) xs = (i - cx) / fx * zs ys = (j - cy) / fy * zs directions = torch.stack((xs, ys, zs), dim=-1) directions = directions / torch.norm(directions, dim=-1, keepdim=True) rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) rays_o = poses[..., :3, 3] # [B, 3] rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] results["rays_o"] = rays_o results["rays_d"] = rays_d return results # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): # for the fox dataset, 0.33 scales camera radius to ~ 2 new_pose = np.array( [ [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]], [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]], [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]], [0, 0, 0, 1], ], dtype=np.float32, ) return new_pose def visualize_poses(poses, size=0.1): # poses: [B, 4, 4] axes = trimesh.creation.axis(axis_length=4) box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() box.colors = np.array([[128, 128, 128]] * len(box.entities)) objects = [axes, box] for pose in poses: # a camera is visualized with 8 line segments. pos = pose[:3, 3] a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] dir = (a + b + c + d) / 4 - pos dir = dir / (np.linalg.norm(dir) + 1e-8) o = pos + dir * 3 segs = np.array( [ [pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o], ] ) segs = trimesh.load_path(segs) objects.append(segs) trimesh.Scene(objects).show() @dataclass class BaseDataset: pass ================================================ FILE: lidarnerf/dataset/kitti360_dataset.py ================================================ import json import os import numpy as np import torch import tqdm from torch.utils.data import DataLoader from dataclasses import dataclass, field from lidarnerf.dataset.base_dataset import get_lidar_rays, BaseDataset @dataclass class KITTI360Dataset(BaseDataset): device: str = "cpu" split: str = "train" # train, val, test root_path: str = "data/kitti360" sequence_id: str = "1908" preload: bool = True # preload data into GPU scale: float = ( 1 # camera radius scale to make sure camera are inside the bounding box. ) offset: list = field(default_factory=list) # offset # bound = opt.bound # bounding box half length, also used as the radius to random sample poses. fp16: bool = True # if preload, load into fp16. patch_size: int = 1 # size of the image to extract from the scene. patch_size_lidar: int = 1 # size of the image to extract from the Lidar. enable_lidar: bool = True num_rays: int = 4096 num_rays_lidar: int = 4096 def __post_init__(self): if self.sequence_id == "1538": print("Using sqequence 1538-1601") elif self.sequence_id == "1728": print("Using sqequence 1728-1791") elif self.sequence_id == "1908": print("Using sqequence 1908-1971") elif self.sequence_id == "3353": print("Using sqequence 3353-3416") else: raise ValueError(f"Invalid sequence id: {sequence_id}") self.training = self.split in ["train", "all", "trainval"] self.num_rays = self.num_rays if self.training else -1 self.num_rays_lidar = self.num_rays_lidar if self.training else -1 # load nerf-compatible format data. with open( os.path.join( self.root_path, f"transforms_{self.sequence_id}_{self.split}.json" ), "r", ) as f: transform = json.load(f) # load image size if "h" in transform and "w" in transform: self.H = int(transform["h"]) self.W = int(transform["w"]) else: # we have to actually read an image to get H and W later. self.H = self.W = None if "h_lidar" in transform and "w_lidar" in transform: self.H_lidar = int(transform["h_lidar"]) self.W_lidar = int(transform["w_lidar"]) # read images frames = transform["frames"] # frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort... self.poses_lidar = [] self.images_lidar = [] for f in tqdm.tqdm(frames, desc=f"Loading {self.split} data"): pose_lidar = np.array(f["lidar2world"], dtype=np.float32) # [4, 4] f_lidar_path = os.path.join(self.root_path, f["lidar_file_path"]) # channel1 None, channel2 intensity , channel3 depth pc = np.load(f_lidar_path) ray_drop = np.where(pc.reshape(-1, 3)[:, 2] == 0.0, 0.0, 1.0).reshape( self.H_lidar, self.W_lidar, 1 ) image_lidar = np.concatenate( [ray_drop, pc[:, :, 1, None], pc[:, :, 2, None] * self.scale], axis=-1, ) self.poses_lidar.append(pose_lidar) self.images_lidar.append(image_lidar) self.poses_lidar = np.stack(self.poses_lidar, axis=0) self.poses_lidar[:, :3, -1] = ( self.poses_lidar[:, :3, -1] - self.offset ) * self.scale self.poses_lidar = torch.from_numpy(self.poses_lidar) # [N, 4, 4] if self.images_lidar is not None: self.images_lidar = torch.from_numpy( np.stack(self.images_lidar, axis=0) ).float() # [N, H, W, C] # calculate mean radius of all camera poses # self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() # print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}') # [debug] uncomment to view all training poses. # visualize_poses(self.poses.numpy()) if self.preload: self.poses_lidar = self.poses_lidar.to(self.device) if self.images_lidar is not None: # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ? if self.fp16: dtype = torch.half else: dtype = torch.float self.images_lidar = self.images_lidar.to(dtype).to(self.device) self.intrinsics_lidar = (2.0, 26.9) # fov_up, fov def collate(self, index): B = len(index) # a list of length 1 results = {} if self.enable_lidar: poses_lidar = self.poses_lidar[index].to(self.device) # [B, 4, 4] rays_lidar = get_lidar_rays( poses_lidar, self.intrinsics_lidar, self.H_lidar, self.W_lidar, self.num_rays_lidar, self.patch_size_lidar, ) results.update( { "H_lidar": self.H_lidar, "W_lidar": self.W_lidar, "rays_o_lidar": rays_lidar["rays_o"], "rays_d_lidar": rays_lidar["rays_d"], } ) if self.images_lidar is not None and self.enable_lidar: images_lidar = self.images_lidar[index].to(self.device) # [B, H, W, 3/4] if self.training: C = images_lidar.shape[-1] images_lidar = torch.gather( images_lidar.view(B, -1, C), 1, torch.stack(C * [rays_lidar["inds"]], -1), ) # [B, N, 3/4] results["images_lidar"] = images_lidar return results def dataloader(self): size = len(self.poses_lidar) loader = DataLoader( list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0, ) loader._data = self loader.has_gt = self.images_lidar is not None return loader def __len__(self): """ Returns # of frames in this dataset. """ num_frames = len(self.poses_lidar) return num_frames ================================================ FILE: lidarnerf/dataset/nerfmvl_dataset.py ================================================ import os import json import tqdm import numpy as np import torch from torch.utils.data import DataLoader from dataclasses import dataclass, field from lidarnerf.dataset.base_dataset import get_lidar_rays, BaseDataset @dataclass class NeRFMVLDataset(BaseDataset): device: str = "cpu" split: str = "train" # train, val, test root_path: str = "data/kitti360" sequence_id: str = "car" preload: bool = True # preload data into GPU scale: float = ( 1 # camera radius scale to make sure camera are inside the bounding box. ) offset: list = field(default_factory=list) # offset # bound = opt.bound # bounding box half length, also used as the radius to random sample poses. fp16: bool = True # if preload, load into fp16. patch_size: int = 1 # size of the image to extract from the scene. patch_size_lidar: int = 1 # size of the image to extract from the Lidar. enable_lidar: bool = True num_rays: int = 4096 num_rays_lidar: int = 4096 def __post_init__(self): self.class_name = self.sequence_id self.training = self.split in ["train", "all", "trainval"] self.testing = self.split in ["test"] self.num_rays = self.num_rays if self.training else -1 self.num_rays_lidar = self.num_rays_lidar if self.training else -1 with open( os.path.join( self.root_path, f"transforms_{self.class_name}_{self.split}.json" ), "r", ) as f: transform = json.load(f) if "h_lidar" in transform and "w_lidar" in transform: self.H_lidar = int(transform["h_lidar"]) self.W_lidar = int(transform["w_lidar"]) # read images frames = transform["frames"] # frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort... self.poses_lidar = [] self.images_lidar = [] for f in tqdm.tqdm(frames, desc=f"Loading {self.split} data"): pose_lidar = np.array(f["lidar2world"], dtype=np.float32) # [4, 4] self.poses_lidar.append(pose_lidar) if "lidar_file_path" in f.keys(): f_lidar_path = os.path.join(self.root_path, f["lidar_file_path"]) # channel1 None, channel2 intensity , channel3 depth pc = np.load(f_lidar_path)["data"] # ray_drop = np.where(pc.reshape(-1, 3)[:, 2] == 0.0, 0.0, # 1.0).reshape(self.H_lidar, self.W_lidar, 1) ray_drop = pc.reshape(-1, 3)[:, 2].copy() ray_drop[ray_drop > 0] = 1.0 ray_drop = ray_drop.reshape(self.H_lidar, self.W_lidar, 1) image_lidar = np.concatenate( [ray_drop, pc[:, :, 1, None], pc[:, :, 2, None] * self.scale], axis=-1, ) self.images_lidar.append(image_lidar) else: self.images_lidar = None dataset_bbox = np.load( os.path.join(self.root_path, "dataset_bbox_7k.npy"), allow_pickle=True ).item() self.OBB = dataset_bbox[self.class_name] self.offset = np.mean(self.OBB, axis=0) self.poses_lidar = np.stack(self.poses_lidar, axis=0) self.poses_lidar_wo_scale_offset = self.poses_lidar.copy() self.OBB_pad = np.concatenate([self.OBB, np.ones((8, 1))], axis=1) self.OBB_local = [ self.OBB_pad @ np.linalg.inv(pose_lidar_wo_scale_offset.reshape(4, 4)).T for pose_lidar_wo_scale_offset in self.poses_lidar_wo_scale_offset ] self.OBB_local = np.stack(self.OBB_local, axis=0) self.poses_lidar[:, :3, -1] = ( self.poses_lidar[:, :3, -1] - self.offset ) * self.scale self.poses_lidar = torch.from_numpy(self.poses_lidar) # [N, 4, 4] if self.images_lidar is not None: self.images_lidar = torch.from_numpy( np.stack(self.images_lidar, axis=0) ).float() # [N, H, W, C] if self.preload: self.poses_lidar = self.poses_lidar.to(self.device) if self.images_lidar is not None: # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ? if self.fp16: dtype = torch.half else: dtype = torch.float self.images_lidar = self.images_lidar.to(dtype).to(self.device) self.intrinsics_lidar = (15, 40) # fov_up, fov def collate(self, index): B = len(index) # a list of length 1 results = {} if self.enable_lidar: poses_lidar = self.poses_lidar[index].to(self.device) # [B, 4, 4] rays_lidar = get_lidar_rays( poses_lidar, self.intrinsics_lidar, self.H_lidar, self.W_lidar, -1, self.patch_size_lidar, ) results.update( { "H_lidar": self.H_lidar, "W_lidar": self.W_lidar, "rays_o_lidar": rays_lidar["rays_o"], "rays_d_lidar": rays_lidar["rays_d"], } ) if self.testing: results["OBB_local"] = self.OBB_local[index].reshape(8, 4) if self.images_lidar is not None and self.enable_lidar: images_lidar = self.images_lidar[index].to(self.device) # [B, H, W, 3/4] if self.training: C = images_lidar.shape[-1] images_lidar = torch.gather( images_lidar.view(B, -1, C), 1, torch.stack(C * [rays_lidar["inds"]], -1), ) # [B, N, 3/4] mask = images_lidar[:, :, 0] > -1 results["images_lidar"] = images_lidar[mask].view(B, -1, C) results["rays_o_lidar"] = results["rays_o_lidar"][mask].view(B, -1, 3) results["rays_d_lidar"] = results["rays_d_lidar"][mask].view(B, -1, 3) valid_num_rays = results["rays_o_lidar"].shape[1] if valid_num_rays > self.num_rays_lidar: # mask_inds = torch.randint(0, valid_num_rays, size=[self.num_rays_lidar], device=self.device) mask_inds = torch.randperm(valid_num_rays)[: self.num_rays_lidar] results["images_lidar"] = results["images_lidar"][ :, mask_inds, : ].view(B, -1, C) results["rays_o_lidar"] = results["rays_o_lidar"][ :, mask_inds, : ].view(B, -1, 3) results["rays_d_lidar"] = results["rays_d_lidar"][ :, mask_inds, : ].view(B, -1, 3) else: results["images_lidar"] = images_lidar return results def dataloader(self): size = len(self.poses_lidar) loader = DataLoader( list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0, ) loader._data = self loader.has_gt = self.images_lidar is not None return loader def __len__(self): """ Returns # of frames in this dataset. """ num_frames = len(self.poses_lidar) return num_frames ================================================ FILE: lidarnerf/encoding.py ================================================ import torch import torch.nn as nn import numpy as np class FreqEncoder(nn.Module): def __init__( self, input_dim, max_freq_log2, N_freqs, log_sampling=True, include_input=True, periodic_fns=(torch.sin, torch.cos), ): super().__init__() self.input_dim = input_dim self.include_input = include_input self.periodic_fns = periodic_fns self.output_dim = 0 if self.include_input: self.output_dim += self.input_dim self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) if log_sampling: self.freq_bands = 2.0 ** torch.linspace(0.0, max_freq_log2, N_freqs) else: self.freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq_log2, N_freqs) self.freq_bands = self.freq_bands.numpy().tolist() def forward(self, input, **kwargs): out = [] if self.include_input: out.append(input) for i in range(len(self.freq_bands)): freq = self.freq_bands[i] for p_fn in self.periodic_fns: out.append(p_fn(input * freq)) out = torch.cat(out, dim=-1) return out def get_encoder( encoding, input_dim=3, multires=6, degree=4, num_levels=16, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=2048, align_corners=False, **kwargs ): if encoding == "None": return lambda x, **kwargs: x, input_dim elif encoding == "frequency": # encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True) from freqencoder import FreqEncoder encoder = FreqEncoder(input_dim=input_dim, degree=multires) elif encoding == "sphere_harmonics": from shencoder import SHEncoder encoder = SHEncoder(input_dim=input_dim, degree=degree) elif encoding == "hashgrid": from gridencoder import GridEncoder encoder = GridEncoder( input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype="hash", align_corners=align_corners, ) elif encoding == "tiledgrid": from gridencoder import GridEncoder encoder = GridEncoder( input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype="tiled", align_corners=align_corners, ) elif encoding == "ash": from ashencoder import AshEncoder encoder = AshEncoder( input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution, ) else: raise NotImplementedError( "Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]" ) return encoder, encoder.output_dim class PeriodicVolumeEncoding(nn.Module): """Periodic Volume encoding Args: num_levels: Number of feature grids. min_res: Resolution of smallest feature grid. max_res: Resolution of largest feature grid. log2_hashmap_size: Size of hash map is 2^log2_hashmap_size. features_per_level: Number of features per level. hash_init_scale: Value to initialize hash grid. implementation: Implementation of hash encoding. Fallback to torch if tcnn not available. """ def __init__( self, num_levels: int = 16, min_res: int = 16, max_res: int = 1024, log2_hashmap_size: int = 19, features_per_level: int = 2, hash_init_scale: float = 0.001, smoothstep: bool = False, ) -> None: super(PeriodicVolumeEncoding, self).__init__() self.in_dim = 3 self.num_levels = num_levels self.features_per_level = features_per_level self.log2_hashmap_size = log2_hashmap_size assert log2_hashmap_size % 3 == 0 self.hash_table_size = 2**log2_hashmap_size self.n_output_dims = num_levels * features_per_level self.smoothstep = smoothstep levels = torch.arange(num_levels) growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1)) self.scalings = torch.floor(min_res * growth_factor**levels) self.periodic_volume_resolution = 2 ** (log2_hashmap_size // 3) # self.periodic_resolution = torch.minimum(torch.floor(self.scalings), periodic_volume_resolution) self.hash_offset = levels * self.hash_table_size self.hash_table = ( torch.rand(size=(self.hash_table_size * num_levels, features_per_level)) * 2 - 1 ) self.hash_table *= hash_init_scale self.hash_table = nn.Parameter(self.hash_table) # TODO weight loss by level? self.per_level_weights = 1.0 def parameters(self): return self.hash_table def get_out_dim(self) -> int: return self.num_levels * self.features_per_level def hash_fn(self, in_tensor): """Returns hash tensor using method described in Instant-NGP Args: in_tensor: Tensor to be hashed """ # round to make it perioidic x = in_tensor x %= self.periodic_volume_resolution # xyz to index x = ( x[..., 0] * (self.periodic_volume_resolution**2) + x[..., 1] * (self.periodic_volume_resolution) + x[..., 2] ) # offset by feature levels x += self.hash_offset.to(x.device) return x.long() def pytorch_fwd(self, in_tensor): """Forward pass using pytorch. Significantly slower than TCNN implementation.""" assert in_tensor.shape[-1] == 3 in_tensor = in_tensor[..., None, :] # [..., 1, 3] scaled = in_tensor * self.scalings.view(-1, 1).to( in_tensor.device ) # [..., L, 3] scaled_c = torch.ceil(scaled).type(torch.int32) scaled_f = torch.floor(scaled).type(torch.int32) offset = scaled - scaled_f if self.smoothstep: offset = offset * offset * (3.0 - 2.0 * offset) hashed_0 = self.hash_fn(scaled_c) # [..., num_levels] hashed_1 = self.hash_fn( torch.cat( [scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1 ) ) hashed_2 = self.hash_fn( torch.cat( [scaled_f[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1 ) ) hashed_3 = self.hash_fn( torch.cat( [scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_c[..., 2:3]], dim=-1 ) ) hashed_4 = self.hash_fn( torch.cat( [scaled_c[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1 ) ) hashed_5 = self.hash_fn( torch.cat( [scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_f[..., 2:3]], dim=-1 ) ) hashed_6 = self.hash_fn(scaled_f) hashed_7 = self.hash_fn( torch.cat( [scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1 ) ) f_0 = self.hash_table[hashed_0] # [..., num_levels, features_per_level] f_1 = self.hash_table[hashed_1] f_2 = self.hash_table[hashed_2] f_3 = self.hash_table[hashed_3] f_4 = self.hash_table[hashed_4] f_5 = self.hash_table[hashed_5] f_6 = self.hash_table[hashed_6] f_7 = self.hash_table[hashed_7] f_03 = f_0 * offset[..., 0:1] + f_3 * (1 - offset[..., 0:1]) f_12 = f_1 * offset[..., 0:1] + f_2 * (1 - offset[..., 0:1]) f_56 = f_5 * offset[..., 0:1] + f_6 * (1 - offset[..., 0:1]) f_47 = f_4 * offset[..., 0:1] + f_7 * (1 - offset[..., 0:1]) f0312 = f_03 * offset[..., 1:2] + f_12 * (1 - offset[..., 1:2]) f4756 = f_47 * offset[..., 1:2] + f_56 * (1 - offset[..., 1:2]) encoded_value = f0312 * offset[..., 2:3] + f4756 * ( 1 - offset[..., 2:3] ) # [..., num_levels, features_per_level] return torch.flatten( encoded_value, start_dim=-2, end_dim=-1 ) # [..., num_levels * features_per_level] def forward(self, in_tensor): return self.pytorch_fwd(in_tensor) def get_total_variation_loss(self): """Compute the total variation loss for the feature volume.""" feature_volume = self.hash_table.reshape( self.num_levels, self.periodic_volume_resolution, self.periodic_volume_resolution, self.periodic_volume_resolution, self.features_per_level, ) diffx = feature_volume[:, 1:, :, :, :] - feature_volume[:, :-1, :, :, :] diffy = feature_volume[:, :, 1:, :, :] - feature_volume[:, :, :-1, :, :] diffz = feature_volume[:, :, :, 1:, :] - feature_volume[:, :, :, :-1, :] # TODO how to sum here or should we use mask? resx = diffx.abs().mean(dim=(1, 2, 3, 4)) resy = diffy.abs().mean(dim=(1, 2, 3, 4)) resz = diffz.abs().mean(dim=(1, 2, 3, 4)) return ((resx + resy + resz) * self.per_level_weights).mean() ================================================ FILE: lidarnerf/ffmlp/__init__.py ================================================ ================================================ FILE: lidarnerf/ffmlp/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "--expt-extended-lambda", "--expt-relaxed-constexpr", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": nvcc_flags += [ "-Xcompiler=-mf16c", "-Xcompiler=-Wno-float-conversion", "-Xcompiler=-fno-strict-aliasing", ] c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path _backend = load( name="_ffmlp", extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, extra_include_paths=[ os.path.join(_src_path, "dependencies/cutlass/include"), os.path.join(_src_path, "dependencies/cutlass/tools/util/include"), ], sources=[ os.path.join(_src_path, "src", f) for f in [ "ffmlp.cu", "bindings.cpp", ] ], ) __all__ = ["_backend"] ================================================ FILE: lidarnerf/ffmlp/ffmlp.py ================================================ import math import torch import torch.nn as nn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd try: import _ffmlp as _backend except ImportError: from .backend import _backend class _ffmlp_forward(Function): @staticmethod @custom_fwd(cast_inputs=torch.half) def forward( ctx, inputs, weights, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, inference=False, calc_grad_inputs=False, ): B = inputs.shape[0] inputs = inputs.contiguous() weights = weights.contiguous() # print('[inputs]', torch.any(torch.isnan(inputs)), inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) # print('[weights]', torch.any(torch.isnan(weights)), weights.shape, weights.dtype, weights.min().item(), weights.max().item()) # allocate output outputs = torch.empty(B, output_dim, device=inputs.device, dtype=inputs.dtype) if not inference: forward_buffer = torch.empty( num_layers, B, hidden_dim, device=inputs.device, dtype=inputs.dtype ) _backend.ffmlp_forward( inputs, weights, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, forward_buffer, outputs, ) ctx.save_for_backward(inputs, weights, outputs, forward_buffer) ctx.dims = ( input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs, ) # print('[outputs]', torch.any(torch.isnan(outputs)), outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) # print('[forward_buffer]', torch.any(torch.isnan(forward_buffer)), forward_buffer.shape, forward_buffer.dtype, forward_buffer.min().item(), forward_buffer.max().item()) else: inference_buffer = torch.empty( B, hidden_dim, device=inputs.device, dtype=inputs.dtype ) _backend.ffmlp_inference( inputs, weights, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, inference_buffer, outputs, ) # print('[outputs]', torch.any(torch.isnan(outputs)), outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) # print('[inference_buffer]', torch.any(torch.isnan(inference_buffer)), inference_buffer.shape, inference_buffer.dtype, inference_buffer.min().item(), inference_buffer.max().item()) return outputs @staticmethod @custom_bwd def backward(ctx, grad): # grad: [B, output_dim] B = grad.shape[0] grad = grad.contiguous() # print('[grad]', torch.any(torch.isnan(grad)), grad.shape, grad.dtype, grad.min().item(), grad.max().item()) # print(grad) inputs, weights, outputs, forward_buffer = ctx.saved_tensors ( input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs, ) = ctx.dims # allocate outputs if calc_grad_inputs: grad_inputs = torch.zeros_like(inputs) else: grad_inputs = torch.zeros(1, device=grad.device, dtype=grad.dtype) # dummy grad_weights = torch.zeros_like(weights) backward_buffer = torch.zeros( num_layers, B, hidden_dim, device=grad.device, dtype=grad.dtype ) _backend.ffmlp_backward( grad, inputs, weights, forward_buffer, B, input_dim, output_dim, hidden_dim, num_layers, activation, output_activation, calc_grad_inputs, backward_buffer, grad_inputs, grad_weights, ) # print('[grad_inputs]', grad_inputs.shape, grad_inputs.dtype, grad_inputs.min().item(), grad_inputs.max().item()) # print('[grad_weights]', grad_weights.shape, grad_weights.dtype, grad_weights.min().item(), grad_weights.max().item()) # print('[backward_buffer]', backward_buffer.shape, backward_buffer.dtype, backward_buffer.min().item(), backward_buffer.max().item()) if calc_grad_inputs: return ( grad_inputs, grad_weights, None, None, None, None, None, None, None, None, ) else: return None, grad_weights, None, None, None, None, None, None, None, None ffmlp_forward = _ffmlp_forward.apply def convert_activation(act): if act == "relu": return 0 elif act == "exponential": return 1 elif act == "sine": return 2 elif act == "sigmoid": return 3 elif act == "squareplus": return 4 elif act == "softplus": return 5 else: return 6 class FFMLP(nn.Module): def __init__( self, input_dim, output_dim, hidden_dim, num_layers, activation="relu" ): super().__init__() self.input_dim = input_dim self.output_dim = output_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.activation = convert_activation(activation) self.output_activation = convert_activation("none") # not supported currently self.tensorcore_width = 16 assert hidden_dim in [ 16, 32, 64, 128, 256, ], f"FFMLP only support hidden_dim in [16, 32, 64, 128, 256], but got {hidden_dim}" assert ( input_dim > 0 and input_dim % 16 == 0 ), f"FFMLP input_dim should be 16 * m (m > 0), but got {input_dim}" assert ( output_dim <= 16 ), f"FFMLP current only supports output dim <= 16, but got {output_dim}" assert ( num_layers >= 2 ), f"FFMLP num_layers should be larger than 2 (3 matmuls), but got {num_layers}" # pad output self.padded_output_dim = int(math.ceil(output_dim / 16)) * 16 # parameters (continuous in memory) self.num_parameters = hidden_dim * ( input_dim + hidden_dim * (num_layers - 1) + self.padded_output_dim ) self.weights = nn.Parameter(torch.zeros(self.num_parameters)) self.reset_parameters() # allocate streams _backend.allocate_splitk(self.num_layers + 1) # register destructor # atexit.register(self.cleanup) # how to correctly clean? this gives CUDA Error: cudaEventDestroy(events[i]) failed with error context is destroyed def cleanup(self): # destroy streams _backend.free_splitk() def __repr__(self): return f"FFMLP: input_dim={self.input_dim} output_dim={self.output_dim} hidden_dim={self.hidden_dim} num_layers={self.num_layers} activation={self.activation}" def reset_parameters(self): torch.manual_seed(42) std = math.sqrt(3 / self.hidden_dim) self.weights.data.uniform_(-std, std) def forward(self, inputs): # inputs: [B, input_dim] # return: [B, outupt_dim] # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item(), inputs.requires_grad) B, C = inputs.shape # assert B >= 128 and B % 128 == 0, f"ffmlp batch size must be 128 * m (m > 0), but got {B}." # pad input pad = 128 - (B % 128) if pad > 0: inputs = torch.cat( [inputs, torch.zeros(pad, C, dtype=inputs.dtype, device=inputs.device)], dim=0, ) outputs = ffmlp_forward( inputs, self.weights, self.input_dim, self.padded_output_dim, self.hidden_dim, self.num_layers, self.activation, self.output_activation, not self.training, inputs.requires_grad, ) # unpad output if B != outputs.shape[0] or self.padded_output_dim != self.output_dim: outputs = outputs[:B, : self.output_dim] # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) return outputs ================================================ FILE: lidarnerf/ffmlp/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "--expt-extended-lambda", "--expt-relaxed-constexpr", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": nvcc_flags += [ "-Xcompiler=-mf16c", "-Xcompiler=-Wno-float-conversion", "-Xcompiler=-fno-strict-aliasing", ] c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path setup( name="ffmlp", # package name, import this to use python API ext_modules=[ CUDAExtension( name="_ffmlp", # extension name, import this to use CUDA API sources=[ os.path.join(_src_path, "src", f) for f in [ "ffmlp.cu", "bindings.cpp", ] ], extra_compile_args={ "cxx": c_flags, "nvcc": nvcc_flags, }, include_dirs=[ os.path.join(_src_path, "dependencies/cutlass/include"), os.path.join(_src_path, "dependencies/cutlass/tools/util/include"), ], ), ], cmdclass={ "build_ext": BuildExtension, }, ) ================================================ FILE: lidarnerf/ffmlp/src/bindings.cpp ================================================ #include #include "ffmlp.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ffmlp_forward", &ffmlp_forward, "ffmlp_forward (CUDA)"); m.def("ffmlp_inference", &ffmlp_inference, "ffmlp_inference (CUDA)"); m.def("ffmlp_backward", &ffmlp_backward, "ffmlp_backward (CUDA)"); m.def("allocate_splitk", &allocate_splitk, "allocate_splitk (CUDA)"); m.def("free_splitk", &free_splitk, "free_splitk (CUDA)"); } ================================================ FILE: lidarnerf/ffmlp/src/cutlass_matmul.h ================================================ /* * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, * this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its * contributors may be used to endorse or promote products derived from this * software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *//* */ /** @file cutlass_matmul.h * @author Thomas Müller, NVIDIA * @brief Matrix multiplication wrappers that call into CUTLASS (plus some * custom modifications). The parameters are optimized to give optimal * performance in a variety of situations. Parts of this file were adapted by * starting from the CUTLASS sample code (see its BSD 3-clause license). */ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include "utils.h" //#define TCNN_VERBOSE_MEMORY_ALLOCS #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ if (error != cutlass::Status::kSuccess) { \ std::cerr << "Got cutlass error: " \ << cutlassGetStatusString(error) << " at: " << __LINE__ \ << std::endl; \ exit(EXIT_FAILURE); \ } \ } #define CUDA_CHECK(status) \ { \ cudaError_t error = status; \ if (error != cudaSuccess) { \ std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ << " at line: " << __LINE__ << std::endl; \ exit(EXIT_FAILURE); \ } \ } using SmArch = std::conditional_t< MIN_GPU_ARCH >= 80, std::conditional_t::value, cutlass::arch::Sm75, cutlass::arch::Sm80>, std::conditional_t= 75, cutlass::arch::Sm75, cutlass::arch::Sm70>>; using TypeAccumulator = std::conditional_t::value, float, cutlass::half_t>; using TypeCompute = std::conditional_t::value, float, cutlass::half_t>; template using MMAOp = typename std::conditional::value, cutlass::arch::OpClassSimt, cutlass::arch::OpClassTensorOp>::type; template using ShapeMMAOp = typename std::conditional< std::is_same, cutlass::arch::OpClassTensorOp>::value, typename std::conditional< std::is_same::value || std::is_same::value, cutlass::gemm::GemmShape<16, 8, 8>, cutlass::gemm::GemmShape<8, 8, 4>>::type, cutlass::gemm::GemmShape<1, 1, 1>>::type; template struct LayerConfig { using k_thread_block = thread_block; using k_warp = warp; }; using FullLayerK = typename std::conditional< std::is_same, cutlass::arch::OpClassSimt>::value, LayerConfig, cutlass::gemm::GemmShape<32, 64, 8>>, LayerConfig, cutlass::gemm::GemmShape<32, 32, 32>>>::type; using LastLayerK = FullLayerK; using FullLayer = typename std::conditional< std::is_same, cutlass::arch::OpClassSimt>::value, LayerConfig, cutlass::gemm::GemmShape<32, 64, 8>>, LayerConfig, cutlass::gemm::GemmShape<64, 64, 32>>>::type; using LastLayer = FullLayer; // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel template struct CutlassFragmentWrapper { static const uint32_t num_elements = V::kElements; V x; }; template class ActivationEpilogue { public: using ElementOutput = ElementOutput_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; static int const kCount = Count; using FragmentOutput = cutlass::Array; using FragmentAccumulator = cutlass::Array; using ComputeFragment = cutlass::Array; static cutlass::FloatRoundStyle const kRound = Round; struct Params { Activation activation; bool sum_source; }; public: CUTLASS_HOST_DEVICE ActivationEpilogue(Params const ¶ms) : m_activation{params.activation}, m_sum_source{params.sum_source} {} CUTLASS_HOST_DEVICE bool is_source_needed() const { return m_sum_source; } /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE void set_k_partition(int k_partition, int k_partition_count) {} CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator) const { cutlass::NumericArrayConverter accumulator_converter; auto intermediate = CutlassFragmentWrapper{ accumulator_converter(accumulator)}; intermediate = warp_activation(m_activation, intermediate); cutlass::NumericArrayConverter destination_converter; return destination_converter(intermediate.x); } CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source) const { cutlass::NumericArrayConverter source_converter; cutlass::NumericArrayConverter accumulator_converter; cutlass::plus plus_op; auto intermediate = CutlassFragmentWrapper{ accumulator_converter(accumulator)}; if (m_sum_source) { intermediate.x = plus_op(intermediate.x, source_converter(source)); } intermediate = warp_activation(m_activation, intermediate); cutlass::NumericArrayConverter destination_converter; return destination_converter(intermediate.x); } private: Activation m_activation; bool m_sum_source; }; template class ActivationTransferEpilogue { public: using ElementOutput = ElementOutput_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; static int const kCount = Count; using FragmentOutput = cutlass::Array; using FragmentAccumulator = cutlass::Array; using ComputeFragment = cutlass::Array; static cutlass::FloatRoundStyle const kRound = Round; /// Host-constructable parameters structure struct Params { Activation activation; }; public: /// Constructs the function object, possibly loading from pointers in host /// memory CUTLASS_HOST_DEVICE ActivationTransferEpilogue(Params const ¶ms) : m_activation{params.activation} {} /// Returns true if source is needed CUTLASS_HOST_DEVICE bool is_source_needed() const { return true; } /// Functionally required for serial reduction in the epilogue CUTLASS_HOST_DEVICE void set_k_partition(int k_partition, int k_partition_count) {} CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source) const { cutlass::NumericArrayConverter source_converter; cutlass::NumericArrayConverter accumulator_converter; auto converted_source = CutlassFragmentWrapper{ source_converter(source)}; auto intermediate = CutlassFragmentWrapper{ accumulator_converter(accumulator)}; intermediate = warp_activation_backward( m_activation, intermediate, converted_source); cutlass::NumericArrayConverter destination_converter; return destination_converter(intermediate.x); } CUTLASS_HOST_DEVICE FragmentOutput operator()(FragmentAccumulator const &accumulator) const { cutlass::NumericArrayConverter accumulator_converter; ComputeFragment converted_accumulator = accumulator_converter(accumulator); cutlass::NumericArrayConverter destination_converter; return destination_converter(converted_accumulator); } private: Activation m_activation; }; template static constexpr int n_vectorized_elements = std::is_same, cutlass::arch::OpClassTensorOp>::value ? (128 / cutlass::sizeof_bits::value) : 1; template using SumOp = cutlass::epilogue::thread::LinearCombination, TypeAccumulator, TypeCompute>; template using IntermediateActivationOp = ActivationEpilogue; template using IntermediateActivationTransferOp = ActivationTransferEpilogue; template using ActivationOp = ActivationEpilogue, TypeAccumulator, TypeCompute>; template using ActivationTransferOp = ActivationTransferEpilogue, TypeAccumulator, TypeCompute>; template using OurGemm = cutlass::gemm::device::Gemm, SmArch, typename LayerConfig::k_thread_block, typename LayerConfig::k_warp, ShapeMMAOp, EPILOGUE, SwizzleThreadBlock, 2>; template using SplitKGemm = cutlass::gemm::device::GemmSplitKParallel< TypeA, LayoutA, TypeB, LayoutB, TypeOutput, LayoutOutput, TypeAccumulator, MMAOp, SmArch, typename LayerConfig::k_thread_block, typename LayerConfig::k_warp, ShapeMMAOp, EPILOGUE>; inline std::map> &cutlass_workspaces() { static std::map> s_workspaces; return s_workspaces; } inline uint8_t *cutlass_get_workspace(size_t size, cudaStream_t stream) { GPUMemory &workspace = cutlass_workspaces()[stream]; if (size > workspace.size()) { size *= 2; #ifdef TCNN_VERBOSE_MEMORY_ALLOCS std::cout << "CUTLASS GEMM: Allocating temporary workspace of " << bytes_to_string(size) << "." << std::endl; #endif // Allocate twice the requested size to make sure we're not constantly // allocating small increments. workspace.resize(size); } return workspace.data(); } inline void cutlass_free_workspace(cudaStream_t stream) { if (cutlass_workspaces().count(stream) == 0) { return; } #ifdef TCNN_VERBOSE_MEMORY_ALLOCS std::cout << "CUTLASS GEMM: Freeing temporary workspace of " << bytes_to_string(cutlass_workspaces().at(stream).size()) << "." << std::endl; #endif cutlass_workspaces().erase(stream); } template void fc_multiply_impl(cudaStream_t stream, const typename Gemm::Arguments &args) { // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(args); // Instantiate CUTLASS kernel depending on templates Gemm gemm_op; // Initialize CUTLASS kernel with arguments and workspace pointer cutlass::Status status = gemm_op.initialize( args, cutlass_get_workspace(workspace_size, stream), stream); CUTLASS_CHECK(status); // Launch initialized CUTLASS kernel status = gemm_op(stream); CUTLASS_CHECK(status); } template void fc_multiply_split_k_impl(cudaStream_t stream, const typename Gemm::Arguments &args) { // Using the arguments, query for extra workspace required for matrix // multiplication computation size_t workspace_size = Gemm::get_workspace_size(args); // Instantiate CUTLASS kernel depending on templates Gemm gemm_op; // Initialize CUTLASS kernel with arguments and workspace pointer cutlass::Status status = gemm_op.initialize( args, cutlass_get_workspace(workspace_size, stream)); CUTLASS_CHECK(status); // Launch initialized CUTLASS kernel status = gemm_op(stream); CUTLASS_CHECK(status); } ////////////////////////////////////////////////////////////////////////////////// //////////////////////////// modified /////////////////////////////// ////////////////////////////////////////////////////////////////////////////////// template void fc_multiply(cudaStream_t stream, const int M, const int K, const int N, const __half *A, const __half *B, const __half *C, __half *D, Activation act = Activation::None, bool transfer = false, bool sum_source = false) { // compute D = A @ B + C // A: [M, K] // B: [K, N] // C, D: [M, N] using CutlassLayoutA = typename std::conditional::type; using CutlassLayoutB = typename std::conditional::type; using CutlassLayoutC = typename std::conditional::type; using MatmulTypeCompute = cutlass::half_t; using MatmulTypeAccumulator = cutlass::half_t; const int lda = RM_A ? K : M; const int ldb = RM_B ? N : K; const int ldc = RM_C ? N : M; const int ldd = RM_C ? N : M; if (transfer) { using Gemm = OurGemm, config, MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute, CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>; typename Gemm::Arguments arguments{{M, N, K}, {(MatmulTypeCompute *)A, lda}, {(MatmulTypeCompute *)B, ldb}, {(MatmulTypeAccumulator *)C, ldc}, {(MatmulTypeAccumulator *)D, ldd}, {act}, 1}; fc_multiply_impl(stream, arguments); } else { using Gemm = OurGemm, config, MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute, CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>; typename Gemm::Arguments arguments{{M, N, K}, {(MatmulTypeCompute *)A, lda}, {(MatmulTypeCompute *)B, ldb}, {(MatmulTypeAccumulator *)C, ldc}, {(MatmulTypeAccumulator *)D, ldd}, {act, sum_source}, 1}; fc_multiply_impl(stream, arguments); } } template void fc_multiply(cudaStream_t stream, const int M, const int K, const int N, const __half *A, const __half *B, __half *D, Activation act = Activation::None) { fc_multiply(stream, M, K, N, A, B, D, D, act); } template void fc_multiply_split_k(cudaStream_t stream, const int M, const int K, const int N, const __half *A, const __half *B, const __half *C, __half *D, int split_k_slices = 1) { // A: [M, K] // B: [K, N] // C, D: [M, N] using CutlassLayoutA = typename std::conditional::type; using CutlassLayoutB = typename std::conditional::type; using CutlassLayoutC = typename std::conditional::type; using MatmulTypeCompute = cutlass::half_t; using MatmulTypeAccumulator = cutlass::half_t; const int lda = RM_A ? K : M; const int ldb = RM_B ? N : K; const int ldc = RM_C ? N : M; const int ldd = RM_C ? N : M; using Gemm = SplitKGemm, config, MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute, CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>; typename Gemm::Arguments arguments{{M, N, K}, {(MatmulTypeCompute *)A, lda}, {(MatmulTypeCompute *)B, ldb}, {(MatmulTypeAccumulator *)C, ldc}, {(MatmulTypeAccumulator *)D, ldd}, {(TypeCompute)1.0f, (TypeCompute)0.0f}, split_k_slices}; fc_multiply_split_k_impl(stream, arguments); } template void fc_multiply_split_k(cudaStream_t stream, const int M, const int K, const int N, const __half *A, const __half *B, __half *D, int split_k_slices = 1) { fc_multiply_split_k(stream, M, K, N, A, B, D, D, split_k_slices); } ================================================ FILE: lidarnerf/ffmlp/src/ffmlp.cu ================================================ #include #include #include #include #include #include #include #include #include #include #include #include "cutlass_matmul.h" #include "utils.h" __host__ __device__ Activation convert_activation(const uint32_t activation) { switch (activation) { case 0: return Activation::ReLU; case 1: return Activation::Exponential; case 2: return Activation::Sine; case 3: return Activation::Sigmoid; case 4: return Activation::Squareplus; case 5: return Activation::Softplus; case 6: return Activation::None; default: return Activation::None; } } template __host__ __device__ T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } void check_shmem_error(cudaError_t error) { if (error != cudaSuccess) { throw std::runtime_error{ "FullyFusedMLP: insufficient shared memory available on the " "GPU. " "Reduce `n_neurons` or use `CutlassMLP` (better compatibility " "but " "slower) instead."}; } } template __device__ void threadblock_layer( Activation activation, __half *__restrict__ act_shmem, const __half *__restrict__ weights_this_layer, OUT_T *__restrict__ out_intermediate_threadblock_this_layer, const OUT_T *__restrict__ activation_aux = nullptr) { // act_shmem contains the intermediate activations (shared memory) of the // thread block's chunk of the batch. // Can be forward activations or backward activations, depending // on caller. // weights_this_layer points to the weight matrix of the current layer. // out_intermediate_threadblock_this_layer points to the location where // intermediate activations produced by the thread block should be written // to. // Can be nullptr if nothing should be written. // activation_aux points to additional arguments that the activation // function may depend on. Points to the hidden forward activations when // computing backward activations. constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; constexpr uint32_t N_BLOCKS = WIDTH / 16; using namespace nvcuda; // If we're performing the backward pass, weights must be loaded in // transposed form, which is achieved by interpreting the memory in // row_major instead of col_major order. using weights_layout_t = std::conditional_t; // Fragments wmma::fragment act_frag; wmma::fragment weights_frag[N_BLOCKS]; wmma::fragment result_frag[N_ITERS]; // Indices const uint32_t li = threadIdx.x; // index in warp ("lane index") const uint32_t wi = threadIdx.y; // index in block ("warp index") const uint32_t lane_offset = (8 * li) % WIDTH; const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; const uint32_t weights_col = 16 * wi; __syncthreads(); // Load N_BLOCKS chunks of weights from global memory into registers. #pragma unroll for (uint32_t i = 0; i < N_BLOCKS; ++i) { if (BACKWARD) { // If we're performing the backward pass, additional index swizzling // is needed to load the weights in transposed form. wmma::load_matrix_sync( weights_frag[i], weights_this_layer + 16 * i * WIDTH + weights_col, WIDTH); } else { wmma::load_matrix_sync( weights_frag[i], weights_this_layer + 16 * i + weights_col * WIDTH, WIDTH); } } #pragma unroll for (int l = 0; l < N_ITERS; ++l) { wmma::fill_fragment(result_frag[l], 0.0f); #pragma unroll for (uint32_t i = 0; i < N_BLOCKS; ++i) { // Load a chunk of intermediate activations from shared memory and // multiply with chunk of weights wmma::load_matrix_sync( act_frag, act_shmem + 16 * i + (16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW), WIDTH + SKEW); wmma::mma_sync(result_frag[l], act_frag, weights_frag[i], result_frag[l]); } // Activation if (BACKWARD) { // Load the temporary forward matrix for the relu transfer wmma::load_matrix_sync( act_frag, activation_aux + weights_col + (threadIdx.z + l * BLOCK_DIM_Z) * 16 * WIDTH, WIDTH); warp_activation_backward<__half>(activation, result_frag[l], act_frag, result_frag[l]); } else { warp_activation<__half>(activation, result_frag[l], result_frag[l]); } } __syncthreads(); #pragma unroll for (int l = 0; l < N_ITERS; ++l) { wmma::store_matrix_sync( act_shmem + weights_col + (threadIdx.z + l * BLOCK_DIM_Z) * 16 * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); } if (out_intermediate_threadblock_this_layer != nullptr) { __syncthreads(); #pragma unroll for (int l = 0; l < N_ITERS; ++l) { *(int4 *)&out_intermediate_threadblock_this_layer [lane_offset + (row + 16 * (threadIdx.z + l * BLOCK_DIM_Z)) * WIDTH] = *(int4 *)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW)]; } } } template __device__ void threadblock_load_input_static( __half *__restrict__ act_shmem, const __half *__restrict__ input_threadblock) { // act_shmem will be filled by the thread block's chunk of input_threadblock constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // Indices const uint32_t li = threadIdx.x; // index in warp ("lane index") const uint32_t wi = threadIdx.y; // index in block ("warp index") const uint32_t lane_offset = (8 * li) % WIDTH; const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; #pragma unroll for (int i = 0; i < N_ITERS; ++i) { *(int4 *)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * (WIDTH + SKEW)] = *(int4 *)&input_threadblock [lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * WIDTH]; } } template __device__ void threadblock_input_layer_forward_dynamic( Activation activation, __half *__restrict__ act_shmem, const __half *__restrict__ input_threadblock, const __half *__restrict__ weights_this_layer, OUT_T *__restrict__ out_intermediate_threadblock_this_layer, const uint32_t in_width) { // act_shmem contains the intermediate activations (shared memory) of the // thread block's chunk of the batch input_threadblock points to the thread // block's chunk of the input batch in global memory weights_this_layer // points to the weight matrix of the current layer // out_intermediate_threadblock_this_layer points to the location where // intermediate activations produced by the thread block should be written // to. // Can be nullptr if nothing should be written. // in_width is the dynamic width of the input layer constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; constexpr uint32_t INPUT_SKEW = 8; constexpr uint32_t N_BLOCKS = WIDTH / 16; using namespace nvcuda; // Fragments wmma::fragment act_frag; wmma::fragment weights_frag; wmma::fragment result_frag[N_ITERS]; // Indices const uint32_t li = threadIdx.x; // index in warp ("lane index") const uint32_t wi = threadIdx.y; // index in block ("warp index") const uint32_t lane_offset = (8 * li) % WIDTH; const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; const uint32_t weights_col = 16 * wi; __half *__restrict__ weights_shmem = act_shmem + BLOCK_DIM_Z * 16 * (in_width + INPUT_SKEW); // Load input weight matrix (fits completely into shared memory) // Each thread can load 8 fp16 elements (16 bytes) at once; we have // N_BLOCKS*BLOCK_DIM_Z warps const uint32_t n_elems_per_load = N_BLOCKS * 32 * BLOCK_DIM_Z * 8; const uint32_t thread_elem_idx = (li + wi * 32 + threadIdx.z * N_BLOCKS * 32) * 8; const uint32_t n_elems_b = WIDTH * in_width; #pragma unroll for (uint32_t idx = thread_elem_idx; idx < n_elems_b; idx += n_elems_per_load) { const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; *(int4 *)&weights_shmem[idx_skewed] = *(int4 *)&weights_this_layer[idx]; } const uint32_t n_tensor_ops = in_width / 16; #pragma unroll for (int l = 0; l < N_ITERS; ++l) { // Load chunk of inputs into shmem. // This is faster than loading it from gmem directly, even though it is // only used once. (Possibly due to latency hiding through staging.) const uint32_t n_elems_a = BLOCK_DIM_Z * 16 * in_width; #pragma unroll for (uint32_t idx = thread_elem_idx; idx < n_elems_a; idx += n_elems_per_load) { const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; *(int4 *)&act_shmem[idx_skewed] = *(int4 *)&input_threadblock[l * n_elems_a + idx]; } __syncthreads(); wmma::fill_fragment(result_frag[l], 0.0f); #pragma unroll for (uint32_t i = 0; i < n_tensor_ops; ++i) { // Load chunk of inputs and weights from shared memory and multiply // them wmma::load_matrix_sync( act_frag, act_shmem + 16 * i + (16 * threadIdx.z) * (in_width + INPUT_SKEW), in_width + INPUT_SKEW); wmma::load_matrix_sync( weights_frag, weights_shmem + 16 * i + weights_col * (in_width + INPUT_SKEW), in_width + INPUT_SKEW); wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]); } __syncthreads(); warp_activation<__half>(activation, result_frag[l], result_frag[l]); } #pragma unroll for (int l = 0; l < N_ITERS; ++l) { wmma::store_matrix_sync( act_shmem + weights_col + (16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); } if (out_intermediate_threadblock_this_layer != nullptr) { __syncthreads(); #pragma unroll for (int i = 0; i < N_ITERS; ++i) { *(int4 *)&out_intermediate_threadblock_this_layer [lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * WIDTH] = *(int4 *)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * (WIDTH + SKEW)]; } } } template __device__ void threadblock_last_layer_forward( Activation activation, __half *__restrict__ act_shmem, const __half *__restrict__ weights_this_layer, OUT_T *__restrict__ out, const uint32_t batch_size, const nvcuda::wmma::layout_t output_layout) { // act_shmem contains the intermediate activations (shared memory) of the // thread block's chunk of the batch weights_this_layer points to the weight // matrix of the current layer out points to the location where the result // produced by the thread block should be written to. // Can be nullptr if nothing should be written. constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; constexpr uint32_t N_BLOCKS = WIDTH / 16; using namespace nvcuda; // Fragments wmma::fragment act_frag; wmma::fragment weights_frag[N_BLOCKS]; wmma::fragment result_frag; // Indices const uint32_t li = threadIdx.x; // index in warp ("lane index") const uint32_t wi = threadIdx.y; // index in block ("warp index") __half *__restrict__ weights_shmem = act_shmem + N_ITERS * BLOCK_DIM_Z * 16 * (WIDTH + SKEW); const uint32_t weights_row = (8 * li) % WIDTH; const uint32_t weights_col = (8 * li + 8 * 32 * wi) / WIDTH; // Load weight matrix into shared memory for the last multiplication. // Loading into shared memory as opposed to directly into registers is // faster because unlike in the previous layers, each warp uses the same // entries of the weight matrix. if (threadIdx.z == 0) { *(int4 *)&weights_shmem[weights_row + weights_col * (WIDTH + SKEW)] = *(int4 *)&weights_this_layer[weights_row + weights_col * WIDTH]; // printf("[last forward] base=%d, shmem=%d, weight=%d\n", N_ITERS * // BLOCK_DIM_Z * 16 * (WIDTH + SKEW), weights_row + weights_col * (WIDTH // + SKEW), weights_row + weights_col * WIDTH); } __syncthreads(); #pragma unroll for (uint32_t i = 0; i < N_BLOCKS; ++i) wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i, WIDTH + SKEW); // Perform last layer by parallelizing over iters for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) { wmma::fill_fragment(result_frag, 0.0f); #pragma unroll for (uint32_t i = 0; i < N_BLOCKS; ++i) { // Load a chunk of intermediate activations from shared memory and // multiply with chunk of the weight matrix wmma::load_matrix_sync( act_frag, act_shmem + 16 * i + (16 * (threadIdx.z + idx * BLOCK_DIM_Z)) * (WIDTH + SKEW), WIDTH + SKEW); wmma::mma_sync(result_frag, act_frag, weights_frag[i], result_frag); } warp_activation<__half>(activation, result_frag, result_frag); if (output_layout == wmma::mem_row_major) { wmma::store_matrix_sync( out + (threadIdx.z + idx * BLOCK_DIM_Z) * 16 * 16, result_frag, 16, output_layout); // printf("[last forward] RM write out %d, batch %d\n", (threadIdx.z // + idx // * BLOCK_DIM_Z) * 16 * 16, 16); } else { wmma::store_matrix_sync( out + (threadIdx.z + idx * BLOCK_DIM_Z) * 16, result_frag, batch_size, output_layout); // printf("[last forward] CM write out %d, batch %d\n", (threadIdx.z // + idx // * BLOCK_DIM_Z) * 16, batch_size); } } } template __device__ void threadblock_write_output_static( const __half *__restrict__ act_shmem, __half *__restrict__ output_threadblock) { // output_threadblock will be filled by the thread block's act_shmem constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // Indices const uint32_t li = threadIdx.x; // index in warp ("lane index") const uint32_t wi = threadIdx.y; // index in block ("warp index") const uint32_t lane_offset = (8 * li) % WIDTH; const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; __syncthreads(); #pragma unroll for (int i = 0; i < N_ITERS; ++i) { *(int4 *)&output_threadblock[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * WIDTH] = *(int4 *)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * (WIDTH + SKEW)]; } } /////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void kernel_mlp_fused(const Activation activation, const Activation output_activation, const __half *__restrict__ input, const __half *__restrict__ weights, OUT_T *__restrict__ out_intermediate, OUT_T *__restrict__ out, const uint32_t batch_size, const uint32_t in_width, const uint32_t out_width, const uint32_t n_hidden_matmuls, const nvcuda::wmma::layout_t output_layout = nvcuda::wmma::mem_row_major) { // `input` points to the input matrix. Can be any width. // `weights` points to the weight matrices (contiguous in memory). // `out_intermediate` points to the memory where intermediate activations // should be written. When performing inference, a value of nullptr is // expected (intermediate results are not written). `out` points to the // memory where the network output should be written. (Output width is // assumed to be 16 neurons.) // if (threadIdx.x == 0) printf("[forward] call kernel_mlp_fused\n"); // if (threadIdx.x == 0) printf("[forward] inputs=%f\n", (float)input[0]); // if (threadIdx.x == 0) printf("[forward] weights=%f\n", // (float)weights[0]); // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", // (float)out_intermediate[0]); // Shared memory contains the intermediate activations of blockDim.y*16 // elements. In some cases, it also contains the weight matrix for the first // and last layer. extern __shared__ __half shmem[]; __half *act_shmem = shmem; // Each block computes exactly one 16-element chunk of the batch. const uint32_t elem_idx = 16 * blockIdx.x * N_ITERS * BLOCK_DIM_Z; // First layer if (in_width == WIDTH) { // If the input has the same width as the network, we can simply use the // network's regular layer routine (with static size) instead of using // the slower dynamic input layer routine. threadblock_load_input_static( act_shmem, input + elem_idx * WIDTH); threadblock_layer( activation, act_shmem, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr); } else { threadblock_input_layer_forward_dynamic( activation, act_shmem, input + elem_idx * in_width, weights, !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, in_width); } // if (threadIdx.x == 0) printf("[forward] kernel_mlp_fused: passed first // layer\n"); // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", // (float)out_intermediate[0]); const uint32_t first_layer_size = WIDTH * in_width; const uint32_t layer_stride = WIDTH * WIDTH; const uint32_t output_stride = WIDTH * batch_size; // Hidden layers for (uint32_t k = 0; k < n_hidden_matmuls; ++k) { threadblock_layer( activation, act_shmem, weights + first_layer_size + layer_stride * k, !INFERENCE ? (out_intermediate + output_stride * (k + 1) + elem_idx * WIDTH) : nullptr); // if (threadIdx.x == 0) printf("[forward] kernel_mlp_fused: passed %d // layer\n", k + 1); // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", // (float)out_intermediate[0]); } if (out_width > 16) { // In the forward pass, intermediate activations are already written // out. if (INFERENCE) { threadblock_write_output_static( act_shmem, out_intermediate + elem_idx * WIDTH); } } else if (out) { // Last layer if (output_layout == nvcuda::wmma::mem_row_major) { // printf("[last layer] RM write to out %d\n", elem_idx * 16); // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", // (float)out_intermediate[0]); threadblock_last_layer_forward( output_activation, act_shmem, weights + first_layer_size + layer_stride * n_hidden_matmuls, out + elem_idx * 16, 16, output_layout); // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", // (float)out_intermediate[0]); } else { // printf("[last layer] CM write to out %d\n", elem_idx); // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", // (float)out_intermediate[0]); threadblock_last_layer_forward( output_activation, act_shmem, weights + first_layer_size + layer_stride * n_hidden_matmuls, out + elem_idx, batch_size, output_layout); // if (threadIdx.x == 0) printf("[forward] forward_buffer=%f\n", // (float)out_intermediate[0]); } } } template __global__ void kernel_mlp_fused_backward( const Activation activation, const __half *__restrict__ dL_doutput, const __half *__restrict__ weights, __half *__restrict__ out_intermediate, const __half *__restrict__ forward, __half *__restrict__ dL_dinput, const __half *__restrict__ weights_first_layer, const uint32_t batch_size, const uint32_t out_width, const uint32_t n_hidden_matmuls) { // `dL_doutput` points to the input matrix of the backward pass, i.e. the // loss gradients. Assumed to be 16 neurons wide. `weights` points to the // weight matrices (contiguous in memory). `out_intermediate` points to the // memory where backpropagated activation gradients should be written. // `forward` points to the memory where the intermediate activations of the // forward pass are located. (needed for activation backprop) constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // Indices const uint32_t li = threadIdx.x; // index in warp ("lane index") const uint32_t wi = threadIdx.y; // index in block ("warp index") const uint32_t bi = blockIdx.x; // block index // Shared memory contains the intermediate activations of blockDim.y*16 // elements. A skew is applied to the matrix storage to avoid bank // conflicts. extern __shared__ __half shmem[]; __half *act_shmem = shmem; const uint32_t lane_offset = (8 * li) % WIDTH; const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; // Multipying one 16-row chunk of intermediate activations with the weight // matrix requires all warps of the block. Thus, each block computes exactly // one 16-row chunk of the next layer's intermediate activations. const uint32_t elem_idx_base = 16 * bi * N_ITERS * BLOCK_DIM_Z; const uint32_t elem_idx = elem_idx_base + 16 * threadIdx.z; const uint32_t layer_stride = WIDTH * WIDTH; const uint32_t output_stride = WIDTH * batch_size; // Backprop through last layer if (out_width <= 16) { using namespace nvcuda; // Fragments in registers wmma::fragment act_frag; wmma::fragment weights_frag; wmma::fragment result_frag[N_ITERS]; // Load the relevant chunk of the last layer's weight matrix from global // memory into registers const uint32_t weights_col = 16 * wi; wmma::load_matrix_sync( weights_frag, weights + layer_stride * n_hidden_matmuls + weights_col, WIDTH); #pragma unroll for (int l = 0; l < N_ITERS; ++l) { wmma::fill_fragment(result_frag[l], 0.0f); // Load a chunk of output gradients from shared memory and multiply // with previously loaded weights if (std::is_same::value) { wmma::load_matrix_sync( act_frag, dL_doutput + (elem_idx + 16 * (threadIdx.z + l * BLOCK_DIM_Z)) * 16, 16); } else { wmma::load_matrix_sync( act_frag, dL_doutput + (elem_idx + 16 * (threadIdx.z + l * BLOCK_DIM_Z)), batch_size); } // NOTE: activation transfer of the _output_ activation is expected // to be done _prior_ to calling this kernel // in a separate pass, because the tranfered activation // gradient is also needed to compute the weight gradient of // the last weight matrix (see backward()). wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]); // Load the temporary forward matrix for the relu transfer wmma::fragment forward_frag; wmma::load_matrix_sync( forward_frag, forward + output_stride * n_hidden_matmuls + weights_col + (elem_idx + l * BLOCK_DIM_Z * 16) * WIDTH, WIDTH); warp_activation_backward<__half>(activation, result_frag[l], forward_frag, result_frag[l]); } __syncthreads(); #pragma unroll for (int l = 0; l < N_ITERS; ++l) { wmma::store_matrix_sync( act_shmem + weights_col + (16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); } __syncthreads(); #pragma unroll for (int i = 0; i < N_ITERS; ++i) { *(int4 *)&out_intermediate[lane_offset + (row + elem_idx + i * BLOCK_DIM_Z * 16) * WIDTH] = *(int4 *)&act_shmem[lane_offset + (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * (WIDTH + SKEW)]; } } else { // If the output width is larger than 16, we will have used CUTLASS for // backpropping through the last layer. Load the resulting gradients. threadblock_load_input_static( act_shmem, out_intermediate + elem_idx * WIDTH); } // Backprop through hidden layers for (uint32_t k = 0; k < n_hidden_matmuls; ++k) { threadblock_layer( activation, act_shmem, weights + layer_stride * (n_hidden_matmuls - k - 1), out_intermediate + output_stride * (k + 1) + elem_idx_base * WIDTH, forward + output_stride * (n_hidden_matmuls - k - 1) + elem_idx_base * WIDTH); } // Compute loss gradients w.r.t. input if desired. // THIS CODE ASSUMES THAT THE INPUT WIDTH IS THE SAME AS THE NETWORK WIDTH. // DON'T PASS A NON-NULL dL_dinput IF THIS REQUIREMENT IS NOT MET. if (dL_dinput != nullptr) { threadblock_layer( Activation::None, act_shmem, weights_first_layer, dL_dinput + elem_idx_base * WIDTH); } } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// template // WIDTH is hidden_dim void ffmlp_forward_cuda(const __half *inputs, const __half *weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t num_layers, const Activation activation, const Activation output_activation, __half *forward_buffer, __half *outputs) { constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // <- always going to be 8 as we only // support multiple-of-16 widths constexpr uint32_t INPUT_SKEW = 8; // <- likewise with inputs constexpr uint32_t N_BLOCK_ROWS = WIDTH / 16; const int N_ITERS = WIDTH >= 256 ? 2 : 8; const uint32_t BLOCK_DIM_Z = (INFERENCE && WIDTH == 128) ? 2 : 1; const dim3 threads = { 32u, N_BLOCK_ROWS, BLOCK_DIM_Z}; // 32 threads = 1 warp, N_BLOCK_ROWS warps // per block for 16 rows, up to 2x 8 warps // can share input (does not help vs. 1) uint32_t n_elems_per_block = 16 * BLOCK_DIM_Z * N_ITERS; uint32_t n_blocks = div_round_up(B, n_elems_per_block); size_t shmem_size = sizeof(__half) * (16 + 16 * BLOCK_DIM_Z * N_ITERS) * (WIDTH + SKEW); // 16*WIDTH rows of weights (for the last layer; others // are in registers only) + 16*WIDTH*BLOCK_DIM_Z*N_ITERS // rows of intermediate activations // If the input width is dynamic, the input weight matrix as well as part of // the input will live in extra shared memory if (input_dim != WIDTH) { shmem_size = std::max(shmem_size, sizeof(__half) * (WIDTH + 16 * BLOCK_DIM_Z) * (input_dim + INPUT_SKEW)); } // printf("[ffmlp_forward_cuda] shmem size = %d\n", shmem_size); const dim3 blocks = {n_blocks, 1u, 1u}; check_shmem_error(cudaFuncSetAttribute( kernel_mlp_fused, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem_size)); kernel_mlp_fused <<>>( activation, output_activation, inputs, // CM weights, // RM forward_buffer, // CM outputs, // CM B, input_dim, output_dim, num_layers - 1, nvcuda::wmma::mem_row_major // reversed outputs's layout ); } template // WIDTH is hidden_dim void ffmlp_backward_cuda(const __half *grad, const __half *weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t num_layers, const Activation activation, const __half *forward_buffer, __half *backward_buffer, __half *grad_inputs) { // locate const __half *weights_first = weights; const __half *weights_second = weights + input_dim * WIDTH; constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; // <- always going to be 8 as we only // support multiple-of-16 widths constexpr uint32_t N_BLOCKS = WIDTH / 16; const int N_ITERS = WIDTH >= 256 ? 2 : 8; const uint32_t BLOCK_DIM_Z = 1; const dim3 threads = { 32u, N_BLOCKS, BLOCK_DIM_Z}; // 32 threads = 1 warp, N_BLOCK_ROWS warps // per block for 16 rows, up to 2x 8 warps // can share input (does not help vs. 1) uint32_t n_elems_per_block = 16 * BLOCK_DIM_Z * N_ITERS; uint32_t n_blocks = div_round_up(B, n_elems_per_block); size_t shmem_size = sizeof(__half) * ((16 * BLOCK_DIM_Z * N_ITERS) * (WIDTH + SKEW)); // WIDTH rows of input and 16 * threads.z rows of weights const dim3 blocks = {n_blocks, 1u, 1u}; // The kernels operate with transposed layouts compared with the MLP code check_shmem_error(cudaFuncSetAttribute( kernel_mlp_fused_backward, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); kernel_mlp_fused_backward <<>>(activation, grad, // CM weights_second, // RM backward_buffer, // CM forward_buffer, // CM grad_inputs, // CM weights_first, // RM B, output_dim, num_layers - 1); } // inputs: col-major [input_dim, B] // weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim * // (num_layers - 1)] + [output_dim * hidden_dim] forward_buffer: col-major // [num_layers, hidden_dim, B] outputs: col-major [output_dim, B] void ffmlp_forward(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor forward_buffer, at::Tensor outputs) { CHECK_CUDA(inputs); CHECK_CONTIGUOUS(inputs); CHECK_IS_HALF(inputs); CHECK_CUDA(weights); CHECK_CONTIGUOUS(weights); CHECK_IS_HALF(weights); Activation activation = convert_activation(activation_); Activation output_activation = convert_activation(output_activation_); auto inputs_ptr = reinterpret_cast<__half *>(inputs.data_ptr()); auto weights_ptr = reinterpret_cast<__half *>(weights.data_ptr()); auto forward_buffer_ptr = reinterpret_cast<__half *>(forward_buffer.data_ptr()); auto outputs_ptr = reinterpret_cast<__half *>(outputs.data_ptr()); switch (hidden_dim) { case 16: ffmlp_forward_cuda<16, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; case 32: ffmlp_forward_cuda<32, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; case 64: ffmlp_forward_cuda<64, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; case 128: ffmlp_forward_cuda<128, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; case 256: ffmlp_forward_cuda<256, false>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, forward_buffer_ptr, outputs_ptr); break; default: throw std::runtime_error{ "hidden_dim should in [16, 32, 64, 128, 256]"}; } // for output_dim > 16 if (output_dim > 16) { fc_multiply( 0, output_dim, hidden_dim, B, (weights_ptr + hidden_dim * input_dim + (num_layers - 1) * hidden_dim * hidden_dim), // row-major, [output_dim, hidden_dim] (forward_buffer_ptr + (num_layers - 1) * hidden_dim * B), // col-major [hidden_dim, B] outputs_ptr, // col-major [outupt_dim, B] output_activation); } } void ffmlp_inference(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor inference_buffer, at::Tensor outputs) { CHECK_CUDA(inputs); CHECK_CONTIGUOUS(inputs); CHECK_IS_HALF(inputs); CHECK_CUDA(weights); CHECK_CONTIGUOUS(weights); CHECK_IS_HALF(weights); Activation activation = convert_activation(activation_); Activation output_activation = convert_activation(output_activation_); auto inputs_ptr = reinterpret_cast<__half *>(inputs.data_ptr()); auto weights_ptr = reinterpret_cast<__half *>(weights.data_ptr()); auto inference_buffer_ptr = reinterpret_cast<__half *>(inference_buffer.data_ptr()); auto outputs_ptr = reinterpret_cast<__half *>(outputs.data_ptr()); switch (hidden_dim) { case 16: ffmlp_forward_cuda<16, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; case 32: ffmlp_forward_cuda<32, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; case 64: ffmlp_forward_cuda<64, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; case 128: ffmlp_forward_cuda<128, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; case 256: ffmlp_forward_cuda<256, true>(inputs_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, output_activation, inference_buffer_ptr, outputs_ptr); break; default: throw std::runtime_error{ "hidden_dim should in [16, 32, 64, 128, 256]"}; } // for output_dim > 16 if (output_dim > 16) { fc_multiply( 0, output_dim, hidden_dim, B, (weights_ptr + hidden_dim * input_dim + (num_layers - 1) * hidden_dim * hidden_dim), // row-major, [output_dim, hidden_dim] inference_buffer_ptr, // col-major [hidden_dim, B] outputs_ptr, // col-major [outupt_dim, B] output_activation); } } inline std::vector &streams_splitk() { static std::vector res; return res; } inline std::vector &events_splitk() { static std::vector res; return res; } void allocate_splitk(size_t size) { auto &streams = streams_splitk(); auto &events = events_splitk(); streams.resize(size); events.resize(size); for (size_t i = 0; i < size; i++) { CUDA_CHECK_THROW(cudaStreamCreate(&streams[i])); CUDA_CHECK_THROW(cudaEventCreate(&events[i])); } } void free_splitk() { auto &streams = streams_splitk(); auto &events = events_splitk(); for (size_t i = 0; i < streams.size(); i++) { cutlass_free_workspace(streams[i]); CUDA_CHECK_PRINT(cudaStreamDestroy(streams[i])); CUDA_CHECK_PRINT(cudaEventDestroy(events[i])); } } // grad: col-major [output_dim, B] // inputs: col-major [input_dim, B] // weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim * // (num_layers - 1)] + [output_dim * hidden_dim] forward_buffer: col-major // [num_layers, hidden_dim, B] backward_buffer: col-major [num_layers, // hidden_dim, B] grad_inputs: col-major [input_dim, B] grad_weights: row-major // [hidden_dim * input_dim] + [hidden_dim * hidden_dim * (num_layers - 1)] + // [output_dim * hidden_dim] void ffmlp_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor weights, const at::Tensor forward_buffer, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, const bool calc_grad_inputs, at::Tensor backward_buffer, at::Tensor grad_inputs, at::Tensor grad_weights) { CHECK_CUDA(grad); CHECK_CONTIGUOUS(grad); CHECK_IS_HALF(grad); CHECK_CUDA(inputs); CHECK_CONTIGUOUS(inputs); CHECK_IS_HALF(inputs); CHECK_CUDA(weights); CHECK_CONTIGUOUS(weights); CHECK_IS_HALF(weights); CHECK_CUDA(forward_buffer); CHECK_CONTIGUOUS(forward_buffer); CHECK_IS_HALF(forward_buffer); CHECK_CUDA(backward_buffer); CHECK_CONTIGUOUS(backward_buffer); CHECK_IS_HALF(backward_buffer); CHECK_CUDA(grad_weights); CHECK_CONTIGUOUS(grad_weights); CHECK_IS_HALF(grad_weights); CHECK_CUDA(grad_inputs); CHECK_CONTIGUOUS(grad_inputs); CHECK_IS_HALF(grad_inputs); Activation activation = convert_activation(activation_); Activation output_activation = convert_activation(output_activation_); // activation_backward_output_gpu (I gonna discard output_activation ...) int split_k_factor = B / std::min((uint32_t)(1 << 12), B); uint32_t forward_index = num_layers - 1; uint32_t backward_index = 0; auto backward_buffer_ptr = reinterpret_cast<__half *>(backward_buffer.data_ptr()); auto forward_buffer_ptr = reinterpret_cast<__half *>(forward_buffer.data_ptr()); auto grad_ptr = reinterpret_cast<__half *>(grad.data_ptr()); auto inputs_ptr = reinterpret_cast<__half *>(inputs.data_ptr()); auto weights_ptr = reinterpret_cast<__half *>(weights.data_ptr()); auto grad_weights_ptr = reinterpret_cast<__half *>(grad_weights.data_ptr()); auto grad_inputs_ptr = calc_grad_inputs ? reinterpret_cast<__half *>( grad_inputs.data_ptr()) : nullptr; auto grad_inputs_fused_ptr = input_dim == hidden_dim ? grad_inputs_ptr : nullptr; // calc output layer, grad_weights cudaEventRecord(events_splitk().at(backward_index), 0); cudaStreamWaitEvent(streams_splitk().at(backward_index), events_splitk().at(backward_index), 0); fc_multiply_split_k( streams_splitk().at(backward_index), output_dim, B, hidden_dim, grad_ptr, // col-major, [output_dim, B] (forward_buffer_ptr + forward_index * hidden_dim * B), // row-major, [B, hidden_dim] (grad_weights_ptr + hidden_dim * input_dim + (num_layers - 1) * hidden_dim * hidden_dim), // row-major, [output_dim, hidden_dim] split_k_factor); cudaEventRecord(events_splitk().at(backward_index), streams_splitk().at(backward_index)); // prepare the last backward_buffer if output_dim > 16 if (output_dim > 16) { fc_multiply( 0, hidden_dim, output_dim, B, (grad_weights_ptr + hidden_dim * input_dim + (num_layers - 1) * hidden_dim * hidden_dim), // col-major, [hidden_dim, output_dim] grad_ptr, // col-major, [output_dim, B] (forward_buffer_ptr + forward_index * hidden_dim * B), // col-major, [hidden_dim, B] (backward_buffer_ptr + backward_index * hidden_dim * B), // col-major [hidden_dim, B] activation, true); } // prepare backward_buffer // calc grad_inputs if input_dim == hidden_dim switch (hidden_dim) { case 16: ffmlp_backward_cuda<16>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; case 32: ffmlp_backward_cuda<32>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; case 64: ffmlp_backward_cuda<64>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; case 128: ffmlp_backward_cuda<128>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; case 256: ffmlp_backward_cuda<256>(grad_ptr, weights_ptr, B, input_dim, output_dim, num_layers, activation, forward_buffer_ptr, backward_buffer_ptr, grad_inputs_fused_ptr); break; default: throw std::runtime_error{ "hidden_dim should in [16, 32, 64, 128, 256]"}; } // printf("[backward] finished backward kernel\n"); forward_index--; backward_index++; // calc middle layer's grad_weights for (uint32_t i = 0; i < num_layers - 1; i++) { uint32_t matrix_index = num_layers - 2 - i; cudaEventRecord(events_splitk().at(backward_index), 0); cudaStreamWaitEvent(streams_splitk().at(backward_index), events_splitk().at(backward_index), 0); fc_multiply_split_k( streams_splitk().at(backward_index), hidden_dim, B, hidden_dim, (backward_buffer_ptr + (backward_index - 1) * hidden_dim * B), // col-major [hidden_dim, B] (forward_buffer_ptr + forward_index * hidden_dim * B), // row-major [B, hidden_dim] (grad_weights_ptr + hidden_dim * input_dim + matrix_index * hidden_dim * hidden_dim), // row-major, [hidden_dim, hidden_dim] split_k_factor); cudaEventRecord(events_splitk().at(backward_index), streams_splitk().at(backward_index)); forward_index--; backward_index++; } // calc input layer's grad_weights cudaEventRecord(events_splitk().at(backward_index), 0); cudaStreamWaitEvent(streams_splitk().at(backward_index), events_splitk().at(backward_index), 0); fc_multiply_split_k( streams_splitk().at(backward_index), hidden_dim, B, input_dim, (backward_buffer_ptr + (backward_index - 1) * hidden_dim * B), // col-major [hidden_dim, B] inputs_ptr, // row-major, [B, input_dim] grad_weights_ptr, // row-major, [hidden_dim, input_dim] split_k_factor); cudaEventRecord(events_splitk().at(backward_index), streams_splitk().at(backward_index)); // calc grad_inputs if input_dim != hidden_dim if (calc_grad_inputs && grad_inputs_fused_ptr == nullptr) { fc_multiply( 0, input_dim, hidden_dim, B, weights_ptr, // col-major [input_dim, hidden_dim] (backward_buffer_ptr + (backward_index - 1) * hidden_dim * B), // col-major [hidden_dim, B] grad_inputs_ptr // col-major [input_dim, B] ); } // All the per-layer split-k matrix multiplications summing over // the batch are computed in parallel streams to the actual // backpropagation. Here, we need to wait for all of these to complete. for (auto &event : events_splitk()) { cudaStreamWaitEvent(0, event, 0); } } ================================================ FILE: lidarnerf/ffmlp/src/ffmlp.h ================================================ #pragma once #include #include // activation: should have been enum, here we just use int. void ffmlp_forward(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor forward_buffer, at::Tensor outputs); void ffmlp_inference(const at::Tensor inputs, const at::Tensor weights, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation_, const uint32_t output_activation_, at::Tensor inference_buffer, at::Tensor outputs); void ffmlp_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor weights, const at::Tensor forward_buffer, const uint32_t B, const uint32_t input_dim, const uint32_t output_dim, const uint32_t hidden_dim, const uint32_t num_layers, const uint32_t activation, const uint32_t output_activation, const bool calc_grad_inputs, at::Tensor backward_buffer, at::Tensor grad_inputs, at::Tensor grad_weights); void allocate_splitk(size_t size); void free_splitk(); ================================================ FILE: lidarnerf/ffmlp/src/utils.h ================================================ #pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ #x " must be an int tensor") #define CHECK_IS_FLOATING(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \ x.scalar_type() == at::ScalarType::Half || \ x.scalar_type() == at::ScalarType::Double, \ #x " must be a floating tensor") #define CHECK_IS_HALF(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Half, \ #x " must be a Half tensor") static constexpr uint32_t MIN_GPU_ARCH = 70; using network_precision_t = __half; enum class Activation { ReLU, Exponential, Sine, Sigmoid, Squareplus, Softplus, None, }; static constexpr float PI = 3.14159265358979323846f; static constexpr float SQRT2 = 1.41421356237309504880f; static constexpr float K_ACT = 10.0f; __host__ __device__ inline float logistic(const float x) { return 1.0f / (1.0f + expf(-x)); } __host__ __device__ inline float logit(const float x) { return -logf(1.0f / (fminf(fmaxf(x, 1e-9f), 1.0f - 1e-9f)) - 1.0f); } inline std::atomic &total_n_bytes_allocated() { static std::atomic s_total_n_bytes_allocated{0}; return s_total_n_bytes_allocated; } /// Checks the result of a cudaXXXXXX call and throws an error on failure #define CUDA_CHECK_THROW(x) \ do { \ cudaError_t result = x; \ if (result != cudaSuccess) \ throw std::runtime_error( \ std::string("CUDA Error: " #x " failed with error ") + \ cudaGetErrorString(result)); \ } while (0) /// Checks the result of a cudaXXXXXX call and prints an error on failure #define CUDA_CHECK_PRINT(x) \ do { \ cudaError_t result = x; \ if (result != cudaSuccess) \ std::cout << "CUDA Error: " #x " failed with error " \ << cudaGetErrorString(result) << std::endl; \ } while (0) #define DEBUG_GUARD_SIZE 0 /// Managed memory on the Device template class GPUMemory { private: T *m_data = nullptr; size_t m_size = 0; // Number of elements bool m_owned = true; public: GPUMemory() {} GPUMemory &operator=(GPUMemory &&other) { std::swap(m_data, other.m_data); std::swap(m_size, other.m_size); return *this; } GPUMemory(GPUMemory &&other) { *this = std::move(other); } __host__ __device__ GPUMemory(const GPUMemory &other) : m_data{other.m_data}, m_size{other.m_size}, m_owned{false} {} void check_guards() const { #if DEBUG_GUARD_SIZE > 0 if (!m_data) return; uint8_t buf[DEBUG_GUARD_SIZE]; const uint8_t *rawptr = (const uint8_t *)m_data; cudaMemcpy(buf, rawptr - DEBUG_GUARD_SIZE, DEBUG_GUARD_SIZE, cudaMemcpyDeviceToHost); for (int i = 0; i < DEBUG_GUARD_SIZE; ++i) if (buf[i] != 0xff) { printf("TRASH BEFORE BLOCK offset %d data %p, read 0x%02x " "expected " "0xff!\n", i, m_data, buf[i]); break; } cudaMemcpy(buf, rawptr + m_size * sizeof(T), DEBUG_GUARD_SIZE, cudaMemcpyDeviceToHost); for (int i = 0; i < DEBUG_GUARD_SIZE; ++i) if (buf[i] != 0xfe) { printf("TRASH AFTER BLOCK offset %d data %p, read 0x%02x " "expected 0xfe!\n", i, m_data, buf[i]); break; } #endif } void allocate_memory(size_t n_bytes) { if (n_bytes == 0) { return; } #ifdef TCNN_VERBOSE_MEMORY_ALLOCS std::cout << "GPUMemory: Allocating " << bytes_to_string(n_bytes) << "." << std::endl; #endif uint8_t *rawptr = nullptr; CUDA_CHECK_THROW(cudaMalloc(&rawptr, n_bytes + DEBUG_GUARD_SIZE * 2)); #if DEBUG_GUARD_SIZE > 0 CUDA_CHECK_THROW(cudaMemset(rawptr, 0xff, DEBUG_GUARD_SIZE)); CUDA_CHECK_THROW(cudaMemset(rawptr + n_bytes + DEBUG_GUARD_SIZE, 0xfe, DEBUG_GUARD_SIZE)); #endif if (rawptr) rawptr += DEBUG_GUARD_SIZE; m_data = (T *)(rawptr); total_n_bytes_allocated() += n_bytes; } void free_memory() { if (!m_data) { return; } uint8_t *rawptr = (uint8_t *)m_data; if (rawptr) rawptr -= DEBUG_GUARD_SIZE; CUDA_CHECK_THROW(cudaFree(rawptr)); total_n_bytes_allocated() -= get_bytes(); m_data = nullptr; } /// Allocates memory for size items of type T GPUMemory(const size_t size) { resize(size); } /// Frees memory again __host__ __device__ ~GPUMemory() { #ifndef __CUDA_ARCH__ if (!m_owned) { return; } try { if (m_data) { free_memory(); m_size = 0; } } catch (std::runtime_error error) { // Don't need to report on memory-free problems when the driver is // shutting down. if (std::string{error.what()}.find("driver shutting down") == std::string::npos) { fprintf(stderr, "Could not free memory: %s\n", error.what()); } } #endif } /** @name Resizing/enlargement * @{ */ /// Resizes the array to the exact new size, even if it is already larger void resize(const size_t size) { if (!m_owned) { throw std::runtime_error("Cannot resize non-owned memory."); } if (m_size != size) { if (m_size) { try { free_memory(); } catch (std::runtime_error error) { throw std::runtime_error( std::string("Could not free memory: ") + error.what()); } } if (size > 0) { try { allocate_memory(size * sizeof(T)); } catch (std::runtime_error error) { throw std::runtime_error( std::string("Could not allocate memory: ") + error.what()); } } m_size = size; } } /// Enlarges the array if its size is smaller void enlarge(const size_t size) { if (size > m_size) { resize(size); } } /** @} */ /** @name Memset * @{ */ /// Sets the memory of the first num_elements to value void memset(const int value, const size_t num_elements, const size_t offset = 0) { if (num_elements + offset > m_size) { throw std::runtime_error( "Could not set memory: Number of elements " "larger than allocated memory"); } try { CUDA_CHECK_THROW(cudaMemset(m_data + offset, value, num_elements * sizeof(T))); } catch (std::runtime_error error) { throw std::runtime_error(std::string("Could not set memory: ") + error.what()); } } /// Sets the memory of the all elements to value void memset(const int value) { memset(value, m_size); } /** @} */ /** @name Copy operations * @{ */ /// Copy data of num_elements from the raw pointer on the host void copy_from_host(const T *host_data, const size_t num_elements) { try { CUDA_CHECK_THROW(cudaMemcpy(data(), host_data, num_elements * sizeof(T), cudaMemcpyHostToDevice)); } catch (std::runtime_error error) { throw std::runtime_error(std::string("Could not copy from host: ") + error.what()); } } /// Copy num_elements from the host vector void copy_from_host(const std::vector &data, const size_t num_elements) { if (data.size() < num_elements) { throw std::runtime_error( std::string("Trying to copy ") + std::to_string(num_elements) + std::string(" elements, but vector size is only ") + std::to_string(data.size())); } copy_from_host(data.data(), num_elements); } /// Copies data from the raw host pointer to fill the entire array void copy_from_host(const T *data) { copy_from_host(data, m_size); } /// Copies num_elements of data from the raw host pointer after enlarging /// the array so that everything fits in void enlarge_and_copy_from_host(const T *data, const size_t num_elements) { enlarge(num_elements); copy_from_host(data, num_elements); } /// Copies num_elements from the host vector after enlarging the array so /// that everything fits in void enlarge_and_copy_from_host(const std::vector &data, const size_t num_elements) { enlarge_and_copy_from_host(data.data(), num_elements); } /// Copies the entire host vector after enlarging the array so that /// everything fits in void enlarge_and_copy_from_host(const std::vector &data) { enlarge_and_copy_from_host(data.data(), data.size()); } /// Copies num_elements of data from the raw host pointer after resizing the /// array void resize_and_copy_from_host(const T *data, const size_t num_elements) { resize(num_elements); copy_from_host(data, num_elements); } /// Copies num_elements from the host vector after resizing the array void resize_and_copy_from_host(const std::vector &data, const size_t num_elements) { resize_and_copy_from_host(data.data(), num_elements); } /// Copies the entire host vector after resizing the array void resize_and_copy_from_host(const std::vector &data) { resize_and_copy_from_host(data.data(), data.size()); } /// Copies the entire host vector to the device. Fails if there is not /// enough space available. void copy_from_host(const std::vector &data) { if (data.size() < m_size) { throw std::runtime_error( std::string("Trying to copy ") + std::to_string(m_size) + std::string(" elements, but vector size is only ") + std::to_string(data.size())); } copy_from_host(data.data(), m_size); } /// Copies num_elements of data from the raw host pointer to the device. /// Fails if there is not enough space available. void copy_to_host(T *host_data, const size_t num_elements) const { if (num_elements > m_size) { throw std::runtime_error( std::string("Trying to copy ") + std::to_string(num_elements) + std::string(" elements, but vector size is only ") + std::to_string(m_size)); } try { CUDA_CHECK_THROW(cudaMemcpy(host_data, data(), num_elements * sizeof(T), cudaMemcpyDeviceToHost)); } catch (std::runtime_error error) { throw std::runtime_error(std::string("Could not copy to host: ") + error.what()); } } /// Copies num_elements from the device to a vector on the host void copy_to_host(std::vector &data, const size_t num_elements) const { if (data.size() < num_elements) { throw std::runtime_error( std::string("Trying to copy ") + std::to_string(num_elements) + std::string(" elements, but vector size is only ") + std::to_string(data.size())); } copy_to_host(data.data(), num_elements); } /// Copies num_elements from the device to a raw pointer on the host void copy_to_host(T *data) const { copy_to_host(data, m_size); } /// Copies all elements from the device to a vector on the host void copy_to_host(std::vector &data) const { if (data.size() < m_size) { throw std::runtime_error( std::string("Trying to copy ") + std::to_string(m_size) + std::string(" elements, but vector size is only ") + std::to_string(data.size())); } copy_to_host(data.data(), m_size); } /// Copies data from another device array to this one, automatically /// resizing it void copy_from_device(const GPUMemory &other) { if (m_size != other.m_size) { resize(other.m_size); } try { CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data, m_size * sizeof(T), cudaMemcpyDeviceToDevice)); } catch (std::runtime_error error) { throw std::runtime_error( std::string("Could not copy from device: ") + error.what()); } } /// Copies size elements from another device array to this one, /// automatically resizing it void copy_from_device(const GPUMemory &other, const size_t size) { if (m_size < size) { resize(size); } try { CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data, size * sizeof(T), cudaMemcpyDeviceToDevice)); } catch (std::runtime_error error) { throw std::runtime_error( std::string("Could not copy from device: ") + error.what()); } } // Created an (owned) copy of the data GPUMemory copy() const { GPUMemory result{m_size}; result.copy_from_device(*this); return result; } T *data() const { check_guards(); return m_data; } __host__ __device__ T &operator[](size_t idx) const { #ifdef DEBUG_BUFFER_OVERRUN if (idx > m_size) { printf("WARNING: buffer overrun of %p at idx %zu\n", idx); } #endif return m_data[idx]; } __host__ __device__ T &operator[](uint32_t idx) const { #ifdef DEBUG_BUFFER_OVERRUN if (idx > m_size) { printf("WARNING: buffer overrun of %p at idx %u\n", idx); } #endif return m_data[idx]; } size_t get_num_elements() const { return m_size; } size_t size() const { return get_num_elements(); } size_t get_bytes() const { return m_size * sizeof(T); } size_t bytes() const { return get_bytes(); } }; inline std::string bytes_to_string(size_t bytes) { std::array suffixes = { {"B", "KB", "MB", "GB", "TB", "PB", "EB"}}; double count = (double)bytes; uint32_t i = 0; for (; i < suffixes.size() && count >= 1024; ++i) { count /= 1024; } std::ostringstream oss; oss.precision(3); oss << count << " " << suffixes[i]; return oss.str(); } template __host__ __device__ void warp_activation(Activation activation, const fragment_t &frag, fragment_t &result) { switch (activation) { case Activation::ReLU: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = frag.x[t] * (T)((T)frag.x[t] > (T)0.0f); } return; case Activation::Exponential: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = (T)(expf((float)frag.x[t])); } return; case Activation::Sine: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = (T)(sinf((float)frag.x[t])); } return; case Activation::Sigmoid: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = (T)(logistic((float)frag.x[t])); } return; case Activation::Squareplus: #pragma unroll for (int t = 0; t < result.num_elements; t++) { float x = (float)frag.x[t] * K_ACT; result.x[t] = (T)(0.5f * (x + sqrtf(x * x + 4)) / K_ACT); } return; case Activation::Softplus: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = (T)(logf(expf((float)frag.x[t] * K_ACT) + 1.0f) / K_ACT); } return; case Activation::None: result = frag; return; default: // Unsupported activation // assert(false); // Commented out due to isolated strange // side-effects on Windows return; } } template __host__ __device__ fragment_t warp_activation(Activation activation, const fragment_t &frag) { fragment_t result; warp_activation(activation, frag, result); return result; } template __host__ __device__ void warp_activation_backward_in( Activation activation, const fragment_t &frag, const forward_fragment_t &forward_frag_in, fragment_t &result) { switch (activation) { case Activation::ReLU: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = frag.x[t] * (T)(forward_frag_in.x[t] > (T)0.0f); } return; case Activation::Exponential: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = frag.x[t] * (T)(expf(forward_frag_in.x[t])); } return; case Activation::Sine: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = frag.x[t] * (T)(cosf(forward_frag_in.x[t])); } return; case Activation::Sigmoid: #pragma unroll for (int t = 0; t < result.num_elements; t++) { float x = logistic(forward_frag_in.x[t]); result.x[t] = frag.x[t] * (T)(x * (1.0f - x)); } return; case Activation::Squareplus: #pragma unroll for (int t = 0; t < result.num_elements; t++) { float x = (float)forward_frag_in.x[t] * K_ACT; float y = 0.5f * (x + sqrtf(x * x + 4)); result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1)); } return; case Activation::Softplus: #pragma unroll for (int t = 0; t < result.num_elements; t++) { float tmp = expf((float)frag.x[t] * K_ACT); result.x[t] = frag.x[t] * (T)(tmp / (tmp + 1)); } return; case Activation::None: result = frag; return; default: // Unsupported activation // assert(false); // Commented out due to isolated strange // side-effects on Windows return; } } template __host__ __device__ fragment_t warp_activation_backward_in(Activation activation, const fragment_t &frag, const forward_fragment_t &forward_frag_in) { fragment_t result; warp_activation_backward_in(activation, frag, forward_frag_in, result); return result; } template __host__ __device__ void warp_activation_backward( Activation activation, const fragment_t &frag, const forward_fragment_t &forward_frag, fragment_t &result) { switch (activation) { case Activation::ReLU: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f); } return; case Activation::Exponential: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = frag.x[t] * forward_frag.x[t]; } return; case Activation::Sine: // Sine requires stored pre-activations, which we don't have. We // only write out the post-activations. assert(false); // Commented // out due to isolated strange side-effects on Windows return; case Activation::Sigmoid: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = frag.x[t] * (T)(forward_frag.x[t] * ((T)1.0f - forward_frag.x[t])); } return; case Activation::Squareplus: #pragma unroll for (int t = 0; t < result.num_elements; t++) { float y = (float)forward_frag.x[t] * K_ACT; result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1)); } return; case Activation::Softplus: #pragma unroll for (int t = 0; t < result.num_elements; t++) { result.x[t] = frag.x[t] * (T)(1.0f - expf(-(float)forward_frag.x[t] * K_ACT)); } return; case Activation::None: result = frag; return; default: // Unsupported activation // assert(false); // Commented out due to isolated strange // side-effects on Windows return; } } template __host__ __device__ fragment_t warp_activation_backward(Activation activation, const fragment_t &frag, const forward_fragment_t &forward_frag) { fragment_t result; warp_activation_backward(activation, frag, forward_frag, result); return result; } ================================================ FILE: lidarnerf/freqencoder/__init__.py ================================================ ================================================ FILE: lidarnerf/freqencoder/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-use_fast_math", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path _backend = load( name="_freqencoder", extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, sources=[ os.path.join(_src_path, "src", f) for f in [ "freqencoder.cu", "bindings.cpp", ] ], ) __all__ = ["_backend"] ================================================ FILE: lidarnerf/freqencoder/freq.py ================================================ import torch import torch.nn as nn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd try: import _freqencoder as _backend except ImportError: from .backend import _backend class _freq_encoder(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision def forward(ctx, inputs, degree, output_dim): # inputs: [B, input_dim], float # RETURN: [B, F], float if not inputs.is_cuda: inputs = inputs.cuda() inputs = inputs.contiguous() B, input_dim = inputs.shape # batch size, coord dim outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) ctx.save_for_backward(inputs, outputs) ctx.dims = [B, input_dim, degree, output_dim] return outputs @staticmethod # @once_differentiable @custom_bwd def backward(ctx, grad): # grad: [B, C * C] grad = grad.contiguous() inputs, outputs = ctx.saved_tensors B, input_dim, degree, output_dim = ctx.dims grad_inputs = torch.zeros_like(inputs) _backend.freq_encode_backward( grad, outputs, B, input_dim, degree, output_dim, grad_inputs ) return grad_inputs, None, None freq_encode = _freq_encoder.apply class FreqEncoder(nn.Module): def __init__(self, input_dim=3, degree=4): super().__init__() self.input_dim = input_dim self.degree = degree self.output_dim = input_dim + input_dim * 2 * degree def __repr__(self): return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" def forward(self, inputs, **kwargs): # inputs: [..., input_dim] # return: [..., ] prefix_shape = list(inputs.shape[:-1]) inputs = inputs.reshape(-1, self.input_dim) outputs = freq_encode(inputs, self.degree, self.output_dim) outputs = outputs.reshape(prefix_shape + [self.output_dim]) return outputs ================================================ FILE: lidarnerf/freqencoder/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-use_fast_math", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path setup( name="freqencoder", # package name, import this to use python API ext_modules=[ CUDAExtension( name="_freqencoder", # extension name, import this to use CUDA API sources=[ os.path.join(_src_path, "src", f) for f in [ "freqencoder.cu", "bindings.cpp", ] ], extra_compile_args={ "cxx": c_flags, "nvcc": nvcc_flags, }, ), ], cmdclass={ "build_ext": BuildExtension, }, ) ================================================ FILE: lidarnerf/freqencoder/src/bindings.cpp ================================================ #include #include "freqencoder.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("freq_encode_forward", &freq_encode_forward, "freq encode forward (CUDA)"); m.def("freq_encode_backward", &freq_encode_backward, "freq encode backward (CUDA)"); } ================================================ FILE: lidarnerf/freqencoder/src/freqencoder.cu ================================================ #include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ #x " must be an int tensor") #define CHECK_IS_FLOATING(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \ x.scalar_type() == at::ScalarType::Half || \ x.scalar_type() == at::ScalarType::Double, \ #x " must be a floating tensor") inline constexpr __device__ float PI() { return 3.141592653589793f; } template __host__ __device__ T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } // inputs: [B, D] // outputs: [B, C], C = D + D * deg * 2 __global__ void kernel_freq(const float *__restrict__ inputs, uint32_t B, uint32_t D, uint32_t deg, uint32_t C, float *outputs) { // parallel on per-element const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; if (t >= B * C) return; // get index const uint32_t b = t / C; const uint32_t c = t - b * C; // t % C; // locate inputs += b * D; outputs += t; // write self if (c < D) { outputs[0] = inputs[c]; // write freq } else { const uint32_t col = c / D - 1; const uint32_t d = c % D; const uint32_t freq = col / 2; const float phase_shift = (col % 2) * (PI() / 2); outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift); } } // grad: [B, C], C = D + D * deg * 2 // outputs: [B, C] // grad_inputs: [B, D] __global__ void kernel_freq_backward(const float *__restrict__ grad, const float *__restrict__ outputs, uint32_t B, uint32_t D, uint32_t deg, uint32_t C, float *grad_inputs) { // parallel on per-element const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; if (t >= B * D) return; const uint32_t b = t / D; const uint32_t d = t - b * D; // t % D; // locate grad += b * C; outputs += b * C; grad_inputs += t; // register float result = grad[d]; grad += D; outputs += D; for (uint32_t f = 0; f < deg; f++) { result += scalbnf(1.0f, f) * (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]); grad += 2 * D; outputs += 2 * D; } // write grad_inputs[0] = result; } void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs) { CHECK_CUDA(inputs); CHECK_CUDA(outputs); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(outputs); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(outputs); static constexpr uint32_t N_THREADS = 128; kernel_freq<<>>( inputs.data_ptr(), B, D, deg, C, outputs.data_ptr()); } void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs) { CHECK_CUDA(grad); CHECK_CUDA(outputs); CHECK_CUDA(grad_inputs); CHECK_CONTIGUOUS(grad); CHECK_CONTIGUOUS(outputs); CHECK_CONTIGUOUS(grad_inputs); CHECK_IS_FLOATING(grad); CHECK_IS_FLOATING(outputs); CHECK_IS_FLOATING(grad_inputs); static constexpr uint32_t N_THREADS = 128; kernel_freq_backward<<>>( grad.data_ptr(), outputs.data_ptr(), B, D, deg, C, grad_inputs.data_ptr()); } ================================================ FILE: lidarnerf/freqencoder/src/freqencoder.h ================================================ #pragma once #include #include // _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, // outputs) void freq_encode_forward(at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor outputs); // _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, // output_dim, grad_inputs) void freq_encode_backward(at::Tensor grad, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t deg, const uint32_t C, at::Tensor grad_inputs); ================================================ FILE: lidarnerf/gridencoder/__init__.py ================================================ ================================================ FILE: lidarnerf/gridencoder/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path _backend = load( name="_grid_encoder", extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, sources=[ os.path.join(_src_path, "src", f) for f in [ "gridencoder.cu", "bindings.cpp", ] ], ) __all__ = ["_backend"] ================================================ FILE: lidarnerf/gridencoder/grid.py ================================================ import numpy as np import torch import torch.nn as nn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd try: import _gridencoder as _backend except ImportError: from .backend import _backend _gridtype_to_id = { "hash": 0, "tiled": 1, } _interp_to_id = { "linear": 0, "smoothstep": 1, } class _grid_encode(Function): @staticmethod @custom_fwd def forward( ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0, ): # inputs: [B, D], float in [0, 1] # embeddings: [sO, C], float # offsets: [L + 1], int # RETURN: [B, F], float inputs = inputs.contiguous() B, D = inputs.shape # batch size, coord dim L = offsets.shape[0] - 1 # level C = embeddings.shape[1] # embedding dim for each level S = np.log2( per_level_scale ) # resolution multiplier at each level, apply log2 for later CUDA exp2f H = base_resolution # base resolution # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) # if C % 2 != 0, force float, since half for atomicAdd is very slow. if torch.is_autocast_enabled() and C % 2 == 0: embeddings = embeddings.to(torch.half) # L first, optimize cache for cuda kernel, but needs an extra permute later outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) if calc_grad_inputs: dy_dx = torch.empty( B, L * D * C, device=inputs.device, dtype=embeddings.dtype ) else: dy_dx = None _backend.grid_encode_forward( inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation, ) # permute back to [B, L * C] outputs = outputs.permute(1, 0, 2).reshape(B, L * C) ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] ctx.align_corners = align_corners return outputs @staticmethod # @once_differentiable @custom_bwd def backward(ctx, grad): inputs, embeddings, offsets, dy_dx = ctx.saved_tensors B, D, C, L, S, H, gridtype, interpolation = ctx.dims align_corners = ctx.align_corners # grad: [B, L * C] --> [L, B, C] grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() grad_embeddings = torch.zeros_like(embeddings) if dy_dx is not None: grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) else: grad_inputs = None _backend.grid_encode_backward( grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation, ) if dy_dx is not None: grad_inputs = grad_inputs.to(inputs.dtype) return grad_inputs, grad_embeddings, None, None, None, None, None, None, None grid_encode = _grid_encode.apply class GridEncoder(nn.Module): def __init__( self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype="hash", align_corners=False, interpolation="linear", ): super().__init__() # the finest resolution desired at the last level, if provided, overridee per_level_scale if desired_resolution is not None: per_level_scale = np.exp2( np.log2(desired_resolution / base_resolution) / (num_levels - 1) ) self.input_dim = input_dim # coord dims, 2 or 3 self.num_levels = num_levels # num levels, each level multiply resolution by 2 self.level_dim = level_dim # encode channels per level self.per_level_scale = ( per_level_scale # multiply resolution by this scale at each level. ) self.log2_hashmap_size = log2_hashmap_size self.base_resolution = base_resolution self.output_dim = num_levels * level_dim self.gridtype = gridtype self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" self.interpolation = interpolation self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" self.align_corners = align_corners # allocate parameters offsets = [] offset = 0 self.max_params = 2**log2_hashmap_size for i in range(num_levels): resolution = int(np.ceil(base_resolution * per_level_scale**i)) params_in_level = min( self.max_params, (resolution if align_corners else resolution + 1) ** input_dim, ) # limit max number params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible offsets.append(offset) offset += params_in_level offsets.append(offset) offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) self.register_buffer("offsets", offsets) self.n_params = offsets[-1] * level_dim # parameters self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) self.reset_parameters() def reset_parameters(self): std = 1e-4 self.embeddings.data.uniform_(-std, std) def __repr__(self): return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" def forward(self, inputs, bound=1): # inputs: [..., input_dim], normalized real world positions in [-bound, bound] # return: [..., num_levels * level_dim] inputs = (inputs + bound) / (2 * bound) # map to [0, 1] # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) prefix_shape = list(inputs.shape[:-1]) inputs = inputs.view(-1, self.input_dim) outputs = grid_encode( inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id, ) outputs = outputs.view(prefix_shape + [self.output_dim]) # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) return outputs # always run in float precision! @torch.cuda.amp.autocast(enabled=False) def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. D = self.input_dim C = self.embeddings.shape[1] # embedding dim for each level L = self.offsets.shape[0] - 1 # level S = np.log2( self.per_level_scale ) # resolution multiplier at each level, apply log2 for later CUDA exp2f H = self.base_resolution # base resolution if inputs is None: # randomized in [0, 1] inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) else: inputs = (inputs + bound) / (2 * bound) # map to [0, 1] inputs = inputs.view(-1, self.input_dim) B = inputs.shape[0] if self.embeddings.grad is None: raise ValueError( "grad is None, should be called after loss.backward() and before optimizer.step()!" ) _backend.grad_total_variation( inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners, ) ================================================ FILE: lidarnerf/gridencoder/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path setup( name="gridencoder", # package name, import this to use python API ext_modules=[ CUDAExtension( name="_gridencoder", # extension name, import this to use CUDA API sources=[ os.path.join(_src_path, "src", f) for f in [ "gridencoder.cu", "bindings.cpp", ] ], extra_compile_args={ "cxx": c_flags, "nvcc": nvcc_flags, }, ), ], cmdclass={ "build_ext": BuildExtension, }, ) ================================================ FILE: lidarnerf/gridencoder/src/bindings.cpp ================================================ #include #include "gridencoder.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); } ================================================ FILE: lidarnerf/gridencoder/src/gridencoder.cu ================================================ #include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ #x " must be an int tensor") #define CHECK_IS_FLOATING(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \ x.scalar_type() == at::ScalarType::Half || \ x.scalar_type() == at::ScalarType::Double, \ #x " must be a floating tensor") // just for compatability of half precision in // AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { // requires CUDA >= 10 and ARCH >= 70 // this is very slow compared to float or __half2, never use it. // return atomicAdd(reinterpret_cast<__half*>(address), val); } template __host__ __device__ inline T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } template __host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { return min(max(v, lo), hi); } template __device__ inline T smoothstep(T val) { return val * val * (3.0f - 2.0f * val); } template __device__ inline T smoothstep_derivative(T val) { return 6 * val * (1.0f - val); } template __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { // coherent type of hashing constexpr uint32_t primes[7] = {1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u}; uint32_t result = 0; #pragma unroll for (uint32_t i = 0; i < D; ++i) { result ^= pos_grid[i] * primes[i]; } return result; } template __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { uint32_t stride = 1; uint32_t index = 0; #pragma unroll for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { index += pos_grid[d] * stride; stride *= align_corners ? resolution : (resolution + 1); } // NOTE: for NeRF, the hash is in fact not necessary. Check // https://github.com/NVlabs/instant-ngp/issues/97. gridtype: 0 == hash, 1 // == tiled if (gridtype == 0 && stride > hashmap_size) { index = fast_hash(pos_grid); } return (index % hashmap_size) * C + ch; } template __global__ void kernel_grid(const float *__restrict__ inputs, const scalar_t *__restrict__ grid, const int *__restrict__ offsets, scalar_t *__restrict__ outputs, const uint32_t B, const uint32_t L, const float S, const uint32_t H, scalar_t *__restrict__ dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; if (b >= B) return; const uint32_t level = blockIdx.y; // locate grid += (uint32_t)offsets[level] * C; inputs += b * D; outputs += level * B * C + b * C; // check input range (should be in [0, 1]) bool flag_oob = false; #pragma unroll for (uint32_t d = 0; d < D; d++) { if (inputs[d] < 0 || inputs[d] > 1) { flag_oob = true; } } // if input out of bound, just set output to 0 if (flag_oob) { #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { outputs[ch] = 0; } if (dy_dx) { dy_dx += b * D * L * C + level * D * C; // B L D C #pragma unroll for (uint32_t d = 0; d < D; d++) { #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { dy_dx[d * C + ch] = 0; } } } return; } const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; const float scale = exp2f(level * S) * H - 1.0f; const uint32_t resolution = (uint32_t)ceil(scale) + 1; // calculate coordinate (always use float for precision!) float pos[D]; float pos_deriv[D]; uint32_t pos_grid[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); pos_grid[d] = floorf(pos[d]); pos[d] -= (float)pos_grid[d]; // smoothstep instead of linear if (interp == 1) { pos_deriv[d] = smoothstep_derivative(pos[d]); pos[d] = smoothstep(pos[d]); } else { pos_deriv[d] = 1.0f; // linear deriv is default to 1 } } // printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], // pos_grid[0], pos_grid[1]); // interpolate scalar_t results[C] = {0}; // temp results in register #pragma unroll for (uint32_t idx = 0; idx < (1 << D); idx++) { float w = 1; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { if ((idx & (1 << d)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); // writing to register (fast) #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { results[ch] += w * grid[index + ch]; } // printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, // index, w, grid[index]); } // writing to global memory (slow) #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { outputs[ch] = results[ch]; } // prepare dy_dx // differentiable (soft) indexing: // https://discuss.pytorch.org/t/differentiable-indexing/17647/9 if (dy_dx) { dy_dx += b * D * L * C + level * D * C; // B L D C #pragma unroll for (uint32_t gd = 0; gd < D; gd++) { scalar_t results_grad[C] = {0}; #pragma unroll for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { float w = scale; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t nd = 0; nd < D - 1; nd++) { const uint32_t d = (nd >= gd) ? (nd + 1) : nd; if ((idx & (1 << nd)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } pos_grid_local[gd] = pos_grid[gd]; uint32_t index_left = get_grid_index( gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); pos_grid_local[gd] = pos_grid[gd] + 1; uint32_t index_right = get_grid_index( gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; } } #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { dy_dx[gd * C + ch] = results_grad[ch]; } } } } template __global__ void kernel_grid_backward(const scalar_t *__restrict__ grad, const float *__restrict__ inputs, const scalar_t *__restrict__ grid, const int *__restrict__ offsets, scalar_t *__restrict__ grad_grid, const uint32_t B, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; if (b >= B) return; const uint32_t level = blockIdx.y; const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; // locate grad_grid += offsets[level] * C; inputs += b * D; grad += level * B * C + b * C + ch; // L, B, C const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; const float scale = exp2f(level * S) * H - 1.0f; const uint32_t resolution = (uint32_t)ceil(scale) + 1; // check input range (should be in [0, 1]) #pragma unroll for (uint32_t d = 0; d < D; d++) { if (inputs[d] < 0 || inputs[d] > 1) { return; // grad is init as 0, so we simply return. } } // calculate coordinate float pos[D]; uint32_t pos_grid[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); pos_grid[d] = floorf(pos[d]); pos[d] -= (float)pos_grid[d]; // smoothstep instead of linear if (interp == 1) { pos[d] = smoothstep(pos[d]); } } scalar_t grad_cur[N_C] = {0}; // fetch to register #pragma unroll for (uint32_t c = 0; c < N_C; c++) { grad_cur[c] = grad[c]; } // interpolate #pragma unroll for (uint32_t idx = 0; idx < (1 << D); idx++) { float w = 1; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { if ((idx & (1 << d)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); // atomicAdd for __half is slow (especially for large values), so we use // __half2 if N_C % 2 == 0 // TODO: use float which is better than __half, if N_C % 2 != 0 if (std::is_same::value && N_C % 2 == 0) { #pragma unroll for (uint32_t c = 0; c < N_C; c += 2) { // process two __half at once (by interpreting as a __half2) __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; atomicAdd((__half2 *)&grad_grid[index + c], v); } // float, or __half when N_C % 2 != 0 (which means C == 1) } else { #pragma unroll for (uint32_t c = 0; c < N_C; c++) { atomicAdd(&grad_grid[index + c], w * grad_cur[c]); } } } } template __global__ void kernel_input_backward(const scalar_t *__restrict__ grad, const scalar_t *__restrict__ dy_dx, scalar_t *__restrict__ grad_inputs, uint32_t B, uint32_t L) { const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; if (t >= B * D) return; const uint32_t b = t / D; const uint32_t d = t - b * D; dy_dx += b * L * D * C; scalar_t result = 0; #pragma unroll for (int l = 0; l < L; l++) { #pragma unroll for (int ch = 0; ch < C; ch++) { result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; } } grad_inputs[t] = result; } template void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { static constexpr uint32_t N_THREAD = 512; const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1}; switch (C) { case 1: kernel_grid<<>>( inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 2: kernel_grid<<>>( inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 4: kernel_grid<<>>( inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 8: kernel_grid<<>>( inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } // inputs: [B, D], float, in [0, 1] // embeddings: [sO, C], float // offsets: [L + 1], uint32_t // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit // into cache at a time.) H: base resolution dy_dx: [B, L * D * C] template void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { switch (D) { case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } template void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { static constexpr uint32_t N_THREAD = 256; const uint32_t N_C = std::min(2u, C); // n_features_per_thread const dim3 blocks_hashgrid = {div_round_up(B * C / N_C, N_THREAD), L, 1}; switch (C) { case 1: kernel_grid_backward <<>>( grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); if (dy_dx) kernel_input_backward <<>>( grad, dy_dx, grad_inputs, B, L); break; case 2: kernel_grid_backward <<>>( grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); if (dy_dx) kernel_input_backward <<>>( grad, dy_dx, grad_inputs, B, L); break; case 4: kernel_grid_backward <<>>( grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); if (dy_dx) kernel_input_backward <<>>( grad, dy_dx, grad_inputs, B, L); break; case 8: kernel_grid_backward <<>>( grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); if (dy_dx) kernel_input_backward <<>>( grad, dy_dx, grad_inputs, B, L); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } // grad: [L, B, C], float // inputs: [B, D], float, in [0, 1] // embeddings: [sO, C], float // offsets: [L + 1], uint32_t // grad_embeddings: [sO, C] // H: base resolution template void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { switch (D) { case 2: kernel_grid_backward_wrapper( grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; case 3: kernel_grid_backward_wrapper( grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; case 4: kernel_grid_backward_wrapper( grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; case 5: kernel_grid_backward_wrapper( grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { CHECK_CUDA(inputs); CHECK_CUDA(embeddings); CHECK_CUDA(offsets); CHECK_CUDA(outputs); // CHECK_CUDA(dy_dx); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(embeddings); CHECK_CONTIGUOUS(offsets); CHECK_CONTIGUOUS(outputs); // CHECK_CONTIGUOUS(dy_dx); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(embeddings); CHECK_IS_INT(offsets); CHECK_IS_FLOATING(outputs); // CHECK_IS_FLOATING(dy_dx); AT_DISPATCH_FLOATING_TYPES_AND_HALF( embeddings.scalar_type(), "grid_encode_forward", ([&] { grid_encode_forward_cuda( inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); })); } void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { CHECK_CUDA(grad); CHECK_CUDA(inputs); CHECK_CUDA(embeddings); CHECK_CUDA(offsets); CHECK_CUDA(grad_embeddings); // CHECK_CUDA(dy_dx); // CHECK_CUDA(grad_inputs); CHECK_CONTIGUOUS(grad); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(embeddings); CHECK_CONTIGUOUS(offsets); CHECK_CONTIGUOUS(grad_embeddings); // CHECK_CONTIGUOUS(dy_dx); // CHECK_CONTIGUOUS(grad_inputs); CHECK_IS_FLOATING(grad); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(embeddings); CHECK_IS_INT(offsets); CHECK_IS_FLOATING(grad_embeddings); // CHECK_IS_FLOATING(dy_dx); // CHECK_IS_FLOATING(grad_inputs); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "grid_encode_backward", ([&] { grid_encode_backward_cuda( grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); })); } template __global__ void kernel_grad_tv(const scalar_t *__restrict__ inputs, const scalar_t *__restrict__ grid, scalar_t *__restrict__ grad, const int *__restrict__ offsets, const float weight, const uint32_t B, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; if (b >= B) return; const uint32_t level = blockIdx.y; // locate inputs += b * D; grid += (uint32_t)offsets[level] * C; grad += (uint32_t)offsets[level] * C; // check input range (should be in [0, 1]) bool flag_oob = false; #pragma unroll for (uint32_t d = 0; d < D; d++) { if (inputs[d] < 0 || inputs[d] > 1) { flag_oob = true; } } // if input out of bound, do nothing if (flag_oob) return; const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; const float scale = exp2f(level * S) * H - 1.0f; const uint32_t resolution = (uint32_t)ceil(scale) + 1; // calculate coordinate float pos[D]; uint32_t pos_grid[D]; // [0, resolution] #pragma unroll for (uint32_t d = 0; d < D; d++) { pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); pos_grid[d] = floorf(pos[d]); // pos[d] -= (float)pos_grid[d]; // not used } // printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], // pos_grid[0], pos_grid[1]); // total variation on pos_grid scalar_t results[C] = {0}; // temp results in register scalar_t idelta[C] = {0}; uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); scalar_t w = weight / (2 * D); #pragma unroll for (uint32_t d = 0; d < D; d++) { uint32_t cur_d = pos_grid[d]; scalar_t grad_val; // right side if (cur_d < resolution) { pos_grid[d] = cur_d + 1; uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { // results[ch] += w * clamp(grid[index + ch] - grid[index_right // + ch], -1.0f, 1.0f); grad_val = (grid[index + ch] - grid[index_right + ch]); results[ch] += grad_val; idelta[ch] += grad_val * grad_val; } } // left side if (cur_d > 0) { pos_grid[d] = cur_d - 1; uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { // results[ch] += w * clamp(grid[index + ch] - grid[index_left + // ch], -1.0f, 1.0f); grad_val = (grid[index + ch] - grid[index_left + ch]); results[ch] += grad_val; idelta[ch] += grad_val * grad_val; } } // reset pos_grid[d] = cur_d; } // writing to global memory (slow) #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { // index may collide, so use atomic! atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); } } template void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { static constexpr uint32_t N_THREAD = 512; const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1}; switch (C) { case 1: kernel_grad_tv<<>>( inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; case 2: kernel_grad_tv<<>>( inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; case 4: kernel_grad_tv<<>>( inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; case 8: kernel_grad_tv<<>>( inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } template void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { switch (D) { case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { AT_DISPATCH_FLOATING_TYPES_AND_HALF( embeddings.scalar_type(), "grad_total_variation", ([&] { grad_total_variation_cuda( inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); })); } ================================================ FILE: lidarnerf/gridencoder/src/gridencoder.h ================================================ #ifndef _HASH_ENCODE_H #define _HASH_ENCODE_H #include #include // inputs: [B, D], float, in [0, 1] // embeddings: [sO, C], float // offsets: [L + 1], uint32_t // outputs: [B, L * C], float // H: base resolution void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); #endif ================================================ FILE: lidarnerf/loss.py ================================================ import torch import numpy as np def mape_loss(pred, target, reduction="mean"): # pred, target: [B, 1], torch tenspr difference = (pred - target).abs() scale = 1 / (target.abs() + 1e-2) loss = difference * scale if reduction == "mean": loss = loss.mean() return loss def huber_loss(pred, target, delta=0.1, reduction="mean"): rel = (pred - target).abs() sqr = 0.5 / delta * rel * rel loss = torch.where(rel > delta, rel - 0.5 * delta, sqr) if reduction == "mean": loss = loss.mean() return loss # ref: https://github.com/sunset1995/torch_efficient_distloss/blob/main/torch_efficient_distloss/eff_distloss.py class EffDistLoss(torch.autograd.Function): @staticmethod def forward(ctx, w, m, interval): """ Efficient O(N) realization of distortion loss. There are B rays each with N sampled points. w: Float tensor in shape [B,N]. Volume rendering weights of each point. m: Float tensor in shape [B,N]. Midpoint distance to camera of each point. interval: Scalar or float tensor in shape [B,N]. The query interval of each point. """ n_rays = np.prod(w.shape[:-1]) wm = w * m w_cumsum = w.cumsum(dim=-1) wm_cumsum = wm.cumsum(dim=-1) w_total = w_cumsum[..., [-1]] wm_total = wm_cumsum[..., [-1]] w_prefix = torch.cat([torch.zeros_like(w_total), w_cumsum[..., :-1]], dim=-1) wm_prefix = torch.cat([torch.zeros_like(wm_total), wm_cumsum[..., :-1]], dim=-1) loss_uni = (1 / 3) * interval * w.pow(2) loss_bi = 2 * w * (m * w_prefix - wm_prefix) if torch.is_tensor(interval): ctx.save_for_backward( w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval ) ctx.interval = None else: ctx.save_for_backward(w, m, wm, w_prefix, w_total, wm_prefix, wm_total) ctx.interval = interval ctx.n_rays = n_rays return (loss_bi.sum() + loss_uni.sum()) / n_rays @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_back): interval = ctx.interval n_rays = ctx.n_rays if interval is None: ( w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval, ) = ctx.saved_tensors else: w, m, wm, w_prefix, w_total, wm_prefix, wm_total = ctx.saved_tensors grad_uni = (1 / 3) * interval * 2 * w w_suffix = w_total - (w_prefix + w) wm_suffix = wm_total - (wm_prefix + wm) grad_bi = 2 * (m * (w_prefix - w_suffix) + (wm_suffix - wm_prefix)) grad = grad_back * (grad_bi + grad_uni) / n_rays return grad, None, None, None eff_distloss = EffDistLoss.apply ================================================ FILE: lidarnerf/nerf/network.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from lidarnerf.encoding import get_encoder from lidarnerf.activation import trunc_exp from .renderer import NeRFRenderer class NeRFNetwork(NeRFRenderer): def __init__( self, encoding="hashgrid", encoding_dir="frequency", multires=15, encoding_bg="hashgrid", desired_resolution=2048, log2_hashmap_size=19, num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, num_layers_bg=2, hidden_dim_bg=64, out_color_dim=3, out_lidar_color_dim=2, bound=1, **kwargs, ): super().__init__(bound, **kwargs) # sigma network self.num_layers = num_layers self.hidden_dim = hidden_dim self.geo_feat_dim = geo_feat_dim self.out_color_dim = out_color_dim self.out_lidar_color_dim = out_lidar_color_dim self.encoder, self.in_dim = get_encoder( encoding, desired_resolution=desired_resolution, log2_hashmap_size=log2_hashmap_size, ) sigma_net = [] for l in range(num_layers): if l == 0: in_dim = self.in_dim else: in_dim = hidden_dim if l == num_layers - 1: out_dim = 1 + self.geo_feat_dim # 1 sigma + 15 SH features for color else: out_dim = hidden_dim sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.sigma_net = nn.ModuleList(sigma_net) # color network self.num_layers_color = num_layers_color self.hidden_dim_color = hidden_dim_color self.encoder_dir, self.in_dim_dir = get_encoder("sphere_harmonics") color_net = [] for l in range(num_layers_color): if l == 0: in_dim = self.in_dim_dir + self.geo_feat_dim else: in_dim = hidden_dim_color if l == num_layers_color - 1: out_dim = self.out_color_dim # 3 rgb else: out_dim = hidden_dim_color color_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.color_net = nn.ModuleList(color_net) # lidar color network self.encoder_lidar_dir, self.in_dim_dir = get_encoder("frequency", multires=12) lidar_color_net = [] for l in range(num_layers_color): if l == 0: in_dim = self.in_dim_dir + self.geo_feat_dim else: in_dim = hidden_dim_color if l == num_layers_color - 1: out_dim = self.out_lidar_color_dim # 2 rgb else: out_dim = hidden_dim_color lidar_color_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.lidar_color_net = nn.ModuleList(lidar_color_net) # background network if self.bg_radius > 0: self.num_layers_bg = num_layers_bg self.hidden_dim_bg = hidden_dim_bg self.encoder_bg, self.in_dim_bg = get_encoder( encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048, ) # much smaller hashgrid bg_net = [] for l in range(num_layers_bg): if l == 0: in_dim = self.in_dim_bg + self.in_dim_dir else: in_dim = hidden_dim_bg if l == num_layers_bg - 1: out_dim = 3 # 3 rgb else: out_dim = hidden_dim_bg bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.bg_net = nn.ModuleList(bg_net) else: self.bg_net = None def forward(self, x, d): # x: [N, 3], in [-bound, bound] # d: [N, 3], nomalized in [-1, 1] # sigma x = self.encoder(x, bound=self.bound) h = x for l in range(self.num_layers): h = self.sigma_net[l](h) if l != self.num_layers - 1: h = F.relu(h, inplace=True) # sigma = F.relu(h[..., 0]) sigma = trunc_exp(h[..., 0]) geo_feat = h[..., 1:] # color d = self.encoder_dir(d) h = torch.cat([d, geo_feat], dim=-1) for l in range(self.num_layers_color): h = self.color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) # sigmoid activation for rgb color = torch.sigmoid(h) return sigma, color def density(self, x): # x: [N, 3], in [-bound, bound] x = self.encoder(x, bound=self.bound) h = x for l in range(self.num_layers): h = self.sigma_net[l](h) if l != self.num_layers - 1: h = F.relu(h, inplace=True) # sigma = F.relu(h[..., 0]) sigma = trunc_exp(h[..., 0]) geo_feat = h[..., 1:] return { "sigma": sigma, "geo_feat": geo_feat, } def background(self, x, d): # x: [N, 2], in [-1, 1] h = self.encoder_bg(x) # [N, C] d = self.encoder_dir(d) h = torch.cat([d, h], dim=-1) for l in range(self.num_layers_bg): h = self.bg_net[l](h) if l != self.num_layers_bg - 1: h = F.relu(h, inplace=True) # sigmoid activation for rgb rgbs = torch.sigmoid(h) return rgbs # allow masked inference def color(self, x, d, cal_lidar_color=False, mask=None, geo_feat=None, **kwargs): # x: [N, 3] in [-bound, bound] # mask: [N,], bool, indicates where we actually needs to compute rgb. if mask is not None: rgbs = torch.zeros( mask.shape[0], self.out_dim, dtype=x.dtype, device=x.device ) # [N, 3] # in case of empty mask if not mask.any(): return rgbs x = x[mask] d = d[mask] geo_feat = geo_feat[mask] if cal_lidar_color: d = self.encoder_lidar_dir(d) h = torch.cat([d, geo_feat], dim=-1) for l in range(self.num_layers_color): h = self.lidar_color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) else: d = self.encoder_dir(d) h = torch.cat([d, geo_feat], dim=-1) for l in range(self.num_layers_color): h = self.color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) # sigmoid activation for rgb h = torch.sigmoid(h) if mask is not None: rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 else: rgbs = h return rgbs # optimizer utils def get_params(self, lr): params = [ {"params": self.encoder.parameters(), "lr": lr}, {"params": self.sigma_net.parameters(), "lr": lr}, {"params": self.encoder_dir.parameters(), "lr": lr}, {"params": self.color_net.parameters(), "lr": lr}, {"params": self.encoder_lidar_dir.parameters(), "lr": lr}, {"params": self.lidar_color_net.parameters(), "lr": lr}, ] if self.bg_radius > 0: params.append({"params": self.encoder_bg.parameters(), "lr": lr}) params.append({"params": self.bg_net.parameters(), "lr": lr}) return params ================================================ FILE: lidarnerf/nerf/network_tcnn.py ================================================ import torch import numpy as np import tinycudann as tcnn from lidarnerf.activation import trunc_exp from .renderer import NeRFRenderer class NeRFNetwork(NeRFRenderer): def __init__( self, encoding="HashGrid", desired_resolution=2048, log2_hashmap_size=19, encoding_dir="SphericalHarmonics", n_features_per_level=2, num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, out_color_dim=3, out_lidar_color_dim=2, bound=1, **kwargs, ): super().__init__(bound, **kwargs) # sigma network self.num_layers = num_layers self.hidden_dim = hidden_dim self.geo_feat_dim = geo_feat_dim self.desired_resolution = desired_resolution self.log2_hashmap_size = log2_hashmap_size self.out_color_dim = out_color_dim self.out_lidar_color_dim = out_lidar_color_dim self.n_features_per_level = n_features_per_level per_level_scale = np.exp2( np.log2(self.desired_resolution * bound / 16) / (16 - 1) ) print(f"TCNN desired resolution: {self.desired_resolution}") print(f"TCNN per level scale: {per_level_scale}") self.encoder = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "HashGrid", "n_levels": 16, "n_features_per_level": self.n_features_per_level, "log2_hashmap_size": self.log2_hashmap_size, "base_resolution": 16, "per_level_scale": per_level_scale, # "interpolation": "Smoothstep" }, ) self.sigma_net = tcnn.Network( n_input_dims=self.encoder.n_output_dims, n_output_dims=1 + self.geo_feat_dim, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": hidden_dim, "n_hidden_layers": num_layers - 1, }, ) # color network self.num_layers_color = num_layers_color self.hidden_dim_color = hidden_dim_color # # SH self.encoder_dir = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "SphericalHarmonics", "degree": 4, }, ) # # Hash # per_level_scale = np.exp2(np.log2(1024 * bound / 4) / (4 - 1)) # self.encoder_dir = tcnn.Encoding( # n_input_dims=3, # encoding_config={ # "otype": "HashGrid", # "n_levels": 4, # "n_features_per_level": 2, # "log2_hashmap_size": self.log2_hashmap_size, # "base_resolution": 128, # "per_level_scale": per_level_scale, # }, # ) # # freq self.encoder_lidar_dir = tcnn.Encoding( n_input_dims=3, encoding_config={ "otype": "Frequency", "degree": 12, }, ) self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim self.color_net = tcnn.Network( n_input_dims=self.in_dim_color, n_output_dims=self.out_color_dim, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": hidden_dim_color, "n_hidden_layers": num_layers_color - 1, }, ) self.in_dim_lidar_color = ( self.encoder_lidar_dir.n_output_dims + self.geo_feat_dim ) self.lidar_color_net = tcnn.Network( n_input_dims=self.in_dim_lidar_color, n_output_dims=self.out_lidar_color_dim, network_config={ "otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": hidden_dim_color, "n_hidden_layers": num_layers_color - 1, }, ) def forward(self, x, d): pass def density(self, x): # x: [N, 3], in [-bound, bound] x = (x + self.bound) / (2 * self.bound) # to [0, 1] x = self.encoder(x) h = self.sigma_net(x) # sigma = F.relu(h[..., 0]) sigma = trunc_exp(h[..., 0]) geo_feat = h[..., 1:] return { "sigma": sigma, "geo_feat": geo_feat, } # allow masked inference def color(self, x, d, cal_lidar_color=False, mask=None, geo_feat=None, **kwargs): # x: [N, 3] in [-bound, bound] # mask: [N,], bool, indicates where we actually needs to compute rgb. x = (x + self.bound) / (2 * self.bound) # to [0, 1] if mask is not None: rgbs = torch.zeros( mask.shape[0], self.out_dim, dtype=x.dtype, device=x.device ) # [N, 3] # in case of empty mask if not mask.any(): return rgbs x = x[mask] d = d[mask] geo_feat = geo_feat[mask] # color # d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] # d = self.encoder_dir(d) # h = torch.cat([d, geo_feat], dim=-1) if cal_lidar_color: d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] d = self.encoder_lidar_dir(d) h = torch.cat([d, geo_feat], dim=-1) h = self.lidar_color_net(h) else: d = (d + 1) / 2 d = self.encoder_dir(d) h = torch.cat([d, geo_feat], dim=-1) h = self.color_net(h) # sigmoid activation for rgb h = torch.sigmoid(h) if mask is not None: rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 else: rgbs = h return rgbs # optimizer utils def get_params(self, lr): params = [ {"params": self.encoder.parameters(), "lr": lr}, {"params": self.sigma_net.parameters(), "lr": lr}, {"params": self.encoder_dir.parameters(), "lr": lr}, {"params": self.encoder_lidar_dir.parameters(), "lr": lr}, {"params": self.color_net.parameters(), "lr": lr}, {"params": self.lidar_color_net.parameters(), "lr": lr}, ] if self.bg_radius > 0: params.append({"params": self.encoder_bg.parameters(), "lr": lr}) params.append({"params": self.bg_net.parameters(), "lr": lr}) return params ================================================ FILE: lidarnerf/nerf/renderer.py ================================================ import math import trimesh import torch import torch.nn as nn from lidarnerf import raymarching def sample_pdf(bins, weights, n_samples, det=False): # This implementation is from NeRF # bins: [B, T], old_z_vals # weights: [B, T - 1], bin weights. # return: [B, n_samples], new_z_vals # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # Take uniform samples if det: u = torch.linspace( 0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples ).to(weights.device) u = u.expand(list(cdf.shape[:-1]) + [n_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = cdf_g[..., 1] - cdf_g[..., 0] denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples def plot_pointcloud(pc, color=None): # pc: [N, 3] # color: [N, 3/4] print("[visualize points]", pc.shape, pc.dtype, pc.min(0), pc.max(0)) pc = trimesh.PointCloud(pc, color) # axis axes = trimesh.creation.axis(axis_length=4) # sphere sphere = trimesh.creation.icosphere(radius=1) trimesh.Scene([pc, axes, sphere]).show() class NeRFRenderer(nn.Module): def __init__( self, bound=1, density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. min_near=0.2, min_near_lidar=0.2, density_thresh=0.01, bg_radius=-1, ): super().__init__() self.bound = bound self.cascade = 1 + math.ceil(math.log2(bound)) self.grid_size = 128 self.density_scale = density_scale self.min_near = min_near self.min_near_lidar = min_near_lidar self.density_thresh = density_thresh self.bg_radius = bg_radius # radius of the background sphere. # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) aabb_infer = aabb_train.clone() self.register_buffer("aabb_train", aabb_train) self.register_buffer("aabb_infer", aabb_infer) def forward(self, x, d): raise NotImplementedError() # separated density and color query (can accelerate non-cuda-ray mode.) def density(self, x): raise NotImplementedError() def color(self, x, d, mask=None, **kwargs): raise NotImplementedError() def run( self, rays_o, rays_d, cal_lidar_color=False, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs ): # rays_o, rays_d: [B, N, 3], assumes B == 1 # bg_color: [3] in range [0, 1] # return: image: [B, N, 3], depth: [B, N] if cal_lidar_color: self.out_dim = self.out_lidar_color_dim else: self.out_dim = self.out_color_dim prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # choose aabb aabb = self.aabb_train if self.training else self.aabb_infer # sample steps if cal_lidar_color: nears = ( torch.ones(N, dtype=rays_o.dtype, device=rays_o.device) * self.min_near_lidar ) fars = ( torch.ones(N, dtype=rays_o.dtype, device=rays_o.device) * self.min_near_lidar * 81.0 ) # hard code else: nears, fars = raymarching.near_far_from_aabb( rays_o, rays_d, aabb, self.min_near ) nears.unsqueeze_(-1) fars.unsqueeze_(-1) # print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze( 0 ) # [1, T] z_vals = z_vals.expand((N, num_steps)) # [N, T] z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] # perturb z_vals sample_dist = (fars - nears) / num_steps if perturb: z_vals = ( z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist ) # z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. # generate xyzs xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze( -1 ) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) # query SDF and RGB density_outputs = self.density(xyzs.reshape(-1, 3)) # sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] for k, v in density_outputs.items(): density_outputs[k] = v.view(N, num_steps, -1) # upsample z_vals (nerf-like) if upsample_steps > 0: with torch.no_grad(): deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] deltas = torch.cat( [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1 ) alphas = 1 - torch.exp( -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1) ) # [N, T] alphas_shifted = torch.cat( [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1 ) # [N, T+1] weights = ( alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] ) # [N, T] z_vals_mid = z_vals[..., :-1] + 0.5 * deltas[..., :-1] # [N, T-1] new_z_vals = sample_pdf( z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training ).detach() # [N, t] new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze( -2 ) * new_z_vals.unsqueeze( -1 ) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] new_xyzs = torch.min( torch.max(new_xyzs, aabb[:3]), aabb[3:] ) # a manual clip. # only forward new points to save computation new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] for k, v in new_density_outputs.items(): new_density_outputs[k] = v.view(N, upsample_steps, -1) # re-order z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] z_vals, z_index = torch.sort(z_vals, dim=1) xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] xyzs = torch.gather( xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs) ) for k in density_outputs: tmp_output = torch.cat( [density_outputs[k], new_density_outputs[k]], dim=1 ) density_outputs[k] = torch.gather( tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output) ) deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] deltas = torch.cat( [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1 ) alphas = 1 - torch.exp( -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1) ) # [N, T+t] alphas_shifted = torch.cat( [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1 ) # [N, T+t+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) for k, v in density_outputs.items(): density_outputs[k] = v.view(-1, v.shape[-1]) mask = weights > 1e-4 # hard coded rgbs = self.color( xyzs.reshape(-1, 3), dirs.reshape(-1, 3), cal_lidar_color=cal_lidar_color, mask=mask.reshape(-1), **density_outputs ) rgbs = rgbs.view(N, -1, self.out_dim) # [N, T+t, 3] # print(xyzs.shape, 'valid_rgb:', mask.sum().item()) # calculate weight_sum (mask) weights_sum = weights.sum(dim=-1) # [N] # calculate depth Note: not real depth!! # ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) # depth = torch.sum(weights * ori_z_vals, dim=-1) depth = torch.sum(weights * z_vals, dim=-1) # calculate color image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color sph = raymarching.sph_from_ray( rays_o, rays_d, self.bg_radius ) # [N, 2] in [-1, 1] bg_color = self.background(sph, rays_d.reshape(-1, 3)) # [N, 3] elif bg_color is None: bg_color = 1 if not cal_lidar_color: image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, self.out_dim) depth = depth.view(*prefix) # tmp: reg loss in mip-nerf 360 # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() return { "depth_lidar": depth, "image_lidar": image, "weights_sum_lidar": weights_sum, } def render( self, rays_o, rays_d, cal_lidar_color=False, staged=False, max_ray_batch=4096, **kwargs ): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: pred_rgb: [B, N, 3] _run = self.run B, N = rays_o.shape[:2] device = rays_o.device if staged: if cal_lidar_color: out_dim = self.out_lidar_color_dim res_keys = ["depth_lidar", "image_lidar"] depth = torch.empty((B, N), device=device) image = torch.empty((B, N, out_dim), device=device) for b in range(B): head = 0 while head < N: tail = min(head + max_ray_batch, N) results_ = _run( rays_o[b : b + 1, head:tail], rays_d[b : b + 1, head:tail], cal_lidar_color=cal_lidar_color, **kwargs ) depth[b : b + 1, head:tail] = results_[res_keys[0]] image[b : b + 1, head:tail] = results_[res_keys[1]] head += max_ray_batch results = {} results[res_keys[0]] = depth results[res_keys[1]] = image else: results = _run(rays_o, rays_d, cal_lidar_color=cal_lidar_color, **kwargs) return results ================================================ FILE: lidarnerf/nerf/utils.py ================================================ import glob import os import random import time import cv2 import imageio import lpips import mcubes import numpy as np import tensorboardX import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import tqdm import trimesh from rich.console import Console from skimage.metrics import structural_similarity from torch_ema import ExponentialMovingAverage from extern.chamfer3D.dist_chamfer_3D import chamfer_3DDist from extern.fscore import fscore from lidarnerf.dataset.base_dataset import custom_meshgrid from lidarnerf.convert import pano_to_lidar def is_ali_cluster(): import socket hostname = socket.gethostname() return "auto-drive" in hostname @torch.jit.script def linear_to_srgb(x): return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x**0.41666 - 0.055) @torch.jit.script def srgb_to_linear(x): return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) def filter_bbox_dataset(pc, OBB_local): bbox_mask = np.isnan(pc[:, 0]) z_min, z_max = min(OBB_local[:, 2]), max(OBB_local[:, 2]) for i, (c1, c2) in enumerate(zip(pc[:, 2] <= z_max, pc[:, 2] >= z_min)): bbox_mask[i] = c1 and c2 pc = pc[bbox_mask] OBB_local = sorted(OBB_local, key=lambda p: p[2]) OBB_2D = np.array(OBB_local)[:4, :2] pc = filter_poly(pc, OBB_2D) return pc def filter_poly(pcs, OBB_2D): OBB_2D = sort_quadrilateral(OBB_2D) mask = [] for pc in pcs: mask.append(is_in_poly(pc[0], pc[1], OBB_2D)) return pcs[mask] def sort_quadrilateral(points): points = points.tolist() top_left = min(points, key=lambda p: p[0] + p[1]) bottom_right = max(points, key=lambda p: p[0] + p[1]) points.remove(top_left) points.remove(bottom_right) bottom_left, top_right = points if bottom_left[1] > top_right[1]: bottom_left, top_right = top_right, bottom_left return [top_left, top_right, bottom_right, bottom_left] def is_in_poly(px, py, poly): """ :param p: [x, y] :param poly: [[], [], [], [], ...] :return: """ is_in = False for i, corner in enumerate(poly): next_i = i + 1 if i + 1 < len(poly) else 0 x1, y1 = corner x2, y2 = poly[next_i] if (x1 == px and y1 == py) or (x2 == px and y2 == py): # if point is on vertex is_in = True break if min(y1, y2) < py <= max(y1, y2): # find horizontal edges of polygon x = x1 + (py - y1) * (x2 - x1) / (y2 - y1) if x == px: # if point is on edge is_in = True break elif x > px: # if point is on left-side of line is_in = not is_in return is_in def seed_everything(seed): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = True def torch_vis_2d(x, renormalize=False): # x: [3, H, W] or [1, H, W] or [H, W] import matplotlib.pyplot as plt import numpy as np import torch if isinstance(x, torch.Tensor): if len(x.shape) == 3: x = x.permute(1, 2, 0).squeeze() x = x.detach().cpu().numpy() print(f"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}") x = x.astype(np.float32) # renormalize if renormalize: x = (x - x.min(axis=0, keepdims=True)) / ( x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8 ) plt.imshow(x) plt.show() def extract_fields(bound_min, bound_max, resolution, query_func, S=128): X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S) Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S) Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S) u = np.zeros([resolution, resolution, resolution], dtype=np.float32) with torch.no_grad(): for xi, xs in enumerate(X): for yi, ys in enumerate(Y): for zi, zs in enumerate(Z): xx, yy, zz = custom_meshgrid(xs, ys, zs) pts = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1, ) # [S, 3] val = ( query_func(pts) .reshape(len(xs), len(ys), len(zs)) .detach() .cpu() .numpy() ) # [S, 1] --> [x, y, z] u[ xi * S : xi * S + len(xs), yi * S : yi * S + len(ys), zi * S : zi * S + len(zs), ] = val return u def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): # print('threshold: {}'.format(threshold)) u = extract_fields(bound_min, bound_max, resolution, query_func) # print(u.shape, u.max(), u.min(), np.percentile(u, 50)) vertices, triangles = mcubes.marching_cubes(u, threshold) b_max_np = bound_max.detach().cpu().numpy() b_min_np = bound_min.detach().cpu().numpy() vertices = ( vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] ) return vertices, triangles class PSNRMeter: def __init__(self): self.V = 0 self.N = 0 def clear(self): self.V = 0 self.N = 0 def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): inp = inp.detach().cpu().numpy() outputs.append(inp) return outputs def update(self, preds, truths): preds, truths = self.prepare_inputs( preds, truths ) # [B, N, 3] or [B, H, W, 3], range[0, 1] # simplified since max_pixel_value is 1 here. psnr = -10 * np.log10(np.mean((preds - truths) ** 2)) self.V += psnr self.N += 1 def measure(self): return self.V / self.N def write(self, writer, global_step, prefix=""): writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step) def report(self): return f"PSNR = {self.measure():.6f}" class RMSEMeter: def __init__(self): self.V = 0 self.N = 0 def clear(self): self.V = 0 self.N = 0 def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): inp = inp.detach().cpu().numpy() outputs.append(inp) return outputs def update(self, preds, truths): preds, truths = self.prepare_inputs( preds, truths ) # [B, N, 3] or [B, H, W, 3], range[0, 1] rmse = (truths - preds) ** 2 rmse = np.sqrt(rmse.mean()) self.V += rmse self.N += 1 def measure(self): return self.V / self.N def write(self, writer, global_step, prefix=""): writer.add_scalar(os.path.join(prefix, "RMSE"), self.measure(), global_step) def report(self): return f"RMSE = {self.measure():.6f}" class MAEMeter: def __init__(self, intensity_inv_scale=1.0): self.V = 0 self.N = 0 self.intensity_inv_scale = intensity_inv_scale def clear(self): self.V = 0 self.N = 0 def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): inp = inp.detach().cpu().numpy() outputs.append(inp) return outputs def update(self, preds, truths): preds, truths = self.prepare_inputs( preds, truths ) # [B, N, 3] or [B, H, W, 3], range[0, 1] # Mean Absolute Error mae = np.abs( truths * self.intensity_inv_scale - preds * self.intensity_inv_scale ).mean() self.V += mae self.N += 1 def measure(self): return self.V / self.N def write(self, writer, global_step, prefix=""): writer.add_scalar(os.path.join(prefix, "MAE"), self.measure(), global_step) def report(self): return f"MAE = {self.measure():.6f}" class DepthMeter: def __init__(self, scale): self.V = [] self.N = 0 self.scale = scale self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def clear(self): self.V = [] self.N = 0 def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): inp = inp.detach().cpu().numpy() outputs.append(inp) return outputs def update(self, preds, truths): preds = preds / self.scale truths = truths / self.scale preds, truths = self.prepare_inputs( preds, truths ) # [B, N, 3] or [B, H, W, 3], range[0, 1] # simplified since max_pixel_value is 1 here. depth_error = self.compute_depth_errors(truths, preds) depth_error = list(depth_error) self.V.append(depth_error) self.N += 1 def compute_depth_errors( self, gt, pred, min_depth=1e-3, max_depth=80, thresh_set=1.25 ): pred[pred < min_depth] = min_depth pred[pred > max_depth] = max_depth gt[gt < min_depth] = min_depth gt[gt > max_depth] = max_depth thresh = np.maximum((gt / pred), (pred / gt)) a1 = (thresh < thresh_set).mean() a2 = (thresh < thresh_set**2).mean() a3 = (thresh < thresh_set**3).mean() rmse = (gt - pred) ** 2 rmse = np.sqrt(rmse.mean()) ssim = structural_similarity( pred.squeeze(0), gt.squeeze(0), data_range=np.max(gt) - np.min(gt) ) return rmse, a1, a2, a3, ssim def measure(self): assert self.N == len(self.V) return np.array(self.V).mean(0) def write(self, writer, global_step, prefix=""): writer.add_scalar( os.path.join(prefix, "depth error"), self.measure()[0], global_step ) def report(self): return f"Depth_error(rmse, a1, a2, a3, ssim) = {self.measure()}" class PointsMeter: def __init__(self, scale, intrinsics): self.V = [] self.N = 0 self.scale = scale self.intrinsics = intrinsics def clear(self): self.V = [] self.N = 0 def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): inp = inp.detach().cpu().numpy() outputs.append(inp) return outputs def update(self, preds, truths): preds = preds / self.scale truths = truths / self.scale preds, truths = self.prepare_inputs( preds, truths ) # [B, N, 3] or [B, H, W, 3], range[0, 1] chamLoss = chamfer_3DDist() pred_lidar = pano_to_lidar(preds[0], self.intrinsics) gt_lidar = pano_to_lidar(truths[0], self.intrinsics) dist1, dist2, idx1, idx2 = chamLoss( torch.FloatTensor(pred_lidar[None, ...]).cuda(), torch.FloatTensor(gt_lidar[None, ...]).cuda(), ) chamfer_dis = dist1.mean() + dist2.mean() threshold = 0.05 # monoSDF f_score, precision, recall = fscore(dist1, dist2, threshold) f_score = f_score.cpu()[0] self.V.append([chamfer_dis.cpu(), f_score]) self.N += 1 def measure(self): # return self.V / self.N assert self.N == len(self.V) return np.array(self.V).mean(0) def write(self, writer, global_step, prefix=""): writer.add_scalar(os.path.join(prefix, "CD"), self.measure()[0], global_step) def report(self): return f"CD f-score = {self.measure()}" class SSIMMeter: def __init__(self, device=None): self.V = 0 self.N = 0 self.device = ( device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) def clear(self): self.V = 0 self.N = 0 # def prepare_inputs(self, *inputs): # outputs = [] # for i, inp in enumerate(inputs): # inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W] # inp = inp.to(self.device) # outputs.append(inp) # return outputs def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): inp = inp.detach().cpu().numpy() outputs.append(inp) return outputs def update(self, preds, truths): preds, truths = self.prepare_inputs(preds, truths) ssim = structural_similarity( preds.squeeze(0).squeeze(-1), truths.squeeze(0).squeeze(-1) ) # preds, truths = self.prepare_inputs( # preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] # ssim = structural_similarity_index_measure(preds, truths) self.V += ssim self.N += 1 def measure(self): return self.V / self.N def write(self, writer, global_step, prefix=""): writer.add_scalar(os.path.join(prefix, "SSIM"), self.measure(), global_step) def report(self): return f"SSIM = {self.measure():.6f}" class LPIPSMeter: def __init__(self, net="alex", device=None): self.V = 0 self.N = 0 self.net = net self.device = ( device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") ) self.fn = lpips.LPIPS(net=net).eval().to(self.device) def clear(self): self.V = 0 self.N = 0 def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W] inp = inp.to(self.device) outputs.append(inp) return outputs def update(self, preds, truths): preds, truths = self.prepare_inputs( preds, truths ) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1] v = self.fn( truths, preds, normalize=True ).item() # normalize=True: [0, 1] to [-1, 1] self.V += v self.N += 1 def measure(self): return self.V / self.N def write(self, writer, global_step, prefix=""): writer.add_scalar( os.path.join(prefix, f"LPIPS ({self.net})"), self.measure(), global_step ) def report(self): return f"LPIPS ({self.net}) = {self.measure():.6f}" class Trainer(object): def __init__( self, name, # name of this experiment opt, # extra conf model, # network criterion=None, # loss function, if None, assume inline implementation in train_step optimizer=None, # optimizer ema_decay=None, # if use EMA, set the decay lr_scheduler=None, # scheduler metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. depth_metrics=[], local_rank=0, # which GPU am I world_size=1, # total num of GPUs device=None, # device to use, usually setting to None is OK. (auto choose device) mute=False, # whether to mute all print fp16=False, # amp optimize level eval_interval=1, # eval once every $ epoch max_keep_ckpt=2, # max num of saved ckpts in disk workspace="workspace", # workspace to save logs & ckpts best_mode="min", # the smaller/larger result, the better use_loss_as_metric=True, # use loss as the first metric report_metric_at_train=False, # also report metrics at training use_checkpoint="latest", # which ckpt to use at init time use_tensorboardX=True, # whether to use tensorboard for logging scheduler_update_every_step=False, # whether to call scheduler.step() after every train step ): self.name = name self.opt = opt self.mute = mute self.metrics = metrics self.depth_metrics = depth_metrics self.local_rank = local_rank self.world_size = world_size self.workspace = workspace self.ema_decay = ema_decay self.fp16 = fp16 self.best_mode = best_mode self.use_loss_as_metric = use_loss_as_metric self.report_metric_at_train = report_metric_at_train self.max_keep_ckpt = max_keep_ckpt self.eval_interval = eval_interval self.use_checkpoint = use_checkpoint self.use_tensorboardX = use_tensorboardX self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") self.scheduler_update_every_step = scheduler_update_every_step self.device = ( device if device is not None else torch.device( f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" ) ) self.console = Console() model.to(self.device) if self.world_size > 1: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank] ) self.model = model if isinstance(criterion, nn.Module): criterion.to(self.device) self.criterion = criterion # optionally use LPIPS loss for patch-based training # if self.opt.patch_size > 1: # import lpips # self.criterion_lpips = lpips.LPIPS(net='alex').to(self.device) if optimizer is None: self.optimizer = optim.Adam( self.model.parameters(), lr=0.001, weight_decay=5e-4 ) # naive adam else: self.optimizer = optimizer(self.model) if lr_scheduler is None: self.lr_scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=lambda epoch: 1 ) # fake scheduler else: self.lr_scheduler = lr_scheduler(self.optimizer) if ema_decay is not None: self.ema = ExponentialMovingAverage( self.model.parameters(), decay=ema_decay ) else: self.ema = None self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) # variable init self.epoch = 0 self.global_step = 0 self.local_step = 0 self.stats = { "loss": [], "valid_loss": [], "results": [], # metrics[0], or valid_loss "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt "best_result": None, } # auto fix if len(metrics) == 0 or self.use_loss_as_metric: self.best_mode = "min" # workspace prepare self.log_ptr = None if self.workspace is not None: os.makedirs(self.workspace, exist_ok=True) self.log_path = os.path.join(workspace, f"log_{self.name}.txt") self.log_ptr = open(self.log_path, "a+") self.ckpt_path = os.path.join(self.workspace, "checkpoints") self.best_path = f"{self.ckpt_path}/{self.name}.pth" os.makedirs(self.ckpt_path, exist_ok=True) self.log( f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}' ) self.log( f"[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}" ) if self.workspace is not None: if self.use_checkpoint == "scratch": self.log("[INFO] Training from scratch ...") elif self.use_checkpoint == "latest": self.log("[INFO] Loading latest checkpoint ...") self.load_checkpoint() elif self.use_checkpoint == "latest_model": self.log("[INFO] Loading latest checkpoint (model only)...") self.load_checkpoint(model_only=True) elif self.use_checkpoint == "best": if os.path.exists(self.best_path): self.log("[INFO] Loading best checkpoint ...") self.load_checkpoint(self.best_path) else: self.log(f"[INFO] {self.best_path} not found, loading latest ...") self.load_checkpoint() else: # path to ckpt self.log(f"[INFO] Loading {self.use_checkpoint} ...") self.load_checkpoint(self.use_checkpoint) def __del__(self): if self.log_ptr: self.log_ptr.close() def log(self, *args, **kwargs): if self.local_rank == 0: if not self.mute: # print(*args) self.console.print(*args, **kwargs) if self.log_ptr: print(*args, file=self.log_ptr) self.log_ptr.flush() # write immediately to file ### ------------------------------ def train_step(self, data): # Initialize all returned values pred_intensity = None gt_intensity = None pred_depth = None gt_depth = None loss = 0 if self.opt.enable_lidar: rays_o_lidar = data["rays_o_lidar"] # [B, N, 3] rays_d_lidar = data["rays_d_lidar"] # [B, N, 3] images_lidar = data["images_lidar"] # [B, N, 3/4] B_lidar, N_lidar, C_lidar = images_lidar.shape gt_raydrop = images_lidar[:, :, 0] gt_intensity = images_lidar[:, :, 1] * gt_raydrop gt_depth = images_lidar[:, :, 2] * gt_raydrop outputs_lidar = self.model.render( rays_o_lidar, rays_d_lidar, cal_lidar_color=True, staged=False, perturb=True, force_all_rays=False if self.opt.patch_size == 1 else True, **vars(self.opt), ) pred_raydrop = outputs_lidar["image_lidar"][:, :, 0] pred_intensity = outputs_lidar["image_lidar"][:, :, 1] * gt_raydrop pred_depth = outputs_lidar["depth_lidar"] * gt_raydrop lidar_loss = ( self.opt.alpha_d * self.criterion["depth"](pred_depth, gt_depth) + self.opt.alpha_r * self.criterion["raydrop"](pred_raydrop, gt_raydrop) + self.opt.alpha_i * self.criterion["intensity"](pred_intensity, gt_intensity) ) pred_intensity = pred_intensity.unsqueeze(-1) gt_intensity = gt_intensity.unsqueeze(-1) else: lidar_loss = 0 loss = lidar_loss # special case for CCNeRF's rank-residual training if len(loss.shape) == 3: # [K, B, N] loss = loss.mean(0) loss = loss.mean() if isinstance(self.opt.patch_size_lidar, int): patch_size_x, patch_size_y = ( self.opt.patch_size_lidar, self.opt.patch_size_lidar, ) elif len(self.opt.patch_size_lidar) == 1: patch_size_x, patch_size_y = ( self.opt.patch_size_lidar[0], self.opt.patch_size_lidar[0], ) else: patch_size_x, patch_size_y = self.opt.patch_size_lidar if self.opt.enable_lidar and patch_size_x > 1: pred_depth = ( pred_depth.view(-1, patch_size_x, patch_size_y, 1) .permute(0, 3, 1, 2) .contiguous() / self.opt.scale ) if self.opt.sobel_grad: pred_grad_x = F.conv2d( pred_depth, torch.tensor( [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32 ) .unsqueeze(0) .unsqueeze(0) .to(self.device), padding=1, ) pred_grad_y = F.conv2d( pred_depth, torch.tensor( [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32 ) .unsqueeze(0) .unsqueeze(0) .to(self.device), padding=1, ) else: pred_grad_y = torch.abs( pred_depth[:, :, :-1, :] - pred_depth[:, :, 1:, :] ) pred_grad_x = torch.abs( pred_depth[:, :, :, :-1] - pred_depth[:, :, :, 1:] ) dy = torch.abs(pred_grad_y) dx = torch.abs(pred_grad_x) if self.opt.grad_norm_smooth: grad_norm = torch.mean(torch.exp(-dx)) + torch.mean(torch.exp(-dy)) # print('grad_norm', grad_norm) loss = loss + self.opt.alpha_grad_norm * grad_norm if self.opt.spatial_smooth: spatial_loss = torch.mean(dx**2) + torch.mean(dy**2) # print('spatial_loss', spatial_loss) loss = loss + self.opt.alpha_spatial * spatial_loss if self.opt.tv_loss: tv_loss = torch.mean(dx) + torch.mean(dy) # print('tv_loss', tv_loss) loss = loss + self.opt.alpha_tv * tv_loss if self.opt.grad_loss: gt_depth = ( gt_depth.view(-1, patch_size_x, patch_size_y, 1) .permute(0, 3, 1, 2) .contiguous() / self.opt.scale ) gt_raydrop = ( gt_raydrop.view(-1, patch_size_x, patch_size_y, 1) .permute(0, 3, 1, 2) .contiguous() ) # sobel if self.opt.sobel_grad: gt_grad_y = F.conv2d( gt_depth, torch.tensor( [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32 ) .unsqueeze(0) .unsqueeze(0) .to(self.device), padding=1, ) gt_grad_x = F.conv2d( gt_depth, torch.tensor( [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32 ) .unsqueeze(0) .unsqueeze(0) .to(self.device), padding=1, ) else: gt_grad_y = gt_depth[:, :, :-1, :] - gt_depth[:, :, 1:, :] gt_grad_x = gt_depth[:, :, :, :-1] - gt_depth[:, :, :, 1:] grad_clip_x = 0.01 grad_mask_x = torch.where(torch.abs(gt_grad_x) < grad_clip_x, 1, 0) grad_clip_y = 0.01 grad_mask_y = torch.where(torch.abs(gt_grad_y) < grad_clip_y, 1, 0) if self.opt.sobel_grad: mask_dx = gt_raydrop * grad_mask_x mask_dy = gt_raydrop * grad_mask_y else: mask_dx = gt_raydrop[:, :, :, :-1] * grad_mask_x mask_dy = gt_raydrop[:, :, :-1, :] * grad_mask_y if self.opt.depth_grad_loss == "cos": patch_num = pred_grad_x.shape[0] grad_loss = self.criterion["grad"]( (pred_grad_x * mask_dx).reshape(patch_num, -1), (gt_grad_x * mask_dx).reshape(patch_num, -1), ) grad_loss = 1 - grad_loss else: grad_loss = self.criterion["grad"]( pred_grad_x * mask_dx, gt_grad_x * mask_dx ) loss = loss + self.opt.alpha_grad * grad_loss.mean() return ( pred_intensity, gt_intensity, pred_depth, gt_depth, loss, ) def eval_step(self, data): pred_intensity = None pred_depth = None pred_depth_crop = None pred_raydrop = None gt_intensity = None gt_depth = None gt_depth_crop = None gt_raydrop = None loss = 0 if self.opt.enable_lidar: rays_o_lidar = data["rays_o_lidar"] # [B, N, 3] rays_d_lidar = data["rays_d_lidar"] # [B, N, 3] images_lidar = data["images_lidar"] # [B, H, W, 3/4] gt_raydrop = images_lidar[:, :, :, 0] if self.opt.dataloader == "nerf_mvl": valid_crop = gt_raydrop != -1 valid_crop_idx = torch.nonzero(valid_crop) crop_h, crop_w = ( max(valid_crop_idx[:, 1]) - min(valid_crop_idx[:, 1]) + 1, max(valid_crop_idx[:, 2]) - min(valid_crop_idx[:, 2]) + 1, ) valid_mask = torch.where(gt_raydrop == -1, 0, 1) gt_raydrop = gt_raydrop * valid_mask gt_intensity = images_lidar[:, :, :, 1] * gt_raydrop gt_depth = images_lidar[:, :, :, 2] * gt_raydrop B_lidar, H_lidar, W_lidar, C_lidar = images_lidar.shape outputs_lidar = self.model.render( rays_o_lidar, rays_d_lidar, cal_lidar_color=True, staged=True, perturb=False, **vars(self.opt), ) pred_rgb_lidar = outputs_lidar["image_lidar"].reshape( B_lidar, H_lidar, W_lidar, 2 ) pred_raydrop = pred_rgb_lidar[:, :, :, 0] raydrop_mask = torch.where(pred_raydrop > 0.5, 1, 0) if self.opt.dataloader == "nerf_mvl": raydrop_mask = raydrop_mask * valid_mask pred_intensity = pred_rgb_lidar[:, :, :, 1] pred_depth = outputs_lidar["depth_lidar"].reshape(B_lidar, H_lidar, W_lidar) # raydrop_mask = gt_raydrop # TODO if self.opt.alpha_r > 0 and (not torch.all(raydrop_mask == 0)): pred_intensity = pred_intensity * raydrop_mask pred_depth = pred_depth * raydrop_mask lidar_loss = ( self.opt.alpha_d * self.criterion["depth"](pred_depth, gt_depth).mean() + self.opt.alpha_r * self.criterion["raydrop"](pred_raydrop, gt_raydrop).mean() + self.opt.alpha_i * self.criterion["intensity"](pred_intensity, gt_intensity).mean() ) if self.opt.dataloader == "nerf_mvl": pred_intensity = pred_intensity[valid_crop].reshape( B_lidar, crop_h, crop_w ) gt_intensity = gt_intensity[valid_crop].reshape(B_lidar, crop_h, crop_w) pred_depth_crop = pred_depth[valid_crop].reshape( B_lidar, crop_h, crop_w ) gt_depth_crop = gt_depth[valid_crop].reshape(B_lidar, crop_h, crop_w) pred_intensity = pred_intensity.unsqueeze(-1) pred_raydrop = pred_raydrop.unsqueeze(-1) gt_intensity = gt_intensity.unsqueeze(-1) gt_raydrop = gt_raydrop.unsqueeze(-1) else: lidar_loss = 0 loss = lidar_loss return ( pred_intensity, pred_depth, pred_depth_crop, pred_raydrop, gt_intensity, gt_depth, gt_depth_crop, gt_raydrop, loss, ) # moved out bg_color and perturb for more flexible control... def test_step(self, data, bg_color=None, perturb=False): pred_raydrop = None pred_intensity = None pred_depth = None if self.opt.enable_lidar: rays_o_lidar = data["rays_o_lidar"] # [B, N, 3] rays_d_lidar = data["rays_d_lidar"] # [B, N, 3] H_lidar, W_lidar = data["H_lidar"], data["W_lidar"] outputs_lidar = self.model.render( rays_o_lidar, rays_d_lidar, cal_lidar_color=True, staged=True, perturb=perturb, **vars(self.opt), ) pred_rgb_lidar = outputs_lidar["image_lidar"].reshape( -1, H_lidar, W_lidar, 2 ) pred_raydrop = pred_rgb_lidar[:, :, :, 0] raydrop_mask = torch.where(pred_raydrop > 0.5, 1, 0) pred_intensity = pred_rgb_lidar[:, :, :, 1] pred_depth = outputs_lidar["depth_lidar"].reshape(-1, H_lidar, W_lidar) if self.opt.alpha_r > 0: pred_intensity = pred_intensity * raydrop_mask pred_depth = pred_depth * raydrop_mask return pred_raydrop, pred_intensity, pred_depth def save_mesh(self, save_path=None, resolution=256, threshold=10): if save_path is None: save_path = os.path.join( self.workspace, "meshes", f"{self.name}_{self.epoch}.ply" ) self.log(f"==> Saving mesh to {save_path}") os.makedirs(os.path.dirname(save_path), exist_ok=True) def query_func(pts): with torch.no_grad(): with torch.cuda.amp.autocast(enabled=self.fp16): sigma = self.model.density(pts.to(self.device))["sigma"] return sigma vertices, triangles = extract_geometry( self.model.aabb_infer[:3], self.model.aabb_infer[3:], resolution=resolution, threshold=threshold, query_func=query_func, ) mesh = trimesh.Trimesh( vertices, triangles, process=False ) # important, process=True leads to seg fault... mesh.export(save_path) self.log(f"==> Finished saving mesh.") ### ------------------------------ def train(self, train_loader, valid_loader, max_epochs): if self.use_tensorboardX and self.local_rank == 0: if is_ali_cluster() and self.opt.cluster_summary_path is not None: summary_path = self.opt.cluster_summary_path else: summary_path = os.path.join(self.workspace, "run", self.name) self.writer = tensorboardX.SummaryWriter(summary_path) change_dataloder = False if self.opt.change_patch_size_lidar[0] > 1: change_dataloder = True for epoch in range(self.epoch + 1, max_epochs + 1): self.epoch = epoch if change_dataloder: if self.epoch % self.opt.change_patch_size_epoch == 0: train_loader._data.patch_size_lidar = ( self.opt.change_patch_size_lidar ) self.opt.patch_size_lidar = self.opt.change_patch_size_lidar else: train_loader._data.patch_size_lidar = 1 self.opt.patch_size_lidar = 1 self.train_one_epoch(train_loader) if self.workspace is not None and self.local_rank == 0: self.save_checkpoint(full=True, best=False) if self.epoch % self.eval_interval == 0: self.evaluate_one_epoch(valid_loader) self.save_checkpoint(full=False, best=True) if self.use_tensorboardX and self.local_rank == 0: self.writer.close() def evaluate(self, loader, name=None): self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX self.evaluate_one_epoch(loader, name) self.use_tensorboardX = use_tensorboardX def test(self, loader, save_path=None, name=None, write_video=True): if save_path is None: save_path = os.path.join(self.workspace, "results") if name is None: name = f"{self.name}_ep{self.epoch:04d}" os.makedirs(save_path, exist_ok=True) self.log(f"==> Start Test, save results to {save_path}") pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) self.model.eval() if write_video: all_preds = [] all_preds_depth = [] with torch.no_grad(): for i, data in enumerate(loader): with torch.cuda.amp.autocast(enabled=self.fp16): preds_raydrop, preds_intensity, preds_depth = self.test_step(data) if self.opt.enable_lidar: pred_raydrop = preds_raydrop[0].detach().cpu().numpy() pred_raydrop = (np.where(pred_raydrop > 0.5, 1.0, 0.0)).reshape( loader._data.H_lidar, loader._data.W_lidar ) pred_raydrop = (pred_raydrop * 255).astype(np.uint8) pred_intensity = preds_intensity[0].detach().cpu().numpy() pred_intensity = (pred_intensity * 255).astype(np.uint8) pred_depth = preds_depth[0].detach().cpu().numpy() pred_lidar = pano_to_lidar( pred_depth / self.opt.scale, loader._data.intrinsics_lidar ) if self.opt.dataloader == "nerf_mvl": pred_lidar = filter_bbox_dataset( pred_lidar, data["OBB_local"][:, :3] ) np.save( os.path.join(save_path, f"test_{name}_{i:04d}_depth_lidar.npy"), pred_lidar, ) pred_depth = (pred_depth * 255).astype(np.uint8) # pred_depth = (pred_depth / self.opt.scale).astype(np.uint8) if write_video: all_preds.append(cv2.applyColorMap(pred_intensity, 1)) all_preds_depth.append(cv2.applyColorMap(pred_depth, 9)) else: cv2.imwrite( os.path.join(save_path, f"test_{name}_{i:04d}_raydrop.png"), pred_raydrop, ) cv2.imwrite( os.path.join( save_path, f"test_{name}_{i:04d}_intensity.png" ), cv2.applyColorMap(pred_intensity, 1), ) cv2.imwrite( os.path.join(save_path, f"test_{name}_{i:04d}_depth.png"), cv2.applyColorMap(pred_depth, 9), ) pbar.update(loader.batch_size) if write_video: if self.opt.enable_lidar: all_preds = np.stack(all_preds, axis=0) all_preds_depth = np.stack(all_preds_depth, axis=0) imageio.mimwrite( os.path.join(save_path, f"{name}_lidar_rgb.mp4"), all_preds, fps=25, quality=8, macro_block_size=1, ) imageio.mimwrite( os.path.join(save_path, f"{name}_depth.mp4"), all_preds_depth, fps=25, quality=8, macro_block_size=1, ) self.log(f"==> Finished Test.") def train_one_epoch(self, loader): self.log( f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ..." ) total_loss = 0 if self.local_rank == 0 and self.report_metric_at_train: for metric in self.metrics: metric.clear() for metric in self.depth_metrics: metric.clear() self.model.train() # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs # ref: https://pytorch.org/docs/stable/data.html if self.world_size > 1: loader.sampler.set_epoch(self.epoch) if self.local_rank == 0: pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) self.local_step = 0 for data in loader: self.local_step += 1 self.global_step += 1 self.optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=self.fp16): ( pred_intensity, gt_intensity, pred_depth, gt_depth, loss, ) = self.train_step(data) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_update_every_step: self.lr_scheduler.step() loss_val = loss.item() total_loss += loss_val if self.local_rank == 0: if self.report_metric_at_train: for i, metric in enumerate(self.depth_metrics): if i < 2: # hard code metric.update(pred_intensity, gt_intensity) else: metric.update(pred_depth, gt_depth) if self.use_tensorboardX: self.writer.add_scalar("train/loss", loss_val, self.global_step) self.writer.add_scalar( "train/lr", self.optimizer.param_groups[0]["lr"], self.global_step, ) if self.scheduler_update_every_step: pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}" ) else: pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})" ) pbar.update(loader.batch_size) if self.ema is not None: self.ema.update() average_loss = total_loss / self.local_step self.stats["loss"].append(average_loss) if self.local_rank == 0: pbar.close() if self.report_metric_at_train: for metric in self.depth_metrics: self.log(metric.report(), style="red") if self.use_tensorboardX: metric.write(self.writer, self.epoch, prefix="LiDAR_train") metric.clear() if not self.scheduler_update_every_step: if isinstance( self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau ): self.lr_scheduler.step(average_loss) else: self.lr_scheduler.step() self.log(f"==> Finished Epoch {self.epoch}.") def evaluate_one_epoch(self, loader, name=None): self.log(f"++> Evaluate at epoch {self.epoch} ...") if name is None: name = f"{self.name}_ep{self.epoch:04d}" total_loss = 0 if self.local_rank == 0: for metric in self.metrics: metric.clear() for metric in self.depth_metrics: metric.clear() self.model.eval() if self.ema is not None: self.ema.store() self.ema.copy_to() if self.local_rank == 0: pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) with torch.no_grad(): self.local_step = 0 for data in loader: self.local_step += 1 with torch.cuda.amp.autocast(enabled=self.fp16): ( preds_intensity, preds_depth, preds_depth_crop, preds_raydrop, gt_intensity, gt_depth, gt_depth_crop, gt_raydrop, loss, ) = self.eval_step(data) # all_gather/reduce the statistics (NCCL only support all_*) if self.world_size > 1: dist.all_reduce(loss, op=dist.ReduceOp.SUM) loss = loss / self.world_size preds_list = [ torch.zeros_like(preds).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(preds_list, preds) preds = torch.cat(preds_list, dim=0) preds_depth_list = [ torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(preds_depth_list, preds_depth) preds_depth = torch.cat(preds_depth_list, dim=0) truths_list = [ torch.zeros_like(truths).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(truths_list, truths) truths = torch.cat(truths_list, dim=0) loss_val = loss.item() total_loss += loss_val # only rank = 0 will perform evaluation. if self.local_rank == 0: for i, metric in enumerate(self.depth_metrics): if i < 2: # hard code metric.update(preds_intensity, gt_intensity) else: if ( self.opt.dataloader == "nerf_mvl" and i == 2 ): # hard code metric.update(preds_depth_crop, gt_depth_crop) else: metric.update(preds_depth, gt_depth) if self.opt.enable_lidar: save_path_raydrop = os.path.join( self.workspace, "validation", f"{name}_{self.local_step:04d}_rarydrop.png", ) save_path_intensity = os.path.join( self.workspace, "validation", f"{name}_{self.local_step:04d}_intensity.png", ) save_path_depth = os.path.join( self.workspace, "validation", f"{name}_{self.local_step:04d}_depth.png", ) os.makedirs(os.path.dirname(save_path_depth), exist_ok=True) pred_intensity = preds_intensity[0].detach().cpu().numpy() pred_intensity = (pred_intensity * 255).astype(np.uint8) pred_raydrop = preds_raydrop[0].detach().cpu().numpy() pred_raydrop = (np.where(pred_raydrop > 0.5, 1.0, 0.0)).reshape( loader._data.H_lidar, loader._data.W_lidar ) pred_raydrop = (pred_raydrop * 255).astype(np.uint8) pred_depth = preds_depth[0].detach().cpu().numpy() pred_lidar = pano_to_lidar( pred_depth / self.opt.scale, loader._data.intrinsics_lidar ) pred_depth = (pred_depth * 255).astype(np.uint8) # pred_depth = (pred_depth / self.opt.scale).astype(np.uint8) # cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)) cv2.imwrite(save_path_raydrop, pred_raydrop) cv2.imwrite( save_path_intensity, cv2.applyColorMap(pred_intensity, 1) ) cv2.imwrite(save_path_depth, cv2.applyColorMap(pred_depth, 9)) np.save( os.path.join( self.workspace, "validation", f"{name}_{self.local_step:04d}_lidar.npy", ), pred_lidar, ) pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})" ) pbar.update(loader.batch_size) average_loss = total_loss / self.local_step self.stats["valid_loss"].append(average_loss) if self.local_rank == 0: pbar.close() if len(self.depth_metrics) > 0: # result = self.metrics[0].measure() result = self.depth_metrics[-1].measure()[0] # hard code self.stats["results"].append( result if self.best_mode == "min" else -result ) # if max mode, use -result else: self.stats["results"].append( average_loss ) # if no metric, choose best by min loss for metric in self.depth_metrics: self.log(metric.report(), style="blue") if self.use_tensorboardX: metric.write(self.writer, self.epoch, prefix="LiDAR_evaluate") metric.clear() if self.ema is not None: self.ema.restore() self.log(f"++> Evaluate epoch {self.epoch} Finished.") def save_checkpoint(self, name=None, full=False, best=False, remove_old=True): if name is None: name = f"{self.name}_ep{self.epoch:04d}" state = { "epoch": self.epoch, "global_step": self.global_step, "stats": self.stats, } if full: state["optimizer"] = self.optimizer.state_dict() state["lr_scheduler"] = self.lr_scheduler.state_dict() state["scaler"] = self.scaler.state_dict() if self.ema is not None: state["ema"] = self.ema.state_dict() if not best: state["model"] = self.model.state_dict() file_path = f"{self.ckpt_path}/{name}.pth" if remove_old: self.stats["checkpoints"].append(file_path) if len(self.stats["checkpoints"]) > self.max_keep_ckpt: old_ckpt = self.stats["checkpoints"].pop(0) if os.path.exists(old_ckpt): os.remove(old_ckpt) torch.save(state, file_path) else: if len(self.stats["results"]) > 0: if ( self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"] ): self.log( f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}" ) self.stats["best_result"] = self.stats["results"][-1] # save ema results if self.ema is not None: self.ema.store() self.ema.copy_to() state["model"] = self.model.state_dict() # we don't consider continued training from the best ckpt, so we discard the unneeded density_grid to save some storage (especially important for dnerf) if "density_grid" in state["model"]: del state["model"]["density_grid"] if self.ema is not None: self.ema.restore() torch.save(state, self.best_path) else: self.log( f"[WARN] no evaluated results found, skip saving best checkpoint." ) def load_checkpoint(self, checkpoint=None, model_only=False): if checkpoint is None: checkpoint_list = sorted(glob.glob(f"{self.ckpt_path}/{self.name}_ep*.pth")) if checkpoint_list: checkpoint = checkpoint_list[-1] self.log(f"[INFO] Latest checkpoint is {checkpoint}") else: self.log("[WARN] No checkpoint found, model randomly initialized.") return checkpoint_dict = torch.load(checkpoint, map_location=self.device) if "model" not in checkpoint_dict: self.model.load_state_dict(checkpoint_dict) self.log("[INFO] loaded model.") return missing_keys, unexpected_keys = self.model.load_state_dict( checkpoint_dict["model"], strict=False ) self.log("[INFO] loaded model.") if len(missing_keys) > 0: self.log(f"[WARN] missing keys: {missing_keys}") if len(unexpected_keys) > 0: self.log(f"[WARN] unexpected keys: {unexpected_keys}") if self.ema is not None and "ema" in checkpoint_dict: self.ema.load_state_dict(checkpoint_dict["ema"]) if model_only: return self.stats = checkpoint_dict["stats"] self.epoch = checkpoint_dict["epoch"] self.global_step = checkpoint_dict["global_step"] self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") if self.optimizer and "optimizer" in checkpoint_dict: try: self.optimizer.load_state_dict(checkpoint_dict["optimizer"]) self.log("[INFO] loaded optimizer.") except: self.log("[WARN] Failed to load optimizer.") if self.lr_scheduler and "lr_scheduler" in checkpoint_dict: try: self.lr_scheduler.load_state_dict(checkpoint_dict["lr_scheduler"]) self.log("[INFO] loaded scheduler.") except: self.log("[WARN] Failed to load scheduler.") if self.scaler and "scaler" in checkpoint_dict: try: self.scaler.load_state_dict(checkpoint_dict["scaler"]) self.log("[INFO] loaded scaler.") except: self.log("[WARN] Failed to load scaler.") ================================================ FILE: lidarnerf/raymarching/__init__.py ================================================ ================================================ FILE: lidarnerf/raymarching/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path _backend = load( name="_raymarching", extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, sources=[ os.path.join(_src_path, "src", f) for f in [ "raymarching.cu", "bindings.cpp", ] ], ) __all__ = ["_backend"] ================================================ FILE: lidarnerf/raymarching/raymarching.py ================================================ import torch from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd try: import _raymarching as _backend except ImportError: from .backend import _backend # ---------------------------------------- # utils # ---------------------------------------- class _near_far_from_aabb(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): """near_far_from_aabb, CUDA implementation Calculate rays' intersection time (near and far) with aabb Args: rays_o: float, [N, 3] rays_d: float, [N, 3] aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) min_near: float, scalar Returns: nears: float, [N] fars: float, [N] """ if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # num rays nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) return nears, fars near_far_from_aabb = _near_far_from_aabb.apply class _sph_from_ray(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, rays_o, rays_d, radius): """sph_from_ray, CUDA implementation get spherical coordinate on the background sphere from rays. Assume rays_o are inside the Sphere(radius). Args: rays_o: [N, 3] rays_d: [N, 3] radius: scalar, float Return: coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) """ if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # num rays coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) _backend.sph_from_ray(rays_o, rays_d, radius, N, coords) return coords sph_from_ray = _sph_from_ray.apply class _morton3D(Function): @staticmethod def forward(ctx, coords): """morton3D, CUDA implementation Args: coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) TODO: check if the coord range is valid! (current 128 is safe) Returns: indices: [N], int32, in [0, 128^3) """ if not coords.is_cuda: coords = coords.cuda() N = coords.shape[0] indices = torch.empty(N, dtype=torch.int32, device=coords.device) _backend.morton3D(coords.int(), N, indices) return indices morton3D = _morton3D.apply class _morton3D_invert(Function): @staticmethod def forward(ctx, indices): """morton3D_invert, CUDA implementation Args: indices: [N], int32, in [0, 128^3) Returns: coords: [N, 3], int32, in [0, 128) """ if not indices.is_cuda: indices = indices.cuda() N = indices.shape[0] coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) _backend.morton3D_invert(indices.int(), N, coords) return coords morton3D_invert = _morton3D_invert.apply class _packbits(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, grid, thresh, bitfield=None): """packbits, CUDA implementation Pack up the density grid into a bit field to accelerate ray marching. Args: grid: float, [C, H * H * H], assume H % 2 == 0 thresh: float, threshold Returns: bitfield: uint8, [C, H * H * H / 8] """ if not grid.is_cuda: grid = grid.cuda() grid = grid.contiguous() C = grid.shape[0] H3 = grid.shape[1] N = C * H3 // 8 if bitfield is None: bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) _backend.packbits(grid, N, thresh, bitfield) return bitfield packbits = _packbits.apply # ---------------------------------------- # train functions # ---------------------------------------- class _march_rays_train(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward( ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024, ): """march rays to generate points (forward only) Args: rays_o/d: float, [N, 3] bound: float, scalar density_bitfield: uint8: [CHHH // 8] C: int H: int nears/fars: float, [N] step_counter: int32, (2), used to count the actual number of generated points. mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) perturb: bool align: int, pad output so its size is dividable by align, set to -1 to disable. force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) max_steps: int, max number of sampled points along each ray, also affect min_stepsize. Returns: xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) dirs: float, [M, 3], all generated points' view dirs. deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] """ if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) density_bitfield = density_bitfield.contiguous() N = rays_o.shape[0] # num rays M = N * max_steps # init max points number in total # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. if not force_all_rays and mean_count > 0: if align > 0: mean_count += align - mean_count % align M = mean_count xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) rays = torch.empty( N, 3, dtype=torch.int32, device=rays_o.device ) # id, offset, num_steps if step_counter is None: step_counter = torch.zeros( 2, dtype=torch.int32, device=rays_o.device ) # point counter, ray counter if perturb: noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) else: noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) _backend.march_rays_train( rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises, ) # m is the actually used points number # print(step_counter, M) # only used at the first (few) epochs. if force_all_rays or mean_count <= 0: m = step_counter[0].item() # D2H copy if align > 0: m += align - m % align xyzs = xyzs[:m] dirs = dirs[:m] deltas = deltas[:m] torch.cuda.empty_cache() return xyzs, dirs, deltas, rays march_rays_train = _march_rays_train.apply class _composite_rays_train(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4): """composite rays' rgbs, according to the ray marching formula. Args: rgbs: float, [M, 3] sigmas: float, [M,] deltas: float, [M, 2] rays: int32, [N, 3] Returns: weights_sum: float, [N,], the alpha channel depth: float, [N, ], the Depth image: float, [N, 3], the RGB channel (after multiplying alpha!) """ sigmas = sigmas.contiguous() rgbs = rgbs.contiguous() M = sigmas.shape[0] N = rays.shape[0] weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) _backend.composite_rays_train_forward( sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image ) ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image) ctx.dims = [M, N, T_thresh] return weights_sum, depth, image @staticmethod @custom_bwd def backward(ctx, grad_weights_sum, grad_depth, grad_image): # NOTE: grad_depth is not used now! It won't be propagated to sigmas. grad_weights_sum = grad_weights_sum.contiguous() grad_image = grad_image.contiguous() sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors M, N, T_thresh = ctx.dims grad_sigmas = torch.zeros_like(sigmas) grad_rgbs = torch.zeros_like(rgbs) _backend.composite_rays_train_backward( grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs, ) return grad_sigmas, grad_rgbs, None, None, None composite_rays_train = _composite_rays_train.apply # ---------------------------------------- # infer functions # ---------------------------------------- class _march_rays(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward( ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024, ): """march rays to generate points (forward only, for inference) Args: n_alive: int, number of alive rays n_step: int, how many steps we march rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) rays_t: float, [N], the alive rays' time, we only use the first n_alive. rays_o/d: float, [N, 3] bound: float, scalar density_bitfield: uint8: [CHHH // 8] C: int H: int nears/fars: float, [N] align: int, pad output so its size is dividable by align, set to -1 to disable. perturb: bool/int, int > 0 is used as the random seed. dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) max_steps: int, max number of sampled points along each ray, also affect min_stepsize. Returns: xyzs: float, [n_alive * n_step, 3], all generated points' coords dirs: float, [n_alive * n_step, 3], all generated points' view dirs. deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). """ if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) M = n_alive * n_step if align > 0: M += align - (M % align) xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) deltas = torch.zeros( M, 2, dtype=rays_o.dtype, device=rays_o.device ) # 2 vals, one for rgb, one for depth if perturb: # torch.manual_seed(perturb) # test_gui uses spp index as seed noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) else: noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) _backend.march_rays( n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises, ) return xyzs, dirs, deltas march_rays = _march_rays.apply class _composite_rays(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float def forward( ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2, ): """composite rays' rgbs, according to the ray marching formula. (for inference) Args: n_alive: int, number of alive rays n_step: int, how many steps we march rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) rays_t: float, [N], the alive rays' time sigmas: float, [n_alive * n_step,] rgbs: float, [n_alive * n_step, 3] deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). In-place Outputs: weights_sum: float, [N,], the alpha channel depth: float, [N,], the depth value image: float, [N, 3], the RGB channel (after multiplying alpha!) """ _backend.composite_rays( n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, ) return tuple() composite_rays = _composite_rays.apply ================================================ FILE: lidarnerf/raymarching/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path """ Usage: python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) python setup.py install # build extensions and install (copy) to PATH. pip install . # ditto but better (e.g., dependency & metadata handling) python setup.py develop # build extensions and install (symbolic) to PATH. pip install -e . # ditto but better (e.g., dependency & metadata handling) """ setup( name="raymarching", # package name, import this to use python API ext_modules=[ CUDAExtension( name="_raymarching", # extension name, import this to use CUDA API sources=[ os.path.join(_src_path, "src", f) for f in [ "raymarching.cu", "bindings.cpp", ] ], extra_compile_args={ "cxx": c_flags, "nvcc": nvcc_flags, }, ), ], cmdclass={ "build_ext": BuildExtension, }, ) ================================================ FILE: lidarnerf/raymarching/src/bindings.cpp ================================================ #include #include "raymarching.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // utils m.def("packbits", &packbits, "packbits (CUDA)"); m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)"); m.def("morton3D", &morton3D, "morton3D (CUDA)"); m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); // train m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); // infer m.def("march_rays", &march_rays, "march rays (CUDA)"); m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); } ================================================ FILE: lidarnerf/raymarching/src/raymarching.cu ================================================ #include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ #x " must be an int tensor") #define CHECK_IS_FLOATING(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \ x.scalar_type() == at::ScalarType::Half || \ x.scalar_type() == at::ScalarType::Double, \ #x " must be a floating tensor") inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } inline constexpr __device__ float PI() { return 3.141592653589793f; } inline constexpr __device__ float RPI() { return 0.3183098861837907f; } template inline __host__ __device__ T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } inline __host__ __device__ float signf(const float x) { return copysignf(1.0, x); } inline __host__ __device__ float clamp(const float x, const float min, const float max) { return fminf(max, fmaxf(min, x)); } inline __host__ __device__ void swapf(float &a, float &b) { float c = a; a = b; b = c; } inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); int exponent; frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, // [2, 4) --> 2, ... return fminf(max_cascade - 1, fmaxf(0, exponent)); } inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { const float mx = dt * H * 0.5; int exponent; frexpf(mx, &exponent); return fminf(max_cascade - 1, fmaxf(0, exponent)); } inline __host__ __device__ uint32_t __expand_bits(uint32_t v) { v = (v * 0x00010001u) & 0xFF0000FFu; v = (v * 0x00000101u) & 0x0F00F00Fu; v = (v * 0x00000011u) & 0xC30C30C3u; v = (v * 0x00000005u) & 0x49249249u; return v; } inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) { uint32_t xx = __expand_bits(x); uint32_t yy = __expand_bits(y); uint32_t zz = __expand_bits(z); return xx | (yy << 1) | (zz << 2); } inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) { x = x & 0x49249249; x = (x | (x >> 2)) & 0xc30c30c3; x = (x | (x >> 4)) & 0x0f00f00f; x = (x | (x >> 8)) & 0xff0000ff; x = (x | (x >> 16)) & 0x0000ffff; return x; } //////////////////////////////////////////////////// ///////////// utils ///////////// //////////////////////////////////////////////////// // rays_o/d: [N, 3] // nears/fars: [N] // scalar_t should always be float in use. template __global__ void kernel_near_far_from_aabb(const scalar_t *__restrict__ rays_o, const scalar_t *__restrict__ rays_d, const scalar_t *__restrict__ aabb, const uint32_t N, const float min_near, scalar_t *nears, scalar_t *fars) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; // get near far (assume cube scene) float near = (aabb[0] - ox) * rdx; float far = (aabb[3] - ox) * rdx; if (near > far) swapf(near, far); float near_y = (aabb[1] - oy) * rdy; float far_y = (aabb[4] - oy) * rdy; if (near_y > far_y) swapf(near_y, far_y); if (near > far_y || near_y > far) { nears[n] = fars[n] = std::numeric_limits::max(); return; } if (near_y > near) near = near_y; if (far_y < far) far = far_y; float near_z = (aabb[2] - oz) * rdz; float far_z = (aabb[5] - oz) * rdz; if (near_z > far_z) swapf(near_z, far_z); if (near > far_z || near_z > far) { nears[n] = fars[n] = std::numeric_limits::max(); return; } if (near_z > near) near = near_z; if (far_z < far) far = far_z; if (near < min_near) near = min_near; nears[n] = near; fars[n] = far; } void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "near_far_from_aabb", ([&] { kernel_near_far_from_aabb<<>>( rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); })); } // rays_o/d: [N, 3] // radius: float // coords: [N, 2] template __global__ void kernel_sph_from_ray(const scalar_t *__restrict__ rays_o, const scalar_t *__restrict__ rays_d, const float radius, const uint32_t N, scalar_t *coords) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; coords += n * 2; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; // solve t from || o + td || = radius const float A = dx * dx + dy * dy + dz * dz; const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 const float C = ox * ox + oy * oy + oz * oz - radius * radius; const float t = (-B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) // solve theta, phi (assume y is the up axis) const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) const float phi = atan2(z, x); // [-PI, PI) // normalize to [-1, 1] coords[0] = 2 * theta * RPI() - 1; coords[1] = phi * RPI(); } void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "sph_from_ray", ([&] { kernel_sph_from_ray<<>>( rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); })); } // coords: int32, [N, 3] // indices: int32, [N] __global__ void kernel_morton3D(const int *__restrict__ coords, const uint32_t N, int *indices) { // parallel const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate coords += n * 3; indices[n] = __morton3D(coords[0], coords[1], coords[2]); } void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) { static constexpr uint32_t N_THREAD = 128; kernel_morton3D<<>>( coords.data_ptr(), N, indices.data_ptr()); } // indices: int32, [N] // coords: int32, [N, 3] __global__ void kernel_morton3D_invert(const int *__restrict__ indices, const uint32_t N, int *coords) { // parallel const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate coords += n * 3; const int ind = indices[n]; coords[0] = __morton3D_invert(ind >> 0); coords[1] = __morton3D_invert(ind >> 1); coords[2] = __morton3D_invert(ind >> 2); } void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) { static constexpr uint32_t N_THREAD = 128; kernel_morton3D_invert<<>>( indices.data_ptr(), N, coords.data_ptr()); } // grid: float, [C, H, H, H] // N: int, C * H * H * H / 8 // density_thresh: float // bitfield: uint8, [N] template __global__ void kernel_packbits(const scalar_t *__restrict__ grid, const uint32_t N, const float density_thresh, uint8_t *bitfield) { // parallel per byte const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate grid += n * 8; uint8_t bits = 0; #pragma unroll for (uint8_t i = 0; i < 8; i++) { bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; } bitfield[n] = bits; } void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( grid.scalar_type(), "packbits", ([&] { kernel_packbits<<>>( grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); })); } //////////////////////////////////////////////////// ///////////// training ///////////// //////////////////////////////////////////////////// // rays_o/d: [N, 3] // grid: [CHHH / 8] // xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] // dirs: [M, 3] // rays: [N, 3], idx, offset, num_steps template __global__ void kernel_march_rays_train(const scalar_t *__restrict__ rays_o, const scalar_t *__restrict__ rays_d, const uint8_t *__restrict__ grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const scalar_t *__restrict__ nears, const scalar_t *__restrict__ fars, scalar_t *xyzs, scalar_t *dirs, scalar_t *deltas, int *rays, int *counter, const scalar_t *__restrict__ noises) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; // ray marching const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; const float rH = 1 / (float)H; const float H3 = H * H * H; const float near = nears[n]; const float far = fars[n]; const float noise = noises[n]; const float dt_min = 2 * SQRT3() / max_steps; const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; float t0 = near; // perturb t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise; // first pass: estimation of num_steps float t = t0; uint32_t num_steps = 0; // if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, // far); while (t < far && num_steps < max_steps) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf(scalbnf(1.0f, level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const uint32_t index = level * H3 + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output // if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, // density, density_thresh, num_steps); if (occ) { num_steps++; t += dt; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } // printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, // num_steps, near, far, dt_min, (far - near) / dt_min); // second pass: really locate and write points & dirs uint32_t point_index = atomicAdd(counter, num_steps); uint32_t ray_index = atomicAdd(counter + 1, 1); // printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, // num_steps, point_index, ray_index); // write rays rays[ray_index * 3] = n; rays[ray_index * 3 + 1] = point_index; rays[ray_index * 3 + 2] = num_steps; if (num_steps == 0) return; if (point_index + num_steps > M) return; xyzs += point_index * 3; dirs += point_index * 3; deltas += point_index * 2; t = t0; uint32_t step = 0; float last_t = t; while (t < far && step < num_steps) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf(scalbnf(1.0f, level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); // query grid const uint32_t index = level * H3 + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output if (occ) { // write step xyzs[0] = x; xyzs[1] = y; xyzs[2] = z; dirs[0] = dx; dirs[1] = dy; dirs[2] = dz; t += dt; deltas[0] = dt; deltas[1] = t - last_t; // used to calc depth last_t = t; xyzs += 3; dirs += 3; deltas += 2; step++; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } } void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "march_rays_train", ([&] { kernel_march_rays_train<<>>( rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), noises.data_ptr()); })); } // sigmas: [M] // rgbs: [M, 3] // deltas: [M, 2] // rays: [N, 3], idx, offset, num_steps // weights_sum: [N], final pixel alpha // depth: [N,] // image: [N, 3] template __global__ void kernel_composite_rays_train_forward( const scalar_t *__restrict__ sigmas, const scalar_t *__restrict__ rgbs, const scalar_t *__restrict__ deltas, const int *__restrict__ rays, const uint32_t M, const uint32_t N, const float T_thresh, scalar_t *weights_sum, scalar_t *depth, scalar_t *image) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate uint32_t index = rays[n * 3]; uint32_t offset = rays[n * 3 + 1]; uint32_t num_steps = rays[n * 3 + 2]; // empty ray, or ray that exceed max step count. if (num_steps == 0 || offset + num_steps > M) { weights_sum[index] = 0; depth[index] = 0; image[index * 3] = 0; image[index * 3 + 1] = 0; image[index * 3 + 2] = 0; return; } sigmas += offset; rgbs += offset * 3; deltas += offset * 2; // accumulate uint32_t step = 0; scalar_t T = 1.0f; scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0; while (step < num_steps) { const scalar_t alpha = 1.0f - __expf(-sigmas[0] * deltas[0]); const scalar_t weight = alpha * T; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; t += deltas[1]; // real delta d += weight * t; ws += weight; T *= 1.0f - alpha; // minimal remained transmittence if (T < T_thresh) break; // printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, // d=%f\n", n, step, alpha, weight, T, sum_delta, d); // locate sigmas++; rgbs += 3; deltas += 2; step++; } // printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); // write weights_sum[index] = ws; // weights_sum depth[index] = d; image[index * 3] = r; image[index * 3 + 1] = g; image[index * 3 + 2] = b; } void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( sigmas.scalar_type(), "composite_rays_train_forward", ([&] { kernel_composite_rays_train_forward<<>>( sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, T_thresh, weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); })); } // grad_weights_sum: [N,] // grad: [N, 3] // sigmas: [M] // rgbs: [M, 3] // deltas: [M, 2] // rays: [N, 3], idx, offset, num_steps // weights_sum: [N,], weights_sum here // image: [N, 3] // grad_sigmas: [M] // grad_rgbs: [M, 3] template __global__ void kernel_composite_rays_train_backward( const scalar_t *__restrict__ grad_weights_sum, const scalar_t *__restrict__ grad_image, const scalar_t *__restrict__ sigmas, const scalar_t *__restrict__ rgbs, const scalar_t *__restrict__ deltas, const int *__restrict__ rays, const scalar_t *__restrict__ weights_sum, const scalar_t *__restrict__ image, const uint32_t M, const uint32_t N, const float T_thresh, scalar_t *grad_sigmas, scalar_t *grad_rgbs) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate uint32_t index = rays[n * 3]; uint32_t offset = rays[n * 3 + 1]; uint32_t num_steps = rays[n * 3 + 2]; if (num_steps == 0 || offset + num_steps > M) return; grad_weights_sum += index; grad_image += index * 3; weights_sum += index; image += index * 3; sigmas += offset; rgbs += offset * 3; deltas += offset * 2; grad_sigmas += offset; grad_rgbs += offset * 3; // accumulate uint32_t step = 0; scalar_t T = 1.0f; const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; scalar_t r = 0, g = 0, b = 0, ws = 0; while (step < num_steps) { const scalar_t alpha = 1.0f - __expf(-sigmas[0] * deltas[0]); const scalar_t weight = alpha * T; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; ws += weight; T *= 1.0f - alpha; // check https://note.kiui.moe/others/nerf_gradient/ for the gradient // calculation. write grad_rgbs grad_rgbs[0] = grad_image[0] * weight; grad_rgbs[1] = grad_image[1] * weight; grad_rgbs[2] = grad_image[2] * weight; // write grad_sigmas grad_sigmas[0] = deltas[0] * (grad_image[0] * (T * rgbs[0] - (r_final - r)) + grad_image[1] * (T * rgbs[1] - (g_final - g)) + grad_image[2] * (T * rgbs[2] - (b_final - b)) + grad_weights_sum[0] * (1 - ws_final)); // printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, // r=%f\n", n, step, T, grad_sigmas[0], r_final, r); // minimal remained transmittence if (T < T_thresh) break; // locate sigmas++; rgbs += 3; deltas += 2; grad_sigmas++; grad_rgbs += 3; step++; } } void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_image.scalar_type(), "composite_rays_train_backward", ([&] { kernel_composite_rays_train_backward<<< div_round_up(N, N_THREAD), N_THREAD>>>( grad_weights_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), image.data_ptr(), M, N, T_thresh, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); })); } //////////////////////////////////////////////////// ///////////// infernce ///////////// //////////////////////////////////////////////////// template __global__ void kernel_march_rays(const uint32_t n_alive, const uint32_t n_step, const int *__restrict__ rays_alive, const scalar_t *__restrict__ rays_t, const scalar_t *__restrict__ rays_o, const scalar_t *__restrict__ rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const uint8_t *__restrict__ grid, const scalar_t *__restrict__ nears, const scalar_t *__restrict__ fars, scalar_t *xyzs, scalar_t *dirs, scalar_t *deltas, const scalar_t *__restrict__ noises) { const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= n_alive) return; const int index = rays_alive[n]; // ray id const float noise = noises[n]; // locate rays_o += index * 3; rays_d += index * 3; xyzs += n * n_step * 3; dirs += n * n_step * 3; deltas += n * n_step * 2; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; const float rH = 1 / (float)H; const float H3 = H * H * H; float t = rays_t[index]; // current ray's t const float near = nears[index], far = fars[index]; const float dt_min = 2 * SQRT3() / max_steps; const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; // march for n_step steps, record points uint32_t step = 0; // introduce some randomness t += clamp(t * dt_gamma, dt_min, dt_max) * noise; float last_t = t; while (t < far && step < n_step) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf(scalbnf(1, level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const uint32_t index = level * H3 + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output if (occ) { // write step xyzs[0] = x; xyzs[1] = y; xyzs[2] = z; dirs[0] = dx; dirs[1] = dy; dirs[2] = dz; // calc dt t += dt; deltas[0] = dt; deltas[1] = t - last_t; // used to calc depth last_t = t; // step xyzs += 3; dirs += 3; deltas += 2; step++; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } } void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "march_rays", ([&] { kernel_march_rays<<>>( n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), noises.data_ptr()); })); } template __global__ void kernel_composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, int *rays_alive, scalar_t *rays_t, const scalar_t *__restrict__ sigmas, const scalar_t *__restrict__ rgbs, const scalar_t *__restrict__ deltas, scalar_t *weights_sum, scalar_t *depth, scalar_t *image) { const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= n_alive) return; const int index = rays_alive[n]; // ray id // locate sigmas += n * n_step; rgbs += n * n_step * 3; deltas += n * n_step * 2; rays_t += index; weights_sum += index; depth += index; image += index * 3; scalar_t t = rays_t[0]; // current ray's t scalar_t weight_sum = weights_sum[0]; scalar_t d = depth[0]; scalar_t r = image[0]; scalar_t g = image[1]; scalar_t b = image[2]; // accumulate uint32_t step = 0; while (step < n_step) { // ray is terminated if delta == 0 if (deltas[0] == 0) break; const scalar_t alpha = 1.0f - __expf(-sigmas[0] * deltas[0]); /* T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) w_i = alpha_i * T_i --> T_i = 1 - \sum_{j=0}^{i-1} w_j */ const scalar_t T = 1 - weight_sum; const scalar_t weight = alpha * T; weight_sum += weight; t += deltas[1]; // real delta d += weight * t; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; // printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, // d=%f\n", n, step, alpha, weight, T, sum_delta, d); // ray is terminated if T is too small // use a larger bound to further accelerate inference if (T < T_thresh) break; // locate sigmas++; rgbs += 3; deltas += 2; step++; } // printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); // rays_alive = -1 means ray is terminated early. if (step < n_step) { rays_alive[n] = -1; } else { rays_t[0] = t; } weights_sum[0] = weight_sum; // this is the thing I needed! depth[0] = d; image[0] = r; image[1] = g; image[2] = b; } void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( image.scalar_type(), "composite_rays", ([&] { kernel_composite_rays<<>>( n_alive, n_step, T_thresh, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); })); } ================================================ FILE: lidarnerf/raymarching/src/raymarching.h ================================================ #pragma once #include #include void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices); void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords); void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, at::Tensor noises); void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, at::Tensor grad_sigmas, at::Tensor grad_rgbs); void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor noises); void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); ================================================ FILE: lidarnerf/shencoder/__init__.py ================================================ ================================================ FILE: lidarnerf/shencoder/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path _backend = load( name="_sh_encoder", extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, sources=[ os.path.join(_src_path, "src", f) for f in [ "shencoder.cu", "bindings.cpp", ] ], ) __all__ = ["_backend"] ================================================ FILE: lidarnerf/shencoder/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path setup( name="shencoder", # package name, import this to use python API ext_modules=[ CUDAExtension( name="_shencoder", # extension name, import this to use CUDA API sources=[ os.path.join(_src_path, "src", f) for f in [ "shencoder.cu", "bindings.cpp", ] ], extra_compile_args={ "cxx": c_flags, "nvcc": nvcc_flags, }, ), ], cmdclass={ "build_ext": BuildExtension, }, ) ================================================ FILE: lidarnerf/shencoder/sphere_harmonics.py ================================================ import torch import torch.nn as nn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd try: import _shencoder as _backend except ImportError: from .backend import _backend class _sh_encoder(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision def forward(ctx, inputs, degree, calc_grad_inputs=False): # inputs: [B, input_dim], float in [-1, 1] # RETURN: [B, F], float inputs = inputs.contiguous() B, input_dim = inputs.shape # batch size, coord dim output_dim = degree**2 outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) if calc_grad_inputs: dy_dx = torch.empty( B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device ) else: dy_dx = None _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) ctx.save_for_backward(inputs, dy_dx) ctx.dims = [B, input_dim, degree] return outputs @staticmethod # @once_differentiable @custom_bwd def backward(ctx, grad): # grad: [B, C * C] inputs, dy_dx = ctx.saved_tensors if dy_dx is not None: grad = grad.contiguous() B, input_dim, degree = ctx.dims grad_inputs = torch.zeros_like(inputs) _backend.sh_encode_backward( grad, inputs, B, input_dim, degree, dy_dx, grad_inputs ) return grad_inputs, None, None else: return None, None, None sh_encode = _sh_encoder.apply class SHEncoder(nn.Module): def __init__(self, input_dim=3, degree=4): super().__init__() self.input_dim = input_dim # coord dims, must be 3 self.degree = degree # 0 ~ 4 self.output_dim = degree**2 assert self.input_dim == 3, "SH encoder only support input dim == 3" assert ( self.degree > 0 and self.degree <= 8 ), "SH encoder only supports degree in [1, 8]" def __repr__(self): return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" def forward(self, inputs, size=1): # inputs: [..., input_dim], normalized real world positions in [-size, size] # return: [..., degree^2] inputs = inputs / size # [-1, 1] prefix_shape = list(inputs.shape[:-1]) inputs = inputs.reshape(-1, self.input_dim) outputs = sh_encode(inputs, self.degree, inputs.requires_grad) outputs = outputs.reshape(prefix_shape + [self.output_dim]) return outputs ================================================ FILE: lidarnerf/shencoder/src/bindings.cpp ================================================ #include #include "shencoder.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); } ================================================ FILE: lidarnerf/shencoder/src/shencoder.cu ================================================ #include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) \ TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ #x " must be an int tensor") #define CHECK_IS_FLOATING(x) \ TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || \ x.scalar_type() == at::ScalarType::Half || \ x.scalar_type() == at::ScalarType::Double, \ #x " must be a floating tensor") template __host__ __device__ T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } template __global__ void kernel_sh(const scalar_t *__restrict__ inputs, scalar_t *outputs, uint32_t B, uint32_t D, uint32_t C, scalar_t *dy_dx) { const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x; if (b >= B) return; const uint32_t C2 = C * C; // locate inputs += b * D; outputs += b * C2; scalar_t x = inputs[0], y = inputs[1], z = inputs[2]; scalar_t xy = x * y, xz = x * z, yz = y * z, x2 = x * x, y2 = y * y, z2 = z * z, xyz = xy * z; scalar_t x4 = x2 * x2, y4 = y2 * y2, z4 = z2 * z2; scalar_t x6 = x4 * x2, y6 = y4 * y2, z6 = z4 * z2; auto write_sh = [&]() { outputs[0] = 0.28209479177387814f; // 1/(2*sqrt(pi)) if (C <= 1) { return; } outputs[1] = -0.48860251190291987f * y; // -sqrt(3)*y/(2*sqrt(pi)) outputs[2] = 0.48860251190291987f * z; // sqrt(3)*z/(2*sqrt(pi)) outputs[3] = -0.48860251190291987f * x; // -sqrt(3)*x/(2*sqrt(pi)) if (C <= 2) { return; } outputs[4] = 1.0925484305920792f * xy; // sqrt(15)*xy/(2*sqrt(pi)) outputs[5] = -1.0925484305920792f * yz; // -sqrt(15)*yz/(2*sqrt(pi)) outputs[6] = 0.94617469575755997f * z2 - 0.31539156525251999f; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) outputs[7] = -1.0925484305920792f * xz; // -sqrt(15)*xz/(2*sqrt(pi)) outputs[8] = 0.54627421529603959f * x2 - 0.54627421529603959f * y2; // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) if (C <= 3) { return; } outputs[9] = 0.59004358992664352f * y * (-3.0f * x2 + y2); // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) outputs[10] = 2.8906114426405538f * xy * z; // sqrt(105)*xy*z/(2*sqrt(pi)) outputs[11] = 0.45704579946446572f * y * (1.0f - 5.0f * z2); // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) outputs[12] = 0.3731763325901154f * z * (5.0f * z2 - 3.0f); // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) outputs[13] = 0.45704579946446572f * x * (1.0f - 5.0f * z2); // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) outputs[14] = 1.4453057213202769f * z * (x2 - y2); // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) outputs[15] = 0.59004358992664352f * x * (-x2 + 3.0f * y2); // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) if (C <= 4) { return; } outputs[16] = 2.5033429417967046f * xy * (x2 - y2); // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) outputs[17] = 1.7701307697799304f * yz * (-3.0f * x2 + y2); // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) outputs[18] = 0.94617469575756008f * xy * (7.0f * z2 - 1.0f); // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) outputs[19] = 0.66904654355728921f * yz * (3.0f - 7.0f * z2); // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) outputs[20] = -3.1735664074561294f * z2 + 3.7024941420321507f * z4 + 0.31735664074561293f; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) outputs[21] = 0.66904654355728921f * xz * (3.0f - 7.0f * z2); // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) outputs[22] = 0.47308734787878004f * (x2 - y2) * (7.0f * z2 - 1.0f); // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) outputs[23] = 1.7701307697799304f * xz * (-x2 + 3.0f * y2); // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) outputs[24] = -3.7550144126950569f * x2 * y2 + 0.62583573544917614f * x4 + 0.62583573544917614f * y4; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) if (C <= 5) { return; } outputs[25] = 0.65638205684017015f * y * (10.0f * x2 * y2 - 5.0f * x4 - y4); // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) outputs[26] = 8.3026492595241645f * xy * z * (x2 - y2); // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) outputs[27] = -0.48923829943525038f * y * (3.0f * x2 - y2) * (9.0f * z2 - 1.0f); // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) outputs[28] = 4.7935367849733241f * xy * z * (3.0f * z2 - 1.0f); // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) outputs[29] = 0.45294665119569694f * y * (14.0f * z2 - 21.0f * z4 - 1.0f); // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) outputs[30] = 0.1169503224534236f * z * (-70.0f * z2 + 63.0f * z4 + 15.0f); // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) outputs[31] = 0.45294665119569694f * x * (14.0f * z2 - 21.0f * z4 - 1.0f); // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) outputs[32] = 2.3967683924866621f * z * (x2 - y2) * (3.0f * z2 - 1.0f); // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) outputs[33] = -0.48923829943525038f * x * (x2 - 3.0f * y2) * (9.0f * z2 - 1.0f); // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) outputs[34] = 2.0756623148810411f * z * (-6.0f * x2 * y2 + x4 + y4); // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) outputs[35] = 0.65638205684017015f * x * (10.0f * x2 * y2 - x4 - 5.0f * y4); // 3*sqrt(154)*x*(10*x2*y2 - x4 - // 5*y4)/(32*sqrt(pi)) if (C <= 6) { return; } outputs[36] = 1.3663682103838286f * xy * (-10.0f * x2 * y2 + 3.0f * x4 + 3.0f * y4); // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + // 3*y4)/(32*sqrt(pi)) outputs[37] = 2.3666191622317521f * yz * (10.0f * x2 * y2 - 5.0f * x4 - y4); // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) outputs[38] = 2.0182596029148963f * xy * (x2 - y2) * (11.0f * z2 - 1.0f); // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) outputs[39] = -0.92120525951492349f * yz * (3.0f * x2 - y2) * (11.0f * z2 - 3.0f); // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) outputs[40] = 0.92120525951492349f * xy * (-18.0f * z2 + 33.0f * z4 + 1.0f); // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) outputs[41] = 0.58262136251873131f * yz * (30.0f * z2 - 33.0f * z4 - 5.0f); // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) outputs[42] = 6.6747662381009842f * z2 - 20.024298714302954f * z4 + 14.684485723822165f * z6 - 0.31784601133814211f; // sqrt(13)*(105*z2 - 315*z4 + // 231*z6 - 5)/(32*sqrt(pi)) outputs[43] = 0.58262136251873131f * xz * (30.0f * z2 - 33.0f * z4 - 5.0f); // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) outputs[44] = 0.46060262975746175f * (x2 - y2) * (11.0f * z2 * (3.0f * z2 - 1.0f) - 7.0f * z2 + 1.0f); // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 // + 1)/(64*sqrt(pi)) outputs[45] = -0.92120525951492349f * xz * (x2 - 3.0f * y2) * (11.0f * z2 - 3.0f); // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) outputs[46] = 0.50456490072872406f * (11.0f * z2 - 1.0f) * (-6.0f * x2 * y2 + x4 + y4); // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + // y4)/(32*sqrt(pi)) outputs[47] = 2.3666191622317521f * xz * (10.0f * x2 * y2 - x4 - 5.0f * y4); // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - // 5*y4)/(32*sqrt(pi)) outputs[48] = 10.247761577878714f * x2 * y4 - 10.247761577878714f * x4 * y2 + 0.6831841051919143f * x6 - 0.6831841051919143f * y6; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + // x6 - y6)/(64*sqrt(pi)) if (C <= 7) { return; } outputs[49] = 0.70716273252459627f * y * (-21.0f * x2 * y4 + 35.0f * x4 * y2 - 7.0f * x6 + y6); // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + // y6)/(64*sqrt(pi)) outputs[50] = 5.2919213236038001f * xy * z * (-10.0f * x2 * y2 + 3.0f * x4 + 3.0f * y4); // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + // 3*y4)/(32*sqrt(pi)) outputs[51] = -0.51891557872026028f * y * (13.0f * z2 - 1.0f) * (-10.0f * x2 * y2 + 5.0f * x4 + y4); // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + // y4)/(64*sqrt(pi)) outputs[52] = 4.1513246297620823f * xy * z * (x2 - y2) * (13.0f * z2 - 3.0f); // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) outputs[53] = -0.15645893386229404f * y * (3.0f * x2 - y2) * (13.0f * z2 * (11.0f * z2 - 3.0f) - 27.0f * z2 + 3.0f); // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - // 27*z2 + 3)/(64*sqrt(pi)) outputs[54] = 0.44253269244498261f * xy * z * (-110.0f * z2 + 143.0f * z4 + 15.0f); // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + // 15)/(32*sqrt(pi)) outputs[55] = 0.090331607582517306f * y * (-135.0f * z2 + 495.0f * z4 - 429.0f * z6 + 5.0f); // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + // 5)/(64*sqrt(pi)) outputs[56] = 0.068284276912004949f * z * (315.0f * z2 - 693.0f * z4 + 429.0f * z6 - 35.0f); // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - // 35)/(32*sqrt(pi)) outputs[57] = 0.090331607582517306f * x * (-135.0f * z2 + 495.0f * z4 - 429.0f * z6 + 5.0f); // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + // 5)/(64*sqrt(pi)) outputs[58] = 0.07375544874083044f * z * (x2 - y2) * (143.0f * z2 * (3.0f * z2 - 1.0f) - 187.0f * z2 + 45.0f); // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - // 187*z2 + 45)/(64*sqrt(pi)) outputs[59] = -0.15645893386229404f * x * (x2 - 3.0f * y2) * (13.0f * z2 * (11.0f * z2 - 3.0f) - 27.0f * z2 + 3.0f); // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - // 27*z2 + 3)/(64*sqrt(pi)) outputs[60] = 1.0378311574405206f * z * (13.0f * z2 - 3.0f) * (-6.0f * x2 * y2 + x4 + y4); // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + // y4)/(32*sqrt(pi)) outputs[61] = -0.51891557872026028f * x * (13.0f * z2 - 1.0f) * (-10.0f * x2 * y2 + x4 + 5.0f * y4); // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + // x4 + 5*y4)/(64*sqrt(pi)) outputs[62] = 2.6459606618019f * z * (15.0f * x2 * y4 - 15.0f * x4 * y2 + x6 - y6); // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - // y6)/(64*sqrt(pi)) outputs[63] = 0.70716273252459627f * x * (-35.0f * x2 * y4 + 21.0f * x4 * y2 - x6 + 7.0f * y6); // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 // + 7*y6)/(64*sqrt(pi)) }; write_sh(); if (dy_dx) { scalar_t *dx = dy_dx + b * D * C2; scalar_t *dy = dx + C2; scalar_t *dz = dy + C2; auto write_sh_dx = [&]() { dx[0] = 0.0f; // 0 if (C <= 1) { return; } dx[1] = 0.0f; // 0 dx[2] = 0.0f; // 0 dx[3] = -0.48860251190291992f; // -sqrt(3)/(2*sqrt(pi)) if (C <= 2) { return; } dx[4] = 1.0925484305920792f * y; // sqrt(15)*y/(2*sqrt(pi)) dx[5] = 0.0f; // 0 dx[6] = 0.0f; // 0 dx[7] = -1.0925484305920792f * z; // -sqrt(15)*z/(2*sqrt(pi)) dx[8] = 1.0925484305920792f * x; // sqrt(15)*x/(2*sqrt(pi)) if (C <= 3) { return; } dx[9] = -3.5402615395598609f * xy; // -3*sqrt(70)*xy/(4*sqrt(pi)) dx[10] = 2.8906114426405538f * yz; // sqrt(105)*yz/(2*sqrt(pi)) dx[11] = 0.0f; // 0 dx[12] = 0.0f; // 0 dx[13] = 0.45704579946446572f - 2.2852289973223288f * z2; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) dx[14] = 2.8906114426405538f * xz; // sqrt(105)*xz/(2*sqrt(pi)) dx[15] = -1.7701307697799304f * x2 + 1.7701307697799304f * y2; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) if (C <= 4) { return; } dx[16] = 2.5033429417967046f * y * (3.0f * x2 - y2); // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) dx[17] = -10.620784618679583f * xy * z; // -9*sqrt(70)*xy*z/(4*sqrt(pi)) dx[18] = 0.94617469575756008f * y * (7.0f * z2 - 1.0f); // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) dx[19] = 0.0f; // 0 dx[20] = 0.0f; // 0 dx[21] = 0.66904654355728921f * z * (3.0f - 7.0f * z2); // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) dx[22] = 0.94617469575756008f * x * (7.0f * z2 - 1.0f); // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) dx[23] = 5.3103923093397913f * z * (-x2 + y2); // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) dx[24] = 2.5033429417967046f * x * (x2 - 3.0f * y2); // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) if (C <= 5) { return; } dx[25] = 13.127641136803401f * xy * (-x2 + y2); // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) dx[26] = 8.3026492595241645f * yz * (3.0f * x2 - y2); // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) dx[27] = 2.9354297966115022f * xy * (1.0f - 9.0f * z2); // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) dx[28] = 4.7935367849733241f * yz * (3.0f * z2 - 1.0f); // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) dx[29] = 0.0f; // 0 dx[30] = 0.0f; // 0 dx[31] = 6.3412531167397574f * z2 - 9.5118796751096362f * z4 - 0.45294665119569694f; // sqrt(165)*(14*z2 - 21*z4 - // 1)/(16*sqrt(pi)) dx[32] = 4.7935367849733241f * xz * (3.0f * z2 - 1.0f); // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) dx[33] = -13.209434084751759f * x2 * z2 + 1.4677148983057511f * x2 + 13.209434084751759f * y2 * z2 - 1.4677148983057511f * y2; // 3*sqrt(770)*(-9*x2*z2 + x2 + // 9*y2*z2 - y2)/(32*sqrt(pi)) dx[34] = 8.3026492595241645f * xz * (x2 - 3.0f * y2); // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) dx[35] = 19.6914617052051f * x2 * y2 - 3.2819102842008503f * x4 - 3.2819102842008503f * y4; // 15*sqrt(154)*(6*x2*y2 - x4 - // y4)/(32*sqrt(pi)) if (C <= 6) { return; } dx[36] = 4.0991046311514854f * y * (-10.0f * x2 * y2 + 5.0f * x4 + y4); // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + // y4)/(32*sqrt(pi)) dx[37] = 47.332383244635047f * xy * z * (-x2 + y2); // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) dx[38] = 2.0182596029148963f * y * (3.0f * x2 - y2) * (11.0f * z2 - 1.0f); // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - // 1)/(8*sqrt(pi)) dx[39] = 5.5272315570895412f * xy * z * (3.0f - 11.0f * z2); // 3*sqrt(2730)*xy*z*(3 - // 11*z2)/(16*sqrt(pi)) dx[40] = 0.92120525951492349f * y * (-18.0f * z2 + 33.0f * z4 + 1.0f); // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) dx[41] = 0.0f; // 0 dx[42] = 0.0f; // 0 dx[43] = 0.58262136251873131f * z * (30.0f * z2 - 33.0f * z4 - 5.0f); // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) dx[44] = 0.92120525951492349f * x * (-18.0f * z2 + 33.0f * z4 + 1.0f); // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) dx[45] = -2.7636157785447706f * z * (x2 - y2) * (11.0f * z2 - 3.0f); // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - // 3)/(32*sqrt(pi)) dx[46] = 2.0182596029148963f * x * (x2 - 3.0f * y2) * (11.0f * z2 - 1.0f); // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - // 1)/(8*sqrt(pi)) dx[47] = 11.833095811158762f * z * (6.0f * x2 * y2 - x4 - y4); // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) dx[48] = 4.0991046311514854f * x * (-10.0f * x2 * y2 + x4 + 5.0f * y4); // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + // 5*y4)/(32*sqrt(pi)) if (C <= 7) { return; } dx[49] = 9.9002782553443485f * xy * (10.0f * x2 * y2 - 3.0f * x4 - 3.0f * y4); // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - // 3*y4)/(32*sqrt(pi)) dx[50] = 15.875763970811402f * yz * (-10.0f * x2 * y2 + 5.0f * x4 + y4); // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + // y4)/(32*sqrt(pi)) dx[51] = -10.378311574405206f * xy * (x2 - y2) * (13.0f * z2 - 1.0f); // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 // - 1)/(16*sqrt(pi)) dx[52] = 4.1513246297620823f * yz * (3.0f * x2 - y2) * (13.0f * z2 - 3.0f); // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 // - 3)/(8*sqrt(pi)) dx[53] = 0.93875360317376422f * xy * (66.0f * z2 - 143.0f * z4 - 3.0f); // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) dx[54] = 0.44253269244498261f * yz * (-110.0f * z2 + 143.0f * z4 + 15.0f); // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + // 15)/(32*sqrt(pi)) dx[55] = 0.0f; // 0 dx[56] = 0.0f; // 0 dx[57] = -12.194767023639836f * z2 + 44.714145753346067f * z4 - 38.752259652899923f * z6 + 0.45165803791258652f; // sqrt(105)*(-135*z2 + 495*z4 - // 429*z6 + 5)/(64*sqrt(pi)) dx[58] = 0.44253269244498261f * xz * (-110.0f * z2 + 143.0f * z4 + 15.0f); // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + // 15)/(32*sqrt(pi)) dx[59] = 30.97886890473422f * x2 * z2 - 67.120882626924143f * x2 * z4 - 1.4081304047606462f * x2 - 30.97886890473422f * y2 * z2 + 67.120882626924143f * y2 * z4 + 1.4081304047606462f * y2; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - // 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) dx[60] = 4.1513246297620823f * xz * (x2 - 3.0f * y2) * (13.0f * z2 - 3.0f); // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 // - 3)/(8*sqrt(pi)) dx[61] = -0.51891557872026028f * (13.0f * z2 - 1.0f) * (-10.0f * x2 * y2 + 4.0f * x2 * (x2 - 5.0f * y2) + x4 + 5.0f * y4); // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + // 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) dx[62] = 15.875763970811402f * xz * (-10.0f * x2 * y2 + x4 + 5.0f * y4); // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + // 5*y4)/(32*sqrt(pi)) dx[63] = -74.252086915082614f * x2 * y4 + 74.252086915082614f * x4 * y2 - 4.9501391276721742f * x6 + 4.9501391276721742f * y6; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + // y6)/(64*sqrt(pi)) }; auto write_sh_dy = [&]() { dy[0] = 0.0f; // 0 if (C <= 1) { return; } dy[1] = -0.48860251190291992f; // -sqrt(3)/(2*sqrt(pi)) dy[2] = 0.0f; // 0 dy[3] = 0.0f; // 0 if (C <= 2) { return; } dy[4] = 1.0925484305920792f * x; // sqrt(15)*x/(2*sqrt(pi)) dy[5] = -1.0925484305920792f * z; // -sqrt(15)*z/(2*sqrt(pi)) dy[6] = 0.0f; // 0 dy[7] = 0.0f; // 0 dy[8] = -1.0925484305920792f * y; // -sqrt(15)*y/(2*sqrt(pi)) if (C <= 3) { return; } dy[9] = -1.7701307697799304f * x2 + 1.7701307697799304f * y2; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) dy[10] = 2.8906114426405538f * xz; // sqrt(105)*xz/(2*sqrt(pi)) dy[11] = 0.45704579946446572f - 2.2852289973223288f * z2; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) dy[12] = 0.0f; // 0 dy[13] = 0.0f; // 0 dy[14] = -2.8906114426405538f * yz; // -sqrt(105)*yz/(2*sqrt(pi)) dy[15] = 3.5402615395598609f * xy; // 3*sqrt(70)*xy/(4*sqrt(pi)) if (C <= 4) { return; } dy[16] = 2.5033429417967046f * x * (x2 - 3.0f * y2); // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) dy[17] = 5.3103923093397913f * z * (-x2 + y2); // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) dy[18] = 0.94617469575756008f * x * (7.0f * z2 - 1.0f); // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) dy[19] = 0.66904654355728921f * z * (3.0f - 7.0f * z2); // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) dy[20] = 0.0f; // 0 dy[21] = 0.0f; // 0 dy[22] = 0.94617469575756008f * y * (1.0f - 7.0f * z2); // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) dy[23] = 10.620784618679583f * xy * z; // 9*sqrt(70)*xy*z/(4*sqrt(pi)) dy[24] = 2.5033429417967046f * y * (-3.0f * x2 + y2); // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) if (C <= 5) { return; } dy[25] = 19.6914617052051f * x2 * y2 - 3.2819102842008503f * x4 - 3.2819102842008503f * y4; // 15*sqrt(154)*(6*x2*y2 - x4 - // y4)/(32*sqrt(pi)) dy[26] = 8.3026492595241645f * xz * (x2 - 3.0f * y2); // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) dy[27] = -1.4677148983057511f * (x2 - y2) * (9.0f * z2 - 1.0f); // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) dy[28] = 4.7935367849733241f * xz * (3.0f * z2 - 1.0f); // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) dy[29] = 6.3412531167397574f * z2 - 9.5118796751096362f * z4 - 0.45294665119569694f; // sqrt(165)*(14*z2 - 21*z4 - // 1)/(16*sqrt(pi)) dy[30] = 0.0f; // 0 dy[31] = 0.0f; // 0 dy[32] = 4.7935367849733241f * yz * (1.0f - 3.0f * z2); // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) dy[33] = 2.9354297966115022f * xy * (9.0f * z2 - 1.0f); // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) dy[34] = 8.3026492595241645f * yz * (-3.0f * x2 + y2); // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) dy[35] = 13.127641136803401f * xy * (x2 - y2); // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) if (C <= 6) { return; } dy[36] = 4.0991046311514854f * x * (-10.0f * x2 * y2 + x4 + 5.0f * y4); // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + // 5*y4)/(32*sqrt(pi)) dy[37] = 11.833095811158762f * z * (6.0f * x2 * y2 - x4 - y4); // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) dy[38] = 2.0182596029148963f * x * (x2 - 3.0f * y2) * (11.0f * z2 - 1.0f); // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - // 1)/(8*sqrt(pi)) dy[39] = -2.7636157785447706f * z * (x2 - y2) * (11.0f * z2 - 3.0f); // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - // 3)/(32*sqrt(pi)) dy[40] = 0.92120525951492349f * x * (-18.0f * z2 + 33.0f * z4 + 1.0f); // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) dy[41] = 0.58262136251873131f * z * (30.0f * z2 - 33.0f * z4 - 5.0f); // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) dy[42] = 0.0f; // 0 dy[43] = 0.0f; // 0 dy[44] = 0.92120525951492349f * y * (18.0f * z2 - 33.0f * z4 - 1.0f); // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) dy[45] = 5.5272315570895412f * xy * z * (11.0f * z2 - 3.0f); // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) dy[46] = -2.0182596029148963f * y * (3.0f * x2 - y2) * (11.0f * z2 - 1.0f); // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - // 1)/(8*sqrt(pi)) dy[47] = 47.332383244635047f * xy * z * (x2 - y2); // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) dy[48] = 4.0991046311514854f * y * (10.0f * x2 * y2 - 5.0f * x4 - y4); // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - // y4)/(32*sqrt(pi)) if (C <= 7) { return; } dy[49] = -74.252086915082614f * x2 * y4 + 74.252086915082614f * x4 * y2 - 4.9501391276721742f * x6 + 4.9501391276721742f * y6; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + // y6)/(64*sqrt(pi)) dy[50] = 15.875763970811402f * xz * (-10.0f * x2 * y2 + x4 + 5.0f * y4); // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + // 5*y4)/(32*sqrt(pi)) dy[51] = 0.51891557872026028f * (13.0f * z2 - 1.0f) * (10.0f * x2 * y2 - 5.0f * x4 + 4.0f * y2 * (5.0f * x2 - y2) - y4); // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + // 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) dy[52] = 4.1513246297620823f * xz * (x2 - 3.0f * y2) * (13.0f * z2 - 3.0f); // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 // - 3)/(8*sqrt(pi)) dy[53] = -0.46937680158688211f * (x2 - y2) * (13.0f * z2 * (11.0f * z2 - 3.0f) - 27.0f * z2 + 3.0f); // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - // 27*z2 + 3)/(64*sqrt(pi)) dy[54] = 0.44253269244498261f * xz * (-110.0f * z2 + 143.0f * z4 + 15.0f); // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + // 15)/(32*sqrt(pi)) dy[55] = -12.194767023639836f * z2 + 44.714145753346067f * z4 - 38.752259652899923f * z6 + 0.45165803791258652f; // sqrt(105)*(-135*z2 + 495*z4 - // 429*z6 + 5)/(64*sqrt(pi)) dy[56] = 0.0f; // 0 dy[57] = 0.0f; // 0 dy[58] = 0.44253269244498261f * yz * (110.0f * z2 - 143.0f * z4 - 15.0f); // 3*sqrt(70)*yz*(110*z2 - 143*z4 - // 15)/(32*sqrt(pi)) dy[59] = 0.93875360317376422f * xy * (-66.0f * z2 + 143.0f * z4 + 3.0f); // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + // 3)/(32*sqrt(pi)) dy[60] = -4.1513246297620823f * yz * (3.0f * x2 - y2) * (13.0f * z2 - 3.0f); // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 // - 3)/(8*sqrt(pi)) dy[61] = 10.378311574405206f * xy * (x2 - y2) * (13.0f * z2 - 1.0f); // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - // 1)/(16*sqrt(pi)) dy[62] = 15.875763970811402f * yz * (10.0f * x2 * y2 - 5.0f * x4 - y4); // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - // y4)/(32*sqrt(pi)) dy[63] = 9.9002782553443485f * xy * (-10.0f * x2 * y2 + 3.0f * x4 + 3.0f * y4); // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + // 3*y4)/(32*sqrt(pi)) }; auto write_sh_dz = [&]() { dz[0] = 0.0f; // 0 if (C <= 1) { return; } dz[1] = 0.0f; // 0 dz[2] = 0.48860251190291992f; // sqrt(3)/(2*sqrt(pi)) dz[3] = 0.0f; // 0 if (C <= 2) { return; } dz[4] = 0.0f; // 0 dz[5] = -1.0925484305920792f * y; // -sqrt(15)*y/(2*sqrt(pi)) dz[6] = 1.8923493915151202f * z; // 3*sqrt(5)*z/(2*sqrt(pi)) dz[7] = -1.0925484305920792f * x; // -sqrt(15)*x/(2*sqrt(pi)) dz[8] = 0.0f; // 0 if (C <= 3) { return; } dz[9] = 0.0f; // 0 dz[10] = 2.8906114426405538f * xy; // sqrt(105)*xy/(2*sqrt(pi)) dz[11] = -4.5704579946446566f * yz; // -5*sqrt(42)*yz/(4*sqrt(pi)) dz[12] = 5.597644988851731f * z2 - 1.1195289977703462f; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) dz[13] = -4.5704579946446566f * xz; // -5*sqrt(42)*xz/(4*sqrt(pi)) dz[14] = 1.4453057213202769f * x2 - 1.4453057213202769f * y2; // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) dz[15] = 0.0f; // 0 if (C <= 4) { return; } dz[16] = 0.0f; // 0 dz[17] = 1.7701307697799304f * y * (-3.0f * x2 + y2); // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) dz[18] = 13.246445740605839f * xy * z; // 21*sqrt(5)*xy*z/(2*sqrt(pi)) dz[19] = 2.0071396306718676f * y * (1.0f - 7.0f * z2); // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) dz[20] = 14.809976568128603f * pow(z, 3) - 6.3471328149122579f * z; // (105*z**3 - 45*z)/(4*sqrt(pi)) dz[21] = 2.0071396306718676f * x * (1.0f - 7.0f * z2); // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) dz[22] = 6.6232228703029197f * z * (x2 - y2); // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) dz[23] = 1.7701307697799304f * x * (-x2 + 3.0f * y2); // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) dz[24] = 0.0f; // 0 if (C <= 5) { return; } dz[25] = 0.0f; // 0 dz[26] = 8.3026492595241645f * xy * (x2 - y2); // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) dz[27] = 8.8062893898345074f * yz * (-3.0f * x2 + y2); // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) dz[28] = 4.7935367849733241f * xy * (9.0f * z2 - 1.0f); // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) dz[29] = 12.682506233479513f * yz * (1.0f - 3.0f * z2); // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) dz[30] = -24.559567715218954f * z2 + 36.839351572828434f * z4 + 1.754254836801354f; // 15*sqrt(11)*(-14*z2 + 21*z4 + // 1)/(16*sqrt(pi)) dz[31] = 12.682506233479513f * xz * (1.0f - 3.0f * z2); // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) dz[32] = 2.3967683924866621f * (x2 - y2) * (9.0f * z2 - 1.0f); // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) dz[33] = 8.8062893898345074f * xz * (-x2 + 3.0f * y2); // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) dz[34] = -12.453973889286246f * x2 * y2 + 2.0756623148810411f * x4 + 2.0756623148810411f * y4; // 3*sqrt(385)*(-6*x2*y2 + x4 + // y4)/(16*sqrt(pi)) dz[35] = 0.0f; // 0 if (C <= 6) { return; } dz[36] = 0.0f; // 0 dz[37] = 2.3666191622317521f * y * (10.0f * x2 * y2 - 5.0f * x4 - y4); // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - // y4)/(32*sqrt(pi)) dz[38] = 44.401711264127719f * xy * z * (x2 - y2); // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) dz[39] = -2.7636157785447706f * y * (3.0f * x2 - y2) * (11.0f * z2 - 1.0f); // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 // - 1)/(32*sqrt(pi)) dz[40] = 11.054463114179082f * xy * z * (11.0f * z2 - 3.0f); // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) dz[41] = 2.9131068125936568f * y * (18.0f * z2 - 33.0f * z4 - 1.0f); // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) dz[42] = 2.6699064952403937f * z * (-30.0f * z2 + 33.0f * z4 + 5.0f); // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) dz[43] = 2.9131068125936568f * x * (18.0f * z2 - 33.0f * z4 - 1.0f); // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) dz[44] = 5.5272315570895412f * z * (x2 - y2) * (11.0f * z2 - 3.0f); // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - // 3)/(16*sqrt(pi)) dz[45] = -2.7636157785447706f * x * (x2 - 3.0f * y2) * (11.0f * z2 - 1.0f); // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 // - 1)/(32*sqrt(pi)) dz[46] = 11.10042781603193f * z * (-6.0f * x2 * y2 + x4 + y4); // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) dz[47] = 2.3666191622317521f * x * (10.0f * x2 * y2 - x4 - 5.0f * y4); // 3*sqrt(2002)*x*(10*x2*y2 - x4 - // 5*y4)/(32*sqrt(pi)) dz[48] = 0.0f; // 0 if (C <= 7) { return; } dz[49] = 0.0f; // 0 dz[50] = 5.2919213236038001f * xy * (-10.0f * x2 * y2 + 3.0f * x4 + 3.0f * y4); // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + // 3*y4)/(32*sqrt(pi)) dz[51] = 13.491805046726766f * yz * (10.0f * x2 * y2 - 5.0f * x4 - y4); // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - // y4)/(32*sqrt(pi)) dz[52] = 12.453973889286248f * xy * (x2 - y2) * (13.0f * z2 - 1.0f); // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - // 1)/(8*sqrt(pi)) dz[53] = -6.8841930899409371f * yz * (3.0f * x2 - y2) * (13.0f * z2 - 3.0f); // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 // - 3)/(16*sqrt(pi)) dz[54] = 2.2126634622249131f * xy * (-66.0f * z2 + 143.0f * z4 + 3.0f); // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + // 3)/(32*sqrt(pi)) dz[55] = 1.6259689364853116f * yz * (110.0f * z2 - 143.0f * z4 - 15.0f); // 9*sqrt(105)*yz*(110*z2 - 143*z4 - // 15)/(32*sqrt(pi)) dz[56] = 64.528641681844675f * z2 - 236.60501950009714f * z4 + 205.05768356675085f * z6 - 2.3899496919201733f; // 7*sqrt(15)*(135*z2 - 495*z4 + // 429*z6 - 5)/(32*sqrt(pi)) dz[57] = 1.6259689364853116f * xz * (110.0f * z2 - 143.0f * z4 - 15.0f); // 9*sqrt(105)*xz*(110*z2 - 143*z4 - // 15)/(32*sqrt(pi)) dz[58] = 0.07375544874083044f * (x2 - y2) * (143.0f * z2 * (3.0f * z2 - 1.0f) + 132.0f * z2 * (13.0f * z2 - 5.0f) - 187.0f * z2 + 45.0f); // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + // 132*z2*(13*z2 // - 5) - 187*z2 + 45)/(64*sqrt(pi)) dz[59] = -6.8841930899409371f * xz * (x2 - 3.0f * y2) * (13.0f * z2 - 3.0f); // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 // - 3)/(16*sqrt(pi)) dz[60] = 3.1134934723215619f * (13.0f * z2 - 1.0f) * (-6.0f * x2 * y2 + x4 + y4); // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + // y4)/(32*sqrt(pi)) dz[61] = 13.491805046726766f * xz * (10.0f * x2 * y2 - x4 - 5.0f * y4); // 39*sqrt(385)*xz*(10*x2*y2 - x4 - // 5*y4)/(32*sqrt(pi)) dz[62] = 39.6894099270285f * x2 * y4 - 39.6894099270285f * x4 * y2 + 2.6459606618019f * x6 - 2.6459606618019f * y6; // 3*sqrt(10010)*(15*x2*y4 - // 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) dz[63] = 0.0f; // 0 }; write_sh_dx(); write_sh_dy(); write_sh_dz(); } } template __global__ void kernel_sh_backward(const scalar_t *__restrict__ grad, const scalar_t *__restrict__ inputs, uint32_t B, uint32_t D, uint32_t C, const scalar_t *__restrict__ dy_dx, scalar_t *grad_inputs) { const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; const uint32_t b = t / D; if (b >= B) return; const uint32_t d = t - b * D; const uint32_t C2 = C * C; // locate grad += b * C2; dy_dx += b * D * C2 + d * C2; for (int ch = 0; ch < C2; ch++) { grad_inputs[t] += grad[ch] * dy_dx[ch]; // printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, // ch, grad_inputs[t], grad[ch], dy_dx[ch]); } } // inputs: [B, D], float, in [0, 1] // outputs: [B, L * C], float template void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx) { static constexpr uint32_t N_THREADS = 256; kernel_sh<<>>( inputs, outputs, B, D, C, dy_dx); } template void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) { static constexpr uint32_t N_THREADS = 256; kernel_sh_backward<<>>( grad, inputs, B, D, C, dy_dx, grad_inputs); } void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx) { CHECK_CUDA(inputs); CHECK_CUDA(outputs); // CHECK_CUDA(dy_dx); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(outputs); // CHECK_CONTIGUOUS(dy_dx); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(outputs); // CHECK_IS_FLOATING(dy_dx); AT_DISPATCH_FLOATING_TYPES_AND_HALF( inputs.scalar_type(), "sh_encode_forward_cuda", ([&] { sh_encode_forward_cuda( inputs.data_ptr(), outputs.data_ptr(), B, D, C, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr); })); } void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) { CHECK_CUDA(grad); CHECK_CUDA(inputs); CHECK_CUDA(dy_dx); CHECK_CUDA(grad_inputs); CHECK_CONTIGUOUS(grad); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(dy_dx); CHECK_CONTIGUOUS(grad_inputs); CHECK_IS_FLOATING(grad); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(dy_dx); CHECK_IS_FLOATING(grad_inputs); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "sh_encode_backward_cuda", ([&] { sh_encode_backward_cuda( grad.data_ptr(), inputs.data_ptr(), B, D, C, dy_dx.data_ptr(), grad_inputs.data_ptr()); })); } ================================================ FILE: lidarnerf/shencoder/src/shencoder.h ================================================ #pragma once #include #include // inputs: [B, D], float, in [-1, 1] // outputs: [B, F], float void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, at::optional dy_dx); void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); ================================================ FILE: lidarnvs/__init__.py ================================================ __version__ = "0.0.1" ================================================ FILE: lidarnvs/configs/pcgen_kitti360_raydrop.txt ================================================ basedir = pcgen_raydrop_log/kitti360seq1908 datadir = data/raydrop/pcgen/kitti360_1908 dataset = kitti360 no_batching = False lrate=5e-3 lrate_decay = 500 rgb_loss_type=mseloss i_embed=-1 i_embed_views=-1 N_iters = 10000 cosLR=False netdepth=4 netwidth=128 N_rand = 2048 H=66 W=1030 i_save=5000 i_print=100 i_weights=5000 ================================================ FILE: lidarnvs/configs/pcgen_nerfmvl_raydrop.txt ================================================ # ['water_safety_barrier', 'tire', 'pier', 'plant', 'warning_sign', 'bollard', 'pedestrian', 'car', 'traffic_cone'] expname = car basedir = pcgen_raydrop_log datadir = data/raydrop/pcgen/nerf_mvl_car dataset = nerfmvl no_batching = False lrate=5e-3 lrate_decay = 500 rgb_loss_type=mseloss N_iters = 10000 cosLR=False netdepth=4 netwidth=128 N_rand = 2048 H=256 W=1800 i_save=5000 i_print=100 i_weights=5000 ================================================ FILE: lidarnvs/eval.py ================================================ import numpy as np import torch from skimage.metrics import structural_similarity from extern.chamfer3D.dist_chamfer_3D import chamfer_3DDist from extern.fscore import fscore def eval_points_and_pano( gt_local_points: np.ndarray, pd_local_points: np.ndarray, gt_intensities: np.ndarray, pd_intensities: np.ndarray, gt_pano: np.ndarray, pd_pano: np.ndarray, ) -> dict: """ Args: gt_local_points: (N, 3), float32, local point coords in world-scale. pd_local_points: (M, 3), float32, local point coords in world-scale. gt_intensities: (H, W), float32, point intensities, >= 0. pd_intensities: (H, W), float32, point intensities, >= 0. gt_pano: (H, W), float32, range depth image in world-scale. 0 means dropped rays. A dropped ray must not have intensity. pd_pano: (H, W), float32, range depth image in world-scale. 0 means dropped rays. A dropped ray must not have intensity. Returns: # Depth metrics - metrics["depth_rmse"] - metrics["depth_a1"] - metrics["depth_a2"] - metrics["depth_a3"] # Point metrics - metrics["chamfer"] - metrics["f_score"] # Intensity metrics - metrics["intensity_mae"] """ # Sanity checks. if not gt_local_points.ndim == 2 or not gt_local_points.shape[1] == 3: raise ValueError( f"gt_local_points must be (N, 3), but got {gt_local_points.shape}" ) if not pd_local_points.ndim == 2 or not pd_local_points.shape[1] == 3: raise ValueError( f"pd_local_points must be (M, 3), but got {pd_local_points.shape}" ) if not gt_intensities.ndim == 2: raise ValueError( f"gt_intensities must be (H, W), but got {gt_intensities.shape}" ) lidar_H, lidar_W = gt_intensities.shape if not pd_intensities.shape == (lidar_H, lidar_W): raise ValueError( f"pd_intensities must be (H, W), but got {pd_intensities.shape}" ) if not gt_pano.shape == (lidar_H, lidar_W): raise ValueError(f"gt_pano must be (H, W), but got {gt_pano.shape}") if not pd_pano.shape == (lidar_H, lidar_W): raise ValueError(f"pd_pano must be (H, W), but got {pd_pano.shape}") # All shall be numpy array is_instance_all = [ isinstance(e, np.ndarray) for e in [ gt_local_points, pd_local_points, gt_intensities, pd_intensities, gt_pano, pd_pano, ] ] if not all(is_instance_all): raise ValueError("All inputs must be numpy array.") def compute_depth_metrics( gt_depths, pd_depths, min_depth=1e-3, max_depth=80, thresh_set=1.25 ): pd_depths[pd_depths < min_depth] = min_depth pd_depths[pd_depths > max_depth] = max_depth gt_depths[gt_depths < min_depth] = min_depth gt_depths[gt_depths > max_depth] = max_depth thresh = np.maximum((gt_depths / pd_depths), (pd_depths / gt_depths)) a1 = (thresh < thresh_set).mean() a2 = (thresh < thresh_set**2).mean() a3 = (thresh < thresh_set**3).mean() rmse = (gt_depths - pd_depths) ** 2 rmse = np.sqrt(rmse.mean()) ssim = structural_similarity( gt_depths, pd_depths, data_range=gt_depths.max() - gt_depths.min(), ) return rmse, a1, a2, a3, ssim def compute_point_metrics(gt_points, pd_points): chamLoss = chamfer_3DDist() dist1, dist2, idx1, idx2 = chamLoss( torch.tensor(pd_points[None, ...]).float().cuda(), torch.tensor(gt_points[None, ...]).float().cuda(), ) chamfer_dis = dist1.mean() + dist2.mean() threshold = 0.05 # monoSDF f_score, precision, recall = fscore(dist1, dist2, threshold) chamfer_dis = chamfer_dis.item() f_score = f_score.item() return chamfer_dis, f_score def compute_intensity_metrics(gt_intensities, pd_intensities): mae = np.abs(gt_intensities - pd_intensities).mean() return mae metrics = dict() ( metrics["depth_rmse"], metrics["depth_a1"], metrics["depth_a2"], metrics["depth_a3"], metrics["depth_ssim"], ) = compute_depth_metrics(gt_depths=gt_pano.flatten(), pd_depths=pd_pano.flatten()) ( metrics["chamfer"], metrics["f_score"], ) = compute_point_metrics(gt_points=gt_local_points, pd_points=pd_local_points) metrics["intensity_mae"] = compute_intensity_metrics( gt_intensities=gt_intensities, pd_intensities=pd_intensities ) return metrics ================================================ FILE: lidarnvs/lidarnvs_base.py ================================================ from abc import ABC, abstractmethod import numpy as np class LidarNVSBase(ABC): @abstractmethod def fit(self, dataset) -> None: """ Fit the model to the given train dataset. Args: dataset: A NeRFDataset object. """ @abstractmethod def predict_frame( self, lidar_K: np.ndarray, # (2, ) lidar_pose: np.ndarray, # (4, 4) lidar_H: int, lidar_W: int, ) -> dict: """ Predict (synthesis) the point cloud from the given lidar parameters. All necessary information parameters to model a lidar are given. Args: lidar_K: (2, ), float32 lidar_pose: (4, 4), float32 lidar_H: int lidar_W: int Return: predict_dict: dict - ["local_points"]: (N, 3), float32 - ["points"] : (N, 3), float32 - ["pano"] : (H, W), float32 - ["intensities"] : (H, W), float32 """ @abstractmethod def predict_frame_with_raydrop( self, lidar_K: np.ndarray, # (2, ) lidar_pose: np.ndarray, # (4, 4) lidar_H: int, lidar_W: int, ) -> dict: pass ================================================ FILE: lidarnvs/lidarnvs_meshing.py ================================================ import camtools as ct import matplotlib.pyplot as plt import numpy as np import open3d as o3d import open3d.core as o3c import torch import torch.nn.functional as F from tqdm import tqdm from lidarnerf.convert import ( lidar_to_pano_with_intensities, pano_to_lidar_with_intensities, ) from lidarnerf.dataset.base_dataset import get_lidar_rays from lidarnvs.lidarnvs_base import LidarNVSBase from lidarnvs.loader import extract_dataset_frame from lidarnvs.unet import UNet class LidarNVSMeshing(LidarNVSBase): """ Liar novel-view synthesis with meshing and ray casting. This is intended to be a base class, where the children class can use different meshing methods. """ def __init__(self, ckpt_path=None): self.ckpt_path = ckpt_path # Network for predicting ray-drop. if ckpt_path is not None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = UNet(n_channels=10, n_classes=1, bilinear=False) self.model = self.model.to(memory_format=torch.channels_last) self.model = self.model.to(device=self.device) state_dict = torch.load(self.ckpt_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.eval() print(f"Checkpoint loaded from {self.ckpt_path}") # To be filled in the fit() method. self.points = None self.point_intensities = None self.pcd = None self.kdtree = None self.mesh = None # To be overwritten by the child class. # - meshing_func: o3d.geometry.PointCloud -> o3d.geometry.TriangleMesh # - the meshing_func shall already be populated with hyper-parameters self.meshing_func = None def fit(self, dataset) -> None: """ Fit the model to the given train dataset. Args: dataset: A NeRFDataset object. """ # Extract all points, in world coordinates. num_frames = len(dataset) all_points = [] all_point_intensities = [] for frame_idx in tqdm(range(num_frames), "Extract train frames"): frame_dict = extract_dataset_frame(dataset, frame_idx) all_points.append(frame_dict["points"]) all_point_intensities.append(frame_dict["point_intensities"]) all_points = np.vstack(all_points) all_point_intensities = np.hstack(all_point_intensities) assert len(all_points) == len(all_point_intensities) # Build Open3D pcd. self.pcd = o3d.geometry.PointCloud() self.pcd.points = o3d.utility.Vector3dVector(all_points) colors = ct.colormap.query(all_point_intensities) self.pcd.colors = o3d.utility.Vector3dVector(colors) self.pcd.estimate_normals() # Save points and intensities for interpolation. self.points = all_points self.point_intensities = all_point_intensities # Run Poisson recon. self.mesh = self.meshing_func(self.pcd) self.mesh.compute_vertex_normals() # o3d.visualization.draw_geometries([self.mesh]) # Build kdtree for kNN search. self.kdtree = o3d.geometry.KDTreeFlann(self.pcd) # Build scene for ray casting. self.raycasting_scene = o3d.t.geometry.RaycastingScene() self.raycasting_scene.add_triangles( o3d.t.geometry.TriangleMesh.from_legacy(self.mesh) ) def predict_frame( self, lidar_K: np.ndarray, # (2, ) lidar_pose: np.ndarray, # (4, 4) lidar_H: int, lidar_W: int, ) -> dict: """ Predict (synthesis) the point cloud from the given lidar parameters. All necessary information parameters to model a lidar are given. Args: lidar_K: (2, ), float32 lidar_pose: (4, 4), float32 lidar_H: int lidar_W: int Return: predict_dict: dict - ["local_points"]: (N, 3), float32 - ["points"] : (N, 3), float32 - ["pano"] : (H, W), float32 - ["intensities"] : (H, W), float32 """ # In world and local coordinates. hit_dict = self.intersect_lidar(lidar_K, lidar_pose, lidar_H, lidar_W) points = hit_dict["points"][hit_dict["masks"]] local_points = ct.project.homo_project( points, ct.convert.pose_to_T(lidar_pose), ) # Point intensities in world/local coordinates. point_intensities = [] for point in points: # ks, indices, distances2 _, indices, _ = self.kdtree.search_knn_vector_3d( point, self.intensity_interpolate_k ) point_intensities.append(np.mean(self.point_intensities[indices])) point_intensities = np.array(point_intensities) local_point_intensities = point_intensities # Pano intensities. local_points_with_intensities = np.concatenate( [local_points, local_point_intensities.reshape((-1, 1))], axis=1 ) pano, intensities = lidar_to_pano_with_intensities( local_points_with_intensities=local_points_with_intensities, lidar_H=lidar_H, lidar_W=lidar_W, lidar_K=lidar_K, ) predict_dict = { # Frame properties. "pano": pano, "intensities": intensities, # Global properties. "points": points, "point_intensities": point_intensities, # Local properties. "local_points": local_points, "local_point_intensities": local_point_intensities, # Hit properties: unfiltered results from ray casting. "hit_dict": hit_dict, } return predict_dict @torch.inference_mode() def predict_frame_with_raydrop( self, lidar_K: np.ndarray, # (2, ) lidar_pose: np.ndarray, # (4, 4) lidar_H: int, lidar_W: int, ) -> dict: """ TODO: I know this is ugly. This is the manual combination of: - generate_raydrop_data() - RaydropDataset::collate_fn() """ nvs_frame = self.predict_frame( lidar_K=lidar_K, lidar_pose=lidar_pose, lidar_H=lidar_H, lidar_W=lidar_W, ) # Compute incidence angle cosine. # TODO: make get rays a function. ray_dict = get_lidar_rays( poses=torch.tensor(np.array([lidar_pose])), intrinsics=torch.tensor(lidar_K), H=torch.tensor(lidar_H), W=torch.tensor(lidar_W), ) # generate_raydrop_data() ############################################ rays_o = ray_dict["rays_o"].squeeze().numpy() rays_d = ray_dict["rays_d"].squeeze().numpy() hit_normals = nvs_frame["hit_dict"]["normals"] hit_incidences = np.abs(np.sum(rays_d * hit_normals, axis=-1)) # Reshape. hit_masks = nvs_frame["hit_dict"]["masks"] hit_masks = hit_masks.reshape((lidar_H, lidar_W)) hit_depths = nvs_frame["hit_dict"]["depths"] hit_depths[hit_depths == np.inf] = 0 hit_depths = hit_depths.reshape((lidar_H, lidar_W)) hit_normals = hit_normals.reshape((lidar_H, lidar_W, 3)) hit_incidences = hit_incidences.reshape((lidar_H, lidar_W)) intensities = nvs_frame["intensities"] intensities = intensities.reshape((lidar_H, lidar_W)) rays_o = rays_o.reshape((lidar_H, lidar_W, 3)) rays_d = rays_d.reshape((lidar_H, lidar_W, 3)) # Cast hit_masks = torch.tensor(hit_masks.astype(np.float32)) hit_depths = torch.tensor(hit_depths.astype(np.float32)) hit_normals = torch.tensor(hit_normals.astype(np.float32)) hit_incidences = torch.tensor(hit_incidences.astype(np.float32)) intensities = torch.tensor(intensities.astype(np.float32)) rays_o = torch.tensor(rays_o.astype(np.float32)) rays_d = torch.tensor(rays_d.astype(np.float32)) # Add batch dimension 1 to the front hit_masks = hit_masks.unsqueeze(0) hit_depths = hit_depths.unsqueeze(0) hit_normals = hit_normals.unsqueeze(0) hit_incidences = hit_incidences.unsqueeze(0) intensities = intensities.unsqueeze(0) rays_o = rays_o.unsqueeze(0) rays_d = rays_d.unsqueeze(0) ###################################################################### # RaydropDataset::collate_fn() ####################################### # (N, H, W, C) images = torch.cat( [ hit_masks[..., None].to(self.device), hit_depths[..., None].to(self.device), hit_normals.to(self.device), hit_incidences[..., None].to(self.device), intensities[..., None].to(self.device), rays_d.to(self.device), ], dim=3, ) # (N, C, H, W) images = images.permute(0, 3, 1, 2) ###################################################################### # Predict raydrop mask. pd_raydrop_masks = self.model(images) pd_raydrop_masks = (F.sigmoid(pd_raydrop_masks) > 0.5).float() pd_raydrop_masks = pd_raydrop_masks.squeeze().cpu().numpy() if False: plt.imshow(pd_raydrop_masks) plt.show() # Update predict_dict # Frame properties. pano = nvs_frame["pano"] * pd_raydrop_masks intensities = nvs_frame["intensities"] * pd_raydrop_masks # Local properties. local_points_with_intensities = pano_to_lidar_with_intensities( pano=pano, intensities=intensities, lidar_K=lidar_K, ) local_points = local_points_with_intensities[:, :3] local_point_intensities = local_points_with_intensities[:, 3] # Global properties. points = ct.project.homo_project(local_points, lidar_pose) point_intensities = local_point_intensities predict_dict = { # Frame properties. "pano": pano, "intensities": intensities, # Global properties. "points": points, "point_intensities": point_intensities, # Local properties. "local_points": local_points, "local_point_intensities": local_point_intensities, # Hit properties: unfiltered results from ray casting. "hit_dict": nvs_frame["hit_dict"], } return predict_dict def intersect_rays(self, rays): """ Compute ray-mesh intersect and return the hit_dict. The hit_dict will NOT be filtered, but the masks will be provided. Args: mesh: o3d.geometry.TriangleMesh rays: (N, 6), float32 rays, where rays[:, :3] is the origin and rays[:, 3:] is the direction. The directions do not need to be normalized. Return: hit_dict - ["masks"] : (N, ) , boolean mask of ray hit. - ["depths"] : (N, ) , depth in world-scale. - ["points"] : (N, 3), coordinates of the intersection points. - ["normals"]: (N, 3), normals of the hit triangles. """ # Sanity checks. if not isinstance(rays, np.ndarray): raise TypeError("rays must be a numpy array.") if rays.ndim != 2 or rays.shape[1] != 6: raise ValueError("rays must be a (N, 6) array.") # Run ray cast. ray_cast_results = self.raycasting_scene.cast_rays(o3c.Tensor(rays)) normals = ray_cast_results["primitive_normals"].numpy() depths = ray_cast_results["t_hit"].numpy() masks = depths != np.inf rays_o = rays[:, :3] rays_d = rays[:, 3:] rays_d = rays_d / np.linalg.norm(rays_d, axis=1, keepdims=True) points = rays_o + rays_d * depths[:, None] hit_dict = { "masks": masks, "depths": depths, "points": points, "normals": normals, } return hit_dict def intersect_lidar( self, lidar_K: np.ndarray, # (2, ) lidar_pose: np.ndarray, # (4, 4) lidar_H: int, lidar_W: int, ): ray_dict = get_lidar_rays( poses=torch.tensor(np.array([lidar_pose])), intrinsics=torch.tensor(lidar_K), H=torch.tensor(lidar_H), W=torch.tensor(lidar_W), ) rays_o = ray_dict["rays_o"].squeeze().numpy() rays_d = ray_dict["rays_d"].squeeze().numpy() rays = np.concatenate([rays_o, rays_d], axis=-1) hit_dict = self.intersect_rays(rays) return hit_dict def generate_raydrop_data_meshing(dataset, nvs: LidarNVSMeshing) -> dict: """ Prepare dataset for learning ray drop. The frames are NOT loaded by our dataset, but GENERATED. Return: raydrop_data = [ { "hit_masks" : (H, W) # Ray cast hit mask "hit_depths" : (H, W) # Hit intersection point depths "hit_normals" : (H, W, 3) # Intersection point normal "hit_incidences" : (H, W) # |cos(normal, ray_d)| "intensities" : (H, W) # Predicted intensities "rays_o" : (H, W, 3) # Lidar ray origin "rays_d" : (H, W, 3) # Lidar ray direction "raydrop_masks" : (H, W) # Ray drop mask, 1 is valid }, ... ] """ raydrop_data = [] for frame_idx in tqdm(range(len(dataset)), desc="Prepare raydrop dataset"): gt_frame = extract_dataset_frame(dataset, frame_idx=frame_idx) nvs_frame = nvs.predict_frame( lidar_K=gt_frame["lidar_K"], lidar_pose=gt_frame["lidar_pose"], lidar_H=gt_frame["lidar_H"], lidar_W=gt_frame["lidar_W"], ) # The target. raydrop_masks = gt_frame["pano"] != 0 # Compute incidence angle cosine. rays_o = gt_frame["rays"][:, :3] rays_d = gt_frame["rays"][:, 3:] hit_normals = nvs_frame["hit_dict"]["normals"] hit_incidences = np.abs(np.sum(rays_d * hit_normals, axis=-1)) # Pre-processing. # TODO: move the reshape to upper-level lidar_H, lidar_W = gt_frame["lidar_H"], gt_frame["lidar_W"] # Reshape. hit_masks = nvs_frame["hit_dict"]["masks"] hit_masks = hit_masks.reshape((lidar_H, lidar_W)) hit_depths = nvs_frame["hit_dict"]["depths"] hit_depths[hit_depths == np.inf] = 0 hit_depths = hit_depths.reshape((lidar_H, lidar_W)) hit_normals = hit_normals.reshape((lidar_H, lidar_W, 3)) hit_incidences = hit_incidences.reshape((lidar_H, lidar_W)) intensities = nvs_frame["intensities"] intensities = intensities.reshape((lidar_H, lidar_W)) rays_o = rays_o.reshape((lidar_H, lidar_W, 3)) rays_d = rays_d.reshape((lidar_H, lidar_W, 3)) raydrop_masks = raydrop_masks.reshape((lidar_H, lidar_W)) # Cast. hit_masks = hit_masks.astype(np.float32) hit_depths = hit_depths.astype(np.float32) hit_normals = hit_normals.astype(np.float32) hit_incidences = hit_incidences.astype(np.float32) intensities = intensities.astype(np.float32) rays_o = rays_o.astype(np.float32) rays_d = rays_d.astype(np.float32) raydrop_masks = raydrop_masks.astype(np.float32) raydrop_datum = { "hit_masks": hit_masks, "hit_depths": hit_depths, "hit_normals": hit_normals, "hit_incidences": hit_incidences, "intensities": intensities, "rays_o": rays_o, "rays_d": rays_d, "raydrop_masks": raydrop_masks, } raydrop_data.append(raydrop_datum) return raydrop_data ================================================ FILE: lidarnvs/lidarnvs_nksr.py ================================================ import open3d as o3d import numpy as np from lidarnvs.lidarnvs_meshing import LidarNVSMeshing import torch import nksr class LidarNVSNksr(LidarNVSMeshing): def __init__(self, ckpt_path=None): super(LidarNVSNksr, self).__init__(ckpt_path=ckpt_path) # To be filled in the fit() method. self.points = None self.point_intensities = None self.pcd = None self.kdtree = None self.mesh = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.nksr_reconstructor = nksr.Reconstructor(self.device) # self.meshing_func shall be pre-filled with functools.partial. self.meshing_func = self._run_nksr def _run_nksr( self, pcd: o3d.geometry.PointCloud, ) -> o3d.geometry.TriangleMesh: print("Start _run_nksr()") pcd.estimate_normals() input_xyz = torch.from_numpy(np.asarray(pcd.points)).float().to(self.device) input_normal = torch.from_numpy(np.asarray(pcd.normals)).float().to(self.device) field = self.nksr_reconstructor.reconstruct( input_xyz, input_normal, detail_level=0.5 ) mesh = field.extract_dual_mesh(mise_iter=1) vertices = mesh.v.cpu().numpy() triangles = mesh.f.cpu().numpy() mesh = o3d.geometry.TriangleMesh() mesh.vertices = o3d.utility.Vector3dVector(vertices) mesh.triangles = o3d.utility.Vector3iVector(triangles) mesh.compute_vertex_normals() return mesh ================================================ FILE: lidarnvs/lidarnvs_pcgen.py ================================================ import camtools as ct import numpy as np import torch from tqdm import tqdm from lidarnerf.convert import ( lidar_to_pano_with_intensities, lidar_to_pano_with_intensities_fpa, pano_to_lidar_with_intensities, ) from lidarnvs.loader import extract_dataset_frame from lidarnvs.raydrop_train_pcgen import RayDrop, run_network, get_embedder from lidarnvs.lidarnvs_base import LidarNVSBase class LidarNVSPCGen(LidarNVSBase): def __init__(self, raycasting="cp", ckpt_path=None): self.raycasting = raycasting # Network for predicting raydrop. if ckpt_path is not None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.embed_fn, input_ch = get_embedder(4, input_dims=1, i=-1) self.embeddirs_fn, input_ch_views = get_embedder(10, input_dims=3, i=-1) total_input_ch = input_ch * 2 + input_ch_views netdepth, netwidth = 4, 128 self.model = RayDrop(D=netdepth, W=netwidth, input_ch=total_input_ch).to( self.device ) ckpt = torch.load(ckpt_path) self.model.load_state_dict(ckpt["network_fn_state_dict"]) self.model.eval() print(f"Checkpoint loaded from {ckpt_path}") def fit(self, dataset) -> None: """ Fit the model to the given train dataset. Args: dataset: A NeRFDataset object. """ # Extract all points, in world coordinates. num_frames = len(dataset) all_points = [] all_point_intensities = [] for frame_idx in tqdm(range(num_frames), "Extract train frames"): frame_dict = extract_dataset_frame(dataset, frame_idx) all_points.append(frame_dict["points"]) all_point_intensities.append(frame_dict["point_intensities"]) all_points = np.vstack(all_points) all_point_intensities = np.hstack(all_point_intensities) assert len(all_points) == len(all_point_intensities) # Save points and intensities for interpolation. self.points = all_points self.point_intensities = all_point_intensities def predict_frame( self, lidar_K: np.ndarray, # (2, ) lidar_pose: np.ndarray, # (4, 4) lidar_H: int, lidar_W: int, ) -> dict: """ Predict (synthesis) the point cloud from the given lidar parameters. All necessary information parameters to model a lidar are given. Args: lidar_K: (2, ), float32 lidar_pose: (4, 4), float32 lidar_H: int lidar_W: int Return: predict_dict: dict - ["local_points"]: (N, 3), float32 - ["points"] : (N, 3), float32 - ["pano"] : (H, W), float32 - ["intensities"] : (H, W), float32 """ # In world and local coordinates. local_points = ct.project.homo_project( self.points, ct.convert.pose_to_T(lidar_pose), ) # Pano intensities. local_points_with_intensities = np.concatenate( [local_points, self.point_intensities.reshape((-1, 1))], axis=1 ) if self.raycasting == "cp": pano, intensities = lidar_to_pano_with_intensities( local_points_with_intensities=local_points_with_intensities, lidar_H=lidar_H, lidar_W=lidar_W, lidar_K=lidar_K, ) elif self.raycasting == "fpa": pano, intensities = lidar_to_pano_with_intensities_fpa( local_points_with_intensities=local_points_with_intensities, lidar_H=lidar_H, lidar_W=lidar_W, lidar_K=lidar_K, ) local_points_with_intensities = pano_to_lidar_with_intensities( pano=pano, intensities=intensities, lidar_K=lidar_K ) local_points = local_points_with_intensities[:, :3] local_point_intensities = local_points_with_intensities[:, 3] points = ct.project.homo_project(local_points, lidar_pose) point_intensities = local_point_intensities predict_dict = { # Frame properties. "pano": pano, "intensities": intensities, # Global properties. "points": points, "point_intensities": point_intensities, # Local properties. "local_points": local_points, "local_point_intensities": local_point_intensities, } return predict_dict @torch.inference_mode() def predict_frame_with_raydrop( self, lidar_K: np.ndarray, # (2, ) lidar_pose: np.ndarray, # (4, 4) lidar_H: int, lidar_W: int, ) -> dict: nvs_frame = self.predict_frame( lidar_K=lidar_K, lidar_pose=lidar_pose, lidar_H=lidar_H, lidar_W=lidar_W, ) direction = get_direction(lidar_H, lidar_W, lidar_K) pano = nvs_frame["pano"] intensity = nvs_frame["intensities"] rays_val = np.concatenate( ( np.array(direction).reshape(-1, 3), np.array(pano).reshape(-1, 1), np.array(intensity).reshape(-1, 1), ), -1, ) rays_val = torch.Tensor(rays_val).to(self.device) pd_raydrop_masks = run_network( rays_val, self.model, self.embed_fn, self.embeddirs_fn ) pd_raydrop_masks = np.where( pd_raydrop_masks.cpu().numpy() > 0.5, 1.0, 0.0 ).reshape(lidar_H, lidar_W) # Update predict_dict # Frame properties. pano = nvs_frame["pano"] intensities = nvs_frame["intensities"] if not np.all(pd_raydrop_masks == 0): pano = pano * pd_raydrop_masks intensities = intensities * pd_raydrop_masks # Local properties. local_points_with_intensities = pano_to_lidar_with_intensities( pano=pano, intensities=intensities, lidar_K=lidar_K, ) local_points = local_points_with_intensities[:, :3] local_point_intensities = local_points_with_intensities[:, 3] # Global properties. points = ct.project.homo_project(local_points, lidar_pose) point_intensities = local_point_intensities predict_dict = { # Frame properties. "pano": pano, "intensities": intensities, # Global properties. "points": points, "point_intensities": point_intensities, # Local properties. "local_points": local_points, "local_point_intensities": local_point_intensities, } return predict_dict def generate_raydrop_data_pcgen(dataset, nvs: LidarNVSPCGen, rm_pano_mask=True) -> dict: """ Prepare dataset for learning ray drop. The frames are NOT loaded by our dataset, but GENERATED. Return: directions, panos, intensities, raydrop_masks """ raydrop_masks = [] directions = [] panos = [] intensities = [] for frame_idx in tqdm(range(len(dataset)), desc="Prepare raydrop dataset"): gt_frame = extract_dataset_frame( dataset, frame_idx=frame_idx, rm_pano_mask=rm_pano_mask ) nvs_frame = nvs.predict_frame( lidar_K=gt_frame["lidar_K"], lidar_pose=gt_frame["lidar_pose"], lidar_H=gt_frame["lidar_H"], lidar_W=gt_frame["lidar_W"], ) # The target. raydrop_masks.append(gt_frame["pano"]) # The inputs lidar_H, lidar_W, lidar_K = ( gt_frame["lidar_H"], gt_frame["lidar_W"], gt_frame["lidar_K"], ) directions.append(get_direction(lidar_H, lidar_W, lidar_K)) panos.append(nvs_frame["pano"]) intensities.append(nvs_frame["intensities"]) return (directions, panos, intensities, raydrop_masks) def get_direction(lidar_H, lidar_W, lidar_K): fov_up, fov = lidar_K i, j = np.meshgrid( np.arange(lidar_W, dtype=np.float32), np.arange(lidar_H, dtype=np.float32), indexing="xy", ) beta = -(i - lidar_W / 2) / lidar_W * 2 * np.pi alpha = (fov_up - j / lidar_H * fov) / 180 * np.pi dirs = np.stack( [np.cos(alpha) * np.cos(beta), np.cos(alpha) * np.sin(beta), np.sin(alpha)], -1 ) return dirs ================================================ FILE: lidarnvs/lidarnvs_poisson.py ================================================ import time import open3d as o3d import numpy as np from lidarnvs.lidarnvs_meshing import LidarNVSMeshing import functools class LidarNVSPoisson(LidarNVSMeshing): @staticmethod def _run_poisson( pcd: o3d.geometry.PointCloud, depth: int, min_density: int, ) -> o3d.geometry.TriangleMesh: print("Start _run_poisson()") s_time = time.time() # Run. mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson( pcd, depth=depth, ) # Filter by density. vertices_to_remove = densities < np.quantile(densities, min_density) mesh.remove_vertices_by_mask(vertices_to_remove) # All-black colors are generated, but we don't need them. mesh.vertex_colors = o3d.utility.Vector3dVector([]) print(f"_run_poisson() time: {time.time() - s_time:.3f} secs") return mesh def __init__( self, poisson_depth=10, poisson_min_density=0.6, intensity_interpolate_k=5, ckpt_path=None, ): super(LidarNVSPoisson, self).__init__(ckpt_path=ckpt_path) self.poisson_depth = poisson_depth self.poisson_min_density = poisson_min_density self.intensity_interpolate_k = intensity_interpolate_k # To be filled in the fit() method. self.points = None self.point_intensities = None self.pcd = None self.kdtree = None self.mesh = None # self.meshing_func shall be pre-filled with functools.partial. self.meshing_func = functools.partial( LidarNVSPoisson._run_poisson, depth=self.poisson_depth, min_density=self.poisson_min_density, ) ================================================ FILE: lidarnvs/loader.py ================================================ import camtools as ct import numpy as np from lidarnerf.dataset.base_dataset import get_lidar_rays from lidarnerf.convert import pano_to_lidar_with_intensities def extract_dataset_frame( dataset, frame_idx: int, rm_pano_mask: bool = True, verbose: bool = False ) -> dict: """ Extract a single frame from a dataset object. """ # Unpack dataset. lidar_pose = dataset.poses_lidar[frame_idx].numpy() pano = dataset.images_lidar[frame_idx][:, :, 2].numpy() intensities = dataset.images_lidar[frame_idx][:, :, 1].numpy() lidar_K = dataset.intrinsics_lidar lidar_H = dataset.H_lidar lidar_W = dataset.W_lidar # Process pano mask. # TODO: remove this. pano_mask = pano != -1 if rm_pano_mask: pano[pano == -1] = 0 # Load rays. ray_dict = get_lidar_rays( poses=dataset.poses_lidar[[frame_idx]], intrinsics=dataset.intrinsics_lidar, H=dataset.H_lidar, W=dataset.W_lidar, N=-1, patch_size=1, ) rays_o = ray_dict["rays_o"].squeeze().numpy() rays_d = ray_dict["rays_d"].squeeze().numpy() rays = np.concatenate([rays_o, rays_d], axis=-1) # Generate gt data. # pose: cam to world projection matrix. # T : world to cam projection matrix. # (N, 4) local_points_with_intensities = pano_to_lidar_with_intensities( pano=pano, intensities=intensities, lidar_K=lidar_K, ) local_points = local_points_with_intensities[:, :3] local_point_intensities = local_points_with_intensities[:, 3] # Project local to world coordinates. points = ct.project.homo_project(local_points, lidar_pose) point_intensities = local_point_intensities # "pano" : invalid points marked as 0 (depth). # "intensities": 0 means likely invalid, but not 100%. frame_dict = { "rays": rays, "lidar_pose": lidar_pose, "lidar_K": lidar_K, "lidar_H": lidar_H, "lidar_W": lidar_W, # Frame properties. "pano": pano, "pano_mask": pano_mask, "intensities": intensities, # Local coord properties. "local_points": local_points, "local_point_intensities": local_point_intensities, # World coord properties. "points": points, "point_intensities": point_intensities, } if verbose: for key, val in frame_dict.items(): if isinstance(val, np.ndarray): print(f"- {key}: {val.shape}") else: print(f"- {key}: {val}") return frame_dict ================================================ FILE: lidarnvs/plot_possion_grid_search.py ================================================ from pathlib import Path import json import matplotlib.pyplot as plt import numpy as np def main(): json_path = Path("poisson_grid_search.json") with open(json_path, "r") as f: data = json.load(f) min_chamfer = 1e10 min_datum = None for datum in data: if datum["chamfer"] < min_chamfer: min_chamfer = datum["chamfer"] min_datum = datum print(f"min_chamfer: {min_chamfer}") print(f"min_datum: {min_datum}") # Fill confusion matrix. col_vals = [8, 9, 10, 11, 12] row_vals = [0.4, 0.3, 0.2] conf_matrix = np.zeros((len(row_vals), len(col_vals))) for datum in data: min_density = datum["poisson_min_density"] poisson_depth = datum["poisson_depth"] if min_density not in row_vals or poisson_depth not in col_vals: continue row_idx = row_vals.index(min_density) col_idx = col_vals.index(poisson_depth) conf_matrix[row_idx, col_idx] = datum["chamfer"] # Print the confusion matrix using Matplotlib fig, ax = plt.subplots(figsize=(7.5, 7.5)) ax.matshow(conf_matrix, cmap=plt.cm.Blues, alpha=0.3) for i in range(conf_matrix.shape[0]): for j in range(conf_matrix.shape[1]): ax.text( x=j, y=i, s=f"{conf_matrix[i, j]:.2f}", va="center", ha="center", size="xx-large", ) ax.set_xticklabels([""] + [str(v) for v in col_vals]) ax.set_yticklabels([""] + [str(v) for v in row_vals]) plt.xlabel("Poisson Depth", fontsize=18) plt.ylabel("Min Density", fontsize=18) plt.show() if __name__ == "__main__": main() ================================================ FILE: lidarnvs/raydrop_dataset_poisson.py ================================================ import pickle from pathlib import Path import torch from torch.utils.data import Dataset class RaydropDataset(Dataset): def __init__(self, data_dir, split): self.data_dir = Path(data_dir) self.split = split if not self.data_dir.is_dir(): raise ValueError(f"Directory {self.data_dir} does not exist.") if self.split not in ["train", "test"]: raise ValueError(f"Split {self.split} not supported.") pkl_path = self.data_dir / f"{self.split}_data.pkl" if not pkl_path.is_file(): raise ValueError(f"File {pkl_path} does not exist.") with open(pkl_path, "rb") as f: self.raydrop_data = pickle.load(f) def __len__(self): return len(self.raydrop_data) def __getitem__(self, idx): return self.raydrop_data[idx] @staticmethod def collate_fn(batch): """ RaydropDataset is a dict-style dataset, where __getitem__(i) returns a dictionary of tensors. Essentially, A dataloader will do: ```python for indices in batch_sampler: yield collate_fn([dataset[i] for i in indices]) ``` Args: batch: list of dicts. Return: images: (N, C, H, W) tensor, float32. masks : (N, H, W) tensor, float32. 1 means valid ray. """ # First, call the default colllate_fn. batch = torch.utils.data.default_collate(batch) # (N, H, W, C) images = torch.cat( [ batch["hit_masks"][..., None], batch["hit_depths"][..., None], batch["hit_normals"], batch["hit_incidences"][..., None], batch["intensities"][..., None], batch["rays_d"], ], dim=3, ) # (N, C, H, W) images = images.permute(0, 3, 1, 2) # (N, H, W) masks = batch["raydrop_masks"] return images, masks ================================================ FILE: lidarnvs/raydrop_train_pcgen.py ================================================ import os import numpy as np import imageio import random import torch import torch.nn as nn import matplotlib.pyplot as plt import torch.nn.functional as F from pathlib import Path import pickle l1loss = nn.L1Loss(reduction="mean") mseloss = nn.MSELoss() img2mse = lambda x, y: torch.mean((x - y) ** 2) to8b = ( lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) if np.max(x) < 10 else (255.0 * np.clip(x / 81.0, 0, 1)).astype(np.uint8) ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def setup_seed(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) torch.backends.cudnn.deterministic = True setup_seed(0) def cal_psnr(im1, im2): mse = (np.abs(im1 - im2) ** 2).mean() psnr = -10 * np.log10(mse) # max_value = 1 return psnr class RayDrop(nn.Module): def __init__(self, D=4, W=128, input_ch=3, output_ch=1): """ """ super(RayDrop, self).__init__() self.D = D self.W = W self.input_ch = input_ch self.linears = nn.ModuleList( [nn.Linear(input_ch, W)] + [nn.Linear(W, W) for i in range(D - 1)] ) self.output_linear = nn.Linear(W, output_ch) self.linears.apply(weights_init) self.output_linear.apply(weights_init) def forward(self, x): h = x for i, l in enumerate(self.linears): h = self.linears[i](h) h = F.relu(h) output = self.output_linear(h) return output def weights_init(m): if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight.data) if m.bias is not None: nn.init.zeros_(m.bias.data) def config_parser(): import configargparse parser = configargparse.ArgumentParser() parser.add_argument("--config", is_config_file=True, help="config file path") parser.add_argument( "--expname", type=str, default="raysdrop", help="experiment name" ) parser.add_argument( "--basedir", type=str, default="./log", help="where to store ckpts and logs" ) parser.add_argument( "--datadir", type=str, default="/data/usr/ziguo.tt/working/nerf/data/", help="input data directory", ) parser.add_argument( "--no_reload", action="store_true", help="do not reload weights from saved ckpt" ) parser.add_argument( "--dataset", type=str, default="kitti360", choices=["kitti360", "nerfmvl"], help="The dataset loader to use.", ) # training options parser.add_argument("--netdepth", type=int, default=8, help="layers in network") parser.add_argument("--netwidth", type=int, default=256, help="channels per layer") parser.add_argument( "--N_rand", type=int, default=2048, help="batch size (number of random rays per gradient step)", ) parser.add_argument("--lrate", type=float, default=5e-4, help="learning rate") parser.add_argument( "--lrate_decay", type=int, default=500, help="exponential learning rate decay (in 1000 steps)", ) parser.add_argument( "--no_batching", action="store_true", help="only take random rays from 1 image at a time", ) parser.add_argument( "--ft_path", type=str, default=None, help="specific weights npy file to reload for coarse network", ) # rendering options parser.add_argument( "--multires", type=int, default=10, help="log2 of max freq for positional encoding (3D location)", ) parser.add_argument( "--multires_views", type=int, default=4, help="log2 of max freq for positional encoding (2D direction)", ) parser.add_argument( "--i_embed", type=int, default=-1, help="set 1 for hashed embedding, 0 for default positional encoding, 2 for spherical", ) parser.add_argument( "--i_embed_views", type=int, default=-1, help="set 1 for hashed embedding, 0 for default positional encoding, 2 for spherical", ) parser.add_argument( "--render_test", action="store_true", help="render the test set instead of render_poses path", ) # logging/saving options parser.add_argument( "--i_print", type=int, default=100, help="frequency of console printout and metric loggin", ) parser.add_argument( "--i_weights", type=int, default=10000, help="frequency of weight ckpt saving" ) parser.add_argument( "--i_save", type=int, default=1000, help="frequency of rays saving" ) # lidar nerf parser.add_argument("--N_iters", type=int, default=500000) parser.add_argument("--H", type=int, default=66) parser.add_argument("--W", type=int, default=1030) # lr parser.add_argument("--cosLR", action="store_true") parser.add_argument( "--coslrate", type=float, default=5e-4, help="init learning rate" ) parser.add_argument( "--cosminlrate", type=float, default=5e-5, help="min learning rate" ) parser.add_argument("--warmup_iters", type=int, default=1000) # loss type parser.add_argument( "--rgb_loss_type", type=str, default="img2mse", help="options: img2mse / mseloss / l1loss", ) return parser def cosine_scheduler( base_value, final_value, globel_step, warmup_iters=0, start_warmup_value=0 ): warmup_schedule = np.array([]) if warmup_iters > 0: warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) iters = np.arange(globel_step - warmup_iters) schedule = final_value + 0.5 * (base_value - final_value) * ( 1 + np.cos(np.pi * iters / len(iters)) ) schedule = np.concatenate((warmup_schedule, schedule)) assert len(schedule) == globel_step return schedule def get_embedder(multires, input_dims=3, i=0): if i == -1: return nn.Identity(), input_dims elif i == 0: embed_kwargs = { "include_input": True, "input_dims": input_dims, "max_freq_log2": multires - 1, "num_freqs": multires, "log_sampling": True, "periodic_fns": [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) embed = lambda x, eo=embedder_obj: eo.embed(x) out_dim = embedder_obj.out_dim return embed, out_dim class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs["input_dims"] out_dim = 0 if self.kwargs["include_input"]: embed_fns.append(lambda x: x) # embed_fns.append(lambda x : x/torch.norm(x, dim=-1, keepdim=True)) out_dim += d max_freq = self.kwargs["max_freq_log2"] N_freqs = self.kwargs["num_freqs"] if self.kwargs["log_sampling"]: freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs) else: freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs) for freq in freq_bands: for p_fn in self.kwargs["periodic_fns"]: embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def run_network(inputs, model, embed_fn, embeddirs_fn): """Prepares inputs and applies network 'fn'.""" ray_direction, depth, intensity = inputs[:, :3], inputs[:, 3], inputs[:, 4] embedded_depth = embed_fn(depth.unsqueeze(1)) embedded_intensity = embed_fn(intensity.unsqueeze(1)) embedded_dirs = embeddirs_fn(ray_direction) input = torch.cat((embedded_dirs, embedded_depth, embedded_intensity), 1) outputs = model(input) return outputs def load_pkl_data(data_dir, split): if not data_dir.is_dir(): raise ValueError(f"Directory {data_dir} does not exist.") if split not in ["train", "test"]: raise ValueError(f"Split {split} not supported.") pkl_path = data_dir / f"{split}_data.pkl" if not pkl_path.is_file(): raise ValueError(f"File {pkl_path} does not exist.") with open(pkl_path, "rb") as f: raydrop_data = pickle.load(f) return raydrop_data def train(): parser = config_parser() args = parser.parse_args() cosLR = args.cosLR loss_dict = {"img2mse": img2mse, "mseloss": mseloss, "l1loss": l1loss} H = args.H W = args.W # load dataset data_dir = Path(args.datadir) (directions, panos, intensities, raydrop_masks) = load_pkl_data(data_dir, "train") # print( # np.array(directions).shape, # np.array(panos).shape, # np.array(intensities).shape, # np.array(raydrop_masks).shape) rays_all = np.concatenate( ( np.array(directions).reshape(-1, 3), np.array(panos).reshape(-1, 1), np.array(intensities).reshape(-1, 1), ), -1, ) raydrop_masks = np.array(raydrop_masks) rays_all = rays_all[raydrop_masks.reshape(-1) > -1] raydrop_masks = np.where(raydrop_masks[raydrop_masks > -1] == 0.0, 0.0, 1.0) rays_all = np.concatenate((rays_all, raydrop_masks.reshape(-1, 1)), -1) (directions, panos, intensities, raydrop_masks) = load_pkl_data(data_dir, "test") raydrop_val_list = [] for direction, pano, intensity, raydrop_mask in zip( directions, panos, intensities, raydrop_masks ): raydrop_val_list.append( np.concatenate( ( np.array(direction).reshape(-1, 3), np.array(pano).reshape(-1, 1), np.array(intensity).reshape(-1, 1), np.array(raydrop_mask).reshape(-1, 1), ), -1, ) ) rays_val1 = raydrop_val_list[0] raydrop_masks = np.array(raydrop_masks) mask_val1 = np.where(raydrop_masks[0] > -1, 1, 0) ray_drop_val1 = raydrop_masks[0].reshape(H, W) # Create log dir and copy the config file basedir = args.basedir expname = args.expname os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, "args.txt") with open(f, "w") as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write("{} = {}\n".format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, "config.txt") with open(f, "w") as file: file.write(open(args.config, "r").read()) # network embed_fn, input_ch = get_embedder(args.multires, input_dims=1, i=args.i_embed) embeddirs_fn, input_ch_views = get_embedder( args.multires_views, input_dims=3, i=args.i_embed_views ) total_input_ch = input_ch * 2 + input_ch_views # model model = RayDrop(D=args.netdepth, W=args.netwidth, input_ch=total_input_ch).to( device ) grad_vars = list(model.parameters()) # optimizer optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) start = 0 if args.ft_path is not None and args.ft_path != "None": ckpts = [args.ft_path] else: ckpts = [ os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if f.endswith(".tar") ] print("Found ckpts", ckpts) if len(ckpts) > 0 and not args.no_reload: ckpt_path = ckpts[-1] print("Reloading from", ckpt_path) ckpt = torch.load(ckpt_path) start = ckpt["global_step"] optimizer.load_state_dict(ckpt["optimizer_state_dict"]) # Load model model.load_state_dict(ckpt["network_fn_state_dict"]) global_step = start # Prepare raybatch tensor if batching random rays N_rand = args.N_rand use_batching = not args.no_batching if use_batching: print("shuffle rays") np.random.shuffle(rays_all) rays_all = torch.Tensor(rays_all).to(device) rays_val1 = torch.Tensor(rays_val1).to(device) if args.render_test: print("RENDER ONLY") for idx, rays_val, raydrop_mask in zip( range(len(raydrop_val_list)), raydrop_val_list, raydrop_masks ): rays_val = torch.Tensor(rays_val).to(device) with torch.no_grad(): predict_drop_val = run_network(rays_val, model, embed_fn, embeddirs_fn) imgbase = os.path.join(basedir, expname, str(idx)) mask_bbox = np.where(raydrop_mask > -1, 1, 0) predict_drop_val = ( np.where(predict_drop_val.cpu().numpy() > 0.5, 1.0, 0.0).reshape(H, W) * mask_bbox ) np.save(imgbase + "_pred_drop.npy", predict_drop_val) imageio.imsave(imgbase + "_pred_drop.png", predict_drop_val.reshape(H, W)) ray_drop_gt = np.where(raydrop_mask > 0, 1, 0) imageio.imsave(imgbase + "_gt_drop.png", ray_drop_gt.reshape(H, W)) return N_iters = args.N_iters + 1 print("Begin") loss_log = [] val_psnr = [] start = start + 1 i_batch = 0 lr_schedule = cosine_scheduler( base_value=args.coslrate, final_value=args.cosminlrate, globel_step=N_iters - 1, warmup_iters=args.warmup_iters, ) for i in range(start, N_iters): # Sample random ray batch if use_batching: # Random over all images batch = rays_all[i_batch : i_batch + N_rand] # [B, 2+1, 3*?] # ray_direction, depth, intensity, target_drop = batch[:, :3], batch[:, 3], batch[:, 4], batch[:, 5] inputs, target_drop = batch[:, :5], batch[:, 5] i_batch += N_rand if i_batch >= rays_all.shape[0]: print("Shuffle data after an epoch!") rand_idx = torch.randperm(rays_all.shape[0]) rays_all = rays_all[rand_idx] i_batch = 0 ##### Core optimization loop ##### predict_drop = run_network(inputs, model, embed_fn, embeddirs_fn) optimizer.zero_grad() rgb_loss = loss_dict[args.rgb_loss_type] loss = rgb_loss(predict_drop, target_drop.unsqueeze(1)) # loss = KL_loss_fun(predict_drop, target_drop.unsqueeze(1)) loss.backward() optimizer.step() # NOTE: IMPORTANT! ## update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) for param_group in optimizer.param_groups: if cosLR: param_group["lr"] = lr_schedule[global_step] else: param_group["lr"] = new_lrate # Rest is logging if i % args.i_weights == 0: path = os.path.join(basedir, expname, "{:06d}.tar".format(i)) ckpt = { "global_step": global_step, "network_fn_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), } torch.save(ckpt, path) print("Saved checkpoints at", path) if i % args.i_save == 0 and i > 0: # Turn in testing mode with torch.no_grad(): predict_drop_val = run_network(rays_val1, model, embed_fn, embeddirs_fn) imgbase = os.path.join(basedir, expname, "{:06d}_".format(i)) predict_drop_val = ( np.where(predict_drop_val.cpu().numpy() > 0.5, 1.0, 0.0).reshape(H, W) * mask_val1 ) psnr = cal_psnr(predict_drop_val.reshape(H, W), ray_drop_val1) print(psnr) val_psnr.append(psnr) loss_save = np.array(val_psnr) plt.plot(loss_save) plt.savefig(os.path.join(basedir, expname, "val_psnr.png")) plt.close() imageio.imsave(imgbase + "val_drop.png", predict_drop_val.reshape(H, W)) loss_log.append(loss.item()) if i % args.i_print == 0: loss_save = np.array(loss_log) plt.plot(loss_save) plt.savefig(os.path.join(basedir, expname, "loss_curve.png")) plt.close() loss_print = [loss.item()] print(f"[TRAIN] Iter: {i} Loss: {loss_print} ") global_step += 1 loss_log = np.array(loss_log) np.save(os.path.join(basedir, expname, "loss_log.npy"), loss_log) val_psnr = np.array(val_psnr) np.save(os.path.join(basedir, expname, "val_psnr.npy"), val_psnr) if __name__ == "__main__": torch.set_default_tensor_type("torch.cuda.FloatTensor") train() ================================================ FILE: lidarnvs/raydrop_train_poisson.py ================================================ import argparse import logging from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch import optim from torch.utils.data import DataLoader from tqdm import tqdm import wandb from lidarnvs.raydrop_dataset_poisson import RaydropDataset from lidarnvs.unet import UNet, dice_coeff, dice_loss, multiclass_dice_coeff @torch.inference_mode() def evaluate(net, dataloader, device, amp): net.eval() num_val_batches = len(dataloader) dice_score = 0 # Iterate over the test set with torch.autocast(device.type if device.type != "mps" else "cpu", enabled=amp): for batch in tqdm( dataloader, total=num_val_batches, desc="Test round", unit="batch", leave=False, ): images, true_masks = batch # Move images and labels to correct device and type images = images.to( device=device, dtype=torch.float32, memory_format=torch.channels_last ) true_masks = true_masks.to(device=device, dtype=torch.long) # Predict the mask mask_pred = net(images) true_masks = true_masks.reshape(mask_pred.shape) if net.n_classes == 1: assert ( true_masks.min() >= 0 and true_masks.max() <= 1 ), "True mask indices should be in [0, 1]" mask_pred = (F.sigmoid(mask_pred) > 0.5).float() # Compute the dice score dice_score += dice_coeff( mask_pred, true_masks, reduce_batch_first=False ) else: assert ( true_masks.min() >= 0 and true_masks.max() < net.n_classes ), "True mask indices should be in [0, n_classes[" # Convert to one-hot format true_masks = ( F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float() ) mask_pred = ( F.one_hot(mask_pred.argmax(dim=1), net.n_classes) .permute(0, 3, 1, 2) .float() ) # Compute the Dice score, ignoring background dice_score += multiclass_dice_coeff( mask_pred[:, 1:], true_masks[:, 1:], reduce_batch_first=False ) net.train() return dice_score / max(num_val_batches, 1) def train_model( model, data_dir, ckpt_dir, device, epochs: int = 5, batch_size: int = 1, learning_rate: float = 1e-5, save_checkpoint: bool = True, img_scale: float = 0.5, amp: bool = False, weight_decay: float = 1e-8, momentum: float = 0.999, gradient_clipping: float = 1.0, ): data_dir = Path(data_dir) ckpt_dir = Path(ckpt_dir) # Create dataset train_dataset = RaydropDataset(data_dir=data_dir, split="train") test_dataset = RaydropDataset(data_dir=data_dir, split="test") n_train = len(train_dataset) n_test = len(test_dataset) # Create data loaders train_loader = DataLoader( train_dataset, batch_size=batch_size, collate_fn=RaydropDataset.collate_fn, shuffle=True, ) test_loader = DataLoader( test_dataset, batch_size=batch_size, collate_fn=RaydropDataset.collate_fn, shuffle=True, ) # Initialize logging experiment = wandb.init(project="U-Net", resume="allow", anonymous="must") experiment.config.update( { "epochs": epochs, "batch_size": batch_size, "learning_rate": learning_rate, "save_checkpoint": save_checkpoint, "img_scale": img_scale, "amp": amp, } ) log_str = ( f"Starting training:\n" f"Epochs: {epochs}\n" f"Batch size: {batch_size}\n" f"Learning rate: {learning_rate}\n" f"Training size: {n_train}\n" f"Validation size: {n_test}\n" f"Checkpoints: {save_checkpoint}\n" f"Device: {device.type}\n" f"Images scaling: {img_scale}\n" f"Mixed Precision: {amp}\n" ) logging.info(log_str) # Set up optimizer, loss, lr_scheduler, loss scaling. optimizer = optim.RMSprop( model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True, ) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, "max", patience=5 ) # goal: maximize Dice score grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss() global_step = 0 # 5. Begin training for epoch in range(1, epochs + 1): model.train() epoch_loss = 0 with tqdm(total=n_train, desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar: for batch in train_loader: images, true_masks = batch if images.shape[1] != model.n_channels: raise ValueError( f"Input channel mismatch: " f"{images.shape[1]} vs {model.n_channels}" ) images = images.to( device=device, dtype=torch.float32, memory_format=torch.channels_last, ) true_masks = true_masks.to(device=device, dtype=torch.long) with torch.autocast( device.type if device.type != "mps" else "cpu", enabled=amp ): masks_pred = model(images) if model.n_classes == 1: loss = criterion(masks_pred.squeeze(1), true_masks.float()) loss += dice_loss( F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False, ) else: loss = criterion(masks_pred, true_masks) loss += dice_loss( F.softmax(masks_pred, dim=1).float(), F.one_hot(true_masks, model.n_classes) .permute(0, 3, 1, 2) .float(), multiclass=True, ) optimizer.zero_grad(set_to_none=True) grad_scaler.scale(loss).backward() torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping) grad_scaler.step(optimizer) grad_scaler.update() pbar.update(images.shape[0]) global_step += 1 epoch_loss += loss.item() experiment.log( {"train loss": loss.item(), "step": global_step, "epoch": epoch} ) pbar.set_postfix(**{"loss (batch)": loss.item()}) # Evaluation round division_step = n_train // (5 * batch_size) if division_step > 0: if global_step % division_step == 0: histograms = {} for tag, value in model.named_parameters(): tag = tag.replace("/", ".") if not (torch.isinf(value) | torch.isnan(value)).any(): histograms["Weights/" + tag] = wandb.Histogram( value.data.cpu() ) if not ( torch.isinf(value.grad) | torch.isnan(value.grad) ).any(): histograms["Gradients/" + tag] = wandb.Histogram( value.grad.data.cpu() ) val_score = evaluate(model, test_loader, device, amp) scheduler.step(val_score) logging.info("Validation Dice score: {}".format(val_score)) try: experiment.log( { "learning rate": optimizer.param_groups[0]["lr"], "validation Dice": val_score, "images": wandb.Image(images[0].cpu()), "masks": { "true": wandb.Image( true_masks[0].float().cpu() ), "pred": wandb.Image( masks_pred.argmax(dim=1)[0].float().cpu() ), }, "step": global_step, "epoch": epoch, **histograms, } ) except: pass if save_checkpoint: checkpoint_path = ckpt_dir / f"checkpoint_{epoch:03}.pth" checkpoint_path.parent.mkdir(parents=True, exist_ok=True) state_dict = model.state_dict() torch.save(state_dict, checkpoint_path) logging.info(f"Checkpoint {epoch} saved!") def get_args(): parser = argparse.ArgumentParser( description="Train the UNet on images and target masks" ) parser.add_argument( "--data_dir", type=str, default="N/A", help="Path to the raydrop dataset." ) parser.add_argument( "--ckpt_dir", type=str, default="N/A", help="Path to the checkpoint directory." ) parser.add_argument("--epochs", "-e", type=int, default=10, help="Number of epochs") parser.add_argument( "--batch-size", "-b", dest="batch_size", type=int, default=2, help="Batch size" ) parser.add_argument( "--learning-rate", "-l", type=float, default=1e-5, help="Learning rate", dest="lr", ) parser.add_argument( "--load", "-f", type=str, default=False, help="Load model from a .pth file" ) parser.add_argument( "--scale", "-s", type=float, default=0.5, help="Downscaling factor of the images", ) parser.add_argument( "--amp", action="store_true", default=False, help="Use mixed precision" ) parser.add_argument( "--bilinear", action="store_true", default=False, help="Use bilinear upsampling" ) parser.add_argument( "--classes", "-c", type=int, default=1, help="Number of classes" ) return parser.parse_args() def main(): args = get_args() logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device {device}") # - n_channels: 10 # - [0] : hit_masks # - [1] : hit_depths # - [2:5] : hit_normals (world), TODO: change to local coord # - [5] : hit_incidences in cosine # - [6] : intensities # - [7:10]: rays_d, TODO: change to local coord # - n_classes: 1 # - number of probabilities you want to get per pixel # - raydrop_masks has only 1 channel model = UNet(n_channels=10, n_classes=args.classes, bilinear=args.bilinear) model = model.to(memory_format=torch.channels_last) logging.info( f"Network:\n" f"\t{model.n_channels} input channels\n" f"\t{model.n_classes} output channels (classes)\n" f"\t{'Bilinear' if model.bilinear else 'Transposed conv'} upscaling" ) if args.load: state_dict = torch.load(args.load, map_location=device) model.load_state_dict(state_dict) logging.info(f"Model loaded from {args.load}") model.to(device=device) train_model( model=model, data_dir=args.data_dir, ckpt_dir=args.ckpt_dir, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.lr, device=device, img_scale=args.scale, amp=args.amp, ) if __name__ == "__main__": main() ================================================ FILE: lidarnvs/readme.md ================================================ # Lidar Novel View Synthesis Baselines ![baseline_render](../assets/baseline-render.png) ## LidarSim ```bash # Generate raydrop dataset. python lidarnvs/run.py --dataset kitti360 --sequence_id "1538" --enable_collect_raydrop_dataset python lidarnvs/run.py --dataset kitti360 --sequence_id "1728" --enable_collect_raydrop_dataset python lidarnvs/run.py --dataset kitti360 --sequence_id "1908" --enable_collect_raydrop_dataset python lidarnvs/run.py --dataset kitti360 --sequence_id "3353" --enable_collect_raydrop_dataset # Train the raydrop model. python raydrop_train.py --data_dir data/raydrop/kitti360_1538 --ckpt_dir log/raydrop/kitti360_1538 python raydrop_train.py --data_dir data/raydrop/kitti360_1728 --ckpt_dir log/raydrop/kitti360_1728 python raydrop_train.py --data_dir data/raydrop/kitti360_1908 --ckpt_dir log/raydrop/kitti360_1908 python raydrop_train.py --data_dir data/raydrop/kitti360_3353 --ckpt_dir log/raydrop/kitti360_3353 # Run lidarnvs again, now with raydrop model. python lidarnvs/run.py --dataset kitti360 --sequence_id "1538" python lidarnvs/run.py --dataset kitti360 --sequence_id "1728" python lidarnvs/run.py --dataset kitti360 --sequence_id "1908" python lidarnvs/run.py --dataset kitti360 --sequence_id "3353" # lidarnvs on NeRF-MVL python lidarnvs/run.py --dataset nerf_mvl --sequence_id "bollard" python lidarnvs/run.py --dataset nerf_mvl --sequence_id "car" python lidarnvs/run.py --dataset nerf_mvl --sequence_id "pedestrian" python lidarnvs/run.py --dataset nerf_mvl --sequence_id "pier" python lidarnvs/run.py --dataset nerf_mvl --sequence_id "plant" python lidarnvs/run.py --dataset nerf_mvl --sequence_id "tire" python lidarnvs/run.py --dataset nerf_mvl --sequence_id "traffic_cone" python lidarnvs/run.py --dataset nerf_mvl --sequence_id "warning_sign" python lidarnvs/run.py --dataset nerf_mvl --sequence_id "water_safety_barrier" ``` ## PCGen ### KITTI-360 ```bash python lidarnvs/run.py --dataset kitti360 --sequence_id "1908" --method "pcgen" --enable_collect_raydrop_dataset python lidarnvs/raydrop_train_pcgen.py --config lidarnvs/configs/pcgen_kitti360_raydrop.txt python lidarnvs/run.py --dataset kitti360 --sequence_id "1908" --method "pcgen" --ckpt_path pcgen_raydrop_log/kitti360seq1908/raysdrop/010000.tar ``` ### LiDAR-MVL ```bash python lidarnvs/run.py --dataset nerf_mvl --sequence_id "car" --method "pcgen" --enable_collect_raydrop_dataset python lidarnvs/raydrop_train_pcgen.py --config lidarnvs/configs/pcgen_nerfmvl_raydrop.txt python lidarnvs/run.py --dataset nerf_mvl --sequence_id "car" --method "pcgen" --ckpt_path pcgen_raydrop_log/car/010000.tar ``` ================================================ FILE: lidarnvs/run.py ================================================ from pathlib import Path import numpy as np from lidarnvs.lidarnvs_pcgen import LidarNVSPCGen, generate_raydrop_data_pcgen from lidarnvs.lidarnvs_poisson import LidarNVSPoisson from lidarnvs.lidarnvs_nksr import LidarNVSNksr from lidarnvs.lidarnvs_meshing import generate_raydrop_data_meshing from lidarnvs.loader import extract_dataset_frame from lidarnvs.eval import eval_points_and_pano from tqdm import tqdm import pickle import argparse from lidarnerf.dataset.kitti360_dataset import KITTI360Dataset from lidarnerf.dataset.nerfmvl_dataset import NeRFMVLDataset def main(): parser = argparse.ArgumentParser() parser.add_argument( "--dataset", type=str, default="kitti360", choices=["kitti360", "nerf_mvl"], help="The dataset loader to use.", ) parser.add_argument( "--method", type=str, default="poisson", choices=["poisson", "nksr", "pcgen"], help="method for lidarnvs", ) parser.add_argument( "--raycasting", type=str, default="cp", choices=["cp", "fpa"], help="raycasting mehtod", ) # dataset parser.add_argument("--path", type=str, default="data/kitti360") parser.add_argument( "--sequence_id", type=str, default="1908", help="The sequence id within the selected dataset to use.", ) parser.add_argument( "--num_rays_lidar", type=int, default=4096, help="num rays sampled per image for each training step", ) parser.add_argument( "--offset", type=float, nargs="*", default=[0, 0, 0], help="offset of location" ) parser.add_argument( "--enable_collect_raydrop_dataset", action="store_true", help="Whether to collect raydrop dataset. If not enabled, inference " "mode will be used", ) parser.add_argument( "--ckpt_path", type=str, default="", help="The ckpt of raydrop network.", ) parser.add_argument( "--poisson_depth", type=int, default=11, help="Depth of tree for Poisson recon.", ) parser.add_argument( "--poisson_min_density", type=float, default=0.3, help="Minimum density to filter points after Poisson recon.", ) args = parser.parse_args() # Check sequence id. kitti360_sequence_ids = [ "1538", "1728", "1908", "3353", ] nerf_mvl_sequence_ids = [ "bollard", "car", "pedestrian", "pier", "plant", "tire", "traffic_cone", "warning_sign", "water_safety_barrier", ] if args.dataset == "kitti360": if args.sequence_id not in kitti360_sequence_ids: raise ValueError( f"Unknown sequence id {args.sequence_id} for {args.dataset}" ) elif args.dataset == "nerf_mvl": if args.sequence_id not in nerf_mvl_sequence_ids: raise ValueError( f"Unknown sequence id {args.sequence_id} for {args.dataset}" ) else: raise ValueError(f"Unknown dataset: {args.dataset}") print("[Config]===============================================") print(f"dataset : {args.dataset}") print(f"sequence_id : {args.sequence_id}") print(f"poisson_depth : {args.poisson_depth}") print(f"poisson_min_density : {args.poisson_min_density}") print(f"dataset_collect_mode: {args.enable_collect_raydrop_dataset}") print("=======================================================") # Init train and test datasets. if args.dataset == "kitti360": train_dataset = KITTI360Dataset( split="train", root_path=args.path, offset=args.offset, num_rays_lidar=args.num_rays_lidar, sequence_id=args.sequence_id, ) test_dataset = KITTI360Dataset( split="train", root_path=args.path, offset=args.offset, num_rays_lidar=args.num_rays_lidar, sequence_id=args.sequence_id, ) elif args.dataset == "nerf_mvl": train_dataset = NeRFMVLDataset( split="train", root_path=args.path, offset=args.offset, num_rays_lidar=args.num_rays_lidar, sequence_id=args.sequence_id, ) test_dataset = NeRFMVLDataset( split="test", root_path=args.path, offset=args.offset, num_rays_lidar=args.num_rays_lidar, sequence_id=args.sequence_id, ) else: raise ValueError(f"Unknown dataset: {args.dataset}") # Train. if args.enable_collect_raydrop_dataset: ckpt_path = None else: ckpt_path = Path(args.ckpt_path) if not ckpt_path.is_file(): raise ValueError(f"ckpt_path ({ckpt_path}) does not exist.") if args.method == "poisson": nvs = LidarNVSPoisson( poisson_depth=args.poisson_depth, poisson_min_density=args.poisson_min_density, intensity_interpolate_k=9, ckpt_path=ckpt_path, ) elif args.method == "nksr": nvs = LidarNVSNksr(ckpt_path=ckpt_path) elif args.method == "pcgen": nvs = LidarNVSPCGen( raycasting=args.raycasting, ckpt_path=ckpt_path, ) else: raise ValueError(f"Unknown method: {args.method}") nvs.fit(train_dataset) exit(0) # Eval test frames. all_metrics = [] for frame_idx in tqdm(range(len(test_dataset)), desc="Eval test frames"): gt_frame = extract_dataset_frame(test_dataset, frame_idx=frame_idx) if args.enable_collect_raydrop_dataset: inference_func = nvs.predict_frame else: inference_func = nvs.predict_frame_with_raydrop pd_frame = inference_func( lidar_K=gt_frame["lidar_K"], lidar_pose=gt_frame["lidar_pose"], lidar_H=gt_frame["lidar_H"], lidar_W=gt_frame["lidar_W"], ) if args.dataset == "nerf_mvl": # Load values to be updated. gt_intensities = gt_frame["intensities"] pd_intensities = pd_frame["intensities"] gt_pano = gt_frame["pano"] pd_pano = pd_frame["pano"] # Load mask. pano_mask = gt_frame["pano_mask"] nonzero_idx = np.array(np.nonzero(pano_mask)) new_h = max(nonzero_idx[0]) - min(nonzero_idx[0]) + 1 new_w = max(nonzero_idx[1]) - min(nonzero_idx[1]) + 1 gt_intensities = gt_intensities[pano_mask].reshape(new_h, new_w) pd_intensities = pd_intensities[pano_mask].reshape(new_h, new_w) gt_intensities = gt_intensities * 255 pd_intensities = pd_intensities * 255 gt_pano = gt_pano[pano_mask].reshape(new_h, new_w) pd_pano = pd_pano[pano_mask].reshape(new_h, new_w) metrics = eval_points_and_pano( gt_local_points=gt_frame["local_points"], pd_local_points=pd_frame["local_points"], gt_intensities=gt_intensities, pd_intensities=pd_intensities, gt_pano=gt_pano, pd_pano=pd_pano, ) else: metrics = eval_points_and_pano( gt_local_points=gt_frame["local_points"], pd_local_points=pd_frame["local_points"], gt_intensities=gt_frame["intensities"], pd_intensities=pd_frame["intensities"], gt_pano=gt_frame["pano"], pd_pano=pd_frame["pano"], ) all_metrics.append(metrics) # Compute mean metrics. mean_metrics = {} for key in all_metrics[0].keys(): mean_metrics[key] = np.mean([m[key] for m in all_metrics]) print("Mean metrics:") for key in sorted(mean_metrics.keys()): print(f"- {key}: {mean_metrics[key]:.4f}") # # Visualize a single test frame. # gt_pcd = o3d.geometry.PointCloud() # gt_pcd.points = o3d.utility.Vector3dVector(gt_frame["points"]) # gt_pcd.colors = o3d.utility.Vector3dVector( # ct.colormap.query(gt_frame["point_intensities"])) # pd_pcd = o3d.geometry.PointCloud() # pd_pcd.points = o3d.utility.Vector3dVector(train_frame_nvs["points"]) # pd_pcd.colors = o3d.utility.Vector3dVector( # ct.colormap.query(train_frame_nvs["point_intensities"])) # o3d.visualization.draw_geometries([gt_pcd]) # o3d.visualization.draw_geometries([pd_pcd]) # Concat all in to big tensors. if args.enable_collect_raydrop_dataset: if args.method == "poisson" and args.dataset != "nerf_mvl": raydrop_train_data = generate_raydrop_data_meshing(train_dataset, nvs) raydrop_test_data = generate_raydrop_data_meshing(test_dataset, nvs) elif args.method == "pcgen": raydrop_train_data = generate_raydrop_data_pcgen( train_dataset, nvs, rm_pano_mask=False ) raydrop_test_data = generate_raydrop_data_pcgen( test_dataset, nvs, rm_pano_mask=False ) else: raise ValueError(f"Unknown method/dataset: {args.method}/{args.dataset}") data_dir = ( Path("data/raydrop") / args.method / f"{args.dataset}_{args.sequence_id}" ) data_dir.mkdir(parents=True, exist_ok=True) train_data_path = data_dir / "train_data.pkl" test_data_path = data_dir / "test_data.pkl" with open(train_data_path, "wb") as f: pickle.dump(raydrop_train_data, f) with open(test_data_path, "wb") as f: pickle.dump(raydrop_test_data, f) if __name__ == "__main__": main() ================================================ FILE: lidarnvs/unet.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class DoubleConv(nn.Module): """(convolution => [BN] => ReLU) * 2""" def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) def forward(self, x): return self.double_conv(x) class Down(nn.Module): """Downscaling with maxpool then double conv""" def __init__(self, in_channels, out_channels): super().__init__() self.maxpool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_channels, out_channels) ) def forward(self, x): return self.maxpool_conv(x) class Up(nn.Module): """Upscaling then double conv""" def __init__(self, in_channels, out_channels, bilinear=True): super().__init__() # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d( in_channels, in_channels // 2, kernel_size=2, stride=2 ) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad( x1, [ diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, ], ) # if you have padding issues, see # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): return self.conv(x) class UNet(nn.Module): def __init__(self, n_channels, n_classes, bilinear=False): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits def dice_coeff( input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6, ): # Average of Dice coefficient for all batches, or for a single mask assert input.size() == target.size() assert input.dim() == 3 or not reduce_batch_first if input.dim() == 2 or not reduce_batch_first: sum_dim = (-1, -2) else: sum_dim = (-1, -2, -3) inter = 2 * (input * target).sum(dim=sum_dim) sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) sets_sum = torch.where(sets_sum == 0, inter, sets_sum) dice = (inter + epsilon) / (sets_sum + epsilon) return dice.mean() def multiclass_dice_coeff( input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6, ): # Average of Dice coefficient for all classes return dice_coeff( input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon ) def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): # Dice loss (objective to minimize) between 0 and 1 fn = multiclass_dice_coeff if multiclass else dice_coeff return 1 - fn(input, target, reduce_batch_first=True) def main(): pass if __name__ == "__main__": main() ================================================ FILE: main_lidarnerf.py ================================================ import torch import configargparse import os import numpy as np from lidarnerf.nerf.utils import ( seed_everything, RMSEMeter, MAEMeter, DepthMeter, PointsMeter, Trainer, ) def get_arg_parser(): parser = configargparse.ArgumentParser() parser.add_argument( "--config", is_config_file=True, default="configs/kitti360_1908.txt", help="config file path", ) parser.add_argument("--path", type=str, default="data/kitti360") parser.add_argument( "-L", action="store_true", help="equals --fp16 --tcnn --preload" ) parser.add_argument("--test", action="store_true", help="test mode") parser.add_argument("--test_eval", action="store_true", help="test and eval mode") parser.add_argument("--workspace", type=str, default="workspace") parser.add_argument( "--cluster_summary_path", type=str, default="/summary", help="Overwrite default summary path if on cluster", ) parser.add_argument("--seed", type=int, default=0) parser.add_argument( "--dataloader", type=str, choices=("kitti360", "nerf_mvl"), default="kitti360" ) parser.add_argument("--sequence_id", type=str, default="1908") ### lidar-nerf parser.add_argument("--enable_lidar", action="store_true", help="Enable lidar.") parser.add_argument("--alpha_d", type=float, default=1e3) parser.add_argument("--alpha_r", type=float, default=1) parser.add_argument("--alpha_i", type=float, default=1) parser.add_argument("--alpha_grad_norm", type=float, default=1) parser.add_argument("--alpha_spatial", type=float, default=0.1) parser.add_argument("--alpha_tv", type=float, default=1) parser.add_argument("--alpha_grad", type=float, default=1e2) parser.add_argument("--intensity_inv_scale", type=float, default=1) parser.add_argument("--spatial_smooth", action="store_true") parser.add_argument("--grad_norm_smooth", action="store_true") parser.add_argument("--tv_loss", action="store_true") parser.add_argument("--grad_loss", action="store_true") parser.add_argument("--sobel_grad", action="store_true") parser.add_argument( "--desired_resolution", type=int, default=2048, help="TCN finest resolution at the smallest scale", ) parser.add_argument("--log2_hashmap_size", type=int, default=19) parser.add_argument("--n_features_per_level", type=int, default=2) parser.add_argument( "--num_layers", type=int, default=2, help="num_layers of sigmanet" ) parser.add_argument( "--hidden_dim", type=int, default=64, help="hidden_dim of sigmanet" ) parser.add_argument( "--geo_feat_dim", type=int, default=15, help="geo_feat_dim of sigmanet" ) parser.add_argument("--eval_interval", type=int, default=50) parser.add_argument( "--num_rays_lidar", type=int, default=4096, help="num rays sampled per image for each training step", ) parser.add_argument( "--min_near_lidar", type=float, default=0.01, help="minimum near distance for camera", ) parser.add_argument( "--depth_loss", type=str, default="l1", help="l1, bce, mse, huber" ) parser.add_argument( "--depth_grad_loss", type=str, default="l1", help="l1, bce, mse, huber" ) parser.add_argument( "--intensity_loss", type=str, default="mse", help="l1, bce, mse, huber" ) parser.add_argument( "--raydrop_loss", type=str, default="mse", help="l1, bce, mse, huber" ) parser.add_argument( "--patch_size_lidar", type=int, default=1, help="[experimental] render patches in training. " "1 means disabled, use [64, 32, 16] to enable", ) parser.add_argument( "--change_patch_size_lidar", nargs="+", type=int, default=[1, 1], help="[experimental] render patches in training. " "1 means disabled, use [64, 32, 16] to enable, change during training", ) parser.add_argument( "--change_patch_size_epoch", type=int, default=2, help="change patch_size intenvel", ) ### training options parser.add_argument( "--iters", type=int, default=30000, help="training iters", ) parser.add_argument("--lr", type=float, default=1e-2, help="initial learning rate") parser.add_argument("--ckpt", type=str, default="latest") parser.add_argument( "--num_rays", type=int, default=4096, help="num rays sampled per image for each training step", ) parser.add_argument( "--num_steps", type=int, default=768, help="num steps sampled per ray" ) parser.add_argument( "--upsample_steps", type=int, default=64, help="num steps up-sampled per ray" ) parser.add_argument( "--max_ray_batch", type=int, default=4096, help="batch size of rays at inference to avoid OOM)", ) parser.add_argument( "--patch_size", type=int, default=1, help="[experimental] render patches in training, so as to apply " "LPIPS loss. 1 means disabled, use [64, 32, 16] to enable", ) ### network backbone options parser.add_argument( "--fp16", action="store_true", help="use amp mixed precision training" ) parser.add_argument("--tcnn", action="store_true", help="use TCNN backend") ### dataset options parser.add_argument( "--color_space", type=str, default="srgb", help="Color space, supports (linear, srgb)", ) parser.add_argument( "--preload", action="store_true", help="preload all data into GPU, accelerate training but use more GPU memory", ) # (the default value is for the fox dataset) parser.add_argument( "--bound", type=float, default=2, help="assume the scene is bounded in box[-bound, bound]^3, " "if > 1, will invoke adaptive ray marching.", ) parser.add_argument( "--scale", type=float, default=0.33, help="scale camera location into box[-bound, bound]^3", ) parser.add_argument( "--offset", type=float, nargs="*", default=[0, 0, 0], help="offset of camera location", ) parser.add_argument( "--dt_gamma", type=float, default=1 / 128, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 " "to accelerate rendering (but usually with worse quality)", ) parser.add_argument( "--min_near", type=float, default=0.2, help="minimum near distance for camera" ) parser.add_argument( "--density_thresh", type=float, default=10, help="threshold for density grid to be occupied", ) parser.add_argument( "--bg_radius", type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)", ) return parser def main(): parser = get_arg_parser() opt = parser.parse_args() opt.enable_lidar = True # Check sequence id. kitti360_sequence_ids = [ "1538", "1728", "1908", "3353", ] nerf_mvl_sequence_ids = [ "bollard", "car", "pedestrian", "pier", "plant", "tire", "traffic_cone", "warning_sign", "water_safety_barrier", ] # Specify dataloader class if opt.dataloader == "kitti360": from lidarnerf.dataset.kitti360_dataset import KITTI360Dataset as NeRFDataset if opt.sequence_id not in kitti360_sequence_ids: raise ValueError( f"Unknown sequence id {opt.sequence_id} for {opt.dataloader}" ) elif opt.dataloader == "nerf_mvl": from lidarnerf.dataset.nerfmvl_dataset import NeRFMVLDataset as NeRFDataset if opt.sequence_id not in nerf_mvl_sequence_ids: raise ValueError( f"Unknown sequence id {opt.sequence_id} for {opt.dataloader}" ) else: raise RuntimeError("Should not reach here.") os.makedirs(opt.workspace, exist_ok=True) f = os.path.join(opt.workspace, "args.txt") with open(f, "w") as file: for arg in vars(opt): attr = getattr(opt, arg) file.write("{} = {}\n".format(arg, attr)) if opt.L: opt.fp16 = True opt.tcnn = True opt.preload = True if opt.patch_size > 1: # assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss." assert ( opt.num_rays % (opt.patch_size**2) == 0 ), "patch_size ** 2 should be dividable by num_rays." opt.min_near = opt.scale # hard-code, set min_near ori 1m opt.min_near_lidar = opt.scale if opt.tcnn: opt.fp16 = True assert opt.bg_radius <= 0, "background model is not implemented for --tcnn" from lidarnerf.nerf.network_tcnn import NeRFNetwork model = NeRFNetwork( encoding="hashgrid", desired_resolution=opt.desired_resolution, log2_hashmap_size=opt.log2_hashmap_size, n_features_per_level=opt.n_features_per_level, num_layers=opt.num_layers, hidden_dim=opt.hidden_dim, geo_feat_dim=opt.geo_feat_dim, bound=opt.bound, density_scale=1, min_near=opt.min_near, min_near_lidar=opt.min_near_lidar, density_thresh=opt.density_thresh, bg_radius=opt.bg_radius, ) else: from lidarnerf.nerf.network import NeRFNetwork model = NeRFNetwork( encoding="hashgrid", desired_resolution=opt.desired_resolution, log2_hashmap_size=opt.log2_hashmap_size, num_layers=opt.num_layers, hidden_dim=opt.hidden_dim, geo_feat_dim=opt.geo_feat_dim, bound=opt.bound, density_scale=1, min_near=opt.min_near, density_thresh=opt.density_thresh, bg_radius=opt.bg_radius, ) print(opt) seed_everything(opt.seed) print(model) loss_dict = { "mse": torch.nn.MSELoss(reduction="none"), "l1": torch.nn.L1Loss(reduction="none"), "bce": torch.nn.BCEWithLogitsLoss(reduction="none"), "huber": torch.nn.HuberLoss(reduction="none", delta=0.2 * opt.scale), "cos": torch.nn.CosineSimilarity(), } criterion = { "depth": loss_dict[opt.depth_loss], "raydrop": loss_dict[opt.raydrop_loss], "intensity": loss_dict[opt.intensity_loss], "grad": loss_dict[opt.depth_grad_loss], } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if opt.test or opt.test_eval: test_loader = NeRFDataset( device=device, split="test", root_path=opt.path, sequence_id=opt.sequence_id, preload=opt.preload, scale=opt.scale, offset=opt.offset, fp16=opt.fp16, patch_size_lidar=opt.patch_size_lidar, enable_lidar=opt.enable_lidar, num_rays_lidar=opt.num_rays_lidar, ).dataloader() if opt.enable_lidar: depth_metrics = [ MAEMeter(intensity_inv_scale=opt.intensity_inv_scale), RMSEMeter(), DepthMeter(scale=opt.scale), PointsMeter( scale=opt.scale, intrinsics=test_loader._data.intrinsics_lidar ), ] else: depth_metrics = [] trainer = Trainer( "lidar_nerf", opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, depth_metrics=depth_metrics, use_checkpoint=opt.ckpt, ) if test_loader.has_gt and opt.test_eval: trainer.evaluate(test_loader) # blender has gt, so evaluate it. trainer.test(test_loader, write_video=False) # test and save video trainer.save_mesh(resolution=128, threshold=10) else: optimizer = lambda model: torch.optim.Adam( model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15 ) train_loader = NeRFDataset( device=device, split="train", root_path=opt.path, sequence_id=opt.sequence_id, preload=opt.preload, scale=opt.scale, offset=opt.offset, fp16=opt.fp16, patch_size_lidar=opt.patch_size_lidar, enable_lidar=opt.enable_lidar, num_rays_lidar=opt.num_rays_lidar, ).dataloader() # decay to 0.1 * init_lr at last iter step scheduler = lambda optimizer: torch.optim.lr_scheduler.LambdaLR( optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1) ) if opt.enable_lidar: depth_metrics = [ MAEMeter(intensity_inv_scale=opt.intensity_inv_scale), RMSEMeter(), DepthMeter(scale=opt.scale), PointsMeter( scale=opt.scale, intrinsics=train_loader._data.intrinsics_lidar ), ] else: depth_metrics = [] trainer = Trainer( "lidar_nerf", opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, depth_metrics=depth_metrics, use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval, ) valid_loader = NeRFDataset( device=device, split="val", root_path=opt.path, sequence_id=opt.sequence_id, preload=opt.preload, scale=opt.scale, offset=opt.offset, fp16=opt.fp16, patch_size_lidar=opt.patch_size_lidar, enable_lidar=opt.enable_lidar, num_rays_lidar=opt.num_rays_lidar, ).dataloader() max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) print(f"max_epoch: {max_epoch}") trainer.train(train_loader, valid_loader, max_epoch) # also test test_loader = NeRFDataset( device=device, split="test", root_path=opt.path, sequence_id=opt.sequence_id, preload=opt.preload, scale=opt.scale, offset=opt.offset, fp16=opt.fp16, patch_size_lidar=opt.patch_size_lidar, enable_lidar=opt.enable_lidar, num_rays_lidar=opt.num_rays_lidar, ).dataloader() if test_loader.has_gt: trainer.evaluate(test_loader) # blender has gt, so evaluate it. trainer.test(test_loader, write_video=True) # test and save video trainer.save_mesh(resolution=128, threshold=10) if __name__ == "__main__": main() ================================================ FILE: preprocess/cal_centerpose_bound.py ================================================ import numpy as np np.set_printoptions(suppress=True) import os import json import tqdm from lidarnerf.convert import pano_to_lidar def cal_centerpose_bound_scale( lidar_rangeview_paths, lidar2worlds, intrinsics, bound=1.0 ): near = 200 far = 0 points_world_list = [] for i, lidar_rangeview_path in enumerate(lidar_rangeview_paths): pano = np.load(lidar_rangeview_path) point_cloud = pano_to_lidar(pano=pano[:, :, 2], lidar_K=intrinsics) point_cloud = np.concatenate( [point_cloud, np.ones(point_cloud.shape[0]).reshape(-1, 1)], -1 ) dis = np.sqrt( point_cloud[:, 0] ** 2 + point_cloud[:, 1] ** 2 + point_cloud[:, 2] ** 2 ) near = min(min(dis), near) far = max(far, max(dis)) points_world = (point_cloud @ lidar2worlds[i].T)[:, :3] points_world_list.append(points_world) print("near, far:", near, far) # plt.figure(figsize=(16, 16)) pc_all_w = np.concatenate(points_world_list)[:, :3] # plt.scatter(pc_all_w[:, 0], pc_all_w[:, 1], s=0.001) # lidar2world_scene = np.array(lidar2worlds) # plt.plot(lidar2world_scene[:, 0, -1], lidar2world_scene[:, 1, -1]) # plt.savefig('vis/points-trajectory.png') centerpose = [ (np.max(pc_all_w[:, 0]) + np.min(pc_all_w[:, 0])) / 2.0, (np.max(pc_all_w[:, 1]) + np.min(pc_all_w[:, 1])) / 2.0, (np.max(pc_all_w[:, 2]) + np.min(pc_all_w[:, 2])) / 2.0, ] print("centerpose: ", centerpose) pc_all_w_centered = pc_all_w - centerpose # plt.figure(figsize=(16, 16)) # plt.scatter(pc_all_w_centered[:, 0], pc_all_w_centered[:, 1], s=0.001) # plt.savefig('vis/points-centered.png') bound_ori = [ np.max(pc_all_w_centered[:, 0]), np.max(pc_all_w_centered[:, 1]), np.max(pc_all_w_centered[:, 2]), ] scale = bound / np.max(bound_ori) print("scale: ", scale) # pc_all_w_centered_scaled = pc_all_w_centered * scale # plt.figure(figsize=(16, 16)) # plt.scatter(pc_all_w_centered_scaled[:, 0], # pc_all_w_centered_scaled[:, 1], # s=0.001) # plt.savefig('vis/points-centered-scaled.png') def get_path_pose_from_json(root_path, sequence_id): with open( os.path.join(root_path, f"transforms_{sequence_id}_train.json"), "r" ) as f: transform = json.load(f) frames = transform["frames"] poses_lidar = [] paths_lidar = [] for f in tqdm.tqdm(frames, desc=f"Loading {type} data"): pose_lidar = np.array(f["lidar2world"], dtype=np.float32) # [4, 4] f_lidar_path = os.path.join(root_path, f["lidar_file_path"]) poses_lidar.append(pose_lidar) paths_lidar.append(f_lidar_path) return paths_lidar, poses_lidar def main(): # kitti360 root_path = "data/kitti360" sequence_id = 1908 lidar_rangeview_paths, lidar2worlds = get_path_pose_from_json( root_path, sequence_id=sequence_id ) intrinsics = (2.0, 26.9) # fov_up, fov cal_centerpose_bound_scale(lidar_rangeview_paths, lidar2worlds, intrinsics) if __name__ == "__main__": main() ================================================ FILE: preprocess/generate_train_rangeview.py ================================================ import numpy as np import os from pathlib import Path from tqdm import tqdm import shutil import argparse from lidarnerf.convert import ( lidar_to_pano_with_intensities, lidar_to_pano_with_intensities_with_bbox_mask, ) def all_points_to_world(pcd_path_list, lidar2world_list): pc_w_list = [] for i, pcd_path in enumerate(pcd_path_list): point_cloud = np.load(pcd_path) point_cloud[:, -1] = 1 points_world = (point_cloud @ (lidar2world_list[i].reshape(4, 4)).T)[:, :3] pc_w_list.append(points_world) return pc_w_list def oriented_bounding_box(data): data_norm = data - data.mean(axis=0) C = np.cov(data_norm, rowvar=False) vals, vecs = np.linalg.eig(C) vecs = vecs[:, np.argsort(-vals)] Y = np.matmul(data_norm, vecs) offset = 0.03 xmin = min(Y[:, 0]) - offset xmax = max(Y[:, 0]) + offset ymin = min(Y[:, 1]) - offset ymax = max(Y[:, 1]) + offset temp = list() temp.append([xmin, ymin]) temp.append([xmax, ymin]) temp.append([xmax, ymax]) temp.append([xmin, ymax]) pointInNewCor = np.asarray(temp) OBB = np.matmul(pointInNewCor, vecs.T) + data.mean(0) return OBB def get_dataset_bbox(all_class, dataset_root, out_dir): object_bbox = {} for class_name in all_class: lidar_path = os.path.join(dataset_root, class_name) rt_path = os.path.join(lidar_path, "lidar2world.txt") filenames = os.listdir(lidar_path) filenames.remove("lidar2world.txt") filenames.sort(key=lambda x: int(x.split(".")[0])) show_interval = 1 pcd_path_list = [os.path.join(lidar_path, filename) for filename in filenames][ ::show_interval ] print(f"{lidar_path}: {len(pcd_path_list)} frames") lidar2world_list = list(np.loadtxt(rt_path))[::show_interval] all_points = all_points_to_world(pcd_path_list, lidar2world_list) pcd = np.concatenate(all_points).reshape((-1, 3)) OBB_xy = oriented_bounding_box(pcd[:, :2]) z_min, z_max = min(pcd[:, 2]), max(pcd[:, 2]) OBB_buttum = np.concatenate([OBB_xy, np.tile(z_min, 4).reshape(4, 1)], axis=1) OBB_top = np.concatenate([OBB_xy, np.tile(z_max, 4).reshape(4, 1)], axis=1) OBB = np.concatenate([OBB_top, OBB_buttum]) object_bbox[class_name] = OBB np.save(os.path.join(out_dir, "dataset_bbox_7k.npy"), object_bbox) def LiDAR_2_Pano_NeRF_MVL( local_points_with_intensities, lidar_H, lidar_W, intrinsics, OBB_local, max_depth=80.0, ): pano, intensities = lidar_to_pano_with_intensities_with_bbox_mask( local_points_with_intensities=local_points_with_intensities, lidar_H=lidar_H, lidar_W=lidar_W, lidar_K=intrinsics, bbox_local=OBB_local, max_depth=max_depth, ) range_view = np.zeros((lidar_H, lidar_W, 3)) range_view[:, :, 1] = intensities range_view[:, :, 2] = pano return range_view def generate_nerf_mvl_train_data( H, W, intrinsics, all_class, dataset_bbox, nerf_mvl_parent_dir, out_dir, ): """ Args: H: Heights of the range view. W: Width of the range view. intrinsics: (fov_up, fov) of the range view. out_dir: Output directory. """ for class_name in all_class: OBB = dataset_bbox[class_name] lidar_path = os.path.join(nerf_mvl_parent_dir, "nerf_mvl_7k", class_name) filenames = os.listdir(lidar_path) filenames.remove("lidar2world.txt") filenames.sort(key=lambda x: int(x.split(".")[0])) save_path = os.path.join(out_dir, class_name) if not os.path.exists(save_path): os.makedirs(save_path) shutil.copy( os.path.join(lidar_path, "lidar2world.txt"), os.path.join(save_path, "lidar2world.txt"), ) lidar2world = np.loadtxt(os.path.join(lidar_path, "lidar2world.txt")) avaliable_frames = [i for i in range(0, len(filenames))] print(class_name, " frames num ", len(avaliable_frames)) for idx in tqdm(avaliable_frames): pcd = np.load(os.path.join(lidar_path, filenames[idx])) OBB_local = ( np.concatenate([OBB, np.ones((8, 1))], axis=1) @ np.linalg.inv(lidar2world[idx].reshape(4, 4)).T ) pano = LiDAR_2_Pano_NeRF_MVL(pcd, H, W, intrinsics, OBB_local) np.savez_compressed( os.path.join(save_path, "{:010d}.npz").format(idx), data=pano ) def create_nerf_mvl_rangeview(): project_root = Path(__file__).parent.parent nerf_mvl_root = project_root / "data" / "nerf_mvl" / "nerf_mvl_7k" nerf_mvl_parent_dir = nerf_mvl_root.parent out_dir = nerf_mvl_parent_dir / "nerf_mvl_7k_pano" all_class = [ "water_safety_barrier", "tire", "pier", "plant", "warning_sign", "traffic_cone", "bollard", "pedestrian", "car", ] # get_dataset_bbox if not os.path.exists(os.path.join(nerf_mvl_parent_dir, "dataset_bbox_7k.npy")): get_dataset_bbox(all_class, nerf_mvl_root, nerf_mvl_parent_dir) dataset_bbox = np.load( os.path.join(nerf_mvl_parent_dir, "dataset_bbox_7k.npy"), allow_pickle=True ).item() # generate train rangeview images H = 256 W = 1800 intrinsics = (15, 40) generate_nerf_mvl_train_data( H=H, W=W, intrinsics=intrinsics, all_class=all_class, dataset_bbox=dataset_bbox, nerf_mvl_parent_dir=nerf_mvl_parent_dir, out_dir=out_dir, ) def LiDAR_2_Pano_KITTI( local_points_with_intensities, lidar_H, lidar_W, intrinsics, max_depth=80.0 ): pano, intensities = lidar_to_pano_with_intensities( local_points_with_intensities=local_points_with_intensities, lidar_H=lidar_H, lidar_W=lidar_W, lidar_K=intrinsics, max_depth=max_depth, ) range_view = np.zeros((lidar_H, lidar_W, 3)) range_view[:, :, 1] = intensities range_view[:, :, 2] = pano return range_view def generate_train_data( H, W, intrinsics, lidar_paths, out_dir, points_dim, ): """ Args: H: Heights of the range view. W: Width of the range view. intrinsics: (fov_up, fov) of the range view. out_dir: Output directory. """ out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) for lidar_path in tqdm(lidar_paths): point_cloud = np.fromfile(lidar_path, dtype=np.float32) point_cloud = point_cloud.reshape((-1, points_dim)) pano = LiDAR_2_Pano_KITTI(point_cloud, H, W, intrinsics) frame_name = lidar_path.split("/")[-1] suffix = frame_name.split(".")[-1] frame_name = frame_name.replace(suffix, "npy") np.save(out_dir / frame_name, pano) def create_kitti_rangeview(): project_root = Path(__file__).parent.parent kitti_360_root = project_root / "data" / "kitti360" / "KITTI-360" kitti_360_parent_dir = kitti_360_root.parent out_dir = kitti_360_parent_dir / "train" sequence_name = "2013_05_28_drive_0000" H = 66 W = 1030 intrinsics = (2.0, 26.9) # fov_up, fov s_frame_id = 1908 e_frame_id = 1971 # Inclusive frame_ids = list(range(s_frame_id, e_frame_id + 1)) lidar_dir = ( kitti_360_root / "data_3d_raw" / f"{sequence_name}_sync" / "velodyne_points" / "data" ) lidar_paths = [ os.path.join(lidar_dir, "%010d.bin" % frame_id) for frame_id in frame_ids ] generate_train_data( H=H, W=W, intrinsics=intrinsics, lidar_paths=lidar_paths, out_dir=out_dir, points_dim=4, ) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--dataset", type=str, default="kitti360", choices=["kitti360", "nerf_mvl"], help="The dataset loader to use.", ) args = parser.parse_args() # Check dataset. if args.dataset == "kitti360": create_kitti_rangeview() elif args.dataset == "nerf_mvl": create_nerf_mvl_rangeview() if __name__ == "__main__": main() ================================================ FILE: preprocess/kitti360_loader.py ================================================ from pathlib import Path import numpy as np import camtools as ct import open3d as o3d class KITTI360Loader: def __init__(self, kitti_360_root) -> None: # Root directory. self.kitti_360_root = Path(kitti_360_root) if not self.kitti_360_root.is_dir(): raise FileNotFoundError(f"KITTI-360 {kitti_360_root} not found.") # Other directories. self.calibration_dir = self.kitti_360_root / "calibration" self.data_poses_dir = self.kitti_360_root / "data_poses" self.data_2d_raw_dir = self.kitti_360_root / "data_2d_raw" self.data_3d_raw_dir = self.kitti_360_root / "data_3d_raw" # Check if all directories exist. if not self.calibration_dir.is_dir(): raise FileNotFoundError( f"Calibration dir {self.calibration_dir} not found." ) if not self.data_poses_dir.is_dir(): raise FileNotFoundError(f"Data poses dir {self.data_poses_dir} not found.") if not self.data_2d_raw_dir.is_dir(): raise FileNotFoundError( f"Data 2D raw dir {self.data_2d_raw_dir} not found." ) if not self.data_3d_raw_dir.is_dir(): raise FileNotFoundError( f"Data 3D raw dir {self.data_3d_raw_dir} not found." ) @staticmethod def _read_variable(fid, name, M, N): """ Ref: kitti360scripts/devkits/commons/loadCalibration.py """ # Rewind fid.seek(0, 0) # Search for variable identifier line = 1 success = 0 while line: line = fid.readline() if line.startswith(name): success = 1 break # Return if variable identifier not found if success == 0: return None # Fill matrix line = line.replace("%s:" % name, "") line = line.split() assert len(line) == M * N line = [float(x) for x in line] mat = np.array(line).reshape(M, N) return mat @staticmethod def _load_perspective_intrinsics(intrinsics_path): """ Args: intrinsics_path: str, path to perspective.txt. Returns: A dict, containing: - "P_rect_00": 4x4 rectified intrinsic for cam_00. - "P_rect_01": 4x4 rectified intrinsic for cam_01. - "R_rect_00": 3x3 rectification matrix for cam_00. - "R_rect_01": 3xe rectification matrix for cam_01. Ref: kitti360scripts/devkits/commons/loadCalibration.py::loadPerspectiveIntrinsic """ intrinsics_path = Path(intrinsics_path) with open(intrinsics_path, "r") as fid: perspective_dict = {} intrinsic_names = ["P_rect_00", "R_rect_00", "P_rect_01", "R_rect_01"] last_row = np.array([0, 0, 0, 1]).reshape(1, 4) for intrinsic in intrinsic_names: if intrinsic.startswith("P_rect"): perspective_dict[intrinsic] = np.concatenate( (KITTI360Loader._read_variable(fid, intrinsic, 3, 4), last_row) ) else: perspective_dict[intrinsic] = KITTI360Loader._read_variable( fid, intrinsic, 3, 3 ) return perspective_dict def load_images(self, camera_name, sequence_name, frame_ids): """ Args: camera_name: str, name of camera. e.g. "cam_00". sequence_name: str, name of sequence. e.g. "2013_05_28_drive_0000". frame_ids: list of int, frame ids. e.g. range(1908, 1971+1). Returns: An np.ndarray, float32, [N, H, W, 3], range 0-1, RGB images. """ im_paths = self.get_image_paths(camera_name, sequence_name, frame_ids) ims = [ct.io.imread(im_path) for im_path in im_paths] ims = np.stack(ims, axis=0) return ims def get_image_paths(self, camera_name, sequence_name, frame_ids): """ Args: camera_name: str, name of camera. e.g. "cam_00". sequence_name: str, name of sequence. e.g. "2013_05_28_drive_0000". frame_ids: list of int, frame ids. e.g. range(1908, 1971+1). Returns: An list of str, image paths. """ # Sanity checks. if camera_name == "cam_00": subdir_name = "image_00" elif camera_name == "cam_01": subdir_name = "image_01" else: raise ValueError(f"Invalid camera_name {camera_name}") # Get image paths. im_dir = ( self.data_2d_raw_dir / f"{sequence_name}_sync" / subdir_name / "data_rect" ) im_paths = [im_dir / f"{frame_id:010d}.png" for frame_id in frame_ids] for im_path in im_paths: if not im_path.is_file(): raise FileNotFoundError(f"Image {im_path} not found.") return im_paths def _load_all_cameras(self, sequence_name): """ Args: sequence_name: str, name of sequence. e.g. "2013_05_28_drive_0000". Returns: cam_00_K: 3x3 intrinsics, rectified perspective cam_00. cam_01_K: 3x3 intrinsics, rectified perspective cam_01. cam_00_T_dict: map frame_id to 4x4 T, rectified perspective cam_00. cam_01_T_dict: map frame_id to 4x4 T, rectified perspective cam_01. """ data_poses_dir = self.data_poses_dir / f"{sequence_name}_sync" assert data_poses_dir.is_dir() # Load intrinsics and rectification matrices. perspective_path = self.calibration_dir / "perspective.txt" perspective_dict = KITTI360Loader._load_perspective_intrinsics(perspective_path) cam_00_K = perspective_dict["P_rect_00"][:3, :3] # 3x3 cam_01_K = perspective_dict["P_rect_01"][:3, :3] # 3x3 cam_00_rec = np.eye(4) # 4x4 cam_00_rec[:3, :3] = perspective_dict["R_rect_00"] cam_01_rec = np.eye(4) # 4x4 cam_01_rec[:3, :3] = perspective_dict["R_rect_01"] # IMU to world transformation (poses.txt). poses_path = data_poses_dir / "poses.txt" imu_to_world_dict = dict() frame_ids = [] for line in np.loadtxt(poses_path): frame_id = int(line[0]) frame_ids.append(frame_id) imu_to_world = line[1:].reshape((3, 4)) imu_to_world_dict[frame_id] = imu_to_world # Camera to IMU transformation (calib_cam_to_pose.txt). cam_to_imu_path = self.calibration_dir / "calib_cam_to_pose.txt" with open(cam_to_imu_path, "r") as fid: cam_00_to_imu = KITTI360Loader._read_variable(fid, "image_00", 3, 4) cam_01_to_imu = KITTI360Loader._read_variable(fid, "image_01", 3, 4) cam_02_to_imu = KITTI360Loader._read_variable(fid, "image_02", 3, 4) cam_03_to_imu = KITTI360Loader._read_variable(fid, "image_03", 3, 4) cam_00_to_imu = ct.convert.pad_0001(cam_00_to_imu) cam_01_to_imu = ct.convert.pad_0001(cam_01_to_imu) cam_02_to_imu = ct.convert.pad_0001(cam_02_to_imu) cam_03_to_imu = ct.convert.pad_0001(cam_03_to_imu) # Compute rectified cam_00_to_world, cam_01_to_world. cam_00_to_world_dict = dict() for frame_id in frame_ids: imu_to_world = imu_to_world_dict[frame_id] cam_00_to_world_unrec = imu_to_world @ cam_00_to_imu cam_00_to_world = cam_00_to_world_unrec @ np.linalg.inv(cam_00_rec) cam_00_to_world_dict[frame_id] = ct.convert.pad_0001(cam_00_to_world) cam_01_to_world_dict = dict() for frame_id in frame_ids: imu_to_world = imu_to_world_dict[frame_id] cam_01_to_world_unrec = imu_to_world @ cam_01_to_imu cam_01_to_world = cam_01_to_world_unrec @ np.linalg.inv(cam_01_rec) cam_01_to_world_dict[frame_id] = ct.convert.pad_0001(cam_01_to_world) # Sanity check: check our rectified cam0_to_world is the same as the # ones ground-truth given by KITTI-360. cam_00_to_world_path = data_poses_dir / "cam0_to_world.txt" gt_cam_00_to_world_dict = dict() for line in np.loadtxt(cam_00_to_world_path): frame_id = int(line[0]) gt_cam_00_to_world_dict[frame_id] = line[1:].reshape((4, 4)) for frame_id in frame_ids: gt_cam_00_to_world = gt_cam_00_to_world_dict[frame_id] cam_00_to_world = cam_00_to_world_dict[frame_id] assert np.allclose( gt_cam_00_to_world, cam_00_to_world, atol=1e-5, rtol=1e-5 ) # Convert cam_to_world to T. cam_00_T_dict = dict() cam_01_T_dict = dict() for frame_id in frame_ids: cam_00_T = np.linalg.inv(cam_00_to_world_dict[frame_id]) cam_01_T = np.linalg.inv(cam_01_to_world_dict[frame_id]) cam_00_T_dict[frame_id] = cam_00_T cam_01_T_dict[frame_id] = cam_01_T return cam_00_K, cam_01_K, cam_00_T_dict, cam_01_T_dict def load_cameras(self, camera_name, sequence_name, frame_ids): """ Args: camera_name: str, name of camera. e.g. "cam_00". sequence_name: str, name of sequence. e.g. "2013_05_28_drive_0000". frame_ids: list of int, frame ids. e.g. range(1908, 1971+1). Returns: Ks, Ts """ ( cam_00_K, cam_01_K, cam_00_T_dict, cam_01_T_dict, ) = self._load_all_cameras(sequence_name) num_cameras = len(frame_ids) if camera_name == "cam_00": Ks = [cam_00_K for _ in range(num_cameras)] Ts = [cam_00_T_dict[frame_id] for frame_id in frame_ids] elif camera_name == "cam_01": Ks = [cam_01_K for _ in range(num_cameras)] Ts = [cam_01_T_dict[frame_id] for frame_id in frame_ids] else: raise ValueError(f"Unknown camera name {camera_name}") Ks = np.stack(Ks) Ts = np.stack(Ts) return Ks, Ts def _load_all_lidars(self, sequence_name): """ Args: sequence_name: str, name of sequence. e.g. "2013_05_28_drive_0000". Returns: velo_to_world: 4x4 metric. """ data_poses_dir = self.data_poses_dir / f"{sequence_name}_sync" assert data_poses_dir.is_dir() # IMU to world transformation (poses.txt). poses_path = data_poses_dir / "poses.txt" imu_to_world_dict = dict() frame_ids = [] for line in np.loadtxt(poses_path): frame_id = int(line[0]) frame_ids.append(frame_id) imu_to_world = line[1:].reshape((3, 4)) imu_to_world_dict[frame_id] = imu_to_world # Camera to IMU transformation (calib_cam_to_pose.txt). cam_to_imu_path = self.calibration_dir / "calib_cam_to_pose.txt" with open(cam_to_imu_path, "r") as fid: cam_00_to_imu = KITTI360Loader._read_variable(fid, "image_00", 3, 4) cam_00_to_imu = ct.convert.pad_0001(cam_00_to_imu) # Camera00 to Velo transformation (calib_cam_to_velo.txt). cam00_to_velo_path = self.calibration_dir / "calib_cam_to_velo.txt" with open(cam00_to_velo_path, "r") as fid: line = fid.readline().split() line = [float(x) for x in line] cam_00_to_velo = np.array(line).reshape(3, 4) cam_00_to_velo = ct.convert.pad_0001(cam_00_to_velo) # Compute velo_to_world velo_to_world_dict = dict() for frame_id in frame_ids: imu_to_world = imu_to_world_dict[frame_id] cam_00_to_world_unrec = imu_to_world @ cam_00_to_imu velo_to_world = cam_00_to_world_unrec @ np.linalg.inv(cam_00_to_velo) velo_to_world_dict[frame_id] = ct.convert.pad_0001(velo_to_world) return velo_to_world_dict def load_lidars(self, sequence_name, frame_ids): """ Args: sequence_name: str, name of sequence. e.g. "2013_05_28_drive_0000". frame_ids: list of int, frame ids. e.g. range(1908, 1971+1). Returns: velo_to_worlds """ velo_to_world_dict = self._load_all_lidars(sequence_name) velo_to_worlds = [velo_to_world_dict[frame_id] for frame_id in frame_ids] velo_to_worlds = np.stack(velo_to_worlds) return velo_to_worlds def main(): # Load cameras. k3 = KITTI360Loader(kitti_360_root=Path("data") / "KITTI-360") cam_00_Ks, cam_00_Ts = k3.load_cameras( camera_name="cam_00", sequence_name="2013_05_28_drive_0000", frame_ids=range(1908, 1971 + 1), ) cam_01_Ks, cam_01_Ts = k3.load_cameras( camera_name="cam_01", sequence_name="2013_05_28_drive_0000", frame_ids=range(1908, 1971 + 1), ) # Load images. im_cam_00s = k3.load_images( camera_name="cam_00", sequence_name="2013_05_28_drive_0000", frame_ids=range(1908, 1971 + 1), ) im_cam_01s = k3.load_images( camera_name="cam_01", sequence_name="2013_05_28_drive_0000", frame_ids=range(1908, 1971 + 1), ) # Visualize. cam_00_frames = ct.camera.create_camera_ray_frames(cam_00_Ks, cam_00_Ts, size=0.8) cam_01_frames = ct.camera.create_camera_ray_frames(cam_01_Ks, cam_01_Ts, size=0.8) o3d.visualization.draw_geometries([cam_00_frames, cam_01_frames]) if __name__ == "__main__": main() ================================================ FILE: preprocess/kitti360_to_nerf.py ================================================ from pathlib import Path from kitti360_loader import KITTI360Loader import camtools as ct import numpy as np import json def normalize_Ts(Ts): # New Cs. Cs = np.array([ct.convert.T_to_C(T) for T in Ts]) normalize_mat = ct.normalize.compute_normalize_mat(Cs) Cs_new = ct.project.homo_project(Cs.reshape((-1, 3)), normalize_mat) # New Ts. Ts_new = [] for T, C_new in zip(Ts, Cs_new): pose = ct.convert.T_to_pose(T) pose[:3, 3] = C_new T_new = ct.convert.pose_to_T(pose) Ts_new.append(T_new) return Ts_new def main(): project_root = Path(__file__).parent.parent kitti_360_root = project_root / "data" / "kitti360" / "KITTI-360" kitti_360_parent_dir = kitti_360_root.parent # Specify frames and splits. sequence_name = "2013_05_28_drive_0000" sequence_id = "1908" if sequence_id == "1538": print("Using sqequence 1538-1601") s_frame_id = 1538 e_frame_id = 1601 # Inclusive val_frame_ids = [1551, 1564, 1577, 1590] elif sequence_id == "1728": print("Using sqequence 1728-1791") s_frame_id = 1728 e_frame_id = 1791 # Inclusive val_frame_ids = [1741, 1754, 1767, 1780] elif sequence_id == "1908": print("Using sqequence 1908-1971") s_frame_id = 1908 e_frame_id = 1971 # Inclusive val_frame_ids = [1921, 1934, 1947, 1960] elif sequence_id == "3353": print("Using sqequence 3353-3416") s_frame_id = 3353 e_frame_id = 3416 # Inclusive val_frame_ids = [3366, 3379, 3392, 3405] else: raise ValueError(f"Invalid sequence id: {sequence_id}") frame_ids = list(range(s_frame_id, e_frame_id + 1)) num_frames = len(frame_ids) test_frame_ids = val_frame_ids train_frame_ids = [x for x in frame_ids if x not in val_frame_ids] # Load KITTI-360 dataset. k3 = KITTI360Loader(kitti_360_root) # Get image paths. cam_00_im_paths = k3.get_image_paths("cam_00", sequence_name, frame_ids) cam_01_im_paths = k3.get_image_paths("cam_01", sequence_name, frame_ids) im_paths = cam_00_im_paths + cam_01_im_paths # Get Ks, Ts. cam_00_Ks, cam_00_Ts = k3.load_cameras("cam_00", sequence_name, frame_ids) cam_01_Ks, cam_01_Ts = k3.load_cameras("cam_01", sequence_name, frame_ids) Ks = np.concatenate([cam_00_Ks, cam_01_Ks], axis=0) Ts = np.concatenate([cam_00_Ts, cam_01_Ts], axis=0) # Ts = normalize_Ts(Ts) # Get image dimensions, assume all images have the same dimensions. im_rgb = ct.io.imread(cam_00_im_paths[0]) im_h, im_w, _ = im_rgb.shape # Get lidar paths (range view not raw data). range_view_dir = kitti_360_parent_dir / "train" range_view_paths = [ range_view_dir / "{:010d}.npy".format(int(frame_id)) for frame_id in frame_ids ] # Get lidar2world. lidar2world = k3.load_lidars(sequence_name, frame_ids) # Get image dimensions, assume all images have the same dimensions. lidar_range_image = np.load(range_view_paths[0]) lidar_h, lidar_w, _ = lidar_range_image.shape # Split by train/test/val. all_indices = [i - s_frame_id for i in frame_ids] train_indices = [i - s_frame_id for i in train_frame_ids] val_indices = [i - s_frame_id for i in val_frame_ids] test_indices = [i - s_frame_id for i in test_frame_ids] # all_indices = all_indices + [i + num_frames for i in all_indices] # train_indices = train_indices + [i + num_frames for i in train_indices] # val_indices = val_indices + [i + num_frames for i in val_indices] # test_indices = test_indices + [i + num_frames for i in test_indices] split_to_all_indices = { "train": train_indices, "val": val_indices, "test": test_indices, } for split, indices in split_to_all_indices.items(): print(f"Split {split} has {len(indices)} frames.") im_paths_split = [im_paths[i] for i in indices] lidar_paths_split = [range_view_paths[i] for i in indices] lidar2world_split = [lidar2world[i] for i in indices] Ks_split = [Ks[i] for i in indices] Ts_split = [Ts[i] for i in indices] json_dict = { "w": im_w, "h": im_h, "w_lidar": lidar_w, "h_lidar": lidar_h, "fl_x": Ks_split[0][0, 0], "fl_y": Ks_split[0][1, 1], "cx": Ks_split[0][0, 2], "cy": Ks_split[0][1, 2], "aabb_scale": 2, "frames": [ { "file_path": str(path.relative_to(kitti_360_parent_dir)), "transform_matrix": ct.convert.T_to_pose(T).tolist(), "lidar_file_path": str( lidar_path.relative_to(kitti_360_parent_dir) ), "lidar2world": lidar2world.tolist(), } for ( path, T, lidar_path, lidar2world, ) in zip( im_paths_split, Ts_split, lidar_paths_split, lidar2world_split, ) ], } json_path = kitti_360_parent_dir / f"transforms_{sequence_id}_{split}.json" with open(json_path, "w") as f: json.dump(json_dict, f, indent=2) print(f"Saved {json_path}.") if __name__ == "__main__": main() ================================================ FILE: preprocess/nerfmvl_loader.py ================================================ from pathlib import Path import numpy as np class NeRFMVLLoader: def __init__(self, nerf_mvl_root, class_name) -> None: # Root directory. self.nerf_mvl_root = Path(nerf_mvl_root) if not self.nerf_mvl_root.is_dir(): raise FileNotFoundError(f"NeRF_MVL {nerf_mvl_root} not found.") # Other directories. self.data_3d_raw_dir = self.nerf_mvl_root / class_name self.lidar2world_path = self.data_3d_raw_dir / "lidar2world.txt" # Check if all directories exist. if not self.data_3d_raw_dir.is_dir(): raise FileNotFoundError( f"Data 3D raw dir {self.data_3d_raw_dir} not found." ) def _load_all_lidars( self, ): """ Args: Returns: velo_to_world: 4x4 metric. """ velo_to_world_dict = np.loadtxt(self.lidar2world_path) return velo_to_world_dict.reshape(-1, 4, 4) def load_lidars(self, frame_ids): """ Args: frame_ids: list of int, frame ids. e.g. range(1908, 1971+1). Returns: velo_to_worlds """ velo_to_world_dict = self._load_all_lidars() velo_to_worlds = [velo_to_world_dict[frame_id] for frame_id in frame_ids] velo_to_worlds = np.stack(velo_to_worlds) return velo_to_worlds def main(): dataset = NeRFMVLLoader(Path("data") / "nerf_mvl" / "nerf_mvl_7k_pano", "pier") velo_to_world_dict = dataset._load_all_lidars() return if __name__ == "__main__": main() ================================================ FILE: preprocess/nerfmvl_to_nerf.py ================================================ import os from nerfmvl_loader import NeRFMVLLoader import numpy as np import json from pathlib import Path def main(): project_root = Path(__file__).parent.parent nerf_mvl_root = project_root / "data" / "nerf_mvl" / "nerf_mvl_7k_pano" nerf_mvl_parent_dir = nerf_mvl_root.parent # Specify frames and splits. train_split = { "water_safety_barrier": 2, "tire": 2, "pier": 2, "plant": 2, "warning_sign": 2, "bollard": 2, "pedestrian": 3, "car": 3, "traffic_cone": 3, } for class_name, split_intervel in train_split.items(): # Get lidar paths (range view not raw data). range_view_dir = nerf_mvl_root / class_name filenames = os.listdir(range_view_dir) filenames.remove("lidar2world.txt") range_view_paths = [ Path(os.path.join(range_view_dir, filename)) for filename in filenames ] num_samples = len(range_view_paths) frame_ids = np.arange(num_samples) train_frame_ids = [i for i in range(0, num_samples, split_intervel)] val_frame_ids = [i for i in range(0, num_samples, split_intervel * 20)] test_frame_ids = val_frame_ids # Load NeRF_MVL dataset. nerf_mvl_dataset = NeRFMVLLoader(nerf_mvl_root, class_name) # Get lidar2world. lidar2world = nerf_mvl_dataset.load_lidars(frame_ids) # Get image dimensions, assume all images have the same dimensions. lidar_range_image = np.load(range_view_paths[0])["data"] lidar_h, lidar_w, _ = lidar_range_image.shape # Split by train/test/val. all_indices = frame_ids train_indices = train_frame_ids val_indices = val_frame_ids test_indices = test_frame_ids split_to_all_indices = { "train": train_indices, "val": val_indices, "test": test_indices, } for split, indices in split_to_all_indices.items(): print(f"Split {split} has {len(indices)} frames.") lidar_paths_split = [range_view_paths[i] for i in indices] lidar2world_split = [lidar2world[i] for i in indices] json_dict = { "w_lidar": lidar_w, "h_lidar": lidar_h, "aabb_scale": 2, "frames": [ { "lidar_file_path": str( lidar_path.relative_to(nerf_mvl_parent_dir) ), "lidar2world": lidar2world.tolist(), } for ( lidar_path, lidar2world, ) in zip( lidar_paths_split, lidar2world_split, ) ], } json_path = nerf_mvl_parent_dir / f"transforms_{class_name}_{split}.json" with open(json_path, "w") as f: json.dump(json_dict, f, indent=2) print(f"Saved {json_path}.") if __name__ == "__main__": main() ================================================ FILE: readme.md ================================================

LiDAR-NeRF: Novel LiDAR View Synthesis via Neural Radiance Fields

Home Page Paper PDF Video MP4

Tao Tang · Longfei Gao · Guangrun Wang · Yixing Lao · Peng Chen · Hengshuang Zhao · Dayang Hao · Xiaodan Liang* · Mathieu Salzmann · Kaicheng Yu

Formatter

![lidar-nerf](./assets/lidar-nerf.png) ![lidar-nerf-res](./assets/lidar-nerf-res.png) This paper introduces a new task of novel LiDAR view synthesis and proposes a differentiable framework called **LiDAR-NeRF** with a structural regularization, as well as an object-centric multi-view LiDAR dataset called **NeRF-MVL**. 1. We formulate the first differentiable framework, LiDAR-NeRF, for novel LiDAR view synthesis, which can render novel point clouds with point intensity and ray-drop probability without explicit 3D reconstruction. 2. We propose a structural regularization method to effectively preserve local structural details, thereby guiding the model towards more precise geometry estimations, leading to more faithful novel LiDAR view synthesis. 3. We establish the NeRF-MVL dataset from LiDAR sensors of real autonomous vehicles to evaluate the object-centric novel LiDAR view synthesis. 4. We demonstrate the effectiveness of our LiDAR-NeRF quantitatively and qualitatively in both scene-level and object-level novel LiDAR view synthesis. ## News - [2023/07/14] LiDAR-NeRF v0.1.0 released. NeRF-MVL dataset released. ## Installation ```bash conda create -n lidarnerf python=3.9 conda activate lidarnerf # Dependencies pip install -r requirements_torch.txt pip install -r requirements.txt # tiny-cuda-nn # This may take a while, please refer to the official documentation pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch # camtools pip install git+https://github.com/yxlao/camtools.git # Install lidar-nerf pip install -e . python -c "import lidarnerf; print(lidarnerf.__version__)" ``` ## Dataset ### KITTI-360 dataset First, download KITTI-360 dataset from [here](https://www.cvlibs.net/datasets/kitti-360/index.php) and put the dataset into `data/kitti360`. Your folder structure should look like this: ```bash data └── kitti360 └── KITTI-360 ├── calibration ├── data_2d_raw ├── data_3d_raw └── data_poses ``` Next, run KITTI-360 dataset preprocessing: ```bash # Generate train range images python preprocess/generate_train_rangeview.py --dataset kitti360 # Generate jsons python preprocess/kitti360_to_nerf.py # Calculate center pose (optional) can directly use our config python preprocess/cal_centerpose_bound.py ``` After preprocessing, your folder structure should look like this: ```bash data └── kitti360 ├── train ├── KITTI-360 │ ├── calibration │ ├── data_2d_raw │ ├── data_3d_raw │ └── data_poses ├── transforms_{sequence_id}test.json ├── transforms_{sequence_id}train.json └── transforms_{sequence_id}val.json ``` ### NeRF-MVL dataset First, download our NeRF-MVL dataset from [here](https://drive.google.com/drive/folders/1ZCuM3lCvWATXL79WdqrFxbYd4kwsHoTM?usp=sharing). Your folder structure should look like this: ```bash $ tree data -l -L 2 data └── nerf_mvl └── nerf_mvl_7k └── {class_name} ├── {frame_id}.npy └── lidar2world.txt ``` Next, run NeRF-MVL dataset preprocessing: ```bash # If you only download raw nerf_mvl_7k, you need convert it to nerf_mvl_7k_pano(optional) # or directly download our processed dataset in https://drive.google.com/drive/folders/1pwnIjBUMIYg0fmLaeLj-sKfVcnBexlMq?usp=sharing # Generate train range images python preprocess/generate_train_rangeview.py --dataset nerf_mvl # Generate jsons python preprocess/nerfmvl_to_nerf.py ``` After preprocessing, your folder structure should look like this: ```bash data └── nerf_mvl ├── dataset_bbox_7k.npy ├── nerf_mvl_7k │ └── {class_name} │ ├── {frame_id}.npy │ └── lidar2world.txt ├── nerf_mvl_7k_pano │ └── {class_name} │ ├── {frame_id}.npy │ └── lidar2world.txt ├── transforms_{class_name}_test.json ├── transforms_{class_name}_train.json └── transforms_{class_name}_val.json ``` ## Run ```bash # kitti360 python main_lidarnerf.py -L --workspace log/kitti360_lidar # nerf_mvl python main_lidarnerf.py --config configs/nerf_mvl.txt -L --workspace log/trial_nerf_nerf_mvl ``` ## Pre-trained Models You can download our pre-trained models [here](https://drive.google.com/drive/folders/1pwnIjBUMIYg0fmLaeLj-sKfVcnBexlMq?usp=sharing). ## Incoming - [ ] Support multi-modality, e.g., RGB & LiDAR - [ ] Support more datasets, e.g, nuScenes, Waymo - [ ] Support more implicit geometry representation, e.g., SDF # Contribution We welcome all forms of community contributions, including issues, bug fixes, new features, and more. Please [format the code](https://black.readthedocs.io/en/stable/getting_started.html) before submitting a pull request. ## Citation If you find our code or paper helps, please consider citing: ```bibtex @article{tao2023lidar, title = {LiDAR-NeRF: Novel LiDAR View Synthesis via Neural Radiance Fields}, author = {Tao, Tang and Gao, Longfei and Wang, Guangrun and Lao, Yixing and Chen, Peng and Zhao hengshuang and Hao, Dayang and Liang, Xiaodan and Salzmann, Mathieu and Yu, Kaicheng}, journal = {arXiv preprint arXiv:2304.10406}, year = {2023} } ``` ## Acknowledgments This code is built on top of the super-useful [torch-ngp](https://github.com/ashawkey/torch-ngp) implementation. ```bibtex @misc{torch-ngp, author = {Jiaxiang Tang}, year = {2022}, note = {https://github.com/ashawkey/torch-ngp}, title = {Torch-ngp: a PyTorch implementation of instant-ngp} } ``` The raydrop-mlp code for PCGen is borrowed from [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch). ```bibtex @misc{lin2020nerfpytorch, title = {NeRF-pytorch}, author = {Yen-Chen, Lin}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/yenchenlin/nerf-pytorch/}}, year = {2020} } ``` ================================================ FILE: requirements.txt ================================================ torch-ema torchmetrics ninja trimesh opencv-python tensorboardX numpy pandas tqdm matplotlib PyMCubes rich pysdf dearpygui packaging scipy lpips imageio==2.13.0 torchmetrics imageio-ffmpeg==0.4.8 open3d configargparse scikit-image nksr black # nuscenes nuscenes-devkit>=1.1.1 pyquaternion ================================================ FILE: requirements_torch.txt ================================================ torch==2.0.0 torchvision torchaudio ================================================ FILE: setup.py ================================================ from pathlib import Path from setuptools import setup import re _pwd = Path(__file__).parent.absolute() def main(): cmdclass = dict() version = None init_path = _pwd / "lidarnerf" / "__init__.py" with open(init_path, "r", encoding="utf-8") as f: lines = f.readlines() for line in lines: match_res = re.match(r'^__version__ = "(.*)"', line) if match_res: version = match_res.group(1) break if version is None: raise RuntimeError(f"Cannot find version from {init_path}") print(f"Detected lidarnerf version: {version}") _ = setup( name="lidarnerf", version=version, description="LiDAR-NeRF: Novel LiDAR View Synthesis via Neural Radiance Fields", packages=["lidarnerf", "lidarnvs"], cmdclass=cmdclass, include_package_data=True, ) if __name__ == "__main__": main()