[
  {
    "path": ".github/workflows/formatter.yml",
    "content": "name: Formatter\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    types: [opened, reopened, synchronize]\n\njobs:\n  formatter:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v3\n      - uses: psf/black@stable\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Tao Tang, Yixing Lao\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "configs/kitti360_1538.txt",
    "content": "sequence_id = 1538\nalpha_d = 1000.0\nalpha_r = 1\nalpha_i = 1e1\nalpha_grad = 100.0\ngrad_loss = True\ndesired_resolution = 32768\nchange_patch_size_lidar = [2, 8]\nnum_steps = 768\nupsample_steps = 64\nbound = 1\nscale = 0.01150158050828236\noffset = [1150.2651429096413, 3997.2217130085182, 109.3943550832148]\n"
  },
  {
    "path": "configs/kitti360_1728.txt",
    "content": "sequence_id = 1728\nalpha_d = 1000.0\nalpha_r = 1\nalpha_i = 1e1\nalpha_grad = 100.0\ngrad_loss = True\ndesired_resolution = 32768\nchange_patch_size_lidar = [2, 8]\nnum_steps = 768\nupsample_steps = 64\nbound = 1\nscale = 0.01235117157331213\noffset = [1036.6078389848537, 3863.5989919125104, 111.73904860790459]"
  },
  {
    "path": "configs/kitti360_1908.txt",
    "content": "sequence_id = 1908\nalpha_d = 1000.0\nalpha_r = 1\nalpha_i = 1e1\nalpha_grad = 100.0\ngrad_loss = True\ndesired_resolution = 32768\nchange_patch_size_lidar = [2, 8]\nnum_steps = 768\nupsample_steps = 64\nbound = 1\nscale = 0.010784853507573345\noffset = [1069.988979297527, 3765.8807850056446, 113.0212841477088]\n"
  },
  {
    "path": "configs/kitti360_3353.txt",
    "content": "sequence_id = 3353\nalpha_d = 1000.0\nalpha_r = 1\nalpha_i = 1e1\nalpha_grad = 100.0\ngrad_loss = True\ndesired_resolution = 32768\nchange_patch_size_lidar = [2, 8]\nnum_steps = 768\nupsample_steps = 64\nbound = 1\nscale = 0.00951045294058913\noffset = [1364.3592435499154, 3818.620913210761, 108.69906656243805]"
  },
  {
    "path": "configs/nerf_mvl.txt",
    "content": "path = data/nerf_mvl\ndataloader = nerf_mvl\nsequence_id = car\nalpha_d = 1000.0\nalpha_r = 1\nalpha_i = 1\nalpha_grad = 100.0\nintensity_inv_scale=255.0\ngrad_loss = False\ndesired_resolution = 32768\neval_interval=5\nnum_steps = 768\nupsample_steps = 64\nbound = 1\nscale = 0.005\noffset = [973.0483450856506, 648.3910430331337, -8.442160936778045]\n"
  },
  {
    "path": "extern/chamfer3D/chamfer3D.cu",
    "content": "\n#include <ATen/ATen.h>\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <stdio.h>\n\n#include <vector>\n\n__global__ void NmDistanceKernel(int b,\n                                 int n,\n                                 const float *xyz,\n                                 int m,\n                                 const float *xyz2,\n                                 float *result,\n                                 int *result_i) {\n    const int batch = 512;\n    __shared__ float buf[batch * 3];\n    for (int i = blockIdx.x; i < b; i += gridDim.x) {\n        for (int k2 = 0; k2 < m; k2 += batch) {\n            int end_k = min(m, k2 + batch) - k2;\n            for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) {\n                buf[j] = xyz2[(i * m + k2) * 3 + j];\n            }\n            __syncthreads();\n            for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;\n                 j += blockDim.x * gridDim.y) {\n                float x1 = xyz[(i * n + j) * 3 + 0];\n                float y1 = xyz[(i * n + j) * 3 + 1];\n                float z1 = xyz[(i * n + j) * 3 + 2];\n                int best_i = 0;\n                float best = 0;\n                int end_ka = end_k - (end_k & 3);\n                if (end_ka == batch) {\n                    for (int k = 0; k < batch; k += 4) {\n                        {\n                            float x2 = buf[k * 3 + 0] - x1;\n                            float y2 = buf[k * 3 + 1] - y1;\n                            float z2 = buf[k * 3 + 2] - z1;\n                            float d = x2 * x2 + y2 * y2 + z2 * z2;\n                            if (k == 0 || d < best) {\n                                best = d;\n                                best_i = k + k2;\n                            }\n                        }\n                        {\n                            float x2 = buf[k * 3 + 3] - x1;\n                            float y2 = buf[k * 3 + 4] - y1;\n                            float z2 = buf[k * 3 + 5] - z1;\n                            float d = x2 * x2 + y2 * y2 + z2 * z2;\n                            if (d < best) {\n                                best = d;\n                                best_i = k + k2 + 1;\n                            }\n                        }\n                        {\n                            float x2 = buf[k * 3 + 6] - x1;\n                            float y2 = buf[k * 3 + 7] - y1;\n                            float z2 = buf[k * 3 + 8] - z1;\n                            float d = x2 * x2 + y2 * y2 + z2 * z2;\n                            if (d < best) {\n                                best = d;\n                                best_i = k + k2 + 2;\n                            }\n                        }\n                        {\n                            float x2 = buf[k * 3 + 9] - x1;\n                            float y2 = buf[k * 3 + 10] - y1;\n                            float z2 = buf[k * 3 + 11] - z1;\n                            float d = x2 * x2 + y2 * y2 + z2 * z2;\n                            if (d < best) {\n                                best = d;\n                                best_i = k + k2 + 3;\n                            }\n                        }\n                    }\n                } else {\n                    for (int k = 0; k < end_ka; k += 4) {\n                        {\n                            float x2 = buf[k * 3 + 0] - x1;\n                            float y2 = buf[k * 3 + 1] - y1;\n                            float z2 = buf[k * 3 + 2] - z1;\n                            float d = x2 * x2 + y2 * y2 + z2 * z2;\n                            if (k == 0 || d < best) {\n                                best = d;\n                                best_i = k + k2;\n                            }\n                        }\n                        {\n                            float x2 = buf[k * 3 + 3] - x1;\n                            float y2 = buf[k * 3 + 4] - y1;\n                            float z2 = buf[k * 3 + 5] - z1;\n                            float d = x2 * x2 + y2 * y2 + z2 * z2;\n                            if (d < best) {\n                                best = d;\n                                best_i = k + k2 + 1;\n                            }\n                        }\n                        {\n                            float x2 = buf[k * 3 + 6] - x1;\n                            float y2 = buf[k * 3 + 7] - y1;\n                            float z2 = buf[k * 3 + 8] - z1;\n                            float d = x2 * x2 + y2 * y2 + z2 * z2;\n                            if (d < best) {\n                                best = d;\n                                best_i = k + k2 + 2;\n                            }\n                        }\n                        {\n                            float x2 = buf[k * 3 + 9] - x1;\n                            float y2 = buf[k * 3 + 10] - y1;\n                            float z2 = buf[k * 3 + 11] - z1;\n                            float d = x2 * x2 + y2 * y2 + z2 * z2;\n                            if (d < best) {\n                                best = d;\n                                best_i = k + k2 + 3;\n                            }\n                        }\n                    }\n                }\n                for (int k = end_ka; k < end_k; k++) {\n                    float x2 = buf[k * 3 + 0] - x1;\n                    float y2 = buf[k * 3 + 1] - y1;\n                    float z2 = buf[k * 3 + 2] - z1;\n                    float d = x2 * x2 + y2 * y2 + z2 * z2;\n                    if (k == 0 || d < best) {\n                        best = d;\n                        best_i = k + k2;\n                    }\n                }\n                if (k2 == 0 || result[(i * n + j)] > best) {\n                    result[(i * n + j)] = best;\n                    result_i[(i * n + j)] = best_i;\n                }\n            }\n            __syncthreads();\n        }\n    }\n}\n// int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float *\n// xyz2,float * result,int * result_i,float * result2,int * result2_i,\n// cudaStream_t stream){\nint chamfer_cuda_forward(at::Tensor xyz1,\n                         at::Tensor xyz2,\n                         at::Tensor dist1,\n                         at::Tensor dist2,\n                         at::Tensor idx1,\n                         at::Tensor idx2) {\n    const auto batch_size = xyz1.size(0);\n    const auto n = xyz1.size(1);  // num_points point cloud A\n    const auto m = xyz2.size(1);  // num_points point cloud B\n\n    NmDistanceKernel<<<dim3(32, 16, 1), 512>>>(\n            batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(),\n            dist1.data<float>(), idx1.data<int>());\n    NmDistanceKernel<<<dim3(32, 16, 1), 512>>>(\n            batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(),\n            dist2.data<float>(), idx2.data<int>());\n\n    cudaError_t err = cudaGetLastError();\n    if (err != cudaSuccess) {\n        printf(\"error in nnd updateOutput: %s\\n\", cudaGetErrorString(err));\n        // THError(\"aborting\");\n        return 0;\n    }\n    return 1;\n}\n__global__ void NmDistanceGradKernel(int b,\n                                     int n,\n                                     const float *xyz1,\n                                     int m,\n                                     const float *xyz2,\n                                     const float *grad_dist1,\n                                     const int *idx1,\n                                     float *grad_xyz1,\n                                     float *grad_xyz2) {\n    for (int i = blockIdx.x; i < b; i += gridDim.x) {\n        for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n;\n             j += blockDim.x * gridDim.y) {\n            float x1 = xyz1[(i * n + j) * 3 + 0];\n            float y1 = xyz1[(i * n + j) * 3 + 1];\n            float z1 = xyz1[(i * n + j) * 3 + 2];\n            int j2 = idx1[i * n + j];\n            float x2 = xyz2[(i * m + j2) * 3 + 0];\n            float y2 = xyz2[(i * m + j2) * 3 + 1];\n            float z2 = xyz2[(i * m + j2) * 3 + 2];\n            float g = grad_dist1[i * n + j] * 2;\n            atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2));\n            atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2));\n            atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2));\n            atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2)));\n            atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2)));\n            atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2)));\n        }\n    }\n}\n// int chamfer_cuda_backward(int b,int n,const float * xyz1,int m,const float *\n// xyz2,const float * grad_dist1,const int * idx1,const float * grad_dist2,const\n// int * idx2,float * grad_xyz1,float * grad_xyz2, cudaStream_t stream){\nint chamfer_cuda_backward(at::Tensor xyz1,\n                          at::Tensor xyz2,\n                          at::Tensor gradxyz1,\n                          at::Tensor gradxyz2,\n                          at::Tensor graddist1,\n                          at::Tensor graddist2,\n                          at::Tensor idx1,\n                          at::Tensor idx2) {\n    // cudaMemset(grad_xyz1,0,b*n*3*4);\n    // cudaMemset(grad_xyz2,0,b*m*3*4);\n\n    const auto batch_size = xyz1.size(0);\n    const auto n = xyz1.size(1);  // num_points point cloud A\n    const auto m = xyz2.size(1);  // num_points point cloud B\n\n    NmDistanceGradKernel<<<dim3(1, 16, 1), 256>>>(\n            batch_size, n, xyz1.data<float>(), m, xyz2.data<float>(),\n            graddist1.data<float>(), idx1.data<int>(), gradxyz1.data<float>(),\n            gradxyz2.data<float>());\n    NmDistanceGradKernel<<<dim3(1, 16, 1), 256>>>(\n            batch_size, m, xyz2.data<float>(), n, xyz1.data<float>(),\n            graddist2.data<float>(), idx2.data<int>(), gradxyz2.data<float>(),\n            gradxyz1.data<float>());\n\n    cudaError_t err = cudaGetLastError();\n    if (err != cudaSuccess) {\n        printf(\"error in nnd get grad: %s\\n\", cudaGetErrorString(err));\n        // THError(\"aborting\");\n        return 0;\n    }\n    return 1;\n}\n"
  },
  {
    "path": "extern/chamfer3D/chamfer_cuda.cpp",
    "content": "#include <torch/torch.h>\n\n#include <vector>\n\n/// TMP\n// #include \"common.h\"\n/// NOT TMP\n\nint chamfer_cuda_forward(at::Tensor xyz1,\n                         at::Tensor xyz2,\n                         at::Tensor dist1,\n                         at::Tensor dist2,\n                         at::Tensor idx1,\n                         at::Tensor idx2);\n\nint chamfer_cuda_backward(at::Tensor xyz1,\n                          at::Tensor xyz2,\n                          at::Tensor gradxyz1,\n                          at::Tensor gradxyz2,\n                          at::Tensor graddist1,\n                          at::Tensor graddist2,\n                          at::Tensor idx1,\n                          at::Tensor idx2);\n\nint chamfer_forward(at::Tensor xyz1,\n                    at::Tensor xyz2,\n                    at::Tensor dist1,\n                    at::Tensor dist2,\n                    at::Tensor idx1,\n                    at::Tensor idx2) {\n    return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);\n}\n\nint chamfer_backward(at::Tensor xyz1,\n                     at::Tensor xyz2,\n                     at::Tensor gradxyz1,\n                     at::Tensor gradxyz2,\n                     at::Tensor graddist1,\n                     at::Tensor graddist2,\n                     at::Tensor idx1,\n                     at::Tensor idx2) {\n    return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1,\n                                 graddist2, idx1, idx2);\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"forward\", &chamfer_forward, \"chamfer forward (CUDA)\");\n    m.def(\"backward\", &chamfer_backward, \"chamfer backward (CUDA)\");\n}\n"
  },
  {
    "path": "extern/chamfer3D/dist_chamfer_3D.py",
    "content": "from torch import nn\nfrom torch.autograd import Function\nimport torch\nimport importlib\nimport os\nimport sys\nfrom pathlib import Path\n\nscript_dir = Path(__file__).parent.absolute()\nobject_dir = script_dir.parent / \"tmp\"\nsys.path.append(str(object_dir))\n\nchamfer_found = importlib.find_loader(\"chamfer_3D\") is not None\n\nif not chamfer_found:\n    ## Cool trick from https://github.com/chrdiller\n    cur_path = os.path.dirname(os.path.abspath(__file__))\n    build_path = cur_path.replace(\"chamfer3D\", \"tmp\")\n    os.makedirs(build_path, exist_ok=True)\n    print(f\"Jitting Chamfer 3D to {build_path}\")\n\n    from torch.utils.cpp_extension import load\n\n    chamfer_3D = load(\n        name=\"chamfer_3D\",\n        sources=[\n            \"/\".join(os.path.abspath(__file__).split(\"/\")[:-1] + [\"chamfer_cuda.cpp\"]),\n            \"/\".join(os.path.abspath(__file__).split(\"/\")[:-1] + [\"chamfer3D.cu\"]),\n        ],\n        build_directory=build_path,\n    )\n    print(f\"Loaded jitted library {chamfer_3D.__file__}\")\nelse:\n    import chamfer_3D\n\n    print(f\"Loaded pre-compiled library {chamfer_3D.__file__}\")\n\n\n# Chamfer's distance module @thibaultgroueix\n# GPU tensors only\nclass chamfer_3DFunction(Function):\n    @staticmethod\n    def forward(ctx, xyz1, xyz2):\n        batchsize, n, dim = xyz1.size()\n        assert (\n            dim == 3\n        ), \"Wrong last dimension for the chamfer distance 's input! Check with .size()\"\n        _, m, dim = xyz2.size()\n        assert (\n            dim == 3\n        ), \"Wrong last dimension for the chamfer distance 's input! Check with .size()\"\n        device = xyz1.device\n\n        device = xyz1.device\n\n        dist1 = torch.zeros(batchsize, n)\n        dist2 = torch.zeros(batchsize, m)\n\n        idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)\n        idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)\n\n        dist1 = dist1.to(device)\n        dist2 = dist2.to(device)\n        idx1 = idx1.to(device)\n        idx2 = idx2.to(device)\n        torch.cuda.set_device(device)\n\n        chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)\n        ctx.save_for_backward(xyz1, xyz2, idx1, idx2)\n        return dist1, dist2, idx1, idx2\n\n    @staticmethod\n    def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):\n        xyz1, xyz2, idx1, idx2 = ctx.saved_tensors\n        graddist1 = graddist1.contiguous()\n        graddist2 = graddist2.contiguous()\n        device = graddist1.device\n\n        gradxyz1 = torch.zeros(xyz1.size())\n        gradxyz2 = torch.zeros(xyz2.size())\n\n        gradxyz1 = gradxyz1.to(device)\n        gradxyz2 = gradxyz2.to(device)\n        chamfer_3D.backward(\n            xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2\n        )\n        return gradxyz1, gradxyz2\n\n\nclass chamfer_3DDist(nn.Module):\n    def __init__(self):\n        super(chamfer_3DDist, self).__init__()\n\n    def forward(self, input1, input2):\n        input1 = input1.contiguous()\n        input2 = input2.contiguous()\n        return chamfer_3DFunction.apply(input1, input2)\n"
  },
  {
    "path": "extern/chamfer3D/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\nsetup(\n    name=\"chamfer_3D\",\n    ext_modules=[\n        CUDAExtension(\n            \"chamfer_3D\",\n            [\n                \"/\".join(__file__.split(\"/\")[:-1] + [\"chamfer_cuda.cpp\"]),\n                \"/\".join(__file__.split(\"/\")[:-1] + [\"chamfer3D.cu\"]),\n            ],\n        ),\n    ],\n    cmdclass={\"build_ext\": BuildExtension},\n)\n"
  },
  {
    "path": "extern/fscore.py",
    "content": "import torch\n\n\ndef fscore(dist1, dist2, threshold=0.001):\n    \"\"\"\n    Calculates the F-score between two point clouds with the corresponding threshold value.\n    :param dist1: Batch, N-Points\n    :param dist2: Batch, N-Points\n    :param th: float\n    :return: fscore, precision, recall\n    \"\"\"\n    # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean\n    # distances, so you should adapt the threshold accordingly.\n    precision_1 = torch.mean((dist1 < threshold).float(), dim=1)\n    precision_2 = torch.mean((dist2 < threshold).float(), dim=1)\n    fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2)\n    fscore[torch.isnan(fscore)] = 0\n    return fscore, precision_1, precision_2\n"
  },
  {
    "path": "lidarmvl/readme.md",
    "content": "# LiDAR-MVL\n\n![dataset_vis.png](../assets/dataset_vis.png)\n\n| Sensor                        | Details (Sensor location: F: front. T: top.)                 |\n| ----------------------------- | ------------------------------------------------------------ |\n| LiDAR                 LiDAR-F | Spinning, 64 beams, 10Hz capture frequency, 360° horizontal FOV, 0.6° horizontal resolution, -52.1° to +52.1° vertical FOV, ≤60m range, ±3cm accuracy. |\n| LiDAR-T                       | Spinning, 64 beams, 20Hz capture frequency, 360° horizontal FOV, 0.4° horizontal resolution, -25° to +15° vertical FOV, ≤200m range, ±2cm accuracy. |\n\nWe establish an object-centric **m**ulti-**v**iew **L**iDAR dataset, which we\ndub the **NeRF-MVL** dataset, containing carefully calibrated sensor poses,\nacquired from multi-LiDAR sensor data from real autonomous vehicles. It contains\nmore than **76k frames** covering two types of collecting vehicles, three LiDAR\nsettings, two collecting paths, and nine object categories.\n\n\n\n## Dataset Format\n\n```bash\nnerf_mvl\n├── nerf_mvl_76k\n│   ├── vehicle_type1\n│   │   ├── LiDAR\n│   │   │   └── {class_name}\n│   │   │       ├── l\n│   │   │       └── s\n│   │   │           ├── {frame_id}.npy\n│   │   │           └── lidar2world.txt\n│   │   ├── LiDAR_F\n│   │   └── LiDAR_T\n│   └── vehicle_type2\n│       ├── LiDAR\n│       ├── LiDAR_F\n│       └── LiDAR_T\n│\n└── nerf_mvl_7k\n    └── {class_name}\n        ├── {frame_id}.npy\n        └── lidar2world.txt\n\nNote:\n{class_name}: {bollard, pedestrian, plant, traffic_cone, water_safety_barrier, car, pier, tire, warning_sign}\n{frame_id}.npy: local point clouds, (N,4)\nlidar2world.txt: the lidar to world matrix, (M, 16)\nl/s: large/small collecting paths\n```\n\nFor fast validation, we extract a  pocket version of the dataset with only 7.3k\nframes covering the nine categories, called **nerf_mvl_7k**.\n\nFor all point clound frames, we  crop out the region of interest, i.e., the\nobject. The raw data will also be released to the community soon.\n\n\n\n## Citation\nIf you find our dataset helps, please consider citing:\n\n```\n@article{tao2023lidar,\n  title={LiDAR-NeRF: Novel LiDAR View Synthesis via Neural Radiance Fields},\n  author={Tao, Tang and Gao, Longfei and Wang, Guangrun and Lao, Yixing and Chen, Peng and Zhao hengshuang and Hao, Dayang and Liang, Xiaodan and Salzmann, Mathieu and Yu, Kaicheng},\n  journal={arXiv preprint arXiv:2304.10406},\n  year={2023}\n}\n```\n"
  },
  {
    "path": "lidarnerf/__init__.py",
    "content": "__version__ = \"0.1.0\"\n"
  },
  {
    "path": "lidarnerf/activation.py",
    "content": "import torch\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n\nclass _trunc_exp(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)  # cast to float32\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return torch.exp(x)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, g):\n        x = ctx.saved_tensors[0]\n        return g * torch.exp(x.clamp(-15, 15))\n\n\ntrunc_exp = _trunc_exp.apply\n"
  },
  {
    "path": "lidarnerf/convert.py",
    "content": "import numpy as np\n\n\ndef lidar_to_pano_with_intensities_with_bbox_mask(\n    local_points_with_intensities: np.ndarray,\n    lidar_H: int,\n    lidar_W: int,\n    lidar_K: int,\n    bbox_local: np.ndarray,\n    max_depth=80,\n    max_intensity=255.0,\n):\n    \"\"\"\n    Convert lidar frame to pano frame with intensities with bbox_mask.\n    Lidar points are in local coordinates.\n\n    Args:\n        local_points: (N, 4), float32, in lidar frame, with intensities.\n        lidar_H: pano height.\n        lidar_W: pano width.\n        lidar_K: lidar intrinsics.\n        bbox_local: (8x4), world bbox in local.\n        max_depth: max depth in meters.\n        max_intensity: max intensity.\n\n    Return:\n        pano: (H, W), float32.\n        intensities: (H, W), float32.\n    \"\"\"\n\n    # Un pack.\n    local_points = local_points_with_intensities[:, :3]\n    local_point_intensities = local_points_with_intensities[:, 3]\n    fov_up, fov = lidar_K\n    fov_down = fov - fov_up\n\n    # Compute dists to lidar center.\n    dists = np.linalg.norm(local_points, axis=1)\n\n    # Fill pano and intensities.\n    pano = np.zeros((lidar_H, lidar_W))\n    intensities = np.zeros((lidar_H, lidar_W))\n\n    # bbox mask\n    pano[:, :] = -1\n    r_min, r_max, c_min, c_max = 1e5, -1, 1e5, -1\n    for bbox_local_point in bbox_local:\n        x, y, z, _ = bbox_local_point\n        beta = np.pi - np.arctan2(y, x)\n        alpha = np.arctan2(z, np.sqrt(x**2 + y**2)) + fov_down / 180 * np.pi\n\n        c = int(round(beta / (2 * np.pi / lidar_W)))\n        r = int(round(lidar_H - alpha / (fov / 180 * np.pi / lidar_H)))\n\n        # Check out-of-bounds.\n        if r >= lidar_H or r < 0 or c >= lidar_W or c < 0:\n            continue\n        else:\n            r_min, r_max, c_min, c_max = (\n                min(r_min, r),\n                max(r_max, r),\n                min(c_min, c),\n                max(c_max, c),\n            )\n\n    pano[r_min:r_max, c_min:c_max] = 0\n\n    # Fill pano and intensities.\n    for local_points, dist, local_point_intensity in zip(\n        local_points,\n        dists,\n        local_point_intensities,\n    ):\n        # Check max depth.\n        if dist >= max_depth:\n            continue\n\n        x, y, z = local_points\n        beta = np.pi - np.arctan2(y, x)\n        alpha = np.arctan2(z, np.sqrt(x**2 + y**2)) + fov_down / 180 * np.pi\n        c = int(round(beta / (2 * np.pi / lidar_W)))\n        r = int(round(lidar_H - alpha / (fov / 180 * np.pi / lidar_H)))\n\n        # Check out-of-bounds.\n        if r >= lidar_H or r < 0 or c >= lidar_W or c < 0:\n            continue\n\n        # Set to min dist if not set.\n        if pano[r, c] == 0.0:\n            pano[r, c] = dist\n            intensities[r, c] = local_point_intensity / max_intensity\n        elif pano[r, c] > dist:\n            pano[r, c] = dist\n            intensities[r, c] = local_point_intensity / max_intensity\n\n    return pano, intensities\n\n\ndef lidar_to_pano_with_intensities(\n    local_points_with_intensities: np.ndarray,\n    lidar_H: int,\n    lidar_W: int,\n    lidar_K: int,\n    max_depth=80,\n):\n    \"\"\"\n    Convert lidar frame to pano frame with intensities.\n    Lidar points are in local coordinates.\n\n    Args:\n        local_points: (N, 4), float32, in lidar frame, with intensities.\n        lidar_H: pano height.\n        lidar_W: pano width.\n        lidar_K: lidar intrinsics.\n        max_depth: max depth in meters.\n\n    Return:\n        pano: (H, W), float32.\n        intensities: (H, W), float32.\n    \"\"\"\n    # Un pack.\n    local_points = local_points_with_intensities[:, :3]\n    local_point_intensities = local_points_with_intensities[:, 3]\n    fov_up, fov = lidar_K\n    fov_down = fov - fov_up\n\n    # Compute dists to lidar center.\n    dists = np.linalg.norm(local_points, axis=1)\n\n    # Fill pano and intensities.\n    pano = np.zeros((lidar_H, lidar_W))\n    intensities = np.zeros((lidar_H, lidar_W))\n    for local_points, dist, local_point_intensity in zip(\n        local_points,\n        dists,\n        local_point_intensities,\n    ):\n        # Check max depth.\n        if dist >= max_depth:\n            continue\n\n        x, y, z = local_points\n        beta = np.pi - np.arctan2(y, x)\n        alpha = np.arctan2(z, np.sqrt(x**2 + y**2)) + fov_down / 180 * np.pi\n        c = int(round(beta / (2 * np.pi / lidar_W)))\n        r = int(round(lidar_H - alpha / (fov / 180 * np.pi / lidar_H)))\n\n        # Check out-of-bounds.\n        if r >= lidar_H or r < 0 or c >= lidar_W or c < 0:\n            continue\n\n        # Set to min dist if not set.\n        if pano[r, c] == 0.0:\n            pano[r, c] = dist\n            intensities[r, c] = local_point_intensity\n        elif pano[r, c] > dist:\n            pano[r, c] = dist\n            intensities[r, c] = local_point_intensity\n\n    return pano, intensities\n\n\ndef lidar_to_pano(\n    local_points: np.ndarray, lidar_H: int, lidar_W: int, lidar_K: int, max_dpeth=80\n):\n    \"\"\"\n    Convert lidar frame to pano frame. Lidar points are in local coordinates.\n\n    Args:\n        local_points: (N, 3), float32, in lidar frame.\n        lidar_H: pano height.\n        lidar_W: pano width.\n        lidar_K: lidar intrinsics.\n        max_depth: max depth in meters.\n\n    Return:\n        pano: (H, W), float32.\n    \"\"\"\n\n    # (N, 3) -> (N, 4), filled with zeros.\n    local_points_with_intensities = np.concatenate(\n        [local_points, np.zeros((local_points.shape[0], 1))], axis=1\n    )\n    pano, _ = lidar_to_pano_with_intensities(\n        local_points_with_intensities=local_points_with_intensities,\n        lidar_H=lidar_H,\n        lidar_W=lidar_W,\n        lidar_K=lidar_K,\n        max_dpeth=max_dpeth,\n    )\n    return pano\n\n\ndef pano_to_lidar_with_intensities(pano: np.ndarray, intensities, lidar_K):\n    \"\"\"\n    Args:\n        pano: (H, W), float32.\n        intensities: (H, W), float32.\n        lidar_K: lidar intrinsics (fov_up, fov)\n\n    Return:\n        local_points_with_intensities: (N, 4), float32, in lidar frame.\n    \"\"\"\n    fov_up, fov = lidar_K\n\n    H, W = pano.shape\n    i, j = np.meshgrid(\n        np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing=\"xy\"\n    )\n    beta = -(i - W / 2) / W * 2 * np.pi\n    alpha = (fov_up - j / H * fov) / 180 * np.pi\n    dirs = np.stack(\n        [\n            np.cos(alpha) * np.cos(beta),\n            np.cos(alpha) * np.sin(beta),\n            np.sin(alpha),\n        ],\n        -1,\n    )\n    local_points = dirs * pano.reshape(H, W, 1)\n\n    # local_points: (H, W, 3)\n    # intensities : (H, W)\n    # local_points_with_intensities: (H, W, 4)\n    local_points_with_intensities = np.concatenate(\n        [local_points, intensities.reshape(H, W, 1)], axis=2\n    )\n\n    # Filter empty points.\n    idx = np.where(pano != 0.0)\n    local_points_with_intensities = local_points_with_intensities[idx]\n\n    return local_points_with_intensities\n\n\ndef pano_to_lidar(pano, lidar_K):\n    \"\"\"\n    Args:\n        pano: (H, W), float32.\n        lidar_K: lidar intrinsics (fov_up, fov)\n\n    Return:\n        local_points: (N, 3), float32, in lidar frame.\n    \"\"\"\n    local_points_with_intensities = pano_to_lidar_with_intensities(\n        pano=pano,\n        intensities=np.zeros_like(pano),\n        lidar_K=lidar_K,\n    )\n    return local_points_with_intensities[:, :3]\n\n\ndef lidar_to_pano_with_intensities_fpa(\n    local_points_with_intensities: np.ndarray,\n    lidar_H: int,\n    lidar_W: int,\n    lidar_K: int,\n    max_depth=80,\n    z_buffer_len=10,\n):\n    \"\"\"\n    Convert lidar frame to pano frame with intensities with bbox_mask.\n    Lidar points are in local coordinates.\n\n    Args:\n        local_points: (N, 4), float32, in lidar frame, with intensities.\n        lidar_H: pano height.\n        lidar_W: pano width.\n        lidar_K: lidar intrinsics.\n        max_depth: max depth in meters.\n        z_buffer_len: length of the z_buffer.\n\n    Return:\n        rangeview image: (H, W, 3), float32.\n    \"\"\"\n\n    # Un pack.\n    local_points = local_points_with_intensities[:, :3]\n    local_point_intensities = local_points_with_intensities[:, 3]\n    fov_up, fov = lidar_K\n    fov_down = fov - fov_up\n\n    # Compute dists to lidar center.\n    dists = np.linalg.norm(local_points, axis=1)\n\n    # Fill pano and intensities.\n    range_view = np.zeros((lidar_H, lidar_W, 3, z_buffer_len + 1))\n\n    for local_point, dist, local_point_intensity in zip(\n        local_points,\n        dists,\n        local_point_intensities,\n    ):\n        # Check max depth.\n        if dist >= max_depth:\n            continue\n\n        x, y, z = local_point\n        beta = np.pi - np.arctan2(y, x)\n        alpha = np.arctan2(z, np.sqrt(x**2 + y**2)) + fov_down / 180 * np.pi\n        c = int(round(beta / (2 * np.pi / lidar_W)))\n        r = int(round(lidar_H - alpha / (fov / 180 * np.pi / lidar_H)))\n\n        if r >= lidar_H or r < 0 or c >= lidar_W or c < 0:\n            continue\n\n        position = range_view[r, c, 2, 0] + 1\n        if position > z_buffer_len:\n            depth_z_buffer = list(range_view[r, c, 2][1:]) + [dist]\n            intensity_z_buffer = list(range_view[r, c, 1][1:]) + [local_point_intensity]\n            position = position - 1\n\n            sort_index = np.argsort(depth_z_buffer)\n            depth_z_buffer = np.insert(\n                np.array(depth_z_buffer)[sort_index][:z_buffer_len], 0, position\n            )\n            intensity_z_buffer = np.insert(\n                np.array(intensity_z_buffer)[sort_index][:z_buffer_len], 0, position\n            )\n            range_view[r, c, 2] = depth_z_buffer\n            range_view[r, c, 1] = intensity_z_buffer\n\n        else:\n            range_view[r, c, 2, int(position)] = dist\n            range_view[r, c, 1, int(position)] = local_point_intensity\n        range_view[r, c, 2, 0] = position\n    range_view = parse_z_buffer(range_view, lidar_H, lidar_W)\n    return range_view[:, :, 2], range_view[:, :, 1]\n\n\ndef parse_z_buffer(novel_pano, lidar_H, lidar_W, threshold=0.2):\n    range_view = np.zeros((lidar_H, lidar_W, 3))\n    for i in range(lidar_H):\n        for j in range(lidar_W):\n            range_pixel = novel_pano[i, j, 2]\n            intensity_pixel = novel_pano[i, j, 1]\n            z_buffer_num = int(range_pixel[0])\n            if z_buffer_num == 0:\n                continue\n            if z_buffer_num == 1:\n                range_view[i][j][2] = range_pixel[1]\n                range_view[i][j][1] = intensity_pixel[1]\n                continue\n\n            depth_z_buffer = range_pixel[1:z_buffer_num]\n            cloest_points = min(depth_z_buffer)\n            index = depth_z_buffer <= (cloest_points + threshold)\n\n            final_depth_z_buffer = np.array(depth_z_buffer)[index]\n            final_dis = np.average(\n                final_depth_z_buffer, weights=1 / final_depth_z_buffer\n            )\n            range_view[i][j][2] = final_dis\n\n            intensity_z_buffer = intensity_pixel[1:z_buffer_num]\n            final_intensity_z_buffer = np.array(intensity_z_buffer)[index]\n            final_intensity = np.average(\n                final_intensity_z_buffer, weights=1 / final_depth_z_buffer\n            )\n            range_view[i][j][1] = final_intensity\n    return range_view\n"
  },
  {
    "path": "lidarnerf/dataset/base_dataset.py",
    "content": "import numpy as np\nimport torch\nimport trimesh\nfrom packaging import version as pver\nfrom dataclasses import dataclass\n\n\ndef custom_meshgrid(*args):\n    if pver.parse(torch.__version__) < pver.parse(\"1.10\"):\n        return torch.meshgrid(*args)\n    else:\n        return torch.meshgrid(*args, indexing=\"ij\")\n\n\n@torch.cuda.amp.autocast(enabled=False)\ndef get_lidar_rays(poses, intrinsics, H, W, N=-1, patch_size=1):\n    \"\"\"\n    Get lidar rays.\n\n    Args:\n        poses: [B, 4, 4], cam2world\n        intrinsics: [2]\n        H, W, N: int\n    Returns:\n        rays_o, rays_d: [B, N, 3]\n        inds: [B, N]\n    \"\"\"\n    device = poses.device\n    B = poses.shape[0]\n\n    i, j = custom_meshgrid(\n        torch.linspace(0, W - 1, W, device=device),\n        torch.linspace(0, H - 1, H, device=device),\n    )  # float\n    # i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5\n    # j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5\n    i = i.t().reshape([1, H * W]).expand([B, H * W])\n    j = j.t().reshape([1, H * W]).expand([B, H * W])\n    results = {}\n    if N > 0:\n        N = min(N, H * W)\n\n        if isinstance(patch_size, int):\n            patch_size_x, patch_size_y = patch_size, patch_size\n        elif len(patch_size) == 1:\n            patch_size_x, patch_size_y = patch_size[0], patch_size[0]\n        else:\n            patch_size_x, patch_size_y = patch_size\n\n        if patch_size_x > 0:\n            # random sample left-top cores.\n            # NOTE: this impl will lead to less sampling on the image corner\n            # pixels... but I don't have other ideas.\n            num_patch = N // (patch_size_x * patch_size_y)\n            inds_x = torch.randint(0, H - patch_size_x, size=[num_patch], device=device)\n            inds_y = torch.randint(0, W - patch_size_y, size=[num_patch], device=device)\n            inds = torch.stack([inds_x, inds_y], dim=-1)  # [np, 2]\n\n            # create meshgrid for each patch\n            pi, pj = custom_meshgrid(\n                torch.arange(patch_size_x, device=device),\n                torch.arange(patch_size_y, device=device),\n            )\n            offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1)  # [p^2, 2]\n\n            inds = inds.unsqueeze(1) + offsets.unsqueeze(0)  # [np, p^2, 2]\n            inds = inds.view(-1, 2)  # [N, 2]\n            inds = inds[:, 0] * W + inds[:, 1]  # [N], flatten\n\n            inds = inds.expand([B, N])\n\n        else:\n            inds = torch.randint(0, H * W, size=[N], device=device)  # may duplicate\n            inds = inds.expand([B, N])\n\n        i = torch.gather(i, -1, inds)\n        j = torch.gather(j, -1, inds)\n\n        results[\"inds\"] = inds\n\n    else:\n        inds = torch.arange(H * W, device=device).expand([B, H * W])\n        results[\"inds\"] = inds\n\n    fov_up, fov = intrinsics\n    beta = -(i - W / 2) / W * 2 * np.pi\n    alpha = (fov_up - j / H * fov) / 180 * np.pi\n\n    directions = torch.stack(\n        [\n            torch.cos(alpha) * torch.cos(beta),\n            torch.cos(alpha) * torch.sin(beta),\n            torch.sin(alpha),\n        ],\n        -1,\n    )\n    # directions = directions / torch.norm(directions, dim=-1, keepdim=True)\n    rays_d = directions @ poses[:, :3, :3].transpose(-1, -2)  # (B, N, 3)\n    rays_o = poses[..., :3, 3]  # [B, 3]\n    rays_o = rays_o[..., None, :].expand_as(rays_d)  # [B, N, 3]\n\n    results[\"rays_o\"] = rays_o\n    results[\"rays_d\"] = rays_d\n\n    return results\n\n\n@torch.cuda.amp.autocast(enabled=False)\ndef get_rays(poses, intrinsics, H, W, N=-1, patch_size=1):\n    \"\"\"get rays\n    Args:\n        poses: [B, 4, 4], cam2world\n        intrinsics: [4]\n        H, W, N: int\n    Returns:\n        rays_o, rays_d: [B, N, 3]\n        inds: [B, N]\n    \"\"\"\n\n    device = poses.device\n    B = poses.shape[0]\n    fx, fy, cx, cy = intrinsics\n\n    i, j = custom_meshgrid(\n        torch.linspace(0, W - 1, W, device=device),\n        torch.linspace(0, H - 1, H, device=device),\n    )  # float\n    i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5\n    j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5\n\n    results = {}\n    if N > 0:\n        N = min(N, H * W)\n\n        if patch_size > 1:\n            # random sample left-top cores.\n            # NOTE: this impl will lead to less sampling on the image corner\n            # pixels... but I don't have other ideas.\n            num_patch = N // (patch_size**2)\n            inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)\n            inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)\n            inds = torch.stack([inds_x, inds_y], dim=-1)  # [np, 2]\n\n            # create meshgrid for each patch\n            pi, pj = custom_meshgrid(\n                torch.arange(patch_size, device=device),\n                torch.arange(patch_size, device=device),\n            )\n            offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1)  # [p^2, 2]\n\n            inds = inds.unsqueeze(1) + offsets.unsqueeze(0)  # [np, p^2, 2]\n            inds = inds.view(-1, 2)  # [N, 2]\n            inds = inds[:, 0] * W + inds[:, 1]  # [N], flatten\n\n            inds = inds.expand([B, N])\n\n        else:\n            inds = torch.randint(0, H * W, size=[N], device=device)  # may duplicate\n            inds = inds.expand([B, N])\n\n        i = torch.gather(i, -1, inds)\n        j = torch.gather(j, -1, inds)\n\n        results[\"inds\"] = inds\n\n    else:\n        inds = torch.arange(H * W, device=device).expand([B, H * W])\n\n    zs = torch.ones_like(i)\n    xs = (i - cx) / fx * zs\n    ys = (j - cy) / fy * zs\n    directions = torch.stack((xs, ys, zs), dim=-1)\n    directions = directions / torch.norm(directions, dim=-1, keepdim=True)\n    rays_d = directions @ poses[:, :3, :3].transpose(-1, -2)  # (B, N, 3)\n\n    rays_o = poses[..., :3, 3]  # [B, 3]\n    rays_o = rays_o[..., None, :].expand_as(rays_d)  # [B, N, 3]\n\n    results[\"rays_o\"] = rays_o\n    results[\"rays_d\"] = rays_d\n\n    return results\n\n\n# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50\ndef nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]):\n    # for the fox dataset, 0.33 scales camera radius to ~ 2\n    new_pose = np.array(\n        [\n            [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]],\n            [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]],\n            [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]],\n            [0, 0, 0, 1],\n        ],\n        dtype=np.float32,\n    )\n    return new_pose\n\n\ndef visualize_poses(poses, size=0.1):\n    # poses: [B, 4, 4]\n\n    axes = trimesh.creation.axis(axis_length=4)\n    box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()\n    box.colors = np.array([[128, 128, 128]] * len(box.entities))\n    objects = [axes, box]\n\n    for pose in poses:\n        # a camera is visualized with 8 line segments.\n        pos = pose[:3, 3]\n        a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]\n        b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2]\n        c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]\n        d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2]\n\n        dir = (a + b + c + d) / 4 - pos\n        dir = dir / (np.linalg.norm(dir) + 1e-8)\n        o = pos + dir * 3\n\n        segs = np.array(\n            [\n                [pos, a],\n                [pos, b],\n                [pos, c],\n                [pos, d],\n                [a, b],\n                [b, c],\n                [c, d],\n                [d, a],\n                [pos, o],\n            ]\n        )\n        segs = trimesh.load_path(segs)\n        objects.append(segs)\n\n    trimesh.Scene(objects).show()\n\n\n@dataclass\nclass BaseDataset:\n    pass\n"
  },
  {
    "path": "lidarnerf/dataset/kitti360_dataset.py",
    "content": "import json\nimport os\n\nimport numpy as np\nimport torch\nimport tqdm\nfrom torch.utils.data import DataLoader\nfrom dataclasses import dataclass, field\n\nfrom lidarnerf.dataset.base_dataset import get_lidar_rays, BaseDataset\n\n\n@dataclass\nclass KITTI360Dataset(BaseDataset):\n    device: str = \"cpu\"\n    split: str = \"train\"  # train, val, test\n    root_path: str = \"data/kitti360\"\n    sequence_id: str = \"1908\"\n    preload: bool = True  # preload data into GPU\n    scale: float = (\n        1  # camera radius scale to make sure camera are inside the bounding box.\n    )\n    offset: list = field(default_factory=list)  # offset\n    # bound = opt.bound  # bounding box half length, also used as the radius to random sample poses.\n    fp16: bool = True  # if preload, load into fp16.\n    patch_size: int = 1  # size of the image to extract from the scene.\n    patch_size_lidar: int = 1  # size of the image to extract from the Lidar.\n    enable_lidar: bool = True\n    num_rays: int = 4096\n    num_rays_lidar: int = 4096\n\n    def __post_init__(self):\n        if self.sequence_id == \"1538\":\n            print(\"Using sqequence 1538-1601\")\n        elif self.sequence_id == \"1728\":\n            print(\"Using sqequence 1728-1791\")\n        elif self.sequence_id == \"1908\":\n            print(\"Using sqequence 1908-1971\")\n        elif self.sequence_id == \"3353\":\n            print(\"Using sqequence 3353-3416\")\n        else:\n            raise ValueError(f\"Invalid sequence id: {sequence_id}\")\n\n        self.training = self.split in [\"train\", \"all\", \"trainval\"]\n        self.num_rays = self.num_rays if self.training else -1\n        self.num_rays_lidar = self.num_rays_lidar if self.training else -1\n        # load nerf-compatible format data.\n        with open(\n            os.path.join(\n                self.root_path, f\"transforms_{self.sequence_id}_{self.split}.json\"\n            ),\n            \"r\",\n        ) as f:\n            transform = json.load(f)\n\n        # load image size\n        if \"h\" in transform and \"w\" in transform:\n            self.H = int(transform[\"h\"])\n            self.W = int(transform[\"w\"])\n        else:\n            # we have to actually read an image to get H and W later.\n            self.H = self.W = None\n\n        if \"h_lidar\" in transform and \"w_lidar\" in transform:\n            self.H_lidar = int(transform[\"h_lidar\"])\n            self.W_lidar = int(transform[\"w_lidar\"])\n\n        # read images\n        frames = transform[\"frames\"]\n        # frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort...\n\n        self.poses_lidar = []\n        self.images_lidar = []\n        for f in tqdm.tqdm(frames, desc=f\"Loading {self.split} data\"):\n            pose_lidar = np.array(f[\"lidar2world\"], dtype=np.float32)  # [4, 4]\n\n            f_lidar_path = os.path.join(self.root_path, f[\"lidar_file_path\"])\n\n            # channel1 None, channel2 intensity , channel3 depth\n            pc = np.load(f_lidar_path)\n            ray_drop = np.where(pc.reshape(-1, 3)[:, 2] == 0.0, 0.0, 1.0).reshape(\n                self.H_lidar, self.W_lidar, 1\n            )\n\n            image_lidar = np.concatenate(\n                [ray_drop, pc[:, :, 1, None], pc[:, :, 2, None] * self.scale],\n                axis=-1,\n            )\n\n            self.poses_lidar.append(pose_lidar)\n            self.images_lidar.append(image_lidar)\n\n        self.poses_lidar = np.stack(self.poses_lidar, axis=0)\n        self.poses_lidar[:, :3, -1] = (\n            self.poses_lidar[:, :3, -1] - self.offset\n        ) * self.scale\n        self.poses_lidar = torch.from_numpy(self.poses_lidar)  # [N, 4, 4]\n\n        if self.images_lidar is not None:\n            self.images_lidar = torch.from_numpy(\n                np.stack(self.images_lidar, axis=0)\n            ).float()  # [N, H, W, C]\n\n        # calculate mean radius of all camera poses\n        # self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()\n        # print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}')\n\n        # [debug] uncomment to view all training poses.\n        # visualize_poses(self.poses.numpy())\n\n        if self.preload:\n            self.poses_lidar = self.poses_lidar.to(self.device)\n            if self.images_lidar is not None:\n                # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ?\n                if self.fp16:\n                    dtype = torch.half\n                else:\n                    dtype = torch.float\n                self.images_lidar = self.images_lidar.to(dtype).to(self.device)\n\n        self.intrinsics_lidar = (2.0, 26.9)  # fov_up, fov\n\n    def collate(self, index):\n        B = len(index)  # a list of length 1\n\n        results = {}\n\n        if self.enable_lidar:\n            poses_lidar = self.poses_lidar[index].to(self.device)  # [B, 4, 4]\n            rays_lidar = get_lidar_rays(\n                poses_lidar,\n                self.intrinsics_lidar,\n                self.H_lidar,\n                self.W_lidar,\n                self.num_rays_lidar,\n                self.patch_size_lidar,\n            )\n\n            results.update(\n                {\n                    \"H_lidar\": self.H_lidar,\n                    \"W_lidar\": self.W_lidar,\n                    \"rays_o_lidar\": rays_lidar[\"rays_o\"],\n                    \"rays_d_lidar\": rays_lidar[\"rays_d\"],\n                }\n            )\n\n        if self.images_lidar is not None and self.enable_lidar:\n            images_lidar = self.images_lidar[index].to(self.device)  # [B, H, W, 3/4]\n            if self.training:\n                C = images_lidar.shape[-1]\n                images_lidar = torch.gather(\n                    images_lidar.view(B, -1, C),\n                    1,\n                    torch.stack(C * [rays_lidar[\"inds\"]], -1),\n                )  # [B, N, 3/4]\n            results[\"images_lidar\"] = images_lidar\n\n        return results\n\n    def dataloader(self):\n        size = len(self.poses_lidar)\n        loader = DataLoader(\n            list(range(size)),\n            batch_size=1,\n            collate_fn=self.collate,\n            shuffle=self.training,\n            num_workers=0,\n        )\n        loader._data = self\n        loader.has_gt = self.images_lidar is not None\n        return loader\n\n    def __len__(self):\n        \"\"\"\n        Returns # of frames in this dataset.\n        \"\"\"\n        num_frames = len(self.poses_lidar)\n        return num_frames\n"
  },
  {
    "path": "lidarnerf/dataset/nerfmvl_dataset.py",
    "content": "import os\nimport json\nimport tqdm\nimport numpy as np\n\nimport torch\nfrom torch.utils.data import DataLoader\nfrom dataclasses import dataclass, field\n\nfrom lidarnerf.dataset.base_dataset import get_lidar_rays, BaseDataset\n\n\n@dataclass\nclass NeRFMVLDataset(BaseDataset):\n    device: str = \"cpu\"\n    split: str = \"train\"  # train, val, test\n    root_path: str = \"data/kitti360\"\n    sequence_id: str = \"car\"\n    preload: bool = True  # preload data into GPU\n    scale: float = (\n        1  # camera radius scale to make sure camera are inside the bounding box.\n    )\n    offset: list = field(default_factory=list)  # offset\n    # bound = opt.bound  # bounding box half length, also used as the radius to random sample poses.\n    fp16: bool = True  # if preload, load into fp16.\n    patch_size: int = 1  # size of the image to extract from the scene.\n    patch_size_lidar: int = 1  # size of the image to extract from the Lidar.\n    enable_lidar: bool = True\n    num_rays: int = 4096\n    num_rays_lidar: int = 4096\n\n    def __post_init__(self):\n        self.class_name = self.sequence_id\n        self.training = self.split in [\"train\", \"all\", \"trainval\"]\n        self.testing = self.split in [\"test\"]\n        self.num_rays = self.num_rays if self.training else -1\n        self.num_rays_lidar = self.num_rays_lidar if self.training else -1\n\n        with open(\n            os.path.join(\n                self.root_path, f\"transforms_{self.class_name}_{self.split}.json\"\n            ),\n            \"r\",\n        ) as f:\n            transform = json.load(f)\n\n        if \"h_lidar\" in transform and \"w_lidar\" in transform:\n            self.H_lidar = int(transform[\"h_lidar\"])\n            self.W_lidar = int(transform[\"w_lidar\"])\n\n        # read images\n        frames = transform[\"frames\"]\n        # frames = sorted(frames, key=lambda d: d['file_path']) # why do I sort...\n\n        self.poses_lidar = []\n        self.images_lidar = []\n        for f in tqdm.tqdm(frames, desc=f\"Loading {self.split} data\"):\n            pose_lidar = np.array(f[\"lidar2world\"], dtype=np.float32)  # [4, 4]\n            self.poses_lidar.append(pose_lidar)\n            if \"lidar_file_path\" in f.keys():\n                f_lidar_path = os.path.join(self.root_path, f[\"lidar_file_path\"])\n                # channel1 None, channel2 intensity , channel3 depth\n                pc = np.load(f_lidar_path)[\"data\"]\n\n                # ray_drop = np.where(pc.reshape(-1, 3)[:, 2] == 0.0, 0.0,\n                #                     1.0).reshape(self.H_lidar, self.W_lidar, 1)\n                ray_drop = pc.reshape(-1, 3)[:, 2].copy()\n                ray_drop[ray_drop > 0] = 1.0\n                ray_drop = ray_drop.reshape(self.H_lidar, self.W_lidar, 1)\n                image_lidar = np.concatenate(\n                    [ray_drop, pc[:, :, 1, None], pc[:, :, 2, None] * self.scale],\n                    axis=-1,\n                )\n\n                self.images_lidar.append(image_lidar)\n            else:\n                self.images_lidar = None\n\n        dataset_bbox = np.load(\n            os.path.join(self.root_path, \"dataset_bbox_7k.npy\"), allow_pickle=True\n        ).item()\n        self.OBB = dataset_bbox[self.class_name]\n\n        self.offset = np.mean(self.OBB, axis=0)\n\n        self.poses_lidar = np.stack(self.poses_lidar, axis=0)\n        self.poses_lidar_wo_scale_offset = self.poses_lidar.copy()\n        self.OBB_pad = np.concatenate([self.OBB, np.ones((8, 1))], axis=1)\n        self.OBB_local = [\n            self.OBB_pad @ np.linalg.inv(pose_lidar_wo_scale_offset.reshape(4, 4)).T\n            for pose_lidar_wo_scale_offset in self.poses_lidar_wo_scale_offset\n        ]\n        self.OBB_local = np.stack(self.OBB_local, axis=0)\n        self.poses_lidar[:, :3, -1] = (\n            self.poses_lidar[:, :3, -1] - self.offset\n        ) * self.scale\n        self.poses_lidar = torch.from_numpy(self.poses_lidar)  # [N, 4, 4]\n\n        if self.images_lidar is not None:\n            self.images_lidar = torch.from_numpy(\n                np.stack(self.images_lidar, axis=0)\n            ).float()  # [N, H, W, C]\n\n        if self.preload:\n            self.poses_lidar = self.poses_lidar.to(self.device)\n            if self.images_lidar is not None:\n                # TODO: linear use pow, but pow for half is only available for torch >= 1.10 ?\n                if self.fp16:\n                    dtype = torch.half\n                else:\n                    dtype = torch.float\n                self.images_lidar = self.images_lidar.to(dtype).to(self.device)\n\n        self.intrinsics_lidar = (15, 40)  # fov_up, fov\n\n    def collate(self, index):\n        B = len(index)  # a list of length 1\n\n        results = {}\n        if self.enable_lidar:\n            poses_lidar = self.poses_lidar[index].to(self.device)  # [B, 4, 4]\n            rays_lidar = get_lidar_rays(\n                poses_lidar,\n                self.intrinsics_lidar,\n                self.H_lidar,\n                self.W_lidar,\n                -1,\n                self.patch_size_lidar,\n            )\n\n            results.update(\n                {\n                    \"H_lidar\": self.H_lidar,\n                    \"W_lidar\": self.W_lidar,\n                    \"rays_o_lidar\": rays_lidar[\"rays_o\"],\n                    \"rays_d_lidar\": rays_lidar[\"rays_d\"],\n                }\n            )\n\n        if self.testing:\n            results[\"OBB_local\"] = self.OBB_local[index].reshape(8, 4)\n\n        if self.images_lidar is not None and self.enable_lidar:\n            images_lidar = self.images_lidar[index].to(self.device)  # [B, H, W, 3/4]\n            if self.training:\n                C = images_lidar.shape[-1]\n                images_lidar = torch.gather(\n                    images_lidar.view(B, -1, C),\n                    1,\n                    torch.stack(C * [rays_lidar[\"inds\"]], -1),\n                )  # [B, N, 3/4]\n                mask = images_lidar[:, :, 0] > -1\n                results[\"images_lidar\"] = images_lidar[mask].view(B, -1, C)\n                results[\"rays_o_lidar\"] = results[\"rays_o_lidar\"][mask].view(B, -1, 3)\n                results[\"rays_d_lidar\"] = results[\"rays_d_lidar\"][mask].view(B, -1, 3)\n                valid_num_rays = results[\"rays_o_lidar\"].shape[1]\n                if valid_num_rays > self.num_rays_lidar:\n                    # mask_inds = torch.randint(0, valid_num_rays, size=[self.num_rays_lidar], device=self.device)\n                    mask_inds = torch.randperm(valid_num_rays)[: self.num_rays_lidar]\n                    results[\"images_lidar\"] = results[\"images_lidar\"][\n                        :, mask_inds, :\n                    ].view(B, -1, C)\n                    results[\"rays_o_lidar\"] = results[\"rays_o_lidar\"][\n                        :, mask_inds, :\n                    ].view(B, -1, 3)\n                    results[\"rays_d_lidar\"] = results[\"rays_d_lidar\"][\n                        :, mask_inds, :\n                    ].view(B, -1, 3)\n            else:\n                results[\"images_lidar\"] = images_lidar\n\n        return results\n\n    def dataloader(self):\n        size = len(self.poses_lidar)\n        loader = DataLoader(\n            list(range(size)),\n            batch_size=1,\n            collate_fn=self.collate,\n            shuffle=self.training,\n            num_workers=0,\n        )\n        loader._data = self\n        loader.has_gt = self.images_lidar is not None\n        return loader\n\n    def __len__(self):\n        \"\"\"\n        Returns # of frames in this dataset.\n        \"\"\"\n        num_frames = len(self.poses_lidar)\n        return num_frames\n"
  },
  {
    "path": "lidarnerf/encoding.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\n\n\nclass FreqEncoder(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        max_freq_log2,\n        N_freqs,\n        log_sampling=True,\n        include_input=True,\n        periodic_fns=(torch.sin, torch.cos),\n    ):\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.include_input = include_input\n        self.periodic_fns = periodic_fns\n\n        self.output_dim = 0\n        if self.include_input:\n            self.output_dim += self.input_dim\n\n        self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)\n\n        if log_sampling:\n            self.freq_bands = 2.0 ** torch.linspace(0.0, max_freq_log2, N_freqs)\n        else:\n            self.freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq_log2, N_freqs)\n\n        self.freq_bands = self.freq_bands.numpy().tolist()\n\n    def forward(self, input, **kwargs):\n        out = []\n        if self.include_input:\n            out.append(input)\n\n        for i in range(len(self.freq_bands)):\n            freq = self.freq_bands[i]\n            for p_fn in self.periodic_fns:\n                out.append(p_fn(input * freq))\n\n        out = torch.cat(out, dim=-1)\n\n        return out\n\n\ndef get_encoder(\n    encoding,\n    input_dim=3,\n    multires=6,\n    degree=4,\n    num_levels=16,\n    level_dim=2,\n    base_resolution=16,\n    log2_hashmap_size=19,\n    desired_resolution=2048,\n    align_corners=False,\n    **kwargs\n):\n    if encoding == \"None\":\n        return lambda x, **kwargs: x, input_dim\n\n    elif encoding == \"frequency\":\n        # encoder = FreqEncoder(input_dim=input_dim, max_freq_log2=multires-1, N_freqs=multires, log_sampling=True)\n        from freqencoder import FreqEncoder\n\n        encoder = FreqEncoder(input_dim=input_dim, degree=multires)\n\n    elif encoding == \"sphere_harmonics\":\n        from shencoder import SHEncoder\n\n        encoder = SHEncoder(input_dim=input_dim, degree=degree)\n\n    elif encoding == \"hashgrid\":\n        from gridencoder import GridEncoder\n\n        encoder = GridEncoder(\n            input_dim=input_dim,\n            num_levels=num_levels,\n            level_dim=level_dim,\n            base_resolution=base_resolution,\n            log2_hashmap_size=log2_hashmap_size,\n            desired_resolution=desired_resolution,\n            gridtype=\"hash\",\n            align_corners=align_corners,\n        )\n\n    elif encoding == \"tiledgrid\":\n        from gridencoder import GridEncoder\n\n        encoder = GridEncoder(\n            input_dim=input_dim,\n            num_levels=num_levels,\n            level_dim=level_dim,\n            base_resolution=base_resolution,\n            log2_hashmap_size=log2_hashmap_size,\n            desired_resolution=desired_resolution,\n            gridtype=\"tiled\",\n            align_corners=align_corners,\n        )\n\n    elif encoding == \"ash\":\n        from ashencoder import AshEncoder\n\n        encoder = AshEncoder(\n            input_dim=input_dim,\n            output_dim=16,\n            log2_hashmap_size=log2_hashmap_size,\n            resolution=desired_resolution,\n        )\n\n    else:\n        raise NotImplementedError(\n            \"Unknown encoding mode, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]\"\n        )\n\n    return encoder, encoder.output_dim\n\n\nclass PeriodicVolumeEncoding(nn.Module):\n    \"\"\"Periodic Volume encoding\n\n    Args:\n        num_levels: Number of feature grids.\n        min_res: Resolution of smallest feature grid.\n        max_res: Resolution of largest feature grid.\n        log2_hashmap_size: Size of hash map is 2^log2_hashmap_size.\n        features_per_level: Number of features per level.\n        hash_init_scale: Value to initialize hash grid.\n        implementation: Implementation of hash encoding. Fallback to torch if tcnn not available.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_levels: int = 16,\n        min_res: int = 16,\n        max_res: int = 1024,\n        log2_hashmap_size: int = 19,\n        features_per_level: int = 2,\n        hash_init_scale: float = 0.001,\n        smoothstep: bool = False,\n    ) -> None:\n        super(PeriodicVolumeEncoding, self).__init__()\n        self.in_dim = 3\n        self.num_levels = num_levels\n        self.features_per_level = features_per_level\n        self.log2_hashmap_size = log2_hashmap_size\n        assert log2_hashmap_size % 3 == 0\n        self.hash_table_size = 2**log2_hashmap_size\n        self.n_output_dims = num_levels * features_per_level\n        self.smoothstep = smoothstep\n\n        levels = torch.arange(num_levels)\n        growth_factor = np.exp((np.log(max_res) - np.log(min_res)) / (num_levels - 1))\n        self.scalings = torch.floor(min_res * growth_factor**levels)\n\n        self.periodic_volume_resolution = 2 ** (log2_hashmap_size // 3)\n        # self.periodic_resolution = torch.minimum(torch.floor(self.scalings), periodic_volume_resolution)\n\n        self.hash_offset = levels * self.hash_table_size\n        self.hash_table = (\n            torch.rand(size=(self.hash_table_size * num_levels, features_per_level)) * 2\n            - 1\n        )\n        self.hash_table *= hash_init_scale\n        self.hash_table = nn.Parameter(self.hash_table)\n\n        # TODO weight loss by level?\n        self.per_level_weights = 1.0\n\n    def parameters(self):\n        return self.hash_table\n\n    def get_out_dim(self) -> int:\n        return self.num_levels * self.features_per_level\n\n    def hash_fn(self, in_tensor):\n        \"\"\"Returns hash tensor using method described in Instant-NGP\n\n        Args:\n            in_tensor: Tensor to be hashed\n        \"\"\"\n\n        # round to make it perioidic\n        x = in_tensor\n        x %= self.periodic_volume_resolution\n        # xyz to index\n        x = (\n            x[..., 0] * (self.periodic_volume_resolution**2)\n            + x[..., 1] * (self.periodic_volume_resolution)\n            + x[..., 2]\n        )\n        # offset by feature levels\n        x += self.hash_offset.to(x.device)\n\n        return x.long()\n\n    def pytorch_fwd(self, in_tensor):\n        \"\"\"Forward pass using pytorch. Significantly slower than TCNN implementation.\"\"\"\n\n        assert in_tensor.shape[-1] == 3\n        in_tensor = in_tensor[..., None, :]  # [..., 1, 3]\n        scaled = in_tensor * self.scalings.view(-1, 1).to(\n            in_tensor.device\n        )  # [..., L, 3]\n        scaled_c = torch.ceil(scaled).type(torch.int32)\n        scaled_f = torch.floor(scaled).type(torch.int32)\n\n        offset = scaled - scaled_f\n\n        if self.smoothstep:\n            offset = offset * offset * (3.0 - 2.0 * offset)\n\n        hashed_0 = self.hash_fn(scaled_c)  # [..., num_levels]\n        hashed_1 = self.hash_fn(\n            torch.cat(\n                [scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1\n            )\n        )\n        hashed_2 = self.hash_fn(\n            torch.cat(\n                [scaled_f[..., 0:1], scaled_f[..., 1:2], scaled_c[..., 2:3]], dim=-1\n            )\n        )\n        hashed_3 = self.hash_fn(\n            torch.cat(\n                [scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_c[..., 2:3]], dim=-1\n            )\n        )\n        hashed_4 = self.hash_fn(\n            torch.cat(\n                [scaled_c[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1\n            )\n        )\n        hashed_5 = self.hash_fn(\n            torch.cat(\n                [scaled_c[..., 0:1], scaled_f[..., 1:2], scaled_f[..., 2:3]], dim=-1\n            )\n        )\n        hashed_6 = self.hash_fn(scaled_f)\n        hashed_7 = self.hash_fn(\n            torch.cat(\n                [scaled_f[..., 0:1], scaled_c[..., 1:2], scaled_f[..., 2:3]], dim=-1\n            )\n        )\n\n        f_0 = self.hash_table[hashed_0]  # [..., num_levels, features_per_level]\n        f_1 = self.hash_table[hashed_1]\n        f_2 = self.hash_table[hashed_2]\n        f_3 = self.hash_table[hashed_3]\n        f_4 = self.hash_table[hashed_4]\n        f_5 = self.hash_table[hashed_5]\n        f_6 = self.hash_table[hashed_6]\n        f_7 = self.hash_table[hashed_7]\n\n        f_03 = f_0 * offset[..., 0:1] + f_3 * (1 - offset[..., 0:1])\n        f_12 = f_1 * offset[..., 0:1] + f_2 * (1 - offset[..., 0:1])\n        f_56 = f_5 * offset[..., 0:1] + f_6 * (1 - offset[..., 0:1])\n        f_47 = f_4 * offset[..., 0:1] + f_7 * (1 - offset[..., 0:1])\n\n        f0312 = f_03 * offset[..., 1:2] + f_12 * (1 - offset[..., 1:2])\n        f4756 = f_47 * offset[..., 1:2] + f_56 * (1 - offset[..., 1:2])\n\n        encoded_value = f0312 * offset[..., 2:3] + f4756 * (\n            1 - offset[..., 2:3]\n        )  # [..., num_levels, features_per_level]\n\n        return torch.flatten(\n            encoded_value, start_dim=-2, end_dim=-1\n        )  # [..., num_levels * features_per_level]\n\n    def forward(self, in_tensor):\n        return self.pytorch_fwd(in_tensor)\n\n    def get_total_variation_loss(self):\n        \"\"\"Compute the total variation loss for the feature volume.\"\"\"\n        feature_volume = self.hash_table.reshape(\n            self.num_levels,\n            self.periodic_volume_resolution,\n            self.periodic_volume_resolution,\n            self.periodic_volume_resolution,\n            self.features_per_level,\n        )\n        diffx = feature_volume[:, 1:, :, :, :] - feature_volume[:, :-1, :, :, :]\n        diffy = feature_volume[:, :, 1:, :, :] - feature_volume[:, :, :-1, :, :]\n        diffz = feature_volume[:, :, :, 1:, :] - feature_volume[:, :, :, :-1, :]\n\n        # TODO how to sum here or should we use mask?\n        resx = diffx.abs().mean(dim=(1, 2, 3, 4))\n        resy = diffy.abs().mean(dim=(1, 2, 3, 4))\n        resz = diffz.abs().mean(dim=(1, 2, 3, 4))\n\n        return ((resx + resy + resz) * self.per_level_weights).mean()\n"
  },
  {
    "path": "lidarnerf/ffmlp/__init__.py",
    "content": ""
  },
  {
    "path": "lidarnerf/ffmlp/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"--expt-extended-lambda\",\n    \"--expt-relaxed-constexpr\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    nvcc_flags += [\n        \"-Xcompiler=-mf16c\",\n        \"-Xcompiler=-Wno-float-conversion\",\n        \"-Xcompiler=-fno-strict-aliasing\",\n    ]\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(\n    name=\"_ffmlp\",\n    extra_cflags=c_flags,\n    extra_cuda_cflags=nvcc_flags,\n    extra_include_paths=[\n        os.path.join(_src_path, \"dependencies/cutlass/include\"),\n        os.path.join(_src_path, \"dependencies/cutlass/tools/util/include\"),\n    ],\n    sources=[\n        os.path.join(_src_path, \"src\", f)\n        for f in [\n            \"ffmlp.cu\",\n            \"bindings.cpp\",\n        ]\n    ],\n)\n\n__all__ = [\"_backend\"]\n"
  },
  {
    "path": "lidarnerf/ffmlp/ffmlp.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    import _ffmlp as _backend\nexcept ImportError:\n    from .backend import _backend\n\n\nclass _ffmlp_forward(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.half)\n    def forward(\n        ctx,\n        inputs,\n        weights,\n        input_dim,\n        output_dim,\n        hidden_dim,\n        num_layers,\n        activation,\n        output_activation,\n        inference=False,\n        calc_grad_inputs=False,\n    ):\n        B = inputs.shape[0]\n\n        inputs = inputs.contiguous()\n        weights = weights.contiguous()\n\n        # print('[inputs]', torch.any(torch.isnan(inputs)), inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())\n        # print('[weights]', torch.any(torch.isnan(weights)), weights.shape, weights.dtype, weights.min().item(), weights.max().item())\n\n        # allocate output\n        outputs = torch.empty(B, output_dim, device=inputs.device, dtype=inputs.dtype)\n\n        if not inference:\n            forward_buffer = torch.empty(\n                num_layers, B, hidden_dim, device=inputs.device, dtype=inputs.dtype\n            )\n            _backend.ffmlp_forward(\n                inputs,\n                weights,\n                B,\n                input_dim,\n                output_dim,\n                hidden_dim,\n                num_layers,\n                activation,\n                output_activation,\n                forward_buffer,\n                outputs,\n            )\n            ctx.save_for_backward(inputs, weights, outputs, forward_buffer)\n            ctx.dims = (\n                input_dim,\n                output_dim,\n                hidden_dim,\n                num_layers,\n                activation,\n                output_activation,\n                calc_grad_inputs,\n            )\n\n            # print('[outputs]', torch.any(torch.isnan(outputs)), outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())\n            # print('[forward_buffer]', torch.any(torch.isnan(forward_buffer)), forward_buffer.shape, forward_buffer.dtype, forward_buffer.min().item(), forward_buffer.max().item())\n        else:\n            inference_buffer = torch.empty(\n                B, hidden_dim, device=inputs.device, dtype=inputs.dtype\n            )\n            _backend.ffmlp_inference(\n                inputs,\n                weights,\n                B,\n                input_dim,\n                output_dim,\n                hidden_dim,\n                num_layers,\n                activation,\n                output_activation,\n                inference_buffer,\n                outputs,\n            )\n\n            # print('[outputs]', torch.any(torch.isnan(outputs)), outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())\n            # print('[inference_buffer]', torch.any(torch.isnan(inference_buffer)), inference_buffer.shape, inference_buffer.dtype, inference_buffer.min().item(), inference_buffer.max().item())\n\n        return outputs\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad):\n        # grad: [B, output_dim]\n\n        B = grad.shape[0]\n\n        grad = grad.contiguous()\n\n        # print('[grad]', torch.any(torch.isnan(grad)), grad.shape, grad.dtype, grad.min().item(), grad.max().item())\n        # print(grad)\n\n        inputs, weights, outputs, forward_buffer = ctx.saved_tensors\n\n        (\n            input_dim,\n            output_dim,\n            hidden_dim,\n            num_layers,\n            activation,\n            output_activation,\n            calc_grad_inputs,\n        ) = ctx.dims\n\n        # allocate outputs\n        if calc_grad_inputs:\n            grad_inputs = torch.zeros_like(inputs)\n        else:\n            grad_inputs = torch.zeros(1, device=grad.device, dtype=grad.dtype)  # dummy\n\n        grad_weights = torch.zeros_like(weights)\n        backward_buffer = torch.zeros(\n            num_layers, B, hidden_dim, device=grad.device, dtype=grad.dtype\n        )\n\n        _backend.ffmlp_backward(\n            grad,\n            inputs,\n            weights,\n            forward_buffer,\n            B,\n            input_dim,\n            output_dim,\n            hidden_dim,\n            num_layers,\n            activation,\n            output_activation,\n            calc_grad_inputs,\n            backward_buffer,\n            grad_inputs,\n            grad_weights,\n        )\n\n        # print('[grad_inputs]', grad_inputs.shape, grad_inputs.dtype, grad_inputs.min().item(), grad_inputs.max().item())\n        # print('[grad_weights]', grad_weights.shape, grad_weights.dtype, grad_weights.min().item(), grad_weights.max().item())\n        # print('[backward_buffer]', backward_buffer.shape, backward_buffer.dtype, backward_buffer.min().item(), backward_buffer.max().item())\n        if calc_grad_inputs:\n            return (\n                grad_inputs,\n                grad_weights,\n                None,\n                None,\n                None,\n                None,\n                None,\n                None,\n                None,\n                None,\n            )\n        else:\n            return None, grad_weights, None, None, None, None, None, None, None, None\n\n\nffmlp_forward = _ffmlp_forward.apply\n\n\ndef convert_activation(act):\n    if act == \"relu\":\n        return 0\n    elif act == \"exponential\":\n        return 1\n    elif act == \"sine\":\n        return 2\n    elif act == \"sigmoid\":\n        return 3\n    elif act == \"squareplus\":\n        return 4\n    elif act == \"softplus\":\n        return 5\n    else:\n        return 6\n\n\nclass FFMLP(nn.Module):\n    def __init__(\n        self, input_dim, output_dim, hidden_dim, num_layers, activation=\"relu\"\n    ):\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n        self.hidden_dim = hidden_dim\n        self.num_layers = num_layers\n        self.activation = convert_activation(activation)\n        self.output_activation = convert_activation(\"none\")  # not supported currently\n\n        self.tensorcore_width = 16\n\n        assert hidden_dim in [\n            16,\n            32,\n            64,\n            128,\n            256,\n        ], f\"FFMLP only support hidden_dim in [16, 32, 64, 128, 256], but got {hidden_dim}\"\n        assert (\n            input_dim > 0 and input_dim % 16 == 0\n        ), f\"FFMLP input_dim should be 16 * m (m  > 0), but got {input_dim}\"\n        assert (\n            output_dim <= 16\n        ), f\"FFMLP current only supports output dim <= 16, but got {output_dim}\"\n        assert (\n            num_layers >= 2\n        ), f\"FFMLP num_layers should be larger than 2 (3 matmuls), but got {num_layers}\"\n\n        # pad output\n        self.padded_output_dim = int(math.ceil(output_dim / 16)) * 16\n\n        # parameters (continuous in memory)\n        self.num_parameters = hidden_dim * (\n            input_dim + hidden_dim * (num_layers - 1) + self.padded_output_dim\n        )\n        self.weights = nn.Parameter(torch.zeros(self.num_parameters))\n        self.reset_parameters()\n\n        # allocate streams\n        _backend.allocate_splitk(self.num_layers + 1)\n\n        # register destructor\n        # atexit.register(self.cleanup) # how to correctly clean? this gives CUDA Error: cudaEventDestroy(events[i]) failed with error context is destroyed\n\n    def cleanup(self):\n        # destroy streams\n        _backend.free_splitk()\n\n    def __repr__(self):\n        return f\"FFMLP: input_dim={self.input_dim} output_dim={self.output_dim} hidden_dim={self.hidden_dim} num_layers={self.num_layers} activation={self.activation}\"\n\n    def reset_parameters(self):\n        torch.manual_seed(42)\n        std = math.sqrt(3 / self.hidden_dim)\n        self.weights.data.uniform_(-std, std)\n\n    def forward(self, inputs):\n        # inputs: [B, input_dim]\n        # return: [B, outupt_dim]\n\n        # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item(), inputs.requires_grad)\n\n        B, C = inputs.shape\n        # assert B >= 128 and B % 128 == 0, f\"ffmlp batch size must be 128 * m (m > 0), but got {B}.\"\n\n        # pad input\n        pad = 128 - (B % 128)\n        if pad > 0:\n            inputs = torch.cat(\n                [inputs, torch.zeros(pad, C, dtype=inputs.dtype, device=inputs.device)],\n                dim=0,\n            )\n\n        outputs = ffmlp_forward(\n            inputs,\n            self.weights,\n            self.input_dim,\n            self.padded_output_dim,\n            self.hidden_dim,\n            self.num_layers,\n            self.activation,\n            self.output_activation,\n            not self.training,\n            inputs.requires_grad,\n        )\n\n        # unpad output\n        if B != outputs.shape[0] or self.padded_output_dim != self.output_dim:\n            outputs = outputs[:B, : self.output_dim]\n\n        # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())\n\n        return outputs\n"
  },
  {
    "path": "lidarnerf/ffmlp/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"--expt-extended-lambda\",\n    \"--expt-relaxed-constexpr\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    nvcc_flags += [\n        \"-Xcompiler=-mf16c\",\n        \"-Xcompiler=-Wno-float-conversion\",\n        \"-Xcompiler=-fno-strict-aliasing\",\n    ]\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name=\"ffmlp\",  # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name=\"_ffmlp\",  # extension name, import this to use CUDA API\n            sources=[\n                os.path.join(_src_path, \"src\", f)\n                for f in [\n                    \"ffmlp.cu\",\n                    \"bindings.cpp\",\n                ]\n            ],\n            extra_compile_args={\n                \"cxx\": c_flags,\n                \"nvcc\": nvcc_flags,\n            },\n            include_dirs=[\n                os.path.join(_src_path, \"dependencies/cutlass/include\"),\n                os.path.join(_src_path, \"dependencies/cutlass/tools/util/include\"),\n            ],\n        ),\n    ],\n    cmdclass={\n        \"build_ext\": BuildExtension,\n    },\n)\n"
  },
  {
    "path": "lidarnerf/ffmlp/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"ffmlp.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"ffmlp_forward\", &ffmlp_forward, \"ffmlp_forward (CUDA)\");\n    m.def(\"ffmlp_inference\", &ffmlp_inference, \"ffmlp_inference (CUDA)\");\n    m.def(\"ffmlp_backward\", &ffmlp_backward, \"ffmlp_backward (CUDA)\");\n    m.def(\"allocate_splitk\", &allocate_splitk, \"allocate_splitk (CUDA)\");\n    m.def(\"free_splitk\", &free_splitk, \"free_splitk (CUDA)\");\n}"
  },
  {
    "path": "lidarnerf/ffmlp/src/cutlass_matmul.h",
    "content": "/*\n * Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *     * Redistributions of source code must retain the above copyright notice,\n * this list of conditions and the following disclaimer.\n *     * Redistributions in binary form must reproduce the above copyright\n * notice, this list of conditions and the following disclaimer in the\n * documentation and/or other materials provided with the distribution.\n *     * Neither the name of the NVIDIA CORPORATION nor the names of its\n * contributors may be used to endorse or promote products derived from this\n * software without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY\n * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *//*\n */\n\n/** @file   cutlass_matmul.h\n *  @author Thomas Müller, NVIDIA\n *  @brief  Matrix multiplication wrappers that call into CUTLASS (plus some\n * custom modifications). The parameters are optimized to give optimal\n * performance in a variety of situations. Parts of this file were adapted by\n * starting from the CUTLASS sample code (see its BSD 3-clause license).\n */\n\n#pragma once\n\n#include <cutlass/array.h>\n#include <cutlass/cutlass.h>\n#include <cutlass/functional.h>\n#include <cutlass/gemm/device/gemm.h>\n#include <cutlass/gemm/device/gemm_splitk_parallel.h>\n#include <cutlass/numeric_conversion.h>\n#include <cutlass/numeric_types.h>\n#include <torch/torch.h>\n\n#include <iostream>\n#include <map>\n#include <type_traits>\n\n#include \"utils.h\"\n\n//#define TCNN_VERBOSE_MEMORY_ALLOCS\n\n#define CUTLASS_CHECK(status)                                                 \\\n    {                                                                         \\\n        cutlass::Status error = status;                                       \\\n        if (error != cutlass::Status::kSuccess) {                             \\\n            std::cerr << \"Got cutlass error: \"                                \\\n                      << cutlassGetStatusString(error) << \" at: \" << __LINE__ \\\n                      << std::endl;                                           \\\n            exit(EXIT_FAILURE);                                               \\\n        }                                                                     \\\n    }\n\n#define CUDA_CHECK(status)                                                    \\\n    {                                                                         \\\n        cudaError_t error = status;                                           \\\n        if (error != cudaSuccess) {                                           \\\n            std::cerr << \"Got bad cuda status: \" << cudaGetErrorString(error) \\\n                      << \" at line: \" << __LINE__ << std::endl;               \\\n            exit(EXIT_FAILURE);                                               \\\n        }                                                                     \\\n    }\n\nusing SmArch = std::conditional_t<\n        MIN_GPU_ARCH >= 80,\n        std::conditional_t<std::is_same<network_precision_t, float>::value,\n                           cutlass::arch::Sm75,\n                           cutlass::arch::Sm80>,\n        std::conditional_t<MIN_GPU_ARCH >= 75,\n                           cutlass::arch::Sm75,\n                           cutlass::arch::Sm70>>;\n\nusing TypeAccumulator =\n        std::conditional_t<std::is_same<network_precision_t, float>::value,\n                           float,\n                           cutlass::half_t>;\nusing TypeCompute =\n        std::conditional_t<std::is_same<network_precision_t, float>::value,\n                           float,\n                           cutlass::half_t>;\n\ntemplate <typename T>\nusing MMAOp = typename std::conditional<std::is_same<T, float>::value,\n                                        cutlass::arch::OpClassSimt,\n                                        cutlass::arch::OpClassTensorOp>::type;\n\ntemplate <typename T>\nusing ShapeMMAOp = typename std::conditional<\n        std::is_same<MMAOp<T>, cutlass::arch::OpClassTensorOp>::value,\n        typename std::conditional<\n                std::is_same<SmArch, cutlass::arch::Sm80>::value ||\n                        std::is_same<SmArch, cutlass::arch::Sm75>::value,\n                cutlass::gemm::GemmShape<16, 8, 8>,\n                cutlass::gemm::GemmShape<8, 8, 4>>::type,\n        cutlass::gemm::GemmShape<1, 1, 1>>::type;\n\ntemplate <typename thread_block, typename warp>\nstruct LayerConfig {\n    using k_thread_block = thread_block;\n    using k_warp = warp;\n};\n\nusing FullLayerK = typename std::conditional<\n        std::is_same<MMAOp<network_precision_t>,\n                     cutlass::arch::OpClassSimt>::value,\n        LayerConfig<cutlass::gemm::GemmShape<128, 128, 8>,\n                    cutlass::gemm::GemmShape<32, 64, 8>>,\n        LayerConfig<cutlass::gemm::GemmShape<64, 64, 32>,\n                    cutlass::gemm::GemmShape<32, 32, 32>>>::type;\nusing LastLayerK = FullLayerK;\n\nusing FullLayer = typename std::conditional<\n        std::is_same<MMAOp<network_precision_t>,\n                     cutlass::arch::OpClassSimt>::value,\n        LayerConfig<cutlass::gemm::GemmShape<128, 128, 8>,\n                    cutlass::gemm::GemmShape<32, 64, 8>>,\n        LayerConfig<cutlass::gemm::GemmShape<128, 128, 32>,\n                    cutlass::gemm::GemmShape<64, 64, 32>>>::type;\nusing LastLayer = FullLayer;\n\n// This code section describes how threadblocks are scheduled on GPU\nusing SwizzleThreadBlock =\n        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;\n\n// This code section describes the epilogue part of the kernel\n\ntemplate <typename V>\nstruct CutlassFragmentWrapper {\n    static const uint32_t num_elements = V::kElements;\n    V x;\n};\n\ntemplate <typename ElementOutput_,  ///< Data type used to load and store\n                                    ///< tensors\n          int Count,  ///< Number of elements computed per operation\n          typename ElementAccumulator_ =\n                  ElementOutput_,  ///< Accumulator data type\n          typename ElementCompute_ =\n                  ElementOutput_,  ///< Data type used to compute linear\n                                   ///< combination\n          cutlass::FloatRoundStyle Round =\n                  cutlass::FloatRoundStyle::round_to_nearest>\nclass ActivationEpilogue {\npublic:\n    using ElementOutput = ElementOutput_;\n    using ElementAccumulator = ElementAccumulator_;\n    using ElementCompute = ElementCompute_;\n\n    static int const kCount = Count;\n\n    using FragmentOutput = cutlass::Array<ElementOutput, kCount>;\n    using FragmentAccumulator = cutlass::Array<ElementAccumulator, kCount>;\n    using ComputeFragment = cutlass::Array<ElementCompute, kCount>;\n\n    static cutlass::FloatRoundStyle const kRound = Round;\n\n    struct Params {\n        Activation activation;\n        bool sum_source;\n    };\n\npublic:\n    CUTLASS_HOST_DEVICE\n    ActivationEpilogue(Params const &params)\n        : m_activation{params.activation}, m_sum_source{params.sum_source} {}\n\n    CUTLASS_HOST_DEVICE\n    bool is_source_needed() const { return m_sum_source; }\n\n    /// Functionally required for serial reduction in the epilogue\n    CUTLASS_HOST_DEVICE\n    void set_k_partition(int k_partition, int k_partition_count) {}\n\n    CUTLASS_HOST_DEVICE\n    FragmentOutput operator()(FragmentAccumulator const &accumulator) const {\n        cutlass::NumericArrayConverter<ElementCompute, ElementAccumulator,\n                                       kCount, Round>\n                accumulator_converter;\n\n        auto intermediate = CutlassFragmentWrapper<ComputeFragment>{\n                accumulator_converter(accumulator)};\n        intermediate =\n                warp_activation<ElementCompute>(m_activation, intermediate);\n\n        cutlass::NumericArrayConverter<ElementOutput, ElementCompute, kCount,\n                                       Round>\n                destination_converter;\n        return destination_converter(intermediate.x);\n    }\n\n    CUTLASS_HOST_DEVICE\n    FragmentOutput operator()(FragmentAccumulator const &accumulator,\n                              FragmentOutput const &source) const {\n        cutlass::NumericArrayConverter<ElementCompute, ElementOutput, kCount,\n                                       Round>\n                source_converter;\n        cutlass::NumericArrayConverter<ElementCompute, ElementAccumulator,\n                                       kCount, Round>\n                accumulator_converter;\n\n        cutlass::plus<ComputeFragment> plus_op;\n        auto intermediate = CutlassFragmentWrapper<ComputeFragment>{\n                accumulator_converter(accumulator)};\n        if (m_sum_source) {\n            intermediate.x = plus_op(intermediate.x, source_converter(source));\n        }\n        intermediate =\n                warp_activation<ElementCompute>(m_activation, intermediate);\n\n        cutlass::NumericArrayConverter<ElementOutput, ElementCompute, kCount,\n                                       Round>\n                destination_converter;\n        return destination_converter(intermediate.x);\n    }\n\nprivate:\n    Activation m_activation;\n    bool m_sum_source;\n};\n\ntemplate <typename ElementOutput_,  ///< Data type used to load and store\n                                    ///< tensors\n          int Count,  ///< Number of elements computed per operation\n          typename ElementAccumulator_ =\n                  ElementOutput_,  ///< Accumulator data type\n          typename ElementCompute_ =\n                  ElementOutput_,  ///< Data type used to compute linear\n                                   ///< combination\n          cutlass::FloatRoundStyle Round =\n                  cutlass::FloatRoundStyle::round_to_nearest>\nclass ActivationTransferEpilogue {\npublic:\n    using ElementOutput = ElementOutput_;\n    using ElementAccumulator = ElementAccumulator_;\n    using ElementCompute = ElementCompute_;\n\n    static int const kCount = Count;\n\n    using FragmentOutput = cutlass::Array<ElementOutput, kCount>;\n    using FragmentAccumulator = cutlass::Array<ElementAccumulator, kCount>;\n    using ComputeFragment = cutlass::Array<ElementCompute, kCount>;\n\n    static cutlass::FloatRoundStyle const kRound = Round;\n\n    /// Host-constructable parameters structure\n    struct Params {\n        Activation activation;\n    };\n\npublic:\n    /// Constructs the function object, possibly loading from pointers in host\n    /// memory\n    CUTLASS_HOST_DEVICE\n    ActivationTransferEpilogue(Params const &params)\n        : m_activation{params.activation} {}\n\n    /// Returns true if source is needed\n    CUTLASS_HOST_DEVICE\n    bool is_source_needed() const { return true; }\n\n    /// Functionally required for serial reduction in the epilogue\n    CUTLASS_HOST_DEVICE\n    void set_k_partition(int k_partition, int k_partition_count) {}\n\n    CUTLASS_HOST_DEVICE\n    FragmentOutput operator()(FragmentAccumulator const &accumulator,\n                              FragmentOutput const &source) const {\n        cutlass::NumericArrayConverter<ElementCompute, ElementOutput, kCount,\n                                       Round>\n                source_converter;\n        cutlass::NumericArrayConverter<ElementCompute, ElementAccumulator,\n                                       kCount, Round>\n                accumulator_converter;\n\n        auto converted_source = CutlassFragmentWrapper<ComputeFragment>{\n                source_converter(source)};\n        auto intermediate = CutlassFragmentWrapper<ComputeFragment>{\n                accumulator_converter(accumulator)};\n\n        intermediate = warp_activation_backward<ElementCompute>(\n                m_activation, intermediate, converted_source);\n\n        cutlass::NumericArrayConverter<ElementOutput, ElementCompute, kCount,\n                                       Round>\n                destination_converter;\n        return destination_converter(intermediate.x);\n    }\n\n    CUTLASS_HOST_DEVICE\n    FragmentOutput operator()(FragmentAccumulator const &accumulator) const {\n        cutlass::NumericArrayConverter<ElementCompute, ElementAccumulator,\n                                       kCount, Round>\n                accumulator_converter;\n\n        ComputeFragment converted_accumulator =\n                accumulator_converter(accumulator);\n\n        cutlass::NumericArrayConverter<ElementOutput, ElementCompute, kCount,\n                                       Round>\n                destination_converter;\n\n        return destination_converter(converted_accumulator);\n    }\n\nprivate:\n    Activation m_activation;\n};\n\ntemplate <typename T>\nstatic constexpr int n_vectorized_elements =\n        std::is_same<MMAOp<T>, cutlass::arch::OpClassTensorOp>::value\n                ? (128 / cutlass::sizeof_bits<T>::value)\n                : 1;\n\ntemplate <typename T>\nusing SumOp =\n        cutlass::epilogue::thread::LinearCombination<T,\n                                                     n_vectorized_elements<T>,\n                                                     TypeAccumulator,\n                                                     TypeCompute>;\n\ntemplate <typename T>\nusing IntermediateActivationOp =\n        ActivationEpilogue<T, 4, TypeAccumulator, TypeCompute>;\n\ntemplate <typename T>\nusing IntermediateActivationTransferOp =\n        ActivationTransferEpilogue<T, 4, TypeAccumulator, TypeCompute>;\n\ntemplate <typename T>\nusing ActivationOp = ActivationEpilogue<T,\n                                        n_vectorized_elements<T>,\n                                        TypeAccumulator,\n                                        TypeCompute>;\n\ntemplate <typename T>\nusing ActivationTransferOp =\n        ActivationTransferEpilogue<T,\n                                   n_vectorized_elements<T>,\n                                   TypeAccumulator,\n                                   TypeCompute>;\n\ntemplate <typename EPILOGUE,\n          typename LayerConfig,\n          typename TypeA,\n          typename LayoutA,\n          typename TypeB,\n          typename LayoutB,\n          typename TypeOutput,\n          typename LayoutOutput>\nusing OurGemm =\n        cutlass::gemm::device::Gemm<TypeA,\n                                    LayoutA,\n                                    TypeB,\n                                    LayoutB,\n                                    TypeOutput,\n                                    LayoutOutput,\n                                    TypeAccumulator,\n                                    MMAOp<TypeA>,\n                                    SmArch,\n                                    typename LayerConfig::k_thread_block,\n                                    typename LayerConfig::k_warp,\n                                    ShapeMMAOp<TypeA>,\n                                    EPILOGUE,\n                                    SwizzleThreadBlock,\n                                    2>;\n\ntemplate <typename EPILOGUE,\n          typename LayerConfig,\n          typename TypeA,\n          typename LayoutA,\n          typename TypeB,\n          typename LayoutB,\n          typename TypeOutput,\n          typename LayoutOutput>\nusing SplitKGemm = cutlass::gemm::device::GemmSplitKParallel<\n        TypeA,\n        LayoutA,\n        TypeB,\n        LayoutB,\n        TypeOutput,\n        LayoutOutput,\n        TypeAccumulator,\n        MMAOp<TypeA>,\n        SmArch,\n        typename LayerConfig::k_thread_block,\n        typename LayerConfig::k_warp,\n        ShapeMMAOp<TypeA>,\n        EPILOGUE>;\n\ninline std::map<cudaStream_t, GPUMemory<uint8_t>> &cutlass_workspaces() {\n    static std::map<cudaStream_t, GPUMemory<uint8_t>> s_workspaces;\n    return s_workspaces;\n}\n\ninline uint8_t *cutlass_get_workspace(size_t size, cudaStream_t stream) {\n    GPUMemory<uint8_t> &workspace = cutlass_workspaces()[stream];\n    if (size > workspace.size()) {\n        size *= 2;\n#ifdef TCNN_VERBOSE_MEMORY_ALLOCS\n        std::cout << \"CUTLASS GEMM: Allocating temporary workspace of \"\n                  << bytes_to_string(size) << \".\" << std::endl;\n#endif\n\n        // Allocate twice the requested size to make sure we're not constantly\n        // allocating small increments.\n        workspace.resize(size);\n    }\n    return workspace.data();\n}\n\ninline void cutlass_free_workspace(cudaStream_t stream) {\n    if (cutlass_workspaces().count(stream) == 0) {\n        return;\n    }\n\n#ifdef TCNN_VERBOSE_MEMORY_ALLOCS\n    std::cout << \"CUTLASS GEMM: Freeing temporary workspace of \"\n              << bytes_to_string(cutlass_workspaces().at(stream).size()) << \".\"\n              << std::endl;\n#endif\n    cutlass_workspaces().erase(stream);\n}\n\ntemplate <class Gemm>\nvoid fc_multiply_impl(cudaStream_t stream,\n                      const typename Gemm::Arguments &args) {\n    // Using the arguments, query for extra workspace required for matrix\n    // multiplication computation\n    size_t workspace_size = Gemm::get_workspace_size(args);\n\n    // Instantiate CUTLASS kernel depending on templates\n    Gemm gemm_op;\n\n    // Initialize CUTLASS kernel with arguments and workspace pointer\n    cutlass::Status status = gemm_op.initialize(\n            args, cutlass_get_workspace(workspace_size, stream), stream);\n    CUTLASS_CHECK(status);\n\n    // Launch initialized CUTLASS kernel\n    status = gemm_op(stream);\n    CUTLASS_CHECK(status);\n}\n\ntemplate <class Gemm>\nvoid fc_multiply_split_k_impl(cudaStream_t stream,\n                              const typename Gemm::Arguments &args) {\n    // Using the arguments, query for extra workspace required for matrix\n    // multiplication computation\n    size_t workspace_size = Gemm::get_workspace_size(args);\n\n    // Instantiate CUTLASS kernel depending on templates\n    Gemm gemm_op;\n\n    // Initialize CUTLASS kernel with arguments and workspace pointer\n    cutlass::Status status = gemm_op.initialize(\n            args, cutlass_get_workspace(workspace_size, stream));\n    CUTLASS_CHECK(status);\n\n    // Launch initialized CUTLASS kernel\n    status = gemm_op(stream);\n    CUTLASS_CHECK(status);\n}\n\n//////////////////////////////////////////////////////////////////////////////////\n////////////////////////////        modified ///////////////////////////////\n//////////////////////////////////////////////////////////////////////////////////\n\ntemplate <typename config, bool RM_A, bool RM_B, bool RM_C>\nvoid fc_multiply(cudaStream_t stream,\n                 const int M,\n                 const int K,\n                 const int N,\n                 const __half *A,\n                 const __half *B,\n                 const __half *C,\n                 __half *D,\n                 Activation act = Activation::None,\n                 bool transfer = false,\n                 bool sum_source = false) {\n    // compute  D = A @ B + C\n    // A: [M, K]\n    // B: [K, N]\n    // C, D: [M, N]\n    using CutlassLayoutA =\n            typename std::conditional<RM_A, cutlass::layout::RowMajor,\n                                      cutlass::layout::ColumnMajor>::type;\n    using CutlassLayoutB =\n            typename std::conditional<RM_B, cutlass::layout::RowMajor,\n                                      cutlass::layout::ColumnMajor>::type;\n    using CutlassLayoutC =\n            typename std::conditional<RM_C, cutlass::layout::RowMajor,\n                                      cutlass::layout::ColumnMajor>::type;\n\n    using MatmulTypeCompute = cutlass::half_t;\n    using MatmulTypeAccumulator = cutlass::half_t;\n\n    const int lda = RM_A ? K : M;\n    const int ldb = RM_B ? N : K;\n    const int ldc = RM_C ? N : M;\n    const int ldd = RM_C ? N : M;\n\n    if (transfer) {\n        using Gemm =\n                OurGemm<ActivationTransferOp<MatmulTypeAccumulator>, config,\n                        MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute,\n                        CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>;\n        typename Gemm::Arguments arguments{{M, N, K},\n                                           {(MatmulTypeCompute *)A, lda},\n                                           {(MatmulTypeCompute *)B, ldb},\n                                           {(MatmulTypeAccumulator *)C, ldc},\n                                           {(MatmulTypeAccumulator *)D, ldd},\n                                           {act},\n                                           1};\n\n        fc_multiply_impl<Gemm>(stream, arguments);\n    } else {\n        using Gemm =\n                OurGemm<ActivationOp<MatmulTypeAccumulator>, config,\n                        MatmulTypeCompute, CutlassLayoutA, MatmulTypeCompute,\n                        CutlassLayoutB, MatmulTypeAccumulator, CutlassLayoutC>;\n        typename Gemm::Arguments arguments{{M, N, K},\n                                           {(MatmulTypeCompute *)A, lda},\n                                           {(MatmulTypeCompute *)B, ldb},\n                                           {(MatmulTypeAccumulator *)C, ldc},\n                                           {(MatmulTypeAccumulator *)D, ldd},\n                                           {act, sum_source},\n                                           1};\n\n        fc_multiply_impl<Gemm>(stream, arguments);\n    }\n}\n\ntemplate <typename config, bool RM_A, bool RM_B, bool RM_C>\nvoid fc_multiply(cudaStream_t stream,\n                 const int M,\n                 const int K,\n                 const int N,\n                 const __half *A,\n                 const __half *B,\n                 __half *D,\n                 Activation act = Activation::None) {\n    fc_multiply<config, RM_A, RM_B, RM_C>(stream, M, K, N, A, B, D, D, act);\n}\n\ntemplate <typename config, bool RM_A, bool RM_B, bool RM_C>\nvoid fc_multiply_split_k(cudaStream_t stream,\n                         const int M,\n                         const int K,\n                         const int N,\n                         const __half *A,\n                         const __half *B,\n                         const __half *C,\n                         __half *D,\n                         int split_k_slices = 1) {\n    // A: [M, K]\n    // B: [K, N]\n    // C, D: [M, N]\n    using CutlassLayoutA =\n            typename std::conditional<RM_A, cutlass::layout::RowMajor,\n                                      cutlass::layout::ColumnMajor>::type;\n    using CutlassLayoutB =\n            typename std::conditional<RM_B, cutlass::layout::RowMajor,\n                                      cutlass::layout::ColumnMajor>::type;\n    using CutlassLayoutC =\n            typename std::conditional<RM_C, cutlass::layout::RowMajor,\n                                      cutlass::layout::ColumnMajor>::type;\n\n    using MatmulTypeCompute = cutlass::half_t;\n    using MatmulTypeAccumulator = cutlass::half_t;\n\n    const int lda = RM_A ? K : M;\n    const int ldb = RM_B ? N : K;\n    const int ldc = RM_C ? N : M;\n    const int ldd = RM_C ? N : M;\n\n    using Gemm =\n            SplitKGemm<SumOp<MatmulTypeAccumulator>, config, MatmulTypeCompute,\n                       CutlassLayoutA, MatmulTypeCompute, CutlassLayoutB,\n                       MatmulTypeAccumulator, CutlassLayoutC>;\n\n    typename Gemm::Arguments arguments{{M, N, K},\n                                       {(MatmulTypeCompute *)A, lda},\n                                       {(MatmulTypeCompute *)B, ldb},\n                                       {(MatmulTypeAccumulator *)C, ldc},\n                                       {(MatmulTypeAccumulator *)D, ldd},\n                                       {(TypeCompute)1.0f, (TypeCompute)0.0f},\n                                       split_k_slices};\n\n    fc_multiply_split_k_impl<Gemm>(stream, arguments);\n}\n\ntemplate <typename config, bool RM_A, bool RM_B, bool RM_C>\nvoid fc_multiply_split_k(cudaStream_t stream,\n                         const int M,\n                         const int K,\n                         const int N,\n                         const __half *A,\n                         const __half *B,\n                         __half *D,\n                         int split_k_slices = 1) {\n    fc_multiply_split_k<config, RM_A, RM_B, RM_C>(stream, M, K, N, A, B, D, D,\n                                                  split_k_slices);\n}\n"
  },
  {
    "path": "lidarnerf/ffmlp/src/ffmlp.cu",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <mma.h>\n#include <stdint.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <cstdio>\n#include <stdexcept>\n#include <vector>\n\n#include \"cutlass_matmul.h\"\n#include \"utils.h\"\n\n__host__ __device__ Activation convert_activation(const uint32_t activation) {\n    switch (activation) {\n        case 0:\n            return Activation::ReLU;\n        case 1:\n            return Activation::Exponential;\n        case 2:\n            return Activation::Sine;\n        case 3:\n            return Activation::Sigmoid;\n        case 4:\n            return Activation::Squareplus;\n        case 5:\n            return Activation::Softplus;\n        case 6:\n            return Activation::None;\n        default:\n            return Activation::None;\n    }\n}\n\ntemplate <typename T>\n__host__ __device__ T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\nvoid check_shmem_error(cudaError_t error) {\n    if (error != cudaSuccess) {\n        throw std::runtime_error{\n                \"FullyFusedMLP: insufficient shared memory available on the \"\n                \"GPU. \"\n                \"Reduce `n_neurons` or use `CutlassMLP` (better compatibility \"\n                \"but \"\n                \"slower) instead.\"};\n    }\n}\n\ntemplate <int WIDTH,\n          int BLOCK_DIM_Z,\n          int N_ITERS,\n          typename OUT_T,\n          bool BACKWARD = false>\n__device__ void threadblock_layer(\n        Activation activation,\n        __half *__restrict__ act_shmem,\n        const __half *__restrict__ weights_this_layer,\n        OUT_T *__restrict__ out_intermediate_threadblock_this_layer,\n        const OUT_T *__restrict__ activation_aux = nullptr) {\n    // act_shmem contains the intermediate activations (shared memory) of the\n    // thread block's chunk of the batch.\n    //           Can be forward activations or backward activations, depending\n    //           on caller.\n    // weights_this_layer points to the weight matrix of the current layer.\n    // out_intermediate_threadblock_this_layer points to the location where\n    // intermediate activations produced by the thread block should be written\n    // to.\n    //                  Can be nullptr if nothing should be written.\n    // activation_aux points to additional arguments that the activation\n    // function may depend on. Points to the hidden forward activations when\n    // computing backward activations.\n\n    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;\n    constexpr uint32_t N_BLOCKS = WIDTH / 16;\n\n    using namespace nvcuda;\n\n    // If we're performing the backward pass, weights must be loaded in\n    // transposed form, which is achieved by interpreting the memory in\n    // row_major instead of col_major order.\n    using weights_layout_t =\n            std::conditional_t<BACKWARD, wmma::row_major, wmma::col_major>;\n\n    // Fragments\n    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major>\n            act_frag;\n    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, weights_layout_t>\n            weights_frag[N_BLOCKS];\n    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS];\n\n    // Indices\n    const uint32_t li = threadIdx.x;  // index in warp (\"lane index\")\n    const uint32_t wi = threadIdx.y;  // index in block (\"warp index\")\n\n    const uint32_t lane_offset = (8 * li) % WIDTH;\n    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;\n\n    const uint32_t weights_col = 16 * wi;\n\n    __syncthreads();\n\n// Load N_BLOCKS chunks of weights from global memory into registers.\n#pragma unroll\n    for (uint32_t i = 0; i < N_BLOCKS; ++i) {\n        if (BACKWARD) {\n            // If we're performing the backward pass, additional index swizzling\n            // is needed to load the weights in transposed form.\n            wmma::load_matrix_sync(\n                    weights_frag[i],\n                    weights_this_layer + 16 * i * WIDTH + weights_col, WIDTH);\n        } else {\n            wmma::load_matrix_sync(\n                    weights_frag[i],\n                    weights_this_layer + 16 * i + weights_col * WIDTH, WIDTH);\n        }\n    }\n\n#pragma unroll\n    for (int l = 0; l < N_ITERS; ++l) {\n        wmma::fill_fragment(result_frag[l], 0.0f);\n\n#pragma unroll\n        for (uint32_t i = 0; i < N_BLOCKS; ++i) {\n            // Load a chunk of intermediate activations from shared memory and\n            // multiply with chunk of weights\n            wmma::load_matrix_sync(\n                    act_frag,\n                    act_shmem + 16 * i +\n                            (16 * (threadIdx.z + l * BLOCK_DIM_Z)) *\n                                    (WIDTH + SKEW),\n                    WIDTH + SKEW);\n            wmma::mma_sync(result_frag[l], act_frag, weights_frag[i],\n                           result_frag[l]);\n        }\n\n        // Activation\n        if (BACKWARD) {\n            // Load the temporary forward matrix for the relu transfer\n            wmma::load_matrix_sync(\n                    act_frag,\n                    activation_aux + weights_col +\n                            (threadIdx.z + l * BLOCK_DIM_Z) * 16 * WIDTH,\n                    WIDTH);\n            warp_activation_backward<__half>(activation, result_frag[l],\n                                             act_frag, result_frag[l]);\n        } else {\n            warp_activation<__half>(activation, result_frag[l], result_frag[l]);\n        }\n    }\n\n    __syncthreads();\n\n#pragma unroll\n    for (int l = 0; l < N_ITERS; ++l) {\n        wmma::store_matrix_sync(\n                act_shmem + weights_col +\n                        (threadIdx.z + l * BLOCK_DIM_Z) * 16 * (WIDTH + SKEW),\n                result_frag[l], WIDTH + SKEW, wmma::mem_row_major);\n    }\n\n    if (out_intermediate_threadblock_this_layer != nullptr) {\n        __syncthreads();\n\n#pragma unroll\n        for (int l = 0; l < N_ITERS; ++l) {\n            *(int4 *)&out_intermediate_threadblock_this_layer\n                    [lane_offset +\n                     (row + 16 * (threadIdx.z + l * BLOCK_DIM_Z)) * WIDTH] =\n                    *(int4 *)&act_shmem[lane_offset +\n                                        (row +\n                                         16 * (threadIdx.z + l * BLOCK_DIM_Z)) *\n                                                (WIDTH + SKEW)];\n        }\n    }\n}\n\ntemplate <int WIDTH, int BLOCK_DIM_Z, int N_ITERS>\n__device__ void threadblock_load_input_static(\n        __half *__restrict__ act_shmem,\n        const __half *__restrict__ input_threadblock) {\n    // act_shmem will be filled by the thread block's chunk of input_threadblock\n\n    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;\n\n    // Indices\n    const uint32_t li = threadIdx.x;  // index in warp (\"lane index\")\n    const uint32_t wi = threadIdx.y;  // index in block (\"warp index\")\n\n    const uint32_t lane_offset = (8 * li) % WIDTH;\n    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;\n\n#pragma unroll\n    for (int i = 0; i < N_ITERS; ++i) {\n        *(int4 *)&act_shmem[lane_offset +\n                            (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) *\n                                    (WIDTH + SKEW)] =\n                *(int4 *)&input_threadblock\n                        [lane_offset +\n                         (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * WIDTH];\n    }\n}\n\ntemplate <int WIDTH, int BLOCK_DIM_Z, int N_ITERS, typename OUT_T>\n__device__ void threadblock_input_layer_forward_dynamic(\n        Activation activation,\n        __half *__restrict__ act_shmem,\n        const __half *__restrict__ input_threadblock,\n        const __half *__restrict__ weights_this_layer,\n        OUT_T *__restrict__ out_intermediate_threadblock_this_layer,\n        const uint32_t in_width) {\n    // act_shmem contains the intermediate activations (shared memory) of the\n    // thread block's chunk of the batch input_threadblock points to the thread\n    // block's chunk of the input batch in global memory weights_this_layer\n    // points to the weight matrix of the current layer\n    // out_intermediate_threadblock_this_layer points to the location where\n    // intermediate activations produced by the thread block should be written\n    // to.\n    //                  Can be nullptr if nothing should be written.\n    // in_width is the dynamic width of the input layer\n\n    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;\n    constexpr uint32_t INPUT_SKEW = 8;\n    constexpr uint32_t N_BLOCKS = WIDTH / 16;\n\n    using namespace nvcuda;\n\n    // Fragments\n    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major>\n            act_frag;\n    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major>\n            weights_frag;\n    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag[N_ITERS];\n\n    // Indices\n    const uint32_t li = threadIdx.x;  // index in warp (\"lane index\")\n    const uint32_t wi = threadIdx.y;  // index in block (\"warp index\")\n\n    const uint32_t lane_offset = (8 * li) % WIDTH;\n    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;\n\n    const uint32_t weights_col = 16 * wi;\n\n    __half *__restrict__ weights_shmem =\n            act_shmem + BLOCK_DIM_Z * 16 * (in_width + INPUT_SKEW);\n\n    // Load input weight matrix (fits completely into shared memory)\n    // Each thread can load 8 fp16 elements (16 bytes) at once; we have\n    // N_BLOCKS*BLOCK_DIM_Z warps\n    const uint32_t n_elems_per_load = N_BLOCKS * 32 * BLOCK_DIM_Z * 8;\n    const uint32_t thread_elem_idx =\n            (li + wi * 32 + threadIdx.z * N_BLOCKS * 32) * 8;\n\n    const uint32_t n_elems_b = WIDTH * in_width;\n\n#pragma unroll\n    for (uint32_t idx = thread_elem_idx; idx < n_elems_b;\n         idx += n_elems_per_load) {\n        const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW;\n        *(int4 *)&weights_shmem[idx_skewed] = *(int4 *)&weights_this_layer[idx];\n    }\n\n    const uint32_t n_tensor_ops = in_width / 16;\n\n#pragma unroll\n    for (int l = 0; l < N_ITERS; ++l) {\n        // Load chunk of inputs into shmem.\n        // This is faster than loading it from gmem directly, even though it is\n        // only used once. (Possibly due to latency hiding through staging.)\n        const uint32_t n_elems_a = BLOCK_DIM_Z * 16 * in_width;\n\n#pragma unroll\n        for (uint32_t idx = thread_elem_idx; idx < n_elems_a;\n             idx += n_elems_per_load) {\n            const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW;\n            *(int4 *)&act_shmem[idx_skewed] =\n                    *(int4 *)&input_threadblock[l * n_elems_a + idx];\n        }\n\n        __syncthreads();\n\n        wmma::fill_fragment(result_frag[l], 0.0f);\n#pragma unroll\n        for (uint32_t i = 0; i < n_tensor_ops; ++i) {\n            // Load chunk of inputs and weights from shared memory and multiply\n            // them\n            wmma::load_matrix_sync(\n                    act_frag,\n                    act_shmem + 16 * i +\n                            (16 * threadIdx.z) * (in_width + INPUT_SKEW),\n                    in_width + INPUT_SKEW);\n            wmma::load_matrix_sync(\n                    weights_frag,\n                    weights_shmem + 16 * i +\n                            weights_col * (in_width + INPUT_SKEW),\n                    in_width + INPUT_SKEW);\n            wmma::mma_sync(result_frag[l], act_frag, weights_frag,\n                           result_frag[l]);\n        }\n\n        __syncthreads();\n\n        warp_activation<__half>(activation, result_frag[l], result_frag[l]);\n    }\n\n#pragma unroll\n    for (int l = 0; l < N_ITERS; ++l) {\n        wmma::store_matrix_sync(\n                act_shmem + weights_col +\n                        (16 * (threadIdx.z + l * BLOCK_DIM_Z)) * (WIDTH + SKEW),\n                result_frag[l], WIDTH + SKEW, wmma::mem_row_major);\n    }\n\n    if (out_intermediate_threadblock_this_layer != nullptr) {\n        __syncthreads();\n\n#pragma unroll\n        for (int i = 0; i < N_ITERS; ++i) {\n            *(int4 *)&out_intermediate_threadblock_this_layer\n                    [lane_offset +\n                     (row + 16 * (threadIdx.z + i * BLOCK_DIM_Z)) * WIDTH] =\n                    *(int4 *)&act_shmem[lane_offset +\n                                        (row +\n                                         16 * (threadIdx.z + i * BLOCK_DIM_Z)) *\n                                                (WIDTH + SKEW)];\n        }\n    }\n}\n\ntemplate <int WIDTH, int BLOCK_DIM_Z, int N_ITERS, typename OUT_T>\n__device__ void threadblock_last_layer_forward(\n        Activation activation,\n        __half *__restrict__ act_shmem,\n        const __half *__restrict__ weights_this_layer,\n        OUT_T *__restrict__ out,\n        const uint32_t batch_size,\n        const nvcuda::wmma::layout_t output_layout) {\n    // act_shmem contains the intermediate activations (shared memory) of the\n    // thread block's chunk of the batch weights_this_layer points to the weight\n    // matrix of the current layer out points to the location where the result\n    // produced by the thread block should be written to.\n    //   Can be nullptr if nothing should be written.\n\n    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;\n    constexpr uint32_t N_BLOCKS = WIDTH / 16;\n\n    using namespace nvcuda;\n\n    // Fragments\n    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major>\n            act_frag;\n    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major>\n            weights_frag[N_BLOCKS];\n    wmma::fragment<wmma::accumulator, 16, 16, 16, OUT_T> result_frag;\n\n    // Indices\n    const uint32_t li = threadIdx.x;  // index in warp (\"lane index\")\n    const uint32_t wi = threadIdx.y;  // index in block (\"warp index\")\n\n    __half *__restrict__ weights_shmem =\n            act_shmem + N_ITERS * BLOCK_DIM_Z * 16 * (WIDTH + SKEW);\n\n    const uint32_t weights_row = (8 * li) % WIDTH;\n    const uint32_t weights_col = (8 * li + 8 * 32 * wi) / WIDTH;\n\n    // Load weight matrix into shared memory for the last multiplication.\n    // Loading into shared memory as opposed to directly into registers is\n    // faster because unlike in the previous layers, each warp uses the same\n    // entries of the weight matrix.\n    if (threadIdx.z == 0) {\n        *(int4 *)&weights_shmem[weights_row + weights_col * (WIDTH + SKEW)] =\n                *(int4 *)&weights_this_layer[weights_row + weights_col * WIDTH];\n        // printf(\"[last forward] base=%d, shmem=%d, weight=%d\\n\", N_ITERS *\n        // BLOCK_DIM_Z * 16 * (WIDTH + SKEW), weights_row + weights_col * (WIDTH\n        // + SKEW), weights_row + weights_col * WIDTH);\n    }\n\n    __syncthreads();\n\n#pragma unroll\n    for (uint32_t i = 0; i < N_BLOCKS; ++i)\n        wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i,\n                               WIDTH + SKEW);\n\n    // Perform last layer by parallelizing over iters\n    for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) {\n        wmma::fill_fragment(result_frag, 0.0f);\n\n#pragma unroll\n        for (uint32_t i = 0; i < N_BLOCKS; ++i) {\n            // Load a chunk of intermediate activations from shared memory and\n            // multiply with chunk of the weight matrix\n            wmma::load_matrix_sync(\n                    act_frag,\n                    act_shmem + 16 * i +\n                            (16 * (threadIdx.z + idx * BLOCK_DIM_Z)) *\n                                    (WIDTH + SKEW),\n                    WIDTH + SKEW);\n            wmma::mma_sync(result_frag, act_frag, weights_frag[i], result_frag);\n        }\n\n        warp_activation<__half>(activation, result_frag, result_frag);\n\n        if (output_layout == wmma::mem_row_major) {\n            wmma::store_matrix_sync(\n                    out + (threadIdx.z + idx * BLOCK_DIM_Z) * 16 * 16,\n                    result_frag, 16, output_layout);\n            // printf(\"[last forward] RM write out %d, batch %d\\n\", (threadIdx.z\n            // + idx\n            // * BLOCK_DIM_Z) * 16 * 16, 16);\n        } else {\n            wmma::store_matrix_sync(\n                    out + (threadIdx.z + idx * BLOCK_DIM_Z) * 16, result_frag,\n                    batch_size, output_layout);\n            // printf(\"[last forward] CM write out %d, batch %d\\n\", (threadIdx.z\n            // + idx\n            // * BLOCK_DIM_Z) * 16, batch_size);\n        }\n    }\n}\n\ntemplate <int WIDTH, int BLOCK_DIM_Z, int N_ITERS>\n__device__ void threadblock_write_output_static(\n        const __half *__restrict__ act_shmem,\n        __half *__restrict__ output_threadblock) {\n    // output_threadblock will be filled by the thread block's act_shmem\n\n    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;\n\n    // Indices\n    const uint32_t li = threadIdx.x;  // index in warp (\"lane index\")\n    const uint32_t wi = threadIdx.y;  // index in block (\"warp index\")\n\n    const uint32_t lane_offset = (8 * li) % WIDTH;\n    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;\n\n    __syncthreads();\n\n#pragma unroll\n    for (int i = 0; i < N_ITERS; ++i) {\n        *(int4 *)&output_threadblock[lane_offset +\n                                     (row +\n                                      16 * (threadIdx.z + i * BLOCK_DIM_Z)) *\n                                             WIDTH] =\n                *(int4 *)&act_shmem[lane_offset +\n                                    (row +\n                                     16 * (threadIdx.z + i * BLOCK_DIM_Z)) *\n                                            (WIDTH + SKEW)];\n    }\n}\n\n///////////////////////////////////////////////////////////////////////////////////////////////\n///////////////////////////////////////////////////////////////////////////////////////////////\n///////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <int WIDTH,\n          int BLOCK_DIM_Z,\n          int N_ITERS,\n          typename OUT_T,\n          bool INFERENCE>\n__global__ void kernel_mlp_fused(const Activation activation,\n                                 const Activation output_activation,\n                                 const __half *__restrict__ input,\n                                 const __half *__restrict__ weights,\n                                 OUT_T *__restrict__ out_intermediate,\n                                 OUT_T *__restrict__ out,\n                                 const uint32_t batch_size,\n                                 const uint32_t in_width,\n                                 const uint32_t out_width,\n                                 const uint32_t n_hidden_matmuls,\n                                 const nvcuda::wmma::layout_t output_layout =\n                                         nvcuda::wmma::mem_row_major) {\n    // `input` points to the input matrix. Can be any width.\n    // `weights` points to the weight matrices (contiguous in memory).\n    // `out_intermediate` points to the memory where intermediate activations\n    // should be written. When performing inference, a value of nullptr is\n    // expected (intermediate results are not written). `out` points to the\n    // memory where the network output should be written. (Output width is\n    // assumed to be 16 neurons.)\n\n    // if (threadIdx.x == 0) printf(\"[forward] call kernel_mlp_fused\\n\");\n    // if (threadIdx.x == 0) printf(\"[forward] inputs=%f\\n\", (float)input[0]);\n    // if (threadIdx.x == 0) printf(\"[forward] weights=%f\\n\",\n    // (float)weights[0]);\n\n    // if (threadIdx.x == 0) printf(\"[forward] forward_buffer=%f\\n\",\n    // (float)out_intermediate[0]);\n\n    // Shared memory contains the intermediate activations of blockDim.y*16\n    // elements. In some cases, it also contains the weight matrix for the first\n    // and last layer.\n    extern __shared__ __half shmem[];\n    __half *act_shmem = shmem;\n\n    // Each block computes exactly one 16-element chunk of the batch.\n    const uint32_t elem_idx = 16 * blockIdx.x * N_ITERS * BLOCK_DIM_Z;\n\n    // First layer\n    if (in_width == WIDTH) {\n        // If the input has the same width as the network, we can simply use the\n        // network's regular layer routine (with static size) instead of using\n        // the slower dynamic input layer routine.\n        threadblock_load_input_static<WIDTH, BLOCK_DIM_Z, N_ITERS>(\n                act_shmem, input + elem_idx * WIDTH);\n        threadblock_layer<WIDTH, BLOCK_DIM_Z, N_ITERS, OUT_T>(\n                activation, act_shmem, weights,\n                !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr);\n    } else {\n        threadblock_input_layer_forward_dynamic<WIDTH, BLOCK_DIM_Z, N_ITERS,\n                                                OUT_T>(\n                activation, act_shmem, input + elem_idx * in_width, weights,\n                !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr,\n                in_width);\n    }\n\n    // if (threadIdx.x == 0) printf(\"[forward] kernel_mlp_fused: passed first\n    // layer\\n\");\n    // if (threadIdx.x == 0) printf(\"[forward] forward_buffer=%f\\n\",\n    // (float)out_intermediate[0]);\n\n    const uint32_t first_layer_size = WIDTH * in_width;\n    const uint32_t layer_stride = WIDTH * WIDTH;\n    const uint32_t output_stride = WIDTH * batch_size;\n\n    // Hidden layers\n    for (uint32_t k = 0; k < n_hidden_matmuls; ++k) {\n        threadblock_layer<WIDTH, BLOCK_DIM_Z, N_ITERS, OUT_T>(\n                activation, act_shmem,\n                weights + first_layer_size + layer_stride * k,\n                !INFERENCE ? (out_intermediate + output_stride * (k + 1) +\n                              elem_idx * WIDTH)\n                           : nullptr);\n        // if (threadIdx.x == 0) printf(\"[forward] kernel_mlp_fused: passed %d\n        // layer\\n\", k + 1);\n        // if (threadIdx.x == 0) printf(\"[forward] forward_buffer=%f\\n\",\n        // (float)out_intermediate[0]);\n    }\n\n    if (out_width > 16) {\n        // In the forward pass, intermediate activations are already written\n        // out.\n        if (INFERENCE) {\n            threadblock_write_output_static<WIDTH, BLOCK_DIM_Z, N_ITERS>(\n                    act_shmem, out_intermediate + elem_idx * WIDTH);\n        }\n    } else if (out) {\n        // Last layer\n        if (output_layout == nvcuda::wmma::mem_row_major) {\n            // printf(\"[last layer] RM write to out %d\\n\", elem_idx * 16);\n            // if (threadIdx.x == 0) printf(\"[forward] forward_buffer=%f\\n\",\n            // (float)out_intermediate[0]);\n            threadblock_last_layer_forward<WIDTH, BLOCK_DIM_Z, N_ITERS, OUT_T>(\n                    output_activation, act_shmem,\n                    weights + first_layer_size +\n                            layer_stride * n_hidden_matmuls,\n                    out + elem_idx * 16, 16, output_layout);\n            // if (threadIdx.x == 0) printf(\"[forward] forward_buffer=%f\\n\",\n            // (float)out_intermediate[0]);\n        } else {\n            // printf(\"[last layer] CM write to out %d\\n\", elem_idx);\n            // if (threadIdx.x == 0) printf(\"[forward] forward_buffer=%f\\n\",\n            // (float)out_intermediate[0]);\n            threadblock_last_layer_forward<WIDTH, BLOCK_DIM_Z, N_ITERS, OUT_T>(\n                    output_activation, act_shmem,\n                    weights + first_layer_size +\n                            layer_stride * n_hidden_matmuls,\n                    out + elem_idx, batch_size, output_layout);\n            // if (threadIdx.x == 0) printf(\"[forward] forward_buffer=%f\\n\",\n            // (float)out_intermediate[0]);\n        }\n    }\n}\n\ntemplate <int WIDTH, int BLOCK_DIM_Z, int N_ITERS, typename OUTPUT_LAYOUT>\n__global__ void kernel_mlp_fused_backward(\n        const Activation activation,\n        const __half *__restrict__ dL_doutput,\n        const __half *__restrict__ weights,\n        __half *__restrict__ out_intermediate,\n        const __half *__restrict__ forward,\n        __half *__restrict__ dL_dinput,\n        const __half *__restrict__ weights_first_layer,\n        const uint32_t batch_size,\n        const uint32_t out_width,\n        const uint32_t n_hidden_matmuls) {\n    // `dL_doutput` points to the input matrix of the backward pass, i.e. the\n    // loss gradients. Assumed to be 16 neurons wide. `weights` points to the\n    // weight matrices (contiguous in memory). `out_intermediate` points to the\n    // memory where backpropagated activation gradients should be written.\n    // `forward` points to the memory where the intermediate activations of the\n    // forward pass are located. (needed for activation backprop)\n\n    constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0;\n\n    // Indices\n    const uint32_t li = threadIdx.x;  // index in warp (\"lane index\")\n    const uint32_t wi = threadIdx.y;  // index in block (\"warp index\")\n    const uint32_t bi = blockIdx.x;   // block index\n\n    // Shared memory contains the intermediate activations of blockDim.y*16\n    // elements. A skew is applied to the matrix storage to avoid bank\n    // conflicts.\n    extern __shared__ __half shmem[];\n    __half *act_shmem = shmem;\n\n    const uint32_t lane_offset = (8 * li) % WIDTH;\n    const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH;\n\n    // Multipying one 16-row chunk of intermediate activations with the weight\n    // matrix requires all warps of the block. Thus, each block computes exactly\n    // one 16-row chunk of the next layer's intermediate activations.\n    const uint32_t elem_idx_base = 16 * bi * N_ITERS * BLOCK_DIM_Z;\n    const uint32_t elem_idx = elem_idx_base + 16 * threadIdx.z;\n\n    const uint32_t layer_stride = WIDTH * WIDTH;\n    const uint32_t output_stride = WIDTH * batch_size;\n\n    // Backprop through last layer\n    if (out_width <= 16) {\n        using namespace nvcuda;\n\n        // Fragments in registers\n        wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, OUTPUT_LAYOUT>\n                act_frag;\n        wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major>\n                weights_frag;\n        wmma::fragment<wmma::accumulator, 16, 16, 16, __half>\n                result_frag[N_ITERS];\n\n        // Load the relevant chunk of the last layer's weight matrix from global\n        // memory into registers\n        const uint32_t weights_col = 16 * wi;\n\n        wmma::load_matrix_sync(\n                weights_frag,\n                weights + layer_stride * n_hidden_matmuls + weights_col, WIDTH);\n\n#pragma unroll\n        for (int l = 0; l < N_ITERS; ++l) {\n            wmma::fill_fragment(result_frag[l], 0.0f);\n\n            // Load a chunk of output gradients from shared memory and multiply\n            // with previously loaded weights\n            if (std::is_same<OUTPUT_LAYOUT, wmma::row_major>::value) {\n                wmma::load_matrix_sync(\n                        act_frag,\n                        dL_doutput + (elem_idx +\n                                      16 * (threadIdx.z + l * BLOCK_DIM_Z)) *\n                                             16,\n                        16);\n            } else {\n                wmma::load_matrix_sync(\n                        act_frag,\n                        dL_doutput + (elem_idx +\n                                      16 * (threadIdx.z + l * BLOCK_DIM_Z)),\n                        batch_size);\n            }\n\n            // NOTE: activation transfer of the _output_ activation is expected\n            // to be done _prior_ to calling this kernel\n            //       in a separate pass, because the tranfered activation\n            //       gradient is also needed to compute the weight gradient of\n            //       the last weight matrix (see backward()).\n            wmma::mma_sync(result_frag[l], act_frag, weights_frag,\n                           result_frag[l]);\n\n            // Load the temporary forward matrix for the relu transfer\n            wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major>\n                    forward_frag;\n            wmma::load_matrix_sync(\n                    forward_frag,\n                    forward + output_stride * n_hidden_matmuls + weights_col +\n                            (elem_idx + l * BLOCK_DIM_Z * 16) * WIDTH,\n                    WIDTH);\n\n            warp_activation_backward<__half>(activation, result_frag[l],\n                                             forward_frag, result_frag[l]);\n        }\n\n        __syncthreads();\n\n#pragma unroll\n        for (int l = 0; l < N_ITERS; ++l) {\n            wmma::store_matrix_sync(\n                    act_shmem + weights_col +\n                            (16 * (threadIdx.z + l * BLOCK_DIM_Z)) *\n                                    (WIDTH + SKEW),\n                    result_frag[l], WIDTH + SKEW, wmma::mem_row_major);\n        }\n\n        __syncthreads();\n\n#pragma unroll\n        for (int i = 0; i < N_ITERS; ++i) {\n            *(int4 *)&out_intermediate[lane_offset +\n                                       (row + elem_idx + i * BLOCK_DIM_Z * 16) *\n                                               WIDTH] =\n                    *(int4 *)&act_shmem[lane_offset +\n                                        (row +\n                                         16 * (threadIdx.z + i * BLOCK_DIM_Z)) *\n                                                (WIDTH + SKEW)];\n        }\n    } else {\n        // If the output width is larger than 16, we will have used CUTLASS for\n        // backpropping through the last layer. Load the resulting gradients.\n        threadblock_load_input_static<WIDTH, BLOCK_DIM_Z, N_ITERS>(\n                act_shmem, out_intermediate + elem_idx * WIDTH);\n    }\n\n    // Backprop through hidden layers\n    for (uint32_t k = 0; k < n_hidden_matmuls; ++k) {\n        threadblock_layer<WIDTH, BLOCK_DIM_Z, N_ITERS, __half, true>(\n                activation, act_shmem,\n                weights + layer_stride * (n_hidden_matmuls - k - 1),\n                out_intermediate + output_stride * (k + 1) +\n                        elem_idx_base * WIDTH,\n                forward + output_stride * (n_hidden_matmuls - k - 1) +\n                        elem_idx_base * WIDTH);\n    }\n\n    // Compute loss gradients w.r.t. input if desired.\n    // THIS CODE ASSUMES THAT THE INPUT WIDTH IS THE SAME AS THE NETWORK WIDTH.\n    // DON'T PASS A NON-NULL dL_dinput IF THIS REQUIREMENT IS NOT MET.\n    if (dL_dinput != nullptr) {\n        threadblock_layer<WIDTH, BLOCK_DIM_Z, N_ITERS, __half, true>(\n                Activation::None, act_shmem, weights_first_layer,\n                dL_dinput + elem_idx_base * WIDTH);\n    }\n}\n\n//////////////////////////////////////////////////////////////////////////////////////////////\n//////////////////////////////////////////////////////////////////////////////////////////////\n//////////////////////////////////////////////////////////////////////////////////////////////\n\ntemplate <uint32_t WIDTH, bool INFERENCE>  // WIDTH is hidden_dim\nvoid ffmlp_forward_cuda(const __half *inputs,\n                        const __half *weights,\n                        const uint32_t B,\n                        const uint32_t input_dim,\n                        const uint32_t output_dim,\n                        const uint32_t num_layers,\n                        const Activation activation,\n                        const Activation output_activation,\n                        __half *forward_buffer,\n                        __half *outputs) {\n    constexpr uint32_t SKEW =\n            WIDTH % 16 == 0 ? 8 : 0;  // <- always going to be 8 as we only\n                                      // support multiple-of-16 widths\n    constexpr uint32_t INPUT_SKEW = 8;  // <- likewise with inputs\n    constexpr uint32_t N_BLOCK_ROWS = WIDTH / 16;\n\n    const int N_ITERS = WIDTH >= 256 ? 2 : 8;\n    const uint32_t BLOCK_DIM_Z = (INFERENCE && WIDTH == 128) ? 2 : 1;\n\n    const dim3 threads = {\n            32u, N_BLOCK_ROWS,\n            BLOCK_DIM_Z};  // 32 threads = 1 warp, N_BLOCK_ROWS warps\n                           // per block for 16 rows, up to 2x 8 warps\n                           // can share input (does not help vs. 1)\n\n    uint32_t n_elems_per_block = 16 * BLOCK_DIM_Z * N_ITERS;\n    uint32_t n_blocks = div_round_up(B, n_elems_per_block);\n\n    size_t shmem_size =\n            sizeof(__half) * (16 + 16 * BLOCK_DIM_Z * N_ITERS) *\n            (WIDTH +\n             SKEW);  // 16*WIDTH rows of weights (for the last layer; others\n                     // are in registers only) + 16*WIDTH*BLOCK_DIM_Z*N_ITERS\n                     // rows of intermediate activations\n\n    // If the input width is dynamic, the input weight matrix as well as part of\n    // the input will live in extra shared memory\n    if (input_dim != WIDTH) {\n        shmem_size = std::max(shmem_size, sizeof(__half) *\n                                                  (WIDTH + 16 * BLOCK_DIM_Z) *\n                                                  (input_dim + INPUT_SKEW));\n    }\n\n    // printf(\"[ffmlp_forward_cuda] shmem size = %d\\n\", shmem_size);\n\n    const dim3 blocks = {n_blocks, 1u, 1u};\n\n    check_shmem_error(cudaFuncSetAttribute(\n            kernel_mlp_fused<WIDTH, BLOCK_DIM_Z, N_ITERS, __half, INFERENCE>,\n            cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem_size));\n\n    kernel_mlp_fused<WIDTH, BLOCK_DIM_Z, N_ITERS, __half, INFERENCE>\n            <<<blocks, threads, shmem_size, 0>>>(\n                    activation, output_activation,\n                    inputs,          // CM\n                    weights,         // RM\n                    forward_buffer,  // CM\n                    outputs,         // CM\n                    B, input_dim, output_dim, num_layers - 1,\n                    nvcuda::wmma::mem_row_major  // reversed outputs's layout\n            );\n}\n\ntemplate <uint32_t WIDTH>  // WIDTH is hidden_dim\nvoid ffmlp_backward_cuda(const __half *grad,\n                         const __half *weights,\n                         const uint32_t B,\n                         const uint32_t input_dim,\n                         const uint32_t output_dim,\n                         const uint32_t num_layers,\n                         const Activation activation,\n                         const __half *forward_buffer,\n                         __half *backward_buffer,\n                         __half *grad_inputs) {\n    // locate\n    const __half *weights_first = weights;\n    const __half *weights_second = weights + input_dim * WIDTH;\n\n    constexpr uint32_t SKEW =\n            WIDTH % 16 == 0 ? 8 : 0;  // <- always going to be 8 as we only\n                                      // support multiple-of-16 widths\n    constexpr uint32_t N_BLOCKS = WIDTH / 16;\n\n    const int N_ITERS = WIDTH >= 256 ? 2 : 8;\n    const uint32_t BLOCK_DIM_Z = 1;\n\n    const dim3 threads = {\n            32u, N_BLOCKS,\n            BLOCK_DIM_Z};  // 32 threads = 1 warp, N_BLOCK_ROWS warps\n                           // per block for 16 rows, up to 2x 8 warps\n                           // can share input (does not help vs. 1)\n\n    uint32_t n_elems_per_block = 16 * BLOCK_DIM_Z * N_ITERS;\n    uint32_t n_blocks = div_round_up(B, n_elems_per_block);\n\n    size_t shmem_size =\n            sizeof(__half) *\n            ((16 * BLOCK_DIM_Z * N_ITERS) *\n             (WIDTH +\n              SKEW));  // WIDTH rows of input and 16 * threads.z rows of weights\n\n    const dim3 blocks = {n_blocks, 1u, 1u};\n\n    // The kernels operate with transposed layouts compared with the MLP code\n    check_shmem_error(cudaFuncSetAttribute(\n            kernel_mlp_fused_backward<WIDTH, BLOCK_DIM_Z, N_ITERS,\n                                      nvcuda::wmma::row_major>,\n            cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));\n\n    kernel_mlp_fused_backward<WIDTH, BLOCK_DIM_Z, N_ITERS,\n                              nvcuda::wmma::row_major>\n            <<<blocks, threads, shmem_size, 0>>>(activation,\n                                                 grad,             // CM\n                                                 weights_second,   // RM\n                                                 backward_buffer,  // CM\n                                                 forward_buffer,   // CM\n                                                 grad_inputs,      // CM\n                                                 weights_first,    // RM\n                                                 B, output_dim, num_layers - 1);\n}\n\n// inputs: col-major [input_dim, B]\n// weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim *\n// (num_layers - 1)] + [output_dim * hidden_dim] forward_buffer: col-major\n// [num_layers, hidden_dim, B] outputs: col-major [output_dim, B]\nvoid ffmlp_forward(const at::Tensor inputs,\n                   const at::Tensor weights,\n                   const uint32_t B,\n                   const uint32_t input_dim,\n                   const uint32_t output_dim,\n                   const uint32_t hidden_dim,\n                   const uint32_t num_layers,\n                   const uint32_t activation_,\n                   const uint32_t output_activation_,\n                   at::Tensor forward_buffer,\n                   at::Tensor outputs) {\n    CHECK_CUDA(inputs);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_IS_HALF(inputs);\n\n    CHECK_CUDA(weights);\n    CHECK_CONTIGUOUS(weights);\n    CHECK_IS_HALF(weights);\n\n    Activation activation = convert_activation(activation_);\n    Activation output_activation = convert_activation(output_activation_);\n\n    auto inputs_ptr = reinterpret_cast<__half *>(inputs.data_ptr<at::Half>());\n    auto weights_ptr = reinterpret_cast<__half *>(weights.data_ptr<at::Half>());\n    auto forward_buffer_ptr =\n            reinterpret_cast<__half *>(forward_buffer.data_ptr<at::Half>());\n    auto outputs_ptr = reinterpret_cast<__half *>(outputs.data_ptr<at::Half>());\n\n    switch (hidden_dim) {\n        case 16:\n            ffmlp_forward_cuda<16, false>(inputs_ptr, weights_ptr, B, input_dim,\n                                          output_dim, num_layers, activation,\n                                          output_activation, forward_buffer_ptr,\n                                          outputs_ptr);\n            break;\n        case 32:\n            ffmlp_forward_cuda<32, false>(inputs_ptr, weights_ptr, B, input_dim,\n                                          output_dim, num_layers, activation,\n                                          output_activation, forward_buffer_ptr,\n                                          outputs_ptr);\n            break;\n        case 64:\n            ffmlp_forward_cuda<64, false>(inputs_ptr, weights_ptr, B, input_dim,\n                                          output_dim, num_layers, activation,\n                                          output_activation, forward_buffer_ptr,\n                                          outputs_ptr);\n            break;\n        case 128:\n            ffmlp_forward_cuda<128, false>(inputs_ptr, weights_ptr, B,\n                                           input_dim, output_dim, num_layers,\n                                           activation, output_activation,\n                                           forward_buffer_ptr, outputs_ptr);\n            break;\n        case 256:\n            ffmlp_forward_cuda<256, false>(inputs_ptr, weights_ptr, B,\n                                           input_dim, output_dim, num_layers,\n                                           activation, output_activation,\n                                           forward_buffer_ptr, outputs_ptr);\n            break;\n        default:\n            throw std::runtime_error{\n                    \"hidden_dim should in [16, 32, 64, 128, 256]\"};\n    }\n\n    // for output_dim > 16\n    if (output_dim > 16) {\n        fc_multiply<LastLayer, true, false, false>(\n                0, output_dim, hidden_dim, B,\n                (weights_ptr + hidden_dim * input_dim +\n                 (num_layers - 1) * hidden_dim *\n                         hidden_dim),  // row-major, [output_dim, hidden_dim]\n                (forward_buffer_ptr + (num_layers - 1) * hidden_dim *\n                                              B),  // col-major [hidden_dim, B]\n                outputs_ptr,                       // col-major [outupt_dim, B]\n                output_activation);\n    }\n}\n\nvoid ffmlp_inference(const at::Tensor inputs,\n                     const at::Tensor weights,\n                     const uint32_t B,\n                     const uint32_t input_dim,\n                     const uint32_t output_dim,\n                     const uint32_t hidden_dim,\n                     const uint32_t num_layers,\n                     const uint32_t activation_,\n                     const uint32_t output_activation_,\n                     at::Tensor inference_buffer,\n                     at::Tensor outputs) {\n    CHECK_CUDA(inputs);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_IS_HALF(inputs);\n\n    CHECK_CUDA(weights);\n    CHECK_CONTIGUOUS(weights);\n    CHECK_IS_HALF(weights);\n\n    Activation activation = convert_activation(activation_);\n    Activation output_activation = convert_activation(output_activation_);\n\n    auto inputs_ptr = reinterpret_cast<__half *>(inputs.data_ptr<at::Half>());\n    auto weights_ptr = reinterpret_cast<__half *>(weights.data_ptr<at::Half>());\n    auto inference_buffer_ptr =\n            reinterpret_cast<__half *>(inference_buffer.data_ptr<at::Half>());\n    auto outputs_ptr = reinterpret_cast<__half *>(outputs.data_ptr<at::Half>());\n\n    switch (hidden_dim) {\n        case 16:\n            ffmlp_forward_cuda<16, true>(inputs_ptr, weights_ptr, B, input_dim,\n                                         output_dim, num_layers, activation,\n                                         output_activation,\n                                         inference_buffer_ptr, outputs_ptr);\n            break;\n        case 32:\n            ffmlp_forward_cuda<32, true>(inputs_ptr, weights_ptr, B, input_dim,\n                                         output_dim, num_layers, activation,\n                                         output_activation,\n                                         inference_buffer_ptr, outputs_ptr);\n            break;\n        case 64:\n            ffmlp_forward_cuda<64, true>(inputs_ptr, weights_ptr, B, input_dim,\n                                         output_dim, num_layers, activation,\n                                         output_activation,\n                                         inference_buffer_ptr, outputs_ptr);\n            break;\n        case 128:\n            ffmlp_forward_cuda<128, true>(inputs_ptr, weights_ptr, B, input_dim,\n                                          output_dim, num_layers, activation,\n                                          output_activation,\n                                          inference_buffer_ptr, outputs_ptr);\n            break;\n        case 256:\n            ffmlp_forward_cuda<256, true>(inputs_ptr, weights_ptr, B, input_dim,\n                                          output_dim, num_layers, activation,\n                                          output_activation,\n                                          inference_buffer_ptr, outputs_ptr);\n            break;\n        default:\n            throw std::runtime_error{\n                    \"hidden_dim should in [16, 32, 64, 128, 256]\"};\n    }\n\n    // for output_dim > 16\n    if (output_dim > 16) {\n        fc_multiply<LastLayer, true, false, false>(\n                0, output_dim, hidden_dim, B,\n                (weights_ptr + hidden_dim * input_dim +\n                 (num_layers - 1) * hidden_dim *\n                         hidden_dim),  // row-major, [output_dim, hidden_dim]\n                inference_buffer_ptr,  // col-major [hidden_dim, B]\n                outputs_ptr,           // col-major [outupt_dim, B]\n                output_activation);\n    }\n}\n\ninline std::vector<cudaStream_t> &streams_splitk() {\n    static std::vector<cudaStream_t> res;\n    return res;\n}\n\ninline std::vector<cudaEvent_t> &events_splitk() {\n    static std::vector<cudaEvent_t> res;\n    return res;\n}\n\nvoid allocate_splitk(size_t size) {\n    auto &streams = streams_splitk();\n    auto &events = events_splitk();\n    streams.resize(size);\n    events.resize(size);\n    for (size_t i = 0; i < size; i++) {\n        CUDA_CHECK_THROW(cudaStreamCreate(&streams[i]));\n        CUDA_CHECK_THROW(cudaEventCreate(&events[i]));\n    }\n}\n\nvoid free_splitk() {\n    auto &streams = streams_splitk();\n    auto &events = events_splitk();\n    for (size_t i = 0; i < streams.size(); i++) {\n        cutlass_free_workspace(streams[i]);\n        CUDA_CHECK_PRINT(cudaStreamDestroy(streams[i]));\n        CUDA_CHECK_PRINT(cudaEventDestroy(events[i]));\n    }\n}\n\n// grad: col-major [output_dim, B]\n// inputs: col-major [input_dim, B]\n// weights: row-major [hidden_dim * input_dim] + [hidden_dim * hidden_dim *\n// (num_layers - 1)] + [output_dim * hidden_dim] forward_buffer: col-major\n// [num_layers, hidden_dim, B] backward_buffer: col-major [num_layers,\n// hidden_dim, B] grad_inputs: col-major [input_dim, B] grad_weights: row-major\n// [hidden_dim * input_dim] + [hidden_dim * hidden_dim * (num_layers - 1)] +\n// [output_dim * hidden_dim]\nvoid ffmlp_backward(const at::Tensor grad,\n                    const at::Tensor inputs,\n                    const at::Tensor weights,\n                    const at::Tensor forward_buffer,\n                    const uint32_t B,\n                    const uint32_t input_dim,\n                    const uint32_t output_dim,\n                    const uint32_t hidden_dim,\n                    const uint32_t num_layers,\n                    const uint32_t activation_,\n                    const uint32_t output_activation_,\n                    const bool calc_grad_inputs,\n                    at::Tensor backward_buffer,\n                    at::Tensor grad_inputs,\n                    at::Tensor grad_weights) {\n    CHECK_CUDA(grad);\n    CHECK_CONTIGUOUS(grad);\n    CHECK_IS_HALF(grad);\n\n    CHECK_CUDA(inputs);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_IS_HALF(inputs);\n\n    CHECK_CUDA(weights);\n    CHECK_CONTIGUOUS(weights);\n    CHECK_IS_HALF(weights);\n\n    CHECK_CUDA(forward_buffer);\n    CHECK_CONTIGUOUS(forward_buffer);\n    CHECK_IS_HALF(forward_buffer);\n\n    CHECK_CUDA(backward_buffer);\n    CHECK_CONTIGUOUS(backward_buffer);\n    CHECK_IS_HALF(backward_buffer);\n\n    CHECK_CUDA(grad_weights);\n    CHECK_CONTIGUOUS(grad_weights);\n    CHECK_IS_HALF(grad_weights);\n\n    CHECK_CUDA(grad_inputs);\n    CHECK_CONTIGUOUS(grad_inputs);\n    CHECK_IS_HALF(grad_inputs);\n\n    Activation activation = convert_activation(activation_);\n    Activation output_activation = convert_activation(output_activation_);\n\n    // activation_backward_output_gpu (I gonna discard output_activation ...)\n\n    int split_k_factor = B / std::min((uint32_t)(1 << 12), B);\n\n    uint32_t forward_index = num_layers - 1;\n    uint32_t backward_index = 0;\n\n    auto backward_buffer_ptr =\n            reinterpret_cast<__half *>(backward_buffer.data_ptr<at::Half>());\n    auto forward_buffer_ptr =\n            reinterpret_cast<__half *>(forward_buffer.data_ptr<at::Half>());\n    auto grad_ptr = reinterpret_cast<__half *>(grad.data_ptr<at::Half>());\n    auto inputs_ptr = reinterpret_cast<__half *>(inputs.data_ptr<at::Half>());\n    auto weights_ptr = reinterpret_cast<__half *>(weights.data_ptr<at::Half>());\n    auto grad_weights_ptr =\n            reinterpret_cast<__half *>(grad_weights.data_ptr<at::Half>());\n\n    auto grad_inputs_ptr = calc_grad_inputs\n                                   ? reinterpret_cast<__half *>(\n                                             grad_inputs.data_ptr<at::Half>())\n                                   : nullptr;\n    auto grad_inputs_fused_ptr =\n            input_dim == hidden_dim ? grad_inputs_ptr : nullptr;\n\n    // calc output layer, grad_weights\n    cudaEventRecord(events_splitk().at(backward_index), 0);\n    cudaStreamWaitEvent(streams_splitk().at(backward_index),\n                        events_splitk().at(backward_index), 0);\n\n    fc_multiply_split_k<LastLayerK, false, true, true>(\n            streams_splitk().at(backward_index), output_dim, B, hidden_dim,\n            grad_ptr,  // col-major, [output_dim, B]\n            (forward_buffer_ptr +\n             forward_index * hidden_dim * B),  // row-major, [B, hidden_dim]\n            (grad_weights_ptr + hidden_dim * input_dim +\n             (num_layers - 1) * hidden_dim *\n                     hidden_dim),  // row-major, [output_dim, hidden_dim]\n            split_k_factor);\n\n    cudaEventRecord(events_splitk().at(backward_index),\n                    streams_splitk().at(backward_index));\n\n    // prepare the last backward_buffer if output_dim > 16\n    if (output_dim > 16) {\n        fc_multiply<FullLayer, false, false, false>(\n                0, hidden_dim, output_dim, B,\n                (grad_weights_ptr + hidden_dim * input_dim +\n                 (num_layers - 1) * hidden_dim *\n                         hidden_dim),  // col-major, [hidden_dim, output_dim]\n                grad_ptr,              // col-major, [output_dim, B]\n                (forward_buffer_ptr +\n                 forward_index * hidden_dim * B),  // col-major, [hidden_dim, B]\n                (backward_buffer_ptr +\n                 backward_index * hidden_dim * B),  // col-major [hidden_dim, B]\n                activation, true);\n    }\n\n    // prepare backward_buffer\n    // calc grad_inputs if input_dim == hidden_dim\n    switch (hidden_dim) {\n        case 16:\n            ffmlp_backward_cuda<16>(grad_ptr, weights_ptr, B, input_dim,\n                                    output_dim, num_layers, activation,\n                                    forward_buffer_ptr, backward_buffer_ptr,\n                                    grad_inputs_fused_ptr);\n            break;\n        case 32:\n            ffmlp_backward_cuda<32>(grad_ptr, weights_ptr, B, input_dim,\n                                    output_dim, num_layers, activation,\n                                    forward_buffer_ptr, backward_buffer_ptr,\n                                    grad_inputs_fused_ptr);\n            break;\n        case 64:\n            ffmlp_backward_cuda<64>(grad_ptr, weights_ptr, B, input_dim,\n                                    output_dim, num_layers, activation,\n                                    forward_buffer_ptr, backward_buffer_ptr,\n                                    grad_inputs_fused_ptr);\n            break;\n        case 128:\n            ffmlp_backward_cuda<128>(grad_ptr, weights_ptr, B, input_dim,\n                                     output_dim, num_layers, activation,\n                                     forward_buffer_ptr, backward_buffer_ptr,\n                                     grad_inputs_fused_ptr);\n            break;\n        case 256:\n            ffmlp_backward_cuda<256>(grad_ptr, weights_ptr, B, input_dim,\n                                     output_dim, num_layers, activation,\n                                     forward_buffer_ptr, backward_buffer_ptr,\n                                     grad_inputs_fused_ptr);\n            break;\n        default:\n            throw std::runtime_error{\n                    \"hidden_dim should in [16, 32, 64, 128, 256]\"};\n    }\n\n    // printf(\"[backward] finished backward kernel\\n\");\n\n    forward_index--;\n    backward_index++;\n\n    // calc middle layer's grad_weights\n    for (uint32_t i = 0; i < num_layers - 1; i++) {\n        uint32_t matrix_index = num_layers - 2 - i;\n\n        cudaEventRecord(events_splitk().at(backward_index), 0);\n        cudaStreamWaitEvent(streams_splitk().at(backward_index),\n                            events_splitk().at(backward_index), 0);\n\n        fc_multiply_split_k<FullLayerK, false, true, true>(\n                streams_splitk().at(backward_index), hidden_dim, B, hidden_dim,\n                (backward_buffer_ptr + (backward_index - 1) * hidden_dim *\n                                               B),  // col-major [hidden_dim, B]\n                (forward_buffer_ptr +\n                 forward_index * hidden_dim * B),  // row-major [B, hidden_dim]\n                (grad_weights_ptr + hidden_dim * input_dim +\n                 matrix_index * hidden_dim *\n                         hidden_dim),  // row-major, [hidden_dim, hidden_dim]\n                split_k_factor);\n\n        cudaEventRecord(events_splitk().at(backward_index),\n                        streams_splitk().at(backward_index));\n\n        forward_index--;\n        backward_index++;\n    }\n\n    // calc input layer's grad_weights\n    cudaEventRecord(events_splitk().at(backward_index), 0);\n    cudaStreamWaitEvent(streams_splitk().at(backward_index),\n                        events_splitk().at(backward_index), 0);\n\n    fc_multiply_split_k<FullLayerK, false, true, true>(\n            streams_splitk().at(backward_index), hidden_dim, B, input_dim,\n            (backward_buffer_ptr + (backward_index - 1) * hidden_dim *\n                                           B),  // col-major [hidden_dim, B]\n            inputs_ptr,                         // row-major, [B, input_dim]\n            grad_weights_ptr,  // row-major, [hidden_dim, input_dim]\n            split_k_factor);\n\n    cudaEventRecord(events_splitk().at(backward_index),\n                    streams_splitk().at(backward_index));\n\n    // calc grad_inputs if input_dim != hidden_dim\n    if (calc_grad_inputs && grad_inputs_fused_ptr == nullptr) {\n        fc_multiply<FullLayer, false, false, false>(\n                0, input_dim, hidden_dim, B,\n                weights_ptr,  // col-major [input_dim, hidden_dim]\n                (backward_buffer_ptr + (backward_index - 1) * hidden_dim *\n                                               B),  // col-major [hidden_dim, B]\n                grad_inputs_ptr                     // col-major [input_dim, B]\n        );\n    }\n\n    // All the per-layer split-k matrix multiplications summing over\n    // the batch are computed in parallel streams to the actual\n    // backpropagation. Here, we need to wait for all of these to complete.\n    for (auto &event : events_splitk()) {\n        cudaStreamWaitEvent(0, event, 0);\n    }\n}"
  },
  {
    "path": "lidarnerf/ffmlp/src/ffmlp.h",
    "content": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// activation: should have been enum, here we just use int.\nvoid ffmlp_forward(const at::Tensor inputs,\n                   const at::Tensor weights,\n                   const uint32_t B,\n                   const uint32_t input_dim,\n                   const uint32_t output_dim,\n                   const uint32_t hidden_dim,\n                   const uint32_t num_layers,\n                   const uint32_t activation_,\n                   const uint32_t output_activation_,\n                   at::Tensor forward_buffer,\n                   at::Tensor outputs);\nvoid ffmlp_inference(const at::Tensor inputs,\n                     const at::Tensor weights,\n                     const uint32_t B,\n                     const uint32_t input_dim,\n                     const uint32_t output_dim,\n                     const uint32_t hidden_dim,\n                     const uint32_t num_layers,\n                     const uint32_t activation_,\n                     const uint32_t output_activation_,\n                     at::Tensor inference_buffer,\n                     at::Tensor outputs);\n\nvoid ffmlp_backward(const at::Tensor grad,\n                    const at::Tensor inputs,\n                    const at::Tensor weights,\n                    const at::Tensor forward_buffer,\n                    const uint32_t B,\n                    const uint32_t input_dim,\n                    const uint32_t output_dim,\n                    const uint32_t hidden_dim,\n                    const uint32_t num_layers,\n                    const uint32_t activation,\n                    const uint32_t output_activation,\n                    const bool calc_grad_inputs,\n                    at::Tensor backward_buffer,\n                    at::Tensor grad_inputs,\n                    at::Tensor grad_weights);\n\nvoid allocate_splitk(size_t size);\nvoid free_splitk();"
  },
  {
    "path": "lidarnerf/ffmlp/src/utils.h",
    "content": "#pragma once\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <array>\n#include <atomic>\n#include <cassert>\n#include <cstdint>\n#include <cstdio>\n#include <iostream>\n#include <sstream>\n#include <stdexcept>\n#include <string>\n#include <vector>\n\n#define CHECK_CUDA(x) \\\n    TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) \\\n    TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x)                                 \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \\\n                #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x)                                       \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Float ||        \\\n                        x.scalar_type() == at::ScalarType::Half || \\\n                        x.scalar_type() == at::ScalarType::Double, \\\n                #x \" must be a floating tensor\")\n#define CHECK_IS_HALF(x)                                 \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Half, \\\n                #x \" must be a Half tensor\")\n\nstatic constexpr uint32_t MIN_GPU_ARCH = 70;\n\nusing network_precision_t = __half;\n\nenum class Activation {\n    ReLU,\n    Exponential,\n    Sine,\n    Sigmoid,\n    Squareplus,\n    Softplus,\n    None,\n};\n\nstatic constexpr float PI = 3.14159265358979323846f;\nstatic constexpr float SQRT2 = 1.41421356237309504880f;\nstatic constexpr float K_ACT = 10.0f;\n\n__host__ __device__ inline float logistic(const float x) {\n    return 1.0f / (1.0f + expf(-x));\n}\n\n__host__ __device__ inline float logit(const float x) {\n    return -logf(1.0f / (fminf(fmaxf(x, 1e-9f), 1.0f - 1e-9f)) - 1.0f);\n}\n\ninline std::atomic<size_t> &total_n_bytes_allocated() {\n    static std::atomic<size_t> s_total_n_bytes_allocated{0};\n    return s_total_n_bytes_allocated;\n}\n\n/// Checks the result of a cudaXXXXXX call and throws an error on failure\n#define CUDA_CHECK_THROW(x)                                                \\\n    do {                                                                   \\\n        cudaError_t result = x;                                            \\\n        if (result != cudaSuccess)                                         \\\n            throw std::runtime_error(                                      \\\n                    std::string(\"CUDA Error: \" #x \" failed with error \") + \\\n                    cudaGetErrorString(result));                           \\\n    } while (0)\n\n/// Checks the result of a cudaXXXXXX call and prints an error on failure\n#define CUDA_CHECK_PRINT(x)                                       \\\n    do {                                                          \\\n        cudaError_t result = x;                                   \\\n        if (result != cudaSuccess)                                \\\n            std::cout << \"CUDA Error: \" #x \" failed with error \"  \\\n                      << cudaGetErrorString(result) << std::endl; \\\n    } while (0)\n\n#define DEBUG_GUARD_SIZE 0\n\n/// Managed memory on the Device\ntemplate <class T>\nclass GPUMemory {\nprivate:\n    T *m_data = nullptr;\n    size_t m_size = 0;  // Number of elements\n    bool m_owned = true;\n\npublic:\n    GPUMemory() {}\n\n    GPUMemory<T> &operator=(GPUMemory<T> &&other) {\n        std::swap(m_data, other.m_data);\n        std::swap(m_size, other.m_size);\n        return *this;\n    }\n\n    GPUMemory(GPUMemory<T> &&other) { *this = std::move(other); }\n\n    __host__ __device__ GPUMemory(const GPUMemory<T> &other)\n        : m_data{other.m_data}, m_size{other.m_size}, m_owned{false} {}\n\n    void check_guards() const {\n#if DEBUG_GUARD_SIZE > 0\n        if (!m_data) return;\n        uint8_t buf[DEBUG_GUARD_SIZE];\n        const uint8_t *rawptr = (const uint8_t *)m_data;\n        cudaMemcpy(buf, rawptr - DEBUG_GUARD_SIZE, DEBUG_GUARD_SIZE,\n                   cudaMemcpyDeviceToHost);\n        for (int i = 0; i < DEBUG_GUARD_SIZE; ++i)\n            if (buf[i] != 0xff) {\n                printf(\"TRASH BEFORE BLOCK offset %d data %p, read 0x%02x \"\n                       \"expected \"\n                       \"0xff!\\n\",\n                       i, m_data, buf[i]);\n                break;\n            }\n        cudaMemcpy(buf, rawptr + m_size * sizeof(T), DEBUG_GUARD_SIZE,\n                   cudaMemcpyDeviceToHost);\n        for (int i = 0; i < DEBUG_GUARD_SIZE; ++i)\n            if (buf[i] != 0xfe) {\n                printf(\"TRASH AFTER BLOCK offset %d data %p, read 0x%02x \"\n                       \"expected 0xfe!\\n\",\n                       i, m_data, buf[i]);\n                break;\n            }\n#endif\n    }\n\n    void allocate_memory(size_t n_bytes) {\n        if (n_bytes == 0) {\n            return;\n        }\n\n#ifdef TCNN_VERBOSE_MEMORY_ALLOCS\n        std::cout << \"GPUMemory: Allocating \" << bytes_to_string(n_bytes) << \".\"\n                  << std::endl;\n#endif\n\n        uint8_t *rawptr = nullptr;\n        CUDA_CHECK_THROW(cudaMalloc(&rawptr, n_bytes + DEBUG_GUARD_SIZE * 2));\n#if DEBUG_GUARD_SIZE > 0\n        CUDA_CHECK_THROW(cudaMemset(rawptr, 0xff, DEBUG_GUARD_SIZE));\n        CUDA_CHECK_THROW(cudaMemset(rawptr + n_bytes + DEBUG_GUARD_SIZE, 0xfe,\n                                    DEBUG_GUARD_SIZE));\n#endif\n        if (rawptr) rawptr += DEBUG_GUARD_SIZE;\n        m_data = (T *)(rawptr);\n        total_n_bytes_allocated() += n_bytes;\n    }\n\n    void free_memory() {\n        if (!m_data) {\n            return;\n        }\n\n        uint8_t *rawptr = (uint8_t *)m_data;\n        if (rawptr) rawptr -= DEBUG_GUARD_SIZE;\n        CUDA_CHECK_THROW(cudaFree(rawptr));\n\n        total_n_bytes_allocated() -= get_bytes();\n\n        m_data = nullptr;\n    }\n\n    /// Allocates memory for size items of type T\n    GPUMemory(const size_t size) { resize(size); }\n\n    /// Frees memory again\n    __host__ __device__ ~GPUMemory() {\n#ifndef __CUDA_ARCH__\n        if (!m_owned) {\n            return;\n        }\n\n        try {\n            if (m_data) {\n                free_memory();\n                m_size = 0;\n            }\n        } catch (std::runtime_error error) {\n            // Don't need to report on memory-free problems when the driver is\n            // shutting down.\n            if (std::string{error.what()}.find(\"driver shutting down\") ==\n                std::string::npos) {\n                fprintf(stderr, \"Could not free memory: %s\\n\", error.what());\n            }\n        }\n#endif\n    }\n\n    /** @name Resizing/enlargement\n     *  @{\n     */\n    /// Resizes the array to the exact new size, even if it is already larger\n    void resize(const size_t size) {\n        if (!m_owned) {\n            throw std::runtime_error(\"Cannot resize non-owned memory.\");\n        }\n\n        if (m_size != size) {\n            if (m_size) {\n                try {\n                    free_memory();\n                } catch (std::runtime_error error) {\n                    throw std::runtime_error(\n                            std::string(\"Could not free memory: \") +\n                            error.what());\n                }\n            }\n\n            if (size > 0) {\n                try {\n                    allocate_memory(size * sizeof(T));\n                } catch (std::runtime_error error) {\n                    throw std::runtime_error(\n                            std::string(\"Could not allocate memory: \") +\n                            error.what());\n                }\n            }\n\n            m_size = size;\n        }\n    }\n\n    /// Enlarges the array if its size is smaller\n    void enlarge(const size_t size) {\n        if (size > m_size) {\n            resize(size);\n        }\n    }\n    /** @} */\n\n    /** @name Memset\n     *  @{\n     */\n    /// Sets the memory of the first num_elements to value\n    void memset(const int value,\n                const size_t num_elements,\n                const size_t offset = 0) {\n        if (num_elements + offset > m_size) {\n            throw std::runtime_error(\n                    \"Could not set memory: Number of elements \"\n                    \"larger than allocated memory\");\n        }\n\n        try {\n            CUDA_CHECK_THROW(cudaMemset(m_data + offset, value,\n                                        num_elements * sizeof(T)));\n        } catch (std::runtime_error error) {\n            throw std::runtime_error(std::string(\"Could not set memory: \") +\n                                     error.what());\n        }\n    }\n\n    /// Sets the memory of the all elements to value\n    void memset(const int value) { memset(value, m_size); }\n    /** @} */\n\n    /** @name Copy operations\n     *  @{\n     */\n    /// Copy data of num_elements from the raw pointer on the host\n    void copy_from_host(const T *host_data, const size_t num_elements) {\n        try {\n            CUDA_CHECK_THROW(cudaMemcpy(data(), host_data,\n                                        num_elements * sizeof(T),\n                                        cudaMemcpyHostToDevice));\n        } catch (std::runtime_error error) {\n            throw std::runtime_error(std::string(\"Could not copy from host: \") +\n                                     error.what());\n        }\n    }\n\n    /// Copy num_elements from the host vector\n    void copy_from_host(const std::vector<T> &data, const size_t num_elements) {\n        if (data.size() < num_elements) {\n            throw std::runtime_error(\n                    std::string(\"Trying to copy \") +\n                    std::to_string(num_elements) +\n                    std::string(\" elements, but vector size is only \") +\n                    std::to_string(data.size()));\n        }\n        copy_from_host(data.data(), num_elements);\n    }\n\n    /// Copies data from the raw host pointer to fill the entire array\n    void copy_from_host(const T *data) { copy_from_host(data, m_size); }\n\n    /// Copies num_elements of data from the raw host pointer after enlarging\n    /// the array so that everything fits in\n    void enlarge_and_copy_from_host(const T *data, const size_t num_elements) {\n        enlarge(num_elements);\n        copy_from_host(data, num_elements);\n    }\n\n    /// Copies num_elements from the host vector after enlarging the array so\n    /// that everything fits in\n    void enlarge_and_copy_from_host(const std::vector<T> &data,\n                                    const size_t num_elements) {\n        enlarge_and_copy_from_host(data.data(), num_elements);\n    }\n\n    /// Copies the entire host vector after enlarging the array so that\n    /// everything fits in\n    void enlarge_and_copy_from_host(const std::vector<T> &data) {\n        enlarge_and_copy_from_host(data.data(), data.size());\n    }\n\n    /// Copies num_elements of data from the raw host pointer after resizing the\n    /// array\n    void resize_and_copy_from_host(const T *data, const size_t num_elements) {\n        resize(num_elements);\n        copy_from_host(data, num_elements);\n    }\n\n    /// Copies num_elements from the host vector after resizing the array\n    void resize_and_copy_from_host(const std::vector<T> &data,\n                                   const size_t num_elements) {\n        resize_and_copy_from_host(data.data(), num_elements);\n    }\n\n    /// Copies the entire host vector after resizing the array\n    void resize_and_copy_from_host(const std::vector<T> &data) {\n        resize_and_copy_from_host(data.data(), data.size());\n    }\n\n    /// Copies the entire host vector to the device. Fails if there is not\n    /// enough space available.\n    void copy_from_host(const std::vector<T> &data) {\n        if (data.size() < m_size) {\n            throw std::runtime_error(\n                    std::string(\"Trying to copy \") + std::to_string(m_size) +\n                    std::string(\" elements, but vector size is only \") +\n                    std::to_string(data.size()));\n        }\n        copy_from_host(data.data(), m_size);\n    }\n\n    /// Copies num_elements of data from the raw host pointer to the device.\n    /// Fails if there is not enough space available.\n    void copy_to_host(T *host_data, const size_t num_elements) const {\n        if (num_elements > m_size) {\n            throw std::runtime_error(\n                    std::string(\"Trying to copy \") +\n                    std::to_string(num_elements) +\n                    std::string(\" elements, but vector size is only \") +\n                    std::to_string(m_size));\n        }\n        try {\n            CUDA_CHECK_THROW(cudaMemcpy(host_data, data(),\n                                        num_elements * sizeof(T),\n                                        cudaMemcpyDeviceToHost));\n        } catch (std::runtime_error error) {\n            throw std::runtime_error(std::string(\"Could not copy to host: \") +\n                                     error.what());\n        }\n    }\n\n    /// Copies num_elements from the device to a vector on the host\n    void copy_to_host(std::vector<T> &data, const size_t num_elements) const {\n        if (data.size() < num_elements) {\n            throw std::runtime_error(\n                    std::string(\"Trying to copy \") +\n                    std::to_string(num_elements) +\n                    std::string(\" elements, but vector size is only \") +\n                    std::to_string(data.size()));\n        }\n        copy_to_host(data.data(), num_elements);\n    }\n\n    /// Copies num_elements from the device to a raw pointer on the host\n    void copy_to_host(T *data) const { copy_to_host(data, m_size); }\n\n    /// Copies all elements from the device to a vector on the host\n    void copy_to_host(std::vector<T> &data) const {\n        if (data.size() < m_size) {\n            throw std::runtime_error(\n                    std::string(\"Trying to copy \") + std::to_string(m_size) +\n                    std::string(\" elements, but vector size is only \") +\n                    std::to_string(data.size()));\n        }\n        copy_to_host(data.data(), m_size);\n    }\n\n    /// Copies data from another device array to this one, automatically\n    /// resizing it\n    void copy_from_device(const GPUMemory<T> &other) {\n        if (m_size != other.m_size) {\n            resize(other.m_size);\n        }\n\n        try {\n            CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data,\n                                        m_size * sizeof(T),\n                                        cudaMemcpyDeviceToDevice));\n        } catch (std::runtime_error error) {\n            throw std::runtime_error(\n                    std::string(\"Could not copy from device: \") + error.what());\n        }\n    }\n\n    /// Copies size elements from another device array to this one,\n    /// automatically resizing it\n    void copy_from_device(const GPUMemory<T> &other, const size_t size) {\n        if (m_size < size) {\n            resize(size);\n        }\n\n        try {\n            CUDA_CHECK_THROW(cudaMemcpy(m_data, other.m_data, size * sizeof(T),\n                                        cudaMemcpyDeviceToDevice));\n        } catch (std::runtime_error error) {\n            throw std::runtime_error(\n                    std::string(\"Could not copy from device: \") + error.what());\n        }\n    }\n\n    // Created an (owned) copy of the data\n    GPUMemory<T> copy() const {\n        GPUMemory<T> result{m_size};\n        result.copy_from_device(*this);\n        return result;\n    }\n\n    T *data() const {\n        check_guards();\n        return m_data;\n    }\n\n    __host__ __device__ T &operator[](size_t idx) const {\n#ifdef DEBUG_BUFFER_OVERRUN\n        if (idx > m_size) {\n            printf(\"WARNING: buffer overrun of %p at idx %zu\\n\", idx);\n        }\n#endif\n        return m_data[idx];\n    }\n\n    __host__ __device__ T &operator[](uint32_t idx) const {\n#ifdef DEBUG_BUFFER_OVERRUN\n        if (idx > m_size) {\n            printf(\"WARNING: buffer overrun of %p at idx %u\\n\", idx);\n        }\n#endif\n        return m_data[idx];\n    }\n\n    size_t get_num_elements() const { return m_size; }\n\n    size_t size() const { return get_num_elements(); }\n\n    size_t get_bytes() const { return m_size * sizeof(T); }\n\n    size_t bytes() const { return get_bytes(); }\n};\n\ninline std::string bytes_to_string(size_t bytes) {\n    std::array<std::string, 7> suffixes = {\n            {\"B\", \"KB\", \"MB\", \"GB\", \"TB\", \"PB\", \"EB\"}};\n\n    double count = (double)bytes;\n    uint32_t i = 0;\n    for (; i < suffixes.size() && count >= 1024; ++i) {\n        count /= 1024;\n    }\n\n    std::ostringstream oss;\n    oss.precision(3);\n    oss << count << \" \" << suffixes[i];\n    return oss.str();\n}\n\ntemplate <typename T, typename fragment_t>\n__host__ __device__ void warp_activation(Activation activation,\n                                         const fragment_t &frag,\n                                         fragment_t &result) {\n    switch (activation) {\n        case Activation::ReLU:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = frag.x[t] * (T)((T)frag.x[t] > (T)0.0f);\n            }\n            return;\n        case Activation::Exponential:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = (T)(expf((float)frag.x[t]));\n            }\n            return;\n        case Activation::Sine:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = (T)(sinf((float)frag.x[t]));\n            }\n            return;\n        case Activation::Sigmoid:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = (T)(logistic((float)frag.x[t]));\n            }\n            return;\n        case Activation::Squareplus:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                float x = (float)frag.x[t] * K_ACT;\n                result.x[t] = (T)(0.5f * (x + sqrtf(x * x + 4)) / K_ACT);\n            }\n            return;\n        case Activation::Softplus:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = (T)(logf(expf((float)frag.x[t] * K_ACT) + 1.0f) /\n                                  K_ACT);\n            }\n            return;\n        case Activation::None:\n            result = frag;\n            return;\n        default:\n            // Unsupported activation\n            // assert(false); // Commented out due to isolated strange\n            // side-effects on Windows\n            return;\n    }\n}\n\ntemplate <typename T, typename fragment_t>\n__host__ __device__ fragment_t warp_activation(Activation activation,\n                                               const fragment_t &frag) {\n    fragment_t result;\n    warp_activation<T>(activation, frag, result);\n    return result;\n}\n\ntemplate <typename T, typename fragment_t, typename forward_fragment_t>\n__host__ __device__ void warp_activation_backward_in(\n        Activation activation,\n        const fragment_t &frag,\n        const forward_fragment_t &forward_frag_in,\n        fragment_t &result) {\n    switch (activation) {\n        case Activation::ReLU:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = frag.x[t] * (T)(forward_frag_in.x[t] > (T)0.0f);\n            }\n            return;\n        case Activation::Exponential:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = frag.x[t] * (T)(expf(forward_frag_in.x[t]));\n            }\n            return;\n        case Activation::Sine:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = frag.x[t] * (T)(cosf(forward_frag_in.x[t]));\n            }\n            return;\n        case Activation::Sigmoid:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                float x = logistic(forward_frag_in.x[t]);\n                result.x[t] = frag.x[t] * (T)(x * (1.0f - x));\n            }\n            return;\n        case Activation::Squareplus:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                float x = (float)forward_frag_in.x[t] * K_ACT;\n                float y = 0.5f * (x + sqrtf(x * x + 4));\n                result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1));\n            }\n            return;\n        case Activation::Softplus:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                float tmp = expf((float)frag.x[t] * K_ACT);\n                result.x[t] = frag.x[t] * (T)(tmp / (tmp + 1));\n            }\n            return;\n        case Activation::None:\n            result = frag;\n            return;\n        default:\n            // Unsupported activation\n            // assert(false); // Commented out due to isolated strange\n            // side-effects on Windows\n            return;\n    }\n}\n\ntemplate <typename T, typename fragment_t, typename forward_fragment_t>\n__host__ __device__ fragment_t\nwarp_activation_backward_in(Activation activation,\n                            const fragment_t &frag,\n                            const forward_fragment_t &forward_frag_in) {\n    fragment_t result;\n    warp_activation_backward_in<T>(activation, frag, forward_frag_in, result);\n    return result;\n}\n\ntemplate <typename T, typename fragment_t, typename forward_fragment_t>\n__host__ __device__ void warp_activation_backward(\n        Activation activation,\n        const fragment_t &frag,\n        const forward_fragment_t &forward_frag,\n        fragment_t &result) {\n    switch (activation) {\n        case Activation::ReLU:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f);\n            }\n            return;\n        case Activation::Exponential:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = frag.x[t] * forward_frag.x[t];\n            }\n            return;\n        case Activation::Sine:\n            // Sine requires stored pre-activations, which we don't have. We\n            // only write out the post-activations. assert(false); // Commented\n            // out due to isolated strange side-effects on Windows\n            return;\n        case Activation::Sigmoid:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] = frag.x[t] * (T)(forward_frag.x[t] *\n                                              ((T)1.0f - forward_frag.x[t]));\n            }\n            return;\n        case Activation::Squareplus:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                float y = (float)forward_frag.x[t] * K_ACT;\n                result.x[t] = frag.x[t] * (T)(y * y / (y * y + 1));\n            }\n            return;\n        case Activation::Softplus:\n#pragma unroll\n            for (int t = 0; t < result.num_elements; t++) {\n                result.x[t] =\n                        frag.x[t] *\n                        (T)(1.0f - expf(-(float)forward_frag.x[t] * K_ACT));\n            }\n            return;\n        case Activation::None:\n            result = frag;\n            return;\n        default:\n            // Unsupported activation\n            // assert(false); // Commented out due to isolated strange\n            // side-effects on Windows\n            return;\n    }\n}\n\ntemplate <typename T, typename fragment_t, typename forward_fragment_t>\n__host__ __device__ fragment_t\nwarp_activation_backward(Activation activation,\n                         const fragment_t &frag,\n                         const forward_fragment_t &forward_frag) {\n    fragment_t result;\n    warp_activation_backward<T>(activation, frag, forward_frag, result);\n    return result;\n}"
  },
  {
    "path": "lidarnerf/freqencoder/__init__.py",
    "content": ""
  },
  {
    "path": "lidarnerf/freqencoder/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n    \"-use_fast_math\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(\n    name=\"_freqencoder\",\n    extra_cflags=c_flags,\n    extra_cuda_cflags=nvcc_flags,\n    sources=[\n        os.path.join(_src_path, \"src\", f)\n        for f in [\n            \"freqencoder.cu\",\n            \"bindings.cpp\",\n        ]\n    ],\n)\n\n__all__ = [\"_backend\"]\n"
  },
  {
    "path": "lidarnerf/freqencoder/freq.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    import _freqencoder as _backend\nexcept ImportError:\n    from .backend import _backend\n\n\nclass _freq_encoder(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)  # force float32 for better precision\n    def forward(ctx, inputs, degree, output_dim):\n        # inputs: [B, input_dim], float\n        # RETURN: [B, F], float\n\n        if not inputs.is_cuda:\n            inputs = inputs.cuda()\n        inputs = inputs.contiguous()\n\n        B, input_dim = inputs.shape  # batch size, coord dim\n\n        outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)\n\n        _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs)\n\n        ctx.save_for_backward(inputs, outputs)\n        ctx.dims = [B, input_dim, degree, output_dim]\n\n        return outputs\n\n    @staticmethod\n    # @once_differentiable\n    @custom_bwd\n    def backward(ctx, grad):\n        # grad: [B, C * C]\n\n        grad = grad.contiguous()\n        inputs, outputs = ctx.saved_tensors\n        B, input_dim, degree, output_dim = ctx.dims\n\n        grad_inputs = torch.zeros_like(inputs)\n        _backend.freq_encode_backward(\n            grad, outputs, B, input_dim, degree, output_dim, grad_inputs\n        )\n\n        return grad_inputs, None, None\n\n\nfreq_encode = _freq_encoder.apply\n\n\nclass FreqEncoder(nn.Module):\n    def __init__(self, input_dim=3, degree=4):\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.degree = degree\n        self.output_dim = input_dim + input_dim * 2 * degree\n\n    def __repr__(self):\n        return f\"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}\"\n\n    def forward(self, inputs, **kwargs):\n        # inputs: [..., input_dim]\n        # return: [..., ]\n\n        prefix_shape = list(inputs.shape[:-1])\n        inputs = inputs.reshape(-1, self.input_dim)\n\n        outputs = freq_encode(inputs, self.degree, self.output_dim)\n\n        outputs = outputs.reshape(prefix_shape + [self.output_dim])\n\n        return outputs\n"
  },
  {
    "path": "lidarnerf/freqencoder/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n    \"-use_fast_math\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name=\"freqencoder\",  # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name=\"_freqencoder\",  # extension name, import this to use CUDA API\n            sources=[\n                os.path.join(_src_path, \"src\", f)\n                for f in [\n                    \"freqencoder.cu\",\n                    \"bindings.cpp\",\n                ]\n            ],\n            extra_compile_args={\n                \"cxx\": c_flags,\n                \"nvcc\": nvcc_flags,\n            },\n        ),\n    ],\n    cmdclass={\n        \"build_ext\": BuildExtension,\n    },\n)\n"
  },
  {
    "path": "lidarnerf/freqencoder/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"freqencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"freq_encode_forward\", &freq_encode_forward,\n          \"freq encode forward (CUDA)\");\n    m.def(\"freq_encode_backward\", &freq_encode_backward,\n          \"freq encode backward (CUDA)\");\n}"
  },
  {
    "path": "lidarnerf/freqencoder/src/freqencoder.cu",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <stdint.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <cstdio>\n#include <stdexcept>\n\n#define CHECK_CUDA(x) \\\n    TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) \\\n    TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x)                                 \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \\\n                #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x)                                       \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Float ||        \\\n                        x.scalar_type() == at::ScalarType::Half || \\\n                        x.scalar_type() == at::ScalarType::Double, \\\n                #x \" must be a floating tensor\")\n\ninline constexpr __device__ float PI() { return 3.141592653589793f; }\n\ntemplate <typename T>\n__host__ __device__ T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\n// inputs: [B, D]\n// outputs: [B, C], C = D + D * deg * 2\n__global__ void kernel_freq(const float *__restrict__ inputs,\n                            uint32_t B,\n                            uint32_t D,\n                            uint32_t deg,\n                            uint32_t C,\n                            float *outputs) {\n    // parallel on per-element\n    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n    if (t >= B * C) return;\n\n    // get index\n    const uint32_t b = t / C;\n    const uint32_t c = t - b * C;  // t % C;\n\n    // locate\n    inputs += b * D;\n    outputs += t;\n\n    // write self\n    if (c < D) {\n        outputs[0] = inputs[c];\n        // write freq\n    } else {\n        const uint32_t col = c / D - 1;\n        const uint32_t d = c % D;\n        const uint32_t freq = col / 2;\n        const float phase_shift = (col % 2) * (PI() / 2);\n        outputs[0] = __sinf(scalbnf(inputs[d], freq) + phase_shift);\n    }\n}\n\n// grad: [B, C], C = D + D * deg * 2\n// outputs: [B, C]\n// grad_inputs: [B, D]\n__global__ void kernel_freq_backward(const float *__restrict__ grad,\n                                     const float *__restrict__ outputs,\n                                     uint32_t B,\n                                     uint32_t D,\n                                     uint32_t deg,\n                                     uint32_t C,\n                                     float *grad_inputs) {\n    // parallel on per-element\n    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n    if (t >= B * D) return;\n\n    const uint32_t b = t / D;\n    const uint32_t d = t - b * D;  // t % D;\n\n    // locate\n    grad += b * C;\n    outputs += b * C;\n    grad_inputs += t;\n\n    // register\n    float result = grad[d];\n    grad += D;\n    outputs += D;\n\n    for (uint32_t f = 0; f < deg; f++) {\n        result += scalbnf(1.0f, f) *\n                  (grad[d] * outputs[D + d] - grad[D + d] * outputs[d]);\n        grad += 2 * D;\n        outputs += 2 * D;\n    }\n\n    // write\n    grad_inputs[0] = result;\n}\n\nvoid freq_encode_forward(at::Tensor inputs,\n                         const uint32_t B,\n                         const uint32_t D,\n                         const uint32_t deg,\n                         const uint32_t C,\n                         at::Tensor outputs) {\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(outputs);\n\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(outputs);\n\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(outputs);\n\n    static constexpr uint32_t N_THREADS = 128;\n\n    kernel_freq<<<div_round_up(B * C, N_THREADS), N_THREADS>>>(\n            inputs.data_ptr<float>(), B, D, deg, C, outputs.data_ptr<float>());\n}\n\nvoid freq_encode_backward(at::Tensor grad,\n                          at::Tensor outputs,\n                          const uint32_t B,\n                          const uint32_t D,\n                          const uint32_t deg,\n                          const uint32_t C,\n                          at::Tensor grad_inputs) {\n    CHECK_CUDA(grad);\n    CHECK_CUDA(outputs);\n    CHECK_CUDA(grad_inputs);\n\n    CHECK_CONTIGUOUS(grad);\n    CHECK_CONTIGUOUS(outputs);\n    CHECK_CONTIGUOUS(grad_inputs);\n\n    CHECK_IS_FLOATING(grad);\n    CHECK_IS_FLOATING(outputs);\n    CHECK_IS_FLOATING(grad_inputs);\n\n    static constexpr uint32_t N_THREADS = 128;\n\n    kernel_freq_backward<<<div_round_up(B * D, N_THREADS), N_THREADS>>>(\n            grad.data_ptr<float>(), outputs.data_ptr<float>(), B, D, deg, C,\n            grad_inputs.data_ptr<float>());\n}"
  },
  {
    "path": "lidarnerf/freqencoder/src/freqencoder.h",
    "content": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim,\n// outputs)\nvoid freq_encode_forward(at::Tensor inputs,\n                         const uint32_t B,\n                         const uint32_t D,\n                         const uint32_t deg,\n                         const uint32_t C,\n                         at::Tensor outputs);\n\n// _backend.freq_encode_backward(grad, outputs, B, input_dim, degree,\n// output_dim, grad_inputs)\nvoid freq_encode_backward(at::Tensor grad,\n                          at::Tensor outputs,\n                          const uint32_t B,\n                          const uint32_t D,\n                          const uint32_t deg,\n                          const uint32_t C,\n                          at::Tensor grad_inputs);"
  },
  {
    "path": "lidarnerf/gridencoder/__init__.py",
    "content": ""
  },
  {
    "path": "lidarnerf/gridencoder/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(\n    name=\"_grid_encoder\",\n    extra_cflags=c_flags,\n    extra_cuda_cflags=nvcc_flags,\n    sources=[\n        os.path.join(_src_path, \"src\", f)\n        for f in [\n            \"gridencoder.cu\",\n            \"bindings.cpp\",\n        ]\n    ],\n)\n\n__all__ = [\"_backend\"]\n"
  },
  {
    "path": "lidarnerf/gridencoder/grid.py",
    "content": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    import _gridencoder as _backend\nexcept ImportError:\n    from .backend import _backend\n\n_gridtype_to_id = {\n    \"hash\": 0,\n    \"tiled\": 1,\n}\n\n_interp_to_id = {\n    \"linear\": 0,\n    \"smoothstep\": 1,\n}\n\n\nclass _grid_encode(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(\n        ctx,\n        inputs,\n        embeddings,\n        offsets,\n        per_level_scale,\n        base_resolution,\n        calc_grad_inputs=False,\n        gridtype=0,\n        align_corners=False,\n        interpolation=0,\n    ):\n        # inputs: [B, D], float in [0, 1]\n        # embeddings: [sO, C], float\n        # offsets: [L + 1], int\n        # RETURN: [B, F], float\n\n        inputs = inputs.contiguous()\n\n        B, D = inputs.shape  # batch size, coord dim\n        L = offsets.shape[0] - 1  # level\n        C = embeddings.shape[1]  # embedding dim for each level\n        S = np.log2(\n            per_level_scale\n        )  # resolution multiplier at each level, apply log2 for later CUDA exp2f\n        H = base_resolution  # base resolution\n\n        # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)\n        # if C % 2 != 0, force float, since half for atomicAdd is very slow.\n        if torch.is_autocast_enabled() and C % 2 == 0:\n            embeddings = embeddings.to(torch.half)\n\n        # L first, optimize cache for cuda kernel, but needs an extra permute later\n        outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)\n\n        if calc_grad_inputs:\n            dy_dx = torch.empty(\n                B, L * D * C, device=inputs.device, dtype=embeddings.dtype\n            )\n        else:\n            dy_dx = None\n\n        _backend.grid_encode_forward(\n            inputs,\n            embeddings,\n            offsets,\n            outputs,\n            B,\n            D,\n            C,\n            L,\n            S,\n            H,\n            dy_dx,\n            gridtype,\n            align_corners,\n            interpolation,\n        )\n\n        # permute back to [B, L * C]\n        outputs = outputs.permute(1, 0, 2).reshape(B, L * C)\n\n        ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)\n        ctx.dims = [B, D, C, L, S, H, gridtype, interpolation]\n        ctx.align_corners = align_corners\n\n        return outputs\n\n    @staticmethod\n    # @once_differentiable\n    @custom_bwd\n    def backward(ctx, grad):\n        inputs, embeddings, offsets, dy_dx = ctx.saved_tensors\n        B, D, C, L, S, H, gridtype, interpolation = ctx.dims\n        align_corners = ctx.align_corners\n\n        # grad: [B, L * C] --> [L, B, C]\n        grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()\n\n        grad_embeddings = torch.zeros_like(embeddings)\n\n        if dy_dx is not None:\n            grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)\n        else:\n            grad_inputs = None\n\n        _backend.grid_encode_backward(\n            grad,\n            inputs,\n            embeddings,\n            offsets,\n            grad_embeddings,\n            B,\n            D,\n            C,\n            L,\n            S,\n            H,\n            dy_dx,\n            grad_inputs,\n            gridtype,\n            align_corners,\n            interpolation,\n        )\n\n        if dy_dx is not None:\n            grad_inputs = grad_inputs.to(inputs.dtype)\n\n        return grad_inputs, grad_embeddings, None, None, None, None, None, None, None\n\n\ngrid_encode = _grid_encode.apply\n\n\nclass GridEncoder(nn.Module):\n    def __init__(\n        self,\n        input_dim=3,\n        num_levels=16,\n        level_dim=2,\n        per_level_scale=2,\n        base_resolution=16,\n        log2_hashmap_size=19,\n        desired_resolution=None,\n        gridtype=\"hash\",\n        align_corners=False,\n        interpolation=\"linear\",\n    ):\n        super().__init__()\n\n        # the finest resolution desired at the last level, if provided, overridee per_level_scale\n        if desired_resolution is not None:\n            per_level_scale = np.exp2(\n                np.log2(desired_resolution / base_resolution) / (num_levels - 1)\n            )\n\n        self.input_dim = input_dim  # coord dims, 2 or 3\n        self.num_levels = num_levels  # num levels, each level multiply resolution by 2\n        self.level_dim = level_dim  # encode channels per level\n        self.per_level_scale = (\n            per_level_scale  # multiply resolution by this scale at each level.\n        )\n        self.log2_hashmap_size = log2_hashmap_size\n        self.base_resolution = base_resolution\n        self.output_dim = num_levels * level_dim\n        self.gridtype = gridtype\n        self.gridtype_id = _gridtype_to_id[gridtype]  # \"tiled\" or \"hash\"\n        self.interpolation = interpolation\n        self.interp_id = _interp_to_id[interpolation]  # \"linear\" or \"smoothstep\"\n        self.align_corners = align_corners\n\n        # allocate parameters\n        offsets = []\n        offset = 0\n        self.max_params = 2**log2_hashmap_size\n        for i in range(num_levels):\n            resolution = int(np.ceil(base_resolution * per_level_scale**i))\n            params_in_level = min(\n                self.max_params,\n                (resolution if align_corners else resolution + 1) ** input_dim,\n            )  # limit max number\n            params_in_level = int(np.ceil(params_in_level / 8) * 8)  # make divisible\n            offsets.append(offset)\n            offset += params_in_level\n        offsets.append(offset)\n        offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))\n        self.register_buffer(\"offsets\", offsets)\n\n        self.n_params = offsets[-1] * level_dim\n\n        # parameters\n        self.embeddings = nn.Parameter(torch.empty(offset, level_dim))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        std = 1e-4\n        self.embeddings.data.uniform_(-std, std)\n\n    def __repr__(self):\n        return f\"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}\"\n\n    def forward(self, inputs, bound=1):\n        # inputs: [..., input_dim], normalized real world positions in [-bound, bound]\n        # return: [..., num_levels * level_dim]\n\n        inputs = (inputs + bound) / (2 * bound)  # map to [0, 1]\n\n        # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())\n\n        prefix_shape = list(inputs.shape[:-1])\n        inputs = inputs.view(-1, self.input_dim)\n\n        outputs = grid_encode(\n            inputs,\n            self.embeddings,\n            self.offsets,\n            self.per_level_scale,\n            self.base_resolution,\n            inputs.requires_grad,\n            self.gridtype_id,\n            self.align_corners,\n            self.interp_id,\n        )\n        outputs = outputs.view(prefix_shape + [self.output_dim])\n\n        # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())\n\n        return outputs\n\n    # always run in float precision!\n    @torch.cuda.amp.autocast(enabled=False)\n    def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000):\n        # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss.\n\n        D = self.input_dim\n        C = self.embeddings.shape[1]  # embedding dim for each level\n        L = self.offsets.shape[0] - 1  # level\n        S = np.log2(\n            self.per_level_scale\n        )  # resolution multiplier at each level, apply log2 for later CUDA exp2f\n        H = self.base_resolution  # base resolution\n\n        if inputs is None:\n            # randomized in [0, 1]\n            inputs = torch.rand(B, self.input_dim, device=self.embeddings.device)\n        else:\n            inputs = (inputs + bound) / (2 * bound)  # map to [0, 1]\n            inputs = inputs.view(-1, self.input_dim)\n            B = inputs.shape[0]\n\n        if self.embeddings.grad is None:\n            raise ValueError(\n                \"grad is None, should be called after loss.backward() and before optimizer.step()!\"\n            )\n\n        _backend.grad_total_variation(\n            inputs,\n            self.embeddings,\n            self.embeddings.grad,\n            self.offsets,\n            weight,\n            B,\n            D,\n            C,\n            L,\n            S,\n            H,\n            self.gridtype_id,\n            self.align_corners,\n        )\n"
  },
  {
    "path": "lidarnerf/gridencoder/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name=\"gridencoder\",  # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name=\"_gridencoder\",  # extension name, import this to use CUDA API\n            sources=[\n                os.path.join(_src_path, \"src\", f)\n                for f in [\n                    \"gridencoder.cu\",\n                    \"bindings.cpp\",\n                ]\n            ],\n            extra_compile_args={\n                \"cxx\": c_flags,\n                \"nvcc\": nvcc_flags,\n            },\n        ),\n    ],\n    cmdclass={\n        \"build_ext\": BuildExtension,\n    },\n)\n"
  },
  {
    "path": "lidarnerf/gridencoder/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"gridencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"grid_encode_forward\", &grid_encode_forward,\n          \"grid_encode_forward (CUDA)\");\n    m.def(\"grid_encode_backward\", &grid_encode_backward,\n          \"grid_encode_backward (CUDA)\");\n    m.def(\"grad_total_variation\", &grad_total_variation,\n          \"grad_total_variation (CUDA)\");\n}"
  },
  {
    "path": "lidarnerf/gridencoder/src/gridencoder.cu",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <stdint.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <cstdio>\n#include <stdexcept>\n\n#define CHECK_CUDA(x) \\\n    TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) \\\n    TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x)                                 \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \\\n                #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x)                                       \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Float ||        \\\n                        x.scalar_type() == at::ScalarType::Half || \\\n                        x.scalar_type() == at::ScalarType::Double, \\\n                #x \" must be a floating tensor\")\n\n// just for compatability of half precision in\n// AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here!\n__device__ inline at::Half atomicAdd(at::Half *address, at::Half val) {\n    // requires CUDA >= 10 and ARCH >= 70\n    // this is very slow compared to float or __half2, never use it.\n    // return atomicAdd(reinterpret_cast<__half*>(address), val);\n}\n\ntemplate <typename T>\n__host__ __device__ inline T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\ntemplate <typename T, typename T2>\n__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) {\n    return min(max(v, lo), hi);\n}\n\ntemplate <typename T>\n__device__ inline T smoothstep(T val) {\n    return val * val * (3.0f - 2.0f * val);\n}\n\ntemplate <typename T>\n__device__ inline T smoothstep_derivative(T val) {\n    return 6 * val * (1.0f - val);\n}\n\ntemplate <uint32_t D>\n__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {\n    // coherent type of hashing\n    constexpr uint32_t primes[7] = {1u,          2654435761u, 805459861u,\n                                    3674653429u, 2097192037u, 1434869437u,\n                                    2165219737u};\n\n    uint32_t result = 0;\n#pragma unroll\n    for (uint32_t i = 0; i < D; ++i) {\n        result ^= pos_grid[i] * primes[i];\n    }\n\n    return result;\n}\n\ntemplate <uint32_t D, uint32_t C>\n__device__ uint32_t get_grid_index(const uint32_t gridtype,\n                                   const bool align_corners,\n                                   const uint32_t ch,\n                                   const uint32_t hashmap_size,\n                                   const uint32_t resolution,\n                                   const uint32_t pos_grid[D]) {\n    uint32_t stride = 1;\n    uint32_t index = 0;\n\n#pragma unroll\n    for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {\n        index += pos_grid[d] * stride;\n        stride *= align_corners ? resolution : (resolution + 1);\n    }\n\n    // NOTE: for NeRF, the hash is in fact not necessary. Check\n    // https://github.com/NVlabs/instant-ngp/issues/97. gridtype: 0 == hash, 1\n    // == tiled\n    if (gridtype == 0 && stride > hashmap_size) {\n        index = fast_hash<D>(pos_grid);\n    }\n\n    return (index % hashmap_size) * C + ch;\n}\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C>\n__global__ void kernel_grid(const float *__restrict__ inputs,\n                            const scalar_t *__restrict__ grid,\n                            const int *__restrict__ offsets,\n                            scalar_t *__restrict__ outputs,\n                            const uint32_t B,\n                            const uint32_t L,\n                            const float S,\n                            const uint32_t H,\n                            scalar_t *__restrict__ dy_dx,\n                            const uint32_t gridtype,\n                            const bool align_corners,\n                            const uint32_t interp) {\n    const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;\n\n    if (b >= B) return;\n\n    const uint32_t level = blockIdx.y;\n\n    // locate\n    grid += (uint32_t)offsets[level] * C;\n    inputs += b * D;\n    outputs += level * B * C + b * C;\n\n    // check input range (should be in [0, 1])\n    bool flag_oob = false;\n#pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        if (inputs[d] < 0 || inputs[d] > 1) {\n            flag_oob = true;\n        }\n    }\n    // if input out of bound, just set output to 0\n    if (flag_oob) {\n#pragma unroll\n        for (uint32_t ch = 0; ch < C; ch++) {\n            outputs[ch] = 0;\n        }\n        if (dy_dx) {\n            dy_dx += b * D * L * C + level * D * C;  // B L D C\n#pragma unroll\n            for (uint32_t d = 0; d < D; d++) {\n#pragma unroll\n                for (uint32_t ch = 0; ch < C; ch++) {\n                    dy_dx[d * C + ch] = 0;\n                }\n            }\n        }\n        return;\n    }\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    const float scale = exp2f(level * S) * H - 1.0f;\n    const uint32_t resolution = (uint32_t)ceil(scale) + 1;\n\n    // calculate coordinate (always use float for precision!)\n    float pos[D];\n    float pos_deriv[D];\n    uint32_t pos_grid[D];\n\n#pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);\n        pos_grid[d] = floorf(pos[d]);\n        pos[d] -= (float)pos_grid[d];\n        // smoothstep instead of linear\n        if (interp == 1) {\n            pos_deriv[d] = smoothstep_derivative(pos[d]);\n            pos[d] = smoothstep(pos[d]);\n        } else {\n            pos_deriv[d] = 1.0f;  // linear deriv is default to 1\n        }\n    }\n\n    // printf(\"[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\\n\", b, level, pos[0], pos[1],\n    // pos_grid[0], pos_grid[1]);\n\n    // interpolate\n    scalar_t results[C] = {0};  // temp results in register\n\n#pragma unroll\n    for (uint32_t idx = 0; idx < (1 << D); idx++) {\n        float w = 1;\n        uint32_t pos_grid_local[D];\n\n#pragma unroll\n        for (uint32_t d = 0; d < D; d++) {\n            if ((idx & (1 << d)) == 0) {\n                w *= 1 - pos[d];\n                pos_grid_local[d] = pos_grid[d];\n            } else {\n                w *= pos[d];\n                pos_grid_local[d] = pos_grid[d] + 1;\n            }\n        }\n\n        uint32_t index =\n                get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size,\n                                     resolution, pos_grid_local);\n\n// writing to register (fast)\n#pragma unroll\n        for (uint32_t ch = 0; ch < C; ch++) {\n            results[ch] += w * grid[index + ch];\n        }\n\n        // printf(\"[b=%d, l=%d] int %d, idx %d, w %f, val %f\\n\", b, level, idx,\n        // index, w, grid[index]);\n    }\n\n// writing to global memory (slow)\n#pragma unroll\n    for (uint32_t ch = 0; ch < C; ch++) {\n        outputs[ch] = results[ch];\n    }\n\n    // prepare dy_dx\n    // differentiable (soft) indexing:\n    // https://discuss.pytorch.org/t/differentiable-indexing/17647/9\n    if (dy_dx) {\n        dy_dx += b * D * L * C + level * D * C;  // B L D C\n\n#pragma unroll\n        for (uint32_t gd = 0; gd < D; gd++) {\n            scalar_t results_grad[C] = {0};\n\n#pragma unroll\n            for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {\n                float w = scale;\n                uint32_t pos_grid_local[D];\n\n#pragma unroll\n                for (uint32_t nd = 0; nd < D - 1; nd++) {\n                    const uint32_t d = (nd >= gd) ? (nd + 1) : nd;\n\n                    if ((idx & (1 << nd)) == 0) {\n                        w *= 1 - pos[d];\n                        pos_grid_local[d] = pos_grid[d];\n                    } else {\n                        w *= pos[d];\n                        pos_grid_local[d] = pos_grid[d] + 1;\n                    }\n                }\n\n                pos_grid_local[gd] = pos_grid[gd];\n                uint32_t index_left = get_grid_index<D, C>(\n                        gridtype, align_corners, 0, hashmap_size, resolution,\n                        pos_grid_local);\n                pos_grid_local[gd] = pos_grid[gd] + 1;\n                uint32_t index_right = get_grid_index<D, C>(\n                        gridtype, align_corners, 0, hashmap_size, resolution,\n                        pos_grid_local);\n\n#pragma unroll\n                for (uint32_t ch = 0; ch < C; ch++) {\n                    results_grad[ch] +=\n                            w *\n                            (grid[index_right + ch] - grid[index_left + ch]) *\n                            pos_deriv[gd];\n                }\n            }\n\n#pragma unroll\n            for (uint32_t ch = 0; ch < C; ch++) {\n                dy_dx[gd * C + ch] = results_grad[ch];\n            }\n        }\n    }\n}\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>\n__global__ void kernel_grid_backward(const scalar_t *__restrict__ grad,\n                                     const float *__restrict__ inputs,\n                                     const scalar_t *__restrict__ grid,\n                                     const int *__restrict__ offsets,\n                                     scalar_t *__restrict__ grad_grid,\n                                     const uint32_t B,\n                                     const uint32_t L,\n                                     const float S,\n                                     const uint32_t H,\n                                     const uint32_t gridtype,\n                                     const bool align_corners,\n                                     const uint32_t interp) {\n    const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;\n    if (b >= B) return;\n\n    const uint32_t level = blockIdx.y;\n    const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;\n\n    // locate\n    grad_grid += offsets[level] * C;\n    inputs += b * D;\n    grad += level * B * C + b * C + ch;  // L, B, C\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    const float scale = exp2f(level * S) * H - 1.0f;\n    const uint32_t resolution = (uint32_t)ceil(scale) + 1;\n\n// check input range (should be in [0, 1])\n#pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        if (inputs[d] < 0 || inputs[d] > 1) {\n            return;  // grad is init as 0, so we simply return.\n        }\n    }\n\n    // calculate coordinate\n    float pos[D];\n    uint32_t pos_grid[D];\n\n#pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);\n        pos_grid[d] = floorf(pos[d]);\n        pos[d] -= (float)pos_grid[d];\n        // smoothstep instead of linear\n        if (interp == 1) {\n            pos[d] = smoothstep(pos[d]);\n        }\n    }\n\n    scalar_t grad_cur[N_C] = {0};  // fetch to register\n#pragma unroll\n    for (uint32_t c = 0; c < N_C; c++) {\n        grad_cur[c] = grad[c];\n    }\n\n// interpolate\n#pragma unroll\n    for (uint32_t idx = 0; idx < (1 << D); idx++) {\n        float w = 1;\n        uint32_t pos_grid_local[D];\n\n#pragma unroll\n        for (uint32_t d = 0; d < D; d++) {\n            if ((idx & (1 << d)) == 0) {\n                w *= 1 - pos[d];\n                pos_grid_local[d] = pos_grid[d];\n            } else {\n                w *= pos[d];\n                pos_grid_local[d] = pos_grid[d] + 1;\n            }\n        }\n\n        uint32_t index =\n                get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size,\n                                     resolution, pos_grid_local);\n\n        // atomicAdd for __half is slow (especially for large values), so we use\n        // __half2 if N_C % 2 == 0\n        // TODO: use float which is better than __half, if N_C % 2 != 0\n        if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {\n#pragma unroll\n            for (uint32_t c = 0; c < N_C; c += 2) {\n                // process two __half at once (by interpreting as a __half2)\n                __half2 v = {(__half)(w * grad_cur[c]),\n                             (__half)(w * grad_cur[c + 1])};\n                atomicAdd((__half2 *)&grad_grid[index + c], v);\n            }\n            // float, or __half when N_C % 2 != 0 (which means C == 1)\n        } else {\n#pragma unroll\n            for (uint32_t c = 0; c < N_C; c++) {\n                atomicAdd(&grad_grid[index + c], w * grad_cur[c]);\n            }\n        }\n    }\n}\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C>\n__global__ void kernel_input_backward(const scalar_t *__restrict__ grad,\n                                      const scalar_t *__restrict__ dy_dx,\n                                      scalar_t *__restrict__ grad_inputs,\n                                      uint32_t B,\n                                      uint32_t L) {\n    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n    if (t >= B * D) return;\n\n    const uint32_t b = t / D;\n    const uint32_t d = t - b * D;\n\n    dy_dx += b * L * D * C;\n\n    scalar_t result = 0;\n\n#pragma unroll\n    for (int l = 0; l < L; l++) {\n#pragma unroll\n        for (int ch = 0; ch < C; ch++) {\n            result += grad[l * B * C + b * C + ch] *\n                      dy_dx[l * D * C + d * C + ch];\n        }\n    }\n\n    grad_inputs[t] = result;\n}\n\ntemplate <typename scalar_t, uint32_t D>\nvoid kernel_grid_wrapper(const float *inputs,\n                         const scalar_t *embeddings,\n                         const int *offsets,\n                         scalar_t *outputs,\n                         const uint32_t B,\n                         const uint32_t C,\n                         const uint32_t L,\n                         const float S,\n                         const uint32_t H,\n                         scalar_t *dy_dx,\n                         const uint32_t gridtype,\n                         const bool align_corners,\n                         const uint32_t interp) {\n    static constexpr uint32_t N_THREAD = 512;\n    const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1};\n    switch (C) {\n        case 1:\n            kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(\n                    inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx,\n                    gridtype, align_corners, interp);\n            break;\n        case 2:\n            kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(\n                    inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx,\n                    gridtype, align_corners, interp);\n            break;\n        case 4:\n            kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(\n                    inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx,\n                    gridtype, align_corners, interp);\n            break;\n        case 8:\n            kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(\n                    inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx,\n                    gridtype, align_corners, interp);\n            break;\n        default:\n            throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit\n// into cache at a time.) H: base resolution dy_dx: [B, L * D * C]\ntemplate <typename scalar_t>\nvoid grid_encode_forward_cuda(const float *inputs,\n                              const scalar_t *embeddings,\n                              const int *offsets,\n                              scalar_t *outputs,\n                              const uint32_t B,\n                              const uint32_t D,\n                              const uint32_t C,\n                              const uint32_t L,\n                              const float S,\n                              const uint32_t H,\n                              scalar_t *dy_dx,\n                              const uint32_t gridtype,\n                              const bool align_corners,\n                              const uint32_t interp) {\n    switch (D) {\n        case 2:\n            kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets,\n                                             outputs, B, C, L, S, H, dy_dx,\n                                             gridtype, align_corners, interp);\n            break;\n        case 3:\n            kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets,\n                                             outputs, B, C, L, S, H, dy_dx,\n                                             gridtype, align_corners, interp);\n            break;\n        case 4:\n            kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets,\n                                             outputs, B, C, L, S, H, dy_dx,\n                                             gridtype, align_corners, interp);\n            break;\n        case 5:\n            kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets,\n                                             outputs, B, C, L, S, H, dy_dx,\n                                             gridtype, align_corners, interp);\n            break;\n        default:\n            throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\ntemplate <typename scalar_t, uint32_t D>\nvoid kernel_grid_backward_wrapper(const scalar_t *grad,\n                                  const float *inputs,\n                                  const scalar_t *embeddings,\n                                  const int *offsets,\n                                  scalar_t *grad_embeddings,\n                                  const uint32_t B,\n                                  const uint32_t C,\n                                  const uint32_t L,\n                                  const float S,\n                                  const uint32_t H,\n                                  scalar_t *dy_dx,\n                                  scalar_t *grad_inputs,\n                                  const uint32_t gridtype,\n                                  const bool align_corners,\n                                  const uint32_t interp) {\n    static constexpr uint32_t N_THREAD = 256;\n    const uint32_t N_C = std::min(2u, C);  // n_features_per_thread\n    const dim3 blocks_hashgrid = {div_round_up(B * C / N_C, N_THREAD), L, 1};\n    switch (C) {\n        case 1:\n            kernel_grid_backward<scalar_t, D, 1, 1>\n                    <<<blocks_hashgrid, N_THREAD>>>(\n                            grad, inputs, embeddings, offsets, grad_embeddings,\n                            B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx)\n                kernel_input_backward<scalar_t, D, 1>\n                        <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(\n                                grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 2:\n            kernel_grid_backward<scalar_t, D, 2, 2>\n                    <<<blocks_hashgrid, N_THREAD>>>(\n                            grad, inputs, embeddings, offsets, grad_embeddings,\n                            B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx)\n                kernel_input_backward<scalar_t, D, 2>\n                        <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(\n                                grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 4:\n            kernel_grid_backward<scalar_t, D, 4, 2>\n                    <<<blocks_hashgrid, N_THREAD>>>(\n                            grad, inputs, embeddings, offsets, grad_embeddings,\n                            B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx)\n                kernel_input_backward<scalar_t, D, 4>\n                        <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(\n                                grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 8:\n            kernel_grid_backward<scalar_t, D, 8, 2>\n                    <<<blocks_hashgrid, N_THREAD>>>(\n                            grad, inputs, embeddings, offsets, grad_embeddings,\n                            B, L, S, H, gridtype, align_corners, interp);\n            if (dy_dx)\n                kernel_input_backward<scalar_t, D, 8>\n                        <<<div_round_up(B * D, N_THREAD), N_THREAD>>>(\n                                grad, dy_dx, grad_inputs, B, L);\n            break;\n        default:\n            throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\n// grad: [L, B, C], float\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// grad_embeddings: [sO, C]\n// H: base resolution\ntemplate <typename scalar_t>\nvoid grid_encode_backward_cuda(const scalar_t *grad,\n                               const float *inputs,\n                               const scalar_t *embeddings,\n                               const int *offsets,\n                               scalar_t *grad_embeddings,\n                               const uint32_t B,\n                               const uint32_t D,\n                               const uint32_t C,\n                               const uint32_t L,\n                               const float S,\n                               const uint32_t H,\n                               scalar_t *dy_dx,\n                               scalar_t *grad_inputs,\n                               const uint32_t gridtype,\n                               const bool align_corners,\n                               const uint32_t interp) {\n    switch (D) {\n        case 2:\n            kernel_grid_backward_wrapper<scalar_t, 2>(\n                    grad, inputs, embeddings, offsets, grad_embeddings, B, C, L,\n                    S, H, dy_dx, grad_inputs, gridtype, align_corners, interp);\n            break;\n        case 3:\n            kernel_grid_backward_wrapper<scalar_t, 3>(\n                    grad, inputs, embeddings, offsets, grad_embeddings, B, C, L,\n                    S, H, dy_dx, grad_inputs, gridtype, align_corners, interp);\n            break;\n        case 4:\n            kernel_grid_backward_wrapper<scalar_t, 4>(\n                    grad, inputs, embeddings, offsets, grad_embeddings, B, C, L,\n                    S, H, dy_dx, grad_inputs, gridtype, align_corners, interp);\n            break;\n        case 5:\n            kernel_grid_backward_wrapper<scalar_t, 5>(\n                    grad, inputs, embeddings, offsets, grad_embeddings, B, C, L,\n                    S, H, dy_dx, grad_inputs, gridtype, align_corners, interp);\n            break;\n        default:\n            throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\nvoid grid_encode_forward(const at::Tensor inputs,\n                         const at::Tensor embeddings,\n                         const at::Tensor offsets,\n                         at::Tensor outputs,\n                         const uint32_t B,\n                         const uint32_t D,\n                         const uint32_t C,\n                         const uint32_t L,\n                         const float S,\n                         const uint32_t H,\n                         at::optional<at::Tensor> dy_dx,\n                         const uint32_t gridtype,\n                         const bool align_corners,\n                         const uint32_t interp) {\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(embeddings);\n    CHECK_CUDA(offsets);\n    CHECK_CUDA(outputs);\n    // CHECK_CUDA(dy_dx);\n\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(embeddings);\n    CHECK_CONTIGUOUS(offsets);\n    CHECK_CONTIGUOUS(outputs);\n    // CHECK_CONTIGUOUS(dy_dx);\n\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(embeddings);\n    CHECK_IS_INT(offsets);\n    CHECK_IS_FLOATING(outputs);\n    // CHECK_IS_FLOATING(dy_dx);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            embeddings.scalar_type(), \"grid_encode_forward\", ([&] {\n                grid_encode_forward_cuda<scalar_t>(\n                        inputs.data_ptr<float>(),\n                        embeddings.data_ptr<scalar_t>(),\n                        offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(),\n                        B, D, C, L, S, H,\n                        dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>()\n                                          : nullptr,\n                        gridtype, align_corners, interp);\n            }));\n}\n\nvoid grid_encode_backward(const at::Tensor grad,\n                          const at::Tensor inputs,\n                          const at::Tensor embeddings,\n                          const at::Tensor offsets,\n                          at::Tensor grad_embeddings,\n                          const uint32_t B,\n                          const uint32_t D,\n                          const uint32_t C,\n                          const uint32_t L,\n                          const float S,\n                          const uint32_t H,\n                          const at::optional<at::Tensor> dy_dx,\n                          at::optional<at::Tensor> grad_inputs,\n                          const uint32_t gridtype,\n                          const bool align_corners,\n                          const uint32_t interp) {\n    CHECK_CUDA(grad);\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(embeddings);\n    CHECK_CUDA(offsets);\n    CHECK_CUDA(grad_embeddings);\n    // CHECK_CUDA(dy_dx);\n    // CHECK_CUDA(grad_inputs);\n\n    CHECK_CONTIGUOUS(grad);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(embeddings);\n    CHECK_CONTIGUOUS(offsets);\n    CHECK_CONTIGUOUS(grad_embeddings);\n    // CHECK_CONTIGUOUS(dy_dx);\n    // CHECK_CONTIGUOUS(grad_inputs);\n\n    CHECK_IS_FLOATING(grad);\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(embeddings);\n    CHECK_IS_INT(offsets);\n    CHECK_IS_FLOATING(grad_embeddings);\n    // CHECK_IS_FLOATING(dy_dx);\n    // CHECK_IS_FLOATING(grad_inputs);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            grad.scalar_type(), \"grid_encode_backward\", ([&] {\n                grid_encode_backward_cuda<scalar_t>(\n                        grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(),\n                        embeddings.data_ptr<scalar_t>(),\n                        offsets.data_ptr<int>(),\n                        grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H,\n                        dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>()\n                                          : nullptr,\n                        grad_inputs.has_value()\n                                ? grad_inputs.value().data_ptr<scalar_t>()\n                                : nullptr,\n                        gridtype, align_corners, interp);\n            }));\n}\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C>\n__global__ void kernel_grad_tv(const scalar_t *__restrict__ inputs,\n                               const scalar_t *__restrict__ grid,\n                               scalar_t *__restrict__ grad,\n                               const int *__restrict__ offsets,\n                               const float weight,\n                               const uint32_t B,\n                               const uint32_t L,\n                               const float S,\n                               const uint32_t H,\n                               const uint32_t gridtype,\n                               const bool align_corners) {\n    const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;\n\n    if (b >= B) return;\n\n    const uint32_t level = blockIdx.y;\n\n    // locate\n    inputs += b * D;\n    grid += (uint32_t)offsets[level] * C;\n    grad += (uint32_t)offsets[level] * C;\n\n    // check input range (should be in [0, 1])\n    bool flag_oob = false;\n#pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        if (inputs[d] < 0 || inputs[d] > 1) {\n            flag_oob = true;\n        }\n    }\n\n    // if input out of bound, do nothing\n    if (flag_oob) return;\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    const float scale = exp2f(level * S) * H - 1.0f;\n    const uint32_t resolution = (uint32_t)ceil(scale) + 1;\n\n    // calculate coordinate\n    float pos[D];\n    uint32_t pos_grid[D];  // [0, resolution]\n\n#pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);\n        pos_grid[d] = floorf(pos[d]);\n        // pos[d] -= (float)pos_grid[d]; // not used\n    }\n\n    // printf(\"[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\\n\", b, level, pos[0], pos[1],\n    // pos_grid[0], pos_grid[1]);\n\n    // total variation on pos_grid\n    scalar_t results[C] = {0};  // temp results in register\n    scalar_t idelta[C] = {0};\n\n    uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0,\n                                          hashmap_size, resolution, pos_grid);\n\n    scalar_t w = weight / (2 * D);\n\n#pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        uint32_t cur_d = pos_grid[d];\n        scalar_t grad_val;\n\n        // right side\n        if (cur_d < resolution) {\n            pos_grid[d] = cur_d + 1;\n            uint32_t index_right =\n                    get_grid_index<D, C>(gridtype, align_corners, 0,\n                                         hashmap_size, resolution, pos_grid);\n\n#pragma unroll\n            for (uint32_t ch = 0; ch < C; ch++) {\n                // results[ch] += w * clamp(grid[index + ch] - grid[index_right\n                // + ch], -1.0f, 1.0f);\n                grad_val = (grid[index + ch] - grid[index_right + ch]);\n                results[ch] += grad_val;\n                idelta[ch] += grad_val * grad_val;\n            }\n        }\n\n        // left side\n        if (cur_d > 0) {\n            pos_grid[d] = cur_d - 1;\n            uint32_t index_left =\n                    get_grid_index<D, C>(gridtype, align_corners, 0,\n                                         hashmap_size, resolution, pos_grid);\n\n#pragma unroll\n            for (uint32_t ch = 0; ch < C; ch++) {\n                // results[ch] += w * clamp(grid[index + ch] - grid[index_left +\n                // ch], -1.0f, 1.0f);\n                grad_val = (grid[index + ch] - grid[index_left + ch]);\n                results[ch] += grad_val;\n                idelta[ch] += grad_val * grad_val;\n            }\n        }\n\n        // reset\n        pos_grid[d] = cur_d;\n    }\n\n// writing to global memory (slow)\n#pragma unroll\n    for (uint32_t ch = 0; ch < C; ch++) {\n        // index may collide, so use atomic!\n        atomicAdd(&grad[index + ch],\n                  w * results[ch] * rsqrtf(idelta[ch] + 1e-9f));\n    }\n}\n\ntemplate <typename scalar_t, uint32_t D>\nvoid kernel_grad_tv_wrapper(const scalar_t *inputs,\n                            const scalar_t *embeddings,\n                            scalar_t *grad,\n                            const int *offsets,\n                            const float weight,\n                            const uint32_t B,\n                            const uint32_t C,\n                            const uint32_t L,\n                            const float S,\n                            const uint32_t H,\n                            const uint32_t gridtype,\n                            const bool align_corners) {\n    static constexpr uint32_t N_THREAD = 512;\n    const dim3 blocks_hashgrid = {div_round_up(B, N_THREAD), L, 1};\n    switch (C) {\n        case 1:\n            kernel_grad_tv<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(\n                    inputs, embeddings, grad, offsets, weight, B, L, S, H,\n                    gridtype, align_corners);\n            break;\n        case 2:\n            kernel_grad_tv<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(\n                    inputs, embeddings, grad, offsets, weight, B, L, S, H,\n                    gridtype, align_corners);\n            break;\n        case 4:\n            kernel_grad_tv<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(\n                    inputs, embeddings, grad, offsets, weight, B, L, S, H,\n                    gridtype, align_corners);\n            break;\n        case 8:\n            kernel_grad_tv<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(\n                    inputs, embeddings, grad, offsets, weight, B, L, S, H,\n                    gridtype, align_corners);\n            break;\n        default:\n            throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\ntemplate <typename scalar_t>\nvoid grad_total_variation_cuda(const scalar_t *inputs,\n                               const scalar_t *embeddings,\n                               scalar_t *grad,\n                               const int *offsets,\n                               const float weight,\n                               const uint32_t B,\n                               const uint32_t D,\n                               const uint32_t C,\n                               const uint32_t L,\n                               const float S,\n                               const uint32_t H,\n                               const uint32_t gridtype,\n                               const bool align_corners) {\n    switch (D) {\n        case 2:\n            kernel_grad_tv_wrapper<scalar_t, 2>(inputs, embeddings, grad,\n                                                offsets, weight, B, C, L, S, H,\n                                                gridtype, align_corners);\n            break;\n        case 3:\n            kernel_grad_tv_wrapper<scalar_t, 3>(inputs, embeddings, grad,\n                                                offsets, weight, B, C, L, S, H,\n                                                gridtype, align_corners);\n            break;\n        case 4:\n            kernel_grad_tv_wrapper<scalar_t, 4>(inputs, embeddings, grad,\n                                                offsets, weight, B, C, L, S, H,\n                                                gridtype, align_corners);\n            break;\n        case 5:\n            kernel_grad_tv_wrapper<scalar_t, 5>(inputs, embeddings, grad,\n                                                offsets, weight, B, C, L, S, H,\n                                                gridtype, align_corners);\n            break;\n        default:\n            throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\nvoid grad_total_variation(const at::Tensor inputs,\n                          const at::Tensor embeddings,\n                          at::Tensor grad,\n                          const at::Tensor offsets,\n                          const float weight,\n                          const uint32_t B,\n                          const uint32_t D,\n                          const uint32_t C,\n                          const uint32_t L,\n                          const float S,\n                          const uint32_t H,\n                          const uint32_t gridtype,\n                          const bool align_corners) {\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            embeddings.scalar_type(), \"grad_total_variation\", ([&] {\n                grad_total_variation_cuda<scalar_t>(\n                        inputs.data_ptr<scalar_t>(),\n                        embeddings.data_ptr<scalar_t>(),\n                        grad.data_ptr<scalar_t>(), offsets.data_ptr<int>(),\n                        weight, B, D, C, L, S, H, gridtype, align_corners);\n            }));\n}"
  },
  {
    "path": "lidarnerf/gridencoder/src/gridencoder.h",
    "content": "#ifndef _HASH_ENCODE_H\n#define _HASH_ENCODE_H\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// outputs: [B, L * C], float\n// H: base resolution\nvoid grid_encode_forward(const at::Tensor inputs,\n                         const at::Tensor embeddings,\n                         const at::Tensor offsets,\n                         at::Tensor outputs,\n                         const uint32_t B,\n                         const uint32_t D,\n                         const uint32_t C,\n                         const uint32_t L,\n                         const float S,\n                         const uint32_t H,\n                         at::optional<at::Tensor> dy_dx,\n                         const uint32_t gridtype,\n                         const bool align_corners,\n                         const uint32_t interp);\nvoid grid_encode_backward(const at::Tensor grad,\n                          const at::Tensor inputs,\n                          const at::Tensor embeddings,\n                          const at::Tensor offsets,\n                          at::Tensor grad_embeddings,\n                          const uint32_t B,\n                          const uint32_t D,\n                          const uint32_t C,\n                          const uint32_t L,\n                          const float S,\n                          const uint32_t H,\n                          const at::optional<at::Tensor> dy_dx,\n                          at::optional<at::Tensor> grad_inputs,\n                          const uint32_t gridtype,\n                          const bool align_corners,\n                          const uint32_t interp);\n\nvoid grad_total_variation(const at::Tensor inputs,\n                          const at::Tensor embeddings,\n                          at::Tensor grad,\n                          const at::Tensor offsets,\n                          const float weight,\n                          const uint32_t B,\n                          const uint32_t D,\n                          const uint32_t C,\n                          const uint32_t L,\n                          const float S,\n                          const uint32_t H,\n                          const uint32_t gridtype,\n                          const bool align_corners);\n\n#endif"
  },
  {
    "path": "lidarnerf/loss.py",
    "content": "import torch\n\nimport numpy as np\n\n\ndef mape_loss(pred, target, reduction=\"mean\"):\n    # pred, target: [B, 1], torch tenspr\n    difference = (pred - target).abs()\n    scale = 1 / (target.abs() + 1e-2)\n    loss = difference * scale\n\n    if reduction == \"mean\":\n        loss = loss.mean()\n\n    return loss\n\n\ndef huber_loss(pred, target, delta=0.1, reduction=\"mean\"):\n    rel = (pred - target).abs()\n    sqr = 0.5 / delta * rel * rel\n    loss = torch.where(rel > delta, rel - 0.5 * delta, sqr)\n\n    if reduction == \"mean\":\n        loss = loss.mean()\n\n    return loss\n\n\n# ref: https://github.com/sunset1995/torch_efficient_distloss/blob/main/torch_efficient_distloss/eff_distloss.py\nclass EffDistLoss(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, w, m, interval):\n        \"\"\"\n        Efficient O(N) realization of distortion loss.\n        There are B rays each with N sampled points.\n        w:        Float tensor in shape [B,N]. Volume rendering weights of each point.\n        m:        Float tensor in shape [B,N]. Midpoint distance to camera of each point.\n        interval: Scalar or float tensor in shape [B,N]. The query interval of each point.\n        \"\"\"\n        n_rays = np.prod(w.shape[:-1])\n        wm = w * m\n        w_cumsum = w.cumsum(dim=-1)\n        wm_cumsum = wm.cumsum(dim=-1)\n\n        w_total = w_cumsum[..., [-1]]\n        wm_total = wm_cumsum[..., [-1]]\n        w_prefix = torch.cat([torch.zeros_like(w_total), w_cumsum[..., :-1]], dim=-1)\n        wm_prefix = torch.cat([torch.zeros_like(wm_total), wm_cumsum[..., :-1]], dim=-1)\n        loss_uni = (1 / 3) * interval * w.pow(2)\n        loss_bi = 2 * w * (m * w_prefix - wm_prefix)\n        if torch.is_tensor(interval):\n            ctx.save_for_backward(\n                w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval\n            )\n            ctx.interval = None\n        else:\n            ctx.save_for_backward(w, m, wm, w_prefix, w_total, wm_prefix, wm_total)\n            ctx.interval = interval\n        ctx.n_rays = n_rays\n        return (loss_bi.sum() + loss_uni.sum()) / n_rays\n\n    @staticmethod\n    @torch.autograd.function.once_differentiable\n    def backward(ctx, grad_back):\n        interval = ctx.interval\n        n_rays = ctx.n_rays\n        if interval is None:\n            (\n                w,\n                m,\n                wm,\n                w_prefix,\n                w_total,\n                wm_prefix,\n                wm_total,\n                interval,\n            ) = ctx.saved_tensors\n        else:\n            w, m, wm, w_prefix, w_total, wm_prefix, wm_total = ctx.saved_tensors\n        grad_uni = (1 / 3) * interval * 2 * w\n        w_suffix = w_total - (w_prefix + w)\n        wm_suffix = wm_total - (wm_prefix + wm)\n        grad_bi = 2 * (m * (w_prefix - w_suffix) + (wm_suffix - wm_prefix))\n        grad = grad_back * (grad_bi + grad_uni) / n_rays\n        return grad, None, None, None\n\n\neff_distloss = EffDistLoss.apply\n"
  },
  {
    "path": "lidarnerf/nerf/network.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom lidarnerf.encoding import get_encoder\nfrom lidarnerf.activation import trunc_exp\nfrom .renderer import NeRFRenderer\n\n\nclass NeRFNetwork(NeRFRenderer):\n    def __init__(\n        self,\n        encoding=\"hashgrid\",\n        encoding_dir=\"frequency\",\n        multires=15,\n        encoding_bg=\"hashgrid\",\n        desired_resolution=2048,\n        log2_hashmap_size=19,\n        num_layers=2,\n        hidden_dim=64,\n        geo_feat_dim=15,\n        num_layers_color=3,\n        hidden_dim_color=64,\n        num_layers_bg=2,\n        hidden_dim_bg=64,\n        out_color_dim=3,\n        out_lidar_color_dim=2,\n        bound=1,\n        **kwargs,\n    ):\n        super().__init__(bound, **kwargs)\n\n        # sigma network\n        self.num_layers = num_layers\n        self.hidden_dim = hidden_dim\n        self.geo_feat_dim = geo_feat_dim\n        self.out_color_dim = out_color_dim\n        self.out_lidar_color_dim = out_lidar_color_dim\n        self.encoder, self.in_dim = get_encoder(\n            encoding,\n            desired_resolution=desired_resolution,\n            log2_hashmap_size=log2_hashmap_size,\n        )\n\n        sigma_net = []\n        for l in range(num_layers):\n            if l == 0:\n                in_dim = self.in_dim\n            else:\n                in_dim = hidden_dim\n\n            if l == num_layers - 1:\n                out_dim = 1 + self.geo_feat_dim  # 1 sigma + 15 SH features for color\n            else:\n                out_dim = hidden_dim\n\n            sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n        self.sigma_net = nn.ModuleList(sigma_net)\n\n        # color network\n        self.num_layers_color = num_layers_color\n        self.hidden_dim_color = hidden_dim_color\n        self.encoder_dir, self.in_dim_dir = get_encoder(\"sphere_harmonics\")\n\n        color_net = []\n        for l in range(num_layers_color):\n            if l == 0:\n                in_dim = self.in_dim_dir + self.geo_feat_dim\n            else:\n                in_dim = hidden_dim_color\n\n            if l == num_layers_color - 1:\n                out_dim = self.out_color_dim  # 3 rgb\n            else:\n                out_dim = hidden_dim_color\n\n            color_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n        self.color_net = nn.ModuleList(color_net)\n\n        # lidar color network\n        self.encoder_lidar_dir, self.in_dim_dir = get_encoder(\"frequency\", multires=12)\n\n        lidar_color_net = []\n        for l in range(num_layers_color):\n            if l == 0:\n                in_dim = self.in_dim_dir + self.geo_feat_dim\n            else:\n                in_dim = hidden_dim_color\n\n            if l == num_layers_color - 1:\n                out_dim = self.out_lidar_color_dim  # 2 rgb\n            else:\n                out_dim = hidden_dim_color\n\n            lidar_color_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n        self.lidar_color_net = nn.ModuleList(lidar_color_net)\n\n        # background network\n        if self.bg_radius > 0:\n            self.num_layers_bg = num_layers_bg\n            self.hidden_dim_bg = hidden_dim_bg\n            self.encoder_bg, self.in_dim_bg = get_encoder(\n                encoding_bg,\n                input_dim=2,\n                num_levels=4,\n                log2_hashmap_size=19,\n                desired_resolution=2048,\n            )  # much smaller hashgrid\n\n            bg_net = []\n            for l in range(num_layers_bg):\n                if l == 0:\n                    in_dim = self.in_dim_bg + self.in_dim_dir\n                else:\n                    in_dim = hidden_dim_bg\n\n                if l == num_layers_bg - 1:\n                    out_dim = 3  # 3 rgb\n                else:\n                    out_dim = hidden_dim_bg\n\n                bg_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n            self.bg_net = nn.ModuleList(bg_net)\n        else:\n            self.bg_net = None\n\n    def forward(self, x, d):\n        # x: [N, 3], in [-bound, bound]\n        # d: [N, 3], nomalized in [-1, 1]\n\n        # sigma\n        x = self.encoder(x, bound=self.bound)\n\n        h = x\n        for l in range(self.num_layers):\n            h = self.sigma_net[l](h)\n            if l != self.num_layers - 1:\n                h = F.relu(h, inplace=True)\n\n        # sigma = F.relu(h[..., 0])\n        sigma = trunc_exp(h[..., 0])\n        geo_feat = h[..., 1:]\n\n        # color\n\n        d = self.encoder_dir(d)\n        h = torch.cat([d, geo_feat], dim=-1)\n        for l in range(self.num_layers_color):\n            h = self.color_net[l](h)\n            if l != self.num_layers_color - 1:\n                h = F.relu(h, inplace=True)\n\n        # sigmoid activation for rgb\n        color = torch.sigmoid(h)\n\n        return sigma, color\n\n    def density(self, x):\n        # x: [N, 3], in [-bound, bound]\n\n        x = self.encoder(x, bound=self.bound)\n        h = x\n        for l in range(self.num_layers):\n            h = self.sigma_net[l](h)\n            if l != self.num_layers - 1:\n                h = F.relu(h, inplace=True)\n\n        # sigma = F.relu(h[..., 0])\n        sigma = trunc_exp(h[..., 0])\n        geo_feat = h[..., 1:]\n\n        return {\n            \"sigma\": sigma,\n            \"geo_feat\": geo_feat,\n        }\n\n    def background(self, x, d):\n        # x: [N, 2], in [-1, 1]\n\n        h = self.encoder_bg(x)  # [N, C]\n        d = self.encoder_dir(d)\n\n        h = torch.cat([d, h], dim=-1)\n        for l in range(self.num_layers_bg):\n            h = self.bg_net[l](h)\n            if l != self.num_layers_bg - 1:\n                h = F.relu(h, inplace=True)\n\n        # sigmoid activation for rgb\n        rgbs = torch.sigmoid(h)\n\n        return rgbs\n\n    # allow masked inference\n    def color(self, x, d, cal_lidar_color=False, mask=None, geo_feat=None, **kwargs):\n        # x: [N, 3] in [-bound, bound]\n        # mask: [N,], bool, indicates where we actually needs to compute rgb.\n\n        if mask is not None:\n            rgbs = torch.zeros(\n                mask.shape[0], self.out_dim, dtype=x.dtype, device=x.device\n            )  # [N, 3]\n            # in case of empty mask\n            if not mask.any():\n                return rgbs\n            x = x[mask]\n            d = d[mask]\n            geo_feat = geo_feat[mask]\n\n        if cal_lidar_color:\n            d = self.encoder_lidar_dir(d)\n            h = torch.cat([d, geo_feat], dim=-1)\n            for l in range(self.num_layers_color):\n                h = self.lidar_color_net[l](h)\n                if l != self.num_layers_color - 1:\n                    h = F.relu(h, inplace=True)\n        else:\n            d = self.encoder_dir(d)\n            h = torch.cat([d, geo_feat], dim=-1)\n            for l in range(self.num_layers_color):\n                h = self.color_net[l](h)\n                if l != self.num_layers_color - 1:\n                    h = F.relu(h, inplace=True)\n\n        # sigmoid activation for rgb\n        h = torch.sigmoid(h)\n\n        if mask is not None:\n            rgbs[mask] = h.to(rgbs.dtype)  # fp16 --> fp32\n        else:\n            rgbs = h\n\n        return rgbs\n\n    # optimizer utils\n    def get_params(self, lr):\n        params = [\n            {\"params\": self.encoder.parameters(), \"lr\": lr},\n            {\"params\": self.sigma_net.parameters(), \"lr\": lr},\n            {\"params\": self.encoder_dir.parameters(), \"lr\": lr},\n            {\"params\": self.color_net.parameters(), \"lr\": lr},\n            {\"params\": self.encoder_lidar_dir.parameters(), \"lr\": lr},\n            {\"params\": self.lidar_color_net.parameters(), \"lr\": lr},\n        ]\n        if self.bg_radius > 0:\n            params.append({\"params\": self.encoder_bg.parameters(), \"lr\": lr})\n            params.append({\"params\": self.bg_net.parameters(), \"lr\": lr})\n\n        return params\n"
  },
  {
    "path": "lidarnerf/nerf/network_tcnn.py",
    "content": "import torch\n\nimport numpy as np\n\nimport tinycudann as tcnn\nfrom lidarnerf.activation import trunc_exp\nfrom .renderer import NeRFRenderer\n\n\nclass NeRFNetwork(NeRFRenderer):\n    def __init__(\n        self,\n        encoding=\"HashGrid\",\n        desired_resolution=2048,\n        log2_hashmap_size=19,\n        encoding_dir=\"SphericalHarmonics\",\n        n_features_per_level=2,\n        num_layers=2,\n        hidden_dim=64,\n        geo_feat_dim=15,\n        num_layers_color=3,\n        hidden_dim_color=64,\n        out_color_dim=3,\n        out_lidar_color_dim=2,\n        bound=1,\n        **kwargs,\n    ):\n        super().__init__(bound, **kwargs)\n\n        # sigma network\n        self.num_layers = num_layers\n        self.hidden_dim = hidden_dim\n        self.geo_feat_dim = geo_feat_dim\n        self.desired_resolution = desired_resolution\n        self.log2_hashmap_size = log2_hashmap_size\n        self.out_color_dim = out_color_dim\n        self.out_lidar_color_dim = out_lidar_color_dim\n        self.n_features_per_level = n_features_per_level\n\n        per_level_scale = np.exp2(\n            np.log2(self.desired_resolution * bound / 16) / (16 - 1)\n        )\n        print(f\"TCNN desired resolution: {self.desired_resolution}\")\n        print(f\"TCNN per level scale: {per_level_scale}\")\n\n        self.encoder = tcnn.Encoding(\n            n_input_dims=3,\n            encoding_config={\n                \"otype\": \"HashGrid\",\n                \"n_levels\": 16,\n                \"n_features_per_level\": self.n_features_per_level,\n                \"log2_hashmap_size\": self.log2_hashmap_size,\n                \"base_resolution\": 16,\n                \"per_level_scale\": per_level_scale,\n                # \"interpolation\": \"Smoothstep\"\n            },\n        )\n\n        self.sigma_net = tcnn.Network(\n            n_input_dims=self.encoder.n_output_dims,\n            n_output_dims=1 + self.geo_feat_dim,\n            network_config={\n                \"otype\": \"FullyFusedMLP\",\n                \"activation\": \"ReLU\",\n                \"output_activation\": \"None\",\n                \"n_neurons\": hidden_dim,\n                \"n_hidden_layers\": num_layers - 1,\n            },\n        )\n\n        # color network\n        self.num_layers_color = num_layers_color\n        self.hidden_dim_color = hidden_dim_color\n\n        # # SH\n        self.encoder_dir = tcnn.Encoding(\n            n_input_dims=3,\n            encoding_config={\n                \"otype\": \"SphericalHarmonics\",\n                \"degree\": 4,\n            },\n        )\n        # # Hash\n        # per_level_scale = np.exp2(np.log2(1024 * bound / 4) / (4 - 1))\n        # self.encoder_dir = tcnn.Encoding(\n        #     n_input_dims=3,\n        #     encoding_config={\n        #         \"otype\": \"HashGrid\",\n        #         \"n_levels\": 4,\n        #         \"n_features_per_level\": 2,\n        #         \"log2_hashmap_size\": self.log2_hashmap_size,\n        #         \"base_resolution\": 128,\n        #         \"per_level_scale\": per_level_scale,\n        #     },\n        # )\n        # # freq\n        self.encoder_lidar_dir = tcnn.Encoding(\n            n_input_dims=3,\n            encoding_config={\n                \"otype\": \"Frequency\",\n                \"degree\": 12,\n            },\n        )\n\n        self.in_dim_color = self.encoder_dir.n_output_dims + self.geo_feat_dim\n\n        self.color_net = tcnn.Network(\n            n_input_dims=self.in_dim_color,\n            n_output_dims=self.out_color_dim,\n            network_config={\n                \"otype\": \"FullyFusedMLP\",\n                \"activation\": \"ReLU\",\n                \"output_activation\": \"None\",\n                \"n_neurons\": hidden_dim_color,\n                \"n_hidden_layers\": num_layers_color - 1,\n            },\n        )\n\n        self.in_dim_lidar_color = (\n            self.encoder_lidar_dir.n_output_dims + self.geo_feat_dim\n        )\n        self.lidar_color_net = tcnn.Network(\n            n_input_dims=self.in_dim_lidar_color,\n            n_output_dims=self.out_lidar_color_dim,\n            network_config={\n                \"otype\": \"FullyFusedMLP\",\n                \"activation\": \"ReLU\",\n                \"output_activation\": \"None\",\n                \"n_neurons\": hidden_dim_color,\n                \"n_hidden_layers\": num_layers_color - 1,\n            },\n        )\n\n    def forward(self, x, d):\n        pass\n\n    def density(self, x):\n        # x: [N, 3], in [-bound, bound]\n\n        x = (x + self.bound) / (2 * self.bound)  # to [0, 1]\n        x = self.encoder(x)\n        h = self.sigma_net(x)\n\n        # sigma = F.relu(h[..., 0])\n        sigma = trunc_exp(h[..., 0])\n        geo_feat = h[..., 1:]\n\n        return {\n            \"sigma\": sigma,\n            \"geo_feat\": geo_feat,\n        }\n\n    # allow masked inference\n    def color(self, x, d, cal_lidar_color=False, mask=None, geo_feat=None, **kwargs):\n        # x: [N, 3] in [-bound, bound]\n        # mask: [N,], bool, indicates where we actually needs to compute rgb.\n\n        x = (x + self.bound) / (2 * self.bound)  # to [0, 1]\n\n        if mask is not None:\n            rgbs = torch.zeros(\n                mask.shape[0], self.out_dim, dtype=x.dtype, device=x.device\n            )  # [N, 3]\n            # in case of empty mask\n            if not mask.any():\n                return rgbs\n            x = x[mask]\n            d = d[mask]\n            geo_feat = geo_feat[mask]\n\n        # color\n        # d = (d + 1) / 2  # tcnn SH encoding requires inputs to be in [0, 1]\n        # d = self.encoder_dir(d)\n\n        # h = torch.cat([d, geo_feat], dim=-1)\n        if cal_lidar_color:\n            d = (d + 1) / 2  # tcnn SH encoding requires inputs to be in [0, 1]\n            d = self.encoder_lidar_dir(d)\n            h = torch.cat([d, geo_feat], dim=-1)\n            h = self.lidar_color_net(h)\n        else:\n            d = (d + 1) / 2\n            d = self.encoder_dir(d)\n            h = torch.cat([d, geo_feat], dim=-1)\n            h = self.color_net(h)\n\n        # sigmoid activation for rgb\n        h = torch.sigmoid(h)\n\n        if mask is not None:\n            rgbs[mask] = h.to(rgbs.dtype)  # fp16 --> fp32\n        else:\n            rgbs = h\n\n        return rgbs\n\n    # optimizer utils\n    def get_params(self, lr):\n        params = [\n            {\"params\": self.encoder.parameters(), \"lr\": lr},\n            {\"params\": self.sigma_net.parameters(), \"lr\": lr},\n            {\"params\": self.encoder_dir.parameters(), \"lr\": lr},\n            {\"params\": self.encoder_lidar_dir.parameters(), \"lr\": lr},\n            {\"params\": self.color_net.parameters(), \"lr\": lr},\n            {\"params\": self.lidar_color_net.parameters(), \"lr\": lr},\n        ]\n        if self.bg_radius > 0:\n            params.append({\"params\": self.encoder_bg.parameters(), \"lr\": lr})\n            params.append({\"params\": self.bg_net.parameters(), \"lr\": lr})\n\n        return params\n"
  },
  {
    "path": "lidarnerf/nerf/renderer.py",
    "content": "import math\nimport trimesh\n\nimport torch\nimport torch.nn as nn\n\nfrom lidarnerf import raymarching\n\n\ndef sample_pdf(bins, weights, n_samples, det=False):\n    # This implementation is from NeRF\n    # bins: [B, T], old_z_vals\n    # weights: [B, T - 1], bin weights.\n    # return: [B, n_samples], new_z_vals\n\n    # Get pdf\n    weights = weights + 1e-5  # prevent nans\n    pdf = weights / torch.sum(weights, -1, keepdim=True)\n    cdf = torch.cumsum(pdf, -1)\n    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)\n    # Take uniform samples\n    if det:\n        u = torch.linspace(\n            0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples\n        ).to(weights.device)\n        u = u.expand(list(cdf.shape[:-1]) + [n_samples])\n    else:\n        u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)\n\n    # Invert CDF\n    u = u.contiguous()\n    inds = torch.searchsorted(cdf, u, right=True)\n    below = torch.max(torch.zeros_like(inds - 1), inds - 1)\n    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)\n    inds_g = torch.stack([below, above], -1)  # (B, n_samples, 2)\n\n    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]\n    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)\n    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)\n\n    denom = cdf_g[..., 1] - cdf_g[..., 0]\n    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)\n    t = (u - cdf_g[..., 0]) / denom\n    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])\n\n    return samples\n\n\ndef plot_pointcloud(pc, color=None):\n    # pc: [N, 3]\n    # color: [N, 3/4]\n    print(\"[visualize points]\", pc.shape, pc.dtype, pc.min(0), pc.max(0))\n    pc = trimesh.PointCloud(pc, color)\n    # axis\n    axes = trimesh.creation.axis(axis_length=4)\n    # sphere\n    sphere = trimesh.creation.icosphere(radius=1)\n    trimesh.Scene([pc, axes, sphere]).show()\n\n\nclass NeRFRenderer(nn.Module):\n    def __init__(\n        self,\n        bound=1,\n        density_scale=1,  # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.\n        min_near=0.2,\n        min_near_lidar=0.2,\n        density_thresh=0.01,\n        bg_radius=-1,\n    ):\n        super().__init__()\n\n        self.bound = bound\n        self.cascade = 1 + math.ceil(math.log2(bound))\n        self.grid_size = 128\n        self.density_scale = density_scale\n        self.min_near = min_near\n        self.min_near_lidar = min_near_lidar\n        self.density_thresh = density_thresh\n        self.bg_radius = bg_radius  # radius of the background sphere.\n\n        # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)\n        # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.\n        aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])\n        aabb_infer = aabb_train.clone()\n        self.register_buffer(\"aabb_train\", aabb_train)\n        self.register_buffer(\"aabb_infer\", aabb_infer)\n\n    def forward(self, x, d):\n        raise NotImplementedError()\n\n    # separated density and color query (can accelerate non-cuda-ray mode.)\n    def density(self, x):\n        raise NotImplementedError()\n\n    def color(self, x, d, mask=None, **kwargs):\n        raise NotImplementedError()\n\n    def run(\n        self,\n        rays_o,\n        rays_d,\n        cal_lidar_color=False,\n        num_steps=128,\n        upsample_steps=128,\n        bg_color=None,\n        perturb=False,\n        **kwargs\n    ):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # bg_color: [3] in range [0, 1]\n        # return: image: [B, N, 3], depth: [B, N]\n        if cal_lidar_color:\n            self.out_dim = self.out_lidar_color_dim\n        else:\n            self.out_dim = self.out_color_dim\n\n        prefix = rays_o.shape[:-1]\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # N = B * N, in fact\n        device = rays_o.device\n\n        # choose aabb\n        aabb = self.aabb_train if self.training else self.aabb_infer\n\n        # sample steps\n        if cal_lidar_color:\n            nears = (\n                torch.ones(N, dtype=rays_o.dtype, device=rays_o.device)\n                * self.min_near_lidar\n            )\n            fars = (\n                torch.ones(N, dtype=rays_o.dtype, device=rays_o.device)\n                * self.min_near_lidar\n                * 81.0\n            )  # hard code\n        else:\n            nears, fars = raymarching.near_far_from_aabb(\n                rays_o, rays_d, aabb, self.min_near\n            )\n\n        nears.unsqueeze_(-1)\n        fars.unsqueeze_(-1)\n\n        # print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')\n\n        z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(\n            0\n        )  # [1, T]\n        z_vals = z_vals.expand((N, num_steps))  # [N, T]\n        z_vals = nears + (fars - nears) * z_vals  # [N, T], in [nears, fars]\n\n        # perturb z_vals\n        sample_dist = (fars - nears) / num_steps\n        if perturb:\n            z_vals = (\n                z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist\n            )\n            # z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.\n\n        # generate xyzs\n        xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(\n            -1\n        )  # [N, 1, 3] * [N, T, 1] -> [N, T, 3]\n        xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:])  # a manual clip.\n\n        # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())\n\n        # query SDF and RGB\n        density_outputs = self.density(xyzs.reshape(-1, 3))\n\n        # sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]\n        for k, v in density_outputs.items():\n            density_outputs[k] = v.view(N, num_steps, -1)\n\n        # upsample z_vals (nerf-like)\n        if upsample_steps > 0:\n            with torch.no_grad():\n                deltas = z_vals[..., 1:] - z_vals[..., :-1]  # [N, T-1]\n                deltas = torch.cat(\n                    [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1\n                )\n\n                alphas = 1 - torch.exp(\n                    -deltas * self.density_scale * density_outputs[\"sigma\"].squeeze(-1)\n                )  # [N, T]\n                alphas_shifted = torch.cat(\n                    [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1\n                )  # [N, T+1]\n                weights = (\n                    alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]\n                )  # [N, T]\n\n                z_vals_mid = z_vals[..., :-1] + 0.5 * deltas[..., :-1]  # [N, T-1]\n                new_z_vals = sample_pdf(\n                    z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training\n                ).detach()  # [N, t]\n\n                new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(\n                    -2\n                ) * new_z_vals.unsqueeze(\n                    -1\n                )  # [N, 1, 3] * [N, t, 1] -> [N, t, 3]\n                new_xyzs = torch.min(\n                    torch.max(new_xyzs, aabb[:3]), aabb[3:]\n                )  # a manual clip.\n\n            # only forward new points to save computation\n            new_density_outputs = self.density(new_xyzs.reshape(-1, 3))\n            # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]\n            for k, v in new_density_outputs.items():\n                new_density_outputs[k] = v.view(N, upsample_steps, -1)\n\n            # re-order\n            z_vals = torch.cat([z_vals, new_z_vals], dim=1)  # [N, T+t]\n            z_vals, z_index = torch.sort(z_vals, dim=1)\n\n            xyzs = torch.cat([xyzs, new_xyzs], dim=1)  # [N, T+t, 3]\n            xyzs = torch.gather(\n                xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)\n            )\n\n            for k in density_outputs:\n                tmp_output = torch.cat(\n                    [density_outputs[k], new_density_outputs[k]], dim=1\n                )\n                density_outputs[k] = torch.gather(\n                    tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)\n                )\n\n        deltas = z_vals[..., 1:] - z_vals[..., :-1]  # [N, T+t-1]\n        deltas = torch.cat(\n            [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1\n        )\n        alphas = 1 - torch.exp(\n            -deltas * self.density_scale * density_outputs[\"sigma\"].squeeze(-1)\n        )  # [N, T+t]\n        alphas_shifted = torch.cat(\n            [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1\n        )  # [N, T+t+1]\n        weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]  # [N, T+t]\n\n        dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)\n        for k, v in density_outputs.items():\n            density_outputs[k] = v.view(-1, v.shape[-1])\n\n        mask = weights > 1e-4  # hard coded\n        rgbs = self.color(\n            xyzs.reshape(-1, 3),\n            dirs.reshape(-1, 3),\n            cal_lidar_color=cal_lidar_color,\n            mask=mask.reshape(-1),\n            **density_outputs\n        )\n\n        rgbs = rgbs.view(N, -1, self.out_dim)  # [N, T+t, 3]\n\n        # print(xyzs.shape, 'valid_rgb:', mask.sum().item())\n\n        # calculate weight_sum (mask)\n        weights_sum = weights.sum(dim=-1)  # [N]\n\n        # calculate depth  Note: not real depth!!\n        # ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)\n        # depth = torch.sum(weights * ori_z_vals, dim=-1)\n        depth = torch.sum(weights * z_vals, dim=-1)\n\n        # calculate color\n        image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2)  # [N, 3], in [0, 1]\n\n        # mix background color\n        if self.bg_radius > 0:\n            # use the bg model to calculate bg_color\n            sph = raymarching.sph_from_ray(\n                rays_o, rays_d, self.bg_radius\n            )  # [N, 2] in [-1, 1]\n            bg_color = self.background(sph, rays_d.reshape(-1, 3))  # [N, 3]\n        elif bg_color is None:\n            bg_color = 1\n\n        if not cal_lidar_color:\n            image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n\n        image = image.view(*prefix, self.out_dim)\n        depth = depth.view(*prefix)\n\n        # tmp: reg loss in mip-nerf 360\n        # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)\n        # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]\n        # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()\n\n        return {\n            \"depth_lidar\": depth,\n            \"image_lidar\": image,\n            \"weights_sum_lidar\": weights_sum,\n        }\n\n    def render(\n        self,\n        rays_o,\n        rays_d,\n        cal_lidar_color=False,\n        staged=False,\n        max_ray_batch=4096,\n        **kwargs\n    ):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # return: pred_rgb: [B, N, 3]\n\n        _run = self.run\n\n        B, N = rays_o.shape[:2]\n        device = rays_o.device\n\n        if staged:\n            if cal_lidar_color:\n                out_dim = self.out_lidar_color_dim\n                res_keys = [\"depth_lidar\", \"image_lidar\"]\n            depth = torch.empty((B, N), device=device)\n            image = torch.empty((B, N, out_dim), device=device)\n\n            for b in range(B):\n                head = 0\n                while head < N:\n                    tail = min(head + max_ray_batch, N)\n                    results_ = _run(\n                        rays_o[b : b + 1, head:tail],\n                        rays_d[b : b + 1, head:tail],\n                        cal_lidar_color=cal_lidar_color,\n                        **kwargs\n                    )\n                    depth[b : b + 1, head:tail] = results_[res_keys[0]]\n                    image[b : b + 1, head:tail] = results_[res_keys[1]]\n                    head += max_ray_batch\n\n            results = {}\n            results[res_keys[0]] = depth\n            results[res_keys[1]] = image\n\n        else:\n            results = _run(rays_o, rays_d, cal_lidar_color=cal_lidar_color, **kwargs)\n\n        return results\n"
  },
  {
    "path": "lidarnerf/nerf/utils.py",
    "content": "import glob\nimport os\nimport random\nimport time\n\nimport cv2\nimport imageio\nimport lpips\nimport mcubes\nimport numpy as np\nimport tensorboardX\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport tqdm\nimport trimesh\nfrom rich.console import Console\nfrom skimage.metrics import structural_similarity\nfrom torch_ema import ExponentialMovingAverage\n\nfrom extern.chamfer3D.dist_chamfer_3D import chamfer_3DDist\nfrom extern.fscore import fscore\n\nfrom lidarnerf.dataset.base_dataset import custom_meshgrid\n\nfrom lidarnerf.convert import pano_to_lidar\n\n\ndef is_ali_cluster():\n    import socket\n\n    hostname = socket.gethostname()\n    return \"auto-drive\" in hostname\n\n\n@torch.jit.script\ndef linear_to_srgb(x):\n    return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x**0.41666 - 0.055)\n\n\n@torch.jit.script\ndef srgb_to_linear(x):\n    return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)\n\n\ndef filter_bbox_dataset(pc, OBB_local):\n    bbox_mask = np.isnan(pc[:, 0])\n    z_min, z_max = min(OBB_local[:, 2]), max(OBB_local[:, 2])\n    for i, (c1, c2) in enumerate(zip(pc[:, 2] <= z_max, pc[:, 2] >= z_min)):\n        bbox_mask[i] = c1 and c2\n    pc = pc[bbox_mask]\n    OBB_local = sorted(OBB_local, key=lambda p: p[2])\n    OBB_2D = np.array(OBB_local)[:4, :2]\n    pc = filter_poly(pc, OBB_2D)\n    return pc\n\n\ndef filter_poly(pcs, OBB_2D):\n    OBB_2D = sort_quadrilateral(OBB_2D)\n    mask = []\n    for pc in pcs:\n        mask.append(is_in_poly(pc[0], pc[1], OBB_2D))\n    return pcs[mask]\n\n\ndef sort_quadrilateral(points):\n    points = points.tolist()\n    top_left = min(points, key=lambda p: p[0] + p[1])\n    bottom_right = max(points, key=lambda p: p[0] + p[1])\n    points.remove(top_left)\n    points.remove(bottom_right)\n    bottom_left, top_right = points\n    if bottom_left[1] > top_right[1]:\n        bottom_left, top_right = top_right, bottom_left\n    return [top_left, top_right, bottom_right, bottom_left]\n\n\ndef is_in_poly(px, py, poly):\n    \"\"\"\n    :param p: [x, y]\n    :param poly: [[], [], [], [], ...]\n    :return:\n    \"\"\"\n    is_in = False\n    for i, corner in enumerate(poly):\n        next_i = i + 1 if i + 1 < len(poly) else 0\n        x1, y1 = corner\n        x2, y2 = poly[next_i]\n        if (x1 == px and y1 == py) or (x2 == px and y2 == py):  # if point is on vertex\n            is_in = True\n            break\n        if min(y1, y2) < py <= max(y1, y2):  # find horizontal edges of polygon\n            x = x1 + (py - y1) * (x2 - x1) / (y2 - y1)\n            if x == px:  # if point is on edge\n                is_in = True\n                break\n            elif x > px:  # if point is on left-side of line\n                is_in = not is_in\n    return is_in\n\n\ndef seed_everything(seed):\n    random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    # torch.backends.cudnn.deterministic = True\n    # torch.backends.cudnn.benchmark = True\n\n\ndef torch_vis_2d(x, renormalize=False):\n    # x: [3, H, W] or [1, H, W] or [H, W]\n    import matplotlib.pyplot as plt\n    import numpy as np\n    import torch\n\n    if isinstance(x, torch.Tensor):\n        if len(x.shape) == 3:\n            x = x.permute(1, 2, 0).squeeze()\n        x = x.detach().cpu().numpy()\n\n    print(f\"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}\")\n\n    x = x.astype(np.float32)\n\n    # renormalize\n    if renormalize:\n        x = (x - x.min(axis=0, keepdims=True)) / (\n            x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8\n        )\n\n    plt.imshow(x)\n    plt.show()\n\n\ndef extract_fields(bound_min, bound_max, resolution, query_func, S=128):\n    X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)\n    Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)\n    Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)\n\n    u = np.zeros([resolution, resolution, resolution], dtype=np.float32)\n    with torch.no_grad():\n        for xi, xs in enumerate(X):\n            for yi, ys in enumerate(Y):\n                for zi, zs in enumerate(Z):\n                    xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                    pts = torch.cat(\n                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                        dim=-1,\n                    )  # [S, 3]\n                    val = (\n                        query_func(pts)\n                        .reshape(len(xs), len(ys), len(zs))\n                        .detach()\n                        .cpu()\n                        .numpy()\n                    )  # [S, 1] --> [x, y, z]\n                    u[\n                        xi * S : xi * S + len(xs),\n                        yi * S : yi * S + len(ys),\n                        zi * S : zi * S + len(zs),\n                    ] = val\n    return u\n\n\ndef extract_geometry(bound_min, bound_max, resolution, threshold, query_func):\n    # print('threshold: {}'.format(threshold))\n    u = extract_fields(bound_min, bound_max, resolution, query_func)\n\n    # print(u.shape, u.max(), u.min(), np.percentile(u, 50))\n\n    vertices, triangles = mcubes.marching_cubes(u, threshold)\n\n    b_max_np = bound_max.detach().cpu().numpy()\n    b_min_np = bound_min.detach().cpu().numpy()\n\n    vertices = (\n        vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :]\n        + b_min_np[None, :]\n    )\n    return vertices, triangles\n\n\nclass PSNRMeter:\n    def __init__(self):\n        self.V = 0\n        self.N = 0\n\n    def clear(self):\n        self.V = 0\n        self.N = 0\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            if torch.is_tensor(inp):\n                inp = inp.detach().cpu().numpy()\n            outputs.append(inp)\n\n        return outputs\n\n    def update(self, preds, truths):\n        preds, truths = self.prepare_inputs(\n            preds, truths\n        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]\n\n        # simplified since max_pixel_value is 1 here.\n        psnr = -10 * np.log10(np.mean((preds - truths) ** 2))\n\n        self.V += psnr\n        self.N += 1\n\n    def measure(self):\n        return self.V / self.N\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(os.path.join(prefix, \"PSNR\"), self.measure(), global_step)\n\n    def report(self):\n        return f\"PSNR = {self.measure():.6f}\"\n\n\nclass RMSEMeter:\n    def __init__(self):\n        self.V = 0\n        self.N = 0\n\n    def clear(self):\n        self.V = 0\n        self.N = 0\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            if torch.is_tensor(inp):\n                inp = inp.detach().cpu().numpy()\n            outputs.append(inp)\n\n        return outputs\n\n    def update(self, preds, truths):\n        preds, truths = self.prepare_inputs(\n            preds, truths\n        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]\n\n        rmse = (truths - preds) ** 2\n        rmse = np.sqrt(rmse.mean())\n\n        self.V += rmse\n        self.N += 1\n\n    def measure(self):\n        return self.V / self.N\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(os.path.join(prefix, \"RMSE\"), self.measure(), global_step)\n\n    def report(self):\n        return f\"RMSE = {self.measure():.6f}\"\n\n\nclass MAEMeter:\n    def __init__(self, intensity_inv_scale=1.0):\n        self.V = 0\n        self.N = 0\n        self.intensity_inv_scale = intensity_inv_scale\n\n    def clear(self):\n        self.V = 0\n        self.N = 0\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            if torch.is_tensor(inp):\n                inp = inp.detach().cpu().numpy()\n            outputs.append(inp)\n\n        return outputs\n\n    def update(self, preds, truths):\n        preds, truths = self.prepare_inputs(\n            preds, truths\n        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]\n\n        # Mean Absolute Error\n        mae = np.abs(\n            truths * self.intensity_inv_scale - preds * self.intensity_inv_scale\n        ).mean()\n\n        self.V += mae\n        self.N += 1\n\n    def measure(self):\n        return self.V / self.N\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(os.path.join(prefix, \"MAE\"), self.measure(), global_step)\n\n    def report(self):\n        return f\"MAE = {self.measure():.6f}\"\n\n\nclass DepthMeter:\n    def __init__(self, scale):\n        self.V = []\n        self.N = 0\n        self.scale = scale\n        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    def clear(self):\n        self.V = []\n        self.N = 0\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            if torch.is_tensor(inp):\n                inp = inp.detach().cpu().numpy()\n            outputs.append(inp)\n\n        return outputs\n\n    def update(self, preds, truths):\n        preds = preds / self.scale\n        truths = truths / self.scale\n        preds, truths = self.prepare_inputs(\n            preds, truths\n        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]\n\n        # simplified since max_pixel_value is 1 here.\n        depth_error = self.compute_depth_errors(truths, preds)\n\n        depth_error = list(depth_error)\n        self.V.append(depth_error)\n        self.N += 1\n\n    def compute_depth_errors(\n        self, gt, pred, min_depth=1e-3, max_depth=80, thresh_set=1.25\n    ):\n        pred[pred < min_depth] = min_depth\n        pred[pred > max_depth] = max_depth\n        gt[gt < min_depth] = min_depth\n        gt[gt > max_depth] = max_depth\n\n        thresh = np.maximum((gt / pred), (pred / gt))\n        a1 = (thresh < thresh_set).mean()\n        a2 = (thresh < thresh_set**2).mean()\n        a3 = (thresh < thresh_set**3).mean()\n\n        rmse = (gt - pred) ** 2\n        rmse = np.sqrt(rmse.mean())\n\n        ssim = structural_similarity(\n            pred.squeeze(0), gt.squeeze(0), data_range=np.max(gt) - np.min(gt)\n        )\n        return rmse, a1, a2, a3, ssim\n\n    def measure(self):\n        assert self.N == len(self.V)\n        return np.array(self.V).mean(0)\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(\n            os.path.join(prefix, \"depth error\"), self.measure()[0], global_step\n        )\n\n    def report(self):\n        return f\"Depth_error(rmse, a1, a2, a3, ssim) = {self.measure()}\"\n\n\nclass PointsMeter:\n    def __init__(self, scale, intrinsics):\n        self.V = []\n        self.N = 0\n        self.scale = scale\n        self.intrinsics = intrinsics\n\n    def clear(self):\n        self.V = []\n        self.N = 0\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            if torch.is_tensor(inp):\n                inp = inp.detach().cpu().numpy()\n            outputs.append(inp)\n\n        return outputs\n\n    def update(self, preds, truths):\n        preds = preds / self.scale\n        truths = truths / self.scale\n        preds, truths = self.prepare_inputs(\n            preds, truths\n        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]\n        chamLoss = chamfer_3DDist()\n        pred_lidar = pano_to_lidar(preds[0], self.intrinsics)\n        gt_lidar = pano_to_lidar(truths[0], self.intrinsics)\n\n        dist1, dist2, idx1, idx2 = chamLoss(\n            torch.FloatTensor(pred_lidar[None, ...]).cuda(),\n            torch.FloatTensor(gt_lidar[None, ...]).cuda(),\n        )\n        chamfer_dis = dist1.mean() + dist2.mean()\n        threshold = 0.05  # monoSDF\n        f_score, precision, recall = fscore(dist1, dist2, threshold)\n        f_score = f_score.cpu()[0]\n\n        self.V.append([chamfer_dis.cpu(), f_score])\n\n        self.N += 1\n\n    def measure(self):\n        # return self.V / self.N\n        assert self.N == len(self.V)\n        return np.array(self.V).mean(0)\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(os.path.join(prefix, \"CD\"), self.measure()[0], global_step)\n\n    def report(self):\n        return f\"CD f-score = {self.measure()}\"\n\n\nclass SSIMMeter:\n    def __init__(self, device=None):\n        self.V = 0\n        self.N = 0\n\n        self.device = (\n            device\n            if device is not None\n            else torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        )\n\n    def clear(self):\n        self.V = 0\n        self.N = 0\n\n    # def prepare_inputs(self, *inputs):\n    #     outputs = []\n    #     for i, inp in enumerate(inputs):\n    #         inp = inp.permute(0, 3, 1, 2).contiguous()  # [B, 3, H, W]\n    #         inp = inp.to(self.device)\n    #         outputs.append(inp)\n    #     return outputs\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            if torch.is_tensor(inp):\n                inp = inp.detach().cpu().numpy()\n            outputs.append(inp)\n\n        return outputs\n\n    def update(self, preds, truths):\n        preds, truths = self.prepare_inputs(preds, truths)\n        ssim = structural_similarity(\n            preds.squeeze(0).squeeze(-1), truths.squeeze(0).squeeze(-1)\n        )\n\n        # preds, truths = self.prepare_inputs(\n        #     preds, truths)  # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1]\n\n        # ssim = structural_similarity_index_measure(preds, truths)\n\n        self.V += ssim\n        self.N += 1\n\n    def measure(self):\n        return self.V / self.N\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(os.path.join(prefix, \"SSIM\"), self.measure(), global_step)\n\n    def report(self):\n        return f\"SSIM = {self.measure():.6f}\"\n\n\nclass LPIPSMeter:\n    def __init__(self, net=\"alex\", device=None):\n        self.V = 0\n        self.N = 0\n        self.net = net\n\n        self.device = (\n            device\n            if device is not None\n            else torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        )\n        self.fn = lpips.LPIPS(net=net).eval().to(self.device)\n\n    def clear(self):\n        self.V = 0\n        self.N = 0\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            inp = inp.permute(0, 3, 1, 2).contiguous()  # [B, 3, H, W]\n            inp = inp.to(self.device)\n            outputs.append(inp)\n        return outputs\n\n    def update(self, preds, truths):\n        preds, truths = self.prepare_inputs(\n            preds, truths\n        )  # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1]\n        v = self.fn(\n            truths, preds, normalize=True\n        ).item()  # normalize=True: [0, 1] to [-1, 1]\n        self.V += v\n        self.N += 1\n\n    def measure(self):\n        return self.V / self.N\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(\n            os.path.join(prefix, f\"LPIPS ({self.net})\"), self.measure(), global_step\n        )\n\n    def report(self):\n        return f\"LPIPS ({self.net}) = {self.measure():.6f}\"\n\n\nclass Trainer(object):\n    def __init__(\n        self,\n        name,  # name of this experiment\n        opt,  # extra conf\n        model,  # network\n        criterion=None,  # loss function, if None, assume inline implementation in train_step\n        optimizer=None,  # optimizer\n        ema_decay=None,  # if use EMA, set the decay\n        lr_scheduler=None,  # scheduler\n        metrics=[],  # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.\n        depth_metrics=[],\n        local_rank=0,  # which GPU am I\n        world_size=1,  # total num of GPUs\n        device=None,  # device to use, usually setting to None is OK. (auto choose device)\n        mute=False,  # whether to mute all print\n        fp16=False,  # amp optimize level\n        eval_interval=1,  # eval once every $ epoch\n        max_keep_ckpt=2,  # max num of saved ckpts in disk\n        workspace=\"workspace\",  # workspace to save logs & ckpts\n        best_mode=\"min\",  # the smaller/larger result, the better\n        use_loss_as_metric=True,  # use loss as the first metric\n        report_metric_at_train=False,  # also report metrics at training\n        use_checkpoint=\"latest\",  # which ckpt to use at init time\n        use_tensorboardX=True,  # whether to use tensorboard for logging\n        scheduler_update_every_step=False,  # whether to call scheduler.step() after every train step\n    ):\n        self.name = name\n        self.opt = opt\n        self.mute = mute\n        self.metrics = metrics\n        self.depth_metrics = depth_metrics\n        self.local_rank = local_rank\n        self.world_size = world_size\n        self.workspace = workspace\n        self.ema_decay = ema_decay\n        self.fp16 = fp16\n        self.best_mode = best_mode\n        self.use_loss_as_metric = use_loss_as_metric\n        self.report_metric_at_train = report_metric_at_train\n        self.max_keep_ckpt = max_keep_ckpt\n        self.eval_interval = eval_interval\n        self.use_checkpoint = use_checkpoint\n        self.use_tensorboardX = use_tensorboardX\n        self.time_stamp = time.strftime(\"%Y-%m-%d_%H-%M-%S\")\n        self.scheduler_update_every_step = scheduler_update_every_step\n        self.device = (\n            device\n            if device is not None\n            else torch.device(\n                f\"cuda:{local_rank}\" if torch.cuda.is_available() else \"cpu\"\n            )\n        )\n        self.console = Console()\n\n        model.to(self.device)\n        if self.world_size > 1:\n            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)\n            model = torch.nn.parallel.DistributedDataParallel(\n                model, device_ids=[local_rank]\n            )\n        self.model = model\n\n        if isinstance(criterion, nn.Module):\n            criterion.to(self.device)\n        self.criterion = criterion\n\n        # optionally use LPIPS loss for patch-based training\n        # if self.opt.patch_size > 1:\n        #     import lpips\n        #     self.criterion_lpips = lpips.LPIPS(net='alex').to(self.device)\n\n        if optimizer is None:\n            self.optimizer = optim.Adam(\n                self.model.parameters(), lr=0.001, weight_decay=5e-4\n            )  # naive adam\n        else:\n            self.optimizer = optimizer(self.model)\n\n        if lr_scheduler is None:\n            self.lr_scheduler = optim.lr_scheduler.LambdaLR(\n                self.optimizer, lr_lambda=lambda epoch: 1\n            )  # fake scheduler\n        else:\n            self.lr_scheduler = lr_scheduler(self.optimizer)\n\n        if ema_decay is not None:\n            self.ema = ExponentialMovingAverage(\n                self.model.parameters(), decay=ema_decay\n            )\n        else:\n            self.ema = None\n\n        self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)\n\n        # variable init\n        self.epoch = 0\n        self.global_step = 0\n        self.local_step = 0\n        self.stats = {\n            \"loss\": [],\n            \"valid_loss\": [],\n            \"results\": [],  # metrics[0], or valid_loss\n            \"checkpoints\": [],  # record path of saved ckpt, to automatically remove old ckpt\n            \"best_result\": None,\n        }\n\n        # auto fix\n        if len(metrics) == 0 or self.use_loss_as_metric:\n            self.best_mode = \"min\"\n\n        # workspace prepare\n        self.log_ptr = None\n        if self.workspace is not None:\n            os.makedirs(self.workspace, exist_ok=True)\n            self.log_path = os.path.join(workspace, f\"log_{self.name}.txt\")\n            self.log_ptr = open(self.log_path, \"a+\")\n\n            self.ckpt_path = os.path.join(self.workspace, \"checkpoints\")\n            self.best_path = f\"{self.ckpt_path}/{self.name}.pth\"\n            os.makedirs(self.ckpt_path, exist_ok=True)\n\n        self.log(\n            f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {\"fp16\" if self.fp16 else \"fp32\"} | {self.workspace}'\n        )\n        self.log(\n            f\"[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}\"\n        )\n\n        if self.workspace is not None:\n            if self.use_checkpoint == \"scratch\":\n                self.log(\"[INFO] Training from scratch ...\")\n            elif self.use_checkpoint == \"latest\":\n                self.log(\"[INFO] Loading latest checkpoint ...\")\n                self.load_checkpoint()\n            elif self.use_checkpoint == \"latest_model\":\n                self.log(\"[INFO] Loading latest checkpoint (model only)...\")\n                self.load_checkpoint(model_only=True)\n            elif self.use_checkpoint == \"best\":\n                if os.path.exists(self.best_path):\n                    self.log(\"[INFO] Loading best checkpoint ...\")\n                    self.load_checkpoint(self.best_path)\n                else:\n                    self.log(f\"[INFO] {self.best_path} not found, loading latest ...\")\n                    self.load_checkpoint()\n            else:  # path to ckpt\n                self.log(f\"[INFO] Loading {self.use_checkpoint} ...\")\n                self.load_checkpoint(self.use_checkpoint)\n\n    def __del__(self):\n        if self.log_ptr:\n            self.log_ptr.close()\n\n    def log(self, *args, **kwargs):\n        if self.local_rank == 0:\n            if not self.mute:\n                # print(*args)\n                self.console.print(*args, **kwargs)\n            if self.log_ptr:\n                print(*args, file=self.log_ptr)\n                self.log_ptr.flush()  # write immediately to file\n\n    ### ------------------------------\n\n    def train_step(self, data):\n        # Initialize all returned values\n        pred_intensity = None\n        gt_intensity = None\n        pred_depth = None\n        gt_depth = None\n        loss = 0\n\n        if self.opt.enable_lidar:\n            rays_o_lidar = data[\"rays_o_lidar\"]  # [B, N, 3]\n            rays_d_lidar = data[\"rays_d_lidar\"]  # [B, N, 3]\n\n            images_lidar = data[\"images_lidar\"]  # [B, N, 3/4]\n            B_lidar, N_lidar, C_lidar = images_lidar.shape\n\n            gt_raydrop = images_lidar[:, :, 0]\n            gt_intensity = images_lidar[:, :, 1] * gt_raydrop\n            gt_depth = images_lidar[:, :, 2] * gt_raydrop\n\n            outputs_lidar = self.model.render(\n                rays_o_lidar,\n                rays_d_lidar,\n                cal_lidar_color=True,\n                staged=False,\n                perturb=True,\n                force_all_rays=False if self.opt.patch_size == 1 else True,\n                **vars(self.opt),\n            )\n\n            pred_raydrop = outputs_lidar[\"image_lidar\"][:, :, 0]\n            pred_intensity = outputs_lidar[\"image_lidar\"][:, :, 1] * gt_raydrop\n            pred_depth = outputs_lidar[\"depth_lidar\"] * gt_raydrop\n            lidar_loss = (\n                self.opt.alpha_d * self.criterion[\"depth\"](pred_depth, gt_depth)\n                + self.opt.alpha_r * self.criterion[\"raydrop\"](pred_raydrop, gt_raydrop)\n                + self.opt.alpha_i\n                * self.criterion[\"intensity\"](pred_intensity, gt_intensity)\n            )\n            pred_intensity = pred_intensity.unsqueeze(-1)\n            gt_intensity = gt_intensity.unsqueeze(-1)\n        else:\n            lidar_loss = 0\n\n        loss = lidar_loss\n\n        # special case for CCNeRF's rank-residual training\n        if len(loss.shape) == 3:  # [K, B, N]\n            loss = loss.mean(0)\n\n        loss = loss.mean()\n\n        if isinstance(self.opt.patch_size_lidar, int):\n            patch_size_x, patch_size_y = (\n                self.opt.patch_size_lidar,\n                self.opt.patch_size_lidar,\n            )\n        elif len(self.opt.patch_size_lidar) == 1:\n            patch_size_x, patch_size_y = (\n                self.opt.patch_size_lidar[0],\n                self.opt.patch_size_lidar[0],\n            )\n        else:\n            patch_size_x, patch_size_y = self.opt.patch_size_lidar\n        if self.opt.enable_lidar and patch_size_x > 1:\n            pred_depth = (\n                pred_depth.view(-1, patch_size_x, patch_size_y, 1)\n                .permute(0, 3, 1, 2)\n                .contiguous()\n                / self.opt.scale\n            )\n            if self.opt.sobel_grad:\n                pred_grad_x = F.conv2d(\n                    pred_depth,\n                    torch.tensor(\n                        [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32\n                    )\n                    .unsqueeze(0)\n                    .unsqueeze(0)\n                    .to(self.device),\n                    padding=1,\n                )\n                pred_grad_y = F.conv2d(\n                    pred_depth,\n                    torch.tensor(\n                        [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32\n                    )\n                    .unsqueeze(0)\n                    .unsqueeze(0)\n                    .to(self.device),\n                    padding=1,\n                )\n            else:\n                pred_grad_y = torch.abs(\n                    pred_depth[:, :, :-1, :] - pred_depth[:, :, 1:, :]\n                )\n                pred_grad_x = torch.abs(\n                    pred_depth[:, :, :, :-1] - pred_depth[:, :, :, 1:]\n                )\n\n            dy = torch.abs(pred_grad_y)\n            dx = torch.abs(pred_grad_x)\n\n            if self.opt.grad_norm_smooth:\n                grad_norm = torch.mean(torch.exp(-dx)) + torch.mean(torch.exp(-dy))\n                # print('grad_norm', grad_norm)\n                loss = loss + self.opt.alpha_grad_norm * grad_norm\n\n            if self.opt.spatial_smooth:\n                spatial_loss = torch.mean(dx**2) + torch.mean(dy**2)\n                # print('spatial_loss', spatial_loss)\n                loss = loss + self.opt.alpha_spatial * spatial_loss\n\n            if self.opt.tv_loss:\n                tv_loss = torch.mean(dx) + torch.mean(dy)\n                # print('tv_loss', tv_loss)\n                loss = loss + self.opt.alpha_tv * tv_loss\n\n            if self.opt.grad_loss:\n                gt_depth = (\n                    gt_depth.view(-1, patch_size_x, patch_size_y, 1)\n                    .permute(0, 3, 1, 2)\n                    .contiguous()\n                    / self.opt.scale\n                )\n                gt_raydrop = (\n                    gt_raydrop.view(-1, patch_size_x, patch_size_y, 1)\n                    .permute(0, 3, 1, 2)\n                    .contiguous()\n                )\n\n                # sobel\n                if self.opt.sobel_grad:\n                    gt_grad_y = F.conv2d(\n                        gt_depth,\n                        torch.tensor(\n                            [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32\n                        )\n                        .unsqueeze(0)\n                        .unsqueeze(0)\n                        .to(self.device),\n                        padding=1,\n                    )\n\n                    gt_grad_x = F.conv2d(\n                        gt_depth,\n                        torch.tensor(\n                            [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32\n                        )\n                        .unsqueeze(0)\n                        .unsqueeze(0)\n                        .to(self.device),\n                        padding=1,\n                    )\n                else:\n                    gt_grad_y = gt_depth[:, :, :-1, :] - gt_depth[:, :, 1:, :]\n                    gt_grad_x = gt_depth[:, :, :, :-1] - gt_depth[:, :, :, 1:]\n\n                grad_clip_x = 0.01\n                grad_mask_x = torch.where(torch.abs(gt_grad_x) < grad_clip_x, 1, 0)\n                grad_clip_y = 0.01\n                grad_mask_y = torch.where(torch.abs(gt_grad_y) < grad_clip_y, 1, 0)\n                if self.opt.sobel_grad:\n                    mask_dx = gt_raydrop * grad_mask_x\n                    mask_dy = gt_raydrop * grad_mask_y\n                else:\n                    mask_dx = gt_raydrop[:, :, :, :-1] * grad_mask_x\n                    mask_dy = gt_raydrop[:, :, :-1, :] * grad_mask_y\n\n                if self.opt.depth_grad_loss == \"cos\":\n                    patch_num = pred_grad_x.shape[0]\n                    grad_loss = self.criterion[\"grad\"](\n                        (pred_grad_x * mask_dx).reshape(patch_num, -1),\n                        (gt_grad_x * mask_dx).reshape(patch_num, -1),\n                    )\n                    grad_loss = 1 - grad_loss\n                else:\n                    grad_loss = self.criterion[\"grad\"](\n                        pred_grad_x * mask_dx, gt_grad_x * mask_dx\n                    )\n                loss = loss + self.opt.alpha_grad * grad_loss.mean()\n\n        return (\n            pred_intensity,\n            gt_intensity,\n            pred_depth,\n            gt_depth,\n            loss,\n        )\n\n    def eval_step(self, data):\n        pred_intensity = None\n        pred_depth = None\n        pred_depth_crop = None\n        pred_raydrop = None\n        gt_intensity = None\n        gt_depth = None\n        gt_depth_crop = None\n        gt_raydrop = None\n        loss = 0\n        if self.opt.enable_lidar:\n            rays_o_lidar = data[\"rays_o_lidar\"]  # [B, N, 3]\n            rays_d_lidar = data[\"rays_d_lidar\"]  # [B, N, 3]\n            images_lidar = data[\"images_lidar\"]  # [B, H, W, 3/4]\n\n            gt_raydrop = images_lidar[:, :, :, 0]\n            if self.opt.dataloader == \"nerf_mvl\":\n                valid_crop = gt_raydrop != -1\n                valid_crop_idx = torch.nonzero(valid_crop)\n                crop_h, crop_w = (\n                    max(valid_crop_idx[:, 1]) - min(valid_crop_idx[:, 1]) + 1,\n                    max(valid_crop_idx[:, 2]) - min(valid_crop_idx[:, 2]) + 1,\n                )\n\n                valid_mask = torch.where(gt_raydrop == -1, 0, 1)\n                gt_raydrop = gt_raydrop * valid_mask\n\n            gt_intensity = images_lidar[:, :, :, 1] * gt_raydrop\n            gt_depth = images_lidar[:, :, :, 2] * gt_raydrop\n            B_lidar, H_lidar, W_lidar, C_lidar = images_lidar.shape\n\n            outputs_lidar = self.model.render(\n                rays_o_lidar,\n                rays_d_lidar,\n                cal_lidar_color=True,\n                staged=True,\n                perturb=False,\n                **vars(self.opt),\n            )\n\n            pred_rgb_lidar = outputs_lidar[\"image_lidar\"].reshape(\n                B_lidar, H_lidar, W_lidar, 2\n            )\n            pred_raydrop = pred_rgb_lidar[:, :, :, 0]\n            raydrop_mask = torch.where(pred_raydrop > 0.5, 1, 0)\n            if self.opt.dataloader == \"nerf_mvl\":\n                raydrop_mask = raydrop_mask * valid_mask\n            pred_intensity = pred_rgb_lidar[:, :, :, 1]\n            pred_depth = outputs_lidar[\"depth_lidar\"].reshape(B_lidar, H_lidar, W_lidar)\n            # raydrop_mask = gt_raydrop  # TODO\n            if self.opt.alpha_r > 0 and (not torch.all(raydrop_mask == 0)):\n                pred_intensity = pred_intensity * raydrop_mask\n                pred_depth = pred_depth * raydrop_mask\n\n            lidar_loss = (\n                self.opt.alpha_d * self.criterion[\"depth\"](pred_depth, gt_depth).mean()\n                + self.opt.alpha_r\n                * self.criterion[\"raydrop\"](pred_raydrop, gt_raydrop).mean()\n                + self.opt.alpha_i\n                * self.criterion[\"intensity\"](pred_intensity, gt_intensity).mean()\n            )\n\n            if self.opt.dataloader == \"nerf_mvl\":\n                pred_intensity = pred_intensity[valid_crop].reshape(\n                    B_lidar, crop_h, crop_w\n                )\n                gt_intensity = gt_intensity[valid_crop].reshape(B_lidar, crop_h, crop_w)\n                pred_depth_crop = pred_depth[valid_crop].reshape(\n                    B_lidar, crop_h, crop_w\n                )\n                gt_depth_crop = gt_depth[valid_crop].reshape(B_lidar, crop_h, crop_w)\n\n            pred_intensity = pred_intensity.unsqueeze(-1)\n            pred_raydrop = pred_raydrop.unsqueeze(-1)\n            gt_intensity = gt_intensity.unsqueeze(-1)\n            gt_raydrop = gt_raydrop.unsqueeze(-1)\n        else:\n            lidar_loss = 0\n\n        loss = lidar_loss\n\n        return (\n            pred_intensity,\n            pred_depth,\n            pred_depth_crop,\n            pred_raydrop,\n            gt_intensity,\n            gt_depth,\n            gt_depth_crop,\n            gt_raydrop,\n            loss,\n        )\n\n    # moved out bg_color and perturb for more flexible control...\n    def test_step(self, data, bg_color=None, perturb=False):\n        pred_raydrop = None\n        pred_intensity = None\n        pred_depth = None\n\n        if self.opt.enable_lidar:\n            rays_o_lidar = data[\"rays_o_lidar\"]  # [B, N, 3]\n            rays_d_lidar = data[\"rays_d_lidar\"]  # [B, N, 3]\n            H_lidar, W_lidar = data[\"H_lidar\"], data[\"W_lidar\"]\n            outputs_lidar = self.model.render(\n                rays_o_lidar,\n                rays_d_lidar,\n                cal_lidar_color=True,\n                staged=True,\n                perturb=perturb,\n                **vars(self.opt),\n            )\n\n            pred_rgb_lidar = outputs_lidar[\"image_lidar\"].reshape(\n                -1, H_lidar, W_lidar, 2\n            )\n            pred_raydrop = pred_rgb_lidar[:, :, :, 0]\n            raydrop_mask = torch.where(pred_raydrop > 0.5, 1, 0)\n            pred_intensity = pred_rgb_lidar[:, :, :, 1]\n            pred_depth = outputs_lidar[\"depth_lidar\"].reshape(-1, H_lidar, W_lidar)\n            if self.opt.alpha_r > 0:\n                pred_intensity = pred_intensity * raydrop_mask\n                pred_depth = pred_depth * raydrop_mask\n\n        return pred_raydrop, pred_intensity, pred_depth\n\n    def save_mesh(self, save_path=None, resolution=256, threshold=10):\n        if save_path is None:\n            save_path = os.path.join(\n                self.workspace, \"meshes\", f\"{self.name}_{self.epoch}.ply\"\n            )\n\n        self.log(f\"==> Saving mesh to {save_path}\")\n\n        os.makedirs(os.path.dirname(save_path), exist_ok=True)\n\n        def query_func(pts):\n            with torch.no_grad():\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    sigma = self.model.density(pts.to(self.device))[\"sigma\"]\n            return sigma\n\n        vertices, triangles = extract_geometry(\n            self.model.aabb_infer[:3],\n            self.model.aabb_infer[3:],\n            resolution=resolution,\n            threshold=threshold,\n            query_func=query_func,\n        )\n\n        mesh = trimesh.Trimesh(\n            vertices, triangles, process=False\n        )  # important, process=True leads to seg fault...\n        mesh.export(save_path)\n\n        self.log(f\"==> Finished saving mesh.\")\n\n    ### ------------------------------\n\n    def train(self, train_loader, valid_loader, max_epochs):\n        if self.use_tensorboardX and self.local_rank == 0:\n            if is_ali_cluster() and self.opt.cluster_summary_path is not None:\n                summary_path = self.opt.cluster_summary_path\n            else:\n                summary_path = os.path.join(self.workspace, \"run\", self.name)\n            self.writer = tensorboardX.SummaryWriter(summary_path)\n\n        change_dataloder = False\n        if self.opt.change_patch_size_lidar[0] > 1:\n            change_dataloder = True\n        for epoch in range(self.epoch + 1, max_epochs + 1):\n            self.epoch = epoch\n            if change_dataloder:\n                if self.epoch % self.opt.change_patch_size_epoch == 0:\n                    train_loader._data.patch_size_lidar = (\n                        self.opt.change_patch_size_lidar\n                    )\n                    self.opt.patch_size_lidar = self.opt.change_patch_size_lidar\n                else:\n                    train_loader._data.patch_size_lidar = 1\n                    self.opt.patch_size_lidar = 1\n\n            self.train_one_epoch(train_loader)\n\n            if self.workspace is not None and self.local_rank == 0:\n                self.save_checkpoint(full=True, best=False)\n\n            if self.epoch % self.eval_interval == 0:\n                self.evaluate_one_epoch(valid_loader)\n                self.save_checkpoint(full=False, best=True)\n\n        if self.use_tensorboardX and self.local_rank == 0:\n            self.writer.close()\n\n    def evaluate(self, loader, name=None):\n        self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX\n        self.evaluate_one_epoch(loader, name)\n        self.use_tensorboardX = use_tensorboardX\n\n    def test(self, loader, save_path=None, name=None, write_video=True):\n        if save_path is None:\n            save_path = os.path.join(self.workspace, \"results\")\n\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n\n        os.makedirs(save_path, exist_ok=True)\n\n        self.log(f\"==> Start Test, save results to {save_path}\")\n\n        pbar = tqdm.tqdm(\n            total=len(loader) * loader.batch_size,\n            bar_format=\"{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n        )\n        self.model.eval()\n\n        if write_video:\n            all_preds = []\n            all_preds_depth = []\n\n        with torch.no_grad():\n            for i, data in enumerate(loader):\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    preds_raydrop, preds_intensity, preds_depth = self.test_step(data)\n\n                if self.opt.enable_lidar:\n                    pred_raydrop = preds_raydrop[0].detach().cpu().numpy()\n                    pred_raydrop = (np.where(pred_raydrop > 0.5, 1.0, 0.0)).reshape(\n                        loader._data.H_lidar, loader._data.W_lidar\n                    )\n                    pred_raydrop = (pred_raydrop * 255).astype(np.uint8)\n\n                    pred_intensity = preds_intensity[0].detach().cpu().numpy()\n                    pred_intensity = (pred_intensity * 255).astype(np.uint8)\n\n                    pred_depth = preds_depth[0].detach().cpu().numpy()\n                    pred_lidar = pano_to_lidar(\n                        pred_depth / self.opt.scale, loader._data.intrinsics_lidar\n                    )\n                    if self.opt.dataloader == \"nerf_mvl\":\n                        pred_lidar = filter_bbox_dataset(\n                            pred_lidar, data[\"OBB_local\"][:, :3]\n                        )\n\n                    np.save(\n                        os.path.join(save_path, f\"test_{name}_{i:04d}_depth_lidar.npy\"),\n                        pred_lidar,\n                    )\n\n                    pred_depth = (pred_depth * 255).astype(np.uint8)\n                    # pred_depth = (pred_depth / self.opt.scale).astype(np.uint8)\n\n                    if write_video:\n                        all_preds.append(cv2.applyColorMap(pred_intensity, 1))\n                        all_preds_depth.append(cv2.applyColorMap(pred_depth, 9))\n                    else:\n                        cv2.imwrite(\n                            os.path.join(save_path, f\"test_{name}_{i:04d}_raydrop.png\"),\n                            pred_raydrop,\n                        )\n                        cv2.imwrite(\n                            os.path.join(\n                                save_path, f\"test_{name}_{i:04d}_intensity.png\"\n                            ),\n                            cv2.applyColorMap(pred_intensity, 1),\n                        )\n                        cv2.imwrite(\n                            os.path.join(save_path, f\"test_{name}_{i:04d}_depth.png\"),\n                            cv2.applyColorMap(pred_depth, 9),\n                        )\n\n                pbar.update(loader.batch_size)\n\n        if write_video:\n            if self.opt.enable_lidar:\n                all_preds = np.stack(all_preds, axis=0)\n                all_preds_depth = np.stack(all_preds_depth, axis=0)\n                imageio.mimwrite(\n                    os.path.join(save_path, f\"{name}_lidar_rgb.mp4\"),\n                    all_preds,\n                    fps=25,\n                    quality=8,\n                    macro_block_size=1,\n                )\n                imageio.mimwrite(\n                    os.path.join(save_path, f\"{name}_depth.mp4\"),\n                    all_preds_depth,\n                    fps=25,\n                    quality=8,\n                    macro_block_size=1,\n                )\n\n        self.log(f\"==> Finished Test.\")\n\n    def train_one_epoch(self, loader):\n        self.log(\n            f\"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...\"\n        )\n\n        total_loss = 0\n        if self.local_rank == 0 and self.report_metric_at_train:\n            for metric in self.metrics:\n                metric.clear()\n            for metric in self.depth_metrics:\n                metric.clear()\n\n        self.model.train()\n\n        # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs\n        # ref: https://pytorch.org/docs/stable/data.html\n        if self.world_size > 1:\n            loader.sampler.set_epoch(self.epoch)\n\n        if self.local_rank == 0:\n            pbar = tqdm.tqdm(\n                total=len(loader) * loader.batch_size,\n                bar_format=\"{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n            )\n\n        self.local_step = 0\n\n        for data in loader:\n            self.local_step += 1\n            self.global_step += 1\n\n            self.optimizer.zero_grad()\n\n            with torch.cuda.amp.autocast(enabled=self.fp16):\n                (\n                    pred_intensity,\n                    gt_intensity,\n                    pred_depth,\n                    gt_depth,\n                    loss,\n                ) = self.train_step(data)\n\n            self.scaler.scale(loss).backward()\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n\n            if self.scheduler_update_every_step:\n                self.lr_scheduler.step()\n\n            loss_val = loss.item()\n            total_loss += loss_val\n\n            if self.local_rank == 0:\n                if self.report_metric_at_train:\n                    for i, metric in enumerate(self.depth_metrics):\n                        if i < 2:  # hard code\n                            metric.update(pred_intensity, gt_intensity)\n                        else:\n                            metric.update(pred_depth, gt_depth)\n\n                if self.use_tensorboardX:\n                    self.writer.add_scalar(\"train/loss\", loss_val, self.global_step)\n                    self.writer.add_scalar(\n                        \"train/lr\",\n                        self.optimizer.param_groups[0][\"lr\"],\n                        self.global_step,\n                    )\n\n                if self.scheduler_update_every_step:\n                    pbar.set_description(\n                        f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}\"\n                    )\n                else:\n                    pbar.set_description(\n                        f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\"\n                    )\n                pbar.update(loader.batch_size)\n\n        if self.ema is not None:\n            self.ema.update()\n\n        average_loss = total_loss / self.local_step\n        self.stats[\"loss\"].append(average_loss)\n\n        if self.local_rank == 0:\n            pbar.close()\n            if self.report_metric_at_train:\n                for metric in self.depth_metrics:\n                    self.log(metric.report(), style=\"red\")\n                    if self.use_tensorboardX:\n                        metric.write(self.writer, self.epoch, prefix=\"LiDAR_train\")\n                    metric.clear()\n\n        if not self.scheduler_update_every_step:\n            if isinstance(\n                self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau\n            ):\n                self.lr_scheduler.step(average_loss)\n            else:\n                self.lr_scheduler.step()\n\n        self.log(f\"==> Finished Epoch {self.epoch}.\")\n\n    def evaluate_one_epoch(self, loader, name=None):\n        self.log(f\"++> Evaluate at epoch {self.epoch} ...\")\n\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n\n        total_loss = 0\n        if self.local_rank == 0:\n            for metric in self.metrics:\n                metric.clear()\n            for metric in self.depth_metrics:\n                metric.clear()\n\n        self.model.eval()\n\n        if self.ema is not None:\n            self.ema.store()\n            self.ema.copy_to()\n\n        if self.local_rank == 0:\n            pbar = tqdm.tqdm(\n                total=len(loader) * loader.batch_size,\n                bar_format=\"{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n            )\n\n        with torch.no_grad():\n            self.local_step = 0\n\n            for data in loader:\n                self.local_step += 1\n\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    (\n                        preds_intensity,\n                        preds_depth,\n                        preds_depth_crop,\n                        preds_raydrop,\n                        gt_intensity,\n                        gt_depth,\n                        gt_depth_crop,\n                        gt_raydrop,\n                        loss,\n                    ) = self.eval_step(data)\n\n                # all_gather/reduce the statistics (NCCL only support all_*)\n                if self.world_size > 1:\n                    dist.all_reduce(loss, op=dist.ReduceOp.SUM)\n                    loss = loss / self.world_size\n\n                    preds_list = [\n                        torch.zeros_like(preds).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(preds_list, preds)\n                    preds = torch.cat(preds_list, dim=0)\n\n                    preds_depth_list = [\n                        torch.zeros_like(preds_depth).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(preds_depth_list, preds_depth)\n                    preds_depth = torch.cat(preds_depth_list, dim=0)\n\n                    truths_list = [\n                        torch.zeros_like(truths).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(truths_list, truths)\n                    truths = torch.cat(truths_list, dim=0)\n\n                loss_val = loss.item()\n                total_loss += loss_val\n\n                # only rank = 0 will perform evaluation.\n                if self.local_rank == 0:\n                    for i, metric in enumerate(self.depth_metrics):\n                        if i < 2:  # hard code\n                            metric.update(preds_intensity, gt_intensity)\n                        else:\n                            if (\n                                self.opt.dataloader == \"nerf_mvl\" and i == 2\n                            ):  # hard code\n                                metric.update(preds_depth_crop, gt_depth_crop)\n                            else:\n                                metric.update(preds_depth, gt_depth)\n\n                    if self.opt.enable_lidar:\n                        save_path_raydrop = os.path.join(\n                            self.workspace,\n                            \"validation\",\n                            f\"{name}_{self.local_step:04d}_rarydrop.png\",\n                        )\n                        save_path_intensity = os.path.join(\n                            self.workspace,\n                            \"validation\",\n                            f\"{name}_{self.local_step:04d}_intensity.png\",\n                        )\n                        save_path_depth = os.path.join(\n                            self.workspace,\n                            \"validation\",\n                            f\"{name}_{self.local_step:04d}_depth.png\",\n                        )\n                        os.makedirs(os.path.dirname(save_path_depth), exist_ok=True)\n\n                        pred_intensity = preds_intensity[0].detach().cpu().numpy()\n                        pred_intensity = (pred_intensity * 255).astype(np.uint8)\n\n                        pred_raydrop = preds_raydrop[0].detach().cpu().numpy()\n                        pred_raydrop = (np.where(pred_raydrop > 0.5, 1.0, 0.0)).reshape(\n                            loader._data.H_lidar, loader._data.W_lidar\n                        )\n                        pred_raydrop = (pred_raydrop * 255).astype(np.uint8)\n\n                        pred_depth = preds_depth[0].detach().cpu().numpy()\n                        pred_lidar = pano_to_lidar(\n                            pred_depth / self.opt.scale, loader._data.intrinsics_lidar\n                        )\n                        pred_depth = (pred_depth * 255).astype(np.uint8)\n                        # pred_depth = (pred_depth / self.opt.scale).astype(np.uint8)\n\n                        # cv2.imwrite(save_path, cv2.cvtColor(pred, cv2.COLOR_RGB2BGR))\n                        cv2.imwrite(save_path_raydrop, pred_raydrop)\n                        cv2.imwrite(\n                            save_path_intensity, cv2.applyColorMap(pred_intensity, 1)\n                        )\n                        cv2.imwrite(save_path_depth, cv2.applyColorMap(pred_depth, 9))\n                        np.save(\n                            os.path.join(\n                                self.workspace,\n                                \"validation\",\n                                f\"{name}_{self.local_step:04d}_lidar.npy\",\n                            ),\n                            pred_lidar,\n                        )\n\n                    pbar.set_description(\n                        f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\"\n                    )\n                    pbar.update(loader.batch_size)\n\n        average_loss = total_loss / self.local_step\n        self.stats[\"valid_loss\"].append(average_loss)\n\n        if self.local_rank == 0:\n            pbar.close()\n            if len(self.depth_metrics) > 0:\n                # result = self.metrics[0].measure()\n                result = self.depth_metrics[-1].measure()[0]  # hard code\n                self.stats[\"results\"].append(\n                    result if self.best_mode == \"min\" else -result\n                )  # if max mode, use -result\n            else:\n                self.stats[\"results\"].append(\n                    average_loss\n                )  # if no metric, choose best by min loss\n\n            for metric in self.depth_metrics:\n                self.log(metric.report(), style=\"blue\")\n                if self.use_tensorboardX:\n                    metric.write(self.writer, self.epoch, prefix=\"LiDAR_evaluate\")\n                metric.clear()\n\n        if self.ema is not None:\n            self.ema.restore()\n\n        self.log(f\"++> Evaluate epoch {self.epoch} Finished.\")\n\n    def save_checkpoint(self, name=None, full=False, best=False, remove_old=True):\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n\n        state = {\n            \"epoch\": self.epoch,\n            \"global_step\": self.global_step,\n            \"stats\": self.stats,\n        }\n\n        if full:\n            state[\"optimizer\"] = self.optimizer.state_dict()\n            state[\"lr_scheduler\"] = self.lr_scheduler.state_dict()\n            state[\"scaler\"] = self.scaler.state_dict()\n            if self.ema is not None:\n                state[\"ema\"] = self.ema.state_dict()\n\n        if not best:\n            state[\"model\"] = self.model.state_dict()\n\n            file_path = f\"{self.ckpt_path}/{name}.pth\"\n\n            if remove_old:\n                self.stats[\"checkpoints\"].append(file_path)\n\n                if len(self.stats[\"checkpoints\"]) > self.max_keep_ckpt:\n                    old_ckpt = self.stats[\"checkpoints\"].pop(0)\n                    if os.path.exists(old_ckpt):\n                        os.remove(old_ckpt)\n\n            torch.save(state, file_path)\n\n        else:\n            if len(self.stats[\"results\"]) > 0:\n                if (\n                    self.stats[\"best_result\"] is None\n                    or self.stats[\"results\"][-1] < self.stats[\"best_result\"]\n                ):\n                    self.log(\n                        f\"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}\"\n                    )\n                    self.stats[\"best_result\"] = self.stats[\"results\"][-1]\n\n                    # save ema results\n                    if self.ema is not None:\n                        self.ema.store()\n                        self.ema.copy_to()\n\n                    state[\"model\"] = self.model.state_dict()\n\n                    # we don't consider continued training from the best ckpt, so we discard the unneeded density_grid to save some storage (especially important for dnerf)\n                    if \"density_grid\" in state[\"model\"]:\n                        del state[\"model\"][\"density_grid\"]\n\n                    if self.ema is not None:\n                        self.ema.restore()\n\n                    torch.save(state, self.best_path)\n            else:\n                self.log(\n                    f\"[WARN] no evaluated results found, skip saving best checkpoint.\"\n                )\n\n    def load_checkpoint(self, checkpoint=None, model_only=False):\n        if checkpoint is None:\n            checkpoint_list = sorted(glob.glob(f\"{self.ckpt_path}/{self.name}_ep*.pth\"))\n            if checkpoint_list:\n                checkpoint = checkpoint_list[-1]\n                self.log(f\"[INFO] Latest checkpoint is {checkpoint}\")\n            else:\n                self.log(\"[WARN] No checkpoint found, model randomly initialized.\")\n                return\n\n        checkpoint_dict = torch.load(checkpoint, map_location=self.device)\n\n        if \"model\" not in checkpoint_dict:\n            self.model.load_state_dict(checkpoint_dict)\n            self.log(\"[INFO] loaded model.\")\n            return\n\n        missing_keys, unexpected_keys = self.model.load_state_dict(\n            checkpoint_dict[\"model\"], strict=False\n        )\n        self.log(\"[INFO] loaded model.\")\n        if len(missing_keys) > 0:\n            self.log(f\"[WARN] missing keys: {missing_keys}\")\n        if len(unexpected_keys) > 0:\n            self.log(f\"[WARN] unexpected keys: {unexpected_keys}\")\n\n        if self.ema is not None and \"ema\" in checkpoint_dict:\n            self.ema.load_state_dict(checkpoint_dict[\"ema\"])\n\n        if model_only:\n            return\n\n        self.stats = checkpoint_dict[\"stats\"]\n        self.epoch = checkpoint_dict[\"epoch\"]\n        self.global_step = checkpoint_dict[\"global_step\"]\n        self.log(f\"[INFO] load at epoch {self.epoch}, global step {self.global_step}\")\n\n        if self.optimizer and \"optimizer\" in checkpoint_dict:\n            try:\n                self.optimizer.load_state_dict(checkpoint_dict[\"optimizer\"])\n                self.log(\"[INFO] loaded optimizer.\")\n            except:\n                self.log(\"[WARN] Failed to load optimizer.\")\n\n        if self.lr_scheduler and \"lr_scheduler\" in checkpoint_dict:\n            try:\n                self.lr_scheduler.load_state_dict(checkpoint_dict[\"lr_scheduler\"])\n                self.log(\"[INFO] loaded scheduler.\")\n            except:\n                self.log(\"[WARN] Failed to load scheduler.\")\n\n        if self.scaler and \"scaler\" in checkpoint_dict:\n            try:\n                self.scaler.load_state_dict(checkpoint_dict[\"scaler\"])\n                self.log(\"[INFO] loaded scaler.\")\n            except:\n                self.log(\"[WARN] Failed to load scaler.\")\n"
  },
  {
    "path": "lidarnerf/raymarching/__init__.py",
    "content": ""
  },
  {
    "path": "lidarnerf/raymarching/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(\n    name=\"_raymarching\",\n    extra_cflags=c_flags,\n    extra_cuda_cflags=nvcc_flags,\n    sources=[\n        os.path.join(_src_path, \"src\", f)\n        for f in [\n            \"raymarching.cu\",\n            \"bindings.cpp\",\n        ]\n    ],\n)\n\n__all__ = [\"_backend\"]\n"
  },
  {
    "path": "lidarnerf/raymarching/raymarching.py",
    "content": "import torch\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    import _raymarching as _backend\nexcept ImportError:\n    from .backend import _backend\n\n# ----------------------------------------\n# utils\n# ----------------------------------------\n\n\nclass _near_far_from_aabb(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):\n        \"\"\"near_far_from_aabb, CUDA implementation\n        Calculate rays' intersection time (near and far) with aabb\n        Args:\n            rays_o: float, [N, 3]\n            rays_d: float, [N, 3]\n            aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)\n            min_near: float, scalar\n        Returns:\n            nears: float, [N]\n            fars: float, [N]\n        \"\"\"\n        if not rays_o.is_cuda:\n            rays_o = rays_o.cuda()\n        if not rays_d.is_cuda:\n            rays_d = rays_d.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # num rays\n\n        nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)\n        fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)\n\n        _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)\n\n        return nears, fars\n\n\nnear_far_from_aabb = _near_far_from_aabb.apply\n\n\nclass _sph_from_ray(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, rays_o, rays_d, radius):\n        \"\"\"sph_from_ray, CUDA implementation\n        get spherical coordinate on the background sphere from rays.\n        Assume rays_o are inside the Sphere(radius).\n        Args:\n            rays_o: [N, 3]\n            rays_d: [N, 3]\n            radius: scalar, float\n        Return:\n            coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)\n        \"\"\"\n        if not rays_o.is_cuda:\n            rays_o = rays_o.cuda()\n        if not rays_d.is_cuda:\n            rays_d = rays_d.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # num rays\n\n        coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)\n\n        _backend.sph_from_ray(rays_o, rays_d, radius, N, coords)\n\n        return coords\n\n\nsph_from_ray = _sph_from_ray.apply\n\n\nclass _morton3D(Function):\n    @staticmethod\n    def forward(ctx, coords):\n        \"\"\"morton3D, CUDA implementation\n        Args:\n            coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)\n            TODO: check if the coord range is valid! (current 128 is safe)\n        Returns:\n            indices: [N], int32, in [0, 128^3)\n\n        \"\"\"\n        if not coords.is_cuda:\n            coords = coords.cuda()\n\n        N = coords.shape[0]\n\n        indices = torch.empty(N, dtype=torch.int32, device=coords.device)\n\n        _backend.morton3D(coords.int(), N, indices)\n\n        return indices\n\n\nmorton3D = _morton3D.apply\n\n\nclass _morton3D_invert(Function):\n    @staticmethod\n    def forward(ctx, indices):\n        \"\"\"morton3D_invert, CUDA implementation\n        Args:\n            indices: [N], int32, in [0, 128^3)\n        Returns:\n            coords: [N, 3], int32, in [0, 128)\n\n        \"\"\"\n        if not indices.is_cuda:\n            indices = indices.cuda()\n\n        N = indices.shape[0]\n\n        coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)\n\n        _backend.morton3D_invert(indices.int(), N, coords)\n\n        return coords\n\n\nmorton3D_invert = _morton3D_invert.apply\n\n\nclass _packbits(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, grid, thresh, bitfield=None):\n        \"\"\"packbits, CUDA implementation\n        Pack up the density grid into a bit field to accelerate ray marching.\n        Args:\n            grid: float, [C, H * H * H], assume H % 2 == 0\n            thresh: float, threshold\n        Returns:\n            bitfield: uint8, [C, H * H * H / 8]\n        \"\"\"\n        if not grid.is_cuda:\n            grid = grid.cuda()\n        grid = grid.contiguous()\n\n        C = grid.shape[0]\n        H3 = grid.shape[1]\n        N = C * H3 // 8\n\n        if bitfield is None:\n            bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)\n\n        _backend.packbits(grid, N, thresh, bitfield)\n\n        return bitfield\n\n\npackbits = _packbits.apply\n\n# ----------------------------------------\n# train functions\n# ----------------------------------------\n\n\nclass _march_rays_train(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(\n        ctx,\n        rays_o,\n        rays_d,\n        bound,\n        density_bitfield,\n        C,\n        H,\n        nears,\n        fars,\n        step_counter=None,\n        mean_count=-1,\n        perturb=False,\n        align=-1,\n        force_all_rays=False,\n        dt_gamma=0,\n        max_steps=1024,\n    ):\n        \"\"\"march rays to generate points (forward only)\n        Args:\n            rays_o/d: float, [N, 3]\n            bound: float, scalar\n            density_bitfield: uint8: [CHHH // 8]\n            C: int\n            H: int\n            nears/fars: float, [N]\n            step_counter: int32, (2), used to count the actual number of generated points.\n            mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)\n            perturb: bool\n            align: int, pad output so its size is dividable by align, set to -1 to disable.\n            force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.\n            dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)\n            max_steps: int, max number of sampled points along each ray, also affect min_stepsize.\n        Returns:\n            xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)\n            dirs: float, [M, 3], all generated points' view dirs.\n            deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)\n            rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]\n        \"\"\"\n\n        if not rays_o.is_cuda:\n            rays_o = rays_o.cuda()\n        if not rays_d.is_cuda:\n            rays_d = rays_d.cuda()\n        if not density_bitfield.is_cuda:\n            density_bitfield = density_bitfield.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n        density_bitfield = density_bitfield.contiguous()\n\n        N = rays_o.shape[0]  # num rays\n        M = N * max_steps  # init max points number in total\n\n        # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)\n        # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.\n        if not force_all_rays and mean_count > 0:\n            if align > 0:\n                mean_count += align - mean_count % align\n            M = mean_count\n\n        xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)\n        rays = torch.empty(\n            N, 3, dtype=torch.int32, device=rays_o.device\n        )  # id, offset, num_steps\n\n        if step_counter is None:\n            step_counter = torch.zeros(\n                2, dtype=torch.int32, device=rays_o.device\n            )  # point counter, ray counter\n\n        if perturb:\n            noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)\n        else:\n            noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)\n\n        _backend.march_rays_train(\n            rays_o,\n            rays_d,\n            density_bitfield,\n            bound,\n            dt_gamma,\n            max_steps,\n            N,\n            C,\n            H,\n            M,\n            nears,\n            fars,\n            xyzs,\n            dirs,\n            deltas,\n            rays,\n            step_counter,\n            noises,\n        )  # m is the actually used points number\n\n        # print(step_counter, M)\n\n        # only used at the first (few) epochs.\n        if force_all_rays or mean_count <= 0:\n            m = step_counter[0].item()  # D2H copy\n            if align > 0:\n                m += align - m % align\n            xyzs = xyzs[:m]\n            dirs = dirs[:m]\n            deltas = deltas[:m]\n\n            torch.cuda.empty_cache()\n\n        return xyzs, dirs, deltas, rays\n\n\nmarch_rays_train = _march_rays_train.apply\n\n\nclass _composite_rays_train(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):\n        \"\"\"composite rays' rgbs, according to the ray marching formula.\n        Args:\n            rgbs: float, [M, 3]\n            sigmas: float, [M,]\n            deltas: float, [M, 2]\n            rays: int32, [N, 3]\n        Returns:\n            weights_sum: float, [N,], the alpha channel\n            depth: float, [N, ], the Depth\n            image: float, [N, 3], the RGB channel (after multiplying alpha!)\n        \"\"\"\n\n        sigmas = sigmas.contiguous()\n        rgbs = rgbs.contiguous()\n\n        M = sigmas.shape[0]\n        N = rays.shape[0]\n\n        weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)\n        depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)\n        image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)\n\n        _backend.composite_rays_train_forward(\n            sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image\n        )\n\n        ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)\n        ctx.dims = [M, N, T_thresh]\n\n        return weights_sum, depth, image\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_weights_sum, grad_depth, grad_image):\n        # NOTE: grad_depth is not used now! It won't be propagated to sigmas.\n\n        grad_weights_sum = grad_weights_sum.contiguous()\n        grad_image = grad_image.contiguous()\n\n        sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors\n        M, N, T_thresh = ctx.dims\n\n        grad_sigmas = torch.zeros_like(sigmas)\n        grad_rgbs = torch.zeros_like(rgbs)\n\n        _backend.composite_rays_train_backward(\n            grad_weights_sum,\n            grad_image,\n            sigmas,\n            rgbs,\n            deltas,\n            rays,\n            weights_sum,\n            image,\n            M,\n            N,\n            T_thresh,\n            grad_sigmas,\n            grad_rgbs,\n        )\n\n        return grad_sigmas, grad_rgbs, None, None, None\n\n\ncomposite_rays_train = _composite_rays_train.apply\n\n# ----------------------------------------\n# infer functions\n# ----------------------------------------\n\n\nclass _march_rays(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(\n        ctx,\n        n_alive,\n        n_step,\n        rays_alive,\n        rays_t,\n        rays_o,\n        rays_d,\n        bound,\n        density_bitfield,\n        C,\n        H,\n        near,\n        far,\n        align=-1,\n        perturb=False,\n        dt_gamma=0,\n        max_steps=1024,\n    ):\n        \"\"\"march rays to generate points (forward only, for inference)\n        Args:\n            n_alive: int, number of alive rays\n            n_step: int, how many steps we march\n            rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)\n            rays_t: float, [N], the alive rays' time, we only use the first n_alive.\n            rays_o/d: float, [N, 3]\n            bound: float, scalar\n            density_bitfield: uint8: [CHHH // 8]\n            C: int\n            H: int\n            nears/fars: float, [N]\n            align: int, pad output so its size is dividable by align, set to -1 to disable.\n            perturb: bool/int, int > 0 is used as the random seed.\n            dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)\n            max_steps: int, max number of sampled points along each ray, also affect min_stepsize.\n        Returns:\n            xyzs: float, [n_alive * n_step, 3], all generated points' coords\n            dirs: float, [n_alive * n_step, 3], all generated points' view dirs.\n            deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).\n        \"\"\"\n\n        if not rays_o.is_cuda:\n            rays_o = rays_o.cuda()\n        if not rays_d.is_cuda:\n            rays_d = rays_d.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        M = n_alive * n_step\n\n        if align > 0:\n            M += align - (M % align)\n\n        xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        deltas = torch.zeros(\n            M, 2, dtype=rays_o.dtype, device=rays_o.device\n        )  # 2 vals, one for rgb, one for depth\n\n        if perturb:\n            # torch.manual_seed(perturb) # test_gui uses spp index as seed\n            noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)\n        else:\n            noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)\n\n        _backend.march_rays(\n            n_alive,\n            n_step,\n            rays_alive,\n            rays_t,\n            rays_o,\n            rays_d,\n            bound,\n            dt_gamma,\n            max_steps,\n            C,\n            H,\n            density_bitfield,\n            near,\n            far,\n            xyzs,\n            dirs,\n            deltas,\n            noises,\n        )\n\n        return xyzs, dirs, deltas\n\n\nmarch_rays = _march_rays.apply\n\n\nclass _composite_rays(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)  # need to cast sigmas & rgbs to float\n    def forward(\n        ctx,\n        n_alive,\n        n_step,\n        rays_alive,\n        rays_t,\n        sigmas,\n        rgbs,\n        deltas,\n        weights_sum,\n        depth,\n        image,\n        T_thresh=1e-2,\n    ):\n        \"\"\"composite rays' rgbs, according to the ray marching formula. (for inference)\n        Args:\n            n_alive: int, number of alive rays\n            n_step: int, how many steps we march\n            rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)\n            rays_t: float, [N], the alive rays' time\n            sigmas: float, [n_alive * n_step,]\n            rgbs: float, [n_alive * n_step, 3]\n            deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).\n        In-place Outputs:\n            weights_sum: float, [N,], the alpha channel\n            depth: float, [N,], the depth value\n            image: float, [N, 3], the RGB channel (after multiplying alpha!)\n        \"\"\"\n        _backend.composite_rays(\n            n_alive,\n            n_step,\n            T_thresh,\n            rays_alive,\n            rays_t,\n            sigmas,\n            rgbs,\n            deltas,\n            weights_sum,\n            depth,\n            image,\n        )\n        return tuple()\n\n\ncomposite_rays = _composite_rays.apply\n"
  },
  {
    "path": "lidarnerf/raymarching/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\"\"\"\nUsage:\n\npython setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)\n\npython setup.py install # build extensions and install (copy) to PATH.\npip install . # ditto but better (e.g., dependency & metadata handling)\n\npython setup.py develop # build extensions and install (symbolic) to PATH.\npip install -e . # ditto but better (e.g., dependency & metadata handling)\n\n\"\"\"\nsetup(\n    name=\"raymarching\",  # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name=\"_raymarching\",  # extension name, import this to use CUDA API\n            sources=[\n                os.path.join(_src_path, \"src\", f)\n                for f in [\n                    \"raymarching.cu\",\n                    \"bindings.cpp\",\n                ]\n            ],\n            extra_compile_args={\n                \"cxx\": c_flags,\n                \"nvcc\": nvcc_flags,\n            },\n        ),\n    ],\n    cmdclass={\n        \"build_ext\": BuildExtension,\n    },\n)\n"
  },
  {
    "path": "lidarnerf/raymarching/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"raymarching.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    // utils\n    m.def(\"packbits\", &packbits, \"packbits (CUDA)\");\n    m.def(\"near_far_from_aabb\", &near_far_from_aabb,\n          \"near_far_from_aabb (CUDA)\");\n    m.def(\"sph_from_ray\", &sph_from_ray, \"sph_from_ray (CUDA)\");\n    m.def(\"morton3D\", &morton3D, \"morton3D (CUDA)\");\n    m.def(\"morton3D_invert\", &morton3D_invert, \"morton3D_invert (CUDA)\");\n    // train\n    m.def(\"march_rays_train\", &march_rays_train, \"march_rays_train (CUDA)\");\n    m.def(\"composite_rays_train_forward\", &composite_rays_train_forward,\n          \"composite_rays_train_forward (CUDA)\");\n    m.def(\"composite_rays_train_backward\", &composite_rays_train_backward,\n          \"composite_rays_train_backward (CUDA)\");\n    // infer\n    m.def(\"march_rays\", &march_rays, \"march rays (CUDA)\");\n    m.def(\"composite_rays\", &composite_rays, \"composite rays (CUDA)\");\n}"
  },
  {
    "path": "lidarnerf/raymarching/src/raymarching.cu",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <stdint.h>\n#include <torch/torch.h>\n\n#include <cstdio>\n#include <limits>\n#include <stdexcept>\n\n#define CHECK_CUDA(x) \\\n    TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) \\\n    TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x)                                 \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \\\n                #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x)                                       \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Float ||        \\\n                        x.scalar_type() == at::ScalarType::Half || \\\n                        x.scalar_type() == at::ScalarType::Double, \\\n                #x \" must be a floating tensor\")\n\ninline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }\ninline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }\ninline constexpr __device__ float PI() { return 3.141592653589793f; }\ninline constexpr __device__ float RPI() { return 0.3183098861837907f; }\n\ntemplate <typename T>\ninline __host__ __device__ T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\ninline __host__ __device__ float signf(const float x) {\n    return copysignf(1.0, x);\n}\n\ninline __host__ __device__ float clamp(const float x,\n                                       const float min,\n                                       const float max) {\n    return fminf(max, fmaxf(min, x));\n}\n\ninline __host__ __device__ void swapf(float &a, float &b) {\n    float c = a;\n    a = b;\n    b = c;\n}\n\ninline __device__ int mip_from_pos(const float x,\n                                   const float y,\n                                   const float z,\n                                   const float max_cascade) {\n    const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));\n    int exponent;\n    frexpf(mx, &exponent);  // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1,\n                            // [2, 4) --> 2, ...\n    return fminf(max_cascade - 1, fmaxf(0, exponent));\n}\n\ninline __device__ int mip_from_dt(const float dt,\n                                  const float H,\n                                  const float max_cascade) {\n    const float mx = dt * H * 0.5;\n    int exponent;\n    frexpf(mx, &exponent);\n    return fminf(max_cascade - 1, fmaxf(0, exponent));\n}\n\ninline __host__ __device__ uint32_t __expand_bits(uint32_t v) {\n    v = (v * 0x00010001u) & 0xFF0000FFu;\n    v = (v * 0x00000101u) & 0x0F00F00Fu;\n    v = (v * 0x00000011u) & 0xC30C30C3u;\n    v = (v * 0x00000005u) & 0x49249249u;\n    return v;\n}\n\ninline __host__ __device__ uint32_t __morton3D(uint32_t x,\n                                               uint32_t y,\n                                               uint32_t z) {\n    uint32_t xx = __expand_bits(x);\n    uint32_t yy = __expand_bits(y);\n    uint32_t zz = __expand_bits(z);\n    return xx | (yy << 1) | (zz << 2);\n}\n\ninline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) {\n    x = x & 0x49249249;\n    x = (x | (x >> 2)) & 0xc30c30c3;\n    x = (x | (x >> 4)) & 0x0f00f00f;\n    x = (x | (x >> 8)) & 0xff0000ff;\n    x = (x | (x >> 16)) & 0x0000ffff;\n    return x;\n}\n\n////////////////////////////////////////////////////\n/////////////           utils          /////////////\n////////////////////////////////////////////////////\n\n// rays_o/d: [N, 3]\n// nears/fars: [N]\n// scalar_t should always be float in use.\ntemplate <typename scalar_t>\n__global__ void kernel_near_far_from_aabb(const scalar_t *__restrict__ rays_o,\n                                          const scalar_t *__restrict__ rays_d,\n                                          const scalar_t *__restrict__ aabb,\n                                          const uint32_t N,\n                                          const float min_near,\n                                          scalar_t *nears,\n                                          scalar_t *fars) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n\n    // get near far (assume cube scene)\n    float near = (aabb[0] - ox) * rdx;\n    float far = (aabb[3] - ox) * rdx;\n    if (near > far) swapf(near, far);\n\n    float near_y = (aabb[1] - oy) * rdy;\n    float far_y = (aabb[4] - oy) * rdy;\n    if (near_y > far_y) swapf(near_y, far_y);\n\n    if (near > far_y || near_y > far) {\n        nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();\n        return;\n    }\n\n    if (near_y > near) near = near_y;\n    if (far_y < far) far = far_y;\n\n    float near_z = (aabb[2] - oz) * rdz;\n    float far_z = (aabb[5] - oz) * rdz;\n    if (near_z > far_z) swapf(near_z, far_z);\n\n    if (near > far_z || near_z > far) {\n        nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();\n        return;\n    }\n\n    if (near_z > near) near = near_z;\n    if (far_z < far) far = far_z;\n\n    if (near < min_near) near = min_near;\n\n    nears[n] = near;\n    fars[n] = far;\n}\n\nvoid near_far_from_aabb(const at::Tensor rays_o,\n                        const at::Tensor rays_d,\n                        const at::Tensor aabb,\n                        const uint32_t N,\n                        const float min_near,\n                        at::Tensor nears,\n                        at::Tensor fars) {\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            rays_o.scalar_type(), \"near_far_from_aabb\", ([&] {\n                kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD),\n                                            N_THREAD>>>(\n                        rays_o.data_ptr<scalar_t>(),\n                        rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(),\n                        N, min_near, nears.data_ptr<scalar_t>(),\n                        fars.data_ptr<scalar_t>());\n            }));\n}\n\n// rays_o/d: [N, 3]\n// radius: float\n// coords: [N, 2]\ntemplate <typename scalar_t>\n__global__ void kernel_sph_from_ray(const scalar_t *__restrict__ rays_o,\n                                    const scalar_t *__restrict__ rays_d,\n                                    const float radius,\n                                    const uint32_t N,\n                                    scalar_t *coords) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n    coords += n * 2;\n\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n\n    // solve t from || o + td || = radius\n    const float A = dx * dx + dy * dy + dz * dz;\n    const float B = ox * dx + oy * dy + oz * dz;  // in fact B / 2\n    const float C = ox * ox + oy * oy + oz * oz - radius * radius;\n\n    const float t = (-B + sqrtf(B * B - A * C)) /\n                    A;  // always use the larger solution (positive)\n\n    // solve theta, phi (assume y is the up axis)\n    const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;\n    const float theta = atan2(sqrtf(x * x + z * z), y);  // [0, PI)\n    const float phi = atan2(z, x);                       // [-PI, PI)\n\n    // normalize to [-1, 1]\n    coords[0] = 2 * theta * RPI() - 1;\n    coords[1] = phi * RPI();\n}\n\nvoid sph_from_ray(const at::Tensor rays_o,\n                  const at::Tensor rays_d,\n                  const float radius,\n                  const uint32_t N,\n                  at::Tensor coords) {\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            rays_o.scalar_type(), \"sph_from_ray\", ([&] {\n                kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(\n                        rays_o.data_ptr<scalar_t>(),\n                        rays_d.data_ptr<scalar_t>(), radius, N,\n                        coords.data_ptr<scalar_t>());\n            }));\n}\n\n// coords: int32, [N, 3]\n// indices: int32, [N]\n__global__ void kernel_morton3D(const int *__restrict__ coords,\n                                const uint32_t N,\n                                int *indices) {\n    // parallel\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    coords += n * 3;\n    indices[n] = __morton3D(coords[0], coords[1], coords[2]);\n}\n\nvoid morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {\n    static constexpr uint32_t N_THREAD = 128;\n    kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(\n            coords.data_ptr<int>(), N, indices.data_ptr<int>());\n}\n\n// indices: int32, [N]\n// coords: int32, [N, 3]\n__global__ void kernel_morton3D_invert(const int *__restrict__ indices,\n                                       const uint32_t N,\n                                       int *coords) {\n    // parallel\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    coords += n * 3;\n\n    const int ind = indices[n];\n\n    coords[0] = __morton3D_invert(ind >> 0);\n    coords[1] = __morton3D_invert(ind >> 1);\n    coords[2] = __morton3D_invert(ind >> 2);\n}\n\nvoid morton3D_invert(const at::Tensor indices,\n                     const uint32_t N,\n                     at::Tensor coords) {\n    static constexpr uint32_t N_THREAD = 128;\n    kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(\n            indices.data_ptr<int>(), N, coords.data_ptr<int>());\n}\n\n// grid: float, [C, H, H, H]\n// N: int, C * H * H * H / 8\n// density_thresh: float\n// bitfield: uint8, [N]\ntemplate <typename scalar_t>\n__global__ void kernel_packbits(const scalar_t *__restrict__ grid,\n                                const uint32_t N,\n                                const float density_thresh,\n                                uint8_t *bitfield) {\n    // parallel per byte\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    grid += n * 8;\n\n    uint8_t bits = 0;\n\n#pragma unroll\n    for (uint8_t i = 0; i < 8; i++) {\n        bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;\n    }\n\n    bitfield[n] = bits;\n}\n\nvoid packbits(const at::Tensor grid,\n              const uint32_t N,\n              const float density_thresh,\n              at::Tensor bitfield) {\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            grid.scalar_type(), \"packbits\", ([&] {\n                kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(\n                        grid.data_ptr<scalar_t>(), N, density_thresh,\n                        bitfield.data_ptr<uint8_t>());\n            }));\n}\n\n////////////////////////////////////////////////////\n/////////////         training         /////////////\n////////////////////////////////////////////////////\n\n// rays_o/d: [N, 3]\n// grid: [CHHH / 8]\n// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]\n// dirs: [M, 3]\n// rays: [N, 3], idx, offset, num_steps\ntemplate <typename scalar_t>\n__global__ void kernel_march_rays_train(const scalar_t *__restrict__ rays_o,\n                                        const scalar_t *__restrict__ rays_d,\n                                        const uint8_t *__restrict__ grid,\n                                        const float bound,\n                                        const float dt_gamma,\n                                        const uint32_t max_steps,\n                                        const uint32_t N,\n                                        const uint32_t C,\n                                        const uint32_t H,\n                                        const uint32_t M,\n                                        const scalar_t *__restrict__ nears,\n                                        const scalar_t *__restrict__ fars,\n                                        scalar_t *xyzs,\n                                        scalar_t *dirs,\n                                        scalar_t *deltas,\n                                        int *rays,\n                                        int *counter,\n                                        const scalar_t *__restrict__ noises) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n\n    // ray marching\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n    const float rH = 1 / (float)H;\n    const float H3 = H * H * H;\n\n    const float near = nears[n];\n    const float far = fars[n];\n    const float noise = noises[n];\n\n    const float dt_min = 2 * SQRT3() / max_steps;\n    const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;\n\n    float t0 = near;\n\n    // perturb\n    t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;\n\n    // first pass: estimation of num_steps\n    float t = t0;\n    uint32_t num_steps = 0;\n\n    // if (t < far) printf(\"valid ray %d t=%f near=%f far=%f \\n\", n, t, near,\n    // far);\n\n    while (t < far && num_steps < max_steps) {\n        // current point\n        const float x = clamp(ox + t * dx, -bound, bound);\n        const float y = clamp(oy + t * dy, -bound, bound);\n        const float z = clamp(oz + t * dz, -bound, bound);\n\n        const float dt = clamp(t * dt_gamma, dt_min, dt_max);\n\n        // get mip level\n        const int level = max(mip_from_pos(x, y, z, C),\n                              mip_from_dt(dt, H, C));  // range in [0, C - 1]\n\n        const float mip_bound = fminf(scalbnf(1.0f, level), bound);\n        const float mip_rbound = 1 / mip_bound;\n\n        // convert to nearest grid position\n        const int nx =\n                clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int ny =\n                clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int nz =\n                clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n\n        const uint32_t index = level * H3 + __morton3D(nx, ny, nz);\n        const bool occ = grid[index / 8] & (1 << (index % 8));\n\n        // if occpuied, advance a small step, and write to output\n        // if (n == 0) printf(\"t=%f density=%f vs thresh=%f step=%d\\n\", t,\n        // density, density_thresh, num_steps);\n\n        if (occ) {\n            num_steps++;\n            t += dt;\n            // else, skip a large step (basically skip a voxel grid)\n        } else {\n            // calc distance to next voxel\n            const float tx =\n                    (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound -\n                     x) *\n                    rdx;\n            const float ty =\n                    (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound -\n                     y) *\n                    rdy;\n            const float tz =\n                    (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound -\n                     z) *\n                    rdz;\n\n            const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));\n            // step until next voxel\n            do {\n                t += clamp(t * dt_gamma, dt_min, dt_max);\n            } while (t < tt);\n        }\n    }\n\n    // printf(\"[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\\n\", n,\n    // num_steps, near, far, dt_min, (far - near) / dt_min);\n\n    // second pass: really locate and write points & dirs\n    uint32_t point_index = atomicAdd(counter, num_steps);\n    uint32_t ray_index = atomicAdd(counter + 1, 1);\n\n    // printf(\"[n=%d] num_steps=%d, point_index=%d, ray_index=%d\\n\", n,\n    // num_steps, point_index, ray_index);\n\n    // write rays\n    rays[ray_index * 3] = n;\n    rays[ray_index * 3 + 1] = point_index;\n    rays[ray_index * 3 + 2] = num_steps;\n\n    if (num_steps == 0) return;\n    if (point_index + num_steps > M) return;\n\n    xyzs += point_index * 3;\n    dirs += point_index * 3;\n    deltas += point_index * 2;\n\n    t = t0;\n    uint32_t step = 0;\n\n    float last_t = t;\n\n    while (t < far && step < num_steps) {\n        // current point\n        const float x = clamp(ox + t * dx, -bound, bound);\n        const float y = clamp(oy + t * dy, -bound, bound);\n        const float z = clamp(oz + t * dz, -bound, bound);\n\n        const float dt = clamp(t * dt_gamma, dt_min, dt_max);\n\n        // get mip level\n        const int level = max(mip_from_pos(x, y, z, C),\n                              mip_from_dt(dt, H, C));  // range in [0, C - 1]\n\n        const float mip_bound = fminf(scalbnf(1.0f, level), bound);\n        const float mip_rbound = 1 / mip_bound;\n\n        // convert to nearest grid position\n        const int nx =\n                clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int ny =\n                clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int nz =\n                clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n\n        // query grid\n        const uint32_t index = level * H3 + __morton3D(nx, ny, nz);\n        const bool occ = grid[index / 8] & (1 << (index % 8));\n\n        // if occpuied, advance a small step, and write to output\n        if (occ) {\n            // write step\n            xyzs[0] = x;\n            xyzs[1] = y;\n            xyzs[2] = z;\n            dirs[0] = dx;\n            dirs[1] = dy;\n            dirs[2] = dz;\n            t += dt;\n            deltas[0] = dt;\n            deltas[1] = t - last_t;  // used to calc depth\n            last_t = t;\n            xyzs += 3;\n            dirs += 3;\n            deltas += 2;\n            step++;\n            // else, skip a large step (basically skip a voxel grid)\n        } else {\n            // calc distance to next voxel\n            const float tx =\n                    (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound -\n                     x) *\n                    rdx;\n            const float ty =\n                    (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound -\n                     y) *\n                    rdy;\n            const float tz =\n                    (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound -\n                     z) *\n                    rdz;\n            const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));\n            // step until next voxel\n            do {\n                t += clamp(t * dt_gamma, dt_min, dt_max);\n            } while (t < tt);\n        }\n    }\n}\n\nvoid march_rays_train(const at::Tensor rays_o,\n                      const at::Tensor rays_d,\n                      const at::Tensor grid,\n                      const float bound,\n                      const float dt_gamma,\n                      const uint32_t max_steps,\n                      const uint32_t N,\n                      const uint32_t C,\n                      const uint32_t H,\n                      const uint32_t M,\n                      const at::Tensor nears,\n                      const at::Tensor fars,\n                      at::Tensor xyzs,\n                      at::Tensor dirs,\n                      at::Tensor deltas,\n                      at::Tensor rays,\n                      at::Tensor counter,\n                      at::Tensor noises) {\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            rays_o.scalar_type(), \"march_rays_train\", ([&] {\n                kernel_march_rays_train<<<div_round_up(N, N_THREAD),\n                                          N_THREAD>>>(\n                        rays_o.data_ptr<scalar_t>(),\n                        rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(),\n                        bound, dt_gamma, max_steps, N, C, H, M,\n                        nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(),\n                        xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(),\n                        deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(),\n                        counter.data_ptr<int>(), noises.data_ptr<scalar_t>());\n            }));\n}\n\n// sigmas: [M]\n// rgbs: [M, 3]\n// deltas: [M, 2]\n// rays: [N, 3], idx, offset, num_steps\n// weights_sum: [N], final pixel alpha\n// depth: [N,]\n// image: [N, 3]\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays_train_forward(\n        const scalar_t *__restrict__ sigmas,\n        const scalar_t *__restrict__ rgbs,\n        const scalar_t *__restrict__ deltas,\n        const int *__restrict__ rays,\n        const uint32_t M,\n        const uint32_t N,\n        const float T_thresh,\n        scalar_t *weights_sum,\n        scalar_t *depth,\n        scalar_t *image) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    uint32_t index = rays[n * 3];\n    uint32_t offset = rays[n * 3 + 1];\n    uint32_t num_steps = rays[n * 3 + 2];\n\n    // empty ray, or ray that exceed max step count.\n    if (num_steps == 0 || offset + num_steps > M) {\n        weights_sum[index] = 0;\n        depth[index] = 0;\n        image[index * 3] = 0;\n        image[index * 3 + 1] = 0;\n        image[index * 3 + 2] = 0;\n        return;\n    }\n\n    sigmas += offset;\n    rgbs += offset * 3;\n    deltas += offset * 2;\n\n    // accumulate\n    uint32_t step = 0;\n\n    scalar_t T = 1.0f;\n    scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;\n\n    while (step < num_steps) {\n        const scalar_t alpha = 1.0f - __expf(-sigmas[0] * deltas[0]);\n        const scalar_t weight = alpha * T;\n\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n\n        t += deltas[1];  // real delta\n        d += weight * t;\n\n        ws += weight;\n\n        T *= 1.0f - alpha;\n\n        // minimal remained transmittence\n        if (T < T_thresh) break;\n\n        // printf(\"[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f,\n        // d=%f\\n\", n, step, alpha, weight, T, sum_delta, d);\n\n        // locate\n        sigmas++;\n        rgbs += 3;\n        deltas += 2;\n\n        step++;\n    }\n\n    // printf(\"[n=%d] rgb=(%f, %f, %f), d=%f\\n\", n, r, g, b, d);\n\n    // write\n    weights_sum[index] = ws;  // weights_sum\n    depth[index] = d;\n    image[index * 3] = r;\n    image[index * 3 + 1] = g;\n    image[index * 3 + 2] = b;\n}\n\nvoid composite_rays_train_forward(const at::Tensor sigmas,\n                                  const at::Tensor rgbs,\n                                  const at::Tensor deltas,\n                                  const at::Tensor rays,\n                                  const uint32_t M,\n                                  const uint32_t N,\n                                  const float T_thresh,\n                                  at::Tensor weights_sum,\n                                  at::Tensor depth,\n                                  at::Tensor image) {\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            sigmas.scalar_type(), \"composite_rays_train_forward\", ([&] {\n                kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD),\n                                                      N_THREAD>>>(\n                        sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(),\n                        deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N,\n                        T_thresh, weights_sum.data_ptr<scalar_t>(),\n                        depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());\n            }));\n}\n\n// grad_weights_sum: [N,]\n// grad: [N, 3]\n// sigmas: [M]\n// rgbs: [M, 3]\n// deltas: [M, 2]\n// rays: [N, 3], idx, offset, num_steps\n// weights_sum: [N,], weights_sum here\n// image: [N, 3]\n// grad_sigmas: [M]\n// grad_rgbs: [M, 3]\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays_train_backward(\n        const scalar_t *__restrict__ grad_weights_sum,\n        const scalar_t *__restrict__ grad_image,\n        const scalar_t *__restrict__ sigmas,\n        const scalar_t *__restrict__ rgbs,\n        const scalar_t *__restrict__ deltas,\n        const int *__restrict__ rays,\n        const scalar_t *__restrict__ weights_sum,\n        const scalar_t *__restrict__ image,\n        const uint32_t M,\n        const uint32_t N,\n        const float T_thresh,\n        scalar_t *grad_sigmas,\n        scalar_t *grad_rgbs) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    uint32_t index = rays[n * 3];\n    uint32_t offset = rays[n * 3 + 1];\n    uint32_t num_steps = rays[n * 3 + 2];\n\n    if (num_steps == 0 || offset + num_steps > M) return;\n\n    grad_weights_sum += index;\n    grad_image += index * 3;\n    weights_sum += index;\n    image += index * 3;\n    sigmas += offset;\n    rgbs += offset * 3;\n    deltas += offset * 2;\n    grad_sigmas += offset;\n    grad_rgbs += offset * 3;\n\n    // accumulate\n    uint32_t step = 0;\n\n    scalar_t T = 1.0f;\n    const scalar_t r_final = image[0], g_final = image[1], b_final = image[2],\n                   ws_final = weights_sum[0];\n    scalar_t r = 0, g = 0, b = 0, ws = 0;\n\n    while (step < num_steps) {\n        const scalar_t alpha = 1.0f - __expf(-sigmas[0] * deltas[0]);\n        const scalar_t weight = alpha * T;\n\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n        ws += weight;\n\n        T *= 1.0f - alpha;\n\n        // check https://note.kiui.moe/others/nerf_gradient/ for the gradient\n        // calculation. write grad_rgbs\n        grad_rgbs[0] = grad_image[0] * weight;\n        grad_rgbs[1] = grad_image[1] * weight;\n        grad_rgbs[2] = grad_image[2] * weight;\n\n        // write grad_sigmas\n        grad_sigmas[0] =\n                deltas[0] * (grad_image[0] * (T * rgbs[0] - (r_final - r)) +\n                             grad_image[1] * (T * rgbs[1] - (g_final - g)) +\n                             grad_image[2] * (T * rgbs[2] - (b_final - b)) +\n                             grad_weights_sum[0] * (1 - ws_final));\n\n        // printf(\"[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f,\n        // r=%f\\n\", n, step, T, grad_sigmas[0], r_final, r);\n        //  minimal remained transmittence\n        if (T < T_thresh) break;\n\n        // locate\n        sigmas++;\n        rgbs += 3;\n        deltas += 2;\n        grad_sigmas++;\n        grad_rgbs += 3;\n\n        step++;\n    }\n}\n\nvoid composite_rays_train_backward(const at::Tensor grad_weights_sum,\n                                   const at::Tensor grad_image,\n                                   const at::Tensor sigmas,\n                                   const at::Tensor rgbs,\n                                   const at::Tensor deltas,\n                                   const at::Tensor rays,\n                                   const at::Tensor weights_sum,\n                                   const at::Tensor image,\n                                   const uint32_t M,\n                                   const uint32_t N,\n                                   const float T_thresh,\n                                   at::Tensor grad_sigmas,\n                                   at::Tensor grad_rgbs) {\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            grad_image.scalar_type(), \"composite_rays_train_backward\", ([&] {\n                kernel_composite_rays_train_backward<<<\n                        div_round_up(N, N_THREAD), N_THREAD>>>(\n                        grad_weights_sum.data_ptr<scalar_t>(),\n                        grad_image.data_ptr<scalar_t>(),\n                        sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(),\n                        deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(),\n                        weights_sum.data_ptr<scalar_t>(),\n                        image.data_ptr<scalar_t>(), M, N, T_thresh,\n                        grad_sigmas.data_ptr<scalar_t>(),\n                        grad_rgbs.data_ptr<scalar_t>());\n            }));\n}\n\n////////////////////////////////////////////////////\n/////////////          infernce        /////////////\n////////////////////////////////////////////////////\n\ntemplate <typename scalar_t>\n__global__ void kernel_march_rays(const uint32_t n_alive,\n                                  const uint32_t n_step,\n                                  const int *__restrict__ rays_alive,\n                                  const scalar_t *__restrict__ rays_t,\n                                  const scalar_t *__restrict__ rays_o,\n                                  const scalar_t *__restrict__ rays_d,\n                                  const float bound,\n                                  const float dt_gamma,\n                                  const uint32_t max_steps,\n                                  const uint32_t C,\n                                  const uint32_t H,\n                                  const uint8_t *__restrict__ grid,\n                                  const scalar_t *__restrict__ nears,\n                                  const scalar_t *__restrict__ fars,\n                                  scalar_t *xyzs,\n                                  scalar_t *dirs,\n                                  scalar_t *deltas,\n                                  const scalar_t *__restrict__ noises) {\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= n_alive) return;\n\n    const int index = rays_alive[n];  // ray id\n    const float noise = noises[n];\n\n    // locate\n    rays_o += index * 3;\n    rays_d += index * 3;\n    xyzs += n * n_step * 3;\n    dirs += n * n_step * 3;\n    deltas += n * n_step * 2;\n\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n    const float rH = 1 / (float)H;\n    const float H3 = H * H * H;\n\n    float t = rays_t[index];  // current ray's t\n    const float near = nears[index], far = fars[index];\n\n    const float dt_min = 2 * SQRT3() / max_steps;\n    const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;\n\n    // march for n_step steps, record points\n    uint32_t step = 0;\n\n    // introduce some randomness\n    t += clamp(t * dt_gamma, dt_min, dt_max) * noise;\n\n    float last_t = t;\n\n    while (t < far && step < n_step) {\n        // current point\n        const float x = clamp(ox + t * dx, -bound, bound);\n        const float y = clamp(oy + t * dy, -bound, bound);\n        const float z = clamp(oz + t * dz, -bound, bound);\n\n        const float dt = clamp(t * dt_gamma, dt_min, dt_max);\n\n        // get mip level\n        const int level = max(mip_from_pos(x, y, z, C),\n                              mip_from_dt(dt, H, C));  // range in [0, C - 1]\n\n        const float mip_bound = fminf(scalbnf(1, level), bound);\n        const float mip_rbound = 1 / mip_bound;\n\n        // convert to nearest grid position\n        const int nx =\n                clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int ny =\n                clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int nz =\n                clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n\n        const uint32_t index = level * H3 + __morton3D(nx, ny, nz);\n        const bool occ = grid[index / 8] & (1 << (index % 8));\n\n        // if occpuied, advance a small step, and write to output\n        if (occ) {\n            // write step\n            xyzs[0] = x;\n            xyzs[1] = y;\n            xyzs[2] = z;\n            dirs[0] = dx;\n            dirs[1] = dy;\n            dirs[2] = dz;\n            // calc dt\n            t += dt;\n            deltas[0] = dt;\n            deltas[1] = t - last_t;  // used to calc depth\n            last_t = t;\n            // step\n            xyzs += 3;\n            dirs += 3;\n            deltas += 2;\n            step++;\n\n            // else, skip a large step (basically skip a voxel grid)\n        } else {\n            // calc distance to next voxel\n            const float tx =\n                    (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound -\n                     x) *\n                    rdx;\n            const float ty =\n                    (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound -\n                     y) *\n                    rdy;\n            const float tz =\n                    (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound -\n                     z) *\n                    rdz;\n            const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));\n            // step until next voxel\n            do {\n                t += clamp(t * dt_gamma, dt_min, dt_max);\n            } while (t < tt);\n        }\n    }\n}\n\nvoid march_rays(const uint32_t n_alive,\n                const uint32_t n_step,\n                const at::Tensor rays_alive,\n                const at::Tensor rays_t,\n                const at::Tensor rays_o,\n                const at::Tensor rays_d,\n                const float bound,\n                const float dt_gamma,\n                const uint32_t max_steps,\n                const uint32_t C,\n                const uint32_t H,\n                const at::Tensor grid,\n                const at::Tensor near,\n                const at::Tensor far,\n                at::Tensor xyzs,\n                at::Tensor dirs,\n                at::Tensor deltas,\n                at::Tensor noises) {\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            rays_o.scalar_type(), \"march_rays\", ([&] {\n                kernel_march_rays<<<div_round_up(n_alive, N_THREAD),\n                                    N_THREAD>>>(\n                        n_alive, n_step, rays_alive.data_ptr<int>(),\n                        rays_t.data_ptr<scalar_t>(),\n                        rays_o.data_ptr<scalar_t>(),\n                        rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps,\n                        C, H, grid.data_ptr<uint8_t>(),\n                        near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(),\n                        xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(),\n                        deltas.data_ptr<scalar_t>(),\n                        noises.data_ptr<scalar_t>());\n            }));\n}\n\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays(const uint32_t n_alive,\n                                      const uint32_t n_step,\n                                      const float T_thresh,\n                                      int *rays_alive,\n                                      scalar_t *rays_t,\n                                      const scalar_t *__restrict__ sigmas,\n                                      const scalar_t *__restrict__ rgbs,\n                                      const scalar_t *__restrict__ deltas,\n                                      scalar_t *weights_sum,\n                                      scalar_t *depth,\n                                      scalar_t *image) {\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= n_alive) return;\n\n    const int index = rays_alive[n];  // ray id\n\n    // locate\n    sigmas += n * n_step;\n    rgbs += n * n_step * 3;\n    deltas += n * n_step * 2;\n\n    rays_t += index;\n    weights_sum += index;\n    depth += index;\n    image += index * 3;\n\n    scalar_t t = rays_t[0];  // current ray's t\n\n    scalar_t weight_sum = weights_sum[0];\n    scalar_t d = depth[0];\n    scalar_t r = image[0];\n    scalar_t g = image[1];\n    scalar_t b = image[2];\n\n    // accumulate\n    uint32_t step = 0;\n    while (step < n_step) {\n        // ray is terminated if delta == 0\n        if (deltas[0] == 0) break;\n\n        const scalar_t alpha = 1.0f - __expf(-sigmas[0] * deltas[0]);\n\n        /*\n        T_0 = 1; T_i = \\prod_{j=0}^{i-1} (1 - alpha_j)\n        w_i = alpha_i * T_i\n        -->\n        T_i = 1 - \\sum_{j=0}^{i-1} w_j\n        */\n        const scalar_t T = 1 - weight_sum;\n        const scalar_t weight = alpha * T;\n        weight_sum += weight;\n\n        t += deltas[1];  // real delta\n        d += weight * t;\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n\n        // printf(\"[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f,\n        // d=%f\\n\", n, step, alpha, weight, T, sum_delta, d);\n\n        // ray is terminated if T is too small\n        // use a larger bound to further accelerate inference\n        if (T < T_thresh) break;\n\n        // locate\n        sigmas++;\n        rgbs += 3;\n        deltas += 2;\n        step++;\n    }\n\n    // printf(\"[n=%d] rgb=(%f, %f, %f), d=%f\\n\", n, r, g, b, d);\n\n    // rays_alive = -1 means ray is terminated early.\n    if (step < n_step) {\n        rays_alive[n] = -1;\n    } else {\n        rays_t[0] = t;\n    }\n\n    weights_sum[0] = weight_sum;  // this is the thing I needed!\n    depth[0] = d;\n    image[0] = r;\n    image[1] = g;\n    image[2] = b;\n}\n\nvoid composite_rays(const uint32_t n_alive,\n                    const uint32_t n_step,\n                    const float T_thresh,\n                    at::Tensor rays_alive,\n                    at::Tensor rays_t,\n                    at::Tensor sigmas,\n                    at::Tensor rgbs,\n                    at::Tensor deltas,\n                    at::Tensor weights,\n                    at::Tensor depth,\n                    at::Tensor image) {\n    static constexpr uint32_t N_THREAD = 128;\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            image.scalar_type(), \"composite_rays\", ([&] {\n                kernel_composite_rays<<<div_round_up(n_alive, N_THREAD),\n                                        N_THREAD>>>(\n                        n_alive, n_step, T_thresh, rays_alive.data_ptr<int>(),\n                        rays_t.data_ptr<scalar_t>(),\n                        sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(),\n                        deltas.data_ptr<scalar_t>(),\n                        weights.data_ptr<scalar_t>(),\n                        depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());\n            }));\n}"
  },
  {
    "path": "lidarnerf/raymarching/src/raymarching.h",
    "content": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\nvoid near_far_from_aabb(const at::Tensor rays_o,\n                        const at::Tensor rays_d,\n                        const at::Tensor aabb,\n                        const uint32_t N,\n                        const float min_near,\n                        at::Tensor nears,\n                        at::Tensor fars);\nvoid sph_from_ray(const at::Tensor rays_o,\n                  const at::Tensor rays_d,\n                  const float radius,\n                  const uint32_t N,\n                  at::Tensor coords);\nvoid morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);\nvoid morton3D_invert(const at::Tensor indices,\n                     const uint32_t N,\n                     at::Tensor coords);\nvoid packbits(const at::Tensor grid,\n              const uint32_t N,\n              const float density_thresh,\n              at::Tensor bitfield);\n\nvoid march_rays_train(const at::Tensor rays_o,\n                      const at::Tensor rays_d,\n                      const at::Tensor grid,\n                      const float bound,\n                      const float dt_gamma,\n                      const uint32_t max_steps,\n                      const uint32_t N,\n                      const uint32_t C,\n                      const uint32_t H,\n                      const uint32_t M,\n                      const at::Tensor nears,\n                      const at::Tensor fars,\n                      at::Tensor xyzs,\n                      at::Tensor dirs,\n                      at::Tensor deltas,\n                      at::Tensor rays,\n                      at::Tensor counter,\n                      at::Tensor noises);\nvoid composite_rays_train_forward(const at::Tensor sigmas,\n                                  const at::Tensor rgbs,\n                                  const at::Tensor deltas,\n                                  const at::Tensor rays,\n                                  const uint32_t M,\n                                  const uint32_t N,\n                                  const float T_thresh,\n                                  at::Tensor weights_sum,\n                                  at::Tensor depth,\n                                  at::Tensor image);\nvoid composite_rays_train_backward(const at::Tensor grad_weights_sum,\n                                   const at::Tensor grad_image,\n                                   const at::Tensor sigmas,\n                                   const at::Tensor rgbs,\n                                   const at::Tensor deltas,\n                                   const at::Tensor rays,\n                                   const at::Tensor weights_sum,\n                                   const at::Tensor image,\n                                   const uint32_t M,\n                                   const uint32_t N,\n                                   const float T_thresh,\n                                   at::Tensor grad_sigmas,\n                                   at::Tensor grad_rgbs);\n\nvoid march_rays(const uint32_t n_alive,\n                const uint32_t n_step,\n                const at::Tensor rays_alive,\n                const at::Tensor rays_t,\n                const at::Tensor rays_o,\n                const at::Tensor rays_d,\n                const float bound,\n                const float dt_gamma,\n                const uint32_t max_steps,\n                const uint32_t C,\n                const uint32_t H,\n                const at::Tensor grid,\n                const at::Tensor nears,\n                const at::Tensor fars,\n                at::Tensor xyzs,\n                at::Tensor dirs,\n                at::Tensor deltas,\n                at::Tensor noises);\nvoid composite_rays(const uint32_t n_alive,\n                    const uint32_t n_step,\n                    const float T_thresh,\n                    at::Tensor rays_alive,\n                    at::Tensor rays_t,\n                    at::Tensor sigmas,\n                    at::Tensor rgbs,\n                    at::Tensor deltas,\n                    at::Tensor weights_sum,\n                    at::Tensor depth,\n                    at::Tensor image);"
  },
  {
    "path": "lidarnerf/shencoder/__init__.py",
    "content": ""
  },
  {
    "path": "lidarnerf/shencoder/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(\n    name=\"_sh_encoder\",\n    extra_cflags=c_flags,\n    extra_cuda_cflags=nvcc_flags,\n    sources=[\n        os.path.join(_src_path, \"src\", f)\n        for f in [\n            \"shencoder.cu\",\n            \"bindings.cpp\",\n        ]\n    ],\n)\n\n__all__ = [\"_backend\"]\n"
  },
  {
    "path": "lidarnerf/shencoder/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name=\"shencoder\",  # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name=\"_shencoder\",  # extension name, import this to use CUDA API\n            sources=[\n                os.path.join(_src_path, \"src\", f)\n                for f in [\n                    \"shencoder.cu\",\n                    \"bindings.cpp\",\n                ]\n            ],\n            extra_compile_args={\n                \"cxx\": c_flags,\n                \"nvcc\": nvcc_flags,\n            },\n        ),\n    ],\n    cmdclass={\n        \"build_ext\": BuildExtension,\n    },\n)\n"
  },
  {
    "path": "lidarnerf/shencoder/sphere_harmonics.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    import _shencoder as _backend\nexcept ImportError:\n    from .backend import _backend\n\n\nclass _sh_encoder(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)  # force float32 for better precision\n    def forward(ctx, inputs, degree, calc_grad_inputs=False):\n        # inputs: [B, input_dim], float in [-1, 1]\n        # RETURN: [B, F], float\n\n        inputs = inputs.contiguous()\n        B, input_dim = inputs.shape  # batch size, coord dim\n        output_dim = degree**2\n\n        outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)\n\n        if calc_grad_inputs:\n            dy_dx = torch.empty(\n                B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device\n            )\n        else:\n            dy_dx = None\n\n        _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx)\n\n        ctx.save_for_backward(inputs, dy_dx)\n        ctx.dims = [B, input_dim, degree]\n\n        return outputs\n\n    @staticmethod\n    # @once_differentiable\n    @custom_bwd\n    def backward(ctx, grad):\n        # grad: [B, C * C]\n\n        inputs, dy_dx = ctx.saved_tensors\n\n        if dy_dx is not None:\n            grad = grad.contiguous()\n            B, input_dim, degree = ctx.dims\n            grad_inputs = torch.zeros_like(inputs)\n            _backend.sh_encode_backward(\n                grad, inputs, B, input_dim, degree, dy_dx, grad_inputs\n            )\n            return grad_inputs, None, None\n        else:\n            return None, None, None\n\n\nsh_encode = _sh_encoder.apply\n\n\nclass SHEncoder(nn.Module):\n    def __init__(self, input_dim=3, degree=4):\n        super().__init__()\n\n        self.input_dim = input_dim  # coord dims, must be 3\n        self.degree = degree  # 0 ~ 4\n        self.output_dim = degree**2\n\n        assert self.input_dim == 3, \"SH encoder only support input dim == 3\"\n        assert (\n            self.degree > 0 and self.degree <= 8\n        ), \"SH encoder only supports degree in [1, 8]\"\n\n    def __repr__(self):\n        return f\"SHEncoder: input_dim={self.input_dim} degree={self.degree}\"\n\n    def forward(self, inputs, size=1):\n        # inputs: [..., input_dim], normalized real world positions in [-size, size]\n        # return: [..., degree^2]\n\n        inputs = inputs / size  # [-1, 1]\n\n        prefix_shape = list(inputs.shape[:-1])\n        inputs = inputs.reshape(-1, self.input_dim)\n\n        outputs = sh_encode(inputs, self.degree, inputs.requires_grad)\n        outputs = outputs.reshape(prefix_shape + [self.output_dim])\n\n        return outputs\n"
  },
  {
    "path": "lidarnerf/shencoder/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"shencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"sh_encode_forward\", &sh_encode_forward, \"SH encode forward (CUDA)\");\n    m.def(\"sh_encode_backward\", &sh_encode_backward,\n          \"SH encode backward (CUDA)\");\n}"
  },
  {
    "path": "lidarnerf/shencoder/src/shencoder.cu",
    "content": "#include <ATen/cuda/CUDAContext.h>\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n#include <stdint.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <cstdio>\n#include <stdexcept>\n\n#define CHECK_CUDA(x) \\\n    TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) \\\n    TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x)                                 \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \\\n                #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x)                                       \\\n    TORCH_CHECK(x.scalar_type() == at::ScalarType::Float ||        \\\n                        x.scalar_type() == at::ScalarType::Half || \\\n                        x.scalar_type() == at::ScalarType::Double, \\\n                #x \" must be a floating tensor\")\n\ntemplate <typename T>\n__host__ __device__ T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\ntemplate <typename scalar_t>\n__global__ void kernel_sh(const scalar_t *__restrict__ inputs,\n                          scalar_t *outputs,\n                          uint32_t B,\n                          uint32_t D,\n                          uint32_t C,\n                          scalar_t *dy_dx) {\n    const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;\n    if (b >= B) return;\n\n    const uint32_t C2 = C * C;\n\n    // locate\n    inputs += b * D;\n    outputs += b * C2;\n\n    scalar_t x = inputs[0], y = inputs[1], z = inputs[2];\n\n    scalar_t xy = x * y, xz = x * z, yz = y * z, x2 = x * x, y2 = y * y,\n             z2 = z * z, xyz = xy * z;\n    scalar_t x4 = x2 * x2, y4 = y2 * y2, z4 = z2 * z2;\n    scalar_t x6 = x4 * x2, y6 = y4 * y2, z6 = z4 * z2;\n\n    auto write_sh = [&]() {\n        outputs[0] = 0.28209479177387814f;  // 1/(2*sqrt(pi))\n        if (C <= 1) {\n            return;\n        }\n        outputs[1] = -0.48860251190291987f * y;  // -sqrt(3)*y/(2*sqrt(pi))\n        outputs[2] = 0.48860251190291987f * z;   // sqrt(3)*z/(2*sqrt(pi))\n        outputs[3] = -0.48860251190291987f * x;  // -sqrt(3)*x/(2*sqrt(pi))\n        if (C <= 2) {\n            return;\n        }\n        outputs[4] = 1.0925484305920792f * xy;   // sqrt(15)*xy/(2*sqrt(pi))\n        outputs[5] = -1.0925484305920792f * yz;  // -sqrt(15)*yz/(2*sqrt(pi))\n        outputs[6] = 0.94617469575755997f * z2 -\n                     0.31539156525251999f;  // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))\n        outputs[7] = -1.0925484305920792f * xz;  // -sqrt(15)*xz/(2*sqrt(pi))\n        outputs[8] =\n                0.54627421529603959f * x2 -\n                0.54627421529603959f * y2;  // sqrt(15)*(x2 - y2)/(4*sqrt(pi))\n        if (C <= 3) {\n            return;\n        }\n        outputs[9] = 0.59004358992664352f * y *\n                     (-3.0f * x2 + y2);  // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))\n        outputs[10] =\n                2.8906114426405538f * xy * z;  // sqrt(105)*xy*z/(2*sqrt(pi))\n        outputs[11] = 0.45704579946446572f * y *\n                      (1.0f - 5.0f * z2);  // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))\n        outputs[12] = 0.3731763325901154f * z *\n                      (5.0f * z2 - 3.0f);  // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))\n        outputs[13] = 0.45704579946446572f * x *\n                      (1.0f - 5.0f * z2);  // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))\n        outputs[14] = 1.4453057213202769f * z *\n                      (x2 - y2);  // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))\n        outputs[15] =\n                0.59004358992664352f * x *\n                (-x2 + 3.0f * y2);  // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))\n        if (C <= 4) {\n            return;\n        }\n        outputs[16] = 2.5033429417967046f * xy *\n                      (x2 - y2);  // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))\n        outputs[17] =\n                1.7701307697799304f * yz *\n                (-3.0f * x2 + y2);  // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))\n        outputs[18] =\n                0.94617469575756008f * xy *\n                (7.0f * z2 - 1.0f);  // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))\n        outputs[19] =\n                0.66904654355728921f * yz *\n                (3.0f - 7.0f * z2);  // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))\n        outputs[20] =\n                -3.1735664074561294f * z2 + 3.7024941420321507f * z4 +\n                0.31735664074561293f;  // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))\n        outputs[21] =\n                0.66904654355728921f * xz *\n                (3.0f - 7.0f * z2);  // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))\n        outputs[22] = 0.47308734787878004f * (x2 - y2) *\n                      (7.0f * z2 -\n                       1.0f);  // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))\n        outputs[23] =\n                1.7701307697799304f * xz *\n                (-x2 + 3.0f * y2);  // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))\n        outputs[24] =\n                -3.7550144126950569f * x2 * y2 + 0.62583573544917614f * x4 +\n                0.62583573544917614f *\n                        y4;  // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n        if (C <= 5) {\n            return;\n        }\n        outputs[25] =\n                0.65638205684017015f * y *\n                (10.0f * x2 * y2 - 5.0f * x4 -\n                 y4);  // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n        outputs[26] = 8.3026492595241645f * xy * z *\n                      (x2 - y2);  // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))\n        outputs[27] =\n                -0.48923829943525038f * y * (3.0f * x2 - y2) *\n                (9.0f * z2 -\n                 1.0f);  // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))\n        outputs[28] =\n                4.7935367849733241f * xy * z *\n                (3.0f * z2 - 1.0f);  // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))\n        outputs[29] = 0.45294665119569694f * y *\n                      (14.0f * z2 - 21.0f * z4 -\n                       1.0f);  // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n        outputs[30] =\n                0.1169503224534236f * z *\n                (-70.0f * z2 + 63.0f * z4 +\n                 15.0f);  // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))\n        outputs[31] = 0.45294665119569694f * x *\n                      (14.0f * z2 - 21.0f * z4 -\n                       1.0f);  // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n        outputs[32] = 2.3967683924866621f * z * (x2 - y2) *\n                      (3.0f * z2 -\n                       1.0f);  // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))\n        outputs[33] =\n                -0.48923829943525038f * x * (x2 - 3.0f * y2) *\n                (9.0f * z2 -\n                 1.0f);  // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))\n        outputs[34] = 2.0756623148810411f * z *\n                      (-6.0f * x2 * y2 + x4 +\n                       y4);  // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n        outputs[35] = 0.65638205684017015f * x *\n                      (10.0f * x2 * y2 - x4 -\n                       5.0f * y4);  // 3*sqrt(154)*x*(10*x2*y2 - x4 -\n                                    // 5*y4)/(32*sqrt(pi))\n        if (C <= 6) {\n            return;\n        }\n        outputs[36] = 1.3663682103838286f * xy *\n                      (-10.0f * x2 * y2 + 3.0f * x4 +\n                       3.0f * y4);  // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 +\n                                    // 3*y4)/(32*sqrt(pi))\n        outputs[37] =\n                2.3666191622317521f * yz *\n                (10.0f * x2 * y2 - 5.0f * x4 -\n                 y4);  // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n        outputs[38] =\n                2.0182596029148963f * xy * (x2 - y2) *\n                (11.0f * z2 -\n                 1.0f);  // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))\n        outputs[39] =\n                -0.92120525951492349f * yz * (3.0f * x2 - y2) *\n                (11.0f * z2 -\n                 3.0f);  // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))\n        outputs[40] =\n                0.92120525951492349f * xy *\n                (-18.0f * z2 + 33.0f * z4 +\n                 1.0f);  // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n        outputs[41] = 0.58262136251873131f * yz *\n                      (30.0f * z2 - 33.0f * z4 -\n                       5.0f);  // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n        outputs[42] = 6.6747662381009842f * z2 - 20.024298714302954f * z4 +\n                      14.684485723822165f * z6 -\n                      0.31784601133814211f;  // sqrt(13)*(105*z2 - 315*z4 +\n                                             // 231*z6 - 5)/(32*sqrt(pi))\n        outputs[43] = 0.58262136251873131f * xz *\n                      (30.0f * z2 - 33.0f * z4 -\n                       5.0f);  // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n        outputs[44] = 0.46060262975746175f * (x2 - y2) *\n                      (11.0f * z2 * (3.0f * z2 - 1.0f) - 7.0f * z2 +\n                       1.0f);  // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2\n                               // + 1)/(64*sqrt(pi))\n        outputs[45] =\n                -0.92120525951492349f * xz * (x2 - 3.0f * y2) *\n                (11.0f * z2 -\n                 3.0f);  // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))\n        outputs[46] = 0.50456490072872406f * (11.0f * z2 - 1.0f) *\n                      (-6.0f * x2 * y2 + x4 +\n                       y4);  // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 +\n                             // y4)/(32*sqrt(pi))\n        outputs[47] = 2.3666191622317521f * xz *\n                      (10.0f * x2 * y2 - x4 -\n                       5.0f * y4);  // 3*sqrt(2002)*xz*(10*x2*y2 - x4 -\n                                    // 5*y4)/(32*sqrt(pi))\n        outputs[48] =\n                10.247761577878714f * x2 * y4 - 10.247761577878714f * x4 * y2 +\n                0.6831841051919143f * x6 -\n                0.6831841051919143f * y6;  // sqrt(6006)*(15*x2*y4 - 15*x4*y2 +\n                                           // x6 - y6)/(64*sqrt(pi))\n        if (C <= 7) {\n            return;\n        }\n        outputs[49] = 0.70716273252459627f * y *\n                      (-21.0f * x2 * y4 + 35.0f * x4 * y2 - 7.0f * x6 +\n                       y6);  // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 +\n                             // y6)/(64*sqrt(pi))\n        outputs[50] = 5.2919213236038001f * xy * z *\n                      (-10.0f * x2 * y2 + 3.0f * x4 +\n                       3.0f * y4);  // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 +\n                                    // 3*y4)/(32*sqrt(pi))\n        outputs[51] = -0.51891557872026028f * y * (13.0f * z2 - 1.0f) *\n                      (-10.0f * x2 * y2 + 5.0f * x4 +\n                       y4);  // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 +\n                             // y4)/(64*sqrt(pi))\n        outputs[52] =\n                4.1513246297620823f * xy * z * (x2 - y2) *\n                (13.0f * z2 -\n                 3.0f);  // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))\n        outputs[53] = -0.15645893386229404f * y * (3.0f * x2 - y2) *\n                      (13.0f * z2 * (11.0f * z2 - 3.0f) - 27.0f * z2 +\n                       3.0f);  // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) -\n                               // 27*z2 + 3)/(64*sqrt(pi))\n        outputs[54] = 0.44253269244498261f * xy * z *\n                      (-110.0f * z2 + 143.0f * z4 +\n                       15.0f);  // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 +\n                                // 15)/(32*sqrt(pi))\n        outputs[55] = 0.090331607582517306f * y *\n                      (-135.0f * z2 + 495.0f * z4 - 429.0f * z6 +\n                       5.0f);  // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 +\n                               // 5)/(64*sqrt(pi))\n        outputs[56] = 0.068284276912004949f * z *\n                      (315.0f * z2 - 693.0f * z4 + 429.0f * z6 -\n                       35.0f);  // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 -\n                                // 35)/(32*sqrt(pi))\n        outputs[57] = 0.090331607582517306f * x *\n                      (-135.0f * z2 + 495.0f * z4 - 429.0f * z6 +\n                       5.0f);  // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 +\n                               // 5)/(64*sqrt(pi))\n        outputs[58] = 0.07375544874083044f * z * (x2 - y2) *\n                      (143.0f * z2 * (3.0f * z2 - 1.0f) - 187.0f * z2 +\n                       45.0f);  // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) -\n                                // 187*z2 + 45)/(64*sqrt(pi))\n        outputs[59] = -0.15645893386229404f * x * (x2 - 3.0f * y2) *\n                      (13.0f * z2 * (11.0f * z2 - 3.0f) - 27.0f * z2 +\n                       3.0f);  // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) -\n                               // 27*z2 + 3)/(64*sqrt(pi))\n        outputs[60] = 1.0378311574405206f * z * (13.0f * z2 - 3.0f) *\n                      (-6.0f * x2 * y2 + x4 +\n                       y4);  // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 +\n                             // y4)/(32*sqrt(pi))\n        outputs[61] = -0.51891557872026028f * x * (13.0f * z2 - 1.0f) *\n                      (-10.0f * x2 * y2 + x4 +\n                       5.0f * y4);  // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 +\n                                    // x4 + 5*y4)/(64*sqrt(pi))\n        outputs[62] = 2.6459606618019f * z *\n                      (15.0f * x2 * y4 - 15.0f * x4 * y2 + x6 -\n                       y6);  // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 -\n                             // y6)/(64*sqrt(pi))\n        outputs[63] = 0.70716273252459627f * x *\n                      (-35.0f * x2 * y4 + 21.0f * x4 * y2 - x6 +\n                       7.0f * y6);  // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6\n                                    // + 7*y6)/(64*sqrt(pi))\n    };\n\n    write_sh();\n\n    if (dy_dx) {\n        scalar_t *dx = dy_dx + b * D * C2;\n        scalar_t *dy = dx + C2;\n        scalar_t *dz = dy + C2;\n\n        auto write_sh_dx = [&]() {\n            dx[0] = 0.0f;  // 0\n            if (C <= 1) {\n                return;\n            }\n            dx[1] = 0.0f;                   // 0\n            dx[2] = 0.0f;                   // 0\n            dx[3] = -0.48860251190291992f;  // -sqrt(3)/(2*sqrt(pi))\n            if (C <= 2) {\n                return;\n            }\n            dx[4] = 1.0925484305920792f * y;   // sqrt(15)*y/(2*sqrt(pi))\n            dx[5] = 0.0f;                      // 0\n            dx[6] = 0.0f;                      // 0\n            dx[7] = -1.0925484305920792f * z;  // -sqrt(15)*z/(2*sqrt(pi))\n            dx[8] = 1.0925484305920792f * x;   // sqrt(15)*x/(2*sqrt(pi))\n            if (C <= 3) {\n                return;\n            }\n            dx[9] = -3.5402615395598609f * xy;  // -3*sqrt(70)*xy/(4*sqrt(pi))\n            dx[10] = 2.8906114426405538f * yz;  // sqrt(105)*yz/(2*sqrt(pi))\n            dx[11] = 0.0f;                      // 0\n            dx[12] = 0.0f;                      // 0\n            dx[13] = 0.45704579946446572f -\n                     2.2852289973223288f *\n                             z2;  // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))\n            dx[14] = 2.8906114426405538f * xz;  // sqrt(105)*xz/(2*sqrt(pi))\n            dx[15] = -1.7701307697799304f * x2 +\n                     1.7701307697799304f *\n                             y2;  // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))\n            if (C <= 4) {\n                return;\n            }\n            dx[16] = 2.5033429417967046f * y *\n                     (3.0f * x2 - y2);  // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))\n            dx[17] = -10.620784618679583f * xy *\n                     z;  // -9*sqrt(70)*xy*z/(4*sqrt(pi))\n            dx[18] = 0.94617469575756008f * y *\n                     (7.0f * z2 - 1.0f);  // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))\n            dx[19] = 0.0f;                // 0\n            dx[20] = 0.0f;                // 0\n            dx[21] =\n                    0.66904654355728921f * z *\n                    (3.0f - 7.0f * z2);  // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))\n            dx[22] = 0.94617469575756008f * x *\n                     (7.0f * z2 - 1.0f);  // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))\n            dx[23] = 5.3103923093397913f * z *\n                     (-x2 + y2);  // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))\n            dx[24] = 2.5033429417967046f * x *\n                     (x2 - 3.0f * y2);  // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))\n            if (C <= 5) {\n                return;\n            }\n            dx[25] = 13.127641136803401f * xy *\n                     (-x2 + y2);  // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))\n            dx[26] = 8.3026492595241645f * yz *\n                     (3.0f * x2 -\n                      y2);  // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))\n            dx[27] = 2.9354297966115022f * xy *\n                     (1.0f -\n                      9.0f * z2);  // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))\n            dx[28] = 4.7935367849733241f * yz *\n                     (3.0f * z2 -\n                      1.0f);  // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))\n            dx[29] = 0.0f;    // 0\n            dx[30] = 0.0f;    // 0\n            dx[31] = 6.3412531167397574f * z2 - 9.5118796751096362f * z4 -\n                     0.45294665119569694f;  // sqrt(165)*(14*z2 - 21*z4 -\n                                            // 1)/(16*sqrt(pi))\n            dx[32] = 4.7935367849733241f * xz *\n                     (3.0f * z2 -\n                      1.0f);  // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))\n            dx[33] = -13.209434084751759f * x2 * z2 + 1.4677148983057511f * x2 +\n                     13.209434084751759f * y2 * z2 -\n                     1.4677148983057511f * y2;  // 3*sqrt(770)*(-9*x2*z2 + x2 +\n                                                // 9*y2*z2 - y2)/(32*sqrt(pi))\n            dx[34] = 8.3026492595241645f * xz *\n                     (x2 -\n                      3.0f * y2);  // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))\n            dx[35] = 19.6914617052051f * x2 * y2 - 3.2819102842008503f * x4 -\n                     3.2819102842008503f * y4;  // 15*sqrt(154)*(6*x2*y2 - x4 -\n                                                // y4)/(32*sqrt(pi))\n            if (C <= 6) {\n                return;\n            }\n            dx[36] = 4.0991046311514854f * y *\n                     (-10.0f * x2 * y2 + 5.0f * x4 +\n                      y4);  // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 +\n                            // y4)/(32*sqrt(pi))\n            dx[37] = 47.332383244635047f * xy * z *\n                     (-x2 + y2);  // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))\n            dx[38] = 2.0182596029148963f * y * (3.0f * x2 - y2) *\n                     (11.0f * z2 - 1.0f);  // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 -\n                                           // 1)/(8*sqrt(pi))\n            dx[39] = 5.5272315570895412f * xy * z *\n                     (3.0f - 11.0f * z2);  // 3*sqrt(2730)*xy*z*(3 -\n                                           // 11*z2)/(16*sqrt(pi))\n            dx[40] = 0.92120525951492349f * y *\n                     (-18.0f * z2 + 33.0f * z4 +\n                      1.0f);  // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n            dx[41] = 0.0f;    // 0\n            dx[42] = 0.0f;    // 0\n            dx[43] = 0.58262136251873131f * z *\n                     (30.0f * z2 - 33.0f * z4 -\n                      5.0f);  // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n            dx[44] = 0.92120525951492349f * x *\n                     (-18.0f * z2 + 33.0f * z4 +\n                      1.0f);  // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n            dx[45] = -2.7636157785447706f * z * (x2 - y2) *\n                     (11.0f * z2 - 3.0f);  // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 -\n                                           // 3)/(32*sqrt(pi))\n            dx[46] = 2.0182596029148963f * x * (x2 - 3.0f * y2) *\n                     (11.0f * z2 - 1.0f);  // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 -\n                                           // 1)/(8*sqrt(pi))\n            dx[47] = 11.833095811158762f * z *\n                     (6.0f * x2 * y2 - x4 -\n                      y4);  // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n            dx[48] = 4.0991046311514854f * x *\n                     (-10.0f * x2 * y2 + x4 +\n                      5.0f * y4);  // 3*sqrt(6006)*x*(-10*x2*y2 + x4 +\n                                   // 5*y4)/(32*sqrt(pi))\n            if (C <= 7) {\n                return;\n            }\n            dx[49] = 9.9002782553443485f * xy *\n                     (10.0f * x2 * y2 - 3.0f * x4 -\n                      3.0f * y4);  // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 -\n                                   // 3*y4)/(32*sqrt(pi))\n            dx[50] = 15.875763970811402f * yz *\n                     (-10.0f * x2 * y2 + 5.0f * x4 +\n                      y4);  // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 +\n                            // y4)/(32*sqrt(pi))\n            dx[51] = -10.378311574405206f * xy * (x2 - y2) *\n                     (13.0f * z2 - 1.0f);  // -15*sqrt(385)*xy*(x2 - y2)*(13*z2\n                                           // - 1)/(16*sqrt(pi))\n            dx[52] = 4.1513246297620823f * yz * (3.0f * x2 - y2) *\n                     (13.0f * z2 - 3.0f);  // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2\n                                           // - 3)/(8*sqrt(pi))\n            dx[53] =\n                    0.93875360317376422f * xy *\n                    (66.0f * z2 - 143.0f * z4 -\n                     3.0f);  // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))\n            dx[54] = 0.44253269244498261f * yz *\n                     (-110.0f * z2 + 143.0f * z4 +\n                      15.0f);  // 3*sqrt(70)*yz*(-110*z2 + 143*z4 +\n                               // 15)/(32*sqrt(pi))\n            dx[55] = 0.0f;     // 0\n            dx[56] = 0.0f;     // 0\n            dx[57] = -12.194767023639836f * z2 + 44.714145753346067f * z4 -\n                     38.752259652899923f * z6 +\n                     0.45165803791258652f;  // sqrt(105)*(-135*z2 + 495*z4 -\n                                            // 429*z6 + 5)/(64*sqrt(pi))\n            dx[58] = 0.44253269244498261f * xz *\n                     (-110.0f * z2 + 143.0f * z4 +\n                      15.0f);  // 3*sqrt(70)*xz*(-110*z2 + 143*z4 +\n                               // 15)/(32*sqrt(pi))\n            dx[59] = 30.97886890473422f * x2 * z2 -\n                     67.120882626924143f * x2 * z4 - 1.4081304047606462f * x2 -\n                     30.97886890473422f * y2 * z2 +\n                     67.120882626924143f * y2 * z4 +\n                     1.4081304047606462f *\n                             y2;  // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 -\n                                  // 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))\n            dx[60] = 4.1513246297620823f * xz * (x2 - 3.0f * y2) *\n                     (13.0f * z2 - 3.0f);  // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2\n                                           // - 3)/(8*sqrt(pi))\n            dx[61] =\n                    -0.51891557872026028f * (13.0f * z2 - 1.0f) *\n                    (-10.0f * x2 * y2 + 4.0f * x2 * (x2 - 5.0f * y2) + x4 +\n                     5.0f * y4);  // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 +\n                                  // 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))\n            dx[62] = 15.875763970811402f * xz *\n                     (-10.0f * x2 * y2 + x4 +\n                      5.0f * y4);  // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 +\n                                   // 5*y4)/(32*sqrt(pi))\n            dx[63] = -74.252086915082614f * x2 * y4 +\n                     74.252086915082614f * x4 * y2 - 4.9501391276721742f * x6 +\n                     4.9501391276721742f *\n                             y6;  // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 +\n                                  // y6)/(64*sqrt(pi))\n        };\n\n        auto write_sh_dy = [&]() {\n            dy[0] = 0.0f;  // 0\n            if (C <= 1) {\n                return;\n            }\n            dy[1] = -0.48860251190291992f;  // -sqrt(3)/(2*sqrt(pi))\n            dy[2] = 0.0f;                   // 0\n            dy[3] = 0.0f;                   // 0\n            if (C <= 2) {\n                return;\n            }\n            dy[4] = 1.0925484305920792f * x;   // sqrt(15)*x/(2*sqrt(pi))\n            dy[5] = -1.0925484305920792f * z;  // -sqrt(15)*z/(2*sqrt(pi))\n            dy[6] = 0.0f;                      // 0\n            dy[7] = 0.0f;                      // 0\n            dy[8] = -1.0925484305920792f * y;  // -sqrt(15)*y/(2*sqrt(pi))\n            if (C <= 3) {\n                return;\n            }\n            dy[9] = -1.7701307697799304f * x2 +\n                    1.7701307697799304f *\n                            y2;  // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))\n            dy[10] = 2.8906114426405538f * xz;  // sqrt(105)*xz/(2*sqrt(pi))\n            dy[11] = 0.45704579946446572f -\n                     2.2852289973223288f *\n                             z2;  // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))\n            dy[12] = 0.0f;        // 0\n            dy[13] = 0.0f;        // 0\n            dy[14] = -2.8906114426405538f * yz;  // -sqrt(105)*yz/(2*sqrt(pi))\n            dy[15] = 3.5402615395598609f * xy;   // 3*sqrt(70)*xy/(4*sqrt(pi))\n            if (C <= 4) {\n                return;\n            }\n            dy[16] = 2.5033429417967046f * x *\n                     (x2 - 3.0f * y2);  // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))\n            dy[17] = 5.3103923093397913f * z *\n                     (-x2 + y2);  // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))\n            dy[18] = 0.94617469575756008f * x *\n                     (7.0f * z2 - 1.0f);  // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))\n            dy[19] =\n                    0.66904654355728921f * z *\n                    (3.0f - 7.0f * z2);  // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))\n            dy[20] = 0.0f;               // 0\n            dy[21] = 0.0f;               // 0\n            dy[22] = 0.94617469575756008f * y *\n                     (1.0f - 7.0f * z2);  // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))\n            dy[23] = 10.620784618679583f * xy *\n                     z;  // 9*sqrt(70)*xy*z/(4*sqrt(pi))\n            dy[24] = 2.5033429417967046f * y *\n                     (-3.0f * x2 +\n                      y2);  // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))\n            if (C <= 5) {\n                return;\n            }\n            dy[25] = 19.6914617052051f * x2 * y2 - 3.2819102842008503f * x4 -\n                     3.2819102842008503f * y4;  // 15*sqrt(154)*(6*x2*y2 - x4 -\n                                                // y4)/(32*sqrt(pi))\n            dy[26] = 8.3026492595241645f * xz *\n                     (x2 -\n                      3.0f * y2);  // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))\n            dy[27] = -1.4677148983057511f * (x2 - y2) *\n                     (9.0f * z2 -\n                      1.0f);  // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))\n            dy[28] = 4.7935367849733241f * xz *\n                     (3.0f * z2 -\n                      1.0f);  // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))\n            dy[29] = 6.3412531167397574f * z2 - 9.5118796751096362f * z4 -\n                     0.45294665119569694f;  // sqrt(165)*(14*z2 - 21*z4 -\n                                            // 1)/(16*sqrt(pi))\n            dy[30] = 0.0f;                  // 0\n            dy[31] = 0.0f;                  // 0\n            dy[32] = 4.7935367849733241f * yz *\n                     (1.0f -\n                      3.0f * z2);  // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))\n            dy[33] = 2.9354297966115022f * xy *\n                     (9.0f * z2 -\n                      1.0f);  // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))\n            dy[34] = 8.3026492595241645f * yz *\n                     (-3.0f * x2 +\n                      y2);  // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))\n            dy[35] = 13.127641136803401f * xy *\n                     (x2 - y2);  // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))\n            if (C <= 6) {\n                return;\n            }\n            dy[36] = 4.0991046311514854f * x *\n                     (-10.0f * x2 * y2 + x4 +\n                      5.0f * y4);  // 3*sqrt(6006)*x*(-10*x2*y2 + x4 +\n                                   // 5*y4)/(32*sqrt(pi))\n            dy[37] = 11.833095811158762f * z *\n                     (6.0f * x2 * y2 - x4 -\n                      y4);  // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n            dy[38] = 2.0182596029148963f * x * (x2 - 3.0f * y2) *\n                     (11.0f * z2 - 1.0f);  // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 -\n                                           // 1)/(8*sqrt(pi))\n            dy[39] = -2.7636157785447706f * z * (x2 - y2) *\n                     (11.0f * z2 - 3.0f);  // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 -\n                                           // 3)/(32*sqrt(pi))\n            dy[40] = 0.92120525951492349f * x *\n                     (-18.0f * z2 + 33.0f * z4 +\n                      1.0f);  // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n            dy[41] = 0.58262136251873131f * z *\n                     (30.0f * z2 - 33.0f * z4 -\n                      5.0f);  // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n            dy[42] = 0.0f;    // 0\n            dy[43] = 0.0f;    // 0\n            dy[44] = 0.92120525951492349f * y *\n                     (18.0f * z2 - 33.0f * z4 -\n                      1.0f);  // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))\n            dy[45] = 5.5272315570895412f * xy * z *\n                     (11.0f * z2 -\n                      3.0f);  // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))\n            dy[46] = -2.0182596029148963f * y * (3.0f * x2 - y2) *\n                     (11.0f * z2 - 1.0f);  // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 -\n                                           // 1)/(8*sqrt(pi))\n            dy[47] = 47.332383244635047f * xy * z *\n                     (x2 - y2);  // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))\n            dy[48] = 4.0991046311514854f * y *\n                     (10.0f * x2 * y2 - 5.0f * x4 -\n                      y4);  // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 -\n                            // y4)/(32*sqrt(pi))\n            if (C <= 7) {\n                return;\n            }\n            dy[49] = -74.252086915082614f * x2 * y4 +\n                     74.252086915082614f * x4 * y2 - 4.9501391276721742f * x6 +\n                     4.9501391276721742f *\n                             y6;  // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 +\n                                  // y6)/(64*sqrt(pi))\n            dy[50] = 15.875763970811402f * xz *\n                     (-10.0f * x2 * y2 + x4 +\n                      5.0f * y4);  // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 +\n                                   // 5*y4)/(32*sqrt(pi))\n            dy[51] = 0.51891557872026028f * (13.0f * z2 - 1.0f) *\n                     (10.0f * x2 * y2 - 5.0f * x4 +\n                      4.0f * y2 * (5.0f * x2 - y2) -\n                      y4);  // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 +\n                            // 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))\n            dy[52] = 4.1513246297620823f * xz * (x2 - 3.0f * y2) *\n                     (13.0f * z2 - 3.0f);  // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2\n                                           // - 3)/(8*sqrt(pi))\n            dy[53] = -0.46937680158688211f * (x2 - y2) *\n                     (13.0f * z2 * (11.0f * z2 - 3.0f) - 27.0f * z2 +\n                      3.0f);  // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) -\n                              // 27*z2 + 3)/(64*sqrt(pi))\n            dy[54] = 0.44253269244498261f * xz *\n                     (-110.0f * z2 + 143.0f * z4 +\n                      15.0f);  // 3*sqrt(70)*xz*(-110*z2 + 143*z4 +\n                               // 15)/(32*sqrt(pi))\n            dy[55] = -12.194767023639836f * z2 + 44.714145753346067f * z4 -\n                     38.752259652899923f * z6 +\n                     0.45165803791258652f;  // sqrt(105)*(-135*z2 + 495*z4 -\n                                            // 429*z6 + 5)/(64*sqrt(pi))\n            dy[56] = 0.0f;  // 0\n            dy[57] = 0.0f;  // 0\n            dy[58] = 0.44253269244498261f * yz *\n                     (110.0f * z2 - 143.0f * z4 -\n                      15.0f);  // 3*sqrt(70)*yz*(110*z2 - 143*z4 -\n                               // 15)/(32*sqrt(pi))\n            dy[59] = 0.93875360317376422f * xy *\n                     (-66.0f * z2 + 143.0f * z4 +\n                      3.0f);  // 9*sqrt(35)*xy*(-66*z2 + 143*z4 +\n                              // 3)/(32*sqrt(pi))\n            dy[60] = -4.1513246297620823f * yz * (3.0f * x2 - y2) *\n                     (13.0f * z2 - 3.0f);  // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2\n                                           // - 3)/(8*sqrt(pi))\n            dy[61] = 10.378311574405206f * xy * (x2 - y2) *\n                     (13.0f * z2 - 1.0f);  // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 -\n                                           // 1)/(16*sqrt(pi))\n            dy[62] = 15.875763970811402f * yz *\n                     (10.0f * x2 * y2 - 5.0f * x4 -\n                      y4);  // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 -\n                            // y4)/(32*sqrt(pi))\n            dy[63] = 9.9002782553443485f * xy *\n                     (-10.0f * x2 * y2 + 3.0f * x4 +\n                      3.0f * y4);  // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 +\n                                   // 3*y4)/(32*sqrt(pi))\n        };\n\n        auto write_sh_dz = [&]() {\n            dz[0] = 0.0f;  // 0\n            if (C <= 1) {\n                return;\n            }\n            dz[1] = 0.0f;                  // 0\n            dz[2] = 0.48860251190291992f;  // sqrt(3)/(2*sqrt(pi))\n            dz[3] = 0.0f;                  // 0\n            if (C <= 2) {\n                return;\n            }\n            dz[4] = 0.0f;                      // 0\n            dz[5] = -1.0925484305920792f * y;  // -sqrt(15)*y/(2*sqrt(pi))\n            dz[6] = 1.8923493915151202f * z;   // 3*sqrt(5)*z/(2*sqrt(pi))\n            dz[7] = -1.0925484305920792f * x;  // -sqrt(15)*x/(2*sqrt(pi))\n            dz[8] = 0.0f;                      // 0\n            if (C <= 3) {\n                return;\n            }\n            dz[9] = 0.0f;                        // 0\n            dz[10] = 2.8906114426405538f * xy;   // sqrt(105)*xy/(2*sqrt(pi))\n            dz[11] = -4.5704579946446566f * yz;  // -5*sqrt(42)*yz/(4*sqrt(pi))\n            dz[12] = 5.597644988851731f * z2 -\n                     1.1195289977703462f;  // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))\n            dz[13] = -4.5704579946446566f * xz;  // -5*sqrt(42)*xz/(4*sqrt(pi))\n            dz[14] = 1.4453057213202769f * x2 -\n                     1.4453057213202769f *\n                             y2;  // sqrt(105)*(x2 - y2)/(4*sqrt(pi))\n            dz[15] = 0.0f;        // 0\n            if (C <= 4) {\n                return;\n            }\n            dz[16] = 0.0f;  // 0\n            dz[17] = 1.7701307697799304f * y *\n                     (-3.0f * x2 +\n                      y2);  // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))\n            dz[18] = 13.246445740605839f * xy *\n                     z;  // 21*sqrt(5)*xy*z/(2*sqrt(pi))\n            dz[19] =\n                    2.0071396306718676f * y *\n                    (1.0f - 7.0f * z2);  // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))\n            dz[20] = 14.809976568128603f * pow(z, 3) -\n                     6.3471328149122579f * z;  // (105*z**3 - 45*z)/(4*sqrt(pi))\n            dz[21] =\n                    2.0071396306718676f * x *\n                    (1.0f - 7.0f * z2);  // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))\n            dz[22] = 6.6232228703029197f * z *\n                     (x2 - y2);  // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))\n            dz[23] = 1.7701307697799304f * x *\n                     (-x2 +\n                      3.0f * y2);  // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))\n            dz[24] = 0.0f;         // 0\n            if (C <= 5) {\n                return;\n            }\n            dz[25] = 0.0f;  // 0\n            dz[26] = 8.3026492595241645f * xy *\n                     (x2 - y2);  // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))\n            dz[27] = 8.8062893898345074f * yz *\n                     (-3.0f * x2 +\n                      y2);  // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))\n            dz[28] = 4.7935367849733241f * xy *\n                     (9.0f * z2 -\n                      1.0f);  // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))\n            dz[29] = 12.682506233479513f * yz *\n                     (1.0f -\n                      3.0f * z2);  // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))\n            dz[30] = -24.559567715218954f * z2 + 36.839351572828434f * z4 +\n                     1.754254836801354f;  // 15*sqrt(11)*(-14*z2 + 21*z4 +\n                                          // 1)/(16*sqrt(pi))\n            dz[31] = 12.682506233479513f * xz *\n                     (1.0f -\n                      3.0f * z2);  // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))\n            dz[32] = 2.3967683924866621f * (x2 - y2) *\n                     (9.0f * z2 -\n                      1.0f);  // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))\n            dz[33] = 8.8062893898345074f * xz *\n                     (-x2 +\n                      3.0f * y2);  // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))\n            dz[34] = -12.453973889286246f * x2 * y2 + 2.0756623148810411f * x4 +\n                     2.0756623148810411f * y4;  // 3*sqrt(385)*(-6*x2*y2 + x4 +\n                                                // y4)/(16*sqrt(pi))\n            dz[35] = 0.0f;                      // 0\n            if (C <= 6) {\n                return;\n            }\n            dz[36] = 0.0f;  // 0\n            dz[37] = 2.3666191622317521f * y *\n                     (10.0f * x2 * y2 - 5.0f * x4 -\n                      y4);  // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 -\n                            // y4)/(32*sqrt(pi))\n            dz[38] = 44.401711264127719f * xy * z *\n                     (x2 - y2);  // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))\n            dz[39] = -2.7636157785447706f * y * (3.0f * x2 - y2) *\n                     (11.0f * z2 - 1.0f);  // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2\n                                           // - 1)/(32*sqrt(pi))\n            dz[40] = 11.054463114179082f * xy * z *\n                     (11.0f * z2 -\n                      3.0f);  // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))\n            dz[41] = 2.9131068125936568f * y *\n                     (18.0f * z2 - 33.0f * z4 -\n                      1.0f);  // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))\n            dz[42] =\n                    2.6699064952403937f * z *\n                    (-30.0f * z2 + 33.0f * z4 +\n                     5.0f);  // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))\n            dz[43] = 2.9131068125936568f * x *\n                     (18.0f * z2 - 33.0f * z4 -\n                      1.0f);  // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))\n            dz[44] = 5.5272315570895412f * z * (x2 - y2) *\n                     (11.0f * z2 - 3.0f);  // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 -\n                                           // 3)/(16*sqrt(pi))\n            dz[45] = -2.7636157785447706f * x * (x2 - 3.0f * y2) *\n                     (11.0f * z2 - 1.0f);  // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2\n                                           // - 1)/(32*sqrt(pi))\n            dz[46] = 11.10042781603193f * z *\n                     (-6.0f * x2 * y2 + x4 +\n                      y4);  // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n            dz[47] = 2.3666191622317521f * x *\n                     (10.0f * x2 * y2 - x4 -\n                      5.0f * y4);  // 3*sqrt(2002)*x*(10*x2*y2 - x4 -\n                                   // 5*y4)/(32*sqrt(pi))\n            dz[48] = 0.0f;         // 0\n            if (C <= 7) {\n                return;\n            }\n            dz[49] = 0.0f;  // 0\n            dz[50] = 5.2919213236038001f * xy *\n                     (-10.0f * x2 * y2 + 3.0f * x4 +\n                      3.0f * y4);  // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 +\n                                   // 3*y4)/(32*sqrt(pi))\n            dz[51] = 13.491805046726766f * yz *\n                     (10.0f * x2 * y2 - 5.0f * x4 -\n                      y4);  // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 -\n                            // y4)/(32*sqrt(pi))\n            dz[52] = 12.453973889286248f * xy * (x2 - y2) *\n                     (13.0f * z2 - 1.0f);  // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 -\n                                           // 1)/(8*sqrt(pi))\n            dz[53] = -6.8841930899409371f * yz * (3.0f * x2 - y2) *\n                     (13.0f * z2 - 3.0f);  // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2\n                                           // - 3)/(16*sqrt(pi))\n            dz[54] = 2.2126634622249131f * xy *\n                     (-66.0f * z2 + 143.0f * z4 +\n                      3.0f);  // 15*sqrt(70)*xy*(-66*z2 + 143*z4 +\n                              // 3)/(32*sqrt(pi))\n            dz[55] = 1.6259689364853116f * yz *\n                     (110.0f * z2 - 143.0f * z4 -\n                      15.0f);  // 9*sqrt(105)*yz*(110*z2 - 143*z4 -\n                               // 15)/(32*sqrt(pi))\n            dz[56] = 64.528641681844675f * z2 - 236.60501950009714f * z4 +\n                     205.05768356675085f * z6 -\n                     2.3899496919201733f;  // 7*sqrt(15)*(135*z2 - 495*z4 +\n                                           // 429*z6 - 5)/(32*sqrt(pi))\n            dz[57] = 1.6259689364853116f * xz *\n                     (110.0f * z2 - 143.0f * z4 -\n                      15.0f);  // 9*sqrt(105)*xz*(110*z2 - 143*z4 -\n                               // 15)/(32*sqrt(pi))\n            dz[58] = 0.07375544874083044f * (x2 - y2) *\n                     (143.0f * z2 * (3.0f * z2 - 1.0f) +\n                      132.0f * z2 * (13.0f * z2 - 5.0f) - 187.0f * z2 +\n                      45.0f);  // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) +\n                               // 132*z2*(13*z2\n                               // - 5) - 187*z2 + 45)/(64*sqrt(pi))\n            dz[59] = -6.8841930899409371f * xz * (x2 - 3.0f * y2) *\n                     (13.0f * z2 - 3.0f);  // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2\n                                           // - 3)/(16*sqrt(pi))\n            dz[60] = 3.1134934723215619f * (13.0f * z2 - 1.0f) *\n                     (-6.0f * x2 * y2 + x4 +\n                      y4);  // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 +\n                            // y4)/(32*sqrt(pi))\n            dz[61] = 13.491805046726766f * xz *\n                     (10.0f * x2 * y2 - x4 -\n                      5.0f * y4);  // 39*sqrt(385)*xz*(10*x2*y2 - x4 -\n                                   // 5*y4)/(32*sqrt(pi))\n            dz[62] =\n                    39.6894099270285f * x2 * y4 - 39.6894099270285f * x4 * y2 +\n                    2.6459606618019f * x6 -\n                    2.6459606618019f * y6;  // 3*sqrt(10010)*(15*x2*y4 -\n                                            // 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n            dz[63] = 0.0f;                  // 0\n        };\n        write_sh_dx();\n        write_sh_dy();\n        write_sh_dz();\n    }\n}\n\ntemplate <typename scalar_t>\n__global__ void kernel_sh_backward(const scalar_t *__restrict__ grad,\n                                   const scalar_t *__restrict__ inputs,\n                                   uint32_t B,\n                                   uint32_t D,\n                                   uint32_t C,\n                                   const scalar_t *__restrict__ dy_dx,\n                                   scalar_t *grad_inputs) {\n    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n    const uint32_t b = t / D;\n    if (b >= B) return;\n\n    const uint32_t d = t - b * D;\n    const uint32_t C2 = C * C;\n\n    // locate\n    grad += b * C2;\n    dy_dx += b * D * C2 + d * C2;\n\n    for (int ch = 0; ch < C2; ch++) {\n        grad_inputs[t] += grad[ch] * dy_dx[ch];\n        // printf(\"t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\\n\", t, b, d,\n        // ch, grad_inputs[t], grad[ch], dy_dx[ch]);\n    }\n}\n\n// inputs: [B, D], float, in [0, 1]\n// outputs: [B, L * C], float\ntemplate <typename scalar_t>\nvoid sh_encode_forward_cuda(const scalar_t *inputs,\n                            scalar_t *outputs,\n                            const uint32_t B,\n                            const uint32_t D,\n                            const uint32_t C,\n                            scalar_t *dy_dx) {\n    static constexpr uint32_t N_THREADS = 256;\n    kernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(\n            inputs, outputs, B, D, C, dy_dx);\n}\n\ntemplate <typename scalar_t>\nvoid sh_encode_backward_cuda(const scalar_t *grad,\n                             const scalar_t *inputs,\n                             const uint32_t B,\n                             const uint32_t D,\n                             const uint32_t C,\n                             scalar_t *dy_dx,\n                             scalar_t *grad_inputs) {\n    static constexpr uint32_t N_THREADS = 256;\n    kernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(\n            grad, inputs, B, D, C, dy_dx, grad_inputs);\n}\n\nvoid sh_encode_forward(at::Tensor inputs,\n                       at::Tensor outputs,\n                       const uint32_t B,\n                       const uint32_t D,\n                       const uint32_t C,\n                       at::optional<at::Tensor> dy_dx) {\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(outputs);\n    // CHECK_CUDA(dy_dx);\n\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(outputs);\n    // CHECK_CONTIGUOUS(dy_dx);\n\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(outputs);\n    // CHECK_IS_FLOATING(dy_dx);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            inputs.scalar_type(), \"sh_encode_forward_cuda\", ([&] {\n                sh_encode_forward_cuda<scalar_t>(\n                        inputs.data_ptr<scalar_t>(),\n                        outputs.data_ptr<scalar_t>(), B, D, C,\n                        dy_dx.has_value() ? dy_dx.value().data_ptr<scalar_t>()\n                                          : nullptr);\n            }));\n}\n\nvoid sh_encode_backward(at::Tensor grad,\n                        at::Tensor inputs,\n                        const uint32_t B,\n                        const uint32_t D,\n                        const uint32_t C,\n                        at::Tensor dy_dx,\n                        at::Tensor grad_inputs) {\n    CHECK_CUDA(grad);\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(dy_dx);\n    CHECK_CUDA(grad_inputs);\n\n    CHECK_CONTIGUOUS(grad);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(dy_dx);\n    CHECK_CONTIGUOUS(grad_inputs);\n\n    CHECK_IS_FLOATING(grad);\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(dy_dx);\n    CHECK_IS_FLOATING(grad_inputs);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n            grad.scalar_type(), \"sh_encode_backward_cuda\", ([&] {\n                sh_encode_backward_cuda<scalar_t>(\n                        grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(),\n                        B, D, C, dy_dx.data_ptr<scalar_t>(),\n                        grad_inputs.data_ptr<scalar_t>());\n            }));\n}"
  },
  {
    "path": "lidarnerf/shencoder/src/shencoder.h",
    "content": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [-1, 1]\n// outputs: [B, F], float\n\nvoid sh_encode_forward(at::Tensor inputs,\n                       at::Tensor outputs,\n                       const uint32_t B,\n                       const uint32_t D,\n                       const uint32_t C,\n                       at::optional<at::Tensor> dy_dx);\nvoid sh_encode_backward(at::Tensor grad,\n                        at::Tensor inputs,\n                        const uint32_t B,\n                        const uint32_t D,\n                        const uint32_t C,\n                        at::Tensor dy_dx,\n                        at::Tensor grad_inputs);"
  },
  {
    "path": "lidarnvs/__init__.py",
    "content": "__version__ = \"0.0.1\"\n"
  },
  {
    "path": "lidarnvs/configs/pcgen_kitti360_raydrop.txt",
    "content": "basedir = pcgen_raydrop_log/kitti360seq1908\ndatadir = data/raydrop/pcgen/kitti360_1908\ndataset = kitti360\nno_batching = False\nlrate=5e-3\nlrate_decay = 500\nrgb_loss_type=mseloss\ni_embed=-1\ni_embed_views=-1\nN_iters = 10000\ncosLR=False\n\nnetdepth=4\nnetwidth=128\nN_rand = 2048\nH=66\nW=1030\ni_save=5000\ni_print=100\ni_weights=5000\n\n\n\n"
  },
  {
    "path": "lidarnvs/configs/pcgen_nerfmvl_raydrop.txt",
    "content": "# ['water_safety_barrier', 'tire', 'pier', 'plant', 'warning_sign', 'bollard', 'pedestrian', 'car',  'traffic_cone']\nexpname = car\nbasedir = pcgen_raydrop_log\ndatadir = data/raydrop/pcgen/nerf_mvl_car\ndataset = nerfmvl\nno_batching = False\nlrate=5e-3\nlrate_decay = 500\nrgb_loss_type=mseloss\nN_iters = 10000\ncosLR=False\nnetdepth=4\nnetwidth=128\nN_rand = 2048\nH=256\nW=1800\ni_save=5000\ni_print=100\ni_weights=5000\n\n\n\n"
  },
  {
    "path": "lidarnvs/eval.py",
    "content": "import numpy as np\nimport torch\n\nfrom skimage.metrics import structural_similarity\nfrom extern.chamfer3D.dist_chamfer_3D import chamfer_3DDist\nfrom extern.fscore import fscore\n\n\ndef eval_points_and_pano(\n    gt_local_points: np.ndarray,\n    pd_local_points: np.ndarray,\n    gt_intensities: np.ndarray,\n    pd_intensities: np.ndarray,\n    gt_pano: np.ndarray,\n    pd_pano: np.ndarray,\n) -> dict:\n    \"\"\"\n    Args:\n        gt_local_points: (N, 3), float32, local point coords in world-scale.\n        pd_local_points: (M, 3), float32, local point coords in world-scale.\n        gt_intensities: (H, W), float32, point intensities, >= 0.\n        pd_intensities: (H, W), float32, point intensities, >= 0.\n        gt_pano: (H, W), float32, range depth image in world-scale.\n            0 means dropped rays. A dropped ray must not have intensity.\n        pd_pano: (H, W), float32, range depth image in world-scale.\n            0 means dropped rays. A dropped ray must not have intensity.\n\n    Returns:\n        # Depth metrics\n        - metrics[\"depth_rmse\"]\n        - metrics[\"depth_a1\"]\n        - metrics[\"depth_a2\"]\n        - metrics[\"depth_a3\"]\n        # Point metrics\n        - metrics[\"chamfer\"]\n        - metrics[\"f_score\"]\n        # Intensity metrics\n        - metrics[\"intensity_mae\"]\n    \"\"\"\n    # Sanity checks.\n    if not gt_local_points.ndim == 2 or not gt_local_points.shape[1] == 3:\n        raise ValueError(\n            f\"gt_local_points must be (N, 3), but got {gt_local_points.shape}\"\n        )\n    if not pd_local_points.ndim == 2 or not pd_local_points.shape[1] == 3:\n        raise ValueError(\n            f\"pd_local_points must be (M, 3), but got {pd_local_points.shape}\"\n        )\n    if not gt_intensities.ndim == 2:\n        raise ValueError(\n            f\"gt_intensities must be (H, W), but got {gt_intensities.shape}\"\n        )\n    lidar_H, lidar_W = gt_intensities.shape\n    if not pd_intensities.shape == (lidar_H, lidar_W):\n        raise ValueError(\n            f\"pd_intensities must be (H, W), but got {pd_intensities.shape}\"\n        )\n    if not gt_pano.shape == (lidar_H, lidar_W):\n        raise ValueError(f\"gt_pano must be (H, W), but got {gt_pano.shape}\")\n    if not pd_pano.shape == (lidar_H, lidar_W):\n        raise ValueError(f\"pd_pano must be (H, W), but got {pd_pano.shape}\")\n\n    # All shall be numpy array\n    is_instance_all = [\n        isinstance(e, np.ndarray)\n        for e in [\n            gt_local_points,\n            pd_local_points,\n            gt_intensities,\n            pd_intensities,\n            gt_pano,\n            pd_pano,\n        ]\n    ]\n    if not all(is_instance_all):\n        raise ValueError(\"All inputs must be numpy array.\")\n\n    def compute_depth_metrics(\n        gt_depths, pd_depths, min_depth=1e-3, max_depth=80, thresh_set=1.25\n    ):\n        pd_depths[pd_depths < min_depth] = min_depth\n        pd_depths[pd_depths > max_depth] = max_depth\n        gt_depths[gt_depths < min_depth] = min_depth\n        gt_depths[gt_depths > max_depth] = max_depth\n\n        thresh = np.maximum((gt_depths / pd_depths), (pd_depths / gt_depths))\n        a1 = (thresh < thresh_set).mean()\n        a2 = (thresh < thresh_set**2).mean()\n        a3 = (thresh < thresh_set**3).mean()\n        rmse = (gt_depths - pd_depths) ** 2\n        rmse = np.sqrt(rmse.mean())\n        ssim = structural_similarity(\n            gt_depths,\n            pd_depths,\n            data_range=gt_depths.max() - gt_depths.min(),\n        )\n        return rmse, a1, a2, a3, ssim\n\n    def compute_point_metrics(gt_points, pd_points):\n        chamLoss = chamfer_3DDist()\n        dist1, dist2, idx1, idx2 = chamLoss(\n            torch.tensor(pd_points[None, ...]).float().cuda(),\n            torch.tensor(gt_points[None, ...]).float().cuda(),\n        )\n        chamfer_dis = dist1.mean() + dist2.mean()\n        threshold = 0.05  # monoSDF\n        f_score, precision, recall = fscore(dist1, dist2, threshold)\n\n        chamfer_dis = chamfer_dis.item()\n        f_score = f_score.item()\n        return chamfer_dis, f_score\n\n    def compute_intensity_metrics(gt_intensities, pd_intensities):\n        mae = np.abs(gt_intensities - pd_intensities).mean()\n        return mae\n\n    metrics = dict()\n    (\n        metrics[\"depth_rmse\"],\n        metrics[\"depth_a1\"],\n        metrics[\"depth_a2\"],\n        metrics[\"depth_a3\"],\n        metrics[\"depth_ssim\"],\n    ) = compute_depth_metrics(gt_depths=gt_pano.flatten(), pd_depths=pd_pano.flatten())\n\n    (\n        metrics[\"chamfer\"],\n        metrics[\"f_score\"],\n    ) = compute_point_metrics(gt_points=gt_local_points, pd_points=pd_local_points)\n\n    metrics[\"intensity_mae\"] = compute_intensity_metrics(\n        gt_intensities=gt_intensities, pd_intensities=pd_intensities\n    )\n\n    return metrics\n"
  },
  {
    "path": "lidarnvs/lidarnvs_base.py",
    "content": "from abc import ABC, abstractmethod\n\nimport numpy as np\n\n\nclass LidarNVSBase(ABC):\n    @abstractmethod\n    def fit(self, dataset) -> None:\n        \"\"\"\n        Fit the model to the given train dataset.\n\n        Args:\n            dataset: A NeRFDataset object.\n        \"\"\"\n\n    @abstractmethod\n    def predict_frame(\n        self,\n        lidar_K: np.ndarray,  # (2, )\n        lidar_pose: np.ndarray,  # (4, 4)\n        lidar_H: int,\n        lidar_W: int,\n    ) -> dict:\n        \"\"\"\n        Predict (synthesis) the point cloud from the given lidar parameters.\n        All necessary information parameters to model a lidar are given.\n\n        Args:\n            lidar_K: (2, ), float32\n            lidar_pose: (4, 4), float32\n            lidar_H: int\n            lidar_W: int\n\n        Return:\n            predict_dict: dict\n            - [\"local_points\"]: (N, 3), float32\n            - [\"points\"]      : (N, 3), float32\n            - [\"pano\"]        : (H, W), float32\n            - [\"intensities\"] : (H, W), float32\n        \"\"\"\n\n    @abstractmethod\n    def predict_frame_with_raydrop(\n        self,\n        lidar_K: np.ndarray,  # (2, )\n        lidar_pose: np.ndarray,  # (4, 4)\n        lidar_H: int,\n        lidar_W: int,\n    ) -> dict:\n        pass\n"
  },
  {
    "path": "lidarnvs/lidarnvs_meshing.py",
    "content": "import camtools as ct\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport open3d as o3d\nimport open3d.core as o3c\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom lidarnerf.convert import (\n    lidar_to_pano_with_intensities,\n    pano_to_lidar_with_intensities,\n)\nfrom lidarnerf.dataset.base_dataset import get_lidar_rays\nfrom lidarnvs.lidarnvs_base import LidarNVSBase\nfrom lidarnvs.loader import extract_dataset_frame\nfrom lidarnvs.unet import UNet\n\n\nclass LidarNVSMeshing(LidarNVSBase):\n    \"\"\"\n    Liar novel-view synthesis with meshing and ray casting.\n\n    This is intended to be a base class, where the children class can use\n    different meshing methods.\n    \"\"\"\n\n    def __init__(self, ckpt_path=None):\n        self.ckpt_path = ckpt_path\n\n        # Network for predicting ray-drop.\n        if ckpt_path is not None:\n            self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n            self.model = UNet(n_channels=10, n_classes=1, bilinear=False)\n            self.model = self.model.to(memory_format=torch.channels_last)\n            self.model = self.model.to(device=self.device)\n\n            state_dict = torch.load(self.ckpt_path, map_location=self.device)\n            self.model.load_state_dict(state_dict)\n            self.model.eval()\n            print(f\"Checkpoint loaded from {self.ckpt_path}\")\n\n        # To be filled in the fit() method.\n        self.points = None\n        self.point_intensities = None\n        self.pcd = None\n        self.kdtree = None\n        self.mesh = None\n\n        # To be overwritten by the child class.\n        # - meshing_func: o3d.geometry.PointCloud -> o3d.geometry.TriangleMesh\n        # - the meshing_func shall already be populated with hyper-parameters\n        self.meshing_func = None\n\n    def fit(self, dataset) -> None:\n        \"\"\"\n        Fit the model to the given train dataset.\n\n        Args:\n            dataset: A NeRFDataset object.\n        \"\"\"\n        # Extract all points, in world coordinates.\n        num_frames = len(dataset)\n        all_points = []\n        all_point_intensities = []\n        for frame_idx in tqdm(range(num_frames), \"Extract train frames\"):\n            frame_dict = extract_dataset_frame(dataset, frame_idx)\n            all_points.append(frame_dict[\"points\"])\n            all_point_intensities.append(frame_dict[\"point_intensities\"])\n        all_points = np.vstack(all_points)\n\n        all_point_intensities = np.hstack(all_point_intensities)\n        assert len(all_points) == len(all_point_intensities)\n\n        # Build Open3D pcd.\n        self.pcd = o3d.geometry.PointCloud()\n        self.pcd.points = o3d.utility.Vector3dVector(all_points)\n        colors = ct.colormap.query(all_point_intensities)\n        self.pcd.colors = o3d.utility.Vector3dVector(colors)\n        self.pcd.estimate_normals()\n\n        # Save points and intensities for interpolation.\n        self.points = all_points\n        self.point_intensities = all_point_intensities\n\n        # Run Poisson recon.\n        self.mesh = self.meshing_func(self.pcd)\n        self.mesh.compute_vertex_normals()\n        # o3d.visualization.draw_geometries([self.mesh])\n\n        # Build kdtree for kNN search.\n        self.kdtree = o3d.geometry.KDTreeFlann(self.pcd)\n\n        # Build scene for ray casting.\n        self.raycasting_scene = o3d.t.geometry.RaycastingScene()\n        self.raycasting_scene.add_triangles(\n            o3d.t.geometry.TriangleMesh.from_legacy(self.mesh)\n        )\n\n    def predict_frame(\n        self,\n        lidar_K: np.ndarray,  # (2, )\n        lidar_pose: np.ndarray,  # (4, 4)\n        lidar_H: int,\n        lidar_W: int,\n    ) -> dict:\n        \"\"\"\n        Predict (synthesis) the point cloud from the given lidar parameters.\n        All necessary information parameters to model a lidar are given.\n\n        Args:\n            lidar_K: (2, ), float32\n            lidar_pose: (4, 4), float32\n            lidar_H: int\n            lidar_W: int\n\n        Return:\n            predict_dict: dict\n            - [\"local_points\"]: (N, 3), float32\n            - [\"points\"]      : (N, 3), float32\n            - [\"pano\"]        : (H, W), float32\n            - [\"intensities\"] : (H, W), float32\n        \"\"\"\n        # In world and local coordinates.\n        hit_dict = self.intersect_lidar(lidar_K, lidar_pose, lidar_H, lidar_W)\n        points = hit_dict[\"points\"][hit_dict[\"masks\"]]\n        local_points = ct.project.homo_project(\n            points,\n            ct.convert.pose_to_T(lidar_pose),\n        )\n\n        # Point intensities in world/local coordinates.\n        point_intensities = []\n        for point in points:\n            # ks, indices, distances2\n            _, indices, _ = self.kdtree.search_knn_vector_3d(\n                point, self.intensity_interpolate_k\n            )\n            point_intensities.append(np.mean(self.point_intensities[indices]))\n        point_intensities = np.array(point_intensities)\n        local_point_intensities = point_intensities\n\n        # Pano intensities.\n        local_points_with_intensities = np.concatenate(\n            [local_points, local_point_intensities.reshape((-1, 1))], axis=1\n        )\n        pano, intensities = lidar_to_pano_with_intensities(\n            local_points_with_intensities=local_points_with_intensities,\n            lidar_H=lidar_H,\n            lidar_W=lidar_W,\n            lidar_K=lidar_K,\n        )\n\n        predict_dict = {\n            # Frame properties.\n            \"pano\": pano,\n            \"intensities\": intensities,\n            # Global properties.\n            \"points\": points,\n            \"point_intensities\": point_intensities,\n            # Local properties.\n            \"local_points\": local_points,\n            \"local_point_intensities\": local_point_intensities,\n            # Hit properties: unfiltered results from ray casting.\n            \"hit_dict\": hit_dict,\n        }\n        return predict_dict\n\n    @torch.inference_mode()\n    def predict_frame_with_raydrop(\n        self,\n        lidar_K: np.ndarray,  # (2, )\n        lidar_pose: np.ndarray,  # (4, 4)\n        lidar_H: int,\n        lidar_W: int,\n    ) -> dict:\n        \"\"\"\n        TODO: I know this is ugly. This is the manual combination of:\n        - generate_raydrop_data()\n        - RaydropDataset::collate_fn()\n        \"\"\"\n        nvs_frame = self.predict_frame(\n            lidar_K=lidar_K,\n            lidar_pose=lidar_pose,\n            lidar_H=lidar_H,\n            lidar_W=lidar_W,\n        )\n\n        # Compute incidence angle cosine.\n        # TODO: make get rays a function.\n        ray_dict = get_lidar_rays(\n            poses=torch.tensor(np.array([lidar_pose])),\n            intrinsics=torch.tensor(lidar_K),\n            H=torch.tensor(lidar_H),\n            W=torch.tensor(lidar_W),\n        )\n\n        # generate_raydrop_data() ############################################\n        rays_o = ray_dict[\"rays_o\"].squeeze().numpy()\n        rays_d = ray_dict[\"rays_d\"].squeeze().numpy()\n        hit_normals = nvs_frame[\"hit_dict\"][\"normals\"]\n        hit_incidences = np.abs(np.sum(rays_d * hit_normals, axis=-1))\n\n        # Reshape.\n        hit_masks = nvs_frame[\"hit_dict\"][\"masks\"]\n        hit_masks = hit_masks.reshape((lidar_H, lidar_W))\n        hit_depths = nvs_frame[\"hit_dict\"][\"depths\"]\n        hit_depths[hit_depths == np.inf] = 0\n        hit_depths = hit_depths.reshape((lidar_H, lidar_W))\n        hit_normals = hit_normals.reshape((lidar_H, lidar_W, 3))\n        hit_incidences = hit_incidences.reshape((lidar_H, lidar_W))\n        intensities = nvs_frame[\"intensities\"]\n        intensities = intensities.reshape((lidar_H, lidar_W))\n        rays_o = rays_o.reshape((lidar_H, lidar_W, 3))\n        rays_d = rays_d.reshape((lidar_H, lidar_W, 3))\n\n        # Cast\n        hit_masks = torch.tensor(hit_masks.astype(np.float32))\n        hit_depths = torch.tensor(hit_depths.astype(np.float32))\n        hit_normals = torch.tensor(hit_normals.astype(np.float32))\n        hit_incidences = torch.tensor(hit_incidences.astype(np.float32))\n        intensities = torch.tensor(intensities.astype(np.float32))\n        rays_o = torch.tensor(rays_o.astype(np.float32))\n        rays_d = torch.tensor(rays_d.astype(np.float32))\n\n        # Add batch dimension 1 to the front\n        hit_masks = hit_masks.unsqueeze(0)\n        hit_depths = hit_depths.unsqueeze(0)\n        hit_normals = hit_normals.unsqueeze(0)\n        hit_incidences = hit_incidences.unsqueeze(0)\n        intensities = intensities.unsqueeze(0)\n        rays_o = rays_o.unsqueeze(0)\n        rays_d = rays_d.unsqueeze(0)\n        ######################################################################\n\n        # RaydropDataset::collate_fn() #######################################\n        # (N, H, W, C)\n        images = torch.cat(\n            [\n                hit_masks[..., None].to(self.device),\n                hit_depths[..., None].to(self.device),\n                hit_normals.to(self.device),\n                hit_incidences[..., None].to(self.device),\n                intensities[..., None].to(self.device),\n                rays_d.to(self.device),\n            ],\n            dim=3,\n        )\n        # (N, C, H, W)\n        images = images.permute(0, 3, 1, 2)\n        ######################################################################\n\n        # Predict raydrop mask.\n        pd_raydrop_masks = self.model(images)\n        pd_raydrop_masks = (F.sigmoid(pd_raydrop_masks) > 0.5).float()\n        pd_raydrop_masks = pd_raydrop_masks.squeeze().cpu().numpy()\n        if False:\n            plt.imshow(pd_raydrop_masks)\n            plt.show()\n\n        # Update predict_dict\n        # Frame properties.\n        pano = nvs_frame[\"pano\"] * pd_raydrop_masks\n        intensities = nvs_frame[\"intensities\"] * pd_raydrop_masks\n        # Local properties.\n        local_points_with_intensities = pano_to_lidar_with_intensities(\n            pano=pano,\n            intensities=intensities,\n            lidar_K=lidar_K,\n        )\n        local_points = local_points_with_intensities[:, :3]\n        local_point_intensities = local_points_with_intensities[:, 3]\n        # Global properties.\n        points = ct.project.homo_project(local_points, lidar_pose)\n        point_intensities = local_point_intensities\n\n        predict_dict = {\n            # Frame properties.\n            \"pano\": pano,\n            \"intensities\": intensities,\n            # Global properties.\n            \"points\": points,\n            \"point_intensities\": point_intensities,\n            # Local properties.\n            \"local_points\": local_points,\n            \"local_point_intensities\": local_point_intensities,\n            # Hit properties: unfiltered results from ray casting.\n            \"hit_dict\": nvs_frame[\"hit_dict\"],\n        }\n\n        return predict_dict\n\n    def intersect_rays(self, rays):\n        \"\"\"\n        Compute ray-mesh intersect and return the hit_dict.\n        The hit_dict will NOT be filtered, but the masks will be provided.\n\n        Args:\n            mesh: o3d.geometry.TriangleMesh\n            rays: (N, 6), float32 rays, where rays[:, :3] is the origin and\n                rays[:, 3:] is the direction. The directions do not need to be\n                normalized.\n\n        Return:\n            hit_dict\n            - [\"masks\"]  : (N, ) , boolean mask of ray hit.\n            - [\"depths\"]  : (N, ) , depth in world-scale.\n            - [\"points\"] : (N, 3), coordinates of the intersection points.\n            - [\"normals\"]: (N, 3), normals of the hit triangles.\n        \"\"\"\n        # Sanity checks.\n        if not isinstance(rays, np.ndarray):\n            raise TypeError(\"rays must be a numpy array.\")\n        if rays.ndim != 2 or rays.shape[1] != 6:\n            raise ValueError(\"rays must be a (N, 6) array.\")\n\n        # Run ray cast.\n        ray_cast_results = self.raycasting_scene.cast_rays(o3c.Tensor(rays))\n        normals = ray_cast_results[\"primitive_normals\"].numpy()\n        depths = ray_cast_results[\"t_hit\"].numpy()\n        masks = depths != np.inf\n        rays_o = rays[:, :3]\n        rays_d = rays[:, 3:]\n        rays_d = rays_d / np.linalg.norm(rays_d, axis=1, keepdims=True)\n        points = rays_o + rays_d * depths[:, None]\n\n        hit_dict = {\n            \"masks\": masks,\n            \"depths\": depths,\n            \"points\": points,\n            \"normals\": normals,\n        }\n\n        return hit_dict\n\n    def intersect_lidar(\n        self,\n        lidar_K: np.ndarray,  # (2, )\n        lidar_pose: np.ndarray,  # (4, 4)\n        lidar_H: int,\n        lidar_W: int,\n    ):\n        ray_dict = get_lidar_rays(\n            poses=torch.tensor(np.array([lidar_pose])),\n            intrinsics=torch.tensor(lidar_K),\n            H=torch.tensor(lidar_H),\n            W=torch.tensor(lidar_W),\n        )\n        rays_o = ray_dict[\"rays_o\"].squeeze().numpy()\n        rays_d = ray_dict[\"rays_d\"].squeeze().numpy()\n        rays = np.concatenate([rays_o, rays_d], axis=-1)\n        hit_dict = self.intersect_rays(rays)\n        return hit_dict\n\n\ndef generate_raydrop_data_meshing(dataset, nvs: LidarNVSMeshing) -> dict:\n    \"\"\"\n    Prepare dataset for learning ray drop.\n    The frames are NOT loaded by our dataset, but GENERATED.\n\n    Return:\n        raydrop_data = [\n            {\n                \"hit_masks\"      : (H, W)    # Ray cast hit mask\n                \"hit_depths\"     : (H, W)    # Hit intersection point depths\n                \"hit_normals\"    : (H, W, 3) # Intersection point normal\n                \"hit_incidences\" : (H, W)    # |cos(normal, ray_d)|\n                \"intensities\"    : (H, W)    # Predicted intensities\n                \"rays_o\"         : (H, W, 3) # Lidar ray origin\n                \"rays_d\"         : (H, W, 3) # Lidar ray direction\n                \"raydrop_masks\"  : (H, W)    # Ray drop mask, 1 is valid\n            },\n            ...\n        ]\n    \"\"\"\n    raydrop_data = []\n    for frame_idx in tqdm(range(len(dataset)), desc=\"Prepare raydrop dataset\"):\n        gt_frame = extract_dataset_frame(dataset, frame_idx=frame_idx)\n        nvs_frame = nvs.predict_frame(\n            lidar_K=gt_frame[\"lidar_K\"],\n            lidar_pose=gt_frame[\"lidar_pose\"],\n            lidar_H=gt_frame[\"lidar_H\"],\n            lidar_W=gt_frame[\"lidar_W\"],\n        )\n\n        # The target.\n        raydrop_masks = gt_frame[\"pano\"] != 0\n\n        # Compute incidence angle cosine.\n        rays_o = gt_frame[\"rays\"][:, :3]\n        rays_d = gt_frame[\"rays\"][:, 3:]\n        hit_normals = nvs_frame[\"hit_dict\"][\"normals\"]\n        hit_incidences = np.abs(np.sum(rays_d * hit_normals, axis=-1))\n\n        # Pre-processing.\n        # TODO: move the reshape to upper-level\n        lidar_H, lidar_W = gt_frame[\"lidar_H\"], gt_frame[\"lidar_W\"]\n\n        # Reshape.\n        hit_masks = nvs_frame[\"hit_dict\"][\"masks\"]\n        hit_masks = hit_masks.reshape((lidar_H, lidar_W))\n        hit_depths = nvs_frame[\"hit_dict\"][\"depths\"]\n        hit_depths[hit_depths == np.inf] = 0\n        hit_depths = hit_depths.reshape((lidar_H, lidar_W))\n        hit_normals = hit_normals.reshape((lidar_H, lidar_W, 3))\n        hit_incidences = hit_incidences.reshape((lidar_H, lidar_W))\n        intensities = nvs_frame[\"intensities\"]\n        intensities = intensities.reshape((lidar_H, lidar_W))\n        rays_o = rays_o.reshape((lidar_H, lidar_W, 3))\n        rays_d = rays_d.reshape((lidar_H, lidar_W, 3))\n        raydrop_masks = raydrop_masks.reshape((lidar_H, lidar_W))\n\n        # Cast.\n        hit_masks = hit_masks.astype(np.float32)\n        hit_depths = hit_depths.astype(np.float32)\n        hit_normals = hit_normals.astype(np.float32)\n        hit_incidences = hit_incidences.astype(np.float32)\n        intensities = intensities.astype(np.float32)\n        rays_o = rays_o.astype(np.float32)\n        rays_d = rays_d.astype(np.float32)\n        raydrop_masks = raydrop_masks.astype(np.float32)\n\n        raydrop_datum = {\n            \"hit_masks\": hit_masks,\n            \"hit_depths\": hit_depths,\n            \"hit_normals\": hit_normals,\n            \"hit_incidences\": hit_incidences,\n            \"intensities\": intensities,\n            \"rays_o\": rays_o,\n            \"rays_d\": rays_d,\n            \"raydrop_masks\": raydrop_masks,\n        }\n        raydrop_data.append(raydrop_datum)\n\n    return raydrop_data\n"
  },
  {
    "path": "lidarnvs/lidarnvs_nksr.py",
    "content": "import open3d as o3d\nimport numpy as np\nfrom lidarnvs.lidarnvs_meshing import LidarNVSMeshing\n\nimport torch\nimport nksr\n\n\nclass LidarNVSNksr(LidarNVSMeshing):\n    def __init__(self, ckpt_path=None):\n        super(LidarNVSNksr, self).__init__(ckpt_path=ckpt_path)\n\n        # To be filled in the fit() method.\n        self.points = None\n        self.point_intensities = None\n        self.pcd = None\n        self.kdtree = None\n        self.mesh = None\n\n        self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n        self.nksr_reconstructor = nksr.Reconstructor(self.device)\n\n        # self.meshing_func shall be pre-filled with functools.partial.\n        self.meshing_func = self._run_nksr\n\n    def _run_nksr(\n        self,\n        pcd: o3d.geometry.PointCloud,\n    ) -> o3d.geometry.TriangleMesh:\n        print(\"Start _run_nksr()\")\n\n        pcd.estimate_normals()\n\n        input_xyz = torch.from_numpy(np.asarray(pcd.points)).float().to(self.device)\n        input_normal = torch.from_numpy(np.asarray(pcd.normals)).float().to(self.device)\n\n        field = self.nksr_reconstructor.reconstruct(\n            input_xyz, input_normal, detail_level=0.5\n        )\n        mesh = field.extract_dual_mesh(mise_iter=1)\n\n        vertices = mesh.v.cpu().numpy()\n        triangles = mesh.f.cpu().numpy()\n\n        mesh = o3d.geometry.TriangleMesh()\n        mesh.vertices = o3d.utility.Vector3dVector(vertices)\n        mesh.triangles = o3d.utility.Vector3iVector(triangles)\n        mesh.compute_vertex_normals()\n\n        return mesh\n"
  },
  {
    "path": "lidarnvs/lidarnvs_pcgen.py",
    "content": "import camtools as ct\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom lidarnerf.convert import (\n    lidar_to_pano_with_intensities,\n    lidar_to_pano_with_intensities_fpa,\n    pano_to_lidar_with_intensities,\n)\nfrom lidarnvs.loader import extract_dataset_frame\nfrom lidarnvs.raydrop_train_pcgen import RayDrop, run_network, get_embedder\nfrom lidarnvs.lidarnvs_base import LidarNVSBase\n\n\nclass LidarNVSPCGen(LidarNVSBase):\n    def __init__(self, raycasting=\"cp\", ckpt_path=None):\n        self.raycasting = raycasting\n\n        # Network for predicting raydrop.\n        if ckpt_path is not None:\n            self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n            self.embed_fn, input_ch = get_embedder(4, input_dims=1, i=-1)\n            self.embeddirs_fn, input_ch_views = get_embedder(10, input_dims=3, i=-1)\n            total_input_ch = input_ch * 2 + input_ch_views\n            netdepth, netwidth = 4, 128\n            self.model = RayDrop(D=netdepth, W=netwidth, input_ch=total_input_ch).to(\n                self.device\n            )\n\n            ckpt = torch.load(ckpt_path)\n            self.model.load_state_dict(ckpt[\"network_fn_state_dict\"])\n            self.model.eval()\n            print(f\"Checkpoint loaded from {ckpt_path}\")\n\n    def fit(self, dataset) -> None:\n        \"\"\"\n        Fit the model to the given train dataset.\n\n        Args:\n            dataset: A NeRFDataset object.\n        \"\"\"\n        # Extract all points, in world coordinates.\n        num_frames = len(dataset)\n        all_points = []\n        all_point_intensities = []\n        for frame_idx in tqdm(range(num_frames), \"Extract train frames\"):\n            frame_dict = extract_dataset_frame(dataset, frame_idx)\n            all_points.append(frame_dict[\"points\"])\n            all_point_intensities.append(frame_dict[\"point_intensities\"])\n        all_points = np.vstack(all_points)\n\n        all_point_intensities = np.hstack(all_point_intensities)\n        assert len(all_points) == len(all_point_intensities)\n\n        # Save points and intensities for interpolation.\n        self.points = all_points\n        self.point_intensities = all_point_intensities\n\n    def predict_frame(\n        self,\n        lidar_K: np.ndarray,  # (2, )\n        lidar_pose: np.ndarray,  # (4, 4)\n        lidar_H: int,\n        lidar_W: int,\n    ) -> dict:\n        \"\"\"\n        Predict (synthesis) the point cloud from the given lidar parameters.\n        All necessary information parameters to model a lidar are given.\n\n        Args:\n            lidar_K: (2, ), float32\n            lidar_pose: (4, 4), float32\n            lidar_H: int\n            lidar_W: int\n\n        Return:\n            predict_dict: dict\n            - [\"local_points\"]: (N, 3), float32\n            - [\"points\"]      : (N, 3), float32\n            - [\"pano\"]        : (H, W), float32\n            - [\"intensities\"] : (H, W), float32\n        \"\"\"\n        # In world and local coordinates.\n        local_points = ct.project.homo_project(\n            self.points,\n            ct.convert.pose_to_T(lidar_pose),\n        )\n\n        # Pano intensities.\n        local_points_with_intensities = np.concatenate(\n            [local_points, self.point_intensities.reshape((-1, 1))], axis=1\n        )\n        if self.raycasting == \"cp\":\n            pano, intensities = lidar_to_pano_with_intensities(\n                local_points_with_intensities=local_points_with_intensities,\n                lidar_H=lidar_H,\n                lidar_W=lidar_W,\n                lidar_K=lidar_K,\n            )\n        elif self.raycasting == \"fpa\":\n            pano, intensities = lidar_to_pano_with_intensities_fpa(\n                local_points_with_intensities=local_points_with_intensities,\n                lidar_H=lidar_H,\n                lidar_W=lidar_W,\n                lidar_K=lidar_K,\n            )\n\n        local_points_with_intensities = pano_to_lidar_with_intensities(\n            pano=pano, intensities=intensities, lidar_K=lidar_K\n        )\n        local_points = local_points_with_intensities[:, :3]\n        local_point_intensities = local_points_with_intensities[:, 3]\n\n        points = ct.project.homo_project(local_points, lidar_pose)\n        point_intensities = local_point_intensities\n\n        predict_dict = {\n            # Frame properties.\n            \"pano\": pano,\n            \"intensities\": intensities,\n            # Global properties.\n            \"points\": points,\n            \"point_intensities\": point_intensities,\n            # Local properties.\n            \"local_points\": local_points,\n            \"local_point_intensities\": local_point_intensities,\n        }\n        return predict_dict\n\n    @torch.inference_mode()\n    def predict_frame_with_raydrop(\n        self,\n        lidar_K: np.ndarray,  # (2, )\n        lidar_pose: np.ndarray,  # (4, 4)\n        lidar_H: int,\n        lidar_W: int,\n    ) -> dict:\n        nvs_frame = self.predict_frame(\n            lidar_K=lidar_K,\n            lidar_pose=lidar_pose,\n            lidar_H=lidar_H,\n            lidar_W=lidar_W,\n        )\n        direction = get_direction(lidar_H, lidar_W, lidar_K)\n        pano = nvs_frame[\"pano\"]\n        intensity = nvs_frame[\"intensities\"]\n        rays_val = np.concatenate(\n            (\n                np.array(direction).reshape(-1, 3),\n                np.array(pano).reshape(-1, 1),\n                np.array(intensity).reshape(-1, 1),\n            ),\n            -1,\n        )\n        rays_val = torch.Tensor(rays_val).to(self.device)\n        pd_raydrop_masks = run_network(\n            rays_val, self.model, self.embed_fn, self.embeddirs_fn\n        )\n        pd_raydrop_masks = np.where(\n            pd_raydrop_masks.cpu().numpy() > 0.5, 1.0, 0.0\n        ).reshape(lidar_H, lidar_W)\n        # Update predict_dict\n        # Frame properties.\n        pano = nvs_frame[\"pano\"]\n        intensities = nvs_frame[\"intensities\"]\n        if not np.all(pd_raydrop_masks == 0):\n            pano = pano * pd_raydrop_masks\n            intensities = intensities * pd_raydrop_masks\n        # Local properties.\n        local_points_with_intensities = pano_to_lidar_with_intensities(\n            pano=pano,\n            intensities=intensities,\n            lidar_K=lidar_K,\n        )\n        local_points = local_points_with_intensities[:, :3]\n        local_point_intensities = local_points_with_intensities[:, 3]\n        # Global properties.\n        points = ct.project.homo_project(local_points, lidar_pose)\n        point_intensities = local_point_intensities\n\n        predict_dict = {\n            # Frame properties.\n            \"pano\": pano,\n            \"intensities\": intensities,\n            # Global properties.\n            \"points\": points,\n            \"point_intensities\": point_intensities,\n            # Local properties.\n            \"local_points\": local_points,\n            \"local_point_intensities\": local_point_intensities,\n        }\n\n        return predict_dict\n\n\ndef generate_raydrop_data_pcgen(dataset, nvs: LidarNVSPCGen, rm_pano_mask=True) -> dict:\n    \"\"\"\n    Prepare dataset for learning ray drop.\n    The frames are NOT loaded by our dataset, but GENERATED.\n\n    Return:\n        directions, panos, intensities, raydrop_masks\n    \"\"\"\n\n    raydrop_masks = []\n    directions = []\n    panos = []\n    intensities = []\n    for frame_idx in tqdm(range(len(dataset)), desc=\"Prepare raydrop dataset\"):\n        gt_frame = extract_dataset_frame(\n            dataset, frame_idx=frame_idx, rm_pano_mask=rm_pano_mask\n        )\n        nvs_frame = nvs.predict_frame(\n            lidar_K=gt_frame[\"lidar_K\"],\n            lidar_pose=gt_frame[\"lidar_pose\"],\n            lidar_H=gt_frame[\"lidar_H\"],\n            lidar_W=gt_frame[\"lidar_W\"],\n        )\n\n        # The target.\n        raydrop_masks.append(gt_frame[\"pano\"])\n\n        # The inputs\n        lidar_H, lidar_W, lidar_K = (\n            gt_frame[\"lidar_H\"],\n            gt_frame[\"lidar_W\"],\n            gt_frame[\"lidar_K\"],\n        )\n        directions.append(get_direction(lidar_H, lidar_W, lidar_K))\n        panos.append(nvs_frame[\"pano\"])\n        intensities.append(nvs_frame[\"intensities\"])\n    return (directions, panos, intensities, raydrop_masks)\n\n\ndef get_direction(lidar_H, lidar_W, lidar_K):\n    fov_up, fov = lidar_K\n    i, j = np.meshgrid(\n        np.arange(lidar_W, dtype=np.float32),\n        np.arange(lidar_H, dtype=np.float32),\n        indexing=\"xy\",\n    )\n    beta = -(i - lidar_W / 2) / lidar_W * 2 * np.pi\n    alpha = (fov_up - j / lidar_H * fov) / 180 * np.pi\n    dirs = np.stack(\n        [np.cos(alpha) * np.cos(beta), np.cos(alpha) * np.sin(beta), np.sin(alpha)], -1\n    )\n    return dirs\n"
  },
  {
    "path": "lidarnvs/lidarnvs_poisson.py",
    "content": "import time\n\nimport open3d as o3d\nimport numpy as np\nfrom lidarnvs.lidarnvs_meshing import LidarNVSMeshing\nimport functools\n\n\nclass LidarNVSPoisson(LidarNVSMeshing):\n    @staticmethod\n    def _run_poisson(\n        pcd: o3d.geometry.PointCloud,\n        depth: int,\n        min_density: int,\n    ) -> o3d.geometry.TriangleMesh:\n        print(\"Start _run_poisson()\")\n        s_time = time.time()\n        # Run.\n        mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(\n            pcd,\n            depth=depth,\n        )\n        # Filter by density.\n        vertices_to_remove = densities < np.quantile(densities, min_density)\n        mesh.remove_vertices_by_mask(vertices_to_remove)\n        # All-black colors are generated, but we don't need them.\n        mesh.vertex_colors = o3d.utility.Vector3dVector([])\n        print(f\"_run_poisson() time: {time.time() - s_time:.3f} secs\")\n        return mesh\n\n    def __init__(\n        self,\n        poisson_depth=10,\n        poisson_min_density=0.6,\n        intensity_interpolate_k=5,\n        ckpt_path=None,\n    ):\n        super(LidarNVSPoisson, self).__init__(ckpt_path=ckpt_path)\n\n        self.poisson_depth = poisson_depth\n        self.poisson_min_density = poisson_min_density\n        self.intensity_interpolate_k = intensity_interpolate_k\n\n        # To be filled in the fit() method.\n        self.points = None\n        self.point_intensities = None\n        self.pcd = None\n        self.kdtree = None\n        self.mesh = None\n\n        # self.meshing_func shall be pre-filled with functools.partial.\n        self.meshing_func = functools.partial(\n            LidarNVSPoisson._run_poisson,\n            depth=self.poisson_depth,\n            min_density=self.poisson_min_density,\n        )\n"
  },
  {
    "path": "lidarnvs/loader.py",
    "content": "import camtools as ct\nimport numpy as np\n\nfrom lidarnerf.dataset.base_dataset import get_lidar_rays\nfrom lidarnerf.convert import pano_to_lidar_with_intensities\n\n\ndef extract_dataset_frame(\n    dataset, frame_idx: int, rm_pano_mask: bool = True, verbose: bool = False\n) -> dict:\n    \"\"\"\n    Extract a single frame from a dataset object.\n    \"\"\"\n    # Unpack dataset.\n    lidar_pose = dataset.poses_lidar[frame_idx].numpy()\n    pano = dataset.images_lidar[frame_idx][:, :, 2].numpy()\n    intensities = dataset.images_lidar[frame_idx][:, :, 1].numpy()\n    lidar_K = dataset.intrinsics_lidar\n    lidar_H = dataset.H_lidar\n    lidar_W = dataset.W_lidar\n\n    # Process pano mask.\n    # TODO: remove this.\n    pano_mask = pano != -1\n    if rm_pano_mask:\n        pano[pano == -1] = 0\n\n    # Load rays.\n    ray_dict = get_lidar_rays(\n        poses=dataset.poses_lidar[[frame_idx]],\n        intrinsics=dataset.intrinsics_lidar,\n        H=dataset.H_lidar,\n        W=dataset.W_lidar,\n        N=-1,\n        patch_size=1,\n    )\n    rays_o = ray_dict[\"rays_o\"].squeeze().numpy()\n    rays_d = ray_dict[\"rays_d\"].squeeze().numpy()\n    rays = np.concatenate([rays_o, rays_d], axis=-1)\n\n    # Generate gt data.\n    # pose: cam to world projection matrix.\n    # T   : world to cam projection matrix.\n    # (N, 4)\n    local_points_with_intensities = pano_to_lidar_with_intensities(\n        pano=pano,\n        intensities=intensities,\n        lidar_K=lidar_K,\n    )\n    local_points = local_points_with_intensities[:, :3]\n    local_point_intensities = local_points_with_intensities[:, 3]\n\n    # Project local to world coordinates.\n    points = ct.project.homo_project(local_points, lidar_pose)\n    point_intensities = local_point_intensities\n\n    # \"pano\"       : invalid points marked as 0 (depth).\n    # \"intensities\": 0 means likely invalid, but not 100%.\n    frame_dict = {\n        \"rays\": rays,\n        \"lidar_pose\": lidar_pose,\n        \"lidar_K\": lidar_K,\n        \"lidar_H\": lidar_H,\n        \"lidar_W\": lidar_W,\n        # Frame properties.\n        \"pano\": pano,\n        \"pano_mask\": pano_mask,\n        \"intensities\": intensities,\n        # Local coord properties.\n        \"local_points\": local_points,\n        \"local_point_intensities\": local_point_intensities,\n        # World coord properties.\n        \"points\": points,\n        \"point_intensities\": point_intensities,\n    }\n    if verbose:\n        for key, val in frame_dict.items():\n            if isinstance(val, np.ndarray):\n                print(f\"- {key}: {val.shape}\")\n            else:\n                print(f\"- {key}: {val}\")\n\n    return frame_dict\n"
  },
  {
    "path": "lidarnvs/plot_possion_grid_search.py",
    "content": "from pathlib import Path\nimport json\n\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\n\ndef main():\n    json_path = Path(\"poisson_grid_search.json\")\n    with open(json_path, \"r\") as f:\n        data = json.load(f)\n\n    min_chamfer = 1e10\n    min_datum = None\n    for datum in data:\n        if datum[\"chamfer\"] < min_chamfer:\n            min_chamfer = datum[\"chamfer\"]\n            min_datum = datum\n    print(f\"min_chamfer: {min_chamfer}\")\n    print(f\"min_datum: {min_datum}\")\n\n    # Fill confusion matrix.\n    col_vals = [8, 9, 10, 11, 12]\n    row_vals = [0.4, 0.3, 0.2]\n    conf_matrix = np.zeros((len(row_vals), len(col_vals)))\n    for datum in data:\n        min_density = datum[\"poisson_min_density\"]\n        poisson_depth = datum[\"poisson_depth\"]\n        if min_density not in row_vals or poisson_depth not in col_vals:\n            continue\n        row_idx = row_vals.index(min_density)\n        col_idx = col_vals.index(poisson_depth)\n        conf_matrix[row_idx, col_idx] = datum[\"chamfer\"]\n\n    # Print the confusion matrix using Matplotlib\n    fig, ax = plt.subplots(figsize=(7.5, 7.5))\n    ax.matshow(conf_matrix, cmap=plt.cm.Blues, alpha=0.3)\n    for i in range(conf_matrix.shape[0]):\n        for j in range(conf_matrix.shape[1]):\n            ax.text(\n                x=j,\n                y=i,\n                s=f\"{conf_matrix[i, j]:.2f}\",\n                va=\"center\",\n                ha=\"center\",\n                size=\"xx-large\",\n            )\n    ax.set_xticklabels([\"\"] + [str(v) for v in col_vals])\n    ax.set_yticklabels([\"\"] + [str(v) for v in row_vals])\n\n    plt.xlabel(\"Poisson Depth\", fontsize=18)\n    plt.ylabel(\"Min Density\", fontsize=18)\n    plt.show()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "lidarnvs/raydrop_dataset_poisson.py",
    "content": "import pickle\nfrom pathlib import Path\n\nimport torch\nfrom torch.utils.data import Dataset\n\n\nclass RaydropDataset(Dataset):\n    def __init__(self, data_dir, split):\n        self.data_dir = Path(data_dir)\n        self.split = split\n        if not self.data_dir.is_dir():\n            raise ValueError(f\"Directory {self.data_dir} does not exist.\")\n        if self.split not in [\"train\", \"test\"]:\n            raise ValueError(f\"Split {self.split} not supported.\")\n\n        pkl_path = self.data_dir / f\"{self.split}_data.pkl\"\n        if not pkl_path.is_file():\n            raise ValueError(f\"File {pkl_path} does not exist.\")\n        with open(pkl_path, \"rb\") as f:\n            self.raydrop_data = pickle.load(f)\n\n    def __len__(self):\n        return len(self.raydrop_data)\n\n    def __getitem__(self, idx):\n        return self.raydrop_data[idx]\n\n    @staticmethod\n    def collate_fn(batch):\n        \"\"\"\n        RaydropDataset is a dict-style dataset, where __getitem__(i) returns\n        a dictionary of tensors. Essentially, A dataloader will do:\n        ```python\n        for indices in batch_sampler:\n            yield collate_fn([dataset[i] for i in indices])\n        ```\n\n        Args:\n            batch: list of dicts.\n\n        Return:\n            images: (N, C, H, W) tensor, float32.\n            masks : (N, H, W) tensor, float32. 1 means valid ray.\n        \"\"\"\n        # First, call the default colllate_fn.\n        batch = torch.utils.data.default_collate(batch)\n\n        # (N, H, W, C)\n        images = torch.cat(\n            [\n                batch[\"hit_masks\"][..., None],\n                batch[\"hit_depths\"][..., None],\n                batch[\"hit_normals\"],\n                batch[\"hit_incidences\"][..., None],\n                batch[\"intensities\"][..., None],\n                batch[\"rays_d\"],\n            ],\n            dim=3,\n        )\n        # (N, C, H, W)\n        images = images.permute(0, 3, 1, 2)\n\n        # (N, H, W)\n        masks = batch[\"raydrop_masks\"]\n\n        return images, masks\n"
  },
  {
    "path": "lidarnvs/raydrop_train_pcgen.py",
    "content": "import os\nimport numpy as np\nimport imageio\nimport random\nimport torch\nimport torch.nn as nn\nimport matplotlib.pyplot as plt\nimport torch.nn.functional as F\nfrom pathlib import Path\nimport pickle\n\nl1loss = nn.L1Loss(reduction=\"mean\")\nmseloss = nn.MSELoss()\nimg2mse = lambda x, y: torch.mean((x - y) ** 2)\nto8b = (\n    lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8)\n    if np.max(x) < 10\n    else (255.0 * np.clip(x / 81.0, 0, 1)).astype(np.uint8)\n)\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef setup_seed(seed):\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    random.seed(seed)\n    torch.backends.cudnn.deterministic = True\n\n\nsetup_seed(0)\n\n\ndef cal_psnr(im1, im2):\n    mse = (np.abs(im1 - im2) ** 2).mean()\n    psnr = -10 * np.log10(mse)  # max_value = 1\n    return psnr\n\n\nclass RayDrop(nn.Module):\n    def __init__(self, D=4, W=128, input_ch=3, output_ch=1):\n        \"\"\" \"\"\"\n        super(RayDrop, self).__init__()\n        self.D = D\n        self.W = W\n        self.input_ch = input_ch\n\n        self.linears = nn.ModuleList(\n            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) for i in range(D - 1)]\n        )\n        self.output_linear = nn.Linear(W, output_ch)\n\n        self.linears.apply(weights_init)\n        self.output_linear.apply(weights_init)\n\n    def forward(self, x):\n        h = x\n        for i, l in enumerate(self.linears):\n            h = self.linears[i](h)\n            h = F.relu(h)\n        output = self.output_linear(h)\n        return output\n\n\ndef weights_init(m):\n    if isinstance(m, nn.Linear):\n        nn.init.kaiming_normal_(m.weight.data)\n        if m.bias is not None:\n            nn.init.zeros_(m.bias.data)\n\n\ndef config_parser():\n    import configargparse\n\n    parser = configargparse.ArgumentParser()\n    parser.add_argument(\"--config\", is_config_file=True, help=\"config file path\")\n    parser.add_argument(\n        \"--expname\", type=str, default=\"raysdrop\", help=\"experiment name\"\n    )\n    parser.add_argument(\n        \"--basedir\", type=str, default=\"./log\", help=\"where to store ckpts and logs\"\n    )\n    parser.add_argument(\n        \"--datadir\",\n        type=str,\n        default=\"/data/usr/ziguo.tt/working/nerf/data/\",\n        help=\"input data directory\",\n    )\n\n    parser.add_argument(\n        \"--no_reload\", action=\"store_true\", help=\"do not reload weights from saved ckpt\"\n    )\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"kitti360\",\n        choices=[\"kitti360\", \"nerfmvl\"],\n        help=\"The dataset loader to use.\",\n    )\n\n    # training options\n    parser.add_argument(\"--netdepth\", type=int, default=8, help=\"layers in network\")\n    parser.add_argument(\"--netwidth\", type=int, default=256, help=\"channels per layer\")\n\n    parser.add_argument(\n        \"--N_rand\",\n        type=int,\n        default=2048,\n        help=\"batch size (number of random rays per gradient step)\",\n    )\n    parser.add_argument(\"--lrate\", type=float, default=5e-4, help=\"learning rate\")\n    parser.add_argument(\n        \"--lrate_decay\",\n        type=int,\n        default=500,\n        help=\"exponential learning rate decay (in 1000 steps)\",\n    )\n    parser.add_argument(\n        \"--no_batching\",\n        action=\"store_true\",\n        help=\"only take random rays from 1 image at a time\",\n    )\n    parser.add_argument(\n        \"--ft_path\",\n        type=str,\n        default=None,\n        help=\"specific weights npy file to reload for coarse network\",\n    )\n\n    # rendering options\n\n    parser.add_argument(\n        \"--multires\",\n        type=int,\n        default=10,\n        help=\"log2 of max freq for positional encoding (3D location)\",\n    )\n    parser.add_argument(\n        \"--multires_views\",\n        type=int,\n        default=4,\n        help=\"log2 of max freq for positional encoding (2D direction)\",\n    )\n    parser.add_argument(\n        \"--i_embed\",\n        type=int,\n        default=-1,\n        help=\"set 1 for hashed embedding, 0 for default positional encoding, 2 for spherical\",\n    )\n    parser.add_argument(\n        \"--i_embed_views\",\n        type=int,\n        default=-1,\n        help=\"set 1 for hashed embedding, 0 for default positional encoding, 2 for spherical\",\n    )\n\n    parser.add_argument(\n        \"--render_test\",\n        action=\"store_true\",\n        help=\"render the test set instead of render_poses path\",\n    )\n\n    # logging/saving options\n    parser.add_argument(\n        \"--i_print\",\n        type=int,\n        default=100,\n        help=\"frequency of console printout and metric loggin\",\n    )\n    parser.add_argument(\n        \"--i_weights\", type=int, default=10000, help=\"frequency of weight ckpt saving\"\n    )\n    parser.add_argument(\n        \"--i_save\", type=int, default=1000, help=\"frequency of rays saving\"\n    )\n\n    # lidar nerf\n    parser.add_argument(\"--N_iters\", type=int, default=500000)\n    parser.add_argument(\"--H\", type=int, default=66)\n    parser.add_argument(\"--W\", type=int, default=1030)\n\n    # lr\n    parser.add_argument(\"--cosLR\", action=\"store_true\")\n    parser.add_argument(\n        \"--coslrate\", type=float, default=5e-4, help=\"init learning rate\"\n    )\n    parser.add_argument(\n        \"--cosminlrate\", type=float, default=5e-5, help=\"min learning rate\"\n    )\n    parser.add_argument(\"--warmup_iters\", type=int, default=1000)\n\n    # loss type\n\n    parser.add_argument(\n        \"--rgb_loss_type\",\n        type=str,\n        default=\"img2mse\",\n        help=\"options: img2mse / mseloss / l1loss\",\n    )\n\n    return parser\n\n\ndef cosine_scheduler(\n    base_value, final_value, globel_step, warmup_iters=0, start_warmup_value=0\n):\n    warmup_schedule = np.array([])\n    if warmup_iters > 0:\n        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)\n\n    iters = np.arange(globel_step - warmup_iters)\n    schedule = final_value + 0.5 * (base_value - final_value) * (\n        1 + np.cos(np.pi * iters / len(iters))\n    )\n\n    schedule = np.concatenate((warmup_schedule, schedule))\n    assert len(schedule) == globel_step\n    return schedule\n\n\ndef get_embedder(multires, input_dims=3, i=0):\n    if i == -1:\n        return nn.Identity(), input_dims\n    elif i == 0:\n        embed_kwargs = {\n            \"include_input\": True,\n            \"input_dims\": input_dims,\n            \"max_freq_log2\": multires - 1,\n            \"num_freqs\": multires,\n            \"log_sampling\": True,\n            \"periodic_fns\": [torch.sin, torch.cos],\n        }\n\n        embedder_obj = Embedder(**embed_kwargs)\n        embed = lambda x, eo=embedder_obj: eo.embed(x)\n        out_dim = embedder_obj.out_dim\n    return embed, out_dim\n\n\nclass Embedder:\n    def __init__(self, **kwargs):\n        self.kwargs = kwargs\n        self.create_embedding_fn()\n\n    def create_embedding_fn(self):\n        embed_fns = []\n        d = self.kwargs[\"input_dims\"]\n        out_dim = 0\n        if self.kwargs[\"include_input\"]:\n            embed_fns.append(lambda x: x)\n            # embed_fns.append(lambda x : x/torch.norm(x, dim=-1, keepdim=True))\n            out_dim += d\n\n        max_freq = self.kwargs[\"max_freq_log2\"]\n        N_freqs = self.kwargs[\"num_freqs\"]\n\n        if self.kwargs[\"log_sampling\"]:\n            freq_bands = 2.0 ** torch.linspace(0.0, max_freq, steps=N_freqs)\n        else:\n            freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, steps=N_freqs)\n\n        for freq in freq_bands:\n            for p_fn in self.kwargs[\"periodic_fns\"]:\n                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))\n                out_dim += d\n\n        self.embed_fns = embed_fns\n        self.out_dim = out_dim\n\n    def embed(self, inputs):\n        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)\n\n\ndef run_network(inputs, model, embed_fn, embeddirs_fn):\n    \"\"\"Prepares inputs and applies network 'fn'.\"\"\"\n    ray_direction, depth, intensity = inputs[:, :3], inputs[:, 3], inputs[:, 4]\n    embedded_depth = embed_fn(depth.unsqueeze(1))\n    embedded_intensity = embed_fn(intensity.unsqueeze(1))\n    embedded_dirs = embeddirs_fn(ray_direction)\n    input = torch.cat((embedded_dirs, embedded_depth, embedded_intensity), 1)\n    outputs = model(input)\n    return outputs\n\n\ndef load_pkl_data(data_dir, split):\n    if not data_dir.is_dir():\n        raise ValueError(f\"Directory {data_dir} does not exist.\")\n    if split not in [\"train\", \"test\"]:\n        raise ValueError(f\"Split {split} not supported.\")\n    pkl_path = data_dir / f\"{split}_data.pkl\"\n    if not pkl_path.is_file():\n        raise ValueError(f\"File {pkl_path} does not exist.\")\n    with open(pkl_path, \"rb\") as f:\n        raydrop_data = pickle.load(f)\n    return raydrop_data\n\n\ndef train():\n    parser = config_parser()\n    args = parser.parse_args()\n    cosLR = args.cosLR\n    loss_dict = {\"img2mse\": img2mse, \"mseloss\": mseloss, \"l1loss\": l1loss}\n\n    H = args.H\n    W = args.W\n\n    # load dataset\n    data_dir = Path(args.datadir)\n\n    (directions, panos, intensities, raydrop_masks) = load_pkl_data(data_dir, \"train\")\n    # print(\n    #     np.array(directions).shape,\n    #     np.array(panos).shape,\n    #     np.array(intensities).shape,\n    #     np.array(raydrop_masks).shape)\n    rays_all = np.concatenate(\n        (\n            np.array(directions).reshape(-1, 3),\n            np.array(panos).reshape(-1, 1),\n            np.array(intensities).reshape(-1, 1),\n        ),\n        -1,\n    )\n    raydrop_masks = np.array(raydrop_masks)\n    rays_all = rays_all[raydrop_masks.reshape(-1) > -1]\n    raydrop_masks = np.where(raydrop_masks[raydrop_masks > -1] == 0.0, 0.0, 1.0)\n    rays_all = np.concatenate((rays_all, raydrop_masks.reshape(-1, 1)), -1)\n\n    (directions, panos, intensities, raydrop_masks) = load_pkl_data(data_dir, \"test\")\n    raydrop_val_list = []\n    for direction, pano, intensity, raydrop_mask in zip(\n        directions, panos, intensities, raydrop_masks\n    ):\n        raydrop_val_list.append(\n            np.concatenate(\n                (\n                    np.array(direction).reshape(-1, 3),\n                    np.array(pano).reshape(-1, 1),\n                    np.array(intensity).reshape(-1, 1),\n                    np.array(raydrop_mask).reshape(-1, 1),\n                ),\n                -1,\n            )\n        )\n    rays_val1 = raydrop_val_list[0]\n    raydrop_masks = np.array(raydrop_masks)\n    mask_val1 = np.where(raydrop_masks[0] > -1, 1, 0)\n    ray_drop_val1 = raydrop_masks[0].reshape(H, W)\n\n    # Create log dir and copy the config file\n    basedir = args.basedir\n    expname = args.expname\n    os.makedirs(os.path.join(basedir, expname), exist_ok=True)\n    f = os.path.join(basedir, expname, \"args.txt\")\n    with open(f, \"w\") as file:\n        for arg in sorted(vars(args)):\n            attr = getattr(args, arg)\n            file.write(\"{} = {}\\n\".format(arg, attr))\n    if args.config is not None:\n        f = os.path.join(basedir, expname, \"config.txt\")\n        with open(f, \"w\") as file:\n            file.write(open(args.config, \"r\").read())\n\n    # network\n    embed_fn, input_ch = get_embedder(args.multires, input_dims=1, i=args.i_embed)\n    embeddirs_fn, input_ch_views = get_embedder(\n        args.multires_views, input_dims=3, i=args.i_embed_views\n    )\n    total_input_ch = input_ch * 2 + input_ch_views\n    # model\n    model = RayDrop(D=args.netdepth, W=args.netwidth, input_ch=total_input_ch).to(\n        device\n    )\n    grad_vars = list(model.parameters())\n    # optimizer\n    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))\n\n    start = 0\n    if args.ft_path is not None and args.ft_path != \"None\":\n        ckpts = [args.ft_path]\n    else:\n        ckpts = [\n            os.path.join(basedir, expname, f)\n            for f in sorted(os.listdir(os.path.join(basedir, expname)))\n            if f.endswith(\".tar\")\n        ]\n\n    print(\"Found ckpts\", ckpts)\n    if len(ckpts) > 0 and not args.no_reload:\n        ckpt_path = ckpts[-1]\n        print(\"Reloading from\", ckpt_path)\n        ckpt = torch.load(ckpt_path)\n\n        start = ckpt[\"global_step\"]\n        optimizer.load_state_dict(ckpt[\"optimizer_state_dict\"])\n        # Load model\n        model.load_state_dict(ckpt[\"network_fn_state_dict\"])\n\n    global_step = start\n\n    # Prepare raybatch tensor if batching random rays\n    N_rand = args.N_rand\n    use_batching = not args.no_batching\n    if use_batching:\n        print(\"shuffle rays\")\n        np.random.shuffle(rays_all)\n        rays_all = torch.Tensor(rays_all).to(device)\n        rays_val1 = torch.Tensor(rays_val1).to(device)\n\n    if args.render_test:\n        print(\"RENDER ONLY\")\n        for idx, rays_val, raydrop_mask in zip(\n            range(len(raydrop_val_list)), raydrop_val_list, raydrop_masks\n        ):\n            rays_val = torch.Tensor(rays_val).to(device)\n            with torch.no_grad():\n                predict_drop_val = run_network(rays_val, model, embed_fn, embeddirs_fn)\n            imgbase = os.path.join(basedir, expname, str(idx))\n            mask_bbox = np.where(raydrop_mask > -1, 1, 0)\n            predict_drop_val = (\n                np.where(predict_drop_val.cpu().numpy() > 0.5, 1.0, 0.0).reshape(H, W)\n                * mask_bbox\n            )\n            np.save(imgbase + \"_pred_drop.npy\", predict_drop_val)\n            imageio.imsave(imgbase + \"_pred_drop.png\", predict_drop_val.reshape(H, W))\n\n            ray_drop_gt = np.where(raydrop_mask > 0, 1, 0)\n            imageio.imsave(imgbase + \"_gt_drop.png\", ray_drop_gt.reshape(H, W))\n\n        return\n\n    N_iters = args.N_iters + 1\n    print(\"Begin\")\n\n    loss_log = []\n    val_psnr = []\n    start = start + 1\n    i_batch = 0\n\n    lr_schedule = cosine_scheduler(\n        base_value=args.coslrate,\n        final_value=args.cosminlrate,\n        globel_step=N_iters - 1,\n        warmup_iters=args.warmup_iters,\n    )\n\n    for i in range(start, N_iters):\n        # Sample random ray batch\n        if use_batching:\n            # Random over all images\n            batch = rays_all[i_batch : i_batch + N_rand]  # [B, 2+1, 3*?]\n            # ray_direction, depth, intensity, target_drop = batch[:, :3], batch[:, 3], batch[:, 4], batch[:, 5]\n            inputs, target_drop = batch[:, :5], batch[:, 5]\n            i_batch += N_rand\n            if i_batch >= rays_all.shape[0]:\n                print(\"Shuffle data after an epoch!\")\n                rand_idx = torch.randperm(rays_all.shape[0])\n                rays_all = rays_all[rand_idx]\n                i_batch = 0\n\n        #####  Core optimization loop  #####\n        predict_drop = run_network(inputs, model, embed_fn, embeddirs_fn)\n        optimizer.zero_grad()\n\n        rgb_loss = loss_dict[args.rgb_loss_type]\n        loss = rgb_loss(predict_drop, target_drop.unsqueeze(1))\n        # loss = KL_loss_fun(predict_drop, target_drop.unsqueeze(1))\n\n        loss.backward()\n        optimizer.step()\n\n        # NOTE: IMPORTANT!\n        ##   update learning rate   ###\n        decay_rate = 0.1\n        decay_steps = args.lrate_decay * 1000\n        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))\n        for param_group in optimizer.param_groups:\n            if cosLR:\n                param_group[\"lr\"] = lr_schedule[global_step]\n            else:\n                param_group[\"lr\"] = new_lrate\n\n        # Rest is logging\n        if i % args.i_weights == 0:\n            path = os.path.join(basedir, expname, \"{:06d}.tar\".format(i))\n            ckpt = {\n                \"global_step\": global_step,\n                \"network_fn_state_dict\": model.state_dict(),\n                \"optimizer_state_dict\": optimizer.state_dict(),\n            }\n            torch.save(ckpt, path)\n            print(\"Saved checkpoints at\", path)\n\n        if i % args.i_save == 0 and i > 0:\n            # Turn in testing mode\n            with torch.no_grad():\n                predict_drop_val = run_network(rays_val1, model, embed_fn, embeddirs_fn)\n            imgbase = os.path.join(basedir, expname, \"{:06d}_\".format(i))\n            predict_drop_val = (\n                np.where(predict_drop_val.cpu().numpy() > 0.5, 1.0, 0.0).reshape(H, W)\n                * mask_val1\n            )\n            psnr = cal_psnr(predict_drop_val.reshape(H, W), ray_drop_val1)\n            print(psnr)\n            val_psnr.append(psnr)\n            loss_save = np.array(val_psnr)\n            plt.plot(loss_save)\n            plt.savefig(os.path.join(basedir, expname, \"val_psnr.png\"))\n            plt.close()\n            imageio.imsave(imgbase + \"val_drop.png\", predict_drop_val.reshape(H, W))\n\n        loss_log.append(loss.item())\n        if i % args.i_print == 0:\n            loss_save = np.array(loss_log)\n            plt.plot(loss_save)\n            plt.savefig(os.path.join(basedir, expname, \"loss_curve.png\"))\n            plt.close()\n\n            loss_print = [loss.item()]\n\n            print(f\"[TRAIN] Iter: {i} Loss: {loss_print} \")\n\n        global_step += 1\n    loss_log = np.array(loss_log)\n    np.save(os.path.join(basedir, expname, \"loss_log.npy\"), loss_log)\n    val_psnr = np.array(val_psnr)\n    np.save(os.path.join(basedir, expname, \"val_psnr.npy\"), val_psnr)\n\n\nif __name__ == \"__main__\":\n    torch.set_default_tensor_type(\"torch.cuda.FloatTensor\")\n\n    train()\n"
  },
  {
    "path": "lidarnvs/raydrop_train_poisson.py",
    "content": "import argparse\nimport logging\nfrom pathlib import Path\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import optim\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\n\nimport wandb\nfrom lidarnvs.raydrop_dataset_poisson import RaydropDataset\nfrom lidarnvs.unet import UNet, dice_coeff, dice_loss, multiclass_dice_coeff\n\n\n@torch.inference_mode()\ndef evaluate(net, dataloader, device, amp):\n    net.eval()\n    num_val_batches = len(dataloader)\n    dice_score = 0\n\n    # Iterate over the test set\n    with torch.autocast(device.type if device.type != \"mps\" else \"cpu\", enabled=amp):\n        for batch in tqdm(\n            dataloader,\n            total=num_val_batches,\n            desc=\"Test round\",\n            unit=\"batch\",\n            leave=False,\n        ):\n            images, true_masks = batch\n\n            # Move images and labels to correct device and type\n            images = images.to(\n                device=device, dtype=torch.float32, memory_format=torch.channels_last\n            )\n            true_masks = true_masks.to(device=device, dtype=torch.long)\n\n            # Predict the mask\n            mask_pred = net(images)\n            true_masks = true_masks.reshape(mask_pred.shape)\n\n            if net.n_classes == 1:\n                assert (\n                    true_masks.min() >= 0 and true_masks.max() <= 1\n                ), \"True mask indices should be in [0, 1]\"\n                mask_pred = (F.sigmoid(mask_pred) > 0.5).float()\n                # Compute the dice score\n                dice_score += dice_coeff(\n                    mask_pred, true_masks, reduce_batch_first=False\n                )\n            else:\n                assert (\n                    true_masks.min() >= 0 and true_masks.max() < net.n_classes\n                ), \"True mask indices should be in [0, n_classes[\"\n                # Convert to one-hot format\n                true_masks = (\n                    F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float()\n                )\n                mask_pred = (\n                    F.one_hot(mask_pred.argmax(dim=1), net.n_classes)\n                    .permute(0, 3, 1, 2)\n                    .float()\n                )\n                # Compute the Dice score, ignoring background\n                dice_score += multiclass_dice_coeff(\n                    mask_pred[:, 1:], true_masks[:, 1:], reduce_batch_first=False\n                )\n\n    net.train()\n    return dice_score / max(num_val_batches, 1)\n\n\ndef train_model(\n    model,\n    data_dir,\n    ckpt_dir,\n    device,\n    epochs: int = 5,\n    batch_size: int = 1,\n    learning_rate: float = 1e-5,\n    save_checkpoint: bool = True,\n    img_scale: float = 0.5,\n    amp: bool = False,\n    weight_decay: float = 1e-8,\n    momentum: float = 0.999,\n    gradient_clipping: float = 1.0,\n):\n    data_dir = Path(data_dir)\n    ckpt_dir = Path(ckpt_dir)\n\n    # Create dataset\n    train_dataset = RaydropDataset(data_dir=data_dir, split=\"train\")\n    test_dataset = RaydropDataset(data_dir=data_dir, split=\"test\")\n    n_train = len(train_dataset)\n    n_test = len(test_dataset)\n\n    # Create data loaders\n    train_loader = DataLoader(\n        train_dataset,\n        batch_size=batch_size,\n        collate_fn=RaydropDataset.collate_fn,\n        shuffle=True,\n    )\n    test_loader = DataLoader(\n        test_dataset,\n        batch_size=batch_size,\n        collate_fn=RaydropDataset.collate_fn,\n        shuffle=True,\n    )\n\n    # Initialize logging\n    experiment = wandb.init(project=\"U-Net\", resume=\"allow\", anonymous=\"must\")\n    experiment.config.update(\n        {\n            \"epochs\": epochs,\n            \"batch_size\": batch_size,\n            \"learning_rate\": learning_rate,\n            \"save_checkpoint\": save_checkpoint,\n            \"img_scale\": img_scale,\n            \"amp\": amp,\n        }\n    )\n\n    log_str = (\n        f\"Starting training:\\n\"\n        f\"Epochs:          {epochs}\\n\"\n        f\"Batch size:      {batch_size}\\n\"\n        f\"Learning rate:   {learning_rate}\\n\"\n        f\"Training size:   {n_train}\\n\"\n        f\"Validation size: {n_test}\\n\"\n        f\"Checkpoints:     {save_checkpoint}\\n\"\n        f\"Device:          {device.type}\\n\"\n        f\"Images scaling:  {img_scale}\\n\"\n        f\"Mixed Precision: {amp}\\n\"\n    )\n    logging.info(log_str)\n\n    # Set up optimizer, loss, lr_scheduler, loss scaling.\n    optimizer = optim.RMSprop(\n        model.parameters(),\n        lr=learning_rate,\n        weight_decay=weight_decay,\n        momentum=momentum,\n        foreach=True,\n    )\n    scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n        optimizer, \"max\", patience=5\n    )  # goal: maximize Dice score\n    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)\n    criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()\n    global_step = 0\n\n    # 5. Begin training\n    for epoch in range(1, epochs + 1):\n        model.train()\n        epoch_loss = 0\n        with tqdm(total=n_train, desc=f\"Epoch {epoch}/{epochs}\", unit=\"img\") as pbar:\n            for batch in train_loader:\n                images, true_masks = batch\n\n                if images.shape[1] != model.n_channels:\n                    raise ValueError(\n                        f\"Input channel mismatch: \"\n                        f\"{images.shape[1]} vs {model.n_channels}\"\n                    )\n                images = images.to(\n                    device=device,\n                    dtype=torch.float32,\n                    memory_format=torch.channels_last,\n                )\n                true_masks = true_masks.to(device=device, dtype=torch.long)\n\n                with torch.autocast(\n                    device.type if device.type != \"mps\" else \"cpu\", enabled=amp\n                ):\n                    masks_pred = model(images)\n                    if model.n_classes == 1:\n                        loss = criterion(masks_pred.squeeze(1), true_masks.float())\n                        loss += dice_loss(\n                            F.sigmoid(masks_pred.squeeze(1)),\n                            true_masks.float(),\n                            multiclass=False,\n                        )\n                    else:\n                        loss = criterion(masks_pred, true_masks)\n                        loss += dice_loss(\n                            F.softmax(masks_pred, dim=1).float(),\n                            F.one_hot(true_masks, model.n_classes)\n                            .permute(0, 3, 1, 2)\n                            .float(),\n                            multiclass=True,\n                        )\n\n                optimizer.zero_grad(set_to_none=True)\n                grad_scaler.scale(loss).backward()\n                torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)\n                grad_scaler.step(optimizer)\n                grad_scaler.update()\n\n                pbar.update(images.shape[0])\n                global_step += 1\n                epoch_loss += loss.item()\n                experiment.log(\n                    {\"train loss\": loss.item(), \"step\": global_step, \"epoch\": epoch}\n                )\n                pbar.set_postfix(**{\"loss (batch)\": loss.item()})\n\n                # Evaluation round\n                division_step = n_train // (5 * batch_size)\n                if division_step > 0:\n                    if global_step % division_step == 0:\n                        histograms = {}\n                        for tag, value in model.named_parameters():\n                            tag = tag.replace(\"/\", \".\")\n                            if not (torch.isinf(value) | torch.isnan(value)).any():\n                                histograms[\"Weights/\" + tag] = wandb.Histogram(\n                                    value.data.cpu()\n                                )\n                            if not (\n                                torch.isinf(value.grad) | torch.isnan(value.grad)\n                            ).any():\n                                histograms[\"Gradients/\" + tag] = wandb.Histogram(\n                                    value.grad.data.cpu()\n                                )\n\n                        val_score = evaluate(model, test_loader, device, amp)\n                        scheduler.step(val_score)\n\n                        logging.info(\"Validation Dice score: {}\".format(val_score))\n                        try:\n                            experiment.log(\n                                {\n                                    \"learning rate\": optimizer.param_groups[0][\"lr\"],\n                                    \"validation Dice\": val_score,\n                                    \"images\": wandb.Image(images[0].cpu()),\n                                    \"masks\": {\n                                        \"true\": wandb.Image(\n                                            true_masks[0].float().cpu()\n                                        ),\n                                        \"pred\": wandb.Image(\n                                            masks_pred.argmax(dim=1)[0].float().cpu()\n                                        ),\n                                    },\n                                    \"step\": global_step,\n                                    \"epoch\": epoch,\n                                    **histograms,\n                                }\n                            )\n                        except:\n                            pass\n\n        if save_checkpoint:\n            checkpoint_path = ckpt_dir / f\"checkpoint_{epoch:03}.pth\"\n            checkpoint_path.parent.mkdir(parents=True, exist_ok=True)\n            state_dict = model.state_dict()\n            torch.save(state_dict, checkpoint_path)\n            logging.info(f\"Checkpoint {epoch} saved!\")\n\n\ndef get_args():\n    parser = argparse.ArgumentParser(\n        description=\"Train the UNet on images and target masks\"\n    )\n    parser.add_argument(\n        \"--data_dir\", type=str, default=\"N/A\", help=\"Path to the raydrop dataset.\"\n    )\n    parser.add_argument(\n        \"--ckpt_dir\", type=str, default=\"N/A\", help=\"Path to the checkpoint directory.\"\n    )\n    parser.add_argument(\"--epochs\", \"-e\", type=int, default=10, help=\"Number of epochs\")\n    parser.add_argument(\n        \"--batch-size\", \"-b\", dest=\"batch_size\", type=int, default=2, help=\"Batch size\"\n    )\n    parser.add_argument(\n        \"--learning-rate\",\n        \"-l\",\n        type=float,\n        default=1e-5,\n        help=\"Learning rate\",\n        dest=\"lr\",\n    )\n    parser.add_argument(\n        \"--load\", \"-f\", type=str, default=False, help=\"Load model from a .pth file\"\n    )\n    parser.add_argument(\n        \"--scale\",\n        \"-s\",\n        type=float,\n        default=0.5,\n        help=\"Downscaling factor of the images\",\n    )\n    parser.add_argument(\n        \"--amp\", action=\"store_true\", default=False, help=\"Use mixed precision\"\n    )\n    parser.add_argument(\n        \"--bilinear\", action=\"store_true\", default=False, help=\"Use bilinear upsampling\"\n    )\n    parser.add_argument(\n        \"--classes\", \"-c\", type=int, default=1, help=\"Number of classes\"\n    )\n\n    return parser.parse_args()\n\n\ndef main():\n    args = get_args()\n\n    logging.basicConfig(level=logging.INFO, format=\"%(levelname)s: %(message)s\")\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    logging.info(f\"Using device {device}\")\n\n    # - n_channels: 10\n    #   - [0]   : hit_masks\n    #   - [1]   : hit_depths\n    #   - [2:5] : hit_normals (world), TODO: change to local coord\n    #   - [5]   : hit_incidences in cosine\n    #   - [6]   : intensities\n    #   - [7:10]: rays_d, TODO: change to local coord\n    # - n_classes: 1\n    #   - number of probabilities you want to get per pixel\n    #   - raydrop_masks has only 1 channel\n    model = UNet(n_channels=10, n_classes=args.classes, bilinear=args.bilinear)\n    model = model.to(memory_format=torch.channels_last)\n\n    logging.info(\n        f\"Network:\\n\"\n        f\"\\t{model.n_channels} input channels\\n\"\n        f\"\\t{model.n_classes} output channels (classes)\\n\"\n        f\"\\t{'Bilinear' if model.bilinear else 'Transposed conv'} upscaling\"\n    )\n\n    if args.load:\n        state_dict = torch.load(args.load, map_location=device)\n        model.load_state_dict(state_dict)\n        logging.info(f\"Model loaded from {args.load}\")\n\n    model.to(device=device)\n    train_model(\n        model=model,\n        data_dir=args.data_dir,\n        ckpt_dir=args.ckpt_dir,\n        epochs=args.epochs,\n        batch_size=args.batch_size,\n        learning_rate=args.lr,\n        device=device,\n        img_scale=args.scale,\n        amp=args.amp,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "lidarnvs/readme.md",
    "content": "# Lidar Novel View Synthesis Baselines\n\n![baseline_render](../assets/baseline-render.png)\n\n## LidarSim\n\n```bash\n# Generate raydrop dataset.\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"1538\" --enable_collect_raydrop_dataset\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"1728\" --enable_collect_raydrop_dataset\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"1908\" --enable_collect_raydrop_dataset\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"3353\" --enable_collect_raydrop_dataset\n\n# Train the raydrop model.\npython raydrop_train.py --data_dir data/raydrop/kitti360_1538 --ckpt_dir log/raydrop/kitti360_1538\npython raydrop_train.py --data_dir data/raydrop/kitti360_1728 --ckpt_dir log/raydrop/kitti360_1728\npython raydrop_train.py --data_dir data/raydrop/kitti360_1908 --ckpt_dir log/raydrop/kitti360_1908\npython raydrop_train.py --data_dir data/raydrop/kitti360_3353 --ckpt_dir log/raydrop/kitti360_3353\n\n# Run lidarnvs again, now with raydrop model.\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"1538\"\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"1728\"\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"1908\"\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"3353\"\n\n# lidarnvs on NeRF-MVL\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"bollard\" \npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"car\"\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"pedestrian\"\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"pier\"\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"plant\"\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"tire\"\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"traffic_cone\"\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"warning_sign\"\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"water_safety_barrier\"\n```\n\n## PCGen\n### KITTI-360\n```bash\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"1908\" --method \"pcgen\" --enable_collect_raydrop_dataset\npython lidarnvs/raydrop_train_pcgen.py --config lidarnvs/configs/pcgen_kitti360_raydrop.txt\npython lidarnvs/run.py --dataset kitti360 --sequence_id \"1908\" --method \"pcgen\" --ckpt_path pcgen_raydrop_log/kitti360seq1908/raysdrop/010000.tar\n```\n\n### LiDAR-MVL\n```bash\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"car\" --method \"pcgen\" --enable_collect_raydrop_dataset\npython lidarnvs/raydrop_train_pcgen.py --config lidarnvs/configs/pcgen_nerfmvl_raydrop.txt\npython lidarnvs/run.py --dataset nerf_mvl --sequence_id \"car\" --method \"pcgen\"  --ckpt_path pcgen_raydrop_log/car/010000.tar\n```\n"
  },
  {
    "path": "lidarnvs/run.py",
    "content": "from pathlib import Path\n\nimport numpy as np\n\nfrom lidarnvs.lidarnvs_pcgen import LidarNVSPCGen, generate_raydrop_data_pcgen\nfrom lidarnvs.lidarnvs_poisson import LidarNVSPoisson\nfrom lidarnvs.lidarnvs_nksr import LidarNVSNksr\nfrom lidarnvs.lidarnvs_meshing import generate_raydrop_data_meshing\nfrom lidarnvs.loader import extract_dataset_frame\nfrom lidarnvs.eval import eval_points_and_pano\nfrom tqdm import tqdm\nimport pickle\nimport argparse\nfrom lidarnerf.dataset.kitti360_dataset import KITTI360Dataset\nfrom lidarnerf.dataset.nerfmvl_dataset import NeRFMVLDataset\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"kitti360\",\n        choices=[\"kitti360\", \"nerf_mvl\"],\n        help=\"The dataset loader to use.\",\n    )\n    parser.add_argument(\n        \"--method\",\n        type=str,\n        default=\"poisson\",\n        choices=[\"poisson\", \"nksr\", \"pcgen\"],\n        help=\"method for lidarnvs\",\n    )\n    parser.add_argument(\n        \"--raycasting\",\n        type=str,\n        default=\"cp\",\n        choices=[\"cp\", \"fpa\"],\n        help=\"raycasting mehtod\",\n    )\n    # dataset\n    parser.add_argument(\"--path\", type=str, default=\"data/kitti360\")\n    parser.add_argument(\n        \"--sequence_id\",\n        type=str,\n        default=\"1908\",\n        help=\"The sequence id within the selected dataset to use.\",\n    )\n    parser.add_argument(\n        \"--num_rays_lidar\",\n        type=int,\n        default=4096,\n        help=\"num rays sampled per image for each training step\",\n    )\n    parser.add_argument(\n        \"--offset\", type=float, nargs=\"*\", default=[0, 0, 0], help=\"offset of location\"\n    )\n    parser.add_argument(\n        \"--enable_collect_raydrop_dataset\",\n        action=\"store_true\",\n        help=\"Whether to collect raydrop dataset. If not enabled, inference \"\n        \"mode will be used\",\n    )\n    parser.add_argument(\n        \"--ckpt_path\",\n        type=str,\n        default=\"\",\n        help=\"The ckpt of raydrop network.\",\n    )\n    parser.add_argument(\n        \"--poisson_depth\",\n        type=int,\n        default=11,\n        help=\"Depth of tree for Poisson recon.\",\n    )\n    parser.add_argument(\n        \"--poisson_min_density\",\n        type=float,\n        default=0.3,\n        help=\"Minimum density to filter points after Poisson recon.\",\n    )\n    args = parser.parse_args()\n\n    # Check sequence id.\n    kitti360_sequence_ids = [\n        \"1538\",\n        \"1728\",\n        \"1908\",\n        \"3353\",\n    ]\n    nerf_mvl_sequence_ids = [\n        \"bollard\",\n        \"car\",\n        \"pedestrian\",\n        \"pier\",\n        \"plant\",\n        \"tire\",\n        \"traffic_cone\",\n        \"warning_sign\",\n        \"water_safety_barrier\",\n    ]\n    if args.dataset == \"kitti360\":\n        if args.sequence_id not in kitti360_sequence_ids:\n            raise ValueError(\n                f\"Unknown sequence id {args.sequence_id} for {args.dataset}\"\n            )\n    elif args.dataset == \"nerf_mvl\":\n        if args.sequence_id not in nerf_mvl_sequence_ids:\n            raise ValueError(\n                f\"Unknown sequence id {args.sequence_id} for {args.dataset}\"\n            )\n    else:\n        raise ValueError(f\"Unknown dataset: {args.dataset}\")\n\n    print(\"[Config]===============================================\")\n    print(f\"dataset             : {args.dataset}\")\n    print(f\"sequence_id         : {args.sequence_id}\")\n    print(f\"poisson_depth       : {args.poisson_depth}\")\n    print(f\"poisson_min_density : {args.poisson_min_density}\")\n    print(f\"dataset_collect_mode: {args.enable_collect_raydrop_dataset}\")\n    print(\"=======================================================\")\n\n    # Init train and test datasets.\n    if args.dataset == \"kitti360\":\n        train_dataset = KITTI360Dataset(\n            split=\"train\",\n            root_path=args.path,\n            offset=args.offset,\n            num_rays_lidar=args.num_rays_lidar,\n            sequence_id=args.sequence_id,\n        )\n\n        test_dataset = KITTI360Dataset(\n            split=\"train\",\n            root_path=args.path,\n            offset=args.offset,\n            num_rays_lidar=args.num_rays_lidar,\n            sequence_id=args.sequence_id,\n        )\n    elif args.dataset == \"nerf_mvl\":\n        train_dataset = NeRFMVLDataset(\n            split=\"train\",\n            root_path=args.path,\n            offset=args.offset,\n            num_rays_lidar=args.num_rays_lidar,\n            sequence_id=args.sequence_id,\n        )\n        test_dataset = NeRFMVLDataset(\n            split=\"test\",\n            root_path=args.path,\n            offset=args.offset,\n            num_rays_lidar=args.num_rays_lidar,\n            sequence_id=args.sequence_id,\n        )\n    else:\n        raise ValueError(f\"Unknown dataset: {args.dataset}\")\n\n    # Train.\n    if args.enable_collect_raydrop_dataset:\n        ckpt_path = None\n    else:\n        ckpt_path = Path(args.ckpt_path)\n        if not ckpt_path.is_file():\n            raise ValueError(f\"ckpt_path ({ckpt_path}) does not exist.\")\n\n    if args.method == \"poisson\":\n        nvs = LidarNVSPoisson(\n            poisson_depth=args.poisson_depth,\n            poisson_min_density=args.poisson_min_density,\n            intensity_interpolate_k=9,\n            ckpt_path=ckpt_path,\n        )\n    elif args.method == \"nksr\":\n        nvs = LidarNVSNksr(ckpt_path=ckpt_path)\n    elif args.method == \"pcgen\":\n        nvs = LidarNVSPCGen(\n            raycasting=args.raycasting,\n            ckpt_path=ckpt_path,\n        )\n    else:\n        raise ValueError(f\"Unknown method: {args.method}\")\n\n    nvs.fit(train_dataset)\n    exit(0)\n\n    # Eval test frames.\n    all_metrics = []\n    for frame_idx in tqdm(range(len(test_dataset)), desc=\"Eval test frames\"):\n        gt_frame = extract_dataset_frame(test_dataset, frame_idx=frame_idx)\n        if args.enable_collect_raydrop_dataset:\n            inference_func = nvs.predict_frame\n        else:\n            inference_func = nvs.predict_frame_with_raydrop\n        pd_frame = inference_func(\n            lidar_K=gt_frame[\"lidar_K\"],\n            lidar_pose=gt_frame[\"lidar_pose\"],\n            lidar_H=gt_frame[\"lidar_H\"],\n            lidar_W=gt_frame[\"lidar_W\"],\n        )\n        if args.dataset == \"nerf_mvl\":\n            # Load values to be updated.\n            gt_intensities = gt_frame[\"intensities\"]\n            pd_intensities = pd_frame[\"intensities\"]\n            gt_pano = gt_frame[\"pano\"]\n            pd_pano = pd_frame[\"pano\"]\n\n            # Load mask.\n            pano_mask = gt_frame[\"pano_mask\"]\n            nonzero_idx = np.array(np.nonzero(pano_mask))\n            new_h = max(nonzero_idx[0]) - min(nonzero_idx[0]) + 1\n            new_w = max(nonzero_idx[1]) - min(nonzero_idx[1]) + 1\n            gt_intensities = gt_intensities[pano_mask].reshape(new_h, new_w)\n            pd_intensities = pd_intensities[pano_mask].reshape(new_h, new_w)\n            gt_intensities = gt_intensities * 255\n            pd_intensities = pd_intensities * 255\n            gt_pano = gt_pano[pano_mask].reshape(new_h, new_w)\n            pd_pano = pd_pano[pano_mask].reshape(new_h, new_w)\n\n            metrics = eval_points_and_pano(\n                gt_local_points=gt_frame[\"local_points\"],\n                pd_local_points=pd_frame[\"local_points\"],\n                gt_intensities=gt_intensities,\n                pd_intensities=pd_intensities,\n                gt_pano=gt_pano,\n                pd_pano=pd_pano,\n            )\n        else:\n            metrics = eval_points_and_pano(\n                gt_local_points=gt_frame[\"local_points\"],\n                pd_local_points=pd_frame[\"local_points\"],\n                gt_intensities=gt_frame[\"intensities\"],\n                pd_intensities=pd_frame[\"intensities\"],\n                gt_pano=gt_frame[\"pano\"],\n                pd_pano=pd_frame[\"pano\"],\n            )\n        all_metrics.append(metrics)\n\n    # Compute mean metrics.\n    mean_metrics = {}\n    for key in all_metrics[0].keys():\n        mean_metrics[key] = np.mean([m[key] for m in all_metrics])\n    print(\"Mean metrics:\")\n    for key in sorted(mean_metrics.keys()):\n        print(f\"- {key}: {mean_metrics[key]:.4f}\")\n\n    # # Visualize a single test frame.\n    # gt_pcd = o3d.geometry.PointCloud()\n    # gt_pcd.points = o3d.utility.Vector3dVector(gt_frame[\"points\"])\n    # gt_pcd.colors = o3d.utility.Vector3dVector(\n    #     ct.colormap.query(gt_frame[\"point_intensities\"]))\n\n    # pd_pcd = o3d.geometry.PointCloud()\n    # pd_pcd.points = o3d.utility.Vector3dVector(train_frame_nvs[\"points\"])\n    # pd_pcd.colors = o3d.utility.Vector3dVector(\n    #     ct.colormap.query(train_frame_nvs[\"point_intensities\"]))\n\n    # o3d.visualization.draw_geometries([gt_pcd])\n    # o3d.visualization.draw_geometries([pd_pcd])\n\n    # Concat all in to big tensors.\n    if args.enable_collect_raydrop_dataset:\n        if args.method == \"poisson\" and args.dataset != \"nerf_mvl\":\n            raydrop_train_data = generate_raydrop_data_meshing(train_dataset, nvs)\n            raydrop_test_data = generate_raydrop_data_meshing(test_dataset, nvs)\n        elif args.method == \"pcgen\":\n            raydrop_train_data = generate_raydrop_data_pcgen(\n                train_dataset, nvs, rm_pano_mask=False\n            )\n            raydrop_test_data = generate_raydrop_data_pcgen(\n                test_dataset, nvs, rm_pano_mask=False\n            )\n        else:\n            raise ValueError(f\"Unknown method/dataset: {args.method}/{args.dataset}\")\n\n        data_dir = (\n            Path(\"data/raydrop\") / args.method / f\"{args.dataset}_{args.sequence_id}\"\n        )\n        data_dir.mkdir(parents=True, exist_ok=True)\n        train_data_path = data_dir / \"train_data.pkl\"\n        test_data_path = data_dir / \"test_data.pkl\"\n\n        with open(train_data_path, \"wb\") as f:\n            pickle.dump(raydrop_train_data, f)\n        with open(test_data_path, \"wb\") as f:\n            pickle.dump(raydrop_test_data, f)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "lidarnvs/unet.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\n\n\nclass DoubleConv(nn.Module):\n    \"\"\"(convolution => [BN] => ReLU) * 2\"\"\"\n\n    def __init__(self, in_channels, out_channels, mid_channels=None):\n        super().__init__()\n        if not mid_channels:\n            mid_channels = out_channels\n        self.double_conv = nn.Sequential(\n            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),\n            nn.BatchNorm2d(mid_channels),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),\n            nn.BatchNorm2d(out_channels),\n            nn.ReLU(inplace=True),\n        )\n\n    def forward(self, x):\n        return self.double_conv(x)\n\n\nclass Down(nn.Module):\n    \"\"\"Downscaling with maxpool then double conv\"\"\"\n\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.maxpool_conv = nn.Sequential(\n            nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)\n        )\n\n    def forward(self, x):\n        return self.maxpool_conv(x)\n\n\nclass Up(nn.Module):\n    \"\"\"Upscaling then double conv\"\"\"\n\n    def __init__(self, in_channels, out_channels, bilinear=True):\n        super().__init__()\n\n        # if bilinear, use the normal convolutions to reduce the number of channels\n        if bilinear:\n            self.up = nn.Upsample(scale_factor=2, mode=\"bilinear\", align_corners=True)\n            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)\n        else:\n            self.up = nn.ConvTranspose2d(\n                in_channels, in_channels // 2, kernel_size=2, stride=2\n            )\n            self.conv = DoubleConv(in_channels, out_channels)\n\n    def forward(self, x1, x2):\n        x1 = self.up(x1)\n        # input is CHW\n        diffY = x2.size()[2] - x1.size()[2]\n        diffX = x2.size()[3] - x1.size()[3]\n\n        x1 = F.pad(\n            x1,\n            [\n                diffX // 2,\n                diffX - diffX // 2,\n                diffY // 2,\n                diffY - diffY // 2,\n            ],\n        )\n        # if you have padding issues, see\n        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a\n        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd\n        x = torch.cat([x2, x1], dim=1)\n        return self.conv(x)\n\n\nclass OutConv(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super(OutConv, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)\n\n    def forward(self, x):\n        return self.conv(x)\n\n\nclass UNet(nn.Module):\n    def __init__(self, n_channels, n_classes, bilinear=False):\n        super(UNet, self).__init__()\n        self.n_channels = n_channels\n        self.n_classes = n_classes\n        self.bilinear = bilinear\n\n        self.inc = DoubleConv(n_channels, 64)\n        self.down1 = Down(64, 128)\n        self.down2 = Down(128, 256)\n        self.down3 = Down(256, 512)\n        factor = 2 if bilinear else 1\n        self.down4 = Down(512, 1024 // factor)\n        self.up1 = Up(1024, 512 // factor, bilinear)\n        self.up2 = Up(512, 256 // factor, bilinear)\n        self.up3 = Up(256, 128 // factor, bilinear)\n        self.up4 = Up(128, 64, bilinear)\n        self.outc = OutConv(64, n_classes)\n\n    def forward(self, x):\n        x1 = self.inc(x)\n        x2 = self.down1(x1)\n        x3 = self.down2(x2)\n        x4 = self.down3(x3)\n        x5 = self.down4(x4)\n        x = self.up1(x5, x4)\n        x = self.up2(x, x3)\n        x = self.up3(x, x2)\n        x = self.up4(x, x1)\n        logits = self.outc(x)\n        return logits\n\n\ndef dice_coeff(\n    input: Tensor,\n    target: Tensor,\n    reduce_batch_first: bool = False,\n    epsilon: float = 1e-6,\n):\n    # Average of Dice coefficient for all batches, or for a single mask\n    assert input.size() == target.size()\n    assert input.dim() == 3 or not reduce_batch_first\n\n    if input.dim() == 2 or not reduce_batch_first:\n        sum_dim = (-1, -2)\n    else:\n        sum_dim = (-1, -2, -3)\n\n    inter = 2 * (input * target).sum(dim=sum_dim)\n    sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)\n    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)\n\n    dice = (inter + epsilon) / (sets_sum + epsilon)\n    return dice.mean()\n\n\ndef multiclass_dice_coeff(\n    input: Tensor,\n    target: Tensor,\n    reduce_batch_first: bool = False,\n    epsilon: float = 1e-6,\n):\n    # Average of Dice coefficient for all classes\n    return dice_coeff(\n        input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon\n    )\n\n\ndef dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):\n    # Dice loss (objective to minimize) between 0 and 1\n    fn = multiclass_dice_coeff if multiclass else dice_coeff\n    return 1 - fn(input, target, reduce_batch_first=True)\n\n\ndef main():\n    pass\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "main_lidarnerf.py",
    "content": "import torch\nimport configargparse\nimport os\nimport numpy as np\n\nfrom lidarnerf.nerf.utils import (\n    seed_everything,\n    RMSEMeter,\n    MAEMeter,\n    DepthMeter,\n    PointsMeter,\n    Trainer,\n)\n\n\ndef get_arg_parser():\n    parser = configargparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--config\",\n        is_config_file=True,\n        default=\"configs/kitti360_1908.txt\",\n        help=\"config file path\",\n    )\n    parser.add_argument(\"--path\", type=str, default=\"data/kitti360\")\n    parser.add_argument(\n        \"-L\", action=\"store_true\", help=\"equals --fp16 --tcnn --preload\"\n    )\n    parser.add_argument(\"--test\", action=\"store_true\", help=\"test mode\")\n    parser.add_argument(\"--test_eval\", action=\"store_true\", help=\"test and eval mode\")\n    parser.add_argument(\"--workspace\", type=str, default=\"workspace\")\n    parser.add_argument(\n        \"--cluster_summary_path\",\n        type=str,\n        default=\"/summary\",\n        help=\"Overwrite default summary path if on cluster\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0)\n    parser.add_argument(\n        \"--dataloader\", type=str, choices=(\"kitti360\", \"nerf_mvl\"), default=\"kitti360\"\n    )\n    parser.add_argument(\"--sequence_id\", type=str, default=\"1908\")\n\n    ### lidar-nerf\n    parser.add_argument(\"--enable_lidar\", action=\"store_true\", help=\"Enable lidar.\")\n    parser.add_argument(\"--alpha_d\", type=float, default=1e3)\n    parser.add_argument(\"--alpha_r\", type=float, default=1)\n    parser.add_argument(\"--alpha_i\", type=float, default=1)\n    parser.add_argument(\"--alpha_grad_norm\", type=float, default=1)\n    parser.add_argument(\"--alpha_spatial\", type=float, default=0.1)\n    parser.add_argument(\"--alpha_tv\", type=float, default=1)\n    parser.add_argument(\"--alpha_grad\", type=float, default=1e2)\n\n    parser.add_argument(\"--intensity_inv_scale\", type=float, default=1)\n\n    parser.add_argument(\"--spatial_smooth\", action=\"store_true\")\n    parser.add_argument(\"--grad_norm_smooth\", action=\"store_true\")\n    parser.add_argument(\"--tv_loss\", action=\"store_true\")\n    parser.add_argument(\"--grad_loss\", action=\"store_true\")\n    parser.add_argument(\"--sobel_grad\", action=\"store_true\")\n\n    parser.add_argument(\n        \"--desired_resolution\",\n        type=int,\n        default=2048,\n        help=\"TCN finest resolution at the smallest scale\",\n    )\n    parser.add_argument(\"--log2_hashmap_size\", type=int, default=19)\n    parser.add_argument(\"--n_features_per_level\", type=int, default=2)\n    parser.add_argument(\n        \"--num_layers\", type=int, default=2, help=\"num_layers of sigmanet\"\n    )\n    parser.add_argument(\n        \"--hidden_dim\", type=int, default=64, help=\"hidden_dim of sigmanet\"\n    )\n    parser.add_argument(\n        \"--geo_feat_dim\", type=int, default=15, help=\"geo_feat_dim of sigmanet\"\n    )\n    parser.add_argument(\"--eval_interval\", type=int, default=50)\n    parser.add_argument(\n        \"--num_rays_lidar\",\n        type=int,\n        default=4096,\n        help=\"num rays sampled per image for each training step\",\n    )\n    parser.add_argument(\n        \"--min_near_lidar\",\n        type=float,\n        default=0.01,\n        help=\"minimum near distance for camera\",\n    )\n    parser.add_argument(\n        \"--depth_loss\", type=str, default=\"l1\", help=\"l1, bce, mse, huber\"\n    )\n    parser.add_argument(\n        \"--depth_grad_loss\", type=str, default=\"l1\", help=\"l1, bce, mse, huber\"\n    )\n    parser.add_argument(\n        \"--intensity_loss\", type=str, default=\"mse\", help=\"l1, bce, mse, huber\"\n    )\n    parser.add_argument(\n        \"--raydrop_loss\", type=str, default=\"mse\", help=\"l1, bce, mse, huber\"\n    )\n    parser.add_argument(\n        \"--patch_size_lidar\",\n        type=int,\n        default=1,\n        help=\"[experimental] render patches in training. \"\n        \"1 means disabled, use [64, 32, 16] to enable\",\n    )\n    parser.add_argument(\n        \"--change_patch_size_lidar\",\n        nargs=\"+\",\n        type=int,\n        default=[1, 1],\n        help=\"[experimental] render patches in training. \"\n        \"1 means disabled, use [64, 32, 16] to enable, change during training\",\n    )\n    parser.add_argument(\n        \"--change_patch_size_epoch\",\n        type=int,\n        default=2,\n        help=\"change patch_size intenvel\",\n    )\n\n    ### training options\n    parser.add_argument(\n        \"--iters\",\n        type=int,\n        default=30000,\n        help=\"training iters\",\n    )\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"initial learning rate\")\n    parser.add_argument(\"--ckpt\", type=str, default=\"latest\")\n    parser.add_argument(\n        \"--num_rays\",\n        type=int,\n        default=4096,\n        help=\"num rays sampled per image for each training step\",\n    )\n    parser.add_argument(\n        \"--num_steps\", type=int, default=768, help=\"num steps sampled per ray\"\n    )\n    parser.add_argument(\n        \"--upsample_steps\", type=int, default=64, help=\"num steps up-sampled per ray\"\n    )\n    parser.add_argument(\n        \"--max_ray_batch\",\n        type=int,\n        default=4096,\n        help=\"batch size of rays at inference to avoid OOM)\",\n    )\n    parser.add_argument(\n        \"--patch_size\",\n        type=int,\n        default=1,\n        help=\"[experimental] render patches in training, so as to apply \"\n        \"LPIPS loss. 1 means disabled, use [64, 32, 16] to enable\",\n    )\n\n    ### network backbone options\n    parser.add_argument(\n        \"--fp16\", action=\"store_true\", help=\"use amp mixed precision training\"\n    )\n    parser.add_argument(\"--tcnn\", action=\"store_true\", help=\"use TCNN backend\")\n\n    ### dataset options\n    parser.add_argument(\n        \"--color_space\",\n        type=str,\n        default=\"srgb\",\n        help=\"Color space, supports (linear, srgb)\",\n    )\n    parser.add_argument(\n        \"--preload\",\n        action=\"store_true\",\n        help=\"preload all data into GPU, accelerate training but use more GPU memory\",\n    )\n    # (the default value is for the fox dataset)\n    parser.add_argument(\n        \"--bound\",\n        type=float,\n        default=2,\n        help=\"assume the scene is bounded in box[-bound, bound]^3, \"\n        \"if > 1, will invoke adaptive ray marching.\",\n    )\n    parser.add_argument(\n        \"--scale\",\n        type=float,\n        default=0.33,\n        help=\"scale camera location into box[-bound, bound]^3\",\n    )\n    parser.add_argument(\n        \"--offset\",\n        type=float,\n        nargs=\"*\",\n        default=[0, 0, 0],\n        help=\"offset of camera location\",\n    )\n    parser.add_argument(\n        \"--dt_gamma\",\n        type=float,\n        default=1 / 128,\n        help=\"dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 \"\n        \"to accelerate rendering (but usually with worse quality)\",\n    )\n    parser.add_argument(\n        \"--min_near\", type=float, default=0.2, help=\"minimum near distance for camera\"\n    )\n    parser.add_argument(\n        \"--density_thresh\",\n        type=float,\n        default=10,\n        help=\"threshold for density grid to be occupied\",\n    )\n    parser.add_argument(\n        \"--bg_radius\",\n        type=float,\n        default=-1,\n        help=\"if positive, use a background model at sphere(bg_radius)\",\n    )\n\n    return parser\n\n\ndef main():\n    parser = get_arg_parser()\n    opt = parser.parse_args()\n    opt.enable_lidar = True\n\n    # Check sequence id.\n    kitti360_sequence_ids = [\n        \"1538\",\n        \"1728\",\n        \"1908\",\n        \"3353\",\n    ]\n    nerf_mvl_sequence_ids = [\n        \"bollard\",\n        \"car\",\n        \"pedestrian\",\n        \"pier\",\n        \"plant\",\n        \"tire\",\n        \"traffic_cone\",\n        \"warning_sign\",\n        \"water_safety_barrier\",\n    ]\n\n    # Specify dataloader class\n    if opt.dataloader == \"kitti360\":\n        from lidarnerf.dataset.kitti360_dataset import KITTI360Dataset as NeRFDataset\n\n        if opt.sequence_id not in kitti360_sequence_ids:\n            raise ValueError(\n                f\"Unknown sequence id {opt.sequence_id} for {opt.dataloader}\"\n            )\n    elif opt.dataloader == \"nerf_mvl\":\n        from lidarnerf.dataset.nerfmvl_dataset import NeRFMVLDataset as NeRFDataset\n\n        if opt.sequence_id not in nerf_mvl_sequence_ids:\n            raise ValueError(\n                f\"Unknown sequence id {opt.sequence_id} for {opt.dataloader}\"\n            )\n    else:\n        raise RuntimeError(\"Should not reach here.\")\n\n    os.makedirs(opt.workspace, exist_ok=True)\n    f = os.path.join(opt.workspace, \"args.txt\")\n    with open(f, \"w\") as file:\n        for arg in vars(opt):\n            attr = getattr(opt, arg)\n            file.write(\"{} = {}\\n\".format(arg, attr))\n\n    if opt.L:\n        opt.fp16 = True\n        opt.tcnn = True\n        opt.preload = True\n\n    if opt.patch_size > 1:\n        # assert opt.patch_size > 16, \"patch_size should > 16 to run LPIPS loss.\"\n        assert (\n            opt.num_rays % (opt.patch_size**2) == 0\n        ), \"patch_size ** 2 should be dividable by num_rays.\"\n\n    opt.min_near = opt.scale  # hard-code, set min_near ori 1m\n    opt.min_near_lidar = opt.scale\n\n    if opt.tcnn:\n        opt.fp16 = True\n        assert opt.bg_radius <= 0, \"background model is not implemented for --tcnn\"\n        from lidarnerf.nerf.network_tcnn import NeRFNetwork\n\n        model = NeRFNetwork(\n            encoding=\"hashgrid\",\n            desired_resolution=opt.desired_resolution,\n            log2_hashmap_size=opt.log2_hashmap_size,\n            n_features_per_level=opt.n_features_per_level,\n            num_layers=opt.num_layers,\n            hidden_dim=opt.hidden_dim,\n            geo_feat_dim=opt.geo_feat_dim,\n            bound=opt.bound,\n            density_scale=1,\n            min_near=opt.min_near,\n            min_near_lidar=opt.min_near_lidar,\n            density_thresh=opt.density_thresh,\n            bg_radius=opt.bg_radius,\n        )\n    else:\n        from lidarnerf.nerf.network import NeRFNetwork\n\n        model = NeRFNetwork(\n            encoding=\"hashgrid\",\n            desired_resolution=opt.desired_resolution,\n            log2_hashmap_size=opt.log2_hashmap_size,\n            num_layers=opt.num_layers,\n            hidden_dim=opt.hidden_dim,\n            geo_feat_dim=opt.geo_feat_dim,\n            bound=opt.bound,\n            density_scale=1,\n            min_near=opt.min_near,\n            density_thresh=opt.density_thresh,\n            bg_radius=opt.bg_radius,\n        )\n\n    print(opt)\n    seed_everything(opt.seed)\n    print(model)\n\n    loss_dict = {\n        \"mse\": torch.nn.MSELoss(reduction=\"none\"),\n        \"l1\": torch.nn.L1Loss(reduction=\"none\"),\n        \"bce\": torch.nn.BCEWithLogitsLoss(reduction=\"none\"),\n        \"huber\": torch.nn.HuberLoss(reduction=\"none\", delta=0.2 * opt.scale),\n        \"cos\": torch.nn.CosineSimilarity(),\n    }\n    criterion = {\n        \"depth\": loss_dict[opt.depth_loss],\n        \"raydrop\": loss_dict[opt.raydrop_loss],\n        \"intensity\": loss_dict[opt.intensity_loss],\n        \"grad\": loss_dict[opt.depth_grad_loss],\n    }\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    if opt.test or opt.test_eval:\n        test_loader = NeRFDataset(\n            device=device,\n            split=\"test\",\n            root_path=opt.path,\n            sequence_id=opt.sequence_id,\n            preload=opt.preload,\n            scale=opt.scale,\n            offset=opt.offset,\n            fp16=opt.fp16,\n            patch_size_lidar=opt.patch_size_lidar,\n            enable_lidar=opt.enable_lidar,\n            num_rays_lidar=opt.num_rays_lidar,\n        ).dataloader()\n        if opt.enable_lidar:\n            depth_metrics = [\n                MAEMeter(intensity_inv_scale=opt.intensity_inv_scale),\n                RMSEMeter(),\n                DepthMeter(scale=opt.scale),\n                PointsMeter(\n                    scale=opt.scale, intrinsics=test_loader._data.intrinsics_lidar\n                ),\n            ]\n        else:\n            depth_metrics = []\n        trainer = Trainer(\n            \"lidar_nerf\",\n            opt,\n            model,\n            device=device,\n            workspace=opt.workspace,\n            criterion=criterion,\n            fp16=opt.fp16,\n            depth_metrics=depth_metrics,\n            use_checkpoint=opt.ckpt,\n        )\n\n        if test_loader.has_gt and opt.test_eval:\n            trainer.evaluate(test_loader)  # blender has gt, so evaluate it.\n        trainer.test(test_loader, write_video=False)  # test and save video\n        trainer.save_mesh(resolution=128, threshold=10)\n\n    else:\n        optimizer = lambda model: torch.optim.Adam(\n            model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15\n        )\n\n        train_loader = NeRFDataset(\n            device=device,\n            split=\"train\",\n            root_path=opt.path,\n            sequence_id=opt.sequence_id,\n            preload=opt.preload,\n            scale=opt.scale,\n            offset=opt.offset,\n            fp16=opt.fp16,\n            patch_size_lidar=opt.patch_size_lidar,\n            enable_lidar=opt.enable_lidar,\n            num_rays_lidar=opt.num_rays_lidar,\n        ).dataloader()\n\n        # decay to 0.1 * init_lr at last iter step\n        scheduler = lambda optimizer: torch.optim.lr_scheduler.LambdaLR(\n            optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)\n        )\n        if opt.enable_lidar:\n            depth_metrics = [\n                MAEMeter(intensity_inv_scale=opt.intensity_inv_scale),\n                RMSEMeter(),\n                DepthMeter(scale=opt.scale),\n                PointsMeter(\n                    scale=opt.scale, intrinsics=train_loader._data.intrinsics_lidar\n                ),\n            ]\n        else:\n            depth_metrics = []\n\n        trainer = Trainer(\n            \"lidar_nerf\",\n            opt,\n            model,\n            device=device,\n            workspace=opt.workspace,\n            optimizer=optimizer,\n            criterion=criterion,\n            ema_decay=0.95,\n            fp16=opt.fp16,\n            lr_scheduler=scheduler,\n            scheduler_update_every_step=True,\n            depth_metrics=depth_metrics,\n            use_checkpoint=opt.ckpt,\n            eval_interval=opt.eval_interval,\n        )\n\n        valid_loader = NeRFDataset(\n            device=device,\n            split=\"val\",\n            root_path=opt.path,\n            sequence_id=opt.sequence_id,\n            preload=opt.preload,\n            scale=opt.scale,\n            offset=opt.offset,\n            fp16=opt.fp16,\n            patch_size_lidar=opt.patch_size_lidar,\n            enable_lidar=opt.enable_lidar,\n            num_rays_lidar=opt.num_rays_lidar,\n        ).dataloader()\n\n        max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)\n        print(f\"max_epoch: {max_epoch}\")\n        trainer.train(train_loader, valid_loader, max_epoch)\n\n        # also test\n        test_loader = NeRFDataset(\n            device=device,\n            split=\"test\",\n            root_path=opt.path,\n            sequence_id=opt.sequence_id,\n            preload=opt.preload,\n            scale=opt.scale,\n            offset=opt.offset,\n            fp16=opt.fp16,\n            patch_size_lidar=opt.patch_size_lidar,\n            enable_lidar=opt.enable_lidar,\n            num_rays_lidar=opt.num_rays_lidar,\n        ).dataloader()\n\n        if test_loader.has_gt:\n            trainer.evaluate(test_loader)  # blender has gt, so evaluate it.\n\n        trainer.test(test_loader, write_video=True)  # test and save video\n\n        trainer.save_mesh(resolution=128, threshold=10)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "preprocess/cal_centerpose_bound.py",
    "content": "import numpy as np\n\nnp.set_printoptions(suppress=True)\nimport os\nimport json\nimport tqdm\nfrom lidarnerf.convert import pano_to_lidar\n\n\ndef cal_centerpose_bound_scale(\n    lidar_rangeview_paths, lidar2worlds, intrinsics, bound=1.0\n):\n    near = 200\n    far = 0\n    points_world_list = []\n    for i, lidar_rangeview_path in enumerate(lidar_rangeview_paths):\n        pano = np.load(lidar_rangeview_path)\n        point_cloud = pano_to_lidar(pano=pano[:, :, 2], lidar_K=intrinsics)\n        point_cloud = np.concatenate(\n            [point_cloud, np.ones(point_cloud.shape[0]).reshape(-1, 1)], -1\n        )\n        dis = np.sqrt(\n            point_cloud[:, 0] ** 2 + point_cloud[:, 1] ** 2 + point_cloud[:, 2] ** 2\n        )\n        near = min(min(dis), near)\n        far = max(far, max(dis))\n        points_world = (point_cloud @ lidar2worlds[i].T)[:, :3]\n        points_world_list.append(points_world)\n    print(\"near, far:\", near, far)\n\n    # plt.figure(figsize=(16, 16))\n    pc_all_w = np.concatenate(points_world_list)[:, :3]\n\n    # plt.scatter(pc_all_w[:, 0], pc_all_w[:, 1], s=0.001)\n    # lidar2world_scene = np.array(lidar2worlds)\n    # plt.plot(lidar2world_scene[:, 0, -1], lidar2world_scene[:, 1, -1])\n    # plt.savefig('vis/points-trajectory.png')\n\n    centerpose = [\n        (np.max(pc_all_w[:, 0]) + np.min(pc_all_w[:, 0])) / 2.0,\n        (np.max(pc_all_w[:, 1]) + np.min(pc_all_w[:, 1])) / 2.0,\n        (np.max(pc_all_w[:, 2]) + np.min(pc_all_w[:, 2])) / 2.0,\n    ]\n    print(\"centerpose: \", centerpose)\n    pc_all_w_centered = pc_all_w - centerpose\n\n    # plt.figure(figsize=(16, 16))\n    # plt.scatter(pc_all_w_centered[:, 0], pc_all_w_centered[:, 1], s=0.001)\n    # plt.savefig('vis/points-centered.png')\n\n    bound_ori = [\n        np.max(pc_all_w_centered[:, 0]),\n        np.max(pc_all_w_centered[:, 1]),\n        np.max(pc_all_w_centered[:, 2]),\n    ]\n    scale = bound / np.max(bound_ori)\n    print(\"scale: \", scale)\n\n    # pc_all_w_centered_scaled = pc_all_w_centered * scale\n    # plt.figure(figsize=(16, 16))\n    # plt.scatter(pc_all_w_centered_scaled[:, 0],\n    #             pc_all_w_centered_scaled[:, 1],\n    #             s=0.001)\n    # plt.savefig('vis/points-centered-scaled.png')\n\n\ndef get_path_pose_from_json(root_path, sequence_id):\n    with open(\n        os.path.join(root_path, f\"transforms_{sequence_id}_train.json\"), \"r\"\n    ) as f:\n        transform = json.load(f)\n    frames = transform[\"frames\"]\n    poses_lidar = []\n    paths_lidar = []\n    for f in tqdm.tqdm(frames, desc=f\"Loading {type} data\"):\n        pose_lidar = np.array(f[\"lidar2world\"], dtype=np.float32)  # [4, 4]\n        f_lidar_path = os.path.join(root_path, f[\"lidar_file_path\"])\n        poses_lidar.append(pose_lidar)\n        paths_lidar.append(f_lidar_path)\n    return paths_lidar, poses_lidar\n\n\ndef main():\n    # kitti360\n    root_path = \"data/kitti360\"\n    sequence_id = 1908\n    lidar_rangeview_paths, lidar2worlds = get_path_pose_from_json(\n        root_path, sequence_id=sequence_id\n    )\n    intrinsics = (2.0, 26.9)  # fov_up, fov\n\n    cal_centerpose_bound_scale(lidar_rangeview_paths, lidar2worlds, intrinsics)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "preprocess/generate_train_rangeview.py",
    "content": "import numpy as np\nimport os\nfrom pathlib import Path\nfrom tqdm import tqdm\nimport shutil\nimport argparse\n\nfrom lidarnerf.convert import (\n    lidar_to_pano_with_intensities,\n    lidar_to_pano_with_intensities_with_bbox_mask,\n)\n\n\ndef all_points_to_world(pcd_path_list, lidar2world_list):\n    pc_w_list = []\n    for i, pcd_path in enumerate(pcd_path_list):\n        point_cloud = np.load(pcd_path)\n        point_cloud[:, -1] = 1\n        points_world = (point_cloud @ (lidar2world_list[i].reshape(4, 4)).T)[:, :3]\n        pc_w_list.append(points_world)\n    return pc_w_list\n\n\ndef oriented_bounding_box(data):\n    data_norm = data - data.mean(axis=0)\n    C = np.cov(data_norm, rowvar=False)\n    vals, vecs = np.linalg.eig(C)\n    vecs = vecs[:, np.argsort(-vals)]\n    Y = np.matmul(data_norm, vecs)\n    offset = 0.03\n    xmin = min(Y[:, 0]) - offset\n    xmax = max(Y[:, 0]) + offset\n    ymin = min(Y[:, 1]) - offset\n    ymax = max(Y[:, 1]) + offset\n\n    temp = list()\n    temp.append([xmin, ymin])\n    temp.append([xmax, ymin])\n    temp.append([xmax, ymax])\n    temp.append([xmin, ymax])\n\n    pointInNewCor = np.asarray(temp)\n    OBB = np.matmul(pointInNewCor, vecs.T) + data.mean(0)\n    return OBB\n\n\ndef get_dataset_bbox(all_class, dataset_root, out_dir):\n    object_bbox = {}\n    for class_name in all_class:\n        lidar_path = os.path.join(dataset_root, class_name)\n        rt_path = os.path.join(lidar_path, \"lidar2world.txt\")\n        filenames = os.listdir(lidar_path)\n        filenames.remove(\"lidar2world.txt\")\n        filenames.sort(key=lambda x: int(x.split(\".\")[0]))\n        show_interval = 1\n        pcd_path_list = [os.path.join(lidar_path, filename) for filename in filenames][\n            ::show_interval\n        ]\n        print(f\"{lidar_path}: {len(pcd_path_list)} frames\")\n        lidar2world_list = list(np.loadtxt(rt_path))[::show_interval]\n        all_points = all_points_to_world(pcd_path_list, lidar2world_list)\n        pcd = np.concatenate(all_points).reshape((-1, 3))\n\n        OBB_xy = oriented_bounding_box(pcd[:, :2])\n        z_min, z_max = min(pcd[:, 2]), max(pcd[:, 2])\n        OBB_buttum = np.concatenate([OBB_xy, np.tile(z_min, 4).reshape(4, 1)], axis=1)\n        OBB_top = np.concatenate([OBB_xy, np.tile(z_max, 4).reshape(4, 1)], axis=1)\n        OBB = np.concatenate([OBB_top, OBB_buttum])\n        object_bbox[class_name] = OBB\n    np.save(os.path.join(out_dir, \"dataset_bbox_7k.npy\"), object_bbox)\n\n\ndef LiDAR_2_Pano_NeRF_MVL(\n    local_points_with_intensities,\n    lidar_H,\n    lidar_W,\n    intrinsics,\n    OBB_local,\n    max_depth=80.0,\n):\n    pano, intensities = lidar_to_pano_with_intensities_with_bbox_mask(\n        local_points_with_intensities=local_points_with_intensities,\n        lidar_H=lidar_H,\n        lidar_W=lidar_W,\n        lidar_K=intrinsics,\n        bbox_local=OBB_local,\n        max_depth=max_depth,\n    )\n    range_view = np.zeros((lidar_H, lidar_W, 3))\n    range_view[:, :, 1] = intensities\n    range_view[:, :, 2] = pano\n    return range_view\n\n\ndef generate_nerf_mvl_train_data(\n    H,\n    W,\n    intrinsics,\n    all_class,\n    dataset_bbox,\n    nerf_mvl_parent_dir,\n    out_dir,\n):\n    \"\"\"\n    Args:\n        H: Heights of the range view.\n        W: Width of the range view.\n        intrinsics: (fov_up, fov) of the range view.\n        out_dir: Output directory.\n    \"\"\"\n\n    for class_name in all_class:\n        OBB = dataset_bbox[class_name]\n        lidar_path = os.path.join(nerf_mvl_parent_dir, \"nerf_mvl_7k\", class_name)\n        filenames = os.listdir(lidar_path)\n        filenames.remove(\"lidar2world.txt\")\n        filenames.sort(key=lambda x: int(x.split(\".\")[0]))\n        save_path = os.path.join(out_dir, class_name)\n        if not os.path.exists(save_path):\n            os.makedirs(save_path)\n        shutil.copy(\n            os.path.join(lidar_path, \"lidar2world.txt\"),\n            os.path.join(save_path, \"lidar2world.txt\"),\n        )\n        lidar2world = np.loadtxt(os.path.join(lidar_path, \"lidar2world.txt\"))\n        avaliable_frames = [i for i in range(0, len(filenames))]\n        print(class_name, \" frames num \", len(avaliable_frames))\n        for idx in tqdm(avaliable_frames):\n            pcd = np.load(os.path.join(lidar_path, filenames[idx]))\n            OBB_local = (\n                np.concatenate([OBB, np.ones((8, 1))], axis=1)\n                @ np.linalg.inv(lidar2world[idx].reshape(4, 4)).T\n            )\n            pano = LiDAR_2_Pano_NeRF_MVL(pcd, H, W, intrinsics, OBB_local)\n            np.savez_compressed(\n                os.path.join(save_path, \"{:010d}.npz\").format(idx), data=pano\n            )\n\n\ndef create_nerf_mvl_rangeview():\n    project_root = Path(__file__).parent.parent\n    nerf_mvl_root = project_root / \"data\" / \"nerf_mvl\" / \"nerf_mvl_7k\"\n    nerf_mvl_parent_dir = nerf_mvl_root.parent\n    out_dir = nerf_mvl_parent_dir / \"nerf_mvl_7k_pano\"\n\n    all_class = [\n        \"water_safety_barrier\",\n        \"tire\",\n        \"pier\",\n        \"plant\",\n        \"warning_sign\",\n        \"traffic_cone\",\n        \"bollard\",\n        \"pedestrian\",\n        \"car\",\n    ]\n\n    # get_dataset_bbox\n    if not os.path.exists(os.path.join(nerf_mvl_parent_dir, \"dataset_bbox_7k.npy\")):\n        get_dataset_bbox(all_class, nerf_mvl_root, nerf_mvl_parent_dir)\n    dataset_bbox = np.load(\n        os.path.join(nerf_mvl_parent_dir, \"dataset_bbox_7k.npy\"), allow_pickle=True\n    ).item()\n\n    # generate train rangeview images\n    H = 256\n    W = 1800\n    intrinsics = (15, 40)\n    generate_nerf_mvl_train_data(\n        H=H,\n        W=W,\n        intrinsics=intrinsics,\n        all_class=all_class,\n        dataset_bbox=dataset_bbox,\n        nerf_mvl_parent_dir=nerf_mvl_parent_dir,\n        out_dir=out_dir,\n    )\n\n\ndef LiDAR_2_Pano_KITTI(\n    local_points_with_intensities, lidar_H, lidar_W, intrinsics, max_depth=80.0\n):\n    pano, intensities = lidar_to_pano_with_intensities(\n        local_points_with_intensities=local_points_with_intensities,\n        lidar_H=lidar_H,\n        lidar_W=lidar_W,\n        lidar_K=intrinsics,\n        max_depth=max_depth,\n    )\n    range_view = np.zeros((lidar_H, lidar_W, 3))\n    range_view[:, :, 1] = intensities\n    range_view[:, :, 2] = pano\n    return range_view\n\n\ndef generate_train_data(\n    H,\n    W,\n    intrinsics,\n    lidar_paths,\n    out_dir,\n    points_dim,\n):\n    \"\"\"\n    Args:\n        H: Heights of the range view.\n        W: Width of the range view.\n        intrinsics: (fov_up, fov) of the range view.\n        out_dir: Output directory.\n    \"\"\"\n\n    out_dir = Path(out_dir)\n    out_dir.mkdir(parents=True, exist_ok=True)\n\n    for lidar_path in tqdm(lidar_paths):\n        point_cloud = np.fromfile(lidar_path, dtype=np.float32)\n        point_cloud = point_cloud.reshape((-1, points_dim))\n        pano = LiDAR_2_Pano_KITTI(point_cloud, H, W, intrinsics)\n        frame_name = lidar_path.split(\"/\")[-1]\n        suffix = frame_name.split(\".\")[-1]\n        frame_name = frame_name.replace(suffix, \"npy\")\n        np.save(out_dir / frame_name, pano)\n\n\ndef create_kitti_rangeview():\n    project_root = Path(__file__).parent.parent\n    kitti_360_root = project_root / \"data\" / \"kitti360\" / \"KITTI-360\"\n    kitti_360_parent_dir = kitti_360_root.parent\n    out_dir = kitti_360_parent_dir / \"train\"\n    sequence_name = \"2013_05_28_drive_0000\"\n\n    H = 66\n    W = 1030\n    intrinsics = (2.0, 26.9)  # fov_up, fov\n\n    s_frame_id = 1908\n    e_frame_id = 1971  # Inclusive\n    frame_ids = list(range(s_frame_id, e_frame_id + 1))\n\n    lidar_dir = (\n        kitti_360_root\n        / \"data_3d_raw\"\n        / f\"{sequence_name}_sync\"\n        / \"velodyne_points\"\n        / \"data\"\n    )\n    lidar_paths = [\n        os.path.join(lidar_dir, \"%010d.bin\" % frame_id) for frame_id in frame_ids\n    ]\n\n    generate_train_data(\n        H=H,\n        W=W,\n        intrinsics=intrinsics,\n        lidar_paths=lidar_paths,\n        out_dir=out_dir,\n        points_dim=4,\n    )\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        default=\"kitti360\",\n        choices=[\"kitti360\", \"nerf_mvl\"],\n        help=\"The dataset loader to use.\",\n    )\n    args = parser.parse_args()\n\n    # Check dataset.\n    if args.dataset == \"kitti360\":\n        create_kitti_rangeview()\n    elif args.dataset == \"nerf_mvl\":\n        create_nerf_mvl_rangeview()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "preprocess/kitti360_loader.py",
    "content": "from pathlib import Path\nimport numpy as np\nimport camtools as ct\nimport open3d as o3d\n\n\nclass KITTI360Loader:\n    def __init__(self, kitti_360_root) -> None:\n        # Root directory.\n        self.kitti_360_root = Path(kitti_360_root)\n        if not self.kitti_360_root.is_dir():\n            raise FileNotFoundError(f\"KITTI-360 {kitti_360_root} not found.\")\n\n        # Other directories.\n        self.calibration_dir = self.kitti_360_root / \"calibration\"\n        self.data_poses_dir = self.kitti_360_root / \"data_poses\"\n        self.data_2d_raw_dir = self.kitti_360_root / \"data_2d_raw\"\n        self.data_3d_raw_dir = self.kitti_360_root / \"data_3d_raw\"\n\n        # Check if all directories exist.\n        if not self.calibration_dir.is_dir():\n            raise FileNotFoundError(\n                f\"Calibration dir {self.calibration_dir} not found.\"\n            )\n        if not self.data_poses_dir.is_dir():\n            raise FileNotFoundError(f\"Data poses dir {self.data_poses_dir} not found.\")\n        if not self.data_2d_raw_dir.is_dir():\n            raise FileNotFoundError(\n                f\"Data 2D raw dir {self.data_2d_raw_dir} not found.\"\n            )\n        if not self.data_3d_raw_dir.is_dir():\n            raise FileNotFoundError(\n                f\"Data 3D raw dir {self.data_3d_raw_dir} not found.\"\n            )\n\n    @staticmethod\n    def _read_variable(fid, name, M, N):\n        \"\"\"\n        Ref:\n            kitti360scripts/devkits/commons/loadCalibration.py\n        \"\"\"\n        # Rewind\n        fid.seek(0, 0)\n\n        # Search for variable identifier\n        line = 1\n        success = 0\n        while line:\n            line = fid.readline()\n            if line.startswith(name):\n                success = 1\n                break\n\n        # Return if variable identifier not found\n        if success == 0:\n            return None\n\n        # Fill matrix\n        line = line.replace(\"%s:\" % name, \"\")\n        line = line.split()\n        assert len(line) == M * N\n        line = [float(x) for x in line]\n        mat = np.array(line).reshape(M, N)\n\n        return mat\n\n    @staticmethod\n    def _load_perspective_intrinsics(intrinsics_path):\n        \"\"\"\n        Args:\n            intrinsics_path: str, path to perspective.txt.\n\n        Returns:\n            A dict, containing:\n            - \"P_rect_00\": 4x4 rectified intrinsic for cam_00.\n            - \"P_rect_01\": 4x4 rectified intrinsic for cam_01.\n            - \"R_rect_00\": 3x3 rectification matrix for cam_00.\n            - \"R_rect_01\": 3xe rectification matrix for cam_01.\n\n        Ref:\n            kitti360scripts/devkits/commons/loadCalibration.py::loadPerspectiveIntrinsic\n        \"\"\"\n        intrinsics_path = Path(intrinsics_path)\n        with open(intrinsics_path, \"r\") as fid:\n            perspective_dict = {}\n            intrinsic_names = [\"P_rect_00\", \"R_rect_00\", \"P_rect_01\", \"R_rect_01\"]\n            last_row = np.array([0, 0, 0, 1]).reshape(1, 4)\n            for intrinsic in intrinsic_names:\n                if intrinsic.startswith(\"P_rect\"):\n                    perspective_dict[intrinsic] = np.concatenate(\n                        (KITTI360Loader._read_variable(fid, intrinsic, 3, 4), last_row)\n                    )\n                else:\n                    perspective_dict[intrinsic] = KITTI360Loader._read_variable(\n                        fid, intrinsic, 3, 3\n                    )\n        return perspective_dict\n\n    def load_images(self, camera_name, sequence_name, frame_ids):\n        \"\"\"\n        Args:\n            camera_name: str, name of camera. e.g. \"cam_00\".\n            sequence_name: str, name of sequence. e.g. \"2013_05_28_drive_0000\".\n            frame_ids: list of int, frame ids. e.g. range(1908, 1971+1).\n\n        Returns:\n            An np.ndarray, float32, [N, H, W, 3], range 0-1, RGB images.\n        \"\"\"\n        im_paths = self.get_image_paths(camera_name, sequence_name, frame_ids)\n        ims = [ct.io.imread(im_path) for im_path in im_paths]\n        ims = np.stack(ims, axis=0)\n\n        return ims\n\n    def get_image_paths(self, camera_name, sequence_name, frame_ids):\n        \"\"\"\n        Args:\n            camera_name: str, name of camera. e.g. \"cam_00\".\n            sequence_name: str, name of sequence. e.g. \"2013_05_28_drive_0000\".\n            frame_ids: list of int, frame ids. e.g. range(1908, 1971+1).\n\n        Returns:\n            An list of str, image paths.\n        \"\"\"\n        # Sanity checks.\n        if camera_name == \"cam_00\":\n            subdir_name = \"image_00\"\n        elif camera_name == \"cam_01\":\n            subdir_name = \"image_01\"\n        else:\n            raise ValueError(f\"Invalid camera_name {camera_name}\")\n\n        # Get image paths.\n        im_dir = (\n            self.data_2d_raw_dir / f\"{sequence_name}_sync\" / subdir_name / \"data_rect\"\n        )\n        im_paths = [im_dir / f\"{frame_id:010d}.png\" for frame_id in frame_ids]\n        for im_path in im_paths:\n            if not im_path.is_file():\n                raise FileNotFoundError(f\"Image {im_path} not found.\")\n\n        return im_paths\n\n    def _load_all_cameras(self, sequence_name):\n        \"\"\"\n        Args:\n            sequence_name: str, name of sequence. e.g. \"2013_05_28_drive_0000\".\n\n        Returns:\n            cam_00_K: 3x3 intrinsics, rectified perspective cam_00.\n            cam_01_K: 3x3 intrinsics, rectified perspective cam_01.\n            cam_00_T_dict: map frame_id to 4x4 T, rectified perspective cam_00.\n            cam_01_T_dict: map frame_id to 4x4 T, rectified perspective cam_01.\n        \"\"\"\n        data_poses_dir = self.data_poses_dir / f\"{sequence_name}_sync\"\n        assert data_poses_dir.is_dir()\n\n        # Load intrinsics and rectification matrices.\n        perspective_path = self.calibration_dir / \"perspective.txt\"\n        perspective_dict = KITTI360Loader._load_perspective_intrinsics(perspective_path)\n        cam_00_K = perspective_dict[\"P_rect_00\"][:3, :3]  # 3x3\n        cam_01_K = perspective_dict[\"P_rect_01\"][:3, :3]  # 3x3\n        cam_00_rec = np.eye(4)  # 4x4\n        cam_00_rec[:3, :3] = perspective_dict[\"R_rect_00\"]\n        cam_01_rec = np.eye(4)  # 4x4\n        cam_01_rec[:3, :3] = perspective_dict[\"R_rect_01\"]\n\n        # IMU to world transformation (poses.txt).\n        poses_path = data_poses_dir / \"poses.txt\"\n        imu_to_world_dict = dict()\n        frame_ids = []\n        for line in np.loadtxt(poses_path):\n            frame_id = int(line[0])\n            frame_ids.append(frame_id)\n            imu_to_world = line[1:].reshape((3, 4))\n            imu_to_world_dict[frame_id] = imu_to_world\n\n        # Camera to IMU transformation (calib_cam_to_pose.txt).\n        cam_to_imu_path = self.calibration_dir / \"calib_cam_to_pose.txt\"\n        with open(cam_to_imu_path, \"r\") as fid:\n            cam_00_to_imu = KITTI360Loader._read_variable(fid, \"image_00\", 3, 4)\n            cam_01_to_imu = KITTI360Loader._read_variable(fid, \"image_01\", 3, 4)\n            cam_02_to_imu = KITTI360Loader._read_variable(fid, \"image_02\", 3, 4)\n            cam_03_to_imu = KITTI360Loader._read_variable(fid, \"image_03\", 3, 4)\n            cam_00_to_imu = ct.convert.pad_0001(cam_00_to_imu)\n            cam_01_to_imu = ct.convert.pad_0001(cam_01_to_imu)\n            cam_02_to_imu = ct.convert.pad_0001(cam_02_to_imu)\n            cam_03_to_imu = ct.convert.pad_0001(cam_03_to_imu)\n\n        # Compute rectified cam_00_to_world, cam_01_to_world.\n        cam_00_to_world_dict = dict()\n        for frame_id in frame_ids:\n            imu_to_world = imu_to_world_dict[frame_id]\n            cam_00_to_world_unrec = imu_to_world @ cam_00_to_imu\n            cam_00_to_world = cam_00_to_world_unrec @ np.linalg.inv(cam_00_rec)\n            cam_00_to_world_dict[frame_id] = ct.convert.pad_0001(cam_00_to_world)\n        cam_01_to_world_dict = dict()\n        for frame_id in frame_ids:\n            imu_to_world = imu_to_world_dict[frame_id]\n            cam_01_to_world_unrec = imu_to_world @ cam_01_to_imu\n            cam_01_to_world = cam_01_to_world_unrec @ np.linalg.inv(cam_01_rec)\n            cam_01_to_world_dict[frame_id] = ct.convert.pad_0001(cam_01_to_world)\n\n        # Sanity check: check our rectified cam0_to_world is the same as the\n        # ones ground-truth given by KITTI-360.\n        cam_00_to_world_path = data_poses_dir / \"cam0_to_world.txt\"\n        gt_cam_00_to_world_dict = dict()\n        for line in np.loadtxt(cam_00_to_world_path):\n            frame_id = int(line[0])\n            gt_cam_00_to_world_dict[frame_id] = line[1:].reshape((4, 4))\n        for frame_id in frame_ids:\n            gt_cam_00_to_world = gt_cam_00_to_world_dict[frame_id]\n            cam_00_to_world = cam_00_to_world_dict[frame_id]\n            assert np.allclose(\n                gt_cam_00_to_world, cam_00_to_world, atol=1e-5, rtol=1e-5\n            )\n\n        # Convert cam_to_world to T.\n        cam_00_T_dict = dict()\n        cam_01_T_dict = dict()\n        for frame_id in frame_ids:\n            cam_00_T = np.linalg.inv(cam_00_to_world_dict[frame_id])\n            cam_01_T = np.linalg.inv(cam_01_to_world_dict[frame_id])\n            cam_00_T_dict[frame_id] = cam_00_T\n            cam_01_T_dict[frame_id] = cam_01_T\n\n        return cam_00_K, cam_01_K, cam_00_T_dict, cam_01_T_dict\n\n    def load_cameras(self, camera_name, sequence_name, frame_ids):\n        \"\"\"\n        Args:\n            camera_name: str, name of camera. e.g. \"cam_00\".\n            sequence_name: str, name of sequence. e.g. \"2013_05_28_drive_0000\".\n            frame_ids: list of int, frame ids. e.g. range(1908, 1971+1).\n\n        Returns:\n            Ks, Ts\n        \"\"\"\n        (\n            cam_00_K,\n            cam_01_K,\n            cam_00_T_dict,\n            cam_01_T_dict,\n        ) = self._load_all_cameras(sequence_name)\n        num_cameras = len(frame_ids)\n\n        if camera_name == \"cam_00\":\n            Ks = [cam_00_K for _ in range(num_cameras)]\n            Ts = [cam_00_T_dict[frame_id] for frame_id in frame_ids]\n        elif camera_name == \"cam_01\":\n            Ks = [cam_01_K for _ in range(num_cameras)]\n            Ts = [cam_01_T_dict[frame_id] for frame_id in frame_ids]\n        else:\n            raise ValueError(f\"Unknown camera name {camera_name}\")\n\n        Ks = np.stack(Ks)\n        Ts = np.stack(Ts)\n        return Ks, Ts\n\n    def _load_all_lidars(self, sequence_name):\n        \"\"\"\n        Args:\n            sequence_name: str, name of sequence. e.g. \"2013_05_28_drive_0000\".\n\n        Returns:\n            velo_to_world: 4x4 metric.\n        \"\"\"\n        data_poses_dir = self.data_poses_dir / f\"{sequence_name}_sync\"\n        assert data_poses_dir.is_dir()\n\n        # IMU to world transformation (poses.txt).\n        poses_path = data_poses_dir / \"poses.txt\"\n        imu_to_world_dict = dict()\n        frame_ids = []\n        for line in np.loadtxt(poses_path):\n            frame_id = int(line[0])\n            frame_ids.append(frame_id)\n            imu_to_world = line[1:].reshape((3, 4))\n            imu_to_world_dict[frame_id] = imu_to_world\n\n        # Camera to IMU transformation (calib_cam_to_pose.txt).\n        cam_to_imu_path = self.calibration_dir / \"calib_cam_to_pose.txt\"\n        with open(cam_to_imu_path, \"r\") as fid:\n            cam_00_to_imu = KITTI360Loader._read_variable(fid, \"image_00\", 3, 4)\n            cam_00_to_imu = ct.convert.pad_0001(cam_00_to_imu)\n\n        # Camera00 to Velo transformation (calib_cam_to_velo.txt).\n        cam00_to_velo_path = self.calibration_dir / \"calib_cam_to_velo.txt\"\n        with open(cam00_to_velo_path, \"r\") as fid:\n            line = fid.readline().split()\n            line = [float(x) for x in line]\n            cam_00_to_velo = np.array(line).reshape(3, 4)\n            cam_00_to_velo = ct.convert.pad_0001(cam_00_to_velo)\n\n        # Compute velo_to_world\n        velo_to_world_dict = dict()\n        for frame_id in frame_ids:\n            imu_to_world = imu_to_world_dict[frame_id]\n            cam_00_to_world_unrec = imu_to_world @ cam_00_to_imu\n            velo_to_world = cam_00_to_world_unrec @ np.linalg.inv(cam_00_to_velo)\n            velo_to_world_dict[frame_id] = ct.convert.pad_0001(velo_to_world)\n\n        return velo_to_world_dict\n\n    def load_lidars(self, sequence_name, frame_ids):\n        \"\"\"\n        Args:\n            sequence_name: str, name of sequence. e.g. \"2013_05_28_drive_0000\".\n            frame_ids: list of int, frame ids. e.g. range(1908, 1971+1).\n\n        Returns:\n            velo_to_worlds\n        \"\"\"\n        velo_to_world_dict = self._load_all_lidars(sequence_name)\n        velo_to_worlds = [velo_to_world_dict[frame_id] for frame_id in frame_ids]\n        velo_to_worlds = np.stack(velo_to_worlds)\n        return velo_to_worlds\n\n\ndef main():\n    # Load cameras.\n    k3 = KITTI360Loader(kitti_360_root=Path(\"data\") / \"KITTI-360\")\n    cam_00_Ks, cam_00_Ts = k3.load_cameras(\n        camera_name=\"cam_00\",\n        sequence_name=\"2013_05_28_drive_0000\",\n        frame_ids=range(1908, 1971 + 1),\n    )\n    cam_01_Ks, cam_01_Ts = k3.load_cameras(\n        camera_name=\"cam_01\",\n        sequence_name=\"2013_05_28_drive_0000\",\n        frame_ids=range(1908, 1971 + 1),\n    )\n\n    # Load images.\n    im_cam_00s = k3.load_images(\n        camera_name=\"cam_00\",\n        sequence_name=\"2013_05_28_drive_0000\",\n        frame_ids=range(1908, 1971 + 1),\n    )\n    im_cam_01s = k3.load_images(\n        camera_name=\"cam_01\",\n        sequence_name=\"2013_05_28_drive_0000\",\n        frame_ids=range(1908, 1971 + 1),\n    )\n\n    # Visualize.\n    cam_00_frames = ct.camera.create_camera_ray_frames(cam_00_Ks, cam_00_Ts, size=0.8)\n    cam_01_frames = ct.camera.create_camera_ray_frames(cam_01_Ks, cam_01_Ts, size=0.8)\n    o3d.visualization.draw_geometries([cam_00_frames, cam_01_frames])\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "preprocess/kitti360_to_nerf.py",
    "content": "from pathlib import Path\n\nfrom kitti360_loader import KITTI360Loader\nimport camtools as ct\nimport numpy as np\nimport json\n\n\ndef normalize_Ts(Ts):\n    # New Cs.\n    Cs = np.array([ct.convert.T_to_C(T) for T in Ts])\n    normalize_mat = ct.normalize.compute_normalize_mat(Cs)\n    Cs_new = ct.project.homo_project(Cs.reshape((-1, 3)), normalize_mat)\n\n    # New Ts.\n    Ts_new = []\n    for T, C_new in zip(Ts, Cs_new):\n        pose = ct.convert.T_to_pose(T)\n        pose[:3, 3] = C_new\n        T_new = ct.convert.pose_to_T(pose)\n        Ts_new.append(T_new)\n\n    return Ts_new\n\n\ndef main():\n    project_root = Path(__file__).parent.parent\n    kitti_360_root = project_root / \"data\" / \"kitti360\" / \"KITTI-360\"\n    kitti_360_parent_dir = kitti_360_root.parent\n\n    # Specify frames and splits.\n    sequence_name = \"2013_05_28_drive_0000\"\n    sequence_id = \"1908\"\n\n    if sequence_id == \"1538\":\n        print(\"Using sqequence 1538-1601\")\n        s_frame_id = 1538\n        e_frame_id = 1601  # Inclusive\n        val_frame_ids = [1551, 1564, 1577, 1590]\n    elif sequence_id == \"1728\":\n        print(\"Using sqequence 1728-1791\")\n        s_frame_id = 1728\n        e_frame_id = 1791  # Inclusive\n        val_frame_ids = [1741, 1754, 1767, 1780]\n    elif sequence_id == \"1908\":\n        print(\"Using sqequence 1908-1971\")\n        s_frame_id = 1908\n        e_frame_id = 1971  # Inclusive\n        val_frame_ids = [1921, 1934, 1947, 1960]\n    elif sequence_id == \"3353\":\n        print(\"Using sqequence 3353-3416\")\n        s_frame_id = 3353\n        e_frame_id = 3416  # Inclusive\n        val_frame_ids = [3366, 3379, 3392, 3405]\n    else:\n        raise ValueError(f\"Invalid sequence id: {sequence_id}\")\n\n    frame_ids = list(range(s_frame_id, e_frame_id + 1))\n    num_frames = len(frame_ids)\n\n    test_frame_ids = val_frame_ids\n    train_frame_ids = [x for x in frame_ids if x not in val_frame_ids]\n\n    # Load KITTI-360 dataset.\n    k3 = KITTI360Loader(kitti_360_root)\n\n    # Get image paths.\n    cam_00_im_paths = k3.get_image_paths(\"cam_00\", sequence_name, frame_ids)\n    cam_01_im_paths = k3.get_image_paths(\"cam_01\", sequence_name, frame_ids)\n    im_paths = cam_00_im_paths + cam_01_im_paths\n\n    # Get Ks, Ts.\n    cam_00_Ks, cam_00_Ts = k3.load_cameras(\"cam_00\", sequence_name, frame_ids)\n    cam_01_Ks, cam_01_Ts = k3.load_cameras(\"cam_01\", sequence_name, frame_ids)\n    Ks = np.concatenate([cam_00_Ks, cam_01_Ks], axis=0)\n    Ts = np.concatenate([cam_00_Ts, cam_01_Ts], axis=0)\n    # Ts = normalize_Ts(Ts)\n\n    # Get image dimensions, assume all images have the same dimensions.\n    im_rgb = ct.io.imread(cam_00_im_paths[0])\n    im_h, im_w, _ = im_rgb.shape\n\n    # Get lidar paths (range view not raw data).\n    range_view_dir = kitti_360_parent_dir / \"train\"\n    range_view_paths = [\n        range_view_dir / \"{:010d}.npy\".format(int(frame_id)) for frame_id in frame_ids\n    ]\n\n    # Get lidar2world.\n    lidar2world = k3.load_lidars(sequence_name, frame_ids)\n\n    # Get image dimensions, assume all images have the same dimensions.\n    lidar_range_image = np.load(range_view_paths[0])\n    lidar_h, lidar_w, _ = lidar_range_image.shape\n\n    # Split by train/test/val.\n    all_indices = [i - s_frame_id for i in frame_ids]\n    train_indices = [i - s_frame_id for i in train_frame_ids]\n    val_indices = [i - s_frame_id for i in val_frame_ids]\n    test_indices = [i - s_frame_id for i in test_frame_ids]\n\n    # all_indices = all_indices + [i + num_frames for i in all_indices]\n    # train_indices = train_indices + [i + num_frames for i in train_indices]\n    # val_indices = val_indices + [i + num_frames for i in val_indices]\n    # test_indices = test_indices + [i + num_frames for i in test_indices]\n\n    split_to_all_indices = {\n        \"train\": train_indices,\n        \"val\": val_indices,\n        \"test\": test_indices,\n    }\n    for split, indices in split_to_all_indices.items():\n        print(f\"Split {split} has {len(indices)} frames.\")\n        im_paths_split = [im_paths[i] for i in indices]\n        lidar_paths_split = [range_view_paths[i] for i in indices]\n        lidar2world_split = [lidar2world[i] for i in indices]\n        Ks_split = [Ks[i] for i in indices]\n        Ts_split = [Ts[i] for i in indices]\n\n        json_dict = {\n            \"w\": im_w,\n            \"h\": im_h,\n            \"w_lidar\": lidar_w,\n            \"h_lidar\": lidar_h,\n            \"fl_x\": Ks_split[0][0, 0],\n            \"fl_y\": Ks_split[0][1, 1],\n            \"cx\": Ks_split[0][0, 2],\n            \"cy\": Ks_split[0][1, 2],\n            \"aabb_scale\": 2,\n            \"frames\": [\n                {\n                    \"file_path\": str(path.relative_to(kitti_360_parent_dir)),\n                    \"transform_matrix\": ct.convert.T_to_pose(T).tolist(),\n                    \"lidar_file_path\": str(\n                        lidar_path.relative_to(kitti_360_parent_dir)\n                    ),\n                    \"lidar2world\": lidar2world.tolist(),\n                }\n                for (\n                    path,\n                    T,\n                    lidar_path,\n                    lidar2world,\n                ) in zip(\n                    im_paths_split,\n                    Ts_split,\n                    lidar_paths_split,\n                    lidar2world_split,\n                )\n            ],\n        }\n        json_path = kitti_360_parent_dir / f\"transforms_{sequence_id}_{split}.json\"\n\n        with open(json_path, \"w\") as f:\n            json.dump(json_dict, f, indent=2)\n            print(f\"Saved {json_path}.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "preprocess/nerfmvl_loader.py",
    "content": "from pathlib import Path\nimport numpy as np\n\n\nclass NeRFMVLLoader:\n    def __init__(self, nerf_mvl_root, class_name) -> None:\n        # Root directory.\n        self.nerf_mvl_root = Path(nerf_mvl_root)\n        if not self.nerf_mvl_root.is_dir():\n            raise FileNotFoundError(f\"NeRF_MVL {nerf_mvl_root} not found.\")\n\n        # Other directories.\n        self.data_3d_raw_dir = self.nerf_mvl_root / class_name\n        self.lidar2world_path = self.data_3d_raw_dir / \"lidar2world.txt\"\n\n        # Check if all directories exist.\n        if not self.data_3d_raw_dir.is_dir():\n            raise FileNotFoundError(\n                f\"Data 3D raw dir {self.data_3d_raw_dir} not found.\"\n            )\n\n    def _load_all_lidars(\n        self,\n    ):\n        \"\"\"\n        Args:\n\n        Returns:\n            velo_to_world: 4x4 metric.\n        \"\"\"\n\n        velo_to_world_dict = np.loadtxt(self.lidar2world_path)\n        return velo_to_world_dict.reshape(-1, 4, 4)\n\n    def load_lidars(self, frame_ids):\n        \"\"\"\n        Args:\n            frame_ids: list of int, frame ids. e.g. range(1908, 1971+1).\n\n        Returns:\n            velo_to_worlds\n        \"\"\"\n        velo_to_world_dict = self._load_all_lidars()\n        velo_to_worlds = [velo_to_world_dict[frame_id] for frame_id in frame_ids]\n        velo_to_worlds = np.stack(velo_to_worlds)\n        return velo_to_worlds\n\n\ndef main():\n    dataset = NeRFMVLLoader(Path(\"data\") / \"nerf_mvl\" / \"nerf_mvl_7k_pano\", \"pier\")\n    velo_to_world_dict = dataset._load_all_lidars()\n    return\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "preprocess/nerfmvl_to_nerf.py",
    "content": "import os\nfrom nerfmvl_loader import NeRFMVLLoader\nimport numpy as np\nimport json\nfrom pathlib import Path\n\n\ndef main():\n    project_root = Path(__file__).parent.parent\n    nerf_mvl_root = project_root / \"data\" / \"nerf_mvl\" / \"nerf_mvl_7k_pano\"\n    nerf_mvl_parent_dir = nerf_mvl_root.parent\n\n    # Specify frames and splits.\n    train_split = {\n        \"water_safety_barrier\": 2,\n        \"tire\": 2,\n        \"pier\": 2,\n        \"plant\": 2,\n        \"warning_sign\": 2,\n        \"bollard\": 2,\n        \"pedestrian\": 3,\n        \"car\": 3,\n        \"traffic_cone\": 3,\n    }\n\n    for class_name, split_intervel in train_split.items():\n        # Get lidar paths (range view not raw data).\n        range_view_dir = nerf_mvl_root / class_name\n        filenames = os.listdir(range_view_dir)\n        filenames.remove(\"lidar2world.txt\")\n        range_view_paths = [\n            Path(os.path.join(range_view_dir, filename)) for filename in filenames\n        ]\n        num_samples = len(range_view_paths)\n        frame_ids = np.arange(num_samples)\n\n        train_frame_ids = [i for i in range(0, num_samples, split_intervel)]\n        val_frame_ids = [i for i in range(0, num_samples, split_intervel * 20)]\n        test_frame_ids = val_frame_ids\n\n        # Load NeRF_MVL dataset.\n        nerf_mvl_dataset = NeRFMVLLoader(nerf_mvl_root, class_name)\n\n        # Get lidar2world.\n        lidar2world = nerf_mvl_dataset.load_lidars(frame_ids)\n\n        # Get image dimensions, assume all images have the same dimensions.\n        lidar_range_image = np.load(range_view_paths[0])[\"data\"]\n        lidar_h, lidar_w, _ = lidar_range_image.shape\n\n        # Split by train/test/val.\n        all_indices = frame_ids\n        train_indices = train_frame_ids\n        val_indices = val_frame_ids\n        test_indices = test_frame_ids\n\n        split_to_all_indices = {\n            \"train\": train_indices,\n            \"val\": val_indices,\n            \"test\": test_indices,\n        }\n        for split, indices in split_to_all_indices.items():\n            print(f\"Split {split} has {len(indices)} frames.\")\n\n            lidar_paths_split = [range_view_paths[i] for i in indices]\n            lidar2world_split = [lidar2world[i] for i in indices]\n            json_dict = {\n                \"w_lidar\": lidar_w,\n                \"h_lidar\": lidar_h,\n                \"aabb_scale\": 2,\n                \"frames\": [\n                    {\n                        \"lidar_file_path\": str(\n                            lidar_path.relative_to(nerf_mvl_parent_dir)\n                        ),\n                        \"lidar2world\": lidar2world.tolist(),\n                    }\n                    for (\n                        lidar_path,\n                        lidar2world,\n                    ) in zip(\n                        lidar_paths_split,\n                        lidar2world_split,\n                    )\n                ],\n            }\n            json_path = nerf_mvl_parent_dir / f\"transforms_{class_name}_{split}.json\"\n\n            with open(json_path, \"w\") as f:\n                json.dump(json_dict, f, indent=2)\n                print(f\"Saved {json_path}.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "readme.md",
    "content": "<p align=\"center\">\n   <img src=\"./assets/lidar_nerf_logo_640.png\" width=\"480\" />\n</p>\n\n<h1 align=\"center\">LiDAR-NeRF: Novel LiDAR View Synthesis via Neural Radiance Fields</h1>\n<p align=\"center\">\n   <a href=\"https://tangtaogo.github.io/lidar-nerf-website/\">\n      <img src='https://img.shields.io/badge/project_page-url-yellow?style=for-the-badge' alt='Home Page'></a>\n   <a href=\"https://arxiv.org/abs/2304.10406\">\n      <img src='https://img.shields.io/badge/paper-pdf-green?style=for-the-badge' alt='Paper PDF'></a>\n   <a href=\"https://youtu.be/YX4LX025mZQ\">\n      <img src='https://img.shields.io/badge/video-mp4-blue?style=for-the-badge' alt='Video MP4'></a>\n</p>\n<p align=\"center\">\n   <a href=\"https://scholar.google.com.hk/citations?user=1ltylFwAAAAJ&hl=zh-CN&oi=sra\">Tao Tang</a>\n   ·\n   <a href=\"https://damo.alibaba.com/labs/intelligent-transportation\">Longfei Gao</a>\n   ·\n   <a href=\"https://wanggrun.github.io/\">Guangrun Wang</a>\n   ·\n   <a href=\"https://scholar.google.com/citations?user=2w9VSWIAAAAJ&hl=en\">Yixing Lao</a>\n   ·\n   <a href=\"https://damo.alibaba.com/labs/intelligent-transportation\">Peng Chen</a>\n   ·\n   <a href=\"https://hszhao.github.io/\">Hengshuang Zhao</a>\n   ·\n   <a href=\"https://damo.alibaba.com/labs/intelligent-transportation\">Dayang Hao</a>\n   ·\n   <a href=\"https://scholar.google.com/citations?user=voxznZAAAAAJ\">Xiaodan Liang*</a>\n   ·\n   <a href=\"https://scholar.google.com/citations?user=n-B0jr4AAAAJ\">Mathieu Salzmann</a>\n   ·\n   <a href=\"https://scholar.google.com.hk/citations?user=Jtmq_m0AAAAJ&hl=zh-CN&oi=sra\">Kaicheng Yu</a>\n</p>\n\n<p align=\"center\">\n   <a href=\"https://github.com/tangtaogo/lidar-nerf/actions/workflows/formatter.yml\"><img src=\"https://github.com/tangtaogo/lidar-nerf/actions/workflows/formatter.yml/badge.svg\" alt=\"Formatter\"></a>\n</p>\n\n![lidar-nerf](./assets/lidar-nerf.png)\n\n![lidar-nerf-res](./assets/lidar-nerf-res.png)\n\nThis paper introduces a new task of novel LiDAR view synthesis and proposes a\ndifferentiable framework called **LiDAR-NeRF** with a structural regularization,\nas well as an object-centric multi-view LiDAR dataset called **NeRF-MVL**.\n\n1. We formulate the first differentiable framework, LiDAR-NeRF, for novel LiDAR\n   view synthesis, which can render novel point clouds with point intensity and\n   ray-drop probability without explicit 3D reconstruction.\n2. We propose a structural regularization method to effectively preserve local\n   structural details, thereby guiding the model towards more precise geometry\n   estimations, leading to more faithful novel LiDAR view synthesis.\n3. We establish the NeRF-MVL dataset from LiDAR sensors of real autonomous\n   vehicles to evaluate the object-centric novel LiDAR view synthesis.\n4. We demonstrate the effectiveness of our LiDAR-NeRF quantitatively and\n   qualitatively in both scene-level and object-level novel LiDAR view\n   synthesis.\n\n## News\n\n- [2023/07/14] LiDAR-NeRF v0.1.0 released. NeRF-MVL dataset released.\n\n## Installation\n\n```bash\nconda create -n lidarnerf python=3.9\nconda activate lidarnerf\n\n# Dependencies\npip install -r requirements_torch.txt\npip install -r requirements.txt\n\n# tiny-cuda-nn\n# This may take a while, please refer to the official documentation\npip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch\n\n# camtools\npip install git+https://github.com/yxlao/camtools.git\n\n# Install lidar-nerf\npip install -e .\npython -c \"import lidarnerf; print(lidarnerf.__version__)\"\n```\n\n## Dataset\n\n### KITTI-360 dataset\n\nFirst, download KITTI-360 dataset from\n[here](https://www.cvlibs.net/datasets/kitti-360/index.php) and put the dataset\ninto `data/kitti360`. Your folder structure should look like this:\n\n```bash\ndata\n└── kitti360\n    └── KITTI-360\n        ├── calibration\n        ├── data_2d_raw\n        ├── data_3d_raw\n        └── data_poses\n```\n\nNext, run KITTI-360 dataset preprocessing:\n\n```bash\n# Generate train range images\npython preprocess/generate_train_rangeview.py --dataset kitti360\n\n# Generate jsons\npython preprocess/kitti360_to_nerf.py\n\n# Calculate center pose (optional) can directly use our config\npython preprocess/cal_centerpose_bound.py\n```\n\nAfter preprocessing, your folder structure should look like this:\n\n```bash\ndata\n└── kitti360\n    ├── train\n    ├── KITTI-360\n    │   ├── calibration\n    │   ├── data_2d_raw\n    │   ├── data_3d_raw\n    │   └── data_poses\n    ├── transforms_{sequence_id}test.json\n    ├── transforms_{sequence_id}train.json\n    └── transforms_{sequence_id}val.json\n```\n\n### NeRF-MVL dataset\n\nFirst, download our NeRF-MVL dataset from\n[here](https://drive.google.com/drive/folders/1ZCuM3lCvWATXL79WdqrFxbYd4kwsHoTM?usp=sharing).\nYour folder structure should look like this:\n\n```bash\n$ tree data -l -L 2\ndata\n└── nerf_mvl\n    └── nerf_mvl_7k\n        └── {class_name}\n            ├── {frame_id}.npy\n            └── lidar2world.txt\n```\n\nNext, run NeRF-MVL dataset preprocessing:\n\n```bash\n# If you only download raw nerf_mvl_7k, you need convert it to nerf_mvl_7k_pano(optional)\n# or directly download our processed dataset in https://drive.google.com/drive/folders/1pwnIjBUMIYg0fmLaeLj-sKfVcnBexlMq?usp=sharing\n\n# Generate train range images\npython preprocess/generate_train_rangeview.py --dataset nerf_mvl\n\n# Generate jsons\npython preprocess/nerfmvl_to_nerf.py\n```\n\nAfter preprocessing, your folder structure should look like this:\n\n```bash\ndata\n└── nerf_mvl\n    ├── dataset_bbox_7k.npy\n    ├── nerf_mvl_7k\n    │   └── {class_name}\n    │       ├── {frame_id}.npy\n    │       └── lidar2world.txt\n    ├── nerf_mvl_7k_pano\n    │   └── {class_name}\n    │       ├── {frame_id}.npy\n    │       └── lidar2world.txt\n    ├── transforms_{class_name}_test.json\n    ├── transforms_{class_name}_train.json\n    └── transforms_{class_name}_val.json\n```\n\n## Run\n\n```bash\n# kitti360\npython main_lidarnerf.py -L --workspace log/kitti360_lidar\n\n# nerf_mvl\npython main_lidarnerf.py --config configs/nerf_mvl.txt  -L --workspace log/trial_nerf_nerf_mvl\n```\n\n## Pre-trained Models\n\nYou can download our pre-trained models\n[here](https://drive.google.com/drive/folders/1pwnIjBUMIYg0fmLaeLj-sKfVcnBexlMq?usp=sharing).\n\n## Incoming\n\n- [ ] Support multi-modality, e.g., RGB & LiDAR\n- [ ] Support more datasets, e.g, nuScenes, Waymo\n- [ ] Support more implicit geometry representation, e.g., SDF\n\n# Contribution\n\nWe welcome all forms of community contributions, including issues, bug fixes,\nnew features, and more. Please\n[format the code](https://black.readthedocs.io/en/stable/getting_started.html)\nbefore submitting a pull request.\n\n## Citation\n\nIf you find our code or paper helps, please consider citing:\n\n```bibtex\n@article{tao2023lidar,\n    title   = {LiDAR-NeRF: Novel LiDAR View Synthesis via Neural Radiance Fields},\n    author  = {Tao, Tang and Gao, Longfei and Wang, Guangrun and Lao, Yixing and Chen, Peng and Zhao hengshuang and Hao, Dayang and Liang, Xiaodan and Salzmann, Mathieu and Yu, Kaicheng},\n    journal = {arXiv preprint arXiv:2304.10406},\n    year    = {2023}\n}\n```\n\n## Acknowledgments\n\nThis code is built on top of the super-useful\n[torch-ngp](https://github.com/ashawkey/torch-ngp) implementation.\n\n```bibtex\n@misc{torch-ngp,\n    author = {Jiaxiang Tang},\n    year   = {2022},\n    note   = {https://github.com/ashawkey/torch-ngp},\n    title  = {Torch-ngp: a PyTorch implementation of instant-ngp}\n}\n```\n\nThe raydrop-mlp code for PCGen is borrowed from\n[nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch).\n\n```bibtex\n@misc{lin2020nerfpytorch,\n    title        = {NeRF-pytorch},\n    author       = {Yen-Chen, Lin},\n    publisher    = {GitHub},\n    journal      = {GitHub repository},\n    howpublished = {\\url{https://github.com/yenchenlin/nerf-pytorch/}},\n    year         = {2020}\n}\n```\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch-ema\ntorchmetrics\nninja\ntrimesh\nopencv-python\ntensorboardX\nnumpy\npandas\ntqdm\nmatplotlib\nPyMCubes\nrich\npysdf\ndearpygui\npackaging\nscipy\nlpips\nimageio==2.13.0\ntorchmetrics\nimageio-ffmpeg==0.4.8\nopen3d\nconfigargparse\nscikit-image\nnksr\nblack\n\n# nuscenes\nnuscenes-devkit>=1.1.1\npyquaternion\n"
  },
  {
    "path": "requirements_torch.txt",
    "content": "torch==2.0.0\ntorchvision\ntorchaudio\n"
  },
  {
    "path": "setup.py",
    "content": "from pathlib import Path\nfrom setuptools import setup\nimport re\n\n_pwd = Path(__file__).parent.absolute()\n\n\ndef main():\n    cmdclass = dict()\n\n    version = None\n    init_path = _pwd / \"lidarnerf\" / \"__init__.py\"\n    with open(init_path, \"r\", encoding=\"utf-8\") as f:\n        lines = f.readlines()\n        for line in lines:\n            match_res = re.match(r'^__version__ = \"(.*)\"', line)\n            if match_res:\n                version = match_res.group(1)\n                break\n    if version is None:\n        raise RuntimeError(f\"Cannot find version from {init_path}\")\n    print(f\"Detected lidarnerf version: {version}\")\n\n    _ = setup(\n        name=\"lidarnerf\",\n        version=version,\n        description=\"LiDAR-NeRF: Novel LiDAR View Synthesis via Neural Radiance Fields\",\n        packages=[\"lidarnerf\", \"lidarnvs\"],\n        cmdclass=cmdclass,\n        include_package_data=True,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]