[
  {
    "path": "LICENSE",
    "content": "Copyright 2022 Megvii Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation (AAAI Oral)\n\n\n# :partying_face: ***New*** :partying_face: Code for more powerful PVD-AL (the follow-up work of PVD) is now provided [here](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL). \n *(We strongly recommend using [PVD-AL](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL) (the follow-up work of pvd) with better performance).*\n\n\n## [Project Page](http://sk-fun.fun/PVD/) | [Paper](https://arxiv.org/abs/2211.15977) | [Datasets](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing) | [Ckpts](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing) | [Chinese tutorial](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/%E4%B8%AD%E6%96%87%E4%BB%8B%E7%BB%8D.md) | [zhihu](https://zhuanlan.zhihu.com/p/605121286)|\n\n## Introduction\nIn this paper, we propose Progressive Volume Distillation (PVD), a systematic distillation method that allows any-to-any conversions between different neural architectures, including MLP(NeRF), sparse(Plenoxels) or low-rank tensors(TensoRF), hash tables(INGP).\n\n## Installation\nWe recommend using [Anaconda](https://www.anaconda.com/) to setup the environment. Run the following commands:\n\n*Step1*: Create a conda environment named 'pvd'\n```\nconda create --name pvd python=3.7\nconda activate pvd\npip install -r ./tools/requirements.txt\n```\n*Step2*: Install extension modules. (Draw from the great project [torch-ngp](https://github.com/ashawkey/torch-ngp) that we mainly rely on.)\n```\nbash ./tools/install_extensions.sh\n```\n\n## Datastes & Pretrained-teacher models\nYou can download Synthetic-NeRF/LLFF/Tanks&Temples datasets from [google](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing), or from [baidu](https://pan.baidu.com/s/1ky_TWrbUZG_MpHTBhncAKA?pwd=4h2h).\n\nAnd download pretrained-teacher-models from [google](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing), or from [baidu](https://pan.baidu.com/s/1LGLXwLGusX60GpAywLwosg?pwd=34k8).\n\nYou can also train a teacher model according to the follow guidance.\n\n## Train a teacher\n```\n# train a hash-based(INGP) teacher\npython main_just_train_tea.py ./data/nerf_synthetic/chair --model_type hash --data_type synthetic  --workspace ./log/train_teacher/hash_chair\n\n# train a sparse-tensor-based(TensoRF VM-decomposion) teacher\npython main_just_train_tea.py ./data/nerf_synthetic/chair --model_type vm --data_type synthetic  --workspace ./log/train_teacher/vm_chair\n\n# train a MLP-based(NeRF) teacher\npython main_just_train_tea.py ./data/nerf_synthetic/chair --model_type mlp --data_type synthetic  --workspace ./log/train_teacher/mlp_chair\n\n# train a tensors-based(Plenoxels) teacher\npython main_just_train_tea.py ./data/nerf_synthetic/chair --model_type tensors --data_type synthetic  --workspace ./log/train_teacher/tensors_chair\n\n```\n\n## Distill a student\n```\n# teacher: hash(INGP),  student: vm(tensoRF)\npython3 main_distill_mutual.py  ./data/nerf_synthetic/chair \\\n                    --data_type synthetic \\\n                    --teacher_type hash \\\n                    --ckpt_teacher ./log/train_teacher/hash_chair/checkpoints/XXX.pth \\\n                    --model_type vm \\\n                    --workspace ./log/distill_student/hash2vm/chair\n                    \n# teacher: MLP(NeRF),  student: tensors(Plenoxels)\npython3 main_distill_mutual.py  ./data/nerf_synthetic/chair \\\n                    --data_type synthetic \\\n                    --teacher_type mlp \\\n                    --ckpt_teacher ./log/train_teacher/mlp_chair/checkpoints/XXX.pth \\\n                    --model_type tensors \\\n                    --workspace ./log/distill_student/mlp2tensors/chair\n                   \n```\n\n## Evaluation\n\n```\n# evaluate a hash teacher\npython main_distill_mutual.py ./data/nerf_synthetic/chair  --teacher_type hash --ckpt_teacher PATH/TO/CKPT.pth --test_teacher --data_type synthetic --workspace ./log/eval_teacher/hash_chair\n\n# evaluate a mlp student\npython main_distill_mutual.py ./data/nerf_synthetic/chair --model_type mlp --ckpt PATH/TO/CKPT.pth --test --data_type synthetic --workspace ./log/eval_student/mlp_chair\n```\n\n## More detailed parameter description and running commonds\nPlease refer to [more running description](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/details.md) for details of training different types of datasets, parameter adjustment, key settings, etc.\n\n## Citation\n\nIf you find our code or paper useful, please consider citing\n```\n@article{fang2022one,\n  title={One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation},\n  author={Fang, Shuangkang and Xu, Weixin and Wang, Heng and Yang, Yi and Wang, Yufeng and Zhou, Shuchang},\n  journal={arXiv preprint arXiv:2211.15977},\n  year={2022}\n}\n```\n\n### Acknowledgement\nWe would like to thank [ingp](https://github.com/NVlabs/instant-ngp),  [torch-ngp](https://github.com/ashawkey/torch-ngp), [TensoRF](https://github.com/apchenstu/TensoRF), [Plenoxels](https://github.com/sxyu/svox2), [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch)  for their great frameworks!\n\nAlso check out [Arch-Net](https://github.com/megvii-research/Arch-Net) for more on general progressive distillation.\n"
  },
  {
    "path": "distill_mutual/network.py",
    "content": "import torch\nfrom time import time\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tools.encoding import get_encoder\nfrom tools.activation import trunc_exp\nfrom .renderer import NeRFRenderer\nimport raymarching\n\n\nclass NeRFNetwork(NeRFRenderer):\n    def __init__(\n        self,\n        encoding=\"hashgrid\",\n        encoding_dir=\"sphere_harmonics\",\n        encoding_bg=\"hashgrid\",\n        num_layers=2,\n        hidden_dim=64,\n        geo_feat_dim=15,\n        num_layers_color=3,\n        hidden_dim_color=64,\n        num_layers_bg=2,\n        hidden_dim_bg=64,\n        bound=1,\n        model_type=\"hash\",\n        args=None,\n        is_teacher=False,\n        **kwargs,\n    ):\n        super().__init__(bound, **kwargs)\n        # sigma network\n        assert model_type in [\"hash\", \"mlp\", \"vm\", \"tensors\"]\n        self.is_teacher = is_teacher\n        self.num_layers = num_layers\n        self.hidden_dim = hidden_dim\n        self.geo_feat_dim = geo_feat_dim\n        self.args = args\n        self.opt = args\n        self.model_type = model_type\n\n        self.plenoxel_degree = args.plenoxel_degree\n        self.plenoxel_res = eval(args.plenoxel_res)\n\n        assert len(self.plenoxel_res) == 3\n\n        self.encoder, self.in_dim = get_encoder(\n            encoding,\n            desired_resolution=2048 * bound,\n            num_levels=14,\n        )\n\n        if \"hash\" != self.model_type:\n            self.encoder = None\n\n        if self.model_type == \"mlp\":\n            self.encoder_nerf_pe, self.in_dim_nerf = get_encoder(\n                encoding=\"frequency\", multires=self.args.PE\n            )\n            self.skips = self.args.skip\n            self.nerf_layer_num = self.args.nerf_layer_num\n            W = self.args.nerf_layer_wide\n            self.nerf_mlp = [nn.Linear(self.in_dim_nerf, W)]\n            for i in range(self.nerf_layer_num - 2):\n                if i != self.skips:\n                    self.nerf_mlp.append(nn.Linear(W, W))\n                else:\n                    self.nerf_mlp.append(nn.Linear(W + self.in_dim_nerf, W))\n            self.nerf_mlp.append(nn.Linear(W, self.in_dim))\n            self.nerf_mlp = nn.ModuleList(self.nerf_mlp)\n\n        elif self.model_type == \"vm\":\n            self.sigma_rank = [16] * 3\n            self.color_rank = [48] * 3\n            self.color_feat_dim = 15  # geo_feat_dim\n            self.mat_ids = [[0, 1], [0, 2], [1, 2]]\n            self.vec_ids = [2, 1, 0]\n            self.resolution = [self.opt.resolution0] * 3\n            # mat: paralist[1,16,res0,res0] repeat 3   vec: paralist[1,16,res0,1] repeat 3; repeat3 because decompose 3D grid [H, W, D] to three 2D mat [H, W], [H,D], [W, D] or decompose to three 1D vec [H], [W], [D]\n            self.sigma_mat, self.sigma_vec = self.init_one_vm(\n                self.sigma_rank, self.resolution\n            )\n            # mat: paralist[1,48,res0,res0] repeat 3   vec: paralist[1,48,res0,1] repeat 3\n            self.color_mat, self.color_vec = self.init_one_vm(\n                self.color_rank, self.resolution\n            )\n            # Linear(in_features=144, out_features=27)\n            self.basis_mat = nn.Linear(\n                sum(self.color_rank), self.color_feat_dim, bias=False\n            )\n        elif self.model_type == \"tensors\":\n            self.init_plenoxel_volume(\n                s=0.02,\n                fea_dim=self.plenoxel_degree ** 2 * 3 + 1,\n                volume=self.plenoxel_res,\n            )\n\n        elif self.model_type == \"hash\":\n            pass\n        else:\n            raise ValueError(f\"error model_type:{self.model_type}\")\n\n        if self.model_type != \"vm\" and self.model_type != \"tensors\":\n            sigma_net = []\n            for l in range(num_layers):\n                if l == 0:\n                    in_dim = self.in_dim\n                else:\n                    in_dim = hidden_dim\n\n                if l == num_layers - 1:\n                    out_dim = (\n                        1 + self.geo_feat_dim\n                    )  # 1 sigma + 15 SH features for color\n                else:\n                    out_dim = hidden_dim\n\n                sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n            self.sigma_net = nn.ModuleList(sigma_net)\n\n        # color network\n        self.num_layers_color = num_layers_color\n        self.hidden_dim_color = hidden_dim_color\n        # self.encoder_dir, self.in_dim_dir = get_encoder(encoding=encoding_dir)\n        if self.model_type == \"tensors\":\n            self.encoder_dir, self.in_dim_dir = get_encoder(\n                encoding=\"sphere_harmonics\",\n                degree=self.plenoxel_degree,\n            )\n\n        else:\n            self.encoder_dir, self.in_dim_dir = get_encoder(\n                encoding=encoding_dir, input_dim=3, multires=2\n            )\n\n        if self.model_type != \"tensors\":\n            color_net = []\n            for l in range(num_layers_color):\n                if l == 0:\n                    in_dim = self.in_dim_dir + self.geo_feat_dim\n                else:\n                    in_dim = hidden_dim\n\n                if l == num_layers_color - 1:\n                    out_dim = 3  # 3 rgb\n                else:\n                    out_dim = hidden_dim\n\n                color_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n            self.color_net = nn.ModuleList(color_net)\n\n        # background network\n        if self.bg_radius > 0:\n            self.num_layers_bg = num_layers_bg\n            self.hidden_dim_bg = hidden_dim_bg\n            self.encoder_bg, self.in_dim_bg = get_encoder(\n                encoding_bg,\n                input_dim=2,\n                num_levels=4,\n                log2_hashmap_size=19,\n                desired_resolution=2048,\n            )  # much smaller hashgrid\n\n            bg_net = []\n            for l in range(num_layers_bg):\n                if l == 0:\n                    in_dim = self.in_dim_bg + self.in_dim_dir\n                else:\n                    in_dim = hidden_dim_bg\n\n                if l == num_layers_bg - 1:\n                    out_dim = 3  # 3 rgb\n                else:\n                    out_dim = hidden_dim_bg\n\n                bg_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n            self.bg_net = nn.ModuleList(bg_net)\n        else:\n            self.bg_net = None\n\n    def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128, 128]):\n        tensor = []\n        tensor.append(\n            torch.nn.Parameter(\n                s * torch.randn((1, fea_dim, volume[0], volume[1], volume[2]))\n            )\n        )\n        self.tensor_volume = torch.nn.ParameterList(tensor).cuda()\n\n    def init_one_vm(self, n_component, resolution, scale=0.1):\n        # self.mat_ids = [[0, 1], [0, 2], [1, 2]]  self.vec_ids = [2, 1, 0]\n        mat, vec = [], []\n\n        for i in range(len(self.vec_ids)):\n            vec_id = self.vec_ids[i]\n            mat_id_0, mat_id_1 = self.mat_ids[i]\n            mat.append(\n                nn.Parameter(\n                    scale\n                    * torch.randn(\n                        (1, n_component[i], resolution[mat_id_1], resolution[mat_id_0])\n                    )\n                )\n            )  # [1, R, H, W]\n            vec.append(\n                nn.Parameter(\n                    scale * torch.randn((1, n_component[i], resolution[vec_id], 1))\n                )\n            )  # [1, R, D, 1] (fake 2d to use grid_sample)\n\n        return nn.ParameterList(mat), nn.ParameterList(vec)\n\n    def get_sigma_feat(self, x):\n        # x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode)\n        # self.mat_ids = [[0, 1], [0, 2], [1, 2]]  self.vec_ids = [2, 1, 0]\n        N = x.shape[0]\n\n        # plane + line basis\n        mat_coord = (\n            torch.stack(\n                (\n                    x[..., self.mat_ids[0]],\n                    x[..., self.mat_ids[1]],\n                    x[..., self.mat_ids[2]],\n                )\n            )\n            .detach()\n            .view(3, -1, 1, 2)\n        )  # [3, N, 1, 2]\n        vec_coord = torch.stack(\n            (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])\n        )\n        vec_coord = (\n            torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)\n            .detach()\n            .view(3, -1, 1, 2)\n        )  # [3, N, 1, 2], fake 2d coord\n\n        sigma_feat = torch.zeros(\n            [\n                N,\n            ],\n            device=x.device,\n        )\n\n        for i in range(len(self.sigma_mat)):\n            mat_feat = F.grid_sample(\n                self.sigma_mat[i], mat_coord[[i]], align_corners=True\n            ).view(\n                -1, N\n            )  # [1, R, N, 1] --> [R, N]\n            vec_feat = F.grid_sample(\n                self.sigma_vec[i], vec_coord[[i]], align_corners=True\n            ).view(\n                -1, N\n            )  # [R, N]\n            sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0)\n\n        return sigma_feat\n\n    def get_color_feat(self, x):\n        # x: [N, 3], in [-1, 1]\n        N = x.shape[0]\n\n        # plane + line basis\n        mat_coord = (\n            torch.stack(\n                (\n                    x[..., self.mat_ids[0]],\n                    x[..., self.mat_ids[1]],\n                    x[..., self.mat_ids[2]],\n                )\n            )\n            .detach()\n            .view(3, -1, 1, 2)\n        )  # [3, N, 1, 2]\n        vec_coord = torch.stack(\n            (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])\n        )\n        vec_coord = (\n            torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)\n            .detach()\n            .view(3, -1, 1, 2)\n        )  # [3, N, 1, 2], fake 2d coord\n\n        mat_feat, vec_feat = [], []\n        for i in range(len(self.color_mat)):\n            mat_feat.append(\n                F.grid_sample(\n                    self.color_mat[i], mat_coord[[i]], align_corners=True\n                ).view(-1, N)\n            )  # [1, R, N, 1] --> [R, N]\n            vec_feat.append(\n                F.grid_sample(\n                    self.color_vec[i], vec_coord[[i]], align_corners=True\n                ).view(-1, N)\n            )  # [R, N]\n\n        mat_feat = torch.cat(mat_feat, dim=0)  # [3 * R, N]\n        vec_feat = torch.cat(vec_feat, dim=0)  # [3 * R, N]\n\n        color_feat = self.basis_mat(\n            (mat_feat * vec_feat).T\n        )  # [N, 3R] --> [N, color_feat_dim]\n\n        return color_feat\n\n    def compute_plenoxel_fea(self, x):\n        composed = self.tensor_volume[0]\n        if self.args.enable_edit_plenoxel and self.is_teacher:\n            composed[\n                :, 0, :, 160:, :128\n            ] = -100  # This will erase the bucket in the lego scene for resolution 256\n        composed = (\n            F.grid_sample(composed, x.view(1, 1, -1, 1, 3), align_corners=True)\n            .view(-1, x.shape[0])\n            .permute(1, 0)\n        )\n        return composed  # [N, fea_dim]\n\n    def forward_nerf_mlp(self, x):\n        x = self.encoder_nerf_pe(x)\n        in_pts = x\n        for i in range(len(self.nerf_mlp)):\n            x = self.nerf_mlp[i](x)\n            if i != len(self.nerf_mlp) - 1:\n                x = F.relu(x, inplace=True)\n            if i == self.skips:\n                x = torch.cat([in_pts, x], -1)\n        return x\n\n    def forward(self, x, d):\n        # x: [N, 3], in [-bound, bound]  d: [N, 3], nomalized in [-1, 1]\n        # sigma\n        if self.model_type == \"hash\":\n            x = self.encoder(\n                x, bound=self.bound\n            )  # out_x[N, 28=num_levels * fea_per_level]\n        elif self.model_type == \"mlp\":\n            x = self.forward_nerf_mlp(x)  # 28\n        elif self.model_type == \"vm\":\n            x = (\n                2\n                * (x - self.aabb_train[:3])\n                / (self.aabb_train[3:] - self.aabb_train[:3])\n                - 1\n            )  # x:[N, 3]\n            sigma_feat = self.get_sigma_feat(x)  # sigma_feat:[N]\n            color_feat = self.get_color_feat(x)  # color_feat:[N, 15]\n            if self.opt.enable_edit_plenoxel:\n                sigma_feat = torch.clamp(sigma_feat, -100, self.args.sigma_clip_max)\n            else:\n                sigma_feat = torch.clamp(\n                    sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max\n                )\n            color_feat = torch.clamp(\n                color_feat, self.args.sigma_clip_min, self.args.sigma_clip_max\n            )\n            self.feature_sigma_color = torch.cat(\n                [sigma_feat.unsqueeze(-1), color_feat], dim=-1\n            )\n            if (\n                self.training\n                and self.args.global_step < self.args.stage_iters[\"stage1\"]\n            ):\n                return None, None\n            self.sigma_l = sigma_feat\n            sigma = trunc_exp(sigma_feat)  # sigma:[N]\n            enc_d = self.encoder_dir(d)  # enc_d:[N, 16]\n            h = torch.cat([enc_d, color_feat], dim=-1)  # h:[N, 16+15]\n            for l in range(self.num_layers_color):\n                h = self.color_net[l](h)\n                if l != self.num_layers_color - 1:\n                    h = F.relu(h, inplace=True)\n\n            color = torch.sigmoid(h)\n            self.color_l = color\n\n            return sigma, color\n        elif self.model_type == \"tensors\":\n            x = (\n                2\n                * (x - self.aabb_train[:3])\n                / (self.aabb_train[3:] - self.aabb_train[:3])\n                - 1\n            )  # x:[N, 3]\n            x = self.compute_plenoxel_fea(x)\n            h = x\n            if self.opt.enable_edit_plenoxel:\n                sigma = torch.clamp(h[..., 0], -100, self.args.sigma_clip_max)\n            else:\n                sigma = torch.clamp(\n                    h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max\n                )\n            self.sigma_l = sigma\n            sigma = trunc_exp(sigma)\n            self.sigma = sigma\n            sh = h[..., 1:].view(\n                -1, 3, self.plenoxel_degree ** 2\n            )  # [N, 3, 9]   ## .permute(1, 0, 2)  # [B, 27]-->[9, B, 3]\n            enc_d = self.encoder_dir(d).unsqueeze(1)  # [N, 9]-->[N,1,9]\n            color = (sh * enc_d).sum(-1)  # [N, 3]\n            color = torch.sigmoid(color)\n            self.feature_sigma_color = None\n            self.color_l = color\n            return sigma, color\n        else:\n            raise ValueError(f\"not illegal model_type:{self.model_type}\")\n\n        h = x\n        for l in range(self.num_layers):\n            h = self.sigma_net[l](h)\n            if l != self.num_layers - 1:\n                h = F.relu(h, inplace=True)\n        h[..., 0] = torch.clamp(\n            h[..., 0].clone(), self.args.sigma_clip_min, self.args.sigma_clip_max\n        )\n        self.feature_sigma_color = h\n        if self.training and self.args.global_step < self.args.stage_iters[\"stage1\"]:\n            return None, None\n        self.sigma_l = h[..., 0]\n        sigma = trunc_exp(h[..., 0])  # sigma: [n]\n        geo_feat = h[..., 1:]  # geo_feat: [n, 15]\n\n        d = self.encoder_dir(d)  # d: [n, 16]\n        h = torch.cat([d, geo_feat], dim=-1)  # h: [n, 15+16]\n        for l in range(self.num_layers_color):\n            h = self.color_net[l](h)\n            if l != self.num_layers_color - 1:\n                h = F.relu(h, inplace=True)\n\n        color = torch.sigmoid(h)\n        self.color_l = color\n        return sigma, color\n\n    def density(self, x):\n        # x: [N, 3], in [-bound, bound]\n        if self.model_type == \"hash\":\n            x = self.encoder(\n                x, bound=self.bound\n            )  # out_x[N, 32=num_levels * fea_per_level]\n        elif self.model_type == \"mlp\":\n            x = self.forward_nerf_mlp(x)\n        elif self.model_type == \"vm\":\n            x = (\n                2\n                * (x - self.aabb_train[:3])\n                / (self.aabb_train[3:] - self.aabb_train[:3])\n                - 1\n            )\n            sigma_feat = self.get_sigma_feat(x)\n            sigma_feat = torch.clamp(\n                sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max\n            )\n            sigma = trunc_exp(sigma_feat)\n            return {\"sigma\": sigma}\n        elif self.model_type == \"tensors\":\n            x = (\n                2\n                * (x - self.aabb_train[:3])\n                / (self.aabb_train[3:] - self.aabb_train[:3])\n                - 1\n            )  # x:[N, 3]\n            x = self.compute_plenoxel_fea(x)\n            h = x\n            # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)\n            sigma = trunc_exp(\n                torch.clamp(\n                    h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max\n                )\n            )\n            sigma = trunc_exp(h[..., 0])\n            return {\"sigma\": sigma}\n\n        else:\n            raise ValueError(f\"not illegal model_type:{self.model_type}\")\n\n        h = x\n        for l in range(self.num_layers):\n            h = self.sigma_net[l](h)\n            if l != self.num_layers - 1:\n                h = F.relu(h, inplace=True)\n\n        h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)\n        sigma = trunc_exp(h[..., 0])\n        geo_feat = h[..., 1:]\n\n        return {\n            \"sigma\": sigma,\n            \"geo_feat\": geo_feat,\n        }\n\n    def background(self, x, d):\n        assert 1 == 2\n        # x: [N, 2], in [-1, 1]\n\n        h = self.encoder_bg(x)  # [N, C]\n        d = self.encoder_dir(d)\n\n        h = torch.cat([d, h], dim=-1)\n        for l in range(self.num_layers_bg):\n            h = self.bg_net[l](h)\n            if l != self.num_layers_bg - 1:\n                h = F.relu(h, inplace=True)\n\n        # sigmoid activation for rgb\n        rgbs = torch.sigmoid(h)\n\n        return rgbs\n\n    # allow masked inference\n    def color(self, x, d, mask=None, geo_feat=None, **kwargs):\n        assert 1 == 2\n        # x: [N, 3] in [-bound, bound]\n        # mask: [N,], bool, indicates where we actually needs to compute rgb.\n\n        if mask is not None:\n            rgbs = torch.zeros(\n                mask.shape[0], 3, dtype=x.dtype, device=x.device\n            )  # [N, 3]\n            # in case of empty mask\n            if not mask.any():\n                return rgbs\n            x = x[mask]\n            d = d[mask]\n            geo_feat = geo_feat[mask]\n\n        d = self.encoder_dir(d)\n        h = torch.cat([d, geo_feat], dim=-1)\n        for l in range(self.num_layers_color):\n            h = self.color_net[l](h)\n            if l != self.num_layers_color - 1:\n                h = F.relu(h, inplace=True)\n\n        # sigmoid activation for rgb\n        h = torch.sigmoid(h)\n\n        if mask is not None:\n            rgbs[mask] = h.to(rgbs.dtype)  # fp16 --> fp32\n        else:\n            rgbs = h\n\n        return rgbs\n\n    # L1 penalty for loss\n    def density_loss(self):\n        loss = 0\n        for i in range(len(self.sigma_mat)):\n            loss = (\n                loss\n                + torch.mean(torch.abs(self.sigma_mat[i]))\n                + torch.mean(torch.abs(self.sigma_vec[i]))\n            )\n        return loss\n\n    # upsample utils\n    @torch.no_grad()\n    def upsample_params(self, mat, vec, resolution):\n\n        for i in range(len(self.vec_ids)):\n            vec_id = self.vec_ids[i]\n            mat_id_0, mat_id_1 = self.mat_ids[i]\n            mat[i] = nn.Parameter(\n                F.interpolate(\n                    mat[i].data,\n                    size=(resolution[mat_id_1], resolution[mat_id_0]),\n                    mode=\"bilinear\",\n                    align_corners=True,\n                )\n            )\n            vec[i] = nn.Parameter(\n                F.interpolate(\n                    vec[i].data,\n                    size=(resolution[vec_id], 1),\n                    mode=\"bilinear\",\n                    align_corners=True,\n                )\n            )\n\n    @torch.no_grad()\n    def upsample_model(self, resolution):\n        self.upsample_params(self.sigma_mat, self.sigma_vec, resolution)\n        self.upsample_params(self.color_mat, self.color_vec, resolution)\n        self.resolution = resolution\n\n    @torch.no_grad()\n    def shrink_model(self):\n        # shrink aabb_train and the model so it only represents the space inside aabb_train.\n\n        half_grid_size = self.bound / self.grid_size\n        thresh = min(self.density_thresh, self.mean_density)\n\n        # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?)\n        valid_grid = self.density_grid[self.cascade - 1] > thresh  # [N]\n        valid_pos = raymarching.morton3D_invert(\n            torch.nonzero(valid_grid)\n        )  # [Nz] --> [Nz, 3], in [0, H - 1]\n        # plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf...\n        valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (\n            self.bound - half_grid_size\n        )  # [Nz, 3], in [-b+hgs, b-hgs]\n        min_pos = valid_pos.amin(0) - half_grid_size  # [3]\n        max_pos = valid_pos.amax(0) + half_grid_size  # [3]\n\n        # shrink model\n        reso = torch.LongTensor(self.resolution).to(self.aabb_train.device)\n        units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso\n        tl = (min_pos - self.aabb_train[:3]) / units\n        br = (max_pos - self.aabb_train[:3]) / units\n        tl = torch.round(tl).long().clamp(min=0)\n        br = torch.minimum(torch.round(br).long(), reso)\n\n        for i in range(len(self.vec_ids)):\n            vec_id = self.vec_ids[i]\n            mat_id_0, mat_id_1 = self.mat_ids[i]\n\n            self.sigma_vec[i] = nn.Parameter(\n                self.sigma_vec[i].data[..., tl[vec_id] : br[vec_id], :]\n            )\n            self.color_vec[i] = nn.Parameter(\n                self.color_vec[i].data[..., tl[vec_id] : br[vec_id], :]\n            )\n\n            self.sigma_mat[i] = nn.Parameter(\n                self.sigma_mat[i].data[\n                    ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]\n                ]\n            )\n            self.color_mat[i] = nn.Parameter(\n                self.color_mat[i].data[\n                    ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]\n                ]\n            )\n\n        self.aabb_train = torch.cat([min_pos, max_pos], dim=0)  # [6]\n\n        print(\n            f\"[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}\"\n        )\n        print(f\"[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}\")\n\n    # optimizer utils\n    def get_params(self, lr, lr2=1e-3):\n        if self.model_type == \"hash\":\n            params = [\n                {\"params\": self.encoder.parameters(), \"lr\": lr},\n                {\"params\": self.sigma_net.parameters(), \"lr\": lr},\n                {\"params\": self.encoder_dir.parameters(), \"lr\": lr},\n                {\"params\": self.color_net.parameters(), \"lr\": lr},\n            ]\n        elif self.model_type == \"mlp\":\n            params = [\n                {\"params\": self.sigma_net.parameters(), \"lr\": lr},\n                {\"params\": self.encoder_dir.parameters(), \"lr\": lr},\n                {\"params\": self.color_net.parameters(), \"lr\": lr},\n                {\"params\": self.nerf_mlp.parameters(), \"lr\": lr},\n            ]\n        elif self.model_type == \"vm\":\n            params = [\n                {\"params\": self.color_net.parameters(), \"lr\": lr2},\n                {\"params\": self.sigma_mat, \"lr\": lr},\n                {\"params\": self.sigma_vec, \"lr\": lr},\n                {\"params\": self.color_mat, \"lr\": lr},\n                {\"params\": self.color_vec, \"lr\": lr},\n                {\"params\": self.basis_mat.parameters(), \"lr\": lr2},\n            ]\n        elif self.model_type == \"tensors\":\n            params = [\n                {\"params\": self.tensor_volume.parameters(), \"lr\": lr},\n                {\"params\": self.encoder_dir.parameters(), \"lr\": lr},\n            ]\n\n        else:\n            raise ValueError(f\"not illegal model_type:{self.model_type}\")\n\n        if self.bg_radius > 0:\n            params.append({\"params\": self.encoder_bg.parameters(), \"lr\": lr})\n            params.append({\"params\": self.bg_net.parameters(), \"lr\": lr})\n\n        return params\n"
  },
  {
    "path": "distill_mutual/provider.py",
    "content": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport numpy as np\nfrom scipy.spatial.transform import Slerp, Rotation\n\nimport trimesh\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom .utils import get_rays, srgb_to_linear\n\n\n# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50\ndef nerf_matrix_to_ngp(pose, scale=0.33):\n    # for the fox dataset, 0.33 scales camera radius to ~ 2\n    new_pose = np.array(\n        [\n            [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],\n            [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],\n            [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],\n            [0, 0, 0, 1],\n        ],\n        dtype=np.float32,\n    )\n    return new_pose\n\n\ndef rand_poses(\n    size,\n    device,\n    radius=1,\n    theta_range=[np.pi / 3, 2 * np.pi / 3],\n    phi_range=[0, 2 * np.pi],\n):\n    \"\"\"generate random poses from an orbit camera\n    Args:\n        size: batch size of generated poses.\n        device: where to allocate the output.\n        radius: camera radius\n        theta_range: [min, max], should be in [0, \\pi]\n        phi_range: [min, max], should be in [0, 2\\pi]\n    Return:\n        poses: [size, 4, 4]\n    \"\"\"\n\n    def normalize(vectors):\n        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)\n\n    thetas = (\n        torch.rand(size, device=device) * (theta_range[1] - theta_range[0])\n        + theta_range[0]\n    )\n    phis = (\n        torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]\n    )\n\n    centers = torch.stack(\n        [\n            radius * torch.sin(thetas) * torch.sin(phis),\n            radius * torch.cos(thetas),\n            radius * torch.sin(thetas) * torch.cos(phis),\n        ],\n        dim=-1,\n    )  # [B, 3]\n\n    # lookat\n    forward_vector = -normalize(centers)\n    up_vector = (\n        torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)\n    )  # confused at the coordinate system...\n    right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))\n    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))\n\n    poses = (\n        torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)\n    )\n    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)\n    poses[:, :3, 3] = centers\n\n    return poses\n\n    def normalize(vectors):\n        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)\n\n    interval_nums = torch.tensor(\n        [i * 1 / (size - 1) for i in range(size)], dtype=torch.float32, device=device\n    )\n    thetas = interval_nums * (theta_range[1] - theta_range[0]) + theta_range[0]\n    phis = interval_nums * (phi_range[1] - phi_range[0]) + phi_range[0]\n\n    centers = torch.stack(\n        [\n            radius * torch.sin(thetas) * torch.sin(phis),\n            radius * torch.cos(thetas),\n            radius * torch.sin(thetas) * torch.cos(phis),\n        ],\n        dim=-1,\n    )  # [B, 3]\n\n    # lookat\n    forward_vector = -normalize(centers)\n    up_vector = (\n        torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)\n    )  # confused at the coordinate system...\n    right_vector = normalize(\n        torch.cross(forward_vector, up_vector, dim=-1)\n    )  # cross product\n    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))\n\n    poses = (\n        torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)\n    )\n    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)\n    poses[:, :3, 3] = centers\n\n    return poses\n\n\nclass NeRFDataset:\n    def __init__(self, opt, device, type=\"train\", downscale=1, n_test=10):\n        super().__init__()\n\n        self.opt = opt\n        self.args = opt\n        self.device = device\n        self.type = type  # train, val, test\n        self.downscale = downscale\n        self.root_path = opt.path\n        self.mode = opt.mode  # only support blender\n        self.preload = opt.preload  # preload data into GPU\n        self.scale = (\n            opt.scale\n        )  # camera radius scale to make sure camera are inside the bounding box.\n        self.bound = (\n            opt.bound\n        )  # bounding box half length, also used as the radius to random sample poses.\n        self.fp16 = opt.fp16  # if preload, load into fp16.\n\n        self.training = self.type in [\"train\", \"all\", \"trainval\"]\n        self.num_rays = self.opt.num_rays if self.training else -1\n\n        if self.mode == \"blender\":\n            if type == \"all\":\n                transform_paths = glob.glob(os.path.join(self.root_path, \"*.json\"))\n                transform = None\n                for transform_path in transform_paths:\n                    with open(transform_path, \"r\") as f:\n                        tmp_transform = json.load(f)\n                        if transform is None:\n                            transform = tmp_transform\n                        else:\n                            transform[\"frames\"].extend(tmp_transform[\"frames\"])\n            # load train and val split\n            elif type == \"trainval\":\n                with open(\n                    os.path.join(self.root_path, f\"transforms_train.json\"), \"r\"\n                ) as f:\n                    transform = json.load(f)\n                with open(\n                    os.path.join(self.root_path, f\"transforms_val.json\"), \"r\"\n                ) as f:\n                    transform_val = json.load(f)\n                transform[\"frames\"].extend(transform_val[\"frames\"])\n            # only load one specified split\n            else:\n                with open(\n                    os.path.join(self.root_path, f\"transforms_{type}.json\"), \"r\"\n                ) as f:\n                    transform = json.load(f)\n\n        else:\n            raise NotImplementedError(f\"unknown dataset mode: {self.mode}\")\n\n        # load image size\n        if \"h\" in transform and \"w\" in transform:\n            self.H = int(transform[\"h\"]) // downscale\n            self.W = int(transform[\"w\"]) // downscale\n        else:\n            # we have to actually read an image to get H and W later.\n            self.H = self.W = None\n        # read images\n        frames = transform[\"frames\"]\n        if True:\n            self.poses = []\n            self.images = []\n            for f in tqdm.tqdm(frames, desc=f\"Loading {type} data:\"):\n                f_path = os.path.join(self.root_path, f[\"file_path\"])\n                if (\n                    self.mode == \"blender\"\n                    and f_path[-4:].lower() != \".png\"\n                    and f_path[-4:].lower() != \".jpg\"\n                ):\n                    f_path += \".png\"  # so silly...\n                if not os.path.exists(f_path):\n                    continue\n                pose = np.array(f[\"transform_matrix\"], dtype=np.float32)  # [4, 4]\n                pose = nerf_matrix_to_ngp(pose, scale=self.scale)\n\n                image = cv2.imread(\n                    f_path, cv2.IMREAD_UNCHANGED\n                )  # [H, W, 3] o [H, W, 4]\n                if self.H is None or self.W is None:\n                    self.H = image.shape[0] // downscale\n                    self.W = image.shape[1] // downscale\n\n                # add support for the alpha channel as a mask.\n                if image.shape[-1] == 3:\n                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n                else:\n                    image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)\n\n                if image.shape[0] != self.H or image.shape[1] != self.W:\n                    image = cv2.resize(\n                        image, (self.W, self.H), interpolation=cv2.INTER_AREA\n                    )\n\n                image = image.astype(np.float32) / 255  # [H, W, 3/4]\n\n                self.poses.append(pose)\n                self.images.append(image)\n        self.poses = torch.from_numpy(np.stack(self.poses, axis=0))  # [N, 4, 4]\n        if self.images is not None:\n            self.images = torch.from_numpy(\n                np.stack(self.images, axis=0)\n            )  # [N, H, W, C]\n        self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()\n\n        if self.training and self.opt.error_map:\n            self.error_map = torch.ones(\n                [self.images.shape[0], 128 * 128], dtype=torch.float\n            )  # [B, 128 * 128], flattened for easy indexing, fixed resolution...\n        else:\n            self.error_map = None\n\n        if self.preload:\n            self.poses = self.poses.to(self.device)\n            if self.images is not None:\n                if self.fp16 and self.opt.color_space != \"linear\":\n                    dtype = torch.half\n                else:\n                    dtype = torch.float\n                self.images = self.images.to(dtype).to(self.device)\n            if self.error_map is not None:\n                self.error_map = self.error_map.to(self.device)\n\n        # load intrinsics\n        if \"fl_x\" in transform or \"fl_y\" in transform:\n            fl_x = (\n                transform[\"fl_x\"] if \"fl_x\" in transform else transform[\"fl_y\"]\n            ) / downscale\n            fl_y = (\n                transform[\"fl_y\"] if \"fl_y\" in transform else transform[\"fl_x\"]\n            ) / downscale\n        elif \"camera_angle_x\" in transform or \"camera_angle_y\" in transform:\n            # blender, assert in radians. already downscaled since we use H/W\n            fl_x = (\n                self.W / (2 * np.tan(transform[\"camera_angle_x\"] / 2))\n                if \"camera_angle_x\" in transform\n                else None\n            )\n            fl_y = (\n                self.H / (2 * np.tan(transform[\"camera_angle_y\"] / 2))\n                if \"camera_angle_y\" in transform\n                else None\n            )\n            if fl_x is None:\n                fl_x = fl_y\n            if fl_y is None:\n                fl_y = fl_x\n        else:\n            raise RuntimeError(\n                \"Failed to load focal length, please check the transforms.json!\"\n            )\n\n        cx = (transform[\"cx\"] / downscale) if \"cx\" in transform else (self.H / 2)\n        cy = (transform[\"cy\"] / downscale) if \"cy\" in transform else (self.W / 2)\n\n        self.intrinsics = np.array([fl_x, fl_y, cx, cy])\n\n    def collate(self, index):\n\n        B = len(index)  # a list of length 1\n        poses = self.poses[index].to(self.device)  # [B, 4, 4]\n\n        error_map = None if self.error_map is None else self.error_map[index]\n        rays = get_rays(\n            poses, self.intrinsics, self.H, self.W, self.num_rays, error_map\n        )\n        results = {\n            \"H\": self.H,\n            \"W\": self.W,\n            \"rays_o\": rays[\"rays_o\"],\n            \"rays_d\": rays[\"rays_d\"],\n        }\n\n        if self.images is not None:\n            images = self.images[index].to(self.device)  # [B, H, W, 3/4]\n            if self.training:\n                C = images.shape[-1]\n                images = torch.gather(\n                    images.view(B, -1, C), 1, torch.stack(C * [rays[\"inds\"]], -1)\n                )  # [B, N, 3/4]\n            results[\"images\"] = images\n\n        # need inds to update error_map\n        if error_map is not None:\n            results[\"index\"] = index\n            results[\"inds_coarse\"] = rays[\"inds_coarse\"]\n\n        return results\n\n    def dataloader(self):\n        size = len(self.poses)\n        loader = DataLoader(\n            list(range(size)),\n            batch_size=1,\n            collate_fn=self.collate,\n            shuffle=self.training,\n            num_workers=0,\n        )\n        loader._data = self\n        return loader\n"
  },
  {
    "path": "distill_mutual/renderer.py",
    "content": "import math\nimport trimesh\nimport numpy as np\nfrom time import time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport raymarching\nfrom .utils import custom_meshgrid\nfrom IPython import embed\n\n\ndef sample_pdf(bins, weights, n_samples, det=False):\n    # This implementation is from NeRF\n    # bins: [B, T], old_z_vals\n    # weights: [B, T - 1], bin weights.\n    # return: [B, n_samples], new_z_vals\n\n    # Get pdf\n    weights = weights + 1e-5  # prevent nans\n    pdf = weights / torch.sum(weights, -1, keepdim=True)\n    cdf = torch.cumsum(pdf, -1)\n    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)\n    # Take uniform samples\n    if det:\n        u = torch.linspace(\n            0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples\n        ).to(weights.device)\n        u = u.expand(list(cdf.shape[:-1]) + [n_samples])\n    else:\n        u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)\n\n    # Invert CDF\n    u = u.contiguous()\n    inds = torch.searchsorted(cdf, u, right=True)\n    below = torch.max(torch.zeros_like(inds - 1), inds - 1)\n    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)\n    inds_g = torch.stack([below, above], -1)  # (B, n_samples, 2)\n\n    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]\n    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)\n    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)\n\n    denom = cdf_g[..., 1] - cdf_g[..., 0]\n    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)\n    t = (u - cdf_g[..., 0]) / denom\n    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])\n\n    return samples\n\n\ndef plot_pointcloud(pc, color=None):\n    # pc: [N, 3]\n    # color: [N, 3/4]\n    print(\"[visualize points]\", pc.shape, pc.dtype, pc.min(0), pc.max(0))\n    pc = trimesh.PointCloud(pc, color)\n    # axis\n    axes = trimesh.creation.axis(axis_length=4)\n    # sphere\n    sphere = trimesh.creation.icosphere(radius=1)\n    trimesh.Scene([pc, axes, sphere]).show()\n\n\nclass NeRFRenderer(nn.Module):\n    def __init__(\n        self,\n        bound=1,\n        cuda_ray=False,\n        density_scale=1,  # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.\n        min_near=0.2,\n        density_thresh=0.01,\n        bg_radius=-1,\n        grid_size=128,\n    ):\n        super().__init__()\n\n        print(\"\\n---------------\", grid_size, \"--------------\\n\")\n        self.bound = bound\n        self.cascade = 1 + math.ceil(math.log2(bound))\n        self.grid_size = grid_size\n        self.density_scale = density_scale\n        self.min_near = min_near\n        self.density_thresh = density_thresh\n        self.bg_radius = bg_radius  # radius of the background sphere.\n\n        # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)\n        # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.\n        aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])\n        aabb_infer = aabb_train.clone()\n        self.register_buffer(\"aabb_train\", aabb_train)\n        self.register_buffer(\"aabb_infer\", aabb_infer)\n\n        # extra state for cuda raymarching\n        self.cuda_ray = cuda_ray\n        if cuda_ray:\n            # density grid\n            density_grid = torch.zeros(\n                [self.cascade, self.grid_size ** 3]\n            )  # [CAS, H * H * H]\n            density_bitfield = torch.zeros(\n                self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8\n            )  # [CAS * H * H * H // 8]\n            self.register_buffer(\"density_grid\", density_grid)\n            self.register_buffer(\"density_bitfield\", density_bitfield)\n            self.mean_density = 0\n            self.iter_density = 0\n            # step counter\n            step_counter = torch.zeros(\n                16, 2, dtype=torch.int32\n            )  # 16 is hardcoded for averaging...\n            self.register_buffer(\"step_counter\", step_counter)\n            self.mean_count = 0\n            self.local_step = 0\n\n    def forward(self, x, d):\n        raise NotImplementedError()\n\n    # separated density and color query (can accelerate non-cuda-ray mode.)\n    def density(self, x):\n        raise NotImplementedError()\n\n    def color(self, x, d, mask=None, **kwargs):\n        raise NotImplementedError()\n\n    def reset_extra_state(self):\n        if not self.cuda_ray:\n            return\n        # density grid\n        self.density_grid.zero_()\n        self.mean_density = 0\n        self.iter_density = 0\n        # step counter\n        self.step_counter.zero_()\n        self.mean_count = 0\n        self.local_step = 0\n\n    def run(\n        self,\n        rays_o,\n        rays_d,\n        num_steps=128,\n        upsample_steps=128,\n        bg_color=None,\n        perturb=False,\n        **kwargs\n    ):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # bg_color: [3] in range [0, 1]\n        # return: image: [B, N, 3], depth: [B, N]\n        prefix = rays_o.shape[:-1]\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # N = B * N, in fact\n        device = rays_o.device\n\n        # choose aabb\n        aabb = self.aabb_train if self.training else self.aabb_infer\n\n        # sample steps\n        nears, fars = raymarching.near_far_from_aabb(\n            rays_o, rays_d, aabb, self.min_near\n        )\n        nears.unsqueeze_(-1)\n        fars.unsqueeze_(-1)\n\n        # print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')\n\n        z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(\n            0\n        )  # [1, T]\n        z_vals = z_vals.expand((N, num_steps))  # [N, T]\n        z_vals = nears + (fars - nears) * z_vals  # [N, T], in [nears, fars]\n\n        # perturb z_vals\n        sample_dist = (fars - nears) / num_steps\n        if perturb:\n            z_vals = (\n                z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist\n            )\n            # z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.\n\n        # generate xyzs\n        xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(\n            -1\n        )  # [N, 1, 3] * [N, T, 1] -> [N, T, 3]\n        xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:])  # a manual clip.\n\n        # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())\n\n        # query SDF and RGB\n        density_outputs = self.density(xyzs.reshape(-1, 3))\n\n        # sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]\n        for k, v in density_outputs.items():\n            density_outputs[k] = v.view(N, num_steps, -1)\n\n        # upsample z_vals (nerf-like)\n        if upsample_steps > 0:\n            with torch.no_grad():\n\n                deltas = z_vals[..., 1:] - z_vals[..., :-1]  # [N, T-1]\n                deltas = torch.cat(\n                    [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1\n                )\n\n                alphas = 1 - torch.exp(\n                    -deltas * self.density_scale * density_outputs[\"sigma\"].squeeze(-1)\n                )  # [N, T]\n                alphas_shifted = torch.cat(\n                    [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1\n                )  # [N, T+1]\n                weights = (\n                    alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]\n                )  # [N, T]\n\n                # sample new z_vals\n                z_vals_mid = z_vals[..., :-1] + 0.5 * deltas[..., :-1]  # [N, T-1]\n                new_z_vals = sample_pdf(\n                    z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training\n                ).detach()  # [N, t]\n\n                new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(\n                    -2\n                ) * new_z_vals.unsqueeze(\n                    -1\n                )  # [N, 1, 3] * [N, t, 1] -> [N, t, 3]\n                new_xyzs = torch.min(\n                    torch.max(new_xyzs, aabb[:3]), aabb[3:]\n                )  # a manual clip.\n\n            # only forward new points to save computation\n            new_density_outputs = self.density(new_xyzs.reshape(-1, 3))\n            # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]\n            for k, v in new_density_outputs.items():\n                new_density_outputs[k] = v.view(N, upsample_steps, -1)\n\n            # re-order\n            z_vals = torch.cat([z_vals, new_z_vals], dim=1)  # [N, T+t]\n            z_vals, z_index = torch.sort(z_vals, dim=1)\n\n            xyzs = torch.cat([xyzs, new_xyzs], dim=1)  # [N, T+t, 3]\n            xyzs = torch.gather(\n                xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)\n            )\n\n            for k in density_outputs:\n                tmp_output = torch.cat(\n                    [density_outputs[k], new_density_outputs[k]], dim=1\n                )\n                density_outputs[k] = torch.gather(\n                    tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)\n                )\n\n        deltas = z_vals[..., 1:] - z_vals[..., :-1]  # [N, T+t-1]\n        deltas = torch.cat(\n            [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1\n        )\n        alphas = 1 - torch.exp(\n            -deltas * self.density_scale * density_outputs[\"sigma\"].squeeze(-1)\n        )  # [N, T+t]\n        alphas_shifted = torch.cat(\n            [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1\n        )  # [N, T+t+1]\n        weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]  # [N, T+t]\n\n        dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)\n        for k, v in density_outputs.items():\n            density_outputs[k] = v.view(-1, v.shape[-1])\n\n        mask = weights > 1e-4  # hard coded\n        rgbs = self.color(\n            xyzs.reshape(-1, 3),\n            dirs.reshape(-1, 3),\n            mask=mask.reshape(-1),\n            **density_outputs\n        )\n        rgbs = rgbs.view(N, -1, 3)  # [N, T+t, 3]\n\n        # print(xyzs.shape, 'valid_rgb:', mask.sum().item())\n\n        # calculate weight_sum (mask)\n        weights_sum = weights.sum(dim=-1)  # [N]\n\n        # calculate depth\n        ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)\n        depth = torch.sum(weights * ori_z_vals, dim=-1)\n\n        # calculate color\n        image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2)  # [N, 3], in [0, 1]\n\n        # mix background color\n        if self.bg_radius > 0:\n            # use the bg model to calculate bg_color\n            polar = raymarching.polar_from_ray(\n                rays_o, rays_d, self.bg_radius\n            )  # [N, 2] in [-1, 1]\n            bg_color = self.background(polar, rays_d.reshape(-1, 3))  # [N, 3]\n        elif bg_color is None:\n            bg_color = 1\n\n        image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n\n        image = image.view(*prefix, 3)\n        depth = depth.view(*prefix)\n\n        # tmp: reg loss in mip-nerf 360\n        # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)\n        # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]\n        # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()\n\n        return {\n            \"depth\": depth,\n            \"image\": image,\n        }\n\n    def run_cuda(\n        self,\n        rays_o,\n        rays_d,\n        dt_gamma=0,\n        bg_color=None,\n        perturb=False,\n        force_all_rays=False,\n        max_steps=1024,\n        inherited_params=[],\n        **kwargs\n    ):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # return: image: [B, N, 3], depth: [B, N]\n\n        prefix = rays_o.shape[:-1]\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # N = B * N, in fact\n        device = rays_o.device\n\n        # pre-calculate near far\n        nears, fars = raymarching.near_far_from_aabb(\n            rays_o,\n            rays_d,\n            self.aabb_train if self.training else self.aabb_infer,\n            self.min_near,\n        )\n\n        # mix background color\n        if self.bg_radius > 0:\n            # use the bg model to calculate bg_color\n            polar = raymarching.polar_from_ray(\n                rays_o, rays_d, self.bg_radius\n            )  # [N, 2] in [-1, 1]\n            bg_color = self.background(polar, rays_d)  # [N, 3]\n        elif bg_color is None:\n            bg_color = 1\n\n        if self.training:  # different with testing\n            # setup counter\n            time1 = time()\n            counter = self.step_counter[self.local_step % 16]\n            counter.zero_()  # set to 0\n            self.local_step += 1\n            if (\n                self.args.render_stu_first\n            ):  # if stu first, then using stu to calculate xyzs, and tea will inherite the xyzs\n                \"\"\"\n                About xyzs, dirs, deltas, rays:\n                    xyzs, dirs are all spatial points sampled by rays_o and rays_d;\n                    rays: xyzs[rays[i, 1]:rays[i,1]+rays[i, 2]] --> points belonging to rays[i, 0]\n                    deltas: shape is [point_nums, 2]. deltas means all generated points' deltas. (first for RGB, second for Depth)\n                \"\"\"\n                if not self.is_teacher:\n                    xyzs, dirs, deltas, rays = raymarching.march_rays_train(\n                        rays_o,\n                        rays_d,\n                        self.bound,\n                        self.density_bitfield,\n                        self.cascade,\n                        self.grid_size,\n                        nears,\n                        fars,\n                        counter,\n                        self.mean_count,\n                        perturb,\n                        128,\n                        force_all_rays,\n                        dt_gamma,\n                        max_steps,\n                    )\n                    inherited_params = [xyzs, dirs, deltas, rays]\n                else:\n                    xyzs, dirs, deltas, rays = inherited_params\n            else:\n                if self.is_teacher:\n                    xyzs, dirs, deltas, rays = raymarching.march_rays_train(\n                        rays_o,\n                        rays_d,\n                        self.bound,\n                        self.density_bitfield,\n                        self.cascade,\n                        self.grid_size,\n                        nears,\n                        fars,\n                        counter,\n                        self.mean_count,\n                        perturb,\n                        128,\n                        force_all_rays,\n                        dt_gamma,\n                        max_steps,\n                    )\n                    inherited_params = [xyzs, dirs, deltas, rays]\n                else:\n                    xyzs, dirs, deltas, rays = inherited_params\n\n            # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())\n            sigmas, rgbs = self(xyzs, dirs)\n\n            if self.args.global_step < self.args.stage_iters[\"stage1\"]:\n                return {\n                    \"stage1\": self.args.global_step,\n                    \"depth\": None,\n                    \"image\": None,\n                    \"inherited_params\": inherited_params,\n                    \"sigmas\": sigmas,\n                    \"rays\": rays,\n                }\n            elif self.args.global_step < self.args.stage_iters[\"stage2\"]:\n                return {\n                    \"stage2\": self.args.global_step,\n                    \"depth\": None,\n                    \"image\": None,\n                    \"inherited_params\": inherited_params,\n                    \"sigmas\": sigmas,\n                    \"rays\": rays,\n                }\n\n            sigmas = self.density_scale * sigmas\n\n            weights_sum, depth, image = raymarching.composite_rays_train(\n                sigmas, rgbs, deltas, rays\n            )\n            image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n            depth = torch.clamp(depth - nears, min=0) / (fars - nears + 1e-6)\n            image = image.view(*prefix, 3)\n            depth = depth.view(*prefix)\n\n        else:\n            # allocate outputs\n            # if use autocast, must init as half so it won't be autocasted and lose reference.\n            # dtype = torch.half if torch.is_autocast_enabled() else torch.float32\n            # output should always be float32! only network inference uses half.\n            dtype = torch.float32\n\n            weights_sum = torch.zeros(N, dtype=dtype, device=device)\n            depth = torch.zeros(N, dtype=dtype, device=device)\n            image = torch.zeros(N, 3, dtype=dtype, device=device)\n\n            n_alive = N\n            alive_counter = torch.zeros([1], dtype=torch.int32, device=device)\n\n            rays_alive = torch.zeros(\n                2, n_alive, dtype=torch.int32, device=device\n            )  # 2 is used to loop old/new\n            rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device)\n\n            step = 0\n            i = 0\n            while step < max_steps:\n\n                # count alive rays\n                if step == 0:\n                    # init rays at first step.\n                    torch.arange(n_alive, out=rays_alive[0])\n                    rays_t[0] = nears\n                else:\n                    alive_counter.zero_()\n                    raymarching.compact_rays(\n                        n_alive,\n                        rays_alive[i % 2],\n                        rays_alive[(i + 1) % 2],\n                        rays_t[i % 2],\n                        rays_t[(i + 1) % 2],\n                        alive_counter,\n                    )\n                    n_alive = alive_counter.item()  # must invoke D2H copy here\n\n                # exit loop\n                if n_alive <= 0:\n                    break\n\n                # decide compact_steps\n                n_step = max(min(N // n_alive, 8), 1)\n\n                xyzs, dirs, deltas = raymarching.march_rays(\n                    n_alive,\n                    n_step,\n                    rays_alive[i % 2],\n                    rays_t[i % 2],\n                    rays_o,\n                    rays_d,\n                    self.bound,\n                    self.density_bitfield,\n                    self.cascade,\n                    self.grid_size,\n                    nears,\n                    fars,\n                    128,\n                    perturb,\n                    dt_gamma,\n                    max_steps,\n                )\n\n                sigmas, rgbs = self(xyzs, dirs)\n                # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.\n                # sigmas = density_outputs['sigma']\n                # rgbs = self.color(xyzs, dirs, **density_outputs)\n                sigmas = self.density_scale * sigmas\n\n                raymarching.composite_rays(\n                    n_alive,\n                    n_step,\n                    rays_alive[i % 2],\n                    rays_t[i % 2],\n                    sigmas,\n                    rgbs,\n                    deltas,\n                    weights_sum,\n                    depth,\n                    image,\n                )\n\n                # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')\n\n                step += n_step\n                i += 1\n\n            image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n            depth = torch.clamp(depth - nears, min=0) / (fars - nears)\n            image = image.view(*prefix, 3)\n            depth = depth.view(*prefix)\n\n        # print('\\n--- render time:--- {:6f}  {:.6f}'.format(time2-time1, time()-time2))\n        if self.training:\n            return {\n                \"depth\": depth,\n                \"image\": image,\n                \"inherited_params\": inherited_params,\n                \"sigmas\": sigmas,\n                \"rays\": rays,\n            }\n        else:\n            return {\n                \"depth\": depth,\n                \"image\": image,\n                \"inherited_params\": inherited_params,\n            }\n\n    @torch.no_grad()\n    def mark_untrained_grid(self, poses, intrinsic, S=64):\n        # poses: [B, 4, 4]\n        # intrinsic: [3, 3]\n\n        if not self.cuda_ray:\n            return\n\n        if isinstance(poses, np.ndarray):\n            poses = torch.from_numpy(poses)\n\n        B = poses.shape[0]\n\n        fx, fy, cx, cy = intrinsic\n\n        X = torch.arange(\n            self.grid_size, dtype=torch.int32, device=self.density_grid.device\n        ).split(S)\n        Y = torch.arange(\n            self.grid_size, dtype=torch.int32, device=self.density_grid.device\n        ).split(S)\n        Z = torch.arange(\n            self.grid_size, dtype=torch.int32, device=self.density_grid.device\n        ).split(S)\n\n        count = torch.zeros_like(self.density_grid)\n        poses = poses.to(count.device)\n\n        # 5-level loop, forgive me...\n\n        for xs in X:\n            for ys in Y:\n                for zs in Z:\n\n                    # construct points\n                    xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                    coords = torch.cat(\n                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                        dim=-1,\n                    )  # [N, 3], in [0, 128)\n                    indices = raymarching.morton3D(coords).long()  # [N]\n                    world_xyzs = (\n                        2 * coords.float() / (self.grid_size - 1) - 1\n                    ).unsqueeze(\n                        0\n                    )  # [1, N, 3] in [-1, 1]\n\n                    # cascading\n                    for cas in range(self.cascade):\n                        bound = min(2 ** cas, self.bound)\n                        half_grid_size = bound / self.grid_size\n                        # scale to current cascade's resolution\n                        cas_world_xyzs = world_xyzs * (bound - half_grid_size)\n\n                        # split batch to avoid OOM\n                        head = 0\n                        while head < B:\n                            tail = min(head + S, B)\n\n                            # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)\n                            cam_xyzs = cas_world_xyzs - poses[\n                                head:tail, :3, 3\n                            ].unsqueeze(1)\n                            cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3]  # [S, N, 3]\n\n                            # query if point is covered by any camera\n                            mask_z = cam_xyzs[:, :, 2] > 0  # [S, N]\n                            mask_x = (\n                                torch.abs(cam_xyzs[:, :, 0])\n                                < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2\n                            )\n                            mask_y = (\n                                torch.abs(cam_xyzs[:, :, 1])\n                                < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2\n                            )\n                            mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1)  # [N]\n\n                            # update count\n                            count[cas, indices] += mask\n                            head += S\n\n        # mark untrained grid as -1\n        self.density_grid[count == 0] = -1\n\n        # print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}')\n\n    @torch.no_grad()\n    def update_extra_state(self, decay=0.95, S=128):\n        # call before each epoch to update extra states.\n\n        if not self.cuda_ray:\n            return\n\n        # update density grid\n        tmp_grid = -torch.ones_like(self.density_grid)\n\n        # full update.\n        if self.iter_density < 16:\n            # if True:\n            X = torch.arange(\n                self.grid_size, dtype=torch.int32, device=self.density_grid.device\n            ).split(S)\n            Y = torch.arange(\n                self.grid_size, dtype=torch.int32, device=self.density_grid.device\n            ).split(S)\n            Z = torch.arange(\n                self.grid_size, dtype=torch.int32, device=self.density_grid.device\n            ).split(S)\n\n            for xs in X:\n                for ys in Y:\n                    for zs in Z:\n                        # construct points\n                        xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                        coords = torch.cat(\n                            [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                            dim=-1,\n                        )  # [N, 3], in [0, 128)\n                        indices = raymarching.morton3D(coords).long()  # [N]\n                        xyzs = (\n                            2 * coords.float() / (self.grid_size - 1) - 1\n                        )  # [N, 3] in [-1, 1]\n\n                        # cascading\n                        for cas in range(self.cascade):\n                            bound = min(2 ** cas, self.bound)\n                            half_grid_size = bound / self.grid_size\n                            # scale to current cascade's resolution\n                            cas_xyzs = xyzs * (bound - half_grid_size)\n                            # add noise in [-hgs, hgs]\n                            cas_xyzs += (\n                                torch.rand_like(cas_xyzs) * 2 - 1\n                            ) * half_grid_size\n                            # query density\n                            sigmas = (\n                                self.density(cas_xyzs)[\"sigma\"].reshape(-1).detach()\n                            )\n                            sigmas *= self.density_scale\n                            # assign\n                            tmp_grid[cas, indices] = sigmas\n\n        # partial update (half the computation)\n        # TODO: why no need of maxpool ?\n        else:\n            N = self.grid_size ** 3 // 4  # H * H * H / 4\n            for cas in range(self.cascade):\n                # random sample some positions\n                coords = torch.randint(\n                    0, self.grid_size, (N, 3), device=self.density_grid.device\n                )  # [N, 3], in [0, 128)\n                indices = raymarching.morton3D(coords).long()  # [N]\n                # random sample occupied positions\n                occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(\n                    -1\n                )  # [Nz]\n                rand_mask = torch.randint(\n                    0,\n                    occ_indices.shape[0],\n                    [N],\n                    dtype=torch.long,\n                    device=self.density_grid.device,\n                )\n                occ_indices = occ_indices[\n                    rand_mask\n                ]  # [Nz] --> [N], allow for duplication\n                occ_coords = raymarching.morton3D_invert(occ_indices)  # [N, 3]\n                # concat\n                indices = torch.cat([indices, occ_indices], dim=0)\n                coords = torch.cat([coords, occ_coords], dim=0)\n                # same below\n                xyzs = (\n                    2 * coords.float() / (self.grid_size - 1) - 1\n                )  # [N, 3] in [-1, 1]\n                bound = min(2 ** cas, self.bound)\n                half_grid_size = bound / self.grid_size\n                # scale to current cascade's resolution\n                cas_xyzs = xyzs * (bound - half_grid_size)\n                # add noise in [-hgs, hgs]\n                cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size\n                # query density\n                sigmas = self.density(cas_xyzs)[\"sigma\"].reshape(-1).detach()\n                sigmas *= self.density_scale\n                # assign\n                tmp_grid[cas, indices] = sigmas\n\n        ## max-pool on tmp_grid for less aggressive culling [No significant improvement...]\n        # invalid_mask = tmp_grid < 0\n        # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1)\n        # tmp_grid[invalid_mask] = -1\n\n        # ema update\n        valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)\n        self.density_grid[valid_mask] = torch.maximum(\n            self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]\n        )\n        self.mean_density = torch.mean(\n            self.density_grid.clamp(min=0)\n        ).item()  # -1 non-training regions are viewed as 0 density.\n        self.iter_density += 1\n\n        # convert to bitfield\n        density_thresh = min(self.mean_density, self.density_thresh)\n        self.density_bitfield = raymarching.packbits(\n            self.density_grid, density_thresh, self.density_bitfield\n        )\n\n        ### update step counter\n        total_step = min(16, self.local_step)\n        if total_step > 0:\n            self.mean_count = int(\n                self.step_counter[:total_step, 0].sum().item() / total_step\n            )\n        self.local_step = 0\n\n        # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')\n\n    def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # return: pred_rgb: [B, N, 3]\n\n        if self.cuda_ray:\n            _run = self.run_cuda\n        else:\n            _run = self.run\n\n        B, N = rays_o.shape[:2]\n        device = rays_o.device\n\n        # never stage when cuda_ray\n        if staged and not self.cuda_ray:\n            depth = torch.empty((B, N), device=device)\n            image = torch.empty((B, N, 3), device=device)\n\n            for b in range(B):\n                head = 0\n                while head < N:\n                    tail = min(head + max_ray_batch, N)\n                    results_ = _run(\n                        rays_o[b : b + 1, head:tail],\n                        rays_d[b : b + 1, head:tail],\n                        **kwargs\n                    )\n                    depth[b : b + 1, head:tail] = results_[\"depth\"]\n                    image[b : b + 1, head:tail] = results_[\"image\"]\n                    head += max_ray_batch\n\n            results = {}\n            results[\"depth\"] = depth\n            results[\"image\"] = image\n\n        else:\n            results = _run(rays_o, rays_d, **kwargs)\n\n        return results\n"
  },
  {
    "path": "distill_mutual/utils.py",
    "content": "import os\nimport copy\nimport lpips\nimport glob\nimport tqdm\nimport math\nimport random\nimport warnings\nimport tensorboardX\n\nimport numpy as np\nimport pandas as pd\n\nimport imageio\n\nimport time\nfrom datetime import datetime\n\nimport cv2\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nimport torch.distributed as dist\nfrom torch.utils.data import Dataset, DataLoader\nimport trimesh\nimport mcubes\nfrom rich.console import Console\nfrom torch_ema import ExponentialMovingAverage\nfrom IPython import embed\nimport sys\n\nfrom packaging import version as pver\n\ndevice = torch.device(\"cuda\")\nTINY_NUMBER = 1e-6  # float32 only has 7 decimal digits precision\n\n\ndef update_loss_rate(cur_lrate, scale=0.99):\n    return cur_lrate * scale\n\n\ndef get_softmax_map_mean(a, b):\n    return (F.softmax(a) - F.softmax(b)).abs().mean()\n\n\ndef get_kl(inputs, targets):\n    return F.kl_div(F.log_softmax(inputs), F.softmax(targets), reduction=\"sum\")\n\n\ndef nerf_matrix_to_ngp(pose, scale=0.8):\n    # for the fox dataset, 0.33 scales camera radius to ~ 2\n    new_pose = np.array(\n        [\n            [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],\n            [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],\n            [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],\n            [0, 0, 0, 1],\n        ],\n        dtype=np.float32,\n    )\n    return new_pose\n\n\ndef pose_spherical(theta, phi, radius):\n    # for synthetic. it generates sphere random poses\n    trans_t = lambda t: np.array(\n        [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]]\n    ).astype(np.float32)\n    rot_phi = lambda phi: np.array(\n        [\n            [1, 0, 0, 0],\n            [0, np.cos(phi), -np.sin(phi), 0],\n            [0, np.sin(phi), np.cos(phi), 0],\n            [0, 0, 0, 1],\n        ]\n    ).astype(np.float32)\n    rot_theta = lambda th: np.array(\n        [\n            [np.cos(th), 0, -np.sin(th), 0],\n            [0, 1, 0, 0],\n            [np.sin(th), 0, np.cos(th), 0],\n            [0, 0, 0, 1],\n        ]\n    ).astype(np.float32)\n    c2w = trans_t(radius)\n    c2w = rot_phi(phi / 180.0 * np.pi) @ c2w\n    c2w = rot_theta(theta / 180.0 * np.pi) @ c2w\n    c2w = (\n        np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).astype(\n            np.float32\n        )\n        @ c2w\n    )\n    return c2w\n\n\ndef get_rand_poses(data_type=\"synthetic\", original_loader=None):\n    \"\"\"\n    Random sampling. Random origins and directions.\n    \"\"\"\n    from scipy.spatial.transform import Slerp, Rotation\n\n    assert data_type in {\"synthetic\", \"llff\", \"tank\"}\n\n    def get_single_syn_pose(ph, rand_radius=False):\n        theta1 = -180\n        theta2 = 180\n        phi1 = -ph\n        phi2 = 5 - ph if (5 - ph) <= 0 else 0\n        theta = theta1 + np.random.rand() * (theta2 - theta1)\n        phi = phi1 + np.random.rand() * (phi2 - phi1)\n        if rand_radius:\n            radius = np.random.uniform(3, 4)\n        else:\n            radius = 4\n        return pose_spherical(theta, phi, radius)\n\n    def get_syn_poses():\n        random_poses = np.array([get_single_syn_pose(8) for _ in range(1)])\n        for a in range(0, 80):\n            rp = np.array(\n                [get_single_syn_pose(a) for _ in range(int(((90 - a) // 15) ** 1 + 1))]\n            )\n            random_poses = np.concatenate([random_poses, rp], axis=0)\n        for i in range(len(random_poses)):\n            random_poses[i] = nerf_matrix_to_ngp(random_poses[i])\n        print(f\"\\nlen(train data): {len(random_poses)}\\n\")\n        random_poses = torch.from_numpy(random_poses).cuda()\n        return random_poses\n\n    def get_tank_poses():\n        random_poses = np.array([get_single_syn_pose(8) for _ in range(1)])\n        for a in range(5, 20):\n            rp = np.array(\n                [\n                    get_single_syn_pose(a, True)\n                    for _ in range(int(((90 - a) // 15) ** 1 + 1))\n                ]\n            )\n            random_poses = np.concatenate([random_poses, rp], axis=0)\n        for i in range(len(random_poses)):\n            random_poses[i] = nerf_matrix_to_ngp(random_poses[i])\n        print(f\"\\nlen(train data): {len(random_poses)}\\n\")\n        random_poses = torch.from_numpy(random_poses).cuda()\n        return random_poses\n\n    def rand_poses_from_cam_centers(centers):\n        def normalize(vectors):\n            return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)\n\n        size = len(centers)\n        forward_vector = -normalize(centers)\n        up_vector = (\n            torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)\n        )  # confused at the coordinate system...\n        right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))\n        up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))\n\n        poses = (\n            torch.eye(4, dtype=torch.float, device=device)\n            .unsqueeze(0)\n            .repeat(size, 1, 1)\n        )\n        poses[:, :3, :3] = torch.stack(\n            (right_vector, up_vector, forward_vector), dim=-1\n        )\n        poses[:, :3, 3] = centers\n        return poses\n\n    def get_llff_poses_rand():\n        def get_rand_cam_centers_from_bbox(poses, gen_num=30):\n            # use poses to estimate the bbox of the camera\n            trasitions = poses[:, :3, 3]\n            bbox_max = trasitions.max(axis=0) + 1e-6\n            bbox_min = trasitions.min(axis=0) - 1e-6\n            rand_xs = np.random.uniform(low=bbox_min[0], high=bbox_max[0], size=gen_num)\n            rand_ys = np.random.uniform(low=bbox_min[1], high=bbox_max[1], size=gen_num)\n            rand_zs = np.random.uniform(low=bbox_min[2], high=bbox_max[2], size=gen_num)\n            centers = np.stack([rand_xs, rand_ys, rand_zs], axis=1)\n            return centers.astype(np.float32)\n\n        centers = get_rand_cam_centers_from_bbox(original_loader)\n        random_poses = rand_poses_from_cam_centers(torch.from_numpy(centers).cuda())\n        random_poses[:, 0, 0] = -random_poses[:, 0, 0]\n        return random_poses\n\n    if data_type == \"synthetic\":\n        random_poses = get_syn_poses()\n    elif data_type == \"llff\":\n        random_poses = get_llff_poses_rand()\n    elif data_type == \"tank\":\n        random_poses = get_tank_poses()\n    else:\n        raise ValueError(\"illegal\")\n    return random_poses\n\n\ndef custom_meshgrid(*args):\n    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid\n    if pver.parse(torch.__version__) < pver.parse(\"1.10\"):\n        return torch.meshgrid(*args)\n    else:\n        return torch.meshgrid(*args, indexing=\"ij\")\n\n\n@torch.jit.script\ndef linear_to_srgb(x):\n    return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)\n\n\n@torch.jit.script\ndef srgb_to_linear(x):\n    return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)\n\n\ndef compute_ssim(\n    img0,\n    img1,\n    max_val,\n    filter_size=11,\n    filter_sigma=1.5,\n    k1=0.01,\n    k2=0.03,\n    return_map=False,\n):\n    \"\"\"Computes SSIM from two images.\n    This function was modeled after tf.image.ssim, and should produce comparable\n    output.\n    Args:\n      img0: torch.tensor. An image of size [..., width, height, num_channels].\n      img1: torch.tensor. An image of size [..., width, height, num_channels].\n      max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.\n      filter_size: int >= 1. Window size.\n      filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.\n      k1: float > 0. One of the SSIM dampening parameters.\n      k2: float > 0. One of the SSIM dampening parameters.\n      return_map: Bool. If True, will cause the per-pixel SSIM \"map\" to returned\n    Returns:\n      Each image's mean SSIM, or a tensor of individual values if `return_map`.\n    \"\"\"\n    device = img0.device\n    img0 = img0.type(torch.float32)\n    img1 = img1.type(torch.float32)\n    ori_shape = img0.size()\n    width, height, num_channels = ori_shape[-3:]\n    img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2)\n    img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2)\n    batch_size = img0.shape[0]\n\n    # Construct a 1D Gaussian blur filter.\n    hw = filter_size // 2\n    shift = (2 * hw - filter_size + 1) / 2\n    f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2\n    filt = torch.exp(-0.5 * f_i)\n    filt /= torch.sum(filt)\n\n    # Blur in x and y (faster than the 2D convolution).\n    # z is a tensor of size [B, H, W, C]\n    filt_fn1 = lambda z: F.conv2d(\n        z,\n        filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1),\n        padding=[hw, 0],\n        groups=num_channels,\n    )\n    filt_fn2 = lambda z: F.conv2d(\n        z,\n        filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1),\n        padding=[0, hw],\n        groups=num_channels,\n    )\n\n    # Vmap the blurs to the tensor size, and then compose them.\n    filt_fn = lambda z: filt_fn1(filt_fn2(z))\n    mu0 = filt_fn(img0)\n    mu1 = filt_fn(img1)\n    mu00 = mu0 * mu0\n    mu11 = mu1 * mu1\n    mu01 = mu0 * mu1\n    sigma00 = filt_fn(img0 ** 2) - mu00\n    sigma11 = filt_fn(img1 ** 2) - mu11\n    sigma01 = filt_fn(img0 * img1) - mu01\n\n    # Clip the variances and covariances to valid values.\n    # Variance must be non-negative:\n    sigma00 = torch.clamp(sigma00, min=0.0)\n    sigma11 = torch.clamp(sigma11, min=0.0)\n    sigma01 = torch.sign(sigma01) * torch.min(\n        torch.sqrt(sigma00 * sigma11), torch.abs(sigma01)\n    )\n\n    c1 = (k1 * max_val) ** 2\n    c2 = (k2 * max_val) ** 2\n    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)\n    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)\n    ssim_map = numer / denom\n    ssim = torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1)\n    return ssim_map if return_map else ssim\n\n\ndef init_lpips(net_name, device):\n    assert net_name in [\"alex\", \"vgg\"]\n    import lpips\n\n    print(f\"init_lpips: lpips_{net_name}\")\n    return lpips.LPIPS(net=net_name, version=\"0.1\").eval().cuda()\n\n\nlpips_fns = {\n    \"alex\": lpips.LPIPS(net=\"alex\", version=\"0.1\").eval().cuda(),\n    \"vgg\": lpips.LPIPS(net=\"vgg\", version=\"0.1\").eval().cuda(),\n}\n\n\ndef rgb_lpips(gt, im, net_name):\n    assert net_name in [\"alex\", \"vgg\"]\n    gt = gt.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda()\n    im = im.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda()\n    return lpips_fns[net_name](gt, im, normalize=True).item()\n\n\n@torch.cuda.amp.autocast(enabled=False)\ndef get_rays(poses, intrinsics, H, W, N=-1, error_map=None):\n    \"\"\"get rays\n    Args:\n        poses: [B, 4, 4], cam2world\n        intrinsics: [4]\n        H, W, N: int\n        error_map: [B, 128 * 128], sample probability based on training error\n    Returns:\n        rays_o, rays_d: [B, N, 3]\n        inds: [B, N]\n    \"\"\"\n\n    device = poses.device\n    B = poses.shape[0]\n    fx, fy, cx, cy = intrinsics\n\n    i, j = custom_meshgrid(\n        torch.linspace(0, W - 1, W, device=device),\n        torch.linspace(0, H - 1, H, device=device),\n    )\n    i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5\n    j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5\n\n    results = {}\n\n    if N > 0:\n        N = min(N, H * W)\n\n        if error_map is None:\n            inds = torch.randint(0, H * W, size=[N], device=device)  # may duplicate\n            inds = inds.expand([B, N])\n        else:\n\n            # weighted sample on a low-reso grid\n            inds_coarse = torch.multinomial(\n                error_map.to(device), N, replacement=False\n            )  # [B, N], but in [0, 128*128)\n\n            # map to the original resolution with random perturb.\n            inds_x, inds_y = (\n                inds_coarse // 128,\n                inds_coarse % 128,\n            )  # `//` will throw a warning in torch 1.10... anyway.\n            sx, sy = H / 128, W / 128\n            inds_x = (\n                (inds_x * sx + torch.rand(B, N, device=device) * sx)\n                .long()\n                .clamp(max=H - 1)\n            )\n            inds_y = (\n                (inds_y * sy + torch.rand(B, N, device=device) * sy)\n                .long()\n                .clamp(max=W - 1)\n            )\n            inds = inds_x * W + inds_y\n\n            results[\"inds_coarse\"] = inds_coarse  # need this when updating error_map\n\n        i = torch.gather(i, -1, inds)\n        j = torch.gather(j, -1, inds)\n\n        results[\"inds\"] = inds\n\n    else:\n        inds = torch.arange(H * W, device=device).expand([B, H * W])\n\n    zs = torch.ones_like(i)\n    xs = (i - cx) / fx * zs\n    ys = (j - cy) / fy * zs\n    directions = torch.stack((xs, ys, zs), dim=-1)\n    directions = directions / torch.norm(directions, dim=-1, keepdim=True)\n    rays_d = directions @ poses[:, :3, :3].transpose(-1, -2)  # (B, N, 3)\n\n    rays_o = poses[..., :3, 3]  # [B, 3]\n    rays_o = rays_o[..., None, :].expand_as(rays_d)  # [B, N, 3]\n\n    results[\"rays_o\"] = rays_o\n    results[\"rays_d\"] = rays_d\n\n    return results\n\n\ndef seed_everything(seed):\n    random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    # torch.backends.cudnn.deterministic = True\n    # torch.backends.cudnn.benchmark = True\n\n\ndef torch_vis_2d(x, renormalize=False):\n    # x: [3, H, W] or [1, H, W] or [H, W]\n    import matplotlib.pyplot as plt\n    import numpy as np\n    import torch\n\n    if isinstance(x, torch.Tensor):\n        if len(x.shape) == 3:\n            x = x.permute(1, 2, 0).squeeze()\n        x = x.detach().cpu().numpy()\n\n    print(f\"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}\")\n\n    x = x.astype(np.float32)\n\n    # renormalize\n    if renormalize:\n        x = (x - x.min(axis=0, keepdims=True)) / (\n            x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8\n        )\n\n    plt.imshow(x)\n    plt.show()\n\n\ndef extract_fields(bound_min, bound_max, resolution, query_func, S=128):\n\n    X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)\n    Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)\n    Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)\n\n    u = np.zeros([resolution, resolution, resolution], dtype=np.float32)\n    with torch.no_grad():\n        for xi, xs in enumerate(X):\n            for yi, ys in enumerate(Y):\n                for zi, zs in enumerate(Z):\n                    xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                    pts = torch.cat(\n                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                        dim=-1,\n                    )  # [S, 3]\n                    val = (\n                        query_func(pts)\n                        .reshape(len(xs), len(ys), len(zs))\n                        .detach()\n                        .cpu()\n                        .numpy()\n                    )  # [S, 1] --> [x, y, z]\n                    u[\n                        xi * S : xi * S + len(xs),\n                        yi * S : yi * S + len(ys),\n                        zi * S : zi * S + len(zs),\n                    ] = val\n    return u\n\n\ndef extract_geometry(bound_min, bound_max, resolution, threshold, query_func):\n    # print('threshold: {}'.format(threshold))\n    u = extract_fields(bound_min, bound_max, resolution, query_func)\n\n    # print(u.shape, u.max(), u.min(), np.percentile(u, 50))\n\n    vertices, triangles = mcubes.marching_cubes(u, threshold)\n\n    b_max_np = bound_max.detach().cpu().numpy()\n    b_min_np = bound_min.detach().cpu().numpy()\n\n    vertices = (\n        vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :]\n        + b_min_np[None, :]\n    )\n    return vertices, triangles\n\n\nclass PSNRMeter:\n    def __init__(self):\n        self.V = 0\n        self.N = 0\n        self.psnr_list = []\n\n    def clear(self):\n        self.V = 0\n        self.N = 0\n        self.psnr_list = []\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            if torch.is_tensor(inp):\n                inp = inp.detach().cpu().numpy()\n            outputs.append(inp)\n\n        return outputs\n\n    def update(self, preds, truths):\n        preds, truths = self.prepare_inputs(\n            preds, truths\n        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]\n\n        psnr = -10 * np.log10(np.mean((preds - truths) ** 2))\n        self.psnr_list.append(psnr)\n        self.V += psnr\n        self.N += 1\n        assert self.N == len(self.psnr_list)\n\n    def measure(self):\n        return self.V / self.N\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(os.path.join(prefix, \"PSNR\"), self.measure(), global_step)\n\n    def report(self):\n        return f\"PSNR = {self.measure():.6f}\"\n\n\nclass Trainer(object):\n    def __init__(\n        self,\n        name,  # name of this experiment\n        opt,  # extra conf\n        model_tea,  # network\n        model_stu,\n        criterion=None,  # loss function, if None, assume inline implementation in train_step\n        optimizer=None,  # optimizer\n        ema_decay=None,  # if use EMA, set the decay\n        lr_scheduler=None,  # scheduler\n        metrics=[],  # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.\n        local_rank=0,  # which GPU am I\n        world_size=1,  # total num of GPUs\n        device=None,  # device to use, usually setting to None is OK. (auto choose device)\n        mute=False,  # whether to mute all print\n        fp16=False,  # amp optimize level\n        eval_interval=10e10,  # eval once every $ epoch\n        max_keep_ckpt=2,  # max num of saved ckpts in disk\n        workspace=\"workspace\",  # workspace to save logs & ckpts\n        best_mode=\"min\",  # the smaller/larger result, the better\n        use_loss_as_metric=True,  # use loss as the first metric\n        report_metric_at_train=False,  # also report metrics at training\n        use_checkpoint=\"latest\",  # which ckpt to use at init time\n        use_tensorboardX=True,  # whether to use tensorboard for logging\n        scheduler_update_every_step=False,  # whether to call scheduler.step() after every train step\n    ):\n\n        self.optimizer_fn = optimizer\n        self.lr_scheduler_fn = lr_scheduler\n        self.name = name\n        self.opt = opt\n        self.args = opt\n        self.mute = mute\n        self.metrics = metrics\n        self.local_rank = local_rank\n        self.world_size = world_size\n        self.workspace = workspace\n        self.ema_decay = ema_decay\n        self.fp16 = fp16\n        self.best_mode = best_mode\n        self.use_loss_as_metric = use_loss_as_metric\n        self.report_metric_at_train = report_metric_at_train\n        self.max_keep_ckpt = max_keep_ckpt\n        self.eval_interval = eval_interval\n        self.use_checkpoint = use_checkpoint\n        self.use_tensorboardX = use_tensorboardX\n        self.time_stamp = time.strftime(\"%Y-%m-%d_%H-%M-%S\")\n        self.scheduler_update_every_step = scheduler_update_every_step\n        self.device = (\n            device\n            if device is not None\n            else torch.device(\n                f\"cuda:{local_rank}\" if torch.cuda.is_available() else \"cpu\"\n            )\n        )\n        self.console = Console()\n\n        self.model_tea = model_tea.to(device)\n        self.model_stu = model_stu.to(device)\n\n        if isinstance(criterion, nn.Module):\n            criterion.to(self.device)\n        self.criterion = criterion\n\n        if optimizer is None:\n            self.optimizer = optim.AdamW(\n                self.model_stu.parameters(), lr=0.001, weight_decay=5e-4\n            )  # naive adam\n        else:\n            self.optimizer = optimizer(self.model_stu)\n\n        if lr_scheduler is None:\n            self.lr_scheduler = optim.lr_scheduler.LambdaLR(\n                self.optimizer, lr_lambda=lambda epoch: 1\n            )  # fake scheduler\n        else:\n            self.ls = lr_scheduler\n            self.lr_scheduler = lr_scheduler(self.optimizer)\n        if ema_decay is not None and ema_decay > 0:\n            self.ema = ExponentialMovingAverage(\n                self.model_stu.parameters(), decay=ema_decay\n            )\n        else:\n            self.ema = None\n\n        self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)\n\n        # variable init\n        self.epoch = 1\n        self.global_step = 0\n        self.local_step = 0\n        self.stats = {\n            \"loss\": [],\n            \"valid_loss\": [],\n            \"results\": [],  # metrics[0], or valid_loss\n            \"checkpoints\": [],  # record path of saved ckpt, to automatically remove old ckpt\n            \"best_result\": None,\n        }\n\n        # auto fix\n        if len(metrics) == 0 or self.use_loss_as_metric:\n            self.best_mode = \"min\"\n\n        # workspace prepare\n        self.log_ptr = None\n        if self.workspace is not None:\n            os.makedirs(self.workspace, exist_ok=True)\n            self.log_path = os.path.join(workspace, f\"log_{self.name}.txt\")\n            self.log_ptr = open(self.log_path, \"a+\")\n\n            self.ckpt_path = os.path.join(self.workspace, \"checkpoints\")\n            self.best_path = f\"{self.ckpt_path}/{self.name}.pth\"\n            os.makedirs(self.ckpt_path, exist_ok=True)\n        self.log(self.opt)\n\n        self.log(\n            f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {\"fp16\" if self.fp16 else \"fp32\"} | {self.workspace}'\n        )\n        self.log(\n            f\"[INFO] #parameters: {sum([p.numel() for p in model_stu.parameters() if p.requires_grad])}\"\n        )\n\n        if (\n            self.workspace is not None\n        ):  # only load state_dict for teacher and share backbone for student\n            self.log(f\"[INFO] Loading teacher ckpt from {self.opt.ckpt_teacher} ...\")\n            self.load_teacher_checkpoint()\n            self.log(self.model_tea)\n            self.load_student_checkpoint()\n            self.log(self.model_stu)\n            # self.model_tea.reset_extra_state()\n            # self.model_stu.reset_extra_state()\n        \"\"\"\n        if opt.rand_pose >= 0: # =0 means only using CLIP loss, >0 means a hybrid mode.\n            from nerf.clip_utils import CLIPLoss\n            self.clip_loss = CLIPLoss(self.device)\n            self.clip_loss.prepare_text([self.opt.clip_text]) # only support one text prompt now...\n        \"\"\"\n\n    def __del__(self):\n        if self.log_ptr:\n            self.log_ptr.close()\n\n    def log(self, *args, **kwargs):\n        if self.local_rank == 0:\n            if not self.mute:\n                # print(*args)\n                self.console.print(*args, **kwargs)\n            if self.log_ptr:\n                print(*args, file=self.log_ptr)\n                self.log_ptr.flush()  # write immediately to file\n\n    def train(self, train_loader, valid_loader, max_epochs):\n        self.hard_rays_pool = [torch.tensor([]).cuda(), torch.tensor([]).cuda()]\n        self.is_hard_rays_pool_full = False\n\n        if self.use_tensorboardX and self.local_rank == 0:\n            self.writer = tensorboardX.SummaryWriter(\n                os.path.join(self.workspace, \"run\", self.name)\n            )\n\n        for p in self.model_tea.parameters():\n            p.requires_grad = False\n        self.model_tea.eval()\n\n        # get a ref to error_map\n        self.error_map = train_loader._data.error_map\n\n        if (\n            not self.args.use_real_data_for_train\n        ):  # using random poses to calculate max_epochs.\n            random_poses = get_rand_poses(\n                data_type=self.args.data_type,\n                original_loader=copy.deepcopy(\n                    train_loader._data.poses.detach().cpu().numpy()\n                ),\n            )\n            self.opt.iters = int(\n                (self.opt.iters // len(random_poses)) * len(random_poses)\n            )\n            max_epochs = np.ceil(self.opt.iters / len(random_poses)).astype(np.int32)\n            scheduler = lambda optimizer: optim.lr_scheduler.CosineAnnealingLR(\n                optimizer, T_max=self.opt.iters * 1, eta_min=7e-5\n            )  # update scheduler according to new opt.iters\n            self.lr_scheduler = scheduler(self.optimizer)\n\n        self.total_epoch = max_epochs\n        self.log(f\"\\n----------------total epoch:{max_epochs} -----------\\n\")\n\n        self.real_train_poses = copy.deepcopy(train_loader._data.poses)\n        for epoch in range(self.epoch, max_epochs + 1):\n            self.epoch = epoch\n            if not self.args.use_real_data_for_train:\n                print(f\"\\n generate new random poses at epoch{self.epoch}\")\n                random_poses = get_rand_poses(\n                    data_type=self.args.data_type,\n                    original_loader=self.real_train_poses.detach().cpu().numpy(),\n                )\n                train_loader._data.poses = copy.deepcopy(random_poses)\n                train_loader._data.images = train_loader._data.images[:1].expand(\n                    len(random_poses), -1, -1, -1\n                )\n                train_loader = train_loader._data.dataloader()\n            self.train_one_epoch(train_loader)\n            print(\"\\n\", self.workspace, \"\\n\")\n\n            if (\n                self.workspace is not None\n                and self.local_rank == 0\n                and self.epoch > max_epochs - 1\n            ):\n                self.save_checkpoint(full=False, best=False)\n\n            if self.epoch % self.eval_interval == 0:\n                self.evaluate_one_epoch(valid_loader)\n                self.save_checkpoint(full=False, best=True)  # #  为了节省存储，暂时不存储pth\n\n        if self.use_tensorboardX and self.local_rank == 0:\n            self.writer.close()\n\n    def train_one_epoch(self, loader):\n        # self.log(\n        #    f\"tttttttttt> Start Training Epoch {self.epoch}/{self.total_epoch}, len(train_data):{len(loader)} lr={self.optimizer.param_groups[0]['lr']:.6f} ...\"\n        # )\n\n        total_loss = 0\n        total_loss_rgb = 0\n        total_loss_fea_sc = 0\n        total_loss_sigma = 0\n        total_loss_color = 0\n\n        psnr_tool = PSNRMeter()\n        psnr_tool.clear()\n        self.pose_psnr = []  # [(pose1, psnr1), (pose2,psnr2)...]\n\n        if self.local_rank == 0 and self.report_metric_at_train:\n            for metric in self.metrics:\n                metric.clear()\n\n        self.model_stu.train()\n        self.model_tea.train()\n\n        if self.world_size > 1:\n            loader.sampler.set_epoch(self.epoch)\n\n        if self.local_rank == 0:\n            pbar = tqdm.tqdm(\n                total=len(loader) * loader.batch_size,\n                bar_format=\"{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n            )\n\n        self.local_step = 0\n\n        for data in loader:\n            # update grid every 16 steps. It shoule be run in just train a teacher, but not when distillting a student\n            if (\n                self.model_tea.cuda_ray\n                and self.global_step % self.opt.update_extra_interval == 0\n            ):\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    if self.opt.update_stu_extra:\n                        self.model_stu.update_extra_state()\n                    else:\n                        pass\n\n            self.local_step += 1\n            self.global_step += 1\n            self.args.global_step = self.global_step\n\n            self.optimizer.zero_grad()\n\n            with torch.cuda.amp.autocast(enabled=self.fp16):\n                (\n                    preds,\n                    truths,\n                    loss,\n                    loss_rgb,\n                    loss_fea_sc,\n                    loss_color,\n                    loss_sigma,\n                ) = self.train_step(data)\n                if preds is not None:\n                    psnr_tool.update(preds, truths)\n\n            self.scaler.scale(loss).backward()\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n\n            if self.scheduler_update_every_step:\n                self.lr_scheduler.step()\n\n            loss_val = loss.item()\n            total_loss += loss_val\n            total_loss_rgb += loss_rgb\n            total_loss_sigma += loss_sigma\n            total_loss_color += loss_color\n            total_loss_fea_sc += loss_fea_sc\n\n            if self.local_rank == 0:\n                if self.report_metric_at_train:\n                    for metric in self.metrics:\n                        metric.update(preds, truths)\n\n                if self.use_tensorboardX:\n                    self.writer.add_scalar(\"train/loss\", loss_val, self.global_step)\n                    self.writer.add_scalar(\"train/loss_rgb\", loss_rgb, self.global_step)\n                    self.writer.add_scalar(\n                        \"train/loss_fea_sc\", loss_fea_sc, self.global_step\n                    )\n                    self.writer.add_scalar(\n                        \"train/loss_coloc\", loss_color, self.global_step\n                    )\n                    self.writer.add_scalar(\n                        \"train/loss_sigma\", loss_sigma, self.global_step\n                    )\n                    self.writer.add_scalar(\n                        \"train/lr\",\n                        self.optimizer.param_groups[0][\"lr\"],\n                        self.global_step,\n                    )\n\n                if self.scheduler_update_every_step:  # run this\n                    cur_lr = self.optimizer.param_groups[0][\"lr\"]\n                    if self.global_step < self.args.stage_iters[\"stage1\"]:\n                        pbar.set_description(\n                            f\"loss={total_loss/self.local_step:.5f}, fea_sc={total_loss_fea_sc/self.local_step:.5f}, lr={cur_lr:.5f}\"\n                        )\n                    elif self.global_step < self.args.stage_iters[\"stage2\"]:\n                        pbar.set_description(\n                            f\"loss={total_loss/self.local_step:.5f}, fea_sc={total_loss_fea_sc/self.local_step:.5f}, sigma={total_loss_sigma/self.local_step:.5f}, color={total_loss_color/self.local_step:.5f}, lr={cur_lr:.6f}\"\n                        )\n                    else:\n                        pbar.set_description(\n                            f\"loss={total_loss/self.local_step:.5f}, rgb={total_loss_rgb/self.local_step:.5f},  lr={cur_lr:.5f}\"\n                        )\n                else:\n                    pbar.set_description(\n                        f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\"\n                    )\n                pbar.update(loader.batch_size)\n\n            if (\n                self.opt.model_type == \"vm\"\n                and self.global_step in self.opt.upsample_model_steps\n            ):\n                # shrink\n                if (\n                    self.model_stu.cuda_ray\n                ):  # and self.global_step == self.opt.upsample_model_steps[0]:\n                    self.model_stu.shrink_model()\n\n                # adaptive voxel size from aabb_train\n                n_vox = self.upsample_resolutions.pop(0) ** 3  # n_voxels\n                aabb = self.model_stu.aabb_train.cpu().numpy()\n                vox_size = np.cbrt(np.prod(aabb[3:] - aabb[:3]) / n_vox)\n                reso = ((aabb[3:] - aabb[:3]) / vox_size).astype(np.int32).tolist()\n                self.log(\n                    f\"[INFO] upsample model at step {self.global_step} from {self.model_stu.resolution} to {reso}\"\n                )\n                from IPython import embed\n\n                embed()\n                self.model_stu.upsample_model(reso)\n\n                # reset optimizer since params changed.\n                self.optimizer = self.optimizer_fn(self.model_stu)\n                self.lr_scheduler = self.lr_scheduler_fn(self.optimizer)\n\n        if self.ema is not None:\n            self.ema.update()\n\n        average_loss = total_loss / self.local_step\n        self.stats[\"loss\"].append(average_loss)\n\n        if self.local_rank == 0:\n            pbar.close()\n            if self.report_metric_at_train:\n                for metric in self.metrics:\n                    self.log(metric.report(), style=\"red\")\n                    if self.use_tensorboardX:\n                        metric.write(self.writer, self.epoch, prefix=\"train\")\n                    metric.clear()\n\n        if not self.scheduler_update_every_step:\n            if isinstance(\n                self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau\n            ):\n                self.lr_scheduler.step(average_loss)\n            else:\n                self.lr_scheduler.step()\n\n        psnr_tool.psnr_list.sort()\n        if self.global_step < self.args.stage_iters[\"stage1\"]:\n            self.log(\n                f\"tttttttttt> Train stage1 Epoch:{self.epoch}. loss_fea:{total_loss_fea_sc/self.local_step:.6f}\"\n            )\n        elif self.global_step < self.args.stage_iters[\"stage2\"]:\n            self.log(\n                f\"tttttttttt> Train stage2 Epoch:{self.epoch}. loss_fea_sc:{total_loss_fea_sc/self.local_step:.3f} loss_sigma:{total_loss_sigma/self.local_step:.3f} loss_color:{total_loss_color/self.local_step:.3f}\"\n            )\n        else:\n            self.log(\n                f\"tttttttttt> Train stage3 Epoch:{self.epoch}. loss_rgb:{total_loss_rgb/self.global_step:.3f} loss_fea_sc:{total_loss_fea_sc/self.local_step:.3f} loss_sigma:{total_loss_sigma/self.local_step:.3f} loss_color:{total_loss_color/self.local_step:.3f}\"\n            )\n            self.log(\n                f\"tttttttttt> Train PSNR Epoch {self.epoch}. psnr_min:{psnr_tool.psnr_list[0]:.3f} psnr_max:{psnr_tool.psnr_list[-1]:.3f} psnr_mean:{np.mean(psnr_tool.psnr_list):.3f}\"\n            )\n\n    def get_loss(self, pred, gt):\n        if self.opt.loss_type == \"L2\":\n            loss = torch.mean((gt - pred) ** 2)\n        elif self.opt.loss_type == \"normL2\":\n            loss = torch.norm(pred - gt)\n        elif self.opt.loss_type == \"normL1\":\n            loss = torch.norm(pred - gt, p=1)\n        elif self.opt.loss_type == \"smoothL1\":\n            loss = torch.nn.functional.smooth_l1_loss(pred, gt, beta=0.05)\n        else:\n            raise ValueError(\"error loss_type\")\n        return loss\n\n    def train_step(self, data):\n        rays_o = data[\"rays_o\"]  # [B, N, 3]\n        rays_d = data[\"rays_d\"]  # [B, N, 3]  [1, N=rays_num=4096, 3]\n\n        loss = 0.0\n\n        # if there is no gt image, we train with CLIP loss.\n        if \"images\" not in data:\n            assert 1 == 2\n            B, N = rays_o.shape[:2]\n            H, W = data[\"H\"], data[\"W\"]\n            # currently fix white bg, MUST force all rays!\n            outputs = self.model.render(\n                rays_o,\n                rays_d,\n                staged=False,\n                bg_color=None,\n                perturb=True,\n                force_all_rays=True,\n                **vars(self.opt),\n            )\n            pred_rgb = (\n                outputs[\"image\"].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()\n            )\n            loss = self.clip_loss(pred_rgb)\n            return pred_rgb, None, loss\n\n        images = data[\"images\"]  # [B, N, 3/4]\n        B, N, C = images.shape\n\n        # if self.opt.color_space == 'linear':\n        #    images[..., :3] = srgb_to_linear(images[..., :3])\n\n        if (\n            C == 3 or self.model_stu.bg_radius > 0\n        ):  #  C=4 in synthetic dataset. C=3 for real dataset\n            bg_color = 1\n        # train with random background color if not using a bg model and has alpha channel.\n        else:\n            bg_color = torch.rand(\n                [B, rays_o.size(1), 3], dtype=images.dtype, device=images.device\n            )\n\n        if self.opt.render_stu_first:\n            outputs_stu = self.model_stu.render(\n                rays_o,\n                rays_d,\n                staged=False,\n                bg_color=bg_color,\n                perturb=True,\n                force_all_rays=False,\n                **vars(self.opt),\n            )\n            pred_rgb_stu = outputs_stu[\"image\"]\n            with torch.no_grad():\n                outputs_tea = self.model_tea.render(\n                    rays_o,\n                    rays_d,\n                    staged=False,\n                    bg_color=bg_color,\n                    perturb=True,\n                    force_all_rays=False,\n                    inherited_params=outputs_stu[\"inherited_params\"],\n                    **vars(self.opt),\n                )\n                pred_rgb_tea = outputs_tea[\"image\"]\n        else:\n            with torch.no_grad():\n                outputs_tea = self.model_tea.render(\n                    rays_o,\n                    rays_d,\n                    staged=False,\n                    bg_color=bg_color,\n                    perturb=True,\n                    force_all_rays=False,\n                    **vars(self.opt),\n                )\n                pred_rgb_tea = outputs_tea[\"image\"]\n            outputs_stu = self.model_stu.render(\n                rays_o,\n                rays_d,\n                staged=False,\n                bg_color=bg_color,\n                perturb=True,\n                force_all_rays=False,\n                inherited_params=outputs_tea[\"inherited_params\"],\n                **vars(self.opt),\n            )\n            pred_rgb_stu = outputs_stu[\"image\"]\n        gt_rgb = pred_rgb_tea\n        self.opt.loss_rate_fea_sc = update_loss_rate(self.opt.loss_rate_fea_sc, 0.995)\n\n        if (\n            \"stage1\" in outputs_stu\n            and self.opt.loss_rate_fea_sc > 0.0\n            and self.model_stu.feature_sigma_color is not None\n            and self.model_tea.feature_sigma_color is not None\n        ):\n            assert (\n                self.model_stu.feature_sigma_color.shape\n                == self.model_tea.feature_sigma_color.shape\n            )\n            loss_fea_sc = self.get_loss(\n                self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color\n            )\n            loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc\n            return None, None, loss, 0, loss_fea_sc.detach().item(), 0, 0\n        if \"stage2\" in outputs_stu:\n            if self.opt.loss_rate_color > 0.0:\n                assert self.model_stu.color_l.shape == self.model_tea.color_l.shape\n                loss_color = self.get_loss(\n                    self.model_stu.color_l, self.model_tea.color_l\n                )\n                loss = loss + self.opt.loss_rate_color * loss_color\n            else:\n                assert self.model_stu.color_l.shape == self.model_tea.color_l.shape\n                loss_color = self.get_loss(\n                    self.model_stu.color_l, self.model_tea.color_l\n                )\n            if self.opt.loss_rate_sigma > 0.0:\n                assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape\n                loss_sigma = self.get_loss(\n                    self.model_stu.sigma_l, self.model_tea.sigma_l\n                )\n                loss = loss + self.opt.loss_rate_sigma * loss_sigma\n            else:\n                assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape\n                loss_sigma = self.get_loss(\n                    self.model_stu.sigma_l, self.model_tea.sigma_l\n                )\n            if (\n                self.opt.loss_rate_fea_sc > 0.0\n                and self.model_stu.feature_sigma_color is not None\n                and self.model_tea.feature_sigma_color is not None\n            ):\n                assert (\n                    self.model_stu.feature_sigma_color.shape\n                    == self.model_tea.feature_sigma_color.shape\n                )\n                loss_fea_sc = self.get_loss(\n                    self.model_stu.feature_sigma_color,\n                    self.model_tea.feature_sigma_color,\n                )\n                loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc\n            else:\n                loss_fea_sc = torch.tensor(0.0)\n            return (\n                None,\n                None,\n                loss,\n                0,\n                loss_fea_sc.detach().item(),\n                loss_color.detach().item(),\n                loss_sigma.detach().item(),\n            )\n\n        if self.opt.loss_type == \"normL2\":\n            loss_rgb = torch.norm(pred_rgb_tea - pred_rgb_stu)\n        elif self.opt.loss_type == \"normL1\":\n            loss_rgb = torch.norm(pred_rgb_tea - pred_rgb_stu, p=1)\n        elif self.opt.loss_type == \"L2\":\n            loss_rgb = self.criterion(pred_rgb_tea, pred_rgb_stu).mean(\n                -1\n            )  # [B, N, 3] --> [B, N]\n            if len(loss_rgb.shape) == 3:  # [K, B, N]\n                loss_rgb = loss_rgb.mean(0)\n            if self.error_map is not None:\n                index = data[\"index\"]  # [B]\n                inds = data[\"inds_coarse\"]  # [B, N]\n                error_map = self.error_map[index]  # [B, H * W]\n                error = loss_rgb.detach().to(\n                    error_map.device\n                )  # [B, N], already in [0, 1]\n                ema_error = 0.1 * error_map.gather(1, inds) + 0.9 * error  # ema update\n                error_map.scatter_(1, inds, ema_error)\n                self.error_map[index] = error_map  # put back\n            loss_rgb = loss_rgb.mean()\n        else:\n            raise ValueError(\"error loss_type\")\n        loss = loss + loss_rgb * self.opt.loss_rate_rgb\n\n        if self.opt.l1_reg_weight > 0.0 and self.opt.model_type == \"vm\":\n            loss = loss + self.model_stu.density_loss() * self.opt.l1_reg_weight\n        if (\n            self.opt.loss_rate_fea_sc > 0.0\n            and self.model_stu.feature_sigma_color is not None\n            and self.model_tea.feature_sigma_color is not None\n        ):\n            assert (\n                self.model_stu.feature_sigma_color.shape\n                == self.model_tea.feature_sigma_color.shape\n            )\n            loss_fea_sc = self.get_loss(\n                self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color\n            )\n            loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc\n        elif (\n            self.model_stu.feature_sigma_color is None\n            or self.model_tea.feature_sigma_color is None\n        ):\n            loss_fea_sc = torch.tensor(0.0)\n        else:\n            assert (\n                self.model_stu.feature_sigma_color.shape\n                == self.model_tea.feature_sigma_color.shape\n            )\n            loss_fea_sc = self.get_loss(\n                self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color\n            )\n        if self.opt.loss_rate_color > 0.0:\n            assert self.model_stu.color_l.shape == self.model_tea.color_l.shape\n            loss_color = self.get_loss(self.model_stu.color_l, self.model_tea.color_l)\n            loss = loss + self.opt.loss_rate_color * loss_color\n        else:\n            assert self.model_stu.color_l.shape == self.model_tea.color_l.shape\n            loss_color = self.get_loss(self.model_stu.color_l, self.model_tea.color_l)\n        if self.opt.loss_rate_sigma > 0.0:\n            assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape\n            loss_sigma = self.get_loss(self.model_stu.sigma_l, self.model_tea.sigma_l)\n            loss = loss + self.opt.loss_rate_sigma * loss_sigma\n        else:\n            assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape\n            loss_sigma = self.get_loss(self.model_stu.sigma_l, self.model_tea.sigma_l)\n\n        loss_rgb_show = self.criterion(\n            pred_rgb_tea.detach(), pred_rgb_stu.detach()\n        ).mean()  # [B, N, 3] --> [B, N]\n        return (\n            pred_rgb_stu,\n            gt_rgb,\n            loss,\n            loss_rgb_show.detach().item(),\n            loss_fea_sc.detach().item(),\n            loss_color.detach().item(),\n            loss_sigma.detach().item(),\n        )\n\n    ### ------------------------------\n\n    def evaluate(self, loader, name=None):\n        self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX\n        self.evaluate_one_epoch(loader, name)\n        self.use_tensorboardX = use_tensorboardX\n\n    def evaluate_one_epoch(self, loader, name=None):\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n\n        total_loss = 0\n        if self.local_rank == 0:\n            for metric in self.metrics:\n                metric.clear()\n        if self.opt.test_teacher:\n            self.model_stu = self.model_tea\n        self.model_stu.eval()\n\n        if self.ema is not None:\n            self.ema.store()\n            self.ema.copy_to()\n\n        if self.local_rank == 0:\n            pbar = tqdm.tqdm(\n                total=len(loader) * loader.batch_size,\n                bar_format=\"{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n            )\n\n        with torch.no_grad():\n            self.local_step = 0\n            self.ssim = 0.0\n            self.lpips_vgg = 0.0\n            self.lpips_alex = 0.0\n\n            # update grid\n            if self.model_stu.cuda_ray:\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    if self.opt.update_stu_extra:\n                        self.model_stu.update_extra_state()\n                    else:\n                        pass\n\n            frames = []\n            frames_depth = []\n            for data in loader:\n                self.local_step += 1\n\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    preds, preds_depth, truths, loss = self.eval_step(data)\n\n                # all_gather/reduce the statistics (NCCL only support all_*)\n                if self.world_size > 1:\n                    dist.all_reduce(loss, op=dist.ReduceOp.SUM)\n                    loss = loss / self.world_size\n                    preds_list = [\n                        torch.zeros_like(preds).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(preds_list, preds)\n                    preds = torch.cat(preds_list, dim=0)\n\n                    preds_depth_list = [\n                        torch.zeros_like(preds_depth).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(preds_depth_list, preds_depth)\n                    preds_depth = torch.cat(preds_depth_list, dim=0)\n\n                    truths_list = [\n                        torch.zeros_like(truths).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(truths_list, truths)\n                    truths = torch.cat(truths_list, dim=0)\n                loss_val = loss.item()\n                total_loss += loss_val\n\n                if self.local_rank == 0:\n\n                    for metric in self.metrics:\n                        metric.update(preds, truths)\n                    self.lpips_alex += rgb_lpips(truths, preds, \"alex\")\n                    self.lpips_vgg += rgb_lpips(truths, preds, \"vgg\")\n                    self.ssim += compute_ssim(\n                        preds,\n                        truths,\n                        max_val=max(preds.max().item(), truths.max().item()),\n                    ).item()\n\n                    # save image\n                    save_path = os.path.join(\n                        self.workspace,\n                        loader._data.type,\n                        f\"{name}_{self.local_step:04d}.png\",\n                    )\n                    save_path_depth = os.path.join(\n                        self.workspace,\n                        loader._data.type,\n                        f\"{name}_{self.local_step:04d}_depth.png\",\n                    )\n                    # save_path_gt = os.path.join(self.workspace, loader._data.type, f'{name}_{self.local_step:04d}_gt.png')\n\n                    os.makedirs(os.path.dirname(save_path), exist_ok=True)\n\n                    if self.opt.color_space == \"linear\":\n                        preds = linear_to_srgb(preds)\n\n                    pred = preds[0].detach().cpu().numpy()\n                    truth = truths[0].detach().cpu().numpy()\n                    pred_depth = preds_depth[0].detach().cpu().numpy()\n                    cv2.imwrite(\n                        save_path,\n                        cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR),\n                    )\n                    cv2.imwrite(save_path_depth, (pred_depth * 255).astype(np.uint8))\n                    frames.append((pred * 255).astype(np.uint8))\n                    frames_depth.append((pred_depth * 255).astype(np.uint8))\n\n                    pbar.set_description(\n                        f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\"\n                    )\n                    pbar.update(loader.batch_size)\n\n            print(\n                f\"\\n----video num(frames): {len(frames)} depth video num:{len(frames_depth)}  ----\\n\"\n            )\n            imageio.mimwrite(\n                os.path.join(os.path.dirname(save_path), \"video.mp4\"),\n                frames,\n                fps=int(30 * 0.7),\n                macro_block_size=8,\n            )\n            imageio.mimwrite(\n                os.path.join(os.path.dirname(save_path), \"video_depth.mp4\"),\n                frames_depth,\n                fps=int(30 * 0.7),\n                macro_block_size=8,\n            )\n\n        psnr_tool = self.metrics[0]\n\n        psnr_tool.psnr_list.sort()\n        self.log(\n            f\"\\neeeeeeeee> {loader._data.type} PSRN Report: Epoch{self.epoch}.  psnr_mean:{np.mean(psnr_tool.psnr_list):.2f}\"\n        )\n\n        average_loss = total_loss / self.local_step\n        self.stats[\"valid_loss\"].append(average_loss)\n\n        if self.local_rank == 0:\n            pbar.close()\n            if not self.use_loss_as_metric and len(self.metrics) > 0:\n                result = self.metrics[0].measure()\n                self.stats[\"results\"].append(\n                    result if self.best_mode == \"min\" else -result\n                )  # if max mode, use -result\n            else:\n                self.stats[\"results\"].append(\n                    average_loss\n                )  # if no metric, choose best by min loss\n\n            for metric in self.metrics:\n                # self.log(metric.report(), style=\"blue\")\n                psnr = metric.report().split(\"=\")[-1].strip()[:5]\n                self.psnr = float(psnr)\n                if self.use_tensorboardX and loader._data.type == 'val':\n                    metric.write(self.writer, self.epoch, prefix=\"evaluate\")\n                metric.clear()\n\n        self.ssim /= self.local_step\n        self.lpips_alex /= self.local_step\n        self.lpips_vgg /= self.local_step\n        if self.ema is not None:\n            self.ema.restore()\n        self.log(\n            f\"eeeeeeeeee> {loader._data.type} Metric Report: Epoch{self.epoch}. psnr:{psnr} ssim:{self.ssim:.2f} alex:{self.lpips_alex:.2f} vgg:{self.lpips_vgg:.2f}\"\n        )\n\n    def eval_step(self, data):\n\n        rays_o = data[\"rays_o\"]  # [B, N, 3]\n        rays_d = data[\"rays_d\"]  # [B, N, 3]\n        images = data[\"images\"]  # [B, H, W, 3/4]\n        B, H, W, C = images.shape\n\n        if self.opt.color_space == \"linear\":\n            images[..., :3] = srgb_to_linear(images[..., :3])\n\n        # eval with fixed background color\n        bg_color = 1\n        if C == 4:\n            gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (\n                1 - images[..., 3:]\n            )\n        else:\n            gt_rgb = images\n\n        outputs = self.model_stu.render(\n            rays_o,\n            rays_d,\n            staged=True,\n            bg_color=bg_color,\n            perturb=False,\n            **vars(self.opt),\n        )\n\n        pred_rgb = outputs[\"image\"].reshape(B, H, W, 3)\n        pred_depth = outputs[\"depth\"].reshape(B, H, W)\n\n        loss = self.criterion(pred_rgb, gt_rgb).mean()\n\n        return pred_rgb, pred_depth, gt_rgb, loss\n\n    def save_checkpoint(self, name=None, full=False, best=False, remove_old=True):\n        full = False\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n        if self.opt.model_type == \"vm\":\n            state = {\n                \"epoch\": self.epoch,\n                \"global_step\": self.global_step,\n                \"stats\": self.stats,\n                \"resolution\": self.model_stu.resolution,\n            }\n        else:\n            state = {\n                \"epoch\": self.epoch,\n                \"global_step\": self.global_step,\n                \"stats\": self.stats,\n            }\n\n        if self.model_stu.cuda_ray:\n            state[\"mean_count\"] = self.model_stu.mean_count\n            state[\"mean_density\"] = self.model_stu.mean_density\n\n        if full:\n            state[\"optimizer\"] = self.optimizer.state_dict()\n            state[\"lr_scheduler\"] = self.lr_scheduler.state_dict()\n            state[\"scaler\"] = self.scaler.state_dict()\n            if self.ema is not None:\n                state[\"ema\"] = self.ema.state_dict()\n\n        if not best:\n\n            state[\"model\"] = self.model_stu.state_dict()\n\n            file_path = f\"{self.ckpt_path}/{name}.pth\"\n\n            if remove_old:\n                self.stats[\"checkpoints\"].append(file_path)\n\n                if len(self.stats[\"checkpoints\"]) > self.max_keep_ckpt:\n                    old_ckpt = self.stats[\"checkpoints\"].pop(0)\n                    if os.path.exists(old_ckpt):\n                        os.remove(old_ckpt)\n\n            torch.save(state, file_path)\n\n        else:\n            if len(self.stats[\"results\"]) > 0:\n                if (\n                    self.stats[\"best_result\"] is None\n                    or self.stats[\"results\"][-1] < self.stats[\"best_result\"]\n                ):\n                    self.log(\n                        f\"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}\"\n                    )\n                    self.stats[\"best_result\"] = self.stats[\"results\"][-1]\n\n                    # save ema results\n                    if self.ema is not None:\n                        self.ema.store()\n                        self.ema.copy_to()\n\n                    state[\"model\"] = self.model_stu.state_dict()\n\n                    if self.ema is not None:\n                        self.ema.restore()\n\n                    torch.save(state, self.best_path)\n            else:\n                self.log(\n                    f\"[WARN] no evaluated results found, skip saving best checkpoint.\"\n                )\n\n    def load_teacher_checkpoint(self):\n        checkpoint_dict = torch.load(self.opt.ckpt_teacher, map_location=self.device)\n\n        missing_keys, unexpected_keys = self.model_tea.load_state_dict(\n            checkpoint_dict[\"model\"], strict=False\n        )\n        self.log(\"[INFO] loaded teacher model.\")\n        if len(missing_keys) > 0:\n            self.log(f\"[WARN] missing keys: {missing_keys}\")\n        if len(unexpected_keys) > 0:\n            self.log(f\"[WARN] unexpected keys: {unexpected_keys}\")\n        if self.ema is not None and \"ema\" in checkpoint_dict:\n            self.ema.load_state_dict(checkpoint_dict[\"ema\"])\n\n        if self.model_tea.cuda_ray:\n            if \"mean_count\" in checkpoint_dict:\n                self.model_tea.mean_count = checkpoint_dict[\"mean_count\"]\n            if \"mean_density\" in checkpoint_dict:\n                self.model_tea.mean_density = checkpoint_dict[\"mean_density\"]\n        \"\"\"\n        self.stats = checkpoint_dict['stats']\n        self.epoch = checkpoint_dict['epoch']\n        self.global_step = checkpoint_dict['global_step']\n        self.log(f\"[INFO] load at epoch {self.epoch}, global step {self.global_step}\")\n        \n        if self.optimizer and  'optimizer' in checkpoint_dict:\n            try:\n                self.optimizer.load_state_dict(checkpoint_dict['optimizer'])\n                self.log(\"[INFO] loaded optimizer.\")\n            except:\n                self.log(\"[WARN] Failed to load optimizer.\")\n        \n        if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:\n            try:\n                self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])\n                self.log(\"[INFO] loaded scheduler.\")\n            except:\n                self.log(\"[WARN] Failed to load scheduler.\")\n        \n        if self.scaler and 'scaler' in checkpoint_dict:\n            try:\n                self.scaler.load_state_dict(checkpoint_dict['scaler'])\n                self.log(\"[INFO] loaded scaler.\")\n            except:\n                self.log(\"[WARN] Failed to load scaler.\")\n\n\n        if self.model_tea.cuda_ray:\n            if 'mean_count' in checkpoint_dict:\n                self.model_tea.mean_count = checkpoint_dict['mean_count']\n            if 'mean_density' in checkpoint_dict:\n                self.model_tea.mean_density = checkpoint_dict['mean_density']\n        \"\"\"\n\n    def load_student_checkpoint(self):\n        if self.opt.ckpt_student:\n            checkpoint_dict = torch.load(\n                self.opt.ckpt_student, map_location=self.device\n            )\n        else:\n            checkpoint_dict = torch.load(\n                self.opt.ckpt_teacher, map_location=self.device\n            )\n\n        if self.opt.model_type == \"vm\" and \"resolution\" in checkpoint_dict:\n            self.model_stu.upsample_model(checkpoint_dict[\"resolution\"])\n        missing_keys, unexpected_keys = self.model_stu.load_state_dict(\n            checkpoint_dict[\"model\"], strict=False\n        )\n        self.log(\"[INFO] loaded student model.\")\n        if len(missing_keys) > 0:\n            self.log(f\"[WARN] missing keys: {missing_keys}\")\n        if len(unexpected_keys) > 0:\n            self.log(f\"[WARN] unexpected keys: {unexpected_keys}\")\n\n        if self.model_stu.cuda_ray:\n            if \"mean_count\" in checkpoint_dict:\n                self.model_stu.mean_count = checkpoint_dict[\"mean_count\"]\n            if \"mean_density\" in checkpoint_dict:\n                self.model_stu.mean_density = checkpoint_dict[\"mean_density\"]\n\n        if self.ema is not None and \"ema\" in checkpoint_dict:\n            self.ema.load_state_dict(checkpoint_dict[\"ema\"])\n\n        \"\"\"\n        self.stats = checkpoint_dict['stats']\n        self.epoch = checkpoint_dict['epoch']\n        self.global_step = checkpoint_dict['global_step']\n        self.log(f\"[INFO] load at epoch {self.epoch}, global step {self.global_step}\")\n        \n        if self.optimizer and  'optimizer' in checkpoint_dict:\n            try:\n                self.optimizer.load_state_dict(checkpoint_dict['optimizer'])\n                self.log(\"[INFO] loaded optimizer.\")\n            except:\n                self.log(\"[WARN] Failed to load optimizer.\")\n        \n        if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:\n            try:\n                self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])\n                self.log(\"[INFO] loaded scheduler.\")\n            except:\n                self.log(\"[WARN] Failed to load scheduler.\")\n        \n        if self.scaler and 'scaler' in checkpoint_dict:\n            try:\n                self.scaler.load_state_dict(checkpoint_dict['scaler'])\n                self.log(\"[INFO] loaded scaler.\")\n            except:\n                self.log(\"[WARN] Failed to load scaler.\")\n        \"\"\"\n\n    def load_checkpoint(self, checkpoint=None, model_only=False):\n        if checkpoint is None:\n            checkpoint_list = sorted(glob.glob(f\"{self.ckpt_path}/{self.name}_ep*.pth\"))\n            if checkpoint_list:\n                checkpoint = checkpoint_list[-1]\n                self.log(f\"[INFO] Latest checkpoint is {checkpoint}\")\n            else:\n                self.log(\"[WARN] No checkpoint found, model randomly initialized.\")\n                return\n\n        checkpoint_dict = torch.load(checkpoint, map_location=self.device)\n\n        if \"model\" not in checkpoint_dict:\n            self.model.load_state_dict(checkpoint_dict)\n            self.log(\"[INFO] loaded model.\")\n            return\n\n        missing_keys, unexpected_keys = self.model.load_state_dict(\n            checkpoint_dict[\"model\"], strict=False\n        )\n        self.log(\"[INFO] loaded model.\")\n        if len(missing_keys) > 0:\n            self.log(f\"[WARN] missing keys: {missing_keys}\")\n        if len(unexpected_keys) > 0:\n            self.log(f\"[WARN] unexpected keys: {unexpected_keys}\")\n\n        if self.ema is not None and \"ema\" in checkpoint_dict:\n            self.ema.load_state_dict(checkpoint_dict[\"ema\"])\n\n        if self.model.cuda_ray:\n            if \"mean_count\" in checkpoint_dict:\n                self.model.mean_count = checkpoint_dict[\"mean_count\"]\n            if \"mean_density\" in checkpoint_dict:\n                self.model.mean_density = checkpoint_dict[\"mean_density\"]\n\n        if model_only:\n            return\n\n        self.stats = checkpoint_dict[\"stats\"]\n        self.epoch = checkpoint_dict[\"epoch\"]\n        self.global_step = checkpoint_dict[\"global_step\"]\n        self.log(f\"[INFO] load at epoch {self.epoch}, global step {self.global_step}\")\n\n        if self.optimizer and \"optimizer\" in checkpoint_dict:\n            try:\n                self.optimizer.load_state_dict(checkpoint_dict[\"optimizer\"])\n                self.log(\"[INFO] loaded optimizer.\")\n            except:\n                self.log(\"[WARN] Failed to load optimizer.\")\n\n        if self.lr_scheduler and \"lr_scheduler\" in checkpoint_dict:\n            try:\n                self.lr_scheduler.load_state_dict(checkpoint_dict[\"lr_scheduler\"])\n                self.log(\"[INFO] loaded scheduler.\")\n            except:\n                self.log(\"[WARN] Failed to load scheduler.\")\n\n        if self.scaler and \"scaler\" in checkpoint_dict:\n            try:\n                self.scaler.load_state_dict(checkpoint_dict[\"scaler\"])\n                self.log(\"[INFO] loaded scaler.\")\n            except:\n                self.log(\"[WARN] Failed to load scaler.\")\n\n    def test(self, loader, save_path=None, name=None):\n        assert 1 == 2\n        if save_path is None:\n            save_path = os.path.join(self.workspace, \"results\")\n\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n\n        os.makedirs(save_path, exist_ok=True)\n\n        self.log(f\"==> Start Test, save results to {save_path}\")\n\n        pbar = tqdm.tqdm(\n            total=len(loader) * loader.batch_size,\n            bar_format=\"{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n        )\n        self.model_stu.eval()\n        with torch.no_grad():\n\n            # update grid\n            if self.model_stu.cuda_ray:\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    self.model_stu.update_extra_state()\n\n            for i, data in enumerate(loader):\n\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    preds, preds_depth = self.test_step(data)\n\n                path = os.path.join(save_path, f\"{name}_{i:04d}.png\")\n                path_depth = os.path.join(save_path, f\"{name}_{i:04d}_depth.png\")\n\n                # self.log(f\"[INFO] saving test image to {path}\")\n\n                if self.opt.color_space == \"linear\":\n                    preds = linear_to_srgb(preds)\n\n                pred = preds[0].detach().cpu().numpy()\n                pred_depth = preds_depth[0].detach().cpu().numpy()\n\n                cv2.imwrite(\n                    path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)\n                )\n                cv2.imwrite(path_depth, (pred_depth * 255).astype(np.uint8))\n\n                pbar.update(loader.batch_size)\n\n        self.log(f\"==> Finished Test.\")\n\n    # moved out bg_color and perturb for more flexible control...\n    def test_step(self, data, bg_color=None, perturb=False):\n\n        rays_o = data[\"rays_o\"]  # [B, N, 3]\n        rays_d = data[\"rays_d\"]  # [B, N, 3]\n        H, W = data[\"H\"], data[\"W\"]\n\n        if bg_color is not None:\n            bg_color = bg_color.to(self.device)\n\n        outputs = self.model_stu.render(\n            rays_o,\n            rays_d,\n            staged=True,\n            bg_color=bg_color,\n            perturb=perturb,\n            **vars(self.opt),\n        )\n\n        pred_rgb = outputs[\"image\"].reshape(-1, H, W, 3)\n        pred_depth = outputs[\"depth\"].reshape(-1, H, W)\n\n        return pred_rgb, pred_depth\n"
  },
  {
    "path": "gridencoder/__init__.py",
    "content": "from .grid import GridEncoder\n"
  },
  {
    "path": "gridencoder/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(\n    name=\"_grid_encoder\",\n    extra_cflags=c_flags,\n    extra_cuda_cflags=nvcc_flags,\n    sources=[\n        os.path.join(_src_path, \"src\", f)\n        for f in [\n            \"gridencoder.cu\",\n            \"bindings.cpp\",\n        ]\n    ],\n)\n\n__all__ = [\"_backend\"]\n"
  },
  {
    "path": "gridencoder/grid.py",
    "content": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    import _gridencoder as _backend\nexcept ImportError:\n    from .backend import _backend\n\n_gridtype_to_id = {\n    \"hash\": 0,\n    \"tiled\": 1,\n}\n\n\nclass _grid_encode(Function):\n    @staticmethod\n    @custom_fwd\n    def forward(\n        ctx,\n        inputs,\n        embeddings,\n        offsets,\n        per_level_scale,\n        base_resolution,\n        calc_grad_inputs=False,\n        gridtype=0,\n        align_corners=False,\n    ):\n        # inputs: [B, D], float in [0, 1]\n        # embeddings: [sO, C], float\n        # offsets: [L + 1], int\n        # RETURN: [B, F], float\n\n        inputs = inputs.contiguous()\n\n        B, D = inputs.shape  # batch size, coord dim\n        L = offsets.shape[0] - 1  # level\n        C = embeddings.shape[1]  # embedding dim for each level\n        S = np.log2(\n            per_level_scale\n        )  # resolution multiplier at each level, apply log2 for later CUDA exp2f\n        H = base_resolution  # base resolution\n\n        # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)\n        # if C % 2 != 0, force float, since half for atomicAdd is very slow.\n        if torch.is_autocast_enabled() and C % 2 == 0:\n            embeddings = embeddings.to(torch.half)\n\n        # L first, optimize cache for cuda kernel, but needs an extra permute later\n        outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)\n\n        if calc_grad_inputs:\n            dy_dx = torch.empty(\n                B, L * D * C, device=inputs.device, dtype=embeddings.dtype\n            )\n        else:\n            dy_dx = torch.empty(\n                1, device=inputs.device, dtype=embeddings.dtype\n            )  # placeholder... TODO: a better way?\n\n        _backend.grid_encode_forward(\n            inputs,\n            embeddings,\n            offsets,\n            outputs,\n            B,\n            D,\n            C,\n            L,\n            S,\n            H,\n            calc_grad_inputs,\n            dy_dx,\n            gridtype,\n            align_corners,\n        )\n\n        # permute back to [B, L * C]\n        outputs = outputs.permute(1, 0, 2).reshape(B, L * C)\n\n        ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)\n        ctx.dims = [B, D, C, L, S, H, gridtype]\n        ctx.calc_grad_inputs = calc_grad_inputs\n        ctx.align_corners = align_corners\n\n        return outputs\n\n    @staticmethod\n    # @once_differentiable\n    @custom_bwd\n    def backward(ctx, grad):\n\n        inputs, embeddings, offsets, dy_dx = ctx.saved_tensors\n        B, D, C, L, S, H, gridtype = ctx.dims\n        calc_grad_inputs = ctx.calc_grad_inputs\n        align_corners = ctx.align_corners\n\n        # grad: [B, L * C] --> [L, B, C]\n        grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()\n\n        grad_embeddings = torch.zeros_like(embeddings)\n\n        if calc_grad_inputs:\n            grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)\n        else:\n            grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)\n\n        _backend.grid_encode_backward(\n            grad,\n            inputs,\n            embeddings,\n            offsets,\n            grad_embeddings,\n            B,\n            D,\n            C,\n            L,\n            S,\n            H,\n            calc_grad_inputs,\n            dy_dx,\n            grad_inputs,\n            gridtype,\n            align_corners,\n        )\n\n        if calc_grad_inputs:\n            grad_inputs = grad_inputs.to(inputs.dtype)\n            return grad_inputs, grad_embeddings, None, None, None, None, None, None\n        else:\n            return None, grad_embeddings, None, None, None, None, None, None\n\n\ngrid_encode = _grid_encode.apply\n\n\nclass GridEncoder(nn.Module):\n    def __init__(\n        self,\n        input_dim=3,\n        num_levels=16,\n        level_dim=2,\n        per_level_scale=2,\n        base_resolution=16,\n        log2_hashmap_size=19,\n        desired_resolution=None,\n        gridtype=\"hash\",\n        align_corners=False,\n    ):\n        super().__init__()\n\n        # the finest resolution desired at the last level, if provided, overridee per_level_scale\n        if desired_resolution is not None:\n            per_level_scale = np.exp2(\n                np.log2(desired_resolution / base_resolution) / (num_levels - 1)\n            )\n\n        self.input_dim = input_dim  # coord dims, 2 or 3\n        self.num_levels = num_levels  # num levels, each level multiply resolution by 2\n        self.level_dim = level_dim  # encode channels per level\n        self.per_level_scale = (\n            per_level_scale  # multiply resolution by this scale at each level.\n        )\n        self.log2_hashmap_size = log2_hashmap_size\n        self.base_resolution = base_resolution\n        self.output_dim = num_levels * level_dim\n        self.gridtype = gridtype\n        self.gridtype_id = _gridtype_to_id[gridtype]  # \"tiled\" or \"hash\"\n        self.align_corners = align_corners\n\n        # allocate parameters\n        offsets = []\n        offset = 0\n        self.max_params = 2 ** log2_hashmap_size\n        for i in range(num_levels):\n            resolution = int(np.ceil(base_resolution * per_level_scale ** i))\n            params_in_level = min(\n                self.max_params,\n                (resolution if align_corners else resolution + 1) ** input_dim,\n            )  # limit max number\n            params_in_level = int(np.ceil(params_in_level / 8) * 8)  # make divisible\n            offsets.append(offset)\n            offset += params_in_level\n        offsets.append(offset)\n        offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))\n        self.register_buffer(\"offsets\", offsets)\n\n        self.n_params = offsets[-1] * level_dim\n\n        # parameters\n        self.embeddings = nn.Parameter(torch.empty(offset, level_dim))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        std = 1e-4\n        self.embeddings.data.uniform_(-std, std)\n\n    def __repr__(self):\n        return f\"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}\"\n\n    def forward(self, inputs, bound=1):\n        # inputs: [..., input_dim], normalized real world positions in [-bound, bound]\n        # return: [..., num_levels * level_dim]\n\n        inputs = (inputs + bound) / (2 * bound)  # map to [0, 1]\n\n        # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())\n\n        prefix_shape = list(inputs.shape[:-1])\n        inputs = inputs.view(-1, self.input_dim)\n\n        outputs = grid_encode(\n            inputs,\n            self.embeddings,\n            self.offsets,\n            self.per_level_scale,\n            self.base_resolution,\n            inputs.requires_grad,\n            self.gridtype_id,\n            self.align_corners,\n        )\n        outputs = outputs.view(prefix_shape + [self.output_dim])\n\n        # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())\n\n        return outputs\n"
  },
  {
    "path": "gridencoder/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name=\"gridencoder\",  # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name=\"_gridencoder\",  # extension name, import this to use CUDA API\n            sources=[\n                os.path.join(_src_path, \"src\", f)\n                for f in [\n                    \"gridencoder.cu\",\n                    \"bindings.cpp\",\n                ]\n            ],\n            extra_compile_args={\n                \"cxx\": c_flags,\n                \"nvcc\": nvcc_flags,\n            },\n        ),\n    ],\n    cmdclass={\n        \"build_ext\": BuildExtension,\n    },\n)\n"
  },
  {
    "path": "gridencoder/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"gridencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"grid_encode_forward\", &grid_encode_forward, \"grid_encode_forward (CUDA)\");\n    m.def(\"grid_encode_backward\", &grid_encode_backward, \"grid_encode_backward (CUDA)\");\n}"
  },
  {
    "path": "gridencoder/src/gridencoder.cu",
    "content": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <stdexcept>\n\n#include <stdint.h>\n#include <cstdio>\n\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x \" must be a floating tensor\")\n\n\n// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...\nstatic inline  __device__ at::Half atomicAdd(at::Half *address, at::Half val) {\n  // requires CUDA >= 10 and ARCH >= 70\n  // this is very slow compared to float or __half2, and never used.\n  //return atomicAdd(reinterpret_cast<__half*>(address), val);\n}\n\n\ntemplate <typename T>\nstatic inline __host__ __device__ T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\n\ntemplate <uint32_t D>\n__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {\n    static_assert(D <= 7, \"fast_hash can only hash up to 7 dimensions.\");\n\n    // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence\n    // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional\n    // coordinates.\n    constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };\n\n    uint32_t result = 0;\n    #pragma unroll\n    for (uint32_t i = 0; i < D; ++i) {\n        result ^= pos_grid[i] * primes[i];\n    }\n\n    return result;\n}\n\n\ntemplate <uint32_t D, uint32_t C>\n__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {\n    uint32_t stride = 1;\n    uint32_t index = 0;\n\n    #pragma unroll\n    for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {\n        index += pos_grid[d] * stride;\n        stride *= align_corners ? resolution: (resolution + 1);\n    }\n\n    // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.\n    // gridtype: 0 == hash, 1 == tiled\n    if (gridtype == 0 && stride > hashmap_size) {\n        index = fast_hash<D>(pos_grid);\n    }\n\n    return (index % hashmap_size) * C + ch;\n}\n\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C>\n__global__ void kernel_grid(\n    const float * __restrict__ inputs, \n    const scalar_t * __restrict__ grid, \n    const int * __restrict__ offsets, \n    scalar_t * __restrict__ outputs, \n    const uint32_t B, const uint32_t L, const float S, const uint32_t H,\n    const bool calc_grad_inputs, \n    scalar_t * __restrict__ dy_dx,\n    const uint32_t gridtype,\n    const bool align_corners\n) {\n    const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;\n    \n    if (b >= B) return;\n\n    const uint32_t level = blockIdx.y;\n    \n    // locate\n    grid += (uint32_t)offsets[level] * C;\n    inputs += b * D;\n    outputs += level * B * C + b * C;\n\n    // check input range (should be in [0, 1])\n    bool flag_oob = false;\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        if (inputs[d] < 0 || inputs[d] > 1) {\n            flag_oob = true;\n        }\n    }\n    // if input out of bound, just set output to 0\n    if (flag_oob) {\n        #pragma unroll\n        for (uint32_t ch = 0; ch < C; ch++) {\n            outputs[ch] = 0; \n        }\n        if (calc_grad_inputs) {\n            dy_dx += b * D * L * C + level * D * C; // B L D C\n            #pragma unroll\n            for (uint32_t d = 0; d < D; d++) {\n                #pragma unroll\n                for (uint32_t ch = 0; ch < C; ch++) {\n                    dy_dx[d * C + ch] = 0; \n                }       \n            }\n        }\n        return;\n    }\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    const float scale = exp2f(level * S) * H - 1.0f;\n    const uint32_t resolution = (uint32_t)ceil(scale) + 1;\n    \n    // calculate coordinate\n    float pos[D];\n    uint32_t pos_grid[D];\n\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);\n        pos_grid[d] = floorf(pos[d]);\n        pos[d] -= (float)pos_grid[d];\n    }\n\n    //printf(\"[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\\n\", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);\n\n    // interpolate\n    scalar_t results[C] = {0}; // temp results in register\n\n    #pragma unroll\n    for (uint32_t idx = 0; idx < (1 << D); idx++) {\n        float w = 1;\n        uint32_t pos_grid_local[D];\n\n        #pragma unroll\n        for (uint32_t d = 0; d < D; d++) {\n            if ((idx & (1 << d)) == 0) {\n                w *= 1 - pos[d];\n                pos_grid_local[d] = pos_grid[d];\n            } else {\n                w *= pos[d];\n                pos_grid_local[d] = pos_grid[d] + 1;\n            }\n        }\n\n        uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);\n\n        // writing to register (fast)\n        #pragma unroll\n        for (uint32_t ch = 0; ch < C; ch++) {\n            results[ch] += w * grid[index + ch];\n        }\n\n        //printf(\"[b=%d, l=%d] int %d, idx %d, w %f, val %f\\n\", b, level, idx, index, w, grid[index]);\n    }    \n\n    // writing to global memory (slow)\n    #pragma unroll\n    for (uint32_t ch = 0; ch < C; ch++) {\n        outputs[ch] = results[ch]; \n    }\n\n    // prepare dy_dx for calc_grad_inputs\n    // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9\n    if (calc_grad_inputs) {\n\n        dy_dx += b * D * L * C + level * D * C; // B L D C\n\n        #pragma unroll\n        for (uint32_t gd = 0; gd < D; gd++) {\n\n            scalar_t results_grad[C] = {0};\n\n            #pragma unroll\n            for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {\n                float w = scale;\n                uint32_t pos_grid_local[D];\n\n                #pragma unroll\n                for (uint32_t nd = 0; nd < D - 1; nd++) {\n                    const uint32_t d = (nd >= gd) ? (nd + 1) : nd;\n\n                    if ((idx & (1 << nd)) == 0) {\n                        w *= 1 - pos[d];\n                        pos_grid_local[d] = pos_grid[d];\n                    } else {\n                        w *= pos[d];\n                        pos_grid_local[d] = pos_grid[d] + 1;\n                    }\n                }\n\n                pos_grid_local[gd] = pos_grid[gd];\n                uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);\n                pos_grid_local[gd] = pos_grid[gd] + 1;\n                uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);\n\n                #pragma unroll\n                for (uint32_t ch = 0; ch < C; ch++) {\n                    results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);\n                }\n            }\n\n            #pragma unroll\n            for (uint32_t ch = 0; ch < C; ch++) {\n                dy_dx[gd * C + ch] = results_grad[ch];\n            }\n        }\n    }\n}\n\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>\n__global__ void kernel_grid_backward(\n    const scalar_t * __restrict__ grad,\n    const float * __restrict__ inputs, \n    const scalar_t * __restrict__ grid, \n    const int * __restrict__ offsets, \n    scalar_t * __restrict__ grad_grid, \n    const uint32_t B, const uint32_t L, const float S, const uint32_t H,\n    const uint32_t gridtype,\n    const bool align_corners\n) {\n    const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;\n    if (b >= B) return;\n\n    const uint32_t level = blockIdx.y;\n    const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;\n\n    // locate\n    grad_grid += offsets[level] * C;\n    inputs += b * D;\n    grad += level * B * C + b * C + ch; // L, B, C\n\n    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];\n    const float scale = exp2f(level * S) * H - 1.0f;\n    const uint32_t resolution = (uint32_t)ceil(scale) + 1;\n\n    // check input range (should be in [0, 1])\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        if (inputs[d] < 0 || inputs[d] > 1) {\n            return; // grad is init as 0, so we simply return.\n        }\n    }\n\n    // calculate coordinate\n    float pos[D];\n    uint32_t pos_grid[D];\n\n    #pragma unroll\n    for (uint32_t d = 0; d < D; d++) {\n        pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);\n        pos_grid[d] = floorf(pos[d]);\n        pos[d] -= (float)pos_grid[d];\n    }\n\n    scalar_t grad_cur[N_C] = {0}; // fetch to register\n    #pragma unroll\n    for (uint32_t c = 0; c < N_C; c++) {\n        grad_cur[c] = grad[c];\n    }\n\n    // interpolate\n    #pragma unroll\n    for (uint32_t idx = 0; idx < (1 << D); idx++) {\n        float w = 1;\n        uint32_t pos_grid_local[D];\n\n        #pragma unroll\n        for (uint32_t d = 0; d < D; d++) {\n            if ((idx & (1 << d)) == 0) {\n                w *= 1 - pos[d];\n                pos_grid_local[d] = pos_grid[d];\n            } else {\n                w *= pos[d];\n                pos_grid_local[d] = pos_grid[d] + 1;\n            }\n        }\n\n        uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);\n\n        // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0\n        // TODO: use float which is better than __half, if N_C % 2 != 0\n        if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {\n            #pragma unroll\n            for (uint32_t c = 0; c < N_C; c += 2) {\n                // process two __half at once (by interpreting as a __half2)\n                __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};\n                atomicAdd((__half2*)&grad_grid[index + c], v);\n            }\n        // float, or __half when N_C % 2 != 0 (which means C == 1)\n        } else {\n            #pragma unroll\n            for (uint32_t c = 0; c < N_C; c++) {\n                atomicAdd(&grad_grid[index + c], w * grad_cur[c]);\n            }\n        }\n    }    \n}\n\n\ntemplate <typename scalar_t, uint32_t D, uint32_t C>\n__global__ void kernel_input_backward(\n    const scalar_t * __restrict__ grad,\n    const scalar_t * __restrict__ dy_dx,  \n    scalar_t * __restrict__ grad_inputs, \n    uint32_t B, uint32_t L\n) {\n    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n    if (t >= B * D) return;\n\n    const uint32_t b = t / D;\n    const uint32_t d = t - b * D;\n\n    dy_dx += b * L * D * C;\n\n    scalar_t result = 0;\n    \n    # pragma unroll\n    for (int l = 0; l < L; l++) {\n        # pragma unroll\n        for (int ch = 0; ch < C; ch++) {\n            result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];\n        }\n    }\n\n    grad_inputs[t] = result;\n}\n\n\ntemplate <typename scalar_t, uint32_t D>\nvoid kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {\n    static constexpr uint32_t N_THREAD = 512;\n    const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };\n    switch (C) {\n        case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;\n        case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;\n        case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;\n        case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;\n        default: throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)\n// H: base resolution\n// dy_dx: [B, L * D * C]\ntemplate <typename scalar_t>\nvoid grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {\n    switch (D) {\n        case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;\n        case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;\n        default: throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n    \n}\n\ntemplate <typename scalar_t, uint32_t D>\nvoid kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {\n    static constexpr uint32_t N_THREAD = 256;\n    const uint32_t N_C = std::min(2u, C); // n_features_per_thread\n    const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };\n    switch (C) {\n        case 1: \n            kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); \n            if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 2: \n            kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);\n            if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 4: \n            kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);\n            if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        case 8: \n            kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);\n            if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);\n            break;\n        default: throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\n\n// grad: [L, B, C], float\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// grad_embeddings: [sO, C]\n// H: base resolution\ntemplate <typename scalar_t>\nvoid grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {\n    switch (D) {\n        case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;\n        case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;\n        default: throw std::runtime_error{\"GridEncoding: C must be 1, 2, 4, or 8.\"};\n    }\n}\n\n\n\nvoid grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) {\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(embeddings);\n    CHECK_CUDA(offsets);\n    CHECK_CUDA(outputs);\n    CHECK_CUDA(dy_dx);\n    \n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(embeddings);\n    CHECK_CONTIGUOUS(offsets);\n    CHECK_CONTIGUOUS(outputs);\n    CHECK_CONTIGUOUS(dy_dx);\n\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(embeddings);\n    CHECK_IS_INT(offsets);\n    CHECK_IS_FLOATING(outputs);\n    CHECK_IS_FLOATING(dy_dx);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    embeddings.scalar_type(), \"grid_encode_forward\", ([&] {\n        grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype, align_corners);\n    }));\n}\n\nvoid grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) {\n    CHECK_CUDA(grad);\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(embeddings);\n    CHECK_CUDA(offsets);\n    CHECK_CUDA(grad_embeddings);\n    CHECK_CUDA(dy_dx);\n    CHECK_CUDA(grad_inputs);\n    \n    CHECK_CONTIGUOUS(grad);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(embeddings);\n    CHECK_CONTIGUOUS(offsets);\n    CHECK_CONTIGUOUS(grad_embeddings);\n    CHECK_CONTIGUOUS(dy_dx);\n    CHECK_CONTIGUOUS(grad_inputs);\n\n    CHECK_IS_FLOATING(grad);\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(embeddings);\n    CHECK_IS_INT(offsets);\n    CHECK_IS_FLOATING(grad_embeddings);\n    CHECK_IS_FLOATING(dy_dx);\n    CHECK_IS_FLOATING(grad_inputs);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    grad.scalar_type(), \"grid_encode_backward\", ([&] {\n        grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);\n    }));\n    \n}\n"
  },
  {
    "path": "gridencoder/src/gridencoder.h",
    "content": "#ifndef _HASH_ENCODE_H\n#define _HASH_ENCODE_H\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [0, 1]\n// embeddings: [sO, C], float\n// offsets: [L + 1], uint32_t\n// outputs: [B, L * C], float\n// H: base resolution\nvoid grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners);\nvoid grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners);\n\n#endif"
  },
  {
    "path": "just_train_tea/network.py",
    "content": "import torch\nfrom time import time\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tools.encoding import get_encoder\nfrom tools.activation import trunc_exp\nfrom .renderer import NeRFRenderer\nimport raymarching\n\n\nclass NeRFNetwork(NeRFRenderer):\n    def __init__(\n        self,\n        encoding=\"hashgrid\",\n        encoding_dir=\"sphere_harmonics\",\n        encoding_bg=\"hashgrid\",\n        num_layers=2,\n        hidden_dim=64,\n        geo_feat_dim=15,\n        num_layers_color=3,\n        hidden_dim_color=64,\n        num_layers_bg=2,\n        hidden_dim_bg=64,\n        bound=1,\n        model_type=\"hash\",\n        args=None,\n        is_teacher=False,\n        **kwargs,\n    ):\n        super().__init__(bound, **kwargs)\n        # sigma network\n        assert model_type in [\"hash\", \"mlp\", \"vm\", \"tensors\"]\n        self.is_teacher = is_teacher\n        self.num_layers = num_layers\n        self.hidden_dim = hidden_dim\n        self.geo_feat_dim = geo_feat_dim\n        self.args = args\n        self.opt = args\n        self.model_type = model_type\n\n        self.plenoxel_degree = args.plenoxel_degree\n        self.plenoxel_res = eval(args.plenoxel_res)\n\n        assert len(self.plenoxel_res) == 3\n\n        self.encoder, self.in_dim = get_encoder(\n            encoding,\n            desired_resolution=2048 * bound,\n            num_levels=14,\n        )\n\n        if \"hash\" != self.model_type:\n            self.encoder = None\n\n        if self.model_type == \"mlp\":\n            self.encoder_nerf_pe, self.in_dim_nerf = get_encoder(\n                encoding=\"frequency\", multires=self.args.PE\n            )\n            self.skips = self.args.skip\n            self.nerf_layer_num = self.args.nerf_layer_num\n            W = self.args.nerf_layer_wide\n            self.nerf_mlp = [nn.Linear(self.in_dim_nerf, W)]\n            for i in range(self.nerf_layer_num - 2):\n                if i != self.skips:\n                    self.nerf_mlp.append(nn.Linear(W, W))\n                else:\n                    self.nerf_mlp.append(nn.Linear(W + self.in_dim_nerf, W))\n            self.nerf_mlp.append(nn.Linear(W, self.in_dim))\n            self.nerf_mlp = nn.ModuleList(self.nerf_mlp)\n\n        elif self.model_type == \"vm\":\n            self.sigma_rank = [16] * 3\n            self.color_rank = [48] * 3\n            self.color_feat_dim = 15  # geo_feat_dim\n            self.mat_ids = [[0, 1], [0, 2], [1, 2]]\n            self.vec_ids = [2, 1, 0]\n            self.resolution = [self.opt.resolution0] * 3\n            # mat: paralist[1,16,res0,res0] repeat 3   vec: paralist[1,16,res0,1] repeat 3; repeat3 because decompose 3D grid [H, W, D] to three 2D mat [H, W], [H,D], [W, D] or decompose to three 1D vec [H], [W], [D]\n            self.sigma_mat, self.sigma_vec = self.init_one_vm(\n                self.sigma_rank, self.resolution\n            )\n            # mat: paralist[1,48,res0,res0] repeat 3   vec: paralist[1,48,res0,1] repeat 3\n            self.color_mat, self.color_vec = self.init_one_vm(\n                self.color_rank, self.resolution\n            )\n            # Linear(in_features=144, out_features=27)\n            self.basis_mat = nn.Linear(\n                sum(self.color_rank), self.color_feat_dim, bias=False\n            )\n        elif self.model_type == \"tensors\":\n            self.init_plenoxel_volume(\n                s=0.02,\n                fea_dim=self.plenoxel_degree ** 2 * 3 + 1,\n                volume=self.plenoxel_res,\n            )\n\n        elif self.model_type == \"hash\":\n            pass\n        else:\n            raise ValueError(f\"error model_type:{self.model_type}\")\n\n        if self.model_type != \"vm\" and self.model_type != \"tensors\":\n            sigma_net = []\n            for l in range(num_layers):\n                if l == 0:\n                    in_dim = self.in_dim\n                else:\n                    in_dim = hidden_dim\n\n                if l == num_layers - 1:\n                    out_dim = (\n                        1 + self.geo_feat_dim\n                    )  # 1 sigma + 15 SH features for color\n                else:\n                    out_dim = hidden_dim\n\n                sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n            self.sigma_net = nn.ModuleList(sigma_net)\n\n        # color network\n        self.num_layers_color = num_layers_color\n        self.hidden_dim_color = hidden_dim_color\n        # self.encoder_dir, self.in_dim_dir = get_encoder(encoding=encoding_dir)\n        if self.model_type == \"tensors\":\n            self.encoder_dir, self.in_dim_dir = get_encoder(\n                encoding=\"sphere_harmonics\",\n                degree=self.plenoxel_degree,\n            )\n\n        else:\n            self.encoder_dir, self.in_dim_dir = get_encoder(\n                encoding=encoding_dir, input_dim=3, multires=2\n            )\n\n        if self.model_type != \"tensors\":\n            color_net = []\n            for l in range(num_layers_color):\n                if l == 0:\n                    in_dim = self.in_dim_dir + self.geo_feat_dim\n                else:\n                    in_dim = hidden_dim\n\n                if l == num_layers_color - 1:\n                    out_dim = 3  # 3 rgb\n                else:\n                    out_dim = hidden_dim\n\n                color_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n            self.color_net = nn.ModuleList(color_net)\n\n        # background network\n        if self.bg_radius > 0:\n            self.num_layers_bg = num_layers_bg\n            self.hidden_dim_bg = hidden_dim_bg\n            self.encoder_bg, self.in_dim_bg = get_encoder(\n                encoding_bg,\n                input_dim=2,\n                num_levels=4,\n                log2_hashmap_size=19,\n                desired_resolution=2048,\n            )  # much smaller hashgrid\n\n            bg_net = []\n            for l in range(num_layers_bg):\n                if l == 0:\n                    in_dim = self.in_dim_bg + self.in_dim_dir\n                else:\n                    in_dim = hidden_dim_bg\n\n                if l == num_layers_bg - 1:\n                    out_dim = 3  # 3 rgb\n                else:\n                    out_dim = hidden_dim_bg\n\n                bg_net.append(nn.Linear(in_dim, out_dim, bias=False))\n\n            self.bg_net = nn.ModuleList(bg_net)\n        else:\n            self.bg_net = None\n\n    def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128, 128]):\n        tensor = []\n        tensor.append(\n            torch.nn.Parameter(\n                s * torch.randn((1, fea_dim, volume[0], volume[1], volume[2]))\n            )\n        )\n        self.tensor_volume = torch.nn.ParameterList(tensor).cuda()\n\n    def init_one_vm(self, n_component, resolution, scale=0.1):\n        # self.mat_ids = [[0, 1], [0, 2], [1, 2]]  self.vec_ids = [2, 1, 0]\n        mat, vec = [], []\n\n        for i in range(len(self.vec_ids)):\n            vec_id = self.vec_ids[i]\n            mat_id_0, mat_id_1 = self.mat_ids[i]\n            mat.append(\n                nn.Parameter(\n                    scale\n                    * torch.randn(\n                        (1, n_component[i], resolution[mat_id_1], resolution[mat_id_0])\n                    )\n                )\n            )  # [1, R, H, W]\n            vec.append(\n                nn.Parameter(\n                    scale * torch.randn((1, n_component[i], resolution[vec_id], 1))\n                )\n            )  # [1, R, D, 1] (fake 2d to use grid_sample)\n\n        return nn.ParameterList(mat), nn.ParameterList(vec)\n\n    def get_sigma_feat(self, x):\n        # x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode)\n        # self.mat_ids = [[0, 1], [0, 2], [1, 2]]  self.vec_ids = [2, 1, 0]\n        N = x.shape[0]\n\n        # plane + line basis\n        mat_coord = (\n            torch.stack(\n                (\n                    x[..., self.mat_ids[0]],\n                    x[..., self.mat_ids[1]],\n                    x[..., self.mat_ids[2]],\n                )\n            )\n            .detach()\n            .view(3, -1, 1, 2)\n        )  # [3, N, 1, 2]\n        vec_coord = torch.stack(\n            (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])\n        )\n        vec_coord = (\n            torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)\n            .detach()\n            .view(3, -1, 1, 2)\n        )  # [3, N, 1, 2], fake 2d coord\n\n        sigma_feat = torch.zeros(\n            [\n                N,\n            ],\n            device=x.device,\n        )\n\n        for i in range(len(self.sigma_mat)):\n            mat_feat = F.grid_sample(\n                self.sigma_mat[i], mat_coord[[i]], align_corners=True\n            ).view(\n                -1, N\n            )  # [1, R, N, 1] --> [R, N]\n            vec_feat = F.grid_sample(\n                self.sigma_vec[i], vec_coord[[i]], align_corners=True\n            ).view(\n                -1, N\n            )  # [R, N]\n            sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0)\n\n        return sigma_feat\n\n    def get_color_feat(self, x):\n        # x: [N, 3], in [-1, 1]\n        N = x.shape[0]\n\n        # plane + line basis\n        mat_coord = (\n            torch.stack(\n                (\n                    x[..., self.mat_ids[0]],\n                    x[..., self.mat_ids[1]],\n                    x[..., self.mat_ids[2]],\n                )\n            )\n            .detach()\n            .view(3, -1, 1, 2)\n        )  # [3, N, 1, 2]\n        vec_coord = torch.stack(\n            (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])\n        )\n        vec_coord = (\n            torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)\n            .detach()\n            .view(3, -1, 1, 2)\n        )  # [3, N, 1, 2], fake 2d coord\n\n        mat_feat, vec_feat = [], []\n        for i in range(len(self.color_mat)):\n            mat_feat.append(\n                F.grid_sample(\n                    self.color_mat[i], mat_coord[[i]], align_corners=True\n                ).view(-1, N)\n            )  # [1, R, N, 1] --> [R, N]\n            vec_feat.append(\n                F.grid_sample(\n                    self.color_vec[i], vec_coord[[i]], align_corners=True\n                ).view(-1, N)\n            )  # [R, N]\n\n        mat_feat = torch.cat(mat_feat, dim=0)  # [3 * R, N]\n        vec_feat = torch.cat(vec_feat, dim=0)  # [3 * R, N]\n\n        color_feat = self.basis_mat(\n            (mat_feat * vec_feat).T\n        )  # [N, 3R] --> [N, color_feat_dim]\n\n        return color_feat\n\n    def compute_plenoxel_fea(self, x):\n        composed = self.tensor_volume[0]\n        composed = (\n            F.grid_sample(composed, x.view(1, 1, -1, 1, 3), align_corners=True)\n            .view(-1, x.shape[0])\n            .permute(1, 0)\n        )\n        return composed  # [N, fea_dim]\n\n    def forward_nerf_mlp(self, x):\n        x = self.encoder_nerf_pe(x)\n        in_pts = x\n        for i in range(len(self.nerf_mlp)):\n            x = self.nerf_mlp[i](x)\n            if i != len(self.nerf_mlp) - 1:\n                x = F.relu(x, inplace=True)\n            if i == self.skips:\n                x = torch.cat([in_pts, x], -1)\n        return x\n\n    def forward(self, x, d):\n        # x: [N, 3], in [-bound, bound]  d: [N, 3], nomalized in [-1, 1]\n        # sigma\n        if self.model_type == \"hash\":\n            x = self.encoder(\n                x, bound=self.bound\n            )  # out_x[N, 28=num_levels * fea_per_level]\n        elif self.model_type == \"mlp\":\n            x = self.forward_nerf_mlp(x)  # 28\n        elif self.model_type == \"vm\":\n            x = (\n                2\n                * (x - self.aabb_train[:3])\n                / (self.aabb_train[3:] - self.aabb_train[:3])\n                - 1\n            )  # x:[N, 3]\n            sigma_feat = self.get_sigma_feat(x)  # sigma_feat:[N]\n            color_feat = self.get_color_feat(x)  # color_feat:[N, 15]\n            sigma_feat = torch.clamp(\n                sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max\n            )\n            # color_feat = torch.clamp(color_feat, self.args.sigma_clip_min, self.args.sigma_clip_max)\n            self.feature_sigma_color = torch.cat(\n                [sigma_feat.unsqueeze(-1), color_feat], dim=-1\n            )\n            self.sigma_l = sigma_feat\n            sigma = trunc_exp(sigma_feat)  # sigma:[N]\n            enc_d = self.encoder_dir(d)  # enc_d:[N, 16]\n            h = torch.cat([enc_d, color_feat], dim=-1)  # h:[N, 16+15]\n            for l in range(self.num_layers_color):\n                h = self.color_net[l](h)\n                if l != self.num_layers_color - 1:\n                    h = F.relu(h, inplace=True)\n\n            color = torch.sigmoid(h)\n            self.color_l = color\n\n            return sigma, color\n        elif self.model_type == \"tensors\":\n            x = (\n                2\n                * (x - self.aabb_train[:3])\n                / (self.aabb_train[3:] - self.aabb_train[:3])\n                - 1\n            )  # x:[N, 3]\n            x = self.compute_plenoxel_fea(x)\n            h = x\n            sigma = torch.clamp(\n                h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max\n            )\n            self.sigma_l = sigma\n            sigma = trunc_exp(sigma)\n            self.sigma = sigma\n            sh = h[..., 1:].view(\n                -1, 3, self.plenoxel_degree ** 2\n            )  # [N, 3, 9]   ## .permute(1, 0, 2)  # [B, 27]-->[9, B, 3]\n            enc_d = self.encoder_dir(d).unsqueeze(1)  # [N, 9]-->[N,1,9]\n            color = (sh * enc_d).sum(-1)  # [N, 3]\n            color = torch.sigmoid(color)\n            self.feature_sigma_color = None\n            self.color_l = color\n            return sigma, color\n\n        else:\n            raise ValueError(f\"not illegal model_type:{self.model_type}\")\n\n        h = x\n        for l in range(self.num_layers):\n            h = self.sigma_net[l](h)\n            if l != self.num_layers - 1:\n                h = F.relu(h, inplace=True)\n        h[..., 0] = torch.clamp(\n            h[..., 0].clone(), self.args.sigma_clip_min, self.args.sigma_clip_max\n        )\n        # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)\n        self.feature_sigma_color = h\n        self.sigma_l = h[..., 0]\n        sigma = trunc_exp(h[..., 0])  # sigma: [n]\n        geo_feat = h[..., 1:]  # geo_feat: [n, 15]\n\n        d = self.encoder_dir(d)  # d: [n, 16]\n        h = torch.cat([d, geo_feat], dim=-1)  # h: [n, 15+16]\n        for l in range(self.num_layers_color):\n            h = self.color_net[l](h)\n            if l != self.num_layers_color - 1:\n                h = F.relu(h, inplace=True)\n\n        color = torch.sigmoid(h)\n        self.color_l = color\n        return sigma, color\n\n    def density(self, x):\n        # x: [N, 3], in [-bound, bound]\n        if self.model_type == \"hash\":\n            x = self.encoder(\n                x, bound=self.bound\n            )  # out_x[N, 32=num_levels * fea_per_level]\n        elif self.model_type == \"mlp\":\n            x = self.forward_nerf_mlp(x)\n        elif self.model_type == \"vm\":\n            x = (\n                2\n                * (x - self.aabb_train[:3])\n                / (self.aabb_train[3:] - self.aabb_train[:3])\n                - 1\n            )\n            sigma_feat = self.get_sigma_feat(x)\n            sigma_feat = torch.clamp(\n                sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max\n            )\n            sigma = trunc_exp(sigma_feat)\n            return {\"sigma\": sigma}\n        elif self.model_type == \"tensors\":\n            x = (\n                2\n                * (x - self.aabb_train[:3])\n                / (self.aabb_train[3:] - self.aabb_train[:3])\n                - 1\n            )  # x:[N, 3]\n            x = self.compute_plenoxel_fea(x)\n            h = x\n            # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)\n            sigma = trunc_exp(\n                torch.clamp(\n                    h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max\n                )\n            )\n            sigma = trunc_exp(h[..., 0])\n            return {\"sigma\": sigma}\n\n        else:\n            raise ValueError(f\"not illegal model_type:{self.model_type}\")\n\n        h = x\n        for l in range(self.num_layers):\n            h = self.sigma_net[l](h)\n            if l != self.num_layers - 1:\n                h = F.relu(h, inplace=True)\n\n        h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)\n        sigma = trunc_exp(h[..., 0])\n        geo_feat = h[..., 1:]\n\n        return {\n            \"sigma\": sigma,\n            \"geo_feat\": geo_feat,\n        }\n\n    def background(self, x, d):\n        assert 1 == 2\n        # x: [N, 2], in [-1, 1]\n\n        h = self.encoder_bg(x)  # [N, C]\n        d = self.encoder_dir(d)\n\n        h = torch.cat([d, h], dim=-1)\n        for l in range(self.num_layers_bg):\n            h = self.bg_net[l](h)\n            if l != self.num_layers_bg - 1:\n                h = F.relu(h, inplace=True)\n\n        # sigmoid activation for rgb\n        rgbs = torch.sigmoid(h)\n\n        return rgbs\n\n    # allow masked inference\n    def color(self, x, d, mask=None, geo_feat=None, **kwargs):\n        assert 1 == 2\n        # x: [N, 3] in [-bound, bound]\n        # mask: [N,], bool, indicates where we actually needs to compute rgb.\n\n        if mask is not None:\n            rgbs = torch.zeros(\n                mask.shape[0], 3, dtype=x.dtype, device=x.device\n            )  # [N, 3]\n            # in case of empty mask\n            if not mask.any():\n                return rgbs\n            x = x[mask]\n            d = d[mask]\n            geo_feat = geo_feat[mask]\n\n        d = self.encoder_dir(d)\n        h = torch.cat([d, geo_feat], dim=-1)\n        for l in range(self.num_layers_color):\n            h = self.color_net[l](h)\n            if l != self.num_layers_color - 1:\n                h = F.relu(h, inplace=True)\n\n        # sigmoid activation for rgb\n        h = torch.sigmoid(h)\n\n        if mask is not None:\n            rgbs[mask] = h.to(rgbs.dtype)  # fp16 --> fp32\n        else:\n            rgbs = h\n\n        return rgbs\n\n    # L1 penalty for loss\n    def density_loss(self):\n        loss = 0\n        for i in range(len(self.sigma_mat)):\n            loss = (\n                loss\n                + torch.mean(torch.abs(self.sigma_mat[i]))\n                + torch.mean(torch.abs(self.sigma_vec[i]))\n            )\n        return loss\n\n    # upsample utils\n    @torch.no_grad()\n    def upsample_params(self, mat, vec, resolution):\n\n        for i in range(len(self.vec_ids)):\n            vec_id = self.vec_ids[i]\n            mat_id_0, mat_id_1 = self.mat_ids[i]\n            mat[i] = nn.Parameter(\n                F.interpolate(\n                    mat[i].data,\n                    size=(resolution[mat_id_1], resolution[mat_id_0]),\n                    mode=\"bilinear\",\n                    align_corners=True,\n                )\n            )\n            vec[i] = nn.Parameter(\n                F.interpolate(\n                    vec[i].data,\n                    size=(resolution[vec_id], 1),\n                    mode=\"bilinear\",\n                    align_corners=True,\n                )\n            )\n\n    @torch.no_grad()\n    def upsample_model(self, resolution):\n        self.upsample_params(self.sigma_mat, self.sigma_vec, resolution)\n        self.upsample_params(self.color_mat, self.color_vec, resolution)\n        self.resolution = resolution\n\n    @torch.no_grad()\n    def shrink_model(self):\n        # shrink aabb_train and the model so it only represents the space inside aabb_train.\n\n        half_grid_size = self.bound / self.grid_size\n        thresh = min(self.density_thresh, self.mean_density)\n\n        valid_grid = self.density_grid[self.cascade - 1] > thresh  # [N]\n        valid_pos = raymarching.morton3D_invert(\n            torch.nonzero(valid_grid)\n        )  # [Nz] --> [Nz, 3], in [0, H - 1]\n        # plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf...\n        valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (\n            self.bound - half_grid_size\n        )  # [Nz, 3], in [-b+hgs, b-hgs]\n        min_pos = valid_pos.amin(0) - half_grid_size  # [3]\n        max_pos = valid_pos.amax(0) + half_grid_size  # [3]\n\n        # shrink model\n        reso = torch.LongTensor(self.resolution).to(self.aabb_train.device)\n        units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso\n        tl = (min_pos - self.aabb_train[:3]) / units\n        br = (max_pos - self.aabb_train[:3]) / units\n        tl = torch.round(tl).long().clamp(min=0)\n        br = torch.minimum(torch.round(br).long(), reso)\n\n        for i in range(len(self.vec_ids)):\n            vec_id = self.vec_ids[i]\n            mat_id_0, mat_id_1 = self.mat_ids[i]\n\n            self.sigma_vec[i] = nn.Parameter(\n                self.sigma_vec[i].data[..., tl[vec_id] : br[vec_id], :]\n            )\n            self.color_vec[i] = nn.Parameter(\n                self.color_vec[i].data[..., tl[vec_id] : br[vec_id], :]\n            )\n\n            self.sigma_mat[i] = nn.Parameter(\n                self.sigma_mat[i].data[\n                    ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]\n                ]\n            )\n            self.color_mat[i] = nn.Parameter(\n                self.color_mat[i].data[\n                    ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]\n                ]\n            )\n\n        self.aabb_train = torch.cat([min_pos, max_pos], dim=0)  # [6]\n\n        print(\n            f\"[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}\"\n        )\n        print(f\"[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}\")\n\n    # optimizer utils\n    def get_params(self, lr, lr2=1e-3):\n        if self.model_type == \"hash\":\n            params = [\n                {\"params\": self.encoder.parameters(), \"lr\": lr},\n                {\"params\": self.sigma_net.parameters(), \"lr\": lr},\n                {\"params\": self.encoder_dir.parameters(), \"lr\": lr},\n                {\"params\": self.color_net.parameters(), \"lr\": lr},\n            ]\n        elif self.model_type == \"mlp\":\n            params = [\n                {\"params\": self.sigma_net.parameters(), \"lr\": lr},\n                {\"params\": self.encoder_dir.parameters(), \"lr\": lr},\n                {\"params\": self.color_net.parameters(), \"lr\": lr},\n                {\"params\": self.nerf_mlp.parameters(), \"lr\": lr},\n            ]\n        elif self.model_type == \"vm\":\n            params = [\n                {\"params\": self.color_net.parameters(), \"lr\": lr2},\n                {\"params\": self.sigma_mat, \"lr\": lr},\n                {\"params\": self.sigma_vec, \"lr\": lr},\n                {\"params\": self.color_mat, \"lr\": lr},\n                {\"params\": self.color_vec, \"lr\": lr},\n                {\"params\": self.basis_mat.parameters(), \"lr\": lr2},\n            ]\n        elif self.model_type == \"tensors\":\n            params = [\n                {\"params\": self.tensor_volume.parameters(), \"lr\": lr},\n                {\"params\": self.encoder_dir.parameters(), \"lr\": lr},\n            ]\n\n        else:\n            raise ValueError(f\"not illegal model_type:{self.model_type}\")\n\n        if self.bg_radius > 0:\n            params.append({\"params\": self.encoder_bg.parameters(), \"lr\": lr})\n            params.append({\"params\": self.bg_net.parameters(), \"lr\": lr})\n\n        return params\n"
  },
  {
    "path": "just_train_tea/provider.py",
    "content": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport numpy as np\nfrom scipy.spatial.transform import Slerp, Rotation\n\nimport trimesh\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom .utils import get_rays, srgb_to_linear\n\n\n# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50\ndef nerf_matrix_to_ngp(pose, scale=0.33):\n    # for the fox dataset, 0.33 scales camera radius to ~ 2\n    new_pose = np.array(\n        [\n            [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],\n            [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],\n            [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],\n            [0, 0, 0, 1],\n        ],\n        dtype=np.float32,\n    )\n    return new_pose\n\n\ndef rand_poses(\n    size,\n    device,\n    radius=1,\n    theta_range=[np.pi / 3, 2 * np.pi / 3],\n    phi_range=[0, 2 * np.pi],\n):\n    \"\"\"generate random poses from an orbit camera\n    Args:\n        size: batch size of generated poses.\n        device: where to allocate the output.\n        radius: camera radius\n        theta_range: [min, max], should be in [0, \\pi]\n        phi_range: [min, max], should be in [0, 2\\pi]\n    Return:\n        poses: [size, 4, 4]\n    \"\"\"\n\n    def normalize(vectors):\n        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)\n\n    thetas = (\n        torch.rand(size, device=device) * (theta_range[1] - theta_range[0])\n        + theta_range[0]\n    )\n    phis = (\n        torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]\n    )\n\n    centers = torch.stack(\n        [\n            radius * torch.sin(thetas) * torch.sin(phis),\n            radius * torch.cos(thetas),\n            radius * torch.sin(thetas) * torch.cos(phis),\n        ],\n        dim=-1,\n    )  # [B, 3]\n\n    # lookat\n    forward_vector = -normalize(centers)\n    up_vector = (\n        torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)\n    )  # confused at the coordinate system...\n    right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))\n    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))\n\n    poses = (\n        torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)\n    )\n    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)\n    poses[:, :3, 3] = centers\n\n    return poses\n\n    def normalize(vectors):\n        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)\n\n    interval_nums = torch.tensor(\n        [i * 1 / (size - 1) for i in range(size)], dtype=torch.float32, device=device\n    )\n    thetas = interval_nums * (theta_range[1] - theta_range[0]) + theta_range[0]\n    phis = interval_nums * (phi_range[1] - phi_range[0]) + phi_range[0]\n\n    centers = torch.stack(\n        [\n            radius * torch.sin(thetas) * torch.sin(phis),\n            radius * torch.cos(thetas),\n            radius * torch.sin(thetas) * torch.cos(phis),\n        ],\n        dim=-1,\n    )  # [B, 3]\n\n    # lookat\n    forward_vector = -normalize(centers)\n    up_vector = (\n        torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)\n    )  # confused at the coordinate system...\n    right_vector = normalize(\n        torch.cross(forward_vector, up_vector, dim=-1)\n    )  # cross product\n    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))\n\n    poses = (\n        torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)\n    )\n    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)\n    poses[:, :3, 3] = centers\n\n    return poses\n\n\nclass NeRFDataset:\n    def __init__(self, opt, device, type=\"train\", downscale=1, n_test=10):\n        super().__init__()\n\n        self.opt = opt\n        self.args = opt\n        self.device = device\n        self.type = type  # train, val, test\n        self.downscale = downscale\n        self.root_path = opt.path\n        self.mode = opt.mode  # only support blender\n        self.preload = opt.preload  # preload data into GPU\n        self.scale = (\n            opt.scale\n        )  # camera radius scale to make sure camera are inside the bounding box.\n        self.bound = (\n            opt.bound\n        )  # bounding box half length, also used as the radius to random sample poses.\n        self.fp16 = opt.fp16  # if preload, load into fp16.\n\n        self.training = self.type in [\"train\", \"all\", \"trainval\"]\n        self.num_rays = self.opt.num_rays if self.training else -1\n\n        if self.mode == \"blender\":\n            if type == \"all\":\n                transform_paths = glob.glob(os.path.join(self.root_path, \"*.json\"))\n                transform = None\n                for transform_path in transform_paths:\n                    with open(transform_path, \"r\") as f:\n                        tmp_transform = json.load(f)\n                        if transform is None:\n                            transform = tmp_transform\n                        else:\n                            transform[\"frames\"].extend(tmp_transform[\"frames\"])\n            # load train and val split\n            elif type == \"trainval\":\n                with open(\n                    os.path.join(self.root_path, f\"transforms_train.json\"), \"r\"\n                ) as f:\n                    transform = json.load(f)\n                with open(\n                    os.path.join(self.root_path, f\"transforms_val.json\"), \"r\"\n                ) as f:\n                    transform_val = json.load(f)\n                transform[\"frames\"].extend(transform_val[\"frames\"])\n            # only load one specified split\n            else:\n                with open(\n                    os.path.join(self.root_path, f\"transforms_{type}.json\"), \"r\"\n                ) as f:\n                    transform = json.load(f)\n\n        else:\n            raise NotImplementedError(f\"unknown dataset mode: {self.mode}\")\n\n        # load image size\n        if \"h\" in transform and \"w\" in transform:\n            self.H = int(transform[\"h\"]) // downscale\n            self.W = int(transform[\"w\"]) // downscale\n        else:\n            # we have to actually read an image to get H and W later.\n            self.H = self.W = None\n        # read images\n        frames = transform[\"frames\"]\n        if True:\n            self.poses = []\n            self.images = []\n            for f in tqdm.tqdm(frames, desc=f\"Loading {type} data:\"):\n                f_path = os.path.join(self.root_path, f[\"file_path\"])\n                if (\n                    self.mode == \"blender\"\n                    and f_path[-4:].lower() != \".png\"\n                    and f_path[-4:].lower() != \".jpg\"\n                ):\n                    f_path += \".png\"  # so silly...\n                if not os.path.exists(f_path):\n                    continue\n                pose = np.array(f[\"transform_matrix\"], dtype=np.float32)  # [4, 4]\n                pose = nerf_matrix_to_ngp(pose, scale=self.scale)\n\n                image = cv2.imread(\n                    f_path, cv2.IMREAD_UNCHANGED\n                )  # [H, W, 3] o [H, W, 4]\n                if self.H is None or self.W is None:\n                    self.H = image.shape[0] // downscale\n                    self.W = image.shape[1] // downscale\n\n                # add support for the alpha channel as a mask.\n                if image.shape[-1] == 3:\n                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n                else:\n                    image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)\n\n                if image.shape[0] != self.H or image.shape[1] != self.W:\n                    image = cv2.resize(\n                        image, (self.W, self.H), interpolation=cv2.INTER_AREA\n                    )\n\n                image = image.astype(np.float32) / 255  # [H, W, 3/4]\n\n                self.poses.append(pose)\n                self.images.append(image)\n        self.poses = torch.from_numpy(np.stack(self.poses, axis=0))  # [N, 4, 4]\n        if self.images is not None:\n            self.images = torch.from_numpy(\n                np.stack(self.images, axis=0)\n            )  # [N, H, W, C]\n        self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()\n\n        if self.training and self.opt.error_map:\n            self.error_map = torch.ones(\n                [self.images.shape[0], 128 * 128], dtype=torch.float\n            )  # [B, 128 * 128], flattened for easy indexing, fixed resolution...\n        else:\n            self.error_map = None\n\n        if self.preload:\n            self.poses = self.poses.to(self.device)\n            if self.images is not None:\n                if self.fp16 and self.opt.color_space != \"linear\":\n                    dtype = torch.half\n                else:\n                    dtype = torch.float\n                self.images = self.images.to(dtype).to(self.device)\n            if self.error_map is not None:\n                self.error_map = self.error_map.to(self.device)\n\n        # load intrinsics\n        if \"fl_x\" in transform or \"fl_y\" in transform:\n            fl_x = (\n                transform[\"fl_x\"] if \"fl_x\" in transform else transform[\"fl_y\"]\n            ) / downscale\n            fl_y = (\n                transform[\"fl_y\"] if \"fl_y\" in transform else transform[\"fl_x\"]\n            ) / downscale\n        elif \"camera_angle_x\" in transform or \"camera_angle_y\" in transform:\n            # blender, assert in radians. already downscaled since we use H/W\n            fl_x = (\n                self.W / (2 * np.tan(transform[\"camera_angle_x\"] / 2))\n                if \"camera_angle_x\" in transform\n                else None\n            )\n            fl_y = (\n                self.H / (2 * np.tan(transform[\"camera_angle_y\"] / 2))\n                if \"camera_angle_y\" in transform\n                else None\n            )\n            if fl_x is None:\n                fl_x = fl_y\n            if fl_y is None:\n                fl_y = fl_x\n        else:\n            raise RuntimeError(\n                \"Failed to load focal length, please check the transforms.json!\"\n            )\n\n        cx = (transform[\"cx\"] / downscale) if \"cx\" in transform else (self.H / 2)\n        cy = (transform[\"cy\"] / downscale) if \"cy\" in transform else (self.W / 2)\n\n        self.intrinsics = np.array([fl_x, fl_y, cx, cy])\n\n    def collate(self, index):\n\n        B = len(index)  # a list of length 1\n        poses = self.poses[index].to(self.device)  # [B, 4, 4]\n\n        error_map = None if self.error_map is None else self.error_map[index]\n        rays = get_rays(\n            poses, self.intrinsics, self.H, self.W, self.num_rays, error_map\n        )\n        results = {\n            \"H\": self.H,\n            \"W\": self.W,\n            \"rays_o\": rays[\"rays_o\"],\n            \"rays_d\": rays[\"rays_d\"],\n        }\n\n        if self.images is not None:\n            images = self.images[index].to(self.device)  # [B, H, W, 3/4]\n            if self.training:\n                C = images.shape[-1]\n                images = torch.gather(\n                    images.view(B, -1, C), 1, torch.stack(C * [rays[\"inds\"]], -1)\n                )  # [B, N, 3/4]\n            results[\"images\"] = images\n\n        # need inds to update error_map\n        if error_map is not None:\n            results[\"index\"] = index\n            results[\"inds_coarse\"] = rays[\"inds_coarse\"]\n\n        return results\n\n    def dataloader(self):\n        size = len(self.poses)\n        loader = DataLoader(\n            list(range(size)),\n            batch_size=1,\n            collate_fn=self.collate,\n            shuffle=self.training,\n            num_workers=0,\n        )\n        loader._data = self\n        return loader\n"
  },
  {
    "path": "just_train_tea/renderer.py",
    "content": "import math\nimport trimesh\nimport numpy as np\nfrom time import time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport raymarching\nfrom .utils import custom_meshgrid\n\n\ndef sample_pdf(bins, weights, n_samples, det=False):\n    # This implementation is from NeRF\n    # bins: [B, T], old_z_vals\n    # weights: [B, T - 1], bin weights.\n    # return: [B, n_samples], new_z_vals\n\n    # Get pdf\n    weights = weights + 1e-5  # prevent nans\n    pdf = weights / torch.sum(weights, -1, keepdim=True)\n    cdf = torch.cumsum(pdf, -1)\n    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)\n    # Take uniform samples\n    if det:\n        u = torch.linspace(\n            0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples\n        ).to(weights.device)\n        u = u.expand(list(cdf.shape[:-1]) + [n_samples])\n    else:\n        u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)\n\n    # Invert CDF\n    u = u.contiguous()\n    inds = torch.searchsorted(cdf, u, right=True)\n    below = torch.max(torch.zeros_like(inds - 1), inds - 1)\n    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)\n    inds_g = torch.stack([below, above], -1)  # (B, n_samples, 2)\n\n    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]\n    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)\n    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)\n\n    denom = cdf_g[..., 1] - cdf_g[..., 0]\n    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)\n    t = (u - cdf_g[..., 0]) / denom\n    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])\n\n    return samples\n\n\ndef plot_pointcloud(pc, color=None):\n    # pc: [N, 3]\n    # color: [N, 3/4]\n    print(\"[visualize points]\", pc.shape, pc.dtype, pc.min(0), pc.max(0))\n    pc = trimesh.PointCloud(pc, color)\n    # axis\n    axes = trimesh.creation.axis(axis_length=4)\n    # sphere\n    sphere = trimesh.creation.icosphere(radius=1)\n    trimesh.Scene([pc, axes, sphere]).show()\n\n\nclass NeRFRenderer(nn.Module):\n    def __init__(\n        self,\n        bound=1,\n        cuda_ray=False,\n        density_scale=1,  # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.\n        min_near=0.2,\n        density_thresh=0.01,\n        bg_radius=-1,\n        grid_size=128,\n    ):\n        super().__init__()\n\n        print(\"\\n---------------\", grid_size, \"--------------\\n\")\n        self.bound = bound\n        self.cascade = 1 + math.ceil(math.log2(bound))\n        self.grid_size = grid_size\n        self.density_scale = density_scale\n        self.min_near = min_near\n        self.density_thresh = density_thresh\n        self.bg_radius = bg_radius  # radius of the background sphere.\n\n        # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)\n        # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.\n        aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])\n        aabb_infer = aabb_train.clone()\n        self.register_buffer(\"aabb_train\", aabb_train)\n        self.register_buffer(\"aabb_infer\", aabb_infer)\n\n        # extra state for cuda raymarching\n        self.cuda_ray = cuda_ray\n        if cuda_ray:\n            # density grid\n            density_grid = torch.zeros(\n                [self.cascade, self.grid_size ** 3]\n            )  # [CAS, H * H * H]\n            density_bitfield = torch.zeros(\n                self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8\n            )  # [CAS * H * H * H // 8]\n            self.register_buffer(\"density_grid\", density_grid)\n            self.register_buffer(\"density_bitfield\", density_bitfield)\n            self.mean_density = 0\n            self.iter_density = 0\n            # step counter\n            step_counter = torch.zeros(\n                16, 2, dtype=torch.int32\n            )  # 16 is hardcoded for averaging...\n            self.register_buffer(\"step_counter\", step_counter)\n            self.mean_count = 0\n            self.local_step = 0\n\n    def forward(self, x, d):\n        raise NotImplementedError()\n\n    # separated density and color query (can accelerate non-cuda-ray mode.)\n    def density(self, x):\n        raise NotImplementedError()\n\n    def color(self, x, d, mask=None, **kwargs):\n        raise NotImplementedError()\n\n    def reset_extra_state(self):\n        if not self.cuda_ray:\n            return\n        # density grid\n        self.density_grid.zero_()\n        self.mean_density = 0\n        self.iter_density = 0\n        # step counter\n        self.step_counter.zero_()\n        self.mean_count = 0\n        self.local_step = 0\n\n    def run(\n        self,\n        rays_o,\n        rays_d,\n        num_steps=128,\n        upsample_steps=128,\n        bg_color=None,\n        perturb=False,\n        **kwargs\n    ):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # bg_color: [3] in range [0, 1]\n        # return: image: [B, N, 3], depth: [B, N]\n\n        prefix = rays_o.shape[:-1]\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # N = B * N, in fact\n        device = rays_o.device\n\n        # choose aabb\n        aabb = self.aabb_train if self.training else self.aabb_infer\n\n        # sample steps\n        nears, fars = raymarching.near_far_from_aabb(\n            rays_o, rays_d, aabb, self.min_near\n        )\n        nears.unsqueeze_(-1)\n        fars.unsqueeze_(-1)\n\n        # print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')\n\n        z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(\n            0\n        )  # [1, T]\n        z_vals = z_vals.expand((N, num_steps))  # [N, T]\n        z_vals = nears + (fars - nears) * z_vals  # [N, T], in [nears, fars]\n\n        # perturb z_vals\n        sample_dist = (fars - nears) / num_steps\n        if perturb:\n            z_vals = (\n                z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist\n            )\n            # z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.\n\n        # generate xyzs\n        xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(\n            -1\n        )  # [N, 1, 3] * [N, T, 1] -> [N, T, 3]\n        xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:])  # a manual clip.\n\n        # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())\n\n        # query SDF and RGB\n        density_outputs = self.density(xyzs.reshape(-1, 3))\n\n        # sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]\n        for k, v in density_outputs.items():\n            density_outputs[k] = v.view(N, num_steps, -1)\n\n        # upsample z_vals (nerf-like)\n        if upsample_steps > 0:\n            with torch.no_grad():\n\n                deltas = z_vals[..., 1:] - z_vals[..., :-1]  # [N, T-1]\n                deltas = torch.cat(\n                    [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1\n                )\n\n                alphas = 1 - torch.exp(\n                    -deltas * self.density_scale * density_outputs[\"sigma\"].squeeze(-1)\n                )  # [N, T]\n                alphas_shifted = torch.cat(\n                    [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1\n                )  # [N, T+1]\n                weights = (\n                    alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]\n                )  # [N, T]\n\n                # sample new z_vals\n                z_vals_mid = z_vals[..., :-1] + 0.5 * deltas[..., :-1]  # [N, T-1]\n                new_z_vals = sample_pdf(\n                    z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training\n                ).detach()  # [N, t]\n\n                new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(\n                    -2\n                ) * new_z_vals.unsqueeze(\n                    -1\n                )  # [N, 1, 3] * [N, t, 1] -> [N, t, 3]\n                new_xyzs = torch.min(\n                    torch.max(new_xyzs, aabb[:3]), aabb[3:]\n                )  # a manual clip.\n\n            # only forward new points to save computation\n            new_density_outputs = self.density(new_xyzs.reshape(-1, 3))\n            # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]\n            for k, v in new_density_outputs.items():\n                new_density_outputs[k] = v.view(N, upsample_steps, -1)\n\n            # re-order\n            z_vals = torch.cat([z_vals, new_z_vals], dim=1)  # [N, T+t]\n            z_vals, z_index = torch.sort(z_vals, dim=1)\n\n            xyzs = torch.cat([xyzs, new_xyzs], dim=1)  # [N, T+t, 3]\n            xyzs = torch.gather(\n                xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)\n            )\n\n            for k in density_outputs:\n                tmp_output = torch.cat(\n                    [density_outputs[k], new_density_outputs[k]], dim=1\n                )\n                density_outputs[k] = torch.gather(\n                    tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)\n                )\n\n        deltas = z_vals[..., 1:] - z_vals[..., :-1]  # [N, T+t-1]\n        deltas = torch.cat(\n            [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1\n        )\n        alphas = 1 - torch.exp(\n            -deltas * self.density_scale * density_outputs[\"sigma\"].squeeze(-1)\n        )  # [N, T+t]\n        alphas_shifted = torch.cat(\n            [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1\n        )  # [N, T+t+1]\n        weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]  # [N, T+t]\n\n        dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)\n        for k, v in density_outputs.items():\n            density_outputs[k] = v.view(-1, v.shape[-1])\n\n        mask = weights > 1e-4  # hard coded\n        rgbs = self.color(\n            xyzs.reshape(-1, 3),\n            dirs.reshape(-1, 3),\n            mask=mask.reshape(-1),\n            **density_outputs\n        )\n        rgbs = rgbs.view(N, -1, 3)  # [N, T+t, 3]\n\n        # print(xyzs.shape, 'valid_rgb:', mask.sum().item())\n\n        # calculate weight_sum (mask)\n        weights_sum = weights.sum(dim=-1)  # [N]\n\n        # calculate depth\n        ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)\n        depth = torch.sum(weights * ori_z_vals, dim=-1)\n\n        # calculate color\n        image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2)  # [N, 3], in [0, 1]\n\n        # mix background color\n        if self.bg_radius > 0:\n            # use the bg model to calculate bg_color\n            polar = raymarching.polar_from_ray(\n                rays_o, rays_d, self.bg_radius\n            )  # [N, 2] in [-1, 1]\n            bg_color = self.background(polar, rays_d.reshape(-1, 3))  # [N, 3]\n        elif bg_color is None:\n            bg_color = 1\n\n        image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n\n        image = image.view(*prefix, 3)\n        depth = depth.view(*prefix)\n\n        # tmp: reg loss in mip-nerf 360\n        # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)\n        # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]\n        # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()\n\n        return {\n            \"depth\": depth,\n            \"image\": image,\n        }\n\n    def run_cuda(\n        self,\n        rays_o,\n        rays_d,\n        dt_gamma=0,\n        bg_color=None,\n        perturb=False,\n        force_all_rays=False,\n        max_steps=1024,\n        inherited_params=[],\n        **kwargs\n    ):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # return: image: [B, N, 3], depth: [B, N]\n\n        prefix = rays_o.shape[:-1]\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # N = B * N, in fact\n        device = rays_o.device\n\n        # pre-calculate near far\n        nears, fars = raymarching.near_far_from_aabb(\n            rays_o,\n            rays_d,\n            self.aabb_train if self.training else self.aabb_infer,\n            self.min_near,\n        )\n\n        # mix background color\n        if self.bg_radius > 0:\n            # use the bg model to calculate bg_color\n            polar = raymarching.polar_from_ray(\n                rays_o, rays_d, self.bg_radius\n            )  # [N, 2] in [-1, 1]\n            bg_color = self.background(polar, rays_d)  # [N, 3]\n        elif bg_color is None:\n            bg_color = 1\n\n        if self.training:  # XXX gurantee the inference\n            # setup counter\n            time1 = time()\n            counter = self.step_counter[self.local_step % 16]\n            counter.zero_()  # set to 0\n            self.local_step += 1\n            if (\n                self.args.render_stu_first\n            ):  # if stu first, then using stu to calculate xyzs, and tea will inherite the xyzs\n                if not self.is_teacher:\n                    xyzs, dirs, deltas, rays = raymarching.march_rays_train(\n                        rays_o,\n                        rays_d,\n                        self.bound,\n                        self.density_bitfield,\n                        self.cascade,\n                        self.grid_size,\n                        nears,\n                        fars,\n                        counter,\n                        self.mean_count,\n                        perturb,\n                        128,\n                        force_all_rays,\n                        dt_gamma,\n                        max_steps,\n                    )\n                    inherited_params = [xyzs, dirs, deltas, rays]\n                else:\n                    xyzs, dirs, deltas, rays = inherited_params\n            else:\n                if self.is_teacher:\n                    xyzs, dirs, deltas, rays = raymarching.march_rays_train(\n                        rays_o,\n                        rays_d,\n                        self.bound,\n                        self.density_bitfield,\n                        self.cascade,\n                        self.grid_size,\n                        nears,\n                        fars,\n                        counter,\n                        self.mean_count,\n                        perturb,\n                        128,\n                        force_all_rays,\n                        dt_gamma,\n                        max_steps,\n                    )\n                    inherited_params = [xyzs, dirs, deltas, rays]\n                else:\n                    xyzs, dirs, deltas, rays = inherited_params\n\n            # print('\\n', self.model_type, self.mean_count, self.mean_density)\n            # from IPython import embed; embed()\n            # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())\n            sigmas, rgbs = self(xyzs, dirs)\n            # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.\n            # sigmas = density_outputs['sigma']\n            # rgbs = self.color(xyzs, dirs, **density_outputs)\n            sigmas = self.density_scale * sigmas\n\n            # print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})')\n            time2 = time()\n\n            # special case for CCNeRF's residual learning\n            if len(sigmas.shape) == 2:\n                K = sigmas.shape[0]\n                depths = []\n                images = []\n                for k in range(K):\n                    weights_sum, depth, image = raymarching.composite_rays_train(\n                        sigmas[k], rgbs[k], deltas, rays\n                    )\n                    image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n                    depth = torch.clamp(depth - nears, min=0) / (fars - nears)\n                    images.append(image.view(*prefix, 3))\n                    depths.append(depth.view(*prefix))\n\n                depth = torch.stack(depths, axis=0)  # [K, B, N]\n                image = torch.stack(images, axis=0)  # [K, B, N, 3]\n\n            else:\n\n                weights_sum, depth, image = raymarching.composite_rays_train(\n                    sigmas, rgbs, deltas, rays\n                )\n                image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n                depth = torch.clamp(depth - nears, min=0) / (fars - nears)\n                image = image.view(*prefix, 3)\n                depth = depth.view(*prefix)\n\n        else:\n\n            # allocate outputs\n            # if use autocast, must init as half so it won't be autocasted and lose reference.\n            # dtype = torch.half if torch.is_autocast_enabled() else torch.float32\n            # output should always be float32! only network inference uses half.\n            dtype = torch.float32\n\n            weights_sum = torch.zeros(N, dtype=dtype, device=device)\n            depth = torch.zeros(N, dtype=dtype, device=device)\n            image = torch.zeros(N, 3, dtype=dtype, device=device)\n\n            n_alive = N\n            alive_counter = torch.zeros([1], dtype=torch.int32, device=device)\n\n            rays_alive = torch.zeros(\n                2, n_alive, dtype=torch.int32, device=device\n            )  # 2 is used to loop old/new\n            rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device)\n\n            step = 0\n            i = 0\n            while step < max_steps:\n\n                # count alive rays\n                if step == 0:\n                    # init rays at first step.\n                    torch.arange(n_alive, out=rays_alive[0])\n                    rays_t[0] = nears\n                else:\n                    alive_counter.zero_()\n                    raymarching.compact_rays(\n                        n_alive,\n                        rays_alive[i % 2],\n                        rays_alive[(i + 1) % 2],\n                        rays_t[i % 2],\n                        rays_t[(i + 1) % 2],\n                        alive_counter,\n                    )\n                    n_alive = alive_counter.item()  # must invoke D2H copy here\n\n                # exit loop\n                if n_alive <= 0:\n                    break\n\n                # decide compact_steps\n                n_step = max(min(N // n_alive, 8), 1)\n\n                xyzs, dirs, deltas = raymarching.march_rays(\n                    n_alive,\n                    n_step,\n                    rays_alive[i % 2],\n                    rays_t[i % 2],\n                    rays_o,\n                    rays_d,\n                    self.bound,\n                    self.density_bitfield,\n                    self.cascade,\n                    self.grid_size,\n                    nears,\n                    fars,\n                    128,\n                    perturb,\n                    dt_gamma,\n                    max_steps,\n                )\n\n                sigmas, rgbs = self(xyzs, dirs)\n                # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.\n                # sigmas = density_outputs['sigma']\n                # rgbs = self.color(xyzs, dirs, **density_outputs)\n                sigmas = self.density_scale * sigmas\n\n                raymarching.composite_rays(\n                    n_alive,\n                    n_step,\n                    rays_alive[i % 2],\n                    rays_t[i % 2],\n                    sigmas,\n                    rgbs,\n                    deltas,\n                    weights_sum,\n                    depth,\n                    image,\n                )\n\n                # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')\n\n                step += n_step\n                i += 1\n\n            image = image + (1 - weights_sum).unsqueeze(-1) * bg_color\n            depth = torch.clamp(depth - nears, min=0) / (fars - nears)\n            image = image.view(*prefix, 3)\n            depth = depth.view(*prefix)\n\n        # print('\\n--- render time:--- {:6f}  {:.6f}'.format(time2-time1, time()-time2))\n        return {\n            \"depth\": depth,\n            \"image\": image,\n            \"inherited_params\": inherited_params,\n        }\n\n    @torch.no_grad()\n    def mark_untrained_grid(self, poses, intrinsic, S=64):\n        # poses: [B, 4, 4]\n        # intrinsic: [3, 3]\n\n        if not self.cuda_ray:\n            return\n\n        if isinstance(poses, np.ndarray):\n            poses = torch.from_numpy(poses)\n\n        B = poses.shape[0]\n\n        fx, fy, cx, cy = intrinsic\n\n        X = torch.arange(\n            self.grid_size, dtype=torch.int32, device=self.density_grid.device\n        ).split(S)\n        Y = torch.arange(\n            self.grid_size, dtype=torch.int32, device=self.density_grid.device\n        ).split(S)\n        Z = torch.arange(\n            self.grid_size, dtype=torch.int32, device=self.density_grid.device\n        ).split(S)\n\n        count = torch.zeros_like(self.density_grid)\n        poses = poses.to(count.device)\n\n        # 5-level loop, forgive me...\n\n        for xs in X:\n            for ys in Y:\n                for zs in Z:\n\n                    # construct points\n                    xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                    coords = torch.cat(\n                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                        dim=-1,\n                    )  # [N, 3], in [0, 128)\n                    indices = raymarching.morton3D(coords).long()  # [N]\n                    world_xyzs = (\n                        2 * coords.float() / (self.grid_size - 1) - 1\n                    ).unsqueeze(\n                        0\n                    )  # [1, N, 3] in [-1, 1]\n\n                    # cascading\n                    for cas in range(self.cascade):\n                        bound = min(2 ** cas, self.bound)\n                        half_grid_size = bound / self.grid_size\n                        # scale to current cascade's resolution\n                        cas_world_xyzs = world_xyzs * (bound - half_grid_size)\n\n                        # split batch to avoid OOM\n                        head = 0\n                        while head < B:\n                            tail = min(head + S, B)\n\n                            # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)\n                            cam_xyzs = cas_world_xyzs - poses[\n                                head:tail, :3, 3\n                            ].unsqueeze(1)\n                            cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3]  # [S, N, 3]\n\n                            # query if point is covered by any camera\n                            mask_z = cam_xyzs[:, :, 2] > 0  # [S, N]\n                            mask_x = (\n                                torch.abs(cam_xyzs[:, :, 0])\n                                < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2\n                            )\n                            mask_y = (\n                                torch.abs(cam_xyzs[:, :, 1])\n                                < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2\n                            )\n                            mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1)  # [N]\n\n                            # update count\n                            count[cas, indices] += mask\n                            head += S\n\n        # mark untrained grid as -1\n        self.density_grid[count == 0] = -1\n\n        # print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}')\n\n    @torch.no_grad()\n    def update_extra_state(self, decay=0.95, S=128):\n        # call before each epoch to update extra states.\n\n        if not self.cuda_ray:\n            return\n\n        # update density grid\n        tmp_grid = -torch.ones_like(self.density_grid)\n\n        # full update.\n        if self.iter_density < 16:\n            # if True:\n            X = torch.arange(\n                self.grid_size, dtype=torch.int32, device=self.density_grid.device\n            ).split(S)\n            Y = torch.arange(\n                self.grid_size, dtype=torch.int32, device=self.density_grid.device\n            ).split(S)\n            Z = torch.arange(\n                self.grid_size, dtype=torch.int32, device=self.density_grid.device\n            ).split(S)\n\n            for xs in X:\n                for ys in Y:\n                    for zs in Z:\n                        # construct points\n                        xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                        coords = torch.cat(\n                            [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                            dim=-1,\n                        )  # [N, 3], in [0, 128)\n                        indices = raymarching.morton3D(coords).long()  # [N]\n                        xyzs = (\n                            2 * coords.float() / (self.grid_size - 1) - 1\n                        )  # [N, 3] in [-1, 1]\n\n                        # cascading\n                        for cas in range(self.cascade):\n                            bound = min(2 ** cas, self.bound)\n                            half_grid_size = bound / self.grid_size\n                            # scale to current cascade's resolution\n                            cas_xyzs = xyzs * (bound - half_grid_size)\n                            # add noise in [-hgs, hgs]\n                            cas_xyzs += (\n                                torch.rand_like(cas_xyzs) * 2 - 1\n                            ) * half_grid_size\n                            # query density\n                            sigmas = (\n                                self.density(cas_xyzs)[\"sigma\"].reshape(-1).detach()\n                            )\n                            sigmas *= self.density_scale\n                            # assign\n                            tmp_grid[cas, indices] = sigmas\n\n        # partial update (half the computation)\n        # TODO: why no need of maxpool ?\n        else:\n            N = self.grid_size ** 3 // 4  # H * H * H / 4\n            for cas in range(self.cascade):\n                # random sample some positions\n                coords = torch.randint(\n                    0, self.grid_size, (N, 3), device=self.density_grid.device\n                )  # [N, 3], in [0, 128)\n                indices = raymarching.morton3D(coords).long()  # [N]\n                # random sample occupied positions\n                occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(\n                    -1\n                )  # [Nz]\n                rand_mask = torch.randint(\n                    0,\n                    occ_indices.shape[0],\n                    [N],\n                    dtype=torch.long,\n                    device=self.density_grid.device,\n                )\n                occ_indices = occ_indices[\n                    rand_mask\n                ]  # [Nz] --> [N], allow for duplication\n                occ_coords = raymarching.morton3D_invert(occ_indices)  # [N, 3]\n                # concat\n                indices = torch.cat([indices, occ_indices], dim=0)\n                coords = torch.cat([coords, occ_coords], dim=0)\n                # same below\n                xyzs = (\n                    2 * coords.float() / (self.grid_size - 1) - 1\n                )  # [N, 3] in [-1, 1]\n                bound = min(2 ** cas, self.bound)\n                half_grid_size = bound / self.grid_size\n                # scale to current cascade's resolution\n                cas_xyzs = xyzs * (bound - half_grid_size)\n                # add noise in [-hgs, hgs]\n                cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size\n                # query density\n                sigmas = self.density(cas_xyzs)[\"sigma\"].reshape(-1).detach()\n                sigmas *= self.density_scale\n                # assign\n                tmp_grid[cas, indices] = sigmas\n\n        ## max-pool on tmp_grid for less aggressive culling [No significant improvement...]\n        # invalid_mask = tmp_grid < 0\n        # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1)\n        # tmp_grid[invalid_mask] = -1\n\n        # ema update\n        valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)\n        self.density_grid[valid_mask] = torch.maximum(\n            self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]\n        )\n        self.mean_density = torch.mean(\n            self.density_grid.clamp(min=0)\n        ).item()  # -1 non-training regions are viewed as 0 density.\n        self.iter_density += 1\n\n        # convert to bitfield\n        density_thresh = min(self.mean_density, self.density_thresh)\n        self.density_bitfield = raymarching.packbits(\n            self.density_grid, density_thresh, self.density_bitfield\n        )\n\n        ### update step counter\n        total_step = min(16, self.local_step)\n        if total_step > 0:\n            self.mean_count = int(\n                self.step_counter[:total_step, 0].sum().item() / total_step\n            )\n        self.local_step = 0\n\n        # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')\n\n    def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):\n        # rays_o, rays_d: [B, N, 3], assumes B == 1\n        # return: pred_rgb: [B, N, 3]\n\n        if self.cuda_ray:\n            _run = self.run_cuda\n        else:\n            _run = self.run\n\n        B, N = rays_o.shape[:2]\n        device = rays_o.device\n\n        # never stage when cuda_ray\n        if staged and not self.cuda_ray:\n            depth = torch.empty((B, N), device=device)\n            image = torch.empty((B, N, 3), device=device)\n\n            for b in range(B):\n                head = 0\n                while head < N:\n                    tail = min(head + max_ray_batch, N)\n                    results_ = _run(\n                        rays_o[b : b + 1, head:tail],\n                        rays_d[b : b + 1, head:tail],\n                        **kwargs\n                    )\n                    depth[b : b + 1, head:tail] = results_[\"depth\"]\n                    image[b : b + 1, head:tail] = results_[\"image\"]\n                    head += max_ray_batch\n            results = {}\n            results[\"depth\"] = depth\n            results[\"image\"] = image\n\n        else:\n            results = _run(rays_o, rays_d, **kwargs)\n\n        return results\n"
  },
  {
    "path": "just_train_tea/utils.py",
    "content": "import os\nimport lpips\nimport glob\nimport tqdm\nimport math\nimport random\nimport warnings\nimport tensorboardX\n\nimport numpy as np\nimport pandas as pd\n\nimport time\nfrom datetime import datetime\n\nimport cv2\nimport matplotlib.pyplot as plt\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nimport torch.distributed as dist\nfrom torch.utils.data import Dataset, DataLoader\n\nimport trimesh\nimport mcubes\nfrom rich.console import Console\nfrom torch_ema import ExponentialMovingAverage\n\nfrom packaging import version as pver\n\ndevice = torch.device(\"cuda\")\n\n\ndef custom_meshgrid(*args):\n    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid\n    if pver.parse(torch.__version__) < pver.parse(\"1.10\"):\n        return torch.meshgrid(*args)\n    else:\n        return torch.meshgrid(*args, indexing=\"ij\")\n\n\n@torch.jit.script\ndef linear_to_srgb(x):\n    return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)\n\n\n@torch.jit.script\ndef srgb_to_linear(x):\n    return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)\n\n\ndef compute_ssim(\n    img0,\n    img1,\n    max_val,\n    filter_size=11,\n    filter_sigma=1.5,\n    k1=0.01,\n    k2=0.03,\n    return_map=False,\n):\n    \"\"\"Computes SSIM from two images.\n    This function was modeled after tf.image.ssim, and should produce comparable\n    output.\n    Args:\n      img0: torch.tensor. An image of size [..., width, height, num_channels].\n      img1: torch.tensor. An image of size [..., width, height, num_channels].\n      max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.\n      filter_size: int >= 1. Window size.\n      filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.\n      k1: float > 0. One of the SSIM dampening parameters.\n      k2: float > 0. One of the SSIM dampening parameters.\n      return_map: Bool. If True, will cause the per-pixel SSIM \"map\" to returned\n    Returns:\n      Each image's mean SSIM, or a tensor of individual values if `return_map`.\n    \"\"\"\n    device = img0.device\n    img0 = img0.type(torch.float32)\n    img1 = img1.type(torch.float32)\n    ori_shape = img0.size()\n    width, height, num_channels = ori_shape[-3:]\n    img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2)\n    img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2)\n    batch_size = img0.shape[0]\n\n    # Construct a 1D Gaussian blur filter.\n    hw = filter_size // 2\n    shift = (2 * hw - filter_size + 1) / 2\n    f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2\n    filt = torch.exp(-0.5 * f_i)\n    filt /= torch.sum(filt)\n\n    # Blur in x and y (faster than the 2D convolution).\n    # z is a tensor of size [B, H, W, C]\n    filt_fn1 = lambda z: F.conv2d(\n        z,\n        filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1),\n        padding=[hw, 0],\n        groups=num_channels,\n    )\n    filt_fn2 = lambda z: F.conv2d(\n        z,\n        filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1),\n        padding=[0, hw],\n        groups=num_channels,\n    )\n\n    # Vmap the blurs to the tensor size, and then compose them.\n    filt_fn = lambda z: filt_fn1(filt_fn2(z))\n    mu0 = filt_fn(img0)\n    mu1 = filt_fn(img1)\n    mu00 = mu0 * mu0\n    mu11 = mu1 * mu1\n    mu01 = mu0 * mu1\n    sigma00 = filt_fn(img0 ** 2) - mu00\n    sigma11 = filt_fn(img1 ** 2) - mu11\n    sigma01 = filt_fn(img0 * img1) - mu01\n\n    # Clip the variances and covariances to valid values.\n    # Variance must be non-negative:\n    sigma00 = torch.clamp(sigma00, min=0.0)\n    sigma11 = torch.clamp(sigma11, min=0.0)\n    sigma01 = torch.sign(sigma01) * torch.min(\n        torch.sqrt(sigma00 * sigma11), torch.abs(sigma01)\n    )\n\n    c1 = (k1 * max_val) ** 2\n    c2 = (k2 * max_val) ** 2\n    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)\n    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)\n    ssim_map = numer / denom\n    ssim = torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1)\n    return ssim_map if return_map else ssim\n\n\ndef init_lpips(net_name, device):\n    assert net_name in [\"alex\", \"vgg\"]\n    import lpips\n\n    print(f\"init_lpips: lpips_{net_name}\")\n    return lpips.LPIPS(net=net_name, version=\"0.1\").eval().cuda()\n\n\nlpips_fns = {\n    \"alex\": lpips.LPIPS(net=\"alex\", version=\"0.1\").eval().cuda(),\n    \"vgg\": lpips.LPIPS(net=\"vgg\", version=\"0.1\").eval().cuda(),\n}\n\n\ndef rgb_lpips(gt, im, net_name):\n    assert net_name in [\"alex\", \"vgg\"]\n    gt = gt.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda()\n    im = im.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda()\n    return lpips_fns[net_name](gt, im, normalize=True).item()\n\n\n@torch.cuda.amp.autocast(enabled=False)\ndef get_rays(poses, intrinsics, H, W, N=-1, error_map=None):\n    \"\"\"get rays\n    Args:\n        poses: [B, 4, 4], cam2world\n        intrinsics: [4]\n        H, W, N: int\n        error_map: [B, 128 * 128], sample probability based on training error\n    Returns:\n        rays_o, rays_d: [B, N, 3]\n        inds: [B, N]\n    \"\"\"\n\n    device = poses.device\n    B = poses.shape[0]\n    fx, fy, cx, cy = intrinsics\n\n    i, j = custom_meshgrid(\n        torch.linspace(0, W - 1, W, device=device),\n        torch.linspace(0, H - 1, H, device=device),\n    )\n    i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5\n    j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5\n\n    results = {}\n\n    if N > 0:\n        N = min(N, H * W)\n\n        if error_map is None:\n            inds = torch.randint(0, H * W, size=[N], device=device)  # may duplicate\n            inds = inds.expand([B, N])\n        else:\n\n            # weighted sample on a low-reso grid\n            inds_coarse = torch.multinomial(\n                error_map.to(device), N, replacement=False\n            )  # [B, N], but in [0, 128*128)\n\n            # map to the original resolution with random perturb.\n            inds_x, inds_y = (\n                inds_coarse // 128,\n                inds_coarse % 128,\n            )  # `//` will throw a warning in torch 1.10... anyway.\n            sx, sy = H / 128, W / 128\n            inds_x = (\n                (inds_x * sx + torch.rand(B, N, device=device) * sx)\n                .long()\n                .clamp(max=H - 1)\n            )\n            inds_y = (\n                (inds_y * sy + torch.rand(B, N, device=device) * sy)\n                .long()\n                .clamp(max=W - 1)\n            )\n            inds = inds_x * W + inds_y\n\n            results[\"inds_coarse\"] = inds_coarse  # need this when updating error_map\n\n        i = torch.gather(i, -1, inds)\n        j = torch.gather(j, -1, inds)\n\n        results[\"inds\"] = inds\n\n    else:\n        inds = torch.arange(H * W, device=device).expand([B, H * W])\n\n    zs = torch.ones_like(i)\n    xs = (i - cx) / fx * zs\n    ys = (j - cy) / fy * zs\n    directions = torch.stack((xs, ys, zs), dim=-1)\n    directions = directions / torch.norm(directions, dim=-1, keepdim=True)\n    rays_d = directions @ poses[:, :3, :3].transpose(-1, -2)  # (B, N, 3)\n\n    rays_o = poses[..., :3, 3]  # [B, 3]\n    rays_o = rays_o[..., None, :].expand_as(rays_d)  # [B, N, 3]\n\n    results[\"rays_o\"] = rays_o\n    results[\"rays_d\"] = rays_d\n\n    return results\n\n\ndef seed_everything(seed):\n    random.seed(seed)\n    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    # torch.backends.cudnn.deterministic = True\n    # torch.backends.cudnn.benchmark = True\n\n\ndef torch_vis_2d(x, renormalize=False):\n    # x: [3, H, W] or [1, H, W] or [H, W]\n    import matplotlib.pyplot as plt\n    import numpy as np\n    import torch\n\n    if isinstance(x, torch.Tensor):\n        if len(x.shape) == 3:\n            x = x.permute(1, 2, 0).squeeze()\n        x = x.detach().cpu().numpy()\n\n    print(f\"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}\")\n\n    x = x.astype(np.float32)\n\n    # renormalize\n    if renormalize:\n        x = (x - x.min(axis=0, keepdims=True)) / (\n            x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8\n        )\n\n    plt.imshow(x)\n    plt.show()\n\n\ndef extract_fields(bound_min, bound_max, resolution, query_func, S=128):\n\n    X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)\n    Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)\n    Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)\n\n    u = np.zeros([resolution, resolution, resolution], dtype=np.float32)\n    with torch.no_grad():\n        for xi, xs in enumerate(X):\n            for yi, ys in enumerate(Y):\n                for zi, zs in enumerate(Z):\n                    xx, yy, zz = custom_meshgrid(xs, ys, zs)\n                    pts = torch.cat(\n                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],\n                        dim=-1,\n                    )  # [S, 3]\n                    val = (\n                        query_func(pts)\n                        .reshape(len(xs), len(ys), len(zs))\n                        .detach()\n                        .cpu()\n                        .numpy()\n                    )  # [S, 1] --> [x, y, z]\n                    u[\n                        xi * S : xi * S + len(xs),\n                        yi * S : yi * S + len(ys),\n                        zi * S : zi * S + len(zs),\n                    ] = val\n    return u\n\n\ndef extract_geometry(bound_min, bound_max, resolution, threshold, query_func):\n    # print('threshold: {}'.format(threshold))\n    u = extract_fields(bound_min, bound_max, resolution, query_func)\n\n    # print(u.shape, u.max(), u.min(), np.percentile(u, 50))\n\n    vertices, triangles = mcubes.marching_cubes(u, threshold)\n\n    b_max_np = bound_max.detach().cpu().numpy()\n    b_min_np = bound_min.detach().cpu().numpy()\n\n    vertices = (\n        vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :]\n        + b_min_np[None, :]\n    )\n    return vertices, triangles\n\n\nclass PSNRMeter:\n    def __init__(self):\n        self.V = 0\n        self.N = 0\n\n    def clear(self):\n        self.V = 0\n        self.N = 0\n\n    def prepare_inputs(self, *inputs):\n        outputs = []\n        for i, inp in enumerate(inputs):\n            if torch.is_tensor(inp):\n                inp = inp.detach().cpu().numpy()\n            outputs.append(inp)\n\n        return outputs\n\n    def update(self, preds, truths):\n        preds, truths = self.prepare_inputs(\n            preds, truths\n        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]\n\n        # simplified since max_pixel_value is 1 here.\n        psnr = -10 * np.log10(np.mean((preds - truths) ** 2))\n\n        self.V += psnr\n        self.N += 1\n\n    def measure(self):\n        return self.V / self.N\n\n    def write(self, writer, global_step, prefix=\"\"):\n        writer.add_scalar(os.path.join(prefix, \"PSNR\"), self.measure(), global_step)\n\n    def report(self):\n        return f\"PSNR = {self.measure():.6f}\"\n\n\nclass Trainer(object):\n    def __init__(\n        self,\n        name,  # name of this experiment\n        opt,  # extra conf\n        model_tea,  # network\n        model_stu,\n        criterion=None,  # loss function, if None, assume inline implementation in train_step\n        optimizer=None,  # optimizer\n        ema_decay=None,  # if use EMA, set the decay\n        lr_scheduler=None,  # scheduler\n        metrics=[],  # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.\n        local_rank=0,  # which GPU am I\n        world_size=1,  # total num of GPUs\n        device=None,  # device to use, usually setting to None is OK. (auto choose device)\n        mute=False,  # whether to mute all print\n        fp16=False,  # amp optimize level\n        eval_interval=10e10,  # eval once every $ epoch\n        max_keep_ckpt=2,  # max num of saved ckpts in disk\n        workspace=\"workspace\",  # workspace to save logs & ckpts\n        best_mode=\"min\",  # the smaller/larger result, the better\n        use_loss_as_metric=True,  # use loss as the first metric\n        report_metric_at_train=False,  # also report metrics at training\n        use_checkpoint=\"latest\",  # which ckpt to use at init time\n        use_tensorboardX=True,  # whether to use tensorboard for logging\n        scheduler_update_every_step=False,  # whether to call scheduler.step() after every train step\n    ):\n\n        self.optimizer_fn = optimizer\n        self.lr_scheduler_fn = lr_scheduler\n        self.name = name\n        self.opt = opt\n        self.mute = mute\n        self.metrics = metrics\n        self.local_rank = local_rank\n        self.world_size = world_size\n        self.workspace = workspace\n        self.ema_decay = ema_decay\n        self.fp16 = fp16\n        self.best_mode = best_mode\n        self.use_loss_as_metric = use_loss_as_metric\n        self.report_metric_at_train = report_metric_at_train\n        self.max_keep_ckpt = max_keep_ckpt\n        self.eval_interval = eval_interval\n        self.use_checkpoint = use_checkpoint\n        self.use_tensorboardX = use_tensorboardX\n        self.time_stamp = time.strftime(\"%Y-%m-%d_%H-%M-%S\")\n        self.scheduler_update_every_step = scheduler_update_every_step\n        self.device = (\n            device\n            if device is not None\n            else torch.device(\n                f\"cuda:{local_rank}\" if torch.cuda.is_available() else \"cpu\"\n            )\n        )\n        self.console = Console()\n\n        self.model_tea = model_tea.to(device)\n        self.model_stu = model_stu.to(device)\n\n        if isinstance(criterion, nn.Module):\n            criterion.to(self.device)\n        self.criterion = criterion\n\n        if optimizer is None:\n            self.optimizer = optim.AdamW(\n                self.model_stu.parameters(), lr=0.001, weight_decay=5e-4\n            )  # naive adam\n        else:\n            self.optimizer = optimizer(self.model_stu)\n\n        if lr_scheduler is None:\n            self.lr_scheduler = optim.lr_scheduler.LambdaLR(\n                self.optimizer, lr_lambda=lambda epoch: 1\n            )  # fake scheduler\n        else:\n            self.ls = lr_scheduler\n            self.lr_scheduler = lr_scheduler(self.optimizer)\n        if ema_decay is not None and ema_decay > 0:  # ema_decay=0.95\n            self.ema = ExponentialMovingAverage(\n                self.model_stu.parameters(), decay=ema_decay\n            )\n        else:\n            self.ema = None\n\n        self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)\n\n        # variable init\n        self.epoch = 1\n        self.global_step = 0\n        self.local_step = 0\n        self.stats = {\n            \"loss\": [],\n            \"valid_loss\": [],\n            \"results\": [],  # metrics[0], or valid_loss\n            \"checkpoints\": [],  # record path of saved ckpt, to automatically remove old ckpt\n            \"best_result\": None,\n        }\n\n        # auto fix\n        if len(metrics) == 0 or self.use_loss_as_metric:\n            self.best_mode = \"min\"\n\n        # workspace prepare\n        self.log_ptr = None\n        if self.workspace is not None:\n            os.makedirs(self.workspace, exist_ok=True)\n            self.log_path = os.path.join(workspace, f\"log_{self.name}.txt\")\n            self.log_ptr = open(self.log_path, \"a+\")\n\n            self.ckpt_path = os.path.join(self.workspace, \"checkpoints\")\n            self.best_path = f\"{self.ckpt_path}/{self.name}.pth\"\n            os.makedirs(self.ckpt_path, exist_ok=True)\n        self.log(self.opt)\n\n        self.log(\n            f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {\"fp16\" if self.fp16 else \"fp32\"} | {self.workspace}'\n        )\n        self.log(\n            f\"[INFO] #parameters: {sum([p.numel() for p in model_stu.parameters() if p.requires_grad])}\"\n        )\n\n    def __del__(self):\n        if self.log_ptr:\n            self.log_ptr.close()\n\n    def log(self, *args, **kwargs):\n        if self.local_rank == 0:\n            if not self.mute:\n                # print(*args)\n                self.console.print(*args, **kwargs)\n            if self.log_ptr:\n                print(*args, file=self.log_ptr)\n                self.log_ptr.flush()  # write immediately to file\n\n    def train(self, train_loader, valid_loader, max_epochs):\n        if self.use_tensorboardX and self.local_rank == 0:\n            self.writer = tensorboardX.SummaryWriter(\n                os.path.join(self.workspace, \"run\", self.name)\n            )\n\n        # mark untrained region (i.e., not covered by any camera from the training dataset)\n        if self.model_tea.cuda_ray:\n            self.model_tea.mark_untrained_grid(\n                train_loader._data.poses, train_loader._data.intrinsics\n            )\n            self.model_stu.mark_untrained_grid(\n                train_loader._data.poses, train_loader._data.intrinsics\n            )\n\n        for p in self.model_tea.parameters():\n            p.requires_grad = False\n        self.model_tea.eval()\n\n        # get a ref to error_map\n        self.error_map = train_loader._data.error_map\n\n        for epoch in range(self.epoch, max_epochs + 1):\n            self.epoch = epoch\n            self.train_one_epoch(train_loader)\n            print(\"\\n\", self.workspace, \"\\n\")\n\n            if (\n                self.workspace is not None\n                and self.local_rank == 0\n                and self.epoch > max_epochs - 2\n            ):\n                self.save_checkpoint(\n                    full=False, best=False\n                )  # FIXME save should include teacher and student\n\n            if self.epoch % self.eval_interval == 0:\n                self.evaluate_one_epoch(valid_loader)\n                self.save_checkpoint(full=False, best=True)\n\n        if self.use_tensorboardX and self.local_rank == 0:\n            self.writer.close()\n\n    def train_one_epoch(self, loader):\n        self.log(\n            f\"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...\"\n        )\n\n        total_loss = 0\n        total_loss_rgb = 0\n        total_loss_fea = 0\n        if self.local_rank == 0 and self.report_metric_at_train:\n            for metric in self.metrics:\n                metric.clear()\n\n        self.model_stu.train()\n        self.model_tea.train()\n\n        # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs\n        # ref: https://pytorch.org/docs/stable/data.html\n        if self.world_size > 1:\n            loader.sampler.set_epoch(self.epoch)\n\n        if self.local_rank == 0:\n            pbar = tqdm.tqdm(\n                total=len(loader) * loader.batch_size,\n                bar_format=\"{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n            )\n\n        self.local_step = 0\n\n        for data in loader:\n            # update grid every 16 steps\n            if (\n                self.model_tea.cuda_ray\n                and self.global_step % self.opt.update_extra_interval == 0\n            ):\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    if self.opt.update_stu_extra:\n                        self.model_stu.update_extra_state()\n                    else:\n                        pass\n\n            self.local_step += 1\n            self.global_step += 1\n\n            self.optimizer.zero_grad()\n\n            with torch.cuda.amp.autocast(enabled=self.fp16):\n                # XXX self.train_step\n                if self.opt.just_train_a_model:\n                    loss, preds, truths = self.train_step(data)\n                else:\n                    (\n                        preds,\n                        truths,\n                        loss,\n                        loss_rgb,\n                        loss_fea,\n                        loss_fea_sc,\n                        loss_color,\n                        loss_sigma,\n                    ) = self.train_step(data)\n\n            self.scaler.scale(loss).backward()\n            self.scaler.step(self.optimizer)\n            self.scaler.update()\n\n            if self.scheduler_update_every_step:\n                self.lr_scheduler.step()\n\n            loss_val = loss.item()\n            total_loss += loss_val\n\n            if self.opt.just_train_a_model:\n                if self.report_metric_at_train:\n                    for metric in self.metrics:\n                        metric.update(preds, truths)\n\n                if self.use_tensorboardX:\n                    self.writer.add_scalar(\"train/loss\", loss_val, self.global_step)\n                    self.writer.add_scalar(\n                        \"train/lr\",\n                        self.optimizer.param_groups[0][\"lr\"],\n                        self.global_step,\n                    )\n\n                if self.scheduler_update_every_step:\n                    pbar.set_description(\n                        f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}\"\n                    )\n                else:\n                    pbar.set_description(\n                        f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\"\n                    )\n                pbar.update(loader.batch_size)\n\n            else:\n                total_loss_rgb += loss_rgb\n                total_loss_fea += loss_fea\n\n                if self.local_rank == 0:\n                    if self.report_metric_at_train:\n                        for metric in self.metrics:\n                            metric.update(preds, truths)\n\n                    if self.use_tensorboardX:\n                        self.writer.add_scalar(\"train/loss\", loss_val, self.global_step)\n                        self.writer.add_scalar(\n                            \"train/loss_rgb\", loss_rgb, self.global_step\n                        )\n                        self.writer.add_scalar(\n                            \"train/loss_fea\", loss_fea, self.global_step\n                        )\n                        self.writer.add_scalar(\n                            \"train/loss_fea_sc\", loss_fea_sc, self.global_step\n                        )\n                        self.writer.add_scalar(\n                            \"train/loss_coloc\", loss_color, self.global_step\n                        )\n                        self.writer.add_scalar(\n                            \"train/loss_sigma\", loss_sigma, self.global_step\n                        )\n                        self.writer.add_scalar(\n                            \"train/lr\",\n                            self.optimizer.param_groups[0][\"lr\"],\n                            self.global_step,\n                        )\n\n                    if self.scheduler_update_every_step:  # run this\n                        # pbar.set_description(f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}\")\n                        cur_lr = self.optimizer.param_groups[0][\"lr\"]\n                        pbar.set_description(\n                            f\"loss={total_loss/self.local_step:.5f}, loss_rgb={total_loss_rgb/self.local_step:.5f}, loss_fea={total_loss_fea/self.local_step:.5f} lr={cur_lr:.5f}\"\n                        )\n                    else:\n                        pbar.set_description(\n                            f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\"\n                        )\n                    pbar.update(loader.batch_size)\n\n            # only for vm FIXME upsample_resolutions should be setted first in main\n            if (\n                self.opt.model_type == \"vm\"\n                and self.global_step in self.opt.upsample_model_steps\n            ):\n                # shrink\n                if (\n                    self.model_stu.cuda_ray\n                ):  # and self.global_step == self.opt.upsample_model_steps[0]:\n                    self.model_stu.shrink_model()\n\n                # adaptive voxel size from aabb_train\n                n_vox = self.upsample_resolutions.pop(0) ** 3  # n_voxels\n                aabb = self.model_stu.aabb_train.cpu().numpy()\n                vox_size = np.cbrt(np.prod(aabb[3:] - aabb[:3]) / n_vox)\n                reso = ((aabb[3:] - aabb[:3]) / vox_size).astype(np.int32).tolist()\n                self.log(\n                    f\"[INFO] upsample model at step {self.global_step} from {self.model_stu.resolution} to {reso}\"\n                )\n                self.model_stu.upsample_model(reso)\n\n                # reset optimizer since params changed.\n                self.optimizer = self.optimizer_fn(self.model_stu)\n                self.lr_scheduler = self.lr_scheduler_fn(self.optimizer)\n\n        if self.ema is not None:\n            self.ema.update()\n\n        average_loss = total_loss / self.local_step\n        self.stats[\"loss\"].append(average_loss)\n\n        if self.local_rank == 0:\n            pbar.close()\n            if self.report_metric_at_train:\n                for metric in self.metrics:\n                    self.log(metric.report(), style=\"red\")\n                    if self.use_tensorboardX:\n                        metric.write(self.writer, self.epoch, prefix=\"train\")\n                    metric.clear()\n\n        if not self.scheduler_update_every_step:\n            if isinstance(\n                self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau\n            ):\n                self.lr_scheduler.step(average_loss)\n            else:\n                self.lr_scheduler.step()\n\n        self.log(f\"==> Finished Epoch {self.epoch}.\")\n\n    ### ------------------------------\n\n    def get_loss(self, pred, gt):\n        if self.opt.loss_type == \"L2\":\n            loss = torch.mean((gt - pred) ** 2)\n        elif self.opt.loss_type == \"normL2\":\n            loss = torch.norm(pred - gt)\n        elif self.opt.loss_type == \"normL1\":\n            loss = torch.norm(pred - gt, p=1)\n        elif self.opt.loss_type == \"smoothL1\":\n            loss = torch.nn.functional.smooth_l1_loss(pred, gt, beta=0.05)\n        else:\n            raise ValueError(\"error loss_type\")\n        return loss\n\n    def train_step(self, data):\n\n        rays_o = data[\"rays_o\"]  # [B, N, 3]\n        rays_d = data[\"rays_d\"]  # [B, N, 3]\n\n        # if there is no gt image, we train with CLIP loss.\n        if \"images\" not in data:\n            assert 1 == 2\n            B, N = rays_o.shape[:2]\n            H, W = data[\"H\"], data[\"W\"]\n            # currently fix white bg, MUST force all rays!\n            outputs = self.model.render(\n                rays_o,\n                rays_d,\n                staged=False,\n                bg_color=None,\n                perturb=True,\n                force_all_rays=True,\n                **vars(self.opt),\n            )\n            pred_rgb = (\n                outputs[\"image\"].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()\n            )\n            loss = self.clip_loss(pred_rgb)\n            return pred_rgb, None, loss\n\n        images = data[\"images\"]  # [B, N, 3/4]\n        B, N, C = images.shape\n\n        if self.opt.color_space == \"linear\":\n            images[..., :3] = srgb_to_linear(images[..., :3])\n\n        if C == 3 or self.model_stu.bg_radius > 0:\n            bg_color = 1\n        # train with random background color if not using a bg model and has alpha channel.\n        else:\n            bg_color = torch.rand_like(images[..., :3])  # [N, 3], pixel-wise random.\n        if C == 4:\n            gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (\n                1 - images[..., 3:]\n            )\n        else:\n            gt_rgb = images\n\n        # outputs = self.model.render(rays_o,rays_d,staged=False,bg_color=bg_color,perturb=True,force_all_rays=False,**vars(self.opt))\n        if self.opt.render_stu_first:\n            outputs_stu = self.model_stu.render(\n                rays_o,\n                rays_d,\n                staged=False,\n                bg_color=bg_color,\n                perturb=True,\n                force_all_rays=False,\n                **vars(self.opt),\n            )\n            pred_rgb_stu = outputs_stu[\"image\"]\n            if not self.opt.just_train_a_model:\n                with torch.no_grad():\n                    outputs_tea = self.model_tea.render(\n                        rays_o,\n                        rays_d,\n                        staged=False,\n                        bg_color=bg_color,\n                        perturb=True,\n                        force_all_rays=False,\n                        inherited_params=outputs_stu[\"inherited_params\"],\n                        **vars(self.opt),\n                    )\n                    pred_rgb_tea = outputs_tea[\"image\"]\n        else:\n            with torch.no_grad():\n                outputs_tea = self.model_tea.render(\n                    rays_o,\n                    rays_d,\n                    staged=False,\n                    bg_color=bg_color,\n                    perturb=True,\n                    force_all_rays=False,\n                    **vars(self.opt),\n                )\n                pred_rgb_tea = outputs_tea[\"image\"]\n            outputs_stu = self.model_stu.render(\n                rays_o,\n                rays_d,\n                staged=False,\n                bg_color=bg_color,\n                perturb=True,\n                force_all_rays=False,\n                inherited_params=outputs_tea[\"inherited_params\"],\n                **vars(self.opt),\n            )\n            pred_rgb_stu = outputs_stu[\"image\"]\n\n        loss = 0.0\n\n        if self.opt.just_train_a_model:\n            pred_rgb = pred_rgb_stu\n            loss = loss + self.criterion(pred_rgb, gt_rgb).mean()\n            if self.opt.model_type == \"vm\":\n                loss = loss + self.model_stu.density_loss() * self.opt.l1_reg_weight\n            return loss, pred_rgb, gt_rgb\n\n    def evaluate(self, loader, name=None):\n        self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX\n        self.evaluate_one_epoch(loader, name)\n        self.use_tensorboardX = use_tensorboardX\n\n    def evaluate_one_epoch(self, loader, name=None):\n        self.log(f\"++> Evaluate at epoch {self.epoch} ...\")\n\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n\n        total_loss = 0\n        if self.local_rank == 0:\n            for metric in self.metrics:\n                metric.clear()\n        if self.opt.test_teacher:\n            self.model_stu = self.model_tea\n        self.model_stu.eval()\n\n        if self.ema is not None:\n            self.ema.store()\n            self.ema.copy_to()\n\n        if self.local_rank == 0:\n            pbar = tqdm.tqdm(\n                total=len(loader) * loader.batch_size,\n                bar_format=\"{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n            )\n\n        with torch.no_grad():\n            self.local_step = 0\n            self.ssim = 0.0\n            self.lpips_vgg = 0.0\n            self.lpips_alex = 0.0\n\n            # update grid\n            if self.model_stu.cuda_ray:\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    if self.opt.update_stu_extra:\n                        self.model_stu.update_extra_state()\n                    else:\n                        pass\n\n            for data in loader:\n                self.local_step += 1\n\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    preds, preds_depth, truths, loss = self.eval_step(data)\n\n                # all_gather/reduce the statistics (NCCL only support all_*)\n                if self.world_size > 1:\n                    dist.all_reduce(loss, op=dist.ReduceOp.SUM)\n                    loss = loss / self.world_size\n\n                    preds_list = [\n                        torch.zeros_like(preds).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(preds_list, preds)\n                    preds = torch.cat(preds_list, dim=0)\n\n                    preds_depth_list = [\n                        torch.zeros_like(preds_depth).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(preds_depth_list, preds_depth)\n                    preds_depth = torch.cat(preds_depth_list, dim=0)\n\n                    truths_list = [\n                        torch.zeros_like(truths).to(self.device)\n                        for _ in range(self.world_size)\n                    ]  # [[B, ...], [B, ...], ...]\n                    dist.all_gather(truths_list, truths)\n                    truths = torch.cat(truths_list, dim=0)\n\n                loss_val = loss.item()\n                total_loss += loss_val\n\n                # only rank = 0 will perform evaluation.\n                if self.local_rank == 0:\n\n                    for metric in self.metrics:\n                        metric.update(preds, truths)\n                    self.lpips_alex += rgb_lpips(truths, preds, \"alex\")\n                    self.lpips_vgg += rgb_lpips(truths, preds, \"vgg\")\n                    self.ssim += compute_ssim(\n                        preds,\n                        truths,\n                        max_val=max(preds.max().item(), truths.max().item()),\n                    ).item()\n\n                    # save image\n                    save_path = os.path.join(\n                        self.workspace,\n                        \"validation\",\n                        f\"{name}_{self.local_step:04d}.png\",\n                    )\n                    save_path_depth = os.path.join(\n                        self.workspace,\n                        \"validation\",\n                        f\"{name}_{self.local_step:04d}_depth.png\",\n                    )\n                    save_path_gt = os.path.join(\n                        self.workspace,\n                        \"validation\",\n                        f\"{name}_{self.local_step:04d}_gt.png\",\n                    )\n\n                    # self.log(f\"==> Saving validation image to {save_path}\")\n                    os.makedirs(os.path.dirname(save_path), exist_ok=True)\n\n                    if self.opt.color_space == \"linear\":\n                        preds = linear_to_srgb(preds)\n\n                    pred = preds[0].detach().cpu().numpy()\n                    pred_depth = preds_depth[0].detach().cpu().numpy()\n                    if self.local_step < 15:\n                        cv2.imwrite(\n                            save_path,\n                            cv2.cvtColor(\n                                (pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR\n                            ),\n                        )\n                        cv2.imwrite(\n                            save_path_depth, (pred_depth * 255).astype(np.uint8)\n                        )\n                        cv2.imwrite(\n                            save_path_gt,\n                            cv2.cvtColor(\n                                (truths[0].detach().cpu().numpy() * 255).astype(\n                                    np.uint8\n                                ),\n                                cv2.COLOR_RGB2BGR,\n                            ),\n                        )\n                        # cv2.imwrite(save_path_gt, cv2.cvtColor((linear_to_srgb(truths[0].detach().cpu().numpy()) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))\n\n                    pbar.set_description(\n                        f\"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})\"\n                    )\n                    pbar.update(loader.batch_size)\n\n        average_loss = total_loss / self.local_step\n        self.stats[\"valid_loss\"].append(average_loss)\n\n        if self.local_rank == 0:\n            pbar.close()\n            if not self.use_loss_as_metric and len(self.metrics) > 0:\n                result = self.metrics[0].measure()\n                self.stats[\"results\"].append(\n                    result if self.best_mode == \"min\" else -result\n                )  # if max mode, use -result\n            else:\n                self.stats[\"results\"].append(\n                    average_loss\n                )  # if no metric, choose best by min loss\n\n            for metric in self.metrics:\n                self.log(metric.report(), style=\"blue\")\n                psnr = metric.report().split(\"=\")[-1].strip()[:5]\n                self.psnr = float(psnr)\n                if self.use_tensorboardX:\n                    metric.write(self.writer, self.epoch, prefix=\"evaluate\")\n                metric.clear()\n\n        self.ssim /= self.local_step\n        self.lpips_alex /= self.local_step\n        self.lpips_vgg /= self.local_step\n        if self.ema is not None:\n            self.ema.restore()\n        # from IPython import embed; embed()\n        print(\n            f\"\\n psnr:{psnr} ssim:{self.ssim} alex:{self.lpips_alex} vgg:{self.lpips_vgg} \\n\"\n        )\n\n        # cmd = f'mv {self.workspace} {self.workspace}-pnsr{psnr}'\n        # print(cmd)\n        # os.system(cmd)\n        self.log(f\"++> Evaluate epoch {self.epoch} Finished.\")\n\n    def eval_step(self, data):\n\n        rays_o = data[\"rays_o\"]  # [B, N, 3]\n        rays_d = data[\"rays_d\"]  # [B, N, 3]\n        images = data[\"images\"]  # [B, H, W, 3/4]\n        B, H, W, C = images.shape\n\n        if self.opt.color_space == \"linear\":\n            images[..., :3] = srgb_to_linear(images[..., :3])\n\n        # eval with fixed background color\n        bg_color = 1\n        if C == 4:\n            gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (\n                1 - images[..., 3:]\n            )\n        else:\n            gt_rgb = images\n\n        outputs = self.model_stu.render(\n            rays_o,\n            rays_d,\n            staged=True,\n            bg_color=bg_color,\n            perturb=False,\n            **vars(self.opt),\n        )\n\n        pred_rgb = outputs[\"image\"].reshape(B, H, W, 3)\n        pred_depth = outputs[\"depth\"].reshape(B, H, W)\n\n        loss = self.criterion(pred_rgb, gt_rgb).mean()\n\n        return pred_rgb, pred_depth, gt_rgb, loss\n\n    def save_checkpoint(self, name=None, full=False, best=False, remove_old=True):\n        full = False\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n        if self.opt.model_type == \"vm\":\n            state = {\n                \"epoch\": self.epoch,\n                \"global_step\": self.global_step,\n                \"stats\": self.stats,\n                \"resolution\": self.model_stu.resolution,\n            }\n        else:\n            state = {\n                \"epoch\": self.epoch,\n                \"global_step\": self.global_step,\n                \"stats\": self.stats,\n            }\n\n        if self.model_stu.cuda_ray:\n            state[\"mean_count\"] = self.model_stu.mean_count\n            state[\"mean_density\"] = self.model_stu.mean_density\n\n        if full:\n            state[\"optimizer\"] = self.optimizer.state_dict()\n            state[\"lr_scheduler\"] = self.lr_scheduler.state_dict()\n            state[\"scaler\"] = self.scaler.state_dict()\n            if self.ema is not None:\n                state[\"ema\"] = self.ema.state_dict()\n\n        if not best:\n\n            state[\"model\"] = self.model_stu.state_dict()\n\n            file_path = f\"{self.ckpt_path}/{name}.pth\"\n\n            if remove_old:\n                self.stats[\"checkpoints\"].append(file_path)\n\n                if len(self.stats[\"checkpoints\"]) > self.max_keep_ckpt:\n                    old_ckpt = self.stats[\"checkpoints\"].pop(0)\n                    if os.path.exists(old_ckpt):\n                        os.remove(old_ckpt)\n\n            torch.save(state, file_path)\n\n        else:\n            if len(self.stats[\"results\"]) > 0:\n                if (\n                    self.stats[\"best_result\"] is None\n                    or self.stats[\"results\"][-1] < self.stats[\"best_result\"]\n                ):\n                    self.log(\n                        f\"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}\"\n                    )\n                    self.stats[\"best_result\"] = self.stats[\"results\"][-1]\n\n                    # save ema results\n                    if self.ema is not None:\n                        self.ema.store()\n                        self.ema.copy_to()\n\n                    state[\"model\"] = self.model_stu.state_dict()\n\n                    if self.ema is not None:\n                        self.ema.restore()\n\n                    torch.save(state, self.best_path)\n            else:\n                self.log(\n                    f\"[WARN] no evaluated results found, skip saving best checkpoint.\"\n                )\n\n    def load_teacher_checkpoint(self):\n        checkpoint_dict = torch.load(self.opt.ckpt_teacher, map_location=self.device)\n\n        missing_keys, unexpected_keys = self.model_tea.load_state_dict(\n            checkpoint_dict[\"model\"], strict=False\n        )\n        self.log(\"[INFO] loaded teacher model.\")\n        if len(missing_keys) > 0:\n            self.log(f\"[WARN] missing keys: {missing_keys}\")\n        if len(unexpected_keys) > 0:\n            self.log(f\"[WARN] unexpected keys: {unexpected_keys}\")\n\n        if self.ema is not None and \"ema\" in checkpoint_dict:\n            self.ema.load_state_dict(checkpoint_dict[\"ema\"])\n        # if self.ema is not None and 'ema' in checkpoint_dict:\n        #     self.ema.load_state_dict(checkpoint_dict['ema'])\n\n        if self.model_tea.cuda_ray:\n            if \"mean_count\" in checkpoint_dict:\n                self.model_tea.mean_count = checkpoint_dict[\"mean_count\"]\n            if \"mean_density\" in checkpoint_dict:\n                self.model_tea.mean_density = checkpoint_dict[\"mean_density\"]\n\n    def load_student_checkpoint(self):\n        if self.opt.ckpt_student:\n            checkpoint_dict = torch.load(\n                self.opt.ckpt_student, map_location=self.device\n            )\n        else:\n            checkpoint_dict = torch.load(\n                self.opt.ckpt_teacher, map_location=self.device\n            )\n\n        if self.opt.model_type == \"vm\" and \"resolution\" in checkpoint_dict:\n            self.model_stu.upsample_model(checkpoint_dict[\"resolution\"])\n        missing_keys, unexpected_keys = self.model_stu.load_state_dict(\n            checkpoint_dict[\"model\"], strict=False\n        )\n        self.log(\"[INFO] loaded student model.\")\n        if len(missing_keys) > 0:\n            self.log(f\"[WARN] missing keys: {missing_keys}\")\n        if len(unexpected_keys) > 0:\n            self.log(f\"[WARN] unexpected keys: {unexpected_keys}\")\n        if self.model_stu.cuda_ray:\n            if \"mean_count\" in checkpoint_dict:\n                self.model_stu.mean_count = checkpoint_dict[\"mean_count\"]\n            if \"mean_density\" in checkpoint_dict:\n                self.model_stu.mean_density = checkpoint_dict[\"mean_density\"]\n\n        if self.ema is not None and \"ema\" in checkpoint_dict:\n            self.ema.load_state_dict(checkpoint_dict[\"ema\"])\n\n    def test(self, loader, save_path=None, name=None):\n        assert 1 == 2\n        if save_path is None:\n            save_path = os.path.join(self.workspace, \"results\")\n\n        if name is None:\n            name = f\"{self.name}_ep{self.epoch:04d}\"\n\n        os.makedirs(save_path, exist_ok=True)\n\n        self.log(f\"==> Start Test, save results to {save_path}\")\n\n        pbar = tqdm.tqdm(\n            total=len(loader) * loader.batch_size,\n            bar_format=\"{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\",\n        )\n        self.model_stu.eval()\n        with torch.no_grad():\n\n            # update grid\n            if self.model_stu.cuda_ray:\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    self.model_stu.update_extra_state()\n\n            for i, data in enumerate(loader):\n\n                with torch.cuda.amp.autocast(enabled=self.fp16):\n                    preds, preds_depth = self.test_step(data)\n\n                path = os.path.join(save_path, f\"{name}_{i:04d}.png\")\n                path_depth = os.path.join(save_path, f\"{name}_{i:04d}_depth.png\")\n\n                # self.log(f\"[INFO] saving test image to {path}\")\n\n                if self.opt.color_space == \"linear\":\n                    preds = linear_to_srgb(preds)\n\n                pred = preds[0].detach().cpu().numpy()\n                pred_depth = preds_depth[0].detach().cpu().numpy()\n\n                cv2.imwrite(\n                    path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)\n                )\n                cv2.imwrite(path_depth, (pred_depth * 255).astype(np.uint8))\n\n                pbar.update(loader.batch_size)\n\n        self.log(f\"==> Finished Test.\")\n\n    # moved out bg_color and perturb for more flexible control...\n    def test_step(self, data, bg_color=None, perturb=False):\n\n        rays_o = data[\"rays_o\"]  # [B, N, 3]\n        rays_d = data[\"rays_d\"]  # [B, N, 3]\n        H, W = data[\"H\"], data[\"W\"]\n\n        if bg_color is not None:\n            bg_color = bg_color.to(self.device)\n\n        outputs = self.model_stu.render(\n            rays_o,\n            rays_d,\n            staged=True,\n            bg_color=bg_color,\n            perturb=perturb,\n            **vars(self.opt),\n        )\n\n        pred_rgb = outputs[\"image\"].reshape(-1, H, W, 3)\n        pred_depth = outputs[\"depth\"].reshape(-1, H, W)\n\n        return pred_rgb, pred_depth\n"
  },
  {
    "path": "main_distill_mutual.py",
    "content": "import torch\nimport os\nimport argparse\n\nfrom distill_mutual.network import NeRFNetwork\nfrom functools import partial\nfrom time import time\nfrom distill_mutual.provider import NeRFDataset\nfrom distill_mutual.utils import *\nfrom IPython import embed\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef save_codes_env(workspace):\n    path = os.path.join(workspace, \"codes_env\")\n    os.makedirs(path, exist_ok=True)\n    os.system(f\"cp *.py {path}\")\n    os.system(f\"cp -r raymarching {path}\")\n    os.system(f\"cp -r distill_mutual {path}\")\n    os.system(f\"cp -r nerf {path}\")\n\n\ndef load_from_txt(opt, except_space=\"\"):\n    # except_space = {'workspace', 'teacher_type', 'model_type', 'test', 'test_teacher', 'use_spiral_pose', 'ckpt_teacher'}\n    except_space = {\"workspace\"}\n    with open(\n        os.path.join(opt.ckpt_teacher.split(\"checkpoints\")[0], \"args.txt\"), \"r\"\n    ) as f:  # change this path to your own params settings\n        load_args = f.readlines()\n    for i in range(1, len(load_args)):\n        if \"(\" in load_args[i]:\n            k, v = eval(load_args[i])\n        else:\n            continue\n        if k in opt and k not in except_space and v != opt.__dict__[k]:\n            print(k, v, opt.__dict__[k])\n            opt.__dict__[k] = v\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"path\", type=str)\n    parser.add_argument(\n        \"-O\", action=\"store_true\", help=\"equals --fp16 --cuda_ray --preload\"\n    )\n    parser.add_argument(\"--test\", action=\"store_true\", help=\"test mode\")\n    parser.add_argument(\"--workspace\", type=str, default=\"workspace\")\n    parser.add_argument(\"--seed\", type=int, default=0)\n\n    # training options\n    parser.add_argument(\"--iters\", type=int, default=30000, help=\"training iters\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"initial learning rate\")\n    parser.add_argument(\"--ckpt\", type=str, default=\"latest\")\n    parser.add_argument(\n        \"--num_rays\",\n        type=int,\n        default=4096,\n        help=\"num rays sampled per image for each training step\",\n    )\n    parser.add_argument(\n        \"--cuda_ray\",\n        action=\"store_true\",\n        help=\"use CUDA raymarching instead of pytorch\",\n    )\n    parser.add_argument(\n        \"--max_steps\",\n        type=int,\n        default=1024,\n        help=\"max num steps sampled per ray (only valid when using --cuda_ray)\",\n    )\n    parser.add_argument(\n        \"--num_steps\",\n        type=int,\n        default=512,\n        help=\"num steps sampled per ray (only valid when NOT using --cuda_ray)\",\n    )\n    parser.add_argument(\n        \"--upsample_steps\",\n        type=int,\n        default=0,\n        help=\"num steps up-sampled per ray (only valid when NOT using --cuda_ray)\",\n    )\n    parser.add_argument(\n        \"--update_extra_interval\",\n        type=int,\n        default=16,\n        help=\"iter interval to update extra status (only valid when using --cuda_ray)\",\n    )\n    parser.add_argument(\n        \"--max_ray_batch\",\n        type=int,\n        default=4096,\n        help=\"batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)\",\n    )\n\n    parser.add_argument(\n        \"--fp16\", action=\"store_true\", help=\"use amp mixed precision training\"\n    )\n\n    parser.add_argument(\n        \"--mode\",\n        type=str,\n        default=\"blender\",\n        help=\"dataset mode, supports (colmap, blender)\",\n    )\n    parser.add_argument(\n        \"--color_space\",\n        type=str,\n        default=\"srgb\",\n        help=\"Color space, supports (linear, srgb)\",\n    )\n    parser.add_argument(\n        \"--preload\",\n        action=\"store_true\",\n        help=\"preload all data into GPU, accelerate training but use more GPU memory\",\n    )\n    parser.add_argument(\n        \"--bound\",\n        type=float,\n        default=1,\n        help=\"assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.\",\n    )\n    parser.add_argument(\n        \"--scale\",\n        type=float,\n        default=0.8,\n        help=\"scale camera location into box[-bound, bound]^3\",\n    )\n    parser.add_argument(\n        \"--dt_gamma\",\n        type=float,\n        default=0,\n        help=\"dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)\",\n    )\n    parser.add_argument(\n        \"--min_near\", type=float, default=0.2, help=\"minimum near distance for camera\"\n    )\n    parser.add_argument(\n        \"--density_thresh\",\n        type=float,\n        default=10,\n        help=\"threshold for density grid to be occupied\",\n    )\n    parser.add_argument(\n        \"--bg_radius\",\n        type=float,\n        default=-1,\n        help=\"if positive, use a background model at sphere(bg_radius)\",\n    )\n\n    # experimental\n    parser.add_argument(\n        \"--error_map\", action=\"store_true\", help=\"use error map to sample rays\"\n    )\n    parser.add_argument(\n        \"--clip_text\", type=str, default=\"\", help=\"text input for CLIP guidance\"\n    )\n\n    parser.add_argument(\n        \"--loss_type\",\n        type=str,\n        default=\"normL2\",\n        choices=[\"normL2\", \"L2\", \"normL1\", \"L1\"],\n    )\n    parser.add_argument(\n        \"--distill_mode\",\n        type=str,\n        default=\"no_fix_mlp\",\n        choices=[\"fix_mlp\", \"no_fix_mlp\"],\n        help=\"fix mlp for hash\",\n    )\n    parser.add_argument(\"--loss_rate_rgb\", type=float, default=1.0)\n    parser.add_argument(\"--loss_rate_fea_sc\", type=float, default=0.002)\n    parser.add_argument(\"--loss_rate_color\", type=float, default=0.002)\n    parser.add_argument(\"--loss_rate_sigma\", type=float, default=0.002)\n    parser.add_argument(\"--l1_reg_weight\", type=float, default=1e-4)\n\n    parser.add_argument(\"--ckpt_teacher\", type=str, default=\"\")\n    parser.add_argument(\"--ckpt_student\", type=str, default=\"\")\n    parser.add_argument(\"--sigma_clip_min\", type=float, default=-2)\n    parser.add_argument(\"--sigma_clip_max\", type=float, default=7)\n    parser.add_argument(\"--render_stu_first\", action=\"store_true\", default=False)\n    parser.add_argument(\"--use_diagonal_matrix\", action=\"store_true\", default=False)\n\n    parser.add_argument(\"--test_teacher\", action=\"store_true\", default=False)\n    parser.add_argument(\"--test_metric\", action=\"store_true\", default=False)\n    parser.add_argument(\n        \"--test_type_trainval\", action=\"store_true\", default=False\n    )  # XXX\n\n    parser.add_argument(\"--PE\", type=int, default=10)\n    parser.add_argument(\"--nerf_layer_num\", type=int, default=8)\n    parser.add_argument(\"--nerf_layer_wide\", type=int, default=256)\n    parser.add_argument(\"--skip\", type=int, default=3)\n    parser.add_argument(\"--residual\", type=int, default=3)\n\n    parser.add_argument(\"--resolution0\", type=int, default=300)\n    parser.add_argument(\"--resolution1\", type=int, default=300)\n    parser.add_argument(\n        \"--upsample_model_steps\", type=int, action=\"append\", default=[1e10]\n    )\n\n    parser.add_argument(\"--teacher_type\", default=\"hash\", type=str)\n    parser.add_argument(\"--model_type\", default=\"hash\", type=str)\n    parser.add_argument(\n        \"--data_type\",\n        default=\"synthetic\",\n        type=str,\n        choices=[\"synthetic\", \"llff\", \"tank\"],\n    )\n\n    parser.add_argument(\"--update_stu_extra\", action=\"store_true\", default=False)\n    parser.add_argument(\"--ema_decay\", type=float, default=-1.0)\n    parser.add_argument(\"--grid_size\", type=int, default=128)\n\n    parser.add_argument(\"--plenoxel_degree\", type=int, default=3)\n    parser.add_argument(\"--plenoxel_res\", type=str, default=\"[128,128,128]\")\n\n    parser.add_argument(\"--load_args\", action=\"store_true\", default=False)\n\n    parser.add_argument(\"--eval_interval_epoch\", default=1e5, type=int, help=\"\")\n\n    parser.add_argument(\n        \"--use_real_data_for_train\",\n        action=\"store_true\",\n        default=False,\n    )\n\n    parser.add_argument(\"--enable_embed\", action=\"store_true\")\n    parser.add_argument(\"--enable_edit_plenoxel\", action=\"store_true\")\n    parser.add_argument(\n        \"--stage_iters\", type=str, default=\"{'stage1':2000, 'stage2':5000}\"\n    )\n\n    opt = parser.parse_args()\n    opt.stage_iters = eval(opt.stage_iters)\n    opt.O = True  # always use -O\n    opt.render_stu_first = True\n    if opt.model_type == \"mlp\":\n        opt.lr *= 0.1\n    if (\n        \"tensors\" == opt.model_type or \"tensors\" == opt.teacher_type\n    ):  # plenoxel have no features\n        opt.stage_iters[\"stage1\"] = -1\n    save_codes_env(opt.workspace)\n\n    if opt.load_args:\n        load_from_txt(opt)\n    if opt.O:\n        opt.fp16 = True\n        opt.cuda_ray = True\n        opt.preload = True\n\n    assert opt.model_type in [\"hash\", \"mlp\", \"vm\", \"tensors\"]\n    assert opt.teacher_type in [\"hash\", \"mlp\", \"vm\", \"tensors\"]\n    print(opt)\n    seed_everything(opt.seed)\n\n    model_tea = NeRFNetwork(\n        encoding=\"hashgrid\",\n        bound=opt.bound,\n        cuda_ray=opt.cuda_ray,\n        density_scale=1,\n        min_near=opt.min_near,\n        density_thresh=opt.density_thresh,\n        bg_radius=opt.bg_radius,\n        model_type=opt.teacher_type,\n        args=opt,\n        grid_size=opt.grid_size,\n        is_teacher=True,\n    )\n\n    model_stu = NeRFNetwork(\n        encoding=\"hashgrid\",\n        bound=opt.bound,\n        cuda_ray=opt.cuda_ray,\n        density_scale=1,\n        min_near=opt.min_near,\n        density_thresh=opt.density_thresh,\n        bg_radius=opt.bg_radius,\n        model_type=opt.model_type,\n        args=opt,\n        grid_size=opt.grid_size,\n    )\n\n    print(\"\\nteacher:\", model_tea)\n    print(f\"\\n{opt.model_type}\", model_stu)\n\n    criterion = torch.nn.MSELoss(reduction=\"none\")\n\n    # ------------------------------------ test-test-test-test-test  ----------------------------------------------\n    if opt.test or opt.test_teacher or opt.test_type_trainval:\n        trainer = Trainer(\n            f\"{opt.teacher_type}2{opt.model_type}\",\n            opt,\n            model_tea,\n            model_stu,\n            device=device,\n            workspace=opt.workspace,\n            criterion=criterion,\n            fp16=opt.fp16,\n            metrics=[PSNRMeter()],\n            use_checkpoint=opt.ckpt,\n            ema_decay=opt.ema_decay,\n        )\n\n        if opt.test_type_trainval:\n            test_loader = NeRFDataset(opt, device=device, type=\"trainval\").dataloader()\n        else:\n            test_loader = NeRFDataset(opt, device=device, type=\"test\").dataloader()\n        if opt.mode == \"blender\":\n            trainer.evaluate(test_loader)\n        else:\n            trainer.test(test_loader)\n\n    # ------------------------------------ train-train-train-train  ----------------------------------------------\n    else:\n        for p in model_tea.parameters():\n            p.requires_grad = False\n        if opt.distill_mode == \"fix_mlp\":\n            for n, p in model_stu.named_parameters():\n                if \"sigma_net\" in n or \"color_net\" in n:\n                    p.requires_grad = False\n            idx = 1 if opt.model_type == \"vm\" else 3\n            optimizer = lambda model_stu: torch.optim.AdamW(\n                model_stu.get_params(opt.lr)[idx:],\n                betas=(0.9, 0.99),\n                eps=1e-15,\n                amsgrad=False,\n            )\n        else:\n            optimizer = lambda model_stu: torch.optim.AdamW(\n                model_stu.get_params(opt.lr),\n                betas=(0.9, 0.99),\n                eps=1e-15,\n                amsgrad=False,\n            )\n        # fake train loader. The real random data for distillating will be generated in utils.py\n        train_loader = NeRFDataset(opt, device=device, type=\"train\").dataloader()\n        opt.iters = opt.iters + opt.iters % len(\n            train_loader\n        )  # will be updated in utils according to the number of random data\n        max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)\n        scheduler = lambda optimizer: optim.lr_scheduler.CosineAnnealingLR(\n            optimizer, T_max=opt.iters * 1, eta_min=5e-5\n        )\n\n        trainer = Trainer(\n            f\"{opt.teacher_type}2{opt.model_type}\",\n            opt,\n            model_tea,\n            model_stu,\n            device=device,\n            workspace=opt.workspace,\n            optimizer=optimizer,\n            criterion=criterion,\n            ema_decay=opt.ema_decay,\n            fp16=opt.fp16,\n            lr_scheduler=scheduler,\n            scheduler_update_every_step=True,\n            metrics=[PSNRMeter()],\n            use_checkpoint=opt.ckpt,\n            eval_interval=opt.eval_interval_epoch,\n        )\n        upsample_resolutions = (\n            (\n                np.round(\n                    np.exp(\n                        np.linspace(\n                            np.log(opt.resolution0),\n                            np.log(opt.resolution1),\n                            len(opt.upsample_model_steps) + 1,\n                        )\n                    )\n                )\n            )\n            .astype(np.int32)\n            .tolist()[1:]\n        )\n        trainer.upsample_resolutions = upsample_resolutions\n        argstxt = sorted(opt.__dict__.items())\n        with open(os.path.join(opt.workspace, \"args.txt\"), \"w\") as f:\n            for t in argstxt:\n                f.write(str(t) + \"\\n\")\n        start_time = time.time()\n        valid_loader = NeRFDataset(\n            opt, device=device, type=\"val\", downscale=1\n        ).dataloader()\n        test_loader = NeRFDataset(opt, device=device, type=\"test\").dataloader()\n        trainer.train(train_loader, valid_loader, max_epoch)\n\n        end_time = time.time()\n        train_time = end_time - start_time\n        print(f\"\\nusing_time : {train_time:.2f}s\\n\")\n\n        # run test data\n        test_loader = NeRFDataset(opt, device=device, type=\"test\").dataloader()\n        print(opt.workspace)\n\n        trainer.evaluate(test_loader)\n\n        with open(os.path.join(trainer.workspace, \"args.txt\"), \"a+\") as f:\n            txt = f\"\\npsnr: {trainer.psnr:.2f} \\nssim: {trainer.ssim:.3f} \\nalex: {trainer.lpips_alex:.3f}\\nvgg:{trainer.lpips_vgg:.3f}\"\n            f.write(txt)\n        cmd = f\"mv {trainer.workspace} {trainer.workspace}-pnsr{trainer.psnr}\"\n        os.system(cmd)\n"
  },
  {
    "path": "main_just_train_tea.py",
    "content": "import torch\nimport os\nimport argparse\n\nfrom just_train_tea.network import NeRFNetwork\n\nfrom functools import partial\nfrom just_train_tea.provider import NeRFDataset\nfrom just_train_tea.utils import *\nfrom time import time\n\n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"path\", type=str)\n    parser.add_argument(\n        \"-O\", action=\"store_true\", help=\"equals --fp16 --cuda_ray --preload\"\n    )\n    parser.add_argument(\"--test\", action=\"store_true\", help=\"test mode\")\n    parser.add_argument(\"--workspace\", type=str, default=\"workspace\")\n    parser.add_argument(\"--seed\", type=int, default=0)\n\n    ### training options\n    parser.add_argument(\"--iters\", type=int, default=40000, help=\"training iters\")\n    parser.add_argument(\"--lr\", type=float, default=1e-2, help=\"initial learning rate\")\n    parser.add_argument(\"--ckpt\", type=str, default=\"latest\")\n    parser.add_argument(\n        \"--num_rays\",\n        type=int,\n        default=8192,\n        help=\"num rays sampled per image for each training step\",\n    )\n    parser.add_argument(\n        \"--cuda_ray\",\n        action=\"store_true\",\n        help=\"use CUDA raymarching instead of pytorch\",\n    )\n    parser.add_argument(\n        \"--max_steps\",\n        type=int,\n        default=1024,\n        help=\"max num steps sampled per ray (only valid when using --cuda_ray)\",\n    )\n    parser.add_argument(\n        \"--num_steps\",\n        type=int,\n        default=512,\n        help=\"num steps sampled per ray (only valid when NOT using --cuda_ray)\",\n    )\n    parser.add_argument(\n        \"--upsample_steps\",\n        type=int,\n        default=0,\n        help=\"num steps up-sampled per ray (only valid when NOT using --cuda_ray)\",\n    )\n    parser.add_argument(\n        \"--update_extra_interval\",\n        type=int,\n        default=16,\n        help=\"iter interval to update extra status (only valid when using --cuda_ray)\",\n    )\n    parser.add_argument(\n        \"--max_ray_batch\",\n        type=int,\n        default=4096,\n        help=\"batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)\",\n    )\n\n    parser.add_argument(\n        \"--fp16\", action=\"store_true\", help=\"use amp mixed precision training\"\n    )\n    parser.add_argument(\"--ff\", action=\"store_true\", help=\"use fully-fused MLP\")\n    parser.add_argument(\"--tcnn\", action=\"store_true\", help=\"use TCNN backend\")\n\n    parser.add_argument(\n        \"--mode\",\n        type=str,\n        default=\"blender\",\n        help=\"dataset mode, supports (colmap, blender)\",\n    )\n    parser.add_argument(\n        \"--color_space\",\n        type=str,\n        default=\"srgb\",\n        help=\"Color space, supports (linear, srgb)\",\n    )\n    parser.add_argument(\n        \"--preload\",\n        action=\"store_true\",\n        help=\"preload all data into GPU, accelerate training but use more GPU memory\",\n    )\n    # (the default value is for the fox dataset)\n    parser.add_argument(\n        \"--bound\",\n        type=float,\n        default=1,\n        help=\"assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.\",\n    )\n    parser.add_argument(\n        \"--scale\",\n        type=float,\n        default=0.8,\n        help=\"scale camera location into box[-bound, bound]^3\",\n    )\n    parser.add_argument(\n        \"--dt_gamma\",\n        type=float,\n        default=0,\n        help=\"dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)\",\n    )\n    parser.add_argument(\n        \"--min_near\", type=float, default=0.2, help=\"minimum near distance for camera\"\n    )\n    parser.add_argument(\n        \"--density_thresh\",\n        type=float,\n        default=10,\n        help=\"threshold for density grid to be occupied\",\n    )\n    parser.add_argument(\n        \"--bg_radius\",\n        type=float,\n        default=-1,\n        help=\"if positive, use a background model at sphere(bg_radius)\",\n    )\n\n    ### GUI options\n    parser.add_argument(\"--gui\", action=\"store_true\", help=\"start a GUI\")\n    parser.add_argument(\"--W\", type=int, default=1920, help=\"GUI width\")\n    parser.add_argument(\"--H\", type=int, default=1080, help=\"GUI height\")\n    parser.add_argument(\n        \"--radius\", type=float, default=5, help=\"default GUI camera radius from center\"\n    )\n    parser.add_argument(\n        \"--fovy\", type=float, default=50, help=\"default GUI camera fovy\"\n    )\n    parser.add_argument(\n        \"--max_spp\", type=int, default=64, help=\"GUI rendering max sample per pixel\"\n    )\n\n    ### experimental\n    parser.add_argument(\n        \"--error_map\", action=\"store_true\", help=\"use error map to sample rays\"\n    )\n    parser.add_argument(\n        \"--clip_text\", type=str, default=\"\", help=\"text input for CLIP guidance\"\n    )\n    parser.add_argument(\n        \"--rand_pose\",\n        type=int,\n        default=-1,\n        help=\"<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses\",\n    )\n\n    parser.add_argument(\n        \"--distill_mode\",\n        type=str,\n        default=\"no_fix_mlp\",\n        choices=[\"fix_mlp\", \"no_fix_mlp\"],\n    )\n    parser.add_argument(\"--loss_rate_rgb\", type=float, default=1.0)\n    parser.add_argument(\"--loss_rate_fea\", type=float, default=0.1)\n    parser.add_argument(\"--loss_rate_fea_sc\", type=float, default=0.1)\n    parser.add_argument(\"--loss_rate_color\", type=float, default=0.0)\n    parser.add_argument(\"--loss_rate_sigma\", type=float, default=0)\n    parser.add_argument(\n        \"--L1_tensorAB_reg\", type=float, default=1e-3, help=\"reg for tensor_ab\"\n    )\n    parser.add_argument(\"--l1_reg_weight\", type=float, default=1e-4)\n\n    parser.add_argument(\"--ckpt_teacher\", type=str, default=\"\")\n    parser.add_argument(\"--ckpt_student\", type=str, default=\"\")\n    parser.add_argument(\"--sigma_clip_min\", type=float, default=-2)\n    parser.add_argument(\"--sigma_clip_max\", type=float, default=7)\n    parser.add_argument(\"--use_sigma_clip\", action=\"store_true\")\n    parser.add_argument(\"--render_stu_first\", action=\"store_true\", default=False)\n    parser.add_argument(\"--nerf_pe\", action=\"store_true\", default=False)\n    parser.add_argument(\"--use_real_gt\", action=\"store_true\", default=False)\n    parser.add_argument(\"--use_diagonal_matrix\", action=\"store_true\", default=False)\n    parser.add_argument(\n        \"--loss_rate_real_gt\", type=float, default=0, help=\"range in [0, 1]\"\n    )\n    parser.add_argument(\"--test_teacher\", action=\"store_true\", default=False)\n    parser.add_argument(\"--test_metric\", action=\"store_true\", default=False)\n\n    parser.add_argument(\"--resolution0\", type=int, default=300)\n    parser.add_argument(\"--resolution1\", type=int, default=300)\n    parser.add_argument(\n        \"--upsample_model_steps\", type=int, action=\"append\", default=[1e10]\n    )\n\n    parser.add_argument(\n        \"--loss_type\", type=str, default=\"L2\", choices=[\"normL2\", \"L2\", \"normL1\", \"L1\"]\n    )\n\n    parser.add_argument(\"--PE\", type=int, default=10)\n    parser.add_argument(\"--nerf_layer_num\", type=int, default=8)\n    parser.add_argument(\"--nerf_layer_wide\", type=int, default=256)\n    parser.add_argument(\"--skip\", type=int, default=3)\n    parser.add_argument(\"--residual\", type=int, default=3)\n\n    parser.add_argument(\"--model_type\", default=\"hash\", type=str)\n    parser.add_argument(\"--teacher_type\", default=\"hash\", type=str)\n\n    parser.add_argument(\"--use_upsample_vm\", action=\"store_true\", default=False)\n    parser.add_argument(\"--update_stu_extra\", action=\"store_true\", default=False)\n    parser.add_argument(\"--ema_decay\", type=float, default=-1)\n    parser.add_argument(\"--grid_size\", type=int, default=128)\n\n    parser.add_argument(\"--plenoxel_degree\", type=int, default=3)\n    parser.add_argument(\"--plenoxel_res\", type=str, default=\"[128,128,128]\")\n    parser.add_argument(\"--just_train_a_model\", action=\"store_true\", default=False)\n    parser.add_argument(\"--data_type\", type=str, default=\"\")\n\n    opt = parser.parse_args()\n    opt.just_train_a_model = True\n    opt.update_stu_extra = True\n    opt.render_stu_first = True\n    opt.O = True\n    if opt.model_type == \"mlp\":\n        opt.lr *= 0.1\n\n    if opt.O:\n        opt.fp16 = True\n        opt.cuda_ray = True\n        opt.preload = True\n\n    assert opt.model_type in [\"hash\", \"mlp\", \"vm\", \"tensors\"]\n    print(opt)\n    seed_everything(opt.seed)\n\n    model_tea = NeRFNetwork(\n        encoding=\"hashgrid\",\n        bound=opt.bound,\n        cuda_ray=opt.cuda_ray,\n        density_scale=1,\n        min_near=opt.min_near,\n        density_thresh=opt.density_thresh,\n        bg_radius=opt.bg_radius,\n        model_type=opt.teacher_type,\n        args=opt,\n        grid_size=opt.grid_size,\n        is_teacher=True,\n    )\n\n    model_stu = NeRFNetwork(\n        encoding=\"hashgrid\",\n        bound=opt.bound,\n        cuda_ray=opt.cuda_ray,\n        density_scale=1,\n        min_near=opt.min_near,\n        density_thresh=opt.density_thresh,\n        bg_radius=opt.bg_radius,\n        model_type=opt.model_type,\n        args=opt,\n        grid_size=opt.grid_size,\n    )\n\n    criterion = torch.nn.MSELoss(reduction=\"none\")\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    if opt.test or opt.test_teacher or opt.test_metric:\n        trainer = Trainer(\n            opt.model_type,\n            opt,\n            model_tea,\n            model_stu,\n            device=device,\n            workspace=opt.workspace,\n            criterion=criterion,\n            fp16=opt.fp16,\n            metrics=[PSNRMeter()],\n            use_checkpoint=opt.ckpt,\n            ema_decay=opt.ema_decay,\n        )\n        test_loader = NeRFDataset(opt, device=device, type=\"test\").dataloader()\n        trainer.evaluate(test_loader)\n    else:\n        for p in model_tea.parameters():\n            p.requires_grad = False\n        optimizer = lambda model_stu: torch.optim.AdamW(\n            model_stu.get_params(opt.lr, opt.lr * 0.1),\n            betas=(0.9, 0.99),\n            eps=1e-15,\n            amsgrad=False,\n        )\n        train_loader = NeRFDataset(opt, device=device, type=\"train\").dataloader()\n        valid_loader = NeRFDataset(opt, device=device, type=\"val\").dataloader()\n        test_loader = NeRFDataset(opt, device=device, type=\"test\").dataloader()\n\n        if opt.just_train_a_model:\n            scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(\n                optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1)\n            )\n        else:\n            scheduler = lambda optimizer: optim.lr_scheduler.CosineAnnealingLR(\n                optimizer, T_max=opt.iters * 1\n            )\n        print(scheduler)\n\n        trainer = Trainer(\n            opt.model_type,\n            opt,\n            model_tea,\n            model_stu,\n            device=device,\n            workspace=opt.workspace,\n            optimizer=optimizer,\n            criterion=criterion,\n            ema_decay=opt.ema_decay,\n            fp16=opt.fp16,\n            lr_scheduler=scheduler,\n            scheduler_update_every_step=True,\n            metrics=[PSNRMeter()],\n            use_checkpoint=opt.ckpt,\n            eval_interval=500000000,\n        )\n        upsample_resolutions = (\n            (\n                np.round(\n                    np.exp(\n                        np.linspace(\n                            np.log(opt.resolution0),\n                            np.log(opt.resolution1),\n                            len(opt.upsample_model_steps) + 1,\n                        )\n                    )\n                )\n            )\n            .astype(np.int32)\n            .tolist()[1:]\n        )\n        trainer.upsample_resolutions = upsample_resolutions\n        argstxt = sorted(opt.__dict__.items())\n        with open(os.path.join(opt.workspace, \"args.txt\"), \"w\") as f:\n            for t in argstxt:\n                f.write(str(t) + \"\\n\")\n\n        start_time = time()\n\n        max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32)\n        trainer.train(train_loader, valid_loader, max_epoch)\n        print(opt.workspace)\n\n        trainer.evaluate(test_loader)\n\n        with open(os.path.join(trainer.workspace, \"args.txt\"), \"a+\") as f:\n            txt = f\"\\npsnr: {trainer.psnr:.2f} \\nssim: {trainer.ssim:.3f} \\nalex: {trainer.lpips_alex:.3f}\\nvgg:{trainer.lpips_vgg:.3f}\"\n            f.write(txt)\n        cmd = f\"mv {trainer.workspace} {trainer.workspace}-pnsr{trainer.psnr}\"\n        print(f\"\\n{cmd}\\n\")\n        os.system(cmd)\n"
  },
  {
    "path": "raymarching/__init__.py",
    "content": "from .raymarching import *\n"
  },
  {
    "path": "raymarching/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(\n    name=\"_raymarching\",\n    extra_cflags=c_flags,\n    extra_cuda_cflags=nvcc_flags,\n    sources=[\n        os.path.join(_src_path, \"src\", f)\n        for f in [\n            \"raymarching.cu\",\n            \"bindings.cpp\",\n        ]\n    ],\n)\n\n__all__ = [\"_backend\"]\n"
  },
  {
    "path": "raymarching/raymarching.py",
    "content": "import numpy as np\nimport time\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    import _raymarching as _backend\nexcept ImportError:\n    from .backend import _backend\n\n\n# ----------------------------------------\n# utils\n# ----------------------------------------\n\n\nclass _near_far_from_aabb(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):\n        \"\"\"near_far_from_aabb, CUDA implementation\n        Calculate rays' intersection time (near and far) with aabb\n        Args:\n            rays_o: float, [N, 3]\n            rays_d: float, [N, 3]\n            aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)\n            min_near: float, scalar\n        Returns:\n            nears: float, [N]\n            fars: float, [N]\n        \"\"\"\n        if not rays_o.is_cuda:\n            rays_o = rays_o.cuda()\n        if not rays_d.is_cuda:\n            rays_d = rays_d.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # num rays\n\n        nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)\n        fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)\n\n        _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)\n\n        return nears, fars\n\n\nnear_far_from_aabb = _near_far_from_aabb.apply\n\n\nclass _polar_from_ray(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, rays_o, rays_d, radius):\n        \"\"\"polar_from_ray, CUDA implementation\n        get polar coordinate on the background sphere from rays.\n        Assume rays_o are inside the Sphere(radius).\n        Args:\n            rays_o: [N, 3]\n            rays_d: [N, 3]\n            radius: scalar, float\n        Return:\n            coords: [N, 2], in [-1, 1], theta and phi on a sphere.\n        \"\"\"\n        if not rays_o.is_cuda:\n            rays_o = rays_o.cuda()\n        if not rays_d.is_cuda:\n            rays_d = rays_d.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        N = rays_o.shape[0]  # num rays\n\n        coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)\n\n        _backend.polar_from_ray(rays_o, rays_d, radius, N, coords)\n\n        return coords\n\n\npolar_from_ray = _polar_from_ray.apply\n\n\nclass _morton3D(Function):\n    @staticmethod\n    def forward(ctx, coords):\n        \"\"\"morton3D, CUDA implementation\n        Args:\n            coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)\n            TODO: check if the coord range is valid! (current 128 is safe)\n        Returns:\n            indices: [N], int32, in [0, 128^3)\n\n        \"\"\"\n        if not coords.is_cuda:\n            coords = coords.cuda()\n\n        N = coords.shape[0]\n\n        indices = torch.empty(N, dtype=torch.int32, device=coords.device)\n\n        _backend.morton3D(coords.int(), N, indices)\n\n        return indices\n\n\nmorton3D = _morton3D.apply\n\n\nclass _morton3D_invert(Function):\n    @staticmethod\n    def forward(ctx, indices):\n        \"\"\"morton3D_invert, CUDA implementation\n        Args:\n            indices: [N], int32, in [0, 128^3)\n        Returns:\n            coords: [N, 3], int32, in [0, 128)\n\n        \"\"\"\n        if not indices.is_cuda:\n            indices = indices.cuda()\n\n        N = indices.shape[0]\n\n        coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)\n\n        _backend.morton3D_invert(indices.int(), N, coords)\n\n        return coords\n\n\nmorton3D_invert = _morton3D_invert.apply\n\n\nclass _packbits(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, grid, thresh, bitfield=None):\n        \"\"\"packbits, CUDA implementation\n        Pack up the density grid into a bit field to accelerate ray marching.\n        Args:\n            grid: float, [C, H * H * H], assume H % 2 == 0\n            thresh: float, threshold\n        Returns:\n            bitfield: uint8, [C, H * H * H / 8]\n        \"\"\"\n        if not grid.is_cuda:\n            grid = grid.cuda()\n        grid = grid.contiguous()\n\n        C = grid.shape[0]\n        H3 = grid.shape[1]\n        N = C * H3 // 8\n\n        if bitfield is None:\n            bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)\n\n        _backend.packbits(grid, N, thresh, bitfield)\n\n        return bitfield\n\n\npackbits = _packbits.apply\n\n# ----------------------------------------\n# train functions\n# ----------------------------------------\n\n\nclass _march_rays_train(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(\n        ctx,\n        rays_o,\n        rays_d,\n        bound,\n        density_bitfield,\n        C,\n        H,\n        nears,\n        fars,\n        step_counter=None,\n        mean_count=-1,\n        perturb=False,\n        align=-1,\n        force_all_rays=False,\n        dt_gamma=0,\n        max_steps=1024,\n    ):\n        \"\"\"march rays to generate points (forward only)\n        Args:\n            rays_o/d: float, [N, 3]\n            bound: float, scalar\n            density_bitfield: uint8: [CHHH // 8]\n            C: int\n            H: int\n            nears/fars: float, [N]\n            step_counter: int32, (2), used to count the actual number of generated points.\n            mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)\n            perturb: bool\n            align: int, pad output so its size is dividable by align, set to -1 to disable.\n            force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.\n            dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)\n            max_steps: int, max number of sampled points along each ray, also affect min_stepsize.\n        Returns:\n            xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)\n            dirs: float, [M, 3], all generated points' view dirs.\n            deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth)\n            rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0]\n        \"\"\"\n\n        if not rays_o.is_cuda:\n            rays_o = rays_o.cuda()\n        if not rays_d.is_cuda:\n            rays_d = rays_d.cuda()\n        if not density_bitfield.is_cuda:\n            density_bitfield = density_bitfield.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n        density_bitfield = density_bitfield.contiguous()\n\n        N = rays_o.shape[0]  # num rays\n        M = N * max_steps  # init max points number in total\n\n        # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp)\n        # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated.\n        if not force_all_rays and mean_count > 0:\n            if align > 0:\n                mean_count += align - mean_count % align\n            M = mean_count\n\n        xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)\n        rays = torch.empty(\n            N, 3, dtype=torch.int32, device=rays_o.device\n        )  # id, offset, num_steps\n\n        if step_counter is None:\n            step_counter = torch.zeros(\n                2, dtype=torch.int32, device=rays_o.device\n            )  # point counter, ray counter\n\n        _backend.march_rays_train(\n            rays_o,\n            rays_d,\n            density_bitfield,\n            bound,\n            dt_gamma,\n            max_steps,\n            N,\n            C,\n            H,\n            M,\n            nears,\n            fars,\n            xyzs,\n            dirs,\n            deltas,\n            rays,\n            step_counter,\n            perturb,\n        )  # m is the actually used points number\n\n        # print(step_counter, M)\n\n        # only used at the first (few) epochs.\n        if force_all_rays or mean_count <= 0:\n            m = step_counter[0].item()  # D2H copy\n            if align > 0:\n                m += align - m % align\n            xyzs = xyzs[:m]\n            dirs = dirs[:m]\n            deltas = deltas[:m]\n\n            torch.cuda.empty_cache()\n\n        return xyzs, dirs, deltas, rays\n\n\nmarch_rays_train = _march_rays_train.apply\n\n\nclass _composite_rays_train(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(ctx, sigmas, rgbs, deltas, rays):\n        \"\"\"composite rays' rgbs, according to the ray marching formula.\n        Args:\n            rgbs: float, [M, 3]\n            sigmas: float, [M,]\n            deltas: float, [M, 2]\n            rays: int32, [N, 3]\n        Returns:\n            weights_sum: float, [N,], the alpha channel\n            depth: float, [N, ], the Depth\n            image: float, [N, 3], the RGB channel (after multiplying alpha!)\n        \"\"\"\n\n        sigmas = sigmas.contiguous()\n        rgbs = rgbs.contiguous()\n\n        M = sigmas.shape[0]\n        N = rays.shape[0]\n\n        weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)\n        depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)\n        image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)\n\n        _backend.composite_rays_train_forward(\n            sigmas, rgbs, deltas, rays, M, N, weights_sum, depth, image\n        )\n\n        ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)\n        ctx.dims = [M, N]\n\n        return weights_sum, depth, image\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, grad_weights_sum, grad_depth, grad_image):\n\n        # NOTE: grad_depth is not used now! It won't be propagated to sigmas.\n\n        grad_weights_sum = grad_weights_sum.contiguous()\n        grad_image = grad_image.contiguous()\n\n        sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors\n        M, N = ctx.dims\n\n        grad_sigmas = torch.zeros_like(sigmas)\n        grad_rgbs = torch.zeros_like(rgbs)\n\n        _backend.composite_rays_train_backward(\n            grad_weights_sum,\n            grad_image,\n            sigmas,\n            rgbs,\n            deltas,\n            rays,\n            weights_sum,\n            image,\n            M,\n            N,\n            grad_sigmas,\n            grad_rgbs,\n        )\n\n        return grad_sigmas, grad_rgbs, None, None\n\n\ncomposite_rays_train = _composite_rays_train.apply\n\n# ----------------------------------------\n# infer functions\n# ----------------------------------------\n\n\nclass _march_rays(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(\n        ctx,\n        n_alive,\n        n_step,\n        rays_alive,\n        rays_t,\n        rays_o,\n        rays_d,\n        bound,\n        density_bitfield,\n        C,\n        H,\n        near,\n        far,\n        align=-1,\n        perturb=False,\n        dt_gamma=0,\n        max_steps=1024,\n    ):\n        \"\"\"march rays to generate points (forward only, for inference)\n        Args:\n            n_alive: int, number of alive rays\n            n_step: int, how many steps we march\n            rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)\n            rays_t: float, [N], the alive rays' time, we only use the first n_alive.\n            rays_o/d: float, [N, 3]\n            bound: float, scalar\n            density_bitfield: uint8: [CHHH // 8]\n            C: int\n            H: int\n            nears/fars: float, [N]\n            align: int, pad output so its size is dividable by align, set to -1 to disable.\n            perturb: bool/int, int > 0 is used as the random seed.\n            dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)\n            max_steps: int, max number of sampled points along each ray, also affect min_stepsize.\n        Returns:\n            xyzs: float, [n_alive * n_step, 3], all generated points' coords\n            dirs: float, [n_alive * n_step, 3], all generated points' view dirs.\n            deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).\n        \"\"\"\n\n        if not rays_o.is_cuda:\n            rays_o = rays_o.cuda()\n        if not rays_d.is_cuda:\n            rays_d = rays_d.cuda()\n\n        rays_o = rays_o.contiguous().view(-1, 3)\n        rays_d = rays_d.contiguous().view(-1, 3)\n\n        M = n_alive * n_step\n\n        if align > 0:\n            M += align - (M % align)\n\n        xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)\n        deltas = torch.zeros(\n            M, 2, dtype=rays_o.dtype, device=rays_o.device\n        )  # 2 vals, one for rgb, one for depth\n\n        _backend.march_rays(\n            n_alive,\n            n_step,\n            rays_alive,\n            rays_t,\n            rays_o,\n            rays_d,\n            bound,\n            dt_gamma,\n            max_steps,\n            C,\n            H,\n            density_bitfield,\n            near,\n            far,\n            xyzs,\n            dirs,\n            deltas,\n            perturb,\n        )\n\n        return xyzs, dirs, deltas\n\n\nmarch_rays = _march_rays.apply\n\n\nclass _composite_rays(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)  # need to cast sigmas & rgbs to float\n    def forward(\n        ctx,\n        n_alive,\n        n_step,\n        rays_alive,\n        rays_t,\n        sigmas,\n        rgbs,\n        deltas,\n        weights_sum,\n        depth,\n        image,\n    ):\n        \"\"\"composite rays' rgbs, according to the ray marching formula. (for inference)\n        Args:\n            n_alive: int, number of alive rays\n            n_step: int, how many steps we march\n            rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)\n            rays_t: float, [N], the alive rays' time, we only use the first n_alive.\n            sigmas: float, [n_alive * n_step,]\n            rgbs: float, [n_alive * n_step, 3]\n            deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).\n        In-place Outputs:\n            weights_sum: float, [N,], the alpha channel\n            depth: float, [N,], the depth value\n            image: float, [N, 3], the RGB channel (after multiplying alpha!)\n        \"\"\"\n        _backend.composite_rays(\n            n_alive,\n            n_step,\n            rays_alive,\n            rays_t,\n            sigmas,\n            rgbs,\n            deltas,\n            weights_sum,\n            depth,\n            image,\n        )\n        return tuple()\n\n\ncomposite_rays = _composite_rays.apply\n\n\nclass _compact_rays(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)\n    def forward(\n        ctx, n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter\n    ):\n        \"\"\"compact rays, remove dead rays and reallocate alive rays, to accelerate next ray marching.\n        Args:\n            n_alive: int, number of alive rays\n            rays_alive_old: int, [N]\n            rays_t_old: float, [N], dead rays are marked by rays_t < 0\n            alive_counter: int, [1], used to count remained alive rays.\n        In-place Outputs:\n            rays_alive: int, [N]\n            rays_t: float, [N]\n        \"\"\"\n        _backend.compact_rays(\n            n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter\n        )\n        return tuple()\n\n\ncompact_rays = _compact_rays.apply\n"
  },
  {
    "path": "raymarching/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n\"\"\"\nUsage:\n\npython setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)\n\npython setup.py install # build extensions and install (copy) to PATH.\npip install . # ditto but better (e.g., dependency & metadata handling)\n\npython setup.py develop # build extensions and install (symbolic) to PATH.\npip install -e . # ditto but better (e.g., dependency & metadata handling)\n\n\"\"\"\nsetup(\n    name=\"raymarching\",  # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name=\"_raymarching\",  # extension name, import this to use CUDA API\n            sources=[\n                os.path.join(_src_path, \"src\", f)\n                for f in [\n                    \"raymarching.cu\",\n                    \"bindings.cpp\",\n                ]\n            ],\n            extra_compile_args={\n                \"cxx\": c_flags,\n                \"nvcc\": nvcc_flags,\n            },\n        ),\n    ],\n    cmdclass={\n        \"build_ext\": BuildExtension,\n    },\n)\n"
  },
  {
    "path": "raymarching/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"raymarching.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    // utils\n    m.def(\"packbits\", &packbits, \"packbits (CUDA)\");\n    m.def(\"near_far_from_aabb\", &near_far_from_aabb, \"near_far_from_aabb (CUDA)\");\n    m.def(\"polar_from_ray\", &polar_from_ray, \"polar_from_ray (CUDA)\");\n    m.def(\"morton3D\", &morton3D, \"morton3D (CUDA)\");\n    m.def(\"morton3D_invert\", &morton3D_invert, \"morton3D_invert (CUDA)\");\n    // train\n    m.def(\"march_rays_train\", &march_rays_train, \"march_rays_train (CUDA)\");\n    m.def(\"composite_rays_train_forward\", &composite_rays_train_forward, \"composite_rays_train_forward (CUDA)\");\n    m.def(\"composite_rays_train_backward\", &composite_rays_train_backward, \"composite_rays_train_backward (CUDA)\");\n    // infer\n    m.def(\"march_rays\", &march_rays, \"march rays (CUDA)\");\n    m.def(\"composite_rays\", &composite_rays, \"composite rays (CUDA)\");\n    m.def(\"compact_rays\", &compact_rays, \"compact rays (CUDA)\");\n}"
  },
  {
    "path": "raymarching/src/pcg32.h",
    "content": "/*\n * Tiny self-contained version of the PCG Random Number Generation for C++\n * put together from pieces of the much larger C/C++ codebase.\n * Wenzel Jakob, February 2015\n *\n * The PCG random number generator was developed by Melissa O'Neill\n * <oneill@pcg-random.org>\n *\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *     http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n *\n * For additional information about the PCG random number generation scheme,\n * including its license and other licensing options, visit\n *\n *     http://www.pcg-random.org\n *\n * Note: This code was modified to work with CUDA by the tiny-cuda-nn authors.\n */\n\n#pragma once\n\n#define PCG32_DEFAULT_STATE  0x853c49e6748fea9bULL\n#define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL\n#define PCG32_MULT           0x5851f42d4c957f2dULL\n\n#include <stdint.h>\n#include <cmath>\n#include <cassert>\n\n#include <cuda.h>\n#include <cuda_runtime.h>\n#include <cuda_fp16.h>\n\n/// PCG32 Pseudorandom number generator\nstruct pcg32 {\n\t/// Initialize the pseudorandom number generator with default seed\n\t__host__ __device__ pcg32() : state(PCG32_DEFAULT_STATE), inc(PCG32_DEFAULT_STREAM) {}\n\n\t/// Initialize the pseudorandom number generator with the \\ref seed() function\n\t__host__ __device__ pcg32(uint64_t initstate, uint64_t initseq = 1u) { seed(initstate, initseq); }\n\n\t/**\n\t * \\brief Seed the pseudorandom number generator\n\t *\n\t * Specified in two parts: a state initializer and a sequence selection\n\t * constant (a.k.a. stream id)\n\t */\n\t__host__ __device__ void seed(uint64_t initstate, uint64_t initseq = 1) {\n\t\tstate = 0U;\n\t\tinc = (initseq << 1u) | 1u;\n\t\tnext_uint();\n\t\tstate += initstate;\n\t\tnext_uint();\n\t}\n\n\t/// Generate a uniformly distributed unsigned 32-bit random number\n\t__host__ __device__ uint32_t next_uint() {\n\t\tuint64_t oldstate = state;\n\t\tstate = oldstate * PCG32_MULT + inc;\n\t\tuint32_t xorshifted = (uint32_t) (((oldstate >> 18u) ^ oldstate) >> 27u);\n\t\tuint32_t rot = (uint32_t) (oldstate >> 59u);\n\t\treturn (xorshifted >> rot) | (xorshifted << ((~rot + 1u) & 31));\n\t}\n\n\t/// Generate a uniformly distributed number, r, where 0 <= r < bound\n\t__host__ __device__ uint32_t next_uint(uint32_t bound) {\n\t\t// To avoid bias, we need to make the range of the RNG a multiple of\n\t\t// bound, which we do by dropping output less than a threshold.\n\t\t// A naive scheme to calculate the threshold would be to do\n\t\t//\n\t\t//     uint32_t threshold = 0x100000000ull % bound;\n\t\t//\n\t\t// but 64-bit div/mod is slower than 32-bit div/mod (especially on\n\t\t// 32-bit platforms).  In essence, we do\n\t\t//\n\t\t//     uint32_t threshold = (0x100000000ull-bound) % bound;\n\t\t//\n\t\t// because this version will calculate the same modulus, but the LHS\n\t\t// value is less than 2^32.\n\n\t\tuint32_t threshold = (~bound+1u) % bound;\n\n\t\t// Uniformity guarantees that this loop will terminate.  In practice, it\n\t\t// should usually terminate quickly; on average (assuming all bounds are\n\t\t// equally likely), 82.25% of the time, we can expect it to require just\n\t\t// one iteration.  In the worst case, someone passes a bound of 2^31 + 1\n\t\t// (i.e., 2147483649), which invalidates almost 50% of the range.  In\n\t\t// practice, bounds are typically small and only a tiny amount of the range\n\t\t// is eliminated.\n\t\tfor (;;) {\n\t\t\tuint32_t r = next_uint();\n\t\t\tif (r >= threshold)\n\t\t\t\treturn r % bound;\n\t\t}\n\t}\n\n\t/// Generate a single precision floating point value on the interval [0, 1)\n\t__host__ __device__ float next_float() {\n\t\t/* Trick from MTGP: generate an uniformly distributed\n\t\t\tsingle precision number in [1,2) and subtract 1. */\n\t\tunion {\n\t\t\tuint32_t u;\n\t\t\tfloat f;\n\t\t} x;\n\t\tx.u = (next_uint() >> 9) | 0x3f800000u;\n\t\treturn x.f - 1.0f;\n\t}\n\n\t/**\n\t * \\brief Generate a double precision floating point value on the interval [0, 1)\n\t *\n\t * \\remark Since the underlying random number generator produces 32 bit output,\n\t * only the first 32 mantissa bits will be filled (however, the resolution is still\n\t * finer than in \\ref next_float(), which only uses 23 mantissa bits)\n\t */\n\t__host__ __device__ double next_double() {\n\t\t/* Trick from MTGP: generate an uniformly distributed\n\t\t\tdouble precision number in [1,2) and subtract 1. */\n\t\tunion {\n\t\t\tuint64_t u;\n\t\t\tdouble d;\n\t\t} x;\n\t\tx.u = ((uint64_t) next_uint() << 20) | 0x3ff0000000000000ULL;\n\t\treturn x.d - 1.0;\n\t}\n\n\t/**\n\t * \\brief Multi-step advance function (jump-ahead, jump-back)\n\t *\n\t * The method used here is based on Brown, \"Random Number Generation\n\t * with Arbitrary Stride\", Transactions of the American Nuclear\n\t * Society (Nov. 1994). The algorithm is very similar to fast\n\t * exponentiation.\n\t *\n\t * The default value of 2^32 ensures that the PRNG is advanced\n\t * sufficiently far that there is (likely) no overlap with\n\t * previously drawn random numbers, even if small advancements.\n\t * are made inbetween.\n\t */\n\t__host__ __device__ void advance(int64_t delta_ = (1ll<<32)) {\n\t\tuint64_t\n\t\t\tcur_mult = PCG32_MULT,\n\t\t\tcur_plus = inc,\n\t\t\tacc_mult = 1u,\n\t\t\tacc_plus = 0u;\n\n\t\t/* Even though delta is an unsigned integer, we can pass a signed\n\t\t\tinteger to go backwards, it just goes \"the long way round\". */\n\t\tuint64_t delta = (uint64_t) delta_;\n\n\t\twhile (delta > 0) {\n\t\t\tif (delta & 1) {\n\t\t\t\tacc_mult *= cur_mult;\n\t\t\t\tacc_plus = acc_plus * cur_mult + cur_plus;\n\t\t\t}\n\t\t\tcur_plus = (cur_mult + 1) * cur_plus;\n\t\t\tcur_mult *= cur_mult;\n\t\t\tdelta /= 2;\n\t\t}\n\t\tstate = acc_mult * state + acc_plus;\n\t}\n\n\t/// Compute the distance between two PCG32 pseudorandom number generators\n\t__host__ __device__ int64_t operator-(const pcg32 &other) const {\n\t\tassert(inc == other.inc);\n\n\t\tuint64_t\n\t\t\tcur_mult = PCG32_MULT,\n\t\t\tcur_plus = inc,\n\t\t\tcur_state = other.state,\n\t\t\tthe_bit = 1u,\n\t\t\tdistance = 0u;\n\n\t\twhile (state != cur_state) {\n\t\t\tif ((state & the_bit) != (cur_state & the_bit)) {\n\t\t\t\tcur_state = cur_state * cur_mult + cur_plus;\n\t\t\t\tdistance |= the_bit;\n\t\t\t}\n\t\t\tassert((state & the_bit) == (cur_state & the_bit));\n\t\t\tthe_bit <<= 1;\n\t\t\tcur_plus = (cur_mult + 1ULL) * cur_plus;\n\t\t\tcur_mult *= cur_mult;\n\t\t}\n\n\t\treturn (int64_t) distance;\n\t}\n\n\t/// Equality operator\n\t__host__ __device__ bool operator==(const pcg32 &other) const { return state == other.state && inc == other.inc; }\n\n\t/// Inequality operator\n\t__host__ __device__ bool operator!=(const pcg32 &other) const { return state != other.state || inc != other.inc; }\n\n\tuint64_t state;  // RNG state.  All values are possible.\n\tuint64_t inc;    // Controls which RNG sequence (stream) is selected. Must *always* be odd.\n};"
  },
  {
    "path": "raymarching/src/raymarching.cu",
    "content": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/torch.h>\n\n#include <cstdio>\n#include <stdint.h>\n#include <stdexcept>\n#include <limits>\n\n#include \"pcg32.h\"\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x \" must be a floating tensor\")\n\n\ninline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }\ninline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }\ninline constexpr __device__ float PI() { return 3.141592653589793f; }\ninline constexpr __device__ float RPI() { return 0.3183098861837907f; }\n\n\ntemplate <typename T>\ninline __host__ __device__ T div_round_up(T val, T divisor) {\n    return (val + divisor - 1) / divisor;\n}\n\ninline __host__ __device__ float signf(const float x) {\n    return copysignf(1.0, x);\n}\n\ninline __host__ __device__ float clamp(const float x, const float min, const float max) {\n    return fminf(max, fmaxf(min, x));\n}\n\ninline __host__ __device__ void swapf(float& a, float& b) {\n    float c = a; a = b; b = c;\n}\n\ninline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {\n    const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z)));\n    int exponent;\n    frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...\n    return fminf(max_cascade - 1, fmaxf(0, exponent));\n}\n\ninline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {\n    const float mx = dt * H * 0.5;\n    int exponent;\n    frexpf(mx, &exponent);\n    return fminf(max_cascade - 1, fmaxf(0, exponent));\n}\n\ninline __host__ __device__ uint32_t __expand_bits(uint32_t v)\n{\n\tv = (v * 0x00010001u) & 0xFF0000FFu;\n\tv = (v * 0x00000101u) & 0x0F00F00Fu;\n\tv = (v * 0x00000011u) & 0xC30C30C3u;\n\tv = (v * 0x00000005u) & 0x49249249u;\n\treturn v;\n}\n\ninline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)\n{\n\tuint32_t xx = __expand_bits(x);\n\tuint32_t yy = __expand_bits(y);\n\tuint32_t zz = __expand_bits(z);\n\treturn xx | (yy << 1) | (zz << 2);\n}\n\ninline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)\n{\n\tx = x & 0x49249249;\n\tx = (x | (x >> 2)) & 0xc30c30c3;\n\tx = (x | (x >> 4)) & 0x0f00f00f;\n\tx = (x | (x >> 8)) & 0xff0000ff;\n\tx = (x | (x >> 16)) & 0x0000ffff;\n\treturn x;\n}\n\n\n////////////////////////////////////////////////////\n/////////////           utils          /////////////\n////////////////////////////////////////////////////\n\n// rays_o/d: [N, 3]\n// nears/fars: [N]\n// scalar_t should always be float in use.\ntemplate <typename scalar_t>\n__global__ void kernel_near_far_from_aabb(\n    const scalar_t * __restrict__ rays_o,\n    const scalar_t * __restrict__ rays_d,\n    const scalar_t * __restrict__ aabb,\n    const uint32_t N,\n    const float min_near,\n    scalar_t * nears, scalar_t * fars\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n\n    // get near far (assume cube scene)\n    float near = (aabb[0] - ox) * rdx;\n    float far = (aabb[3] - ox) * rdx;\n    if (near > far) swapf(near, far);\n\n    float near_y = (aabb[1] - oy) * rdy;\n    float far_y = (aabb[4] - oy) * rdy;\n    if (near_y > far_y) swapf(near_y, far_y);\n\n    if (near > far_y || near_y > far) {\n        nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();\n        return;\n    }\n\n    if (near_y > near) near = near_y;\n    if (far_y < far) far = far_y;\n\n    float near_z = (aabb[2] - oz) * rdz;\n    float far_z = (aabb[5] - oz) * rdz;\n    if (near_z > far_z) swapf(near_z, far_z);\n\n    if (near > far_z || near_z > far) {\n        nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();\n        return;\n    }\n\n    if (near_z > near) near = near_z;\n    if (far_z < far) far = far_z;\n\n    if (near < min_near) near = min_near;\n\n    nears[n] = near;\n    fars[n] = far;\n}\n\n\nvoid near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_o.scalar_type(), \"near_far_from_aabb\", ([&] {\n        kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());\n    }));\n}\n\n\n// rays_o/d: [N, 3]\n// radius: float\n// coords: [N, 2]\ntemplate <typename scalar_t>\n__global__ void kernel_polar_from_ray(\n    const scalar_t * __restrict__ rays_o,\n    const scalar_t * __restrict__ rays_d,\n    const float radius,\n    const uint32_t N,\n    scalar_t * coords\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n    coords += n * 2;\n\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n\n    // solve t from || o + td || = radius\n    const float A = dx * dx + dy * dy + dz * dz;\n    const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2\n    const float C = ox * ox + oy * oy + oz * oz - radius * radius;\n\n    const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)\n\n    // solve theta, phi (assume y is the up axis)\n    const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;\n    const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)\n    const float phi = atan2(z, x); // [-PI, PI)\n\n    // normalize to [-1, 1]\n    coords[0] = 2 * theta * RPI() - 1;\n    coords[1] = phi * RPI();\n}\n\n\nvoid polar_from_ray(at::Tensor rays_o, at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_o.scalar_type(), \"polar_from_ray\", ([&] {\n        kernel_polar_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());\n    }));\n}\n\n\n// coords: int32, [N, 3]\n// indices: int32, [N]\n__global__ void kernel_morton3D(\n    const int * __restrict__ coords,\n    const uint32_t N,\n    int * indices\n) {\n    // parallel\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    coords += n * 3;\n    indices[n] = __morton3D(coords[0], coords[1], coords[2]);\n}\n\n\nvoid morton3D(at::Tensor coords, const uint32_t N, at::Tensor indices) {\n    static constexpr uint32_t N_THREAD = 128;\n    kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());\n}\n\n\n// indices: int32, [N]\n// coords: int32, [N, 3]\n__global__ void kernel_morton3D_invert(\n    const int * __restrict__ indices,\n    const uint32_t N,\n    int * coords\n) {\n    // parallel\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    coords += n * 3;\n\n    const int ind = indices[n];\n\n    coords[0] = __morton3D_invert(ind >> 0);\n    coords[1] = __morton3D_invert(ind >> 1);\n    coords[2] = __morton3D_invert(ind >> 2);\n}\n\n\nvoid morton3D_invert(at::Tensor indices, const uint32_t N, at::Tensor coords) {\n    static constexpr uint32_t N_THREAD = 128;\n    kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());\n}\n\n\n// grid: float, [C, H, H, H]\n// N: int, C * H * H * H / 8\n// density_thresh: float\n// bitfield: uint8, [N]\ntemplate <typename scalar_t>\n__global__ void kernel_packbits(\n    const scalar_t * __restrict__ grid,\n    const uint32_t N,\n    const float density_thresh,\n    uint8_t * bitfield\n) {\n    // parallel per byte\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    grid += n * 8;\n\n    uint8_t bits = 0;\n\n    #pragma unroll\n    for (uint8_t i = 0; i < 8; i++) {\n        bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;\n    }\n\n    bitfield[n] = bits;\n}\n\n\nvoid packbits(at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    grid.scalar_type(), \"packbits\", ([&] {\n        kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());\n    }));\n}\n\n////////////////////////////////////////////////////\n/////////////         training         /////////////\n////////////////////////////////////////////////////\n\n// rays_o/d: [N, 3]\n// grid: [CHHH / 8]\n// xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2]\n// dirs: [M, 3]\n// rays: [N, 3], idx, offset, num_steps\ntemplate <typename scalar_t>\n__global__ void kernel_march_rays_train(\n    const scalar_t * __restrict__ rays_o,\n    const scalar_t * __restrict__ rays_d,  \n    const uint8_t * __restrict__ grid,\n    const float bound,\n    const float dt_gamma, const uint32_t max_steps,\n    const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M,\n    const scalar_t* __restrict__ nears, \n    const scalar_t* __restrict__ fars,\n    scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas,\n    int * rays,\n    int * counter,\n    const uint32_t perturb,\n    pcg32 rng\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate\n    rays_o += n * 3;\n    rays_d += n * 3;\n\n    // ray marching\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n    const float rH = 1 / (float)H;\n\n    const float near = nears[n];\n    const float far = fars[n];\n\n    const float dt_min = 2 * SQRT3() / max_steps;\n    const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;\n    \n    float t0 = near;\n    \n    if (perturb) {\n        rng.advance(n);\n        t0 += dt_min * rng.next_float();\n    }\n    \n    // first pass: estimation of num_steps\n    float t = t0;\n    uint32_t num_steps = 0;\n\n    //if (t < far) printf(\"valid ray %d t=%f near=%f far=%f \\n\", n, t, near, far);\n    \n    while (t < far && num_steps < max_steps) {\n        // current point\n        const float x = clamp(ox + t * dx, -bound, bound);\n        const float y = clamp(oy + t * dy, -bound, bound);\n        const float z = clamp(oz + t * dz, -bound, bound);\n\n        const float dt = clamp(t * dt_gamma, dt_min, dt_max);\n\n        // get mip level\n        const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]\n\n        const float mip_bound = fminf((float)(1 << level), bound);\n        const float mip_rbound = 1 / mip_bound;\n        \n        // convert to nearest grid position\n        const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n\n        const uint32_t index = level * H * H * H + __morton3D(nx, ny, nz);\n        const bool occ = grid[index / 8] & (1 << (index % 8));\n\n        // if occpuied, advance a small step, and write to output\n        //if (n == 0) printf(\"t=%f density=%f vs thresh=%f step=%d\\n\", t, density, density_thresh, num_steps);\n\n        if (occ) {\n            num_steps++;\n            t += dt;\n        // else, skip a large step (basically skip a voxel grid)\n        } else {\n            // calc distance to next voxel\n            const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;\n            const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;\n            const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;\n\n            const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));\n            // step until next voxel\n            do { \n                t += clamp(t * dt_gamma, dt_min, dt_max);\n            } while (t < tt);\n        }\n    }\n\n    //printf(\"[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\\n\", n, num_steps, near, far, dt_min, (far - near) / dt_min);\n\n    // second pass: really locate and write points & dirs\n    uint32_t point_index = atomicAdd(counter, num_steps);\n    uint32_t ray_index = atomicAdd(counter + 1, 1);\n    \n    //printf(\"[n=%d] num_steps=%d, point_index=%d, ray_index=%d\\n\", n, num_steps, point_index, ray_index);\n\n    // write rays\n    rays[ray_index * 3] = n;\n    rays[ray_index * 3 + 1] = point_index;\n    rays[ray_index * 3 + 2] = num_steps;\n\n    if (num_steps == 0) return;\n    if (point_index + num_steps >= M) return;\n\n    xyzs += point_index * 3;\n    dirs += point_index * 3;\n    deltas += point_index * 2;\n\n    t = t0;\n    uint32_t step = 0;\n\n    float last_t = t;\n\n    while (t < far && step < num_steps) {\n        // current point\n        const float x = clamp(ox + t * dx, -bound, bound);\n        const float y = clamp(oy + t * dy, -bound, bound);\n        const float z = clamp(oz + t * dz, -bound, bound);\n\n        const float dt = clamp(t * dt_gamma, dt_min, dt_max);\n\n        // get mip level\n        const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]\n\n        const float mip_bound = fminf((float)(1 << level), bound);\n        const float mip_rbound = 1 / mip_bound;\n        \n        // convert to nearest grid position\n        const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n\n        // query grid\n        const uint32_t index = level * H * H * H + __morton3D(nx, ny, nz);\n        const bool occ = grid[index / 8] & (1 << (index % 8));\n\n        // if occpuied, advance a small step, and write to output\n        if (occ) {\n            // write step\n            xyzs[0] = x;\n            xyzs[1] = y;\n            xyzs[2] = z;\n            dirs[0] = dx;\n            dirs[1] = dy;\n            dirs[2] = dz;\n            t += dt;\n            deltas[0] = dt;\n            deltas[1] = t - last_t; // used to calc depth\n            last_t = t;\n            xyzs += 3;\n            dirs += 3;\n            deltas += 2;\n            step++;\n        // else, skip a large step (basically skip a voxel grid)\n        } else {\n            // calc distance to next voxel\n            const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;\n            const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;\n            const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;\n            const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));\n            // step until next voxel\n            do { \n                t += clamp(t * dt_gamma, dt_min, dt_max); \n            } while (t < tt);\n        }\n    }\n}\n\nvoid march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb) {\n\n    static constexpr uint32_t N_THREAD = 128;\n    pcg32 rng = pcg32{(uint64_t)42}; // hard coded random seed\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_o.scalar_type(), \"march_rays_train\", ([&] {\n        kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), counter.data_ptr<int>(), perturb, rng);\n    }));\n}\n\n\n// sigmas: [M]\n// rgbs: [M, 3]\n// deltas: [M, 2]\n// rays: [N, 3], idx, offset, num_steps\n// weights_sum: [N], final pixel alpha\n// depth: [N,]\n// image: [N, 3]\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays_train_forward(\n    const scalar_t * __restrict__ sigmas,\n    const scalar_t * __restrict__ rgbs,  \n    const scalar_t * __restrict__ deltas,\n    const int * __restrict__ rays,\n    const uint32_t M, const uint32_t N,\n    scalar_t * weights_sum,\n    scalar_t * depth,\n    scalar_t * image\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate \n    uint32_t index = rays[n * 3];\n    uint32_t offset = rays[n * 3 + 1];\n    uint32_t num_steps = rays[n * 3 + 2];\n\n    // empty ray, or ray that exceed max step count.\n    if (num_steps == 0 || offset + num_steps >= M) {\n        weights_sum[index] = 0;\n        depth[index] = 0;\n        image[index * 3] = 0;\n        image[index * 3 + 1] = 0;\n        image[index * 3 + 2] = 0;\n        return;\n    }\n\n    sigmas += offset;\n    rgbs += offset * 3;\n    deltas += offset * 2;\n\n    // accumulate \n    uint32_t step = 0;\n\n    scalar_t T = 1.0f;\n    scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0;\n\n    while (step < num_steps) {\n\n        const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);\n        const scalar_t weight = alpha * T;\n\n        // minimal remained transmittence\n        // NOTE: uncomment it won't affect instant-ngp, but totally breaks TensoRF...\n        //if (weight < 1e-4f) break;\n\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n\n        t += deltas[1]; // real delta\n        d += weight * t;\n\n        ws += weight;\n\n        T *= 1.0f - alpha;\n\n        //printf(\"[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\\n\", n, step, alpha, weight, T, sum_delta, d);\n\n        // locate\n        sigmas++;\n        rgbs += 3;\n        deltas += 2;\n\n        step++;\n    }\n\n    //printf(\"[n=%d] rgb=(%f, %f, %f), d=%f\\n\", n, r, g, b, d);\n\n    // write\n    weights_sum[index] = ws; // weights_sum\n    depth[index] = d;\n    image[index * 3] = r;\n    image[index * 3 + 1] = g;\n    image[index * 3 + 2] = b;\n}\n\n\nvoid composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    sigmas.scalar_type(), \"composite_rays_train_forward\", ([&] {\n        kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());\n    }));\n}\n\n\n// grad_weights_sum: [N,]\n// grad: [N, 3]\n// sigmas: [M]\n// rgbs: [M, 3]\n// deltas: [M, 2]\n// rays: [N, 3], idx, offset, num_steps\n// weights_sum: [N,], weights_sum here \n// image: [N, 3]\n// grad_sigmas: [M]\n// grad_rgbs: [M, 3]\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays_train_backward(\n    const scalar_t * __restrict__ grad_weights_sum,\n    const scalar_t * __restrict__ grad_image,\n    const scalar_t * __restrict__ sigmas,\n    const scalar_t * __restrict__ rgbs, \n    const scalar_t * __restrict__ deltas,\n    const int * __restrict__ rays,\n    const scalar_t * __restrict__ weights_sum,\n    const scalar_t * __restrict__ image,\n    const uint32_t M, const uint32_t N,\n    scalar_t * grad_sigmas,\n    scalar_t * grad_rgbs\n) {\n    // parallel per ray\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= N) return;\n\n    // locate \n    uint32_t index = rays[n * 3];\n    uint32_t offset = rays[n * 3 + 1];\n    uint32_t num_steps = rays[n * 3 + 2];\n\n    if (num_steps == 0 || offset + num_steps >= M) return;\n\n    grad_weights_sum += index;\n    grad_image += index * 3;\n    weights_sum += index;\n    image += index * 3;\n    sigmas += offset;\n    rgbs += offset * 3;\n    deltas += offset * 2;\n    grad_sigmas += offset;\n    grad_rgbs += offset * 3;\n\n    // accumulate \n    uint32_t step = 0;\n    \n    scalar_t T = 1.0f;\n    const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0];\n    scalar_t r = 0, g = 0, b = 0, ws = 0;\n\n    while (step < num_steps) {\n        \n        const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);\n        const scalar_t weight = alpha * T;\n\n        //if (weight < 1e-4f) break;\n\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n        ws += weight;\n\n        T *= 1.0f - alpha;\n\n        // write grad_rgbs\n        grad_rgbs[0] = grad_image[0] * weight;\n        grad_rgbs[1] = grad_image[1] * weight;\n        grad_rgbs[2] = grad_image[2] * weight;\n\n        // write grad_sigmas\n        grad_sigmas[0] = deltas[0] * (\n            grad_image[0] * (T * rgbs[0] - (r_final - r)) + \n            grad_image[1] * (T * rgbs[1] - (g_final - g)) + \n            grad_image[2] * (T * rgbs[2] - (b_final - b)) +\n            grad_weights_sum[0] * (T - (ws_final - ws))\n        );\n\n        //printf(\"[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\\n\", n, step, T, grad_sigmas[0], r_final, r);\n    \n        // locate\n        sigmas++;\n        rgbs += 3;\n        deltas += 2;\n        grad_sigmas++;\n        grad_rgbs += 3;\n\n        step++;\n    }\n}\n\n\nvoid composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {\n\n    static constexpr uint32_t N_THREAD = 128;\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    grad_image.scalar_type(), \"composite_rays_train_backward\", ([&] {\n        kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights_sum.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());\n    }));\n}\n\n\n////////////////////////////////////////////////////\n/////////////          infernce        /////////////\n////////////////////////////////////////////////////\n\ntemplate <typename scalar_t>\n__global__ void kernel_march_rays(\n    const uint32_t n_alive, \n    const uint32_t n_step, \n    const int* __restrict__ rays_alive, \n    const scalar_t* __restrict__ rays_t, \n    const scalar_t* __restrict__ rays_o, \n    const scalar_t* __restrict__ rays_d, \n    const float bound,\n    const float dt_gamma, const uint32_t max_steps,\n    const uint32_t C, const uint32_t H,\n    const uint8_t * __restrict__ grid,\n    const scalar_t* __restrict__ nears,\n    const scalar_t* __restrict__ fars,\n    scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas,\n    const uint32_t perturb,\n    pcg32 rng\n) {\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= n_alive) return;\n\n    const int index = rays_alive[n]; // ray id\n    float t = rays_t[n]; // current ray's t\n\n    // locate\n    rays_o += index * 3;\n    rays_d += index * 3;\n    xyzs += n * n_step * 3;\n    dirs += n * n_step * 3;\n    deltas += n * n_step * 2;\n\n    const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];\n    const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];\n    const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;\n    const float rH = 1 / (float)H;\n\n    const float near = nears[index], far = fars[index];\n\n    const float dt_min = 2 * SQRT3() / max_steps;\n    const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H;\n\n    // march for n_step steps, record points\n    uint32_t step = 0;\n\n    // introduce some randomness (pass in spp as perturb here)\n    if (perturb) {\n        rng.advance(n);\n        t += dt_min * rng.next_float();\n    }\n\n    float last_t = t;\n\n    while (t < far && step < n_step) {\n        // current point\n        const float x = clamp(ox + t * dx, -bound, bound);\n        const float y = clamp(oy + t * dy, -bound, bound);\n        const float z = clamp(oz + t * dz, -bound, bound);\n\n        const float dt = clamp(t * dt_gamma, dt_min, dt_max);\n\n        // get mip level\n        const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]\n\n        const float mip_bound = fminf((float)(1 << level), bound);\n        const float mip_rbound = 1 / mip_bound;\n        \n        // convert to nearest grid position\n        const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n        const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1));\n\n        const uint32_t index = level * H * H * H + __morton3D(nx, ny, nz);\n        const bool occ = grid[index / 8] & (1 << (index % 8));\n\n        // if occpuied, advance a small step, and write to output\n        if (occ) {\n            // write step\n            xyzs[0] = x;\n            xyzs[1] = y;\n            xyzs[2] = z;\n            dirs[0] = dx;\n            dirs[1] = dy;\n            dirs[2] = dz;\n            // calc dt\n            t += dt;\n            deltas[0] = dt;\n            deltas[1] = t - last_t; // used to calc depth\n            last_t = t;\n            // step\n            xyzs += 3;\n            dirs += 3;\n            deltas += 2;\n            step++;\n\n        // else, skip a large step (basically skip a voxel grid)\n        } else {\n            // calc distance to next voxel\n            const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx;\n            const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy;\n            const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz;\n            const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));\n            // step until next voxel\n            do { \n                t += clamp(t * dt_gamma, dt_min, dt_max);\n            } while (t < tt);\n        }\n    }\n}\n\n\nvoid march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, at::Tensor grid, at::Tensor near, at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb) {\n    static constexpr uint32_t N_THREAD = 128;\n    pcg32 rng = pcg32{(uint64_t)perturb};\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_o.scalar_type(), \"march_rays\", ([&] {\n        kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), perturb, rng);\n    }));\n}\n\n\ntemplate <typename scalar_t>\n__global__ void kernel_composite_rays(\n    const uint32_t n_alive, \n    const uint32_t n_step, \n    const int* __restrict__ rays_alive, \n    scalar_t* rays_t, \n    const scalar_t* __restrict__ sigmas, \n    const scalar_t* __restrict__ rgbs, \n    const scalar_t* __restrict__ deltas, \n    scalar_t* weights_sum, scalar_t* depth, scalar_t* image\n) {\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= n_alive) return;\n\n    const int index = rays_alive[n]; // ray id\n    scalar_t t = rays_t[n]; // current ray's t\n\n    // locate \n    sigmas += n * n_step;\n    rgbs += n * n_step * 3;\n    deltas += n * n_step * 2;\n\n    weights_sum += index;\n    depth += index;\n    image += index * 3;\n    \n    scalar_t weight_sum = weights_sum[0];\n    scalar_t d = depth[0];\n    scalar_t r = image[0];\n    scalar_t g = image[1];\n    scalar_t b = image[2];\n\n    // accumulate \n    uint32_t step = 0;\n    while (step < n_step) {\n        \n        // ray is terminated if delta == 0\n        if (deltas[0] == 0) break;\n        \n        const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]);\n\n        /* \n        T_0 = 1; T_i = \\prod_{j=0}^{i-1} (1 - alpha_j)\n        w_i = alpha_i * T_i\n        --> \n        T_i = 1 - \\sum_{j=0}^{i-1} w_j\n        */\n        const scalar_t T = 1 - weight_sum;\n        const scalar_t weight = alpha * T;\n        weight_sum += weight;\n\n        t += deltas[1]; // real delta\n        d += weight * t;\n        r += weight * rgbs[0];\n        g += weight * rgbs[1];\n        b += weight * rgbs[2];\n\n        //printf(\"[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\\n\", n, step, alpha, weight, T, sum_delta, d);\n\n        // ray is terminated if T is too small\n        // NOTE: can significantly accelerate inference!\n        if (T < 1e-4) break;\n\n        // locate\n        sigmas++;\n        rgbs += 3;\n        deltas += 2;\n        step++;\n    }\n\n    //printf(\"[n=%d] rgb=(%f, %f, %f), d=%f\\n\", n, r, g, b, d);\n\n    // rays_t = -1 means ray is terminated early.\n    if (step < n_step) {\n        rays_t[n] = -1;\n    } else {\n        rays_t[n] = t;\n    }\n\n    weights_sum[0] = weight_sum; // this is the thing I needed!\n    depth[0] = d;\n    image[0] = r;\n    image[1] = g;\n    image[2] = b;\n}\n\n\nvoid composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {\n    static constexpr uint32_t N_THREAD = 128;\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    image.scalar_type(), \"composite_rays\", ([&] {\n        kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());\n    }));\n}\n\n\ntemplate <typename scalar_t>\n__global__ void kernel_compact_rays(\n    const uint32_t n_alive, \n    int* rays_alive, \n    const int* __restrict__ rays_alive_old, \n    scalar_t* rays_t, \n    const scalar_t* __restrict__ rays_t_old, \n    int* alive_counter\n) {\n    const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;\n    if (n >= n_alive) return;\n\n    // rays_t_old[n] < 0 means ray died in last composite kernel.\n    if (rays_t_old[n] >= 0) {\n        const int index = atomicAdd(alive_counter, 1);\n        rays_alive[index] = rays_alive_old[n];\n        rays_t[index] = rays_t_old[n];\n    }\n}\n\n\nvoid compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter) {\n    static constexpr uint32_t N_THREAD = 128;\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    rays_t.scalar_type(), \"compact_rays\", ([&] {\n        kernel_compact_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, rays_alive.data_ptr<int>(), rays_alive_old.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_t_old.data_ptr<scalar_t>(), alive_counter.data_ptr<int>());\n    }));\n}"
  },
  {
    "path": "raymarching/src/raymarching.h",
    "content": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n\nvoid near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);\nvoid polar_from_ray(at::Tensor rays_o, at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);\nvoid morton3D(at::Tensor coords, const uint32_t N, at::Tensor indices);\nvoid morton3D_invert(at::Tensor indices, const uint32_t N, at::Tensor coords);\nvoid packbits(at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);\n\nvoid march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb);\nvoid composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);\nvoid composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs);\n\nvoid march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, at::Tensor grid, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb);\nvoid composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);\nvoid compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter);"
  },
  {
    "path": "shencoder/__init__.py",
    "content": "from .sphere_harmonics import SHEncoder\n"
  },
  {
    "path": "shencoder/backend.py",
    "content": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\n_backend = load(\n    name=\"_sh_encoder\",\n    extra_cflags=c_flags,\n    extra_cuda_cflags=nvcc_flags,\n    sources=[\n        os.path.join(_src_path, \"src\", f)\n        for f in [\n            \"shencoder.cu\",\n            \"bindings.cpp\",\n        ]\n    ],\n)\n\n__all__ = [\"_backend\"]\n"
  },
  {
    "path": "shencoder/setup.py",
    "content": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags = [\n    \"-O3\",\n    \"-std=c++14\",\n    \"-U__CUDA_NO_HALF_OPERATORS__\",\n    \"-U__CUDA_NO_HALF_CONVERSIONS__\",\n    \"-U__CUDA_NO_HALF2_OPERATORS__\",\n]\n\nif os.name == \"posix\":\n    c_flags = [\"-O3\", \"-std=c++14\"]\nelif os.name == \"nt\":\n    c_flags = [\"/O2\", \"/std:c++17\"]\n\n    # find cl.exe\n    def find_cl_path():\n        import glob\n\n        for edition in [\"Enterprise\", \"Professional\", \"BuildTools\", \"Community\"]:\n            paths = sorted(\n                glob.glob(\n                    r\"C:\\\\Program Files (x86)\\\\Microsoft Visual Studio\\\\*\\\\%s\\\\VC\\\\Tools\\\\MSVC\\\\*\\\\bin\\\\Hostx64\\\\x64\"\n                    % edition\n                ),\n                reverse=True,\n            )\n            if paths:\n                return paths[0]\n\n    # If cl.exe is not on path, try to find it.\n    if os.system(\"where cl.exe >nul 2>nul\") != 0:\n        cl_path = find_cl_path()\n        if cl_path is None:\n            raise RuntimeError(\n                \"Could not locate a supported Microsoft Visual C++ installation\"\n            )\n        os.environ[\"PATH\"] += \";\" + cl_path\n\nsetup(\n    name=\"shencoder\",  # package name, import this to use python API\n    ext_modules=[\n        CUDAExtension(\n            name=\"_shencoder\",  # extension name, import this to use CUDA API\n            sources=[\n                os.path.join(_src_path, \"src\", f)\n                for f in [\n                    \"shencoder.cu\",\n                    \"bindings.cpp\",\n                ]\n            ],\n            extra_compile_args={\n                \"cxx\": c_flags,\n                \"nvcc\": nvcc_flags,\n            },\n        ),\n    ],\n    cmdclass={\n        \"build_ext\": BuildExtension,\n    },\n)\n"
  },
  {
    "path": "shencoder/sphere_harmonics.py",
    "content": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\ntry:\n    import _shencoder as _backend\nexcept ImportError:\n    from .backend import _backend\n\n\nclass _sh_encoder(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)  # force float32 for better precision\n    def forward(ctx, inputs, degree, calc_grad_inputs=False):\n        # inputs: [B, input_dim], float in [-1, 1]\n        # RETURN: [B, F], float\n\n        inputs = inputs.contiguous()\n        B, input_dim = inputs.shape  # batch size, coord dim\n        output_dim = degree ** 2\n\n        outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device)\n\n        if calc_grad_inputs:\n            dy_dx = torch.empty(\n                B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device\n            )\n        else:\n            dy_dx = torch.empty(1, dtype=inputs.dtype, device=inputs.device)\n\n        _backend.sh_encode_forward(\n            inputs, outputs, B, input_dim, degree, calc_grad_inputs, dy_dx\n        )\n\n        ctx.save_for_backward(inputs, dy_dx)\n        ctx.dims = [B, input_dim, degree]\n        ctx.calc_grad_inputs = calc_grad_inputs\n\n        return outputs\n\n    @staticmethod\n    # @once_differentiable\n    @custom_bwd\n    def backward(ctx, grad):\n        # grad: [B, C * C]\n\n        if ctx.calc_grad_inputs:\n            grad = grad.contiguous()\n            inputs, dy_dx = ctx.saved_tensors\n            B, input_dim, degree = ctx.dims\n            grad_inputs = torch.zeros_like(inputs)\n            _backend.sh_encode_backward(\n                grad, inputs, B, input_dim, degree, dy_dx, grad_inputs\n            )\n            return grad_inputs, None, None\n        else:\n            return None, None, None\n\n\nsh_encode = _sh_encoder.apply\n\n\nclass SHEncoder(nn.Module):\n    def __init__(self, input_dim=3, degree=4):\n        super().__init__()\n\n        self.input_dim = input_dim  # coord dims, must be 3\n        self.degree = degree  # 0 ~ 4\n        self.output_dim = degree ** 2\n\n        assert self.input_dim == 3, \"SH encoder only support input dim == 3\"\n        assert (\n            self.degree > 0 and self.degree <= 8\n        ), \"SH encoder only supports degree in [1, 8]\"\n\n    def __repr__(self):\n        return f\"SHEncoder: input_dim={self.input_dim} degree={self.degree}\"\n\n    def forward(self, inputs, size=1):\n        # inputs: [..., input_dim], normalized real world positions in [-size, size]\n        # return: [..., degree^2]\n\n        inputs = inputs / size  # [-1, 1]\n\n        prefix_shape = list(inputs.shape[:-1])\n        inputs = inputs.reshape(-1, self.input_dim)\n\n        outputs = sh_encode(inputs, self.degree, inputs.requires_grad)\n        outputs = outputs.reshape(prefix_shape + [self.output_dim])\n\n        return outputs\n"
  },
  {
    "path": "shencoder/src/bindings.cpp",
    "content": "#include <torch/extension.h>\n\n#include \"shencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"sh_encode_forward\", &sh_encode_forward, \"SH encode forward (CUDA)\");\n    m.def(\"sh_encode_backward\", &sh_encode_backward, \"SH encode backward (CUDA)\");\n}"
  },
  {
    "path": "shencoder/src/shencoder.cu",
    "content": "#include <stdint.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/torch.h>\n\n#include <algorithm>\n#include <stdexcept>\n\n#include <cstdio>\n\n\n#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x \" must be a CUDA tensor\")\n#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x \" must be a contiguous tensor\")\n#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x \" must be an int tensor\")\n#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x \" must be a floating tensor\")\n\n\ntemplate <typename T>\n__host__ __device__ T div_round_up(T val, T divisor) {\n\treturn (val + divisor - 1) / divisor;\n}\n\ntemplate <typename scalar_t>\n__global__ void kernel_sh(\n    const scalar_t * __restrict__ inputs, \n    scalar_t * outputs, \n    uint32_t B, uint32_t D, uint32_t C,\n    const bool calc_grad_inputs, \n    scalar_t * dy_dx\n) {\n\tconst uint32_t b = threadIdx.x + blockIdx.x * blockDim.x;\n\tif (b >= B) return;\n\n\tconst uint32_t C2 = C * C;\n\n\t// locate\n\tinputs += b * D;\n\toutputs += b * C2;\n\n\tscalar_t x = inputs[0], y = inputs[1], z = inputs[2];\n\n\tscalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;\n\tscalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2;\n\tscalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2;\n\n\tauto write_sh = [&]() {\n\t\toutputs[0] = 0.28209479177387814f ;                          // 1/(2*sqrt(pi))\n\t\tif (C <= 1) { return; }\n\t\toutputs[1] = -0.48860251190291987f*y ;                               // -sqrt(3)*y/(2*sqrt(pi))\n\t\toutputs[2] = 0.48860251190291987f*z ;                                // sqrt(3)*z/(2*sqrt(pi))\n\t\toutputs[3] = -0.48860251190291987f*x ;                               // -sqrt(3)*x/(2*sqrt(pi))\n\t\tif (C <= 2) { return; }\n\t\toutputs[4] = 1.0925484305920792f*xy ;                                // sqrt(15)*xy/(2*sqrt(pi))\n\t\toutputs[5] = -1.0925484305920792f*yz ;                               // -sqrt(15)*yz/(2*sqrt(pi))\n\t\toutputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ;                         // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi))\n\t\toutputs[7] = -1.0925484305920792f*xz ;                               // -sqrt(15)*xz/(2*sqrt(pi))\n\t\toutputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ;                              // sqrt(15)*(x2 - y2)/(4*sqrt(pi))\n\t\tif (C <= 3) { return; }\n\t\toutputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ;                         // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))\n\t\toutputs[10] = 2.8906114426405538f*xy*z ;                             // sqrt(105)*xy*z/(2*sqrt(pi))\n\t\toutputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ;                                // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi))\n\t\toutputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ;                         // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi))\n\t\toutputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ;                                // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi))\n\t\toutputs[14] = 1.4453057213202769f*z*(x2 - y2) ;                              // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi))\n\t\toutputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ;                                // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))\n\t\tif (C <= 4) { return; }\n\t\toutputs[16] = 2.5033429417967046f*xy*(x2 - y2) ;                             // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi))\n\t\toutputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ;                                // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi))\n\t\toutputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ;                               // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi))\n\t\toutputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ;                               // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi))\n\t\toutputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ;                                // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi))\n\t\toutputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ;                               // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi))\n\t\toutputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ;                                // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi))\n\t\toutputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ;                                // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi))\n\t\toutputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ;                         // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n\t\tif (C <= 5) { return; }\n\t\toutputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                            // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\toutputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ;                           // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi))\n\t\toutputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ;                         // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))\n\t\toutputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ;                              // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi))\n\t\toutputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ;                             // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n\t\toutputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ;                            // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi))\n\t\toutputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ;                             // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n\t\toutputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ;                               // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi))\n\t\toutputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ;                         // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi))\n\t\toutputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ;                         // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n\t\toutputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                            // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n\t\tif (C <= 6) { return; }\n\t\toutputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                               // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n\t\toutputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                            // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\toutputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ;                             // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\toutputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ;                               // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))\n\t\toutputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                           // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n\t\toutputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                            // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n\t\toutputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ;                         // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi))\n\t\toutputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                            // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n\t\toutputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ;                               // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi))\n\t\toutputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ;                               // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi))\n\t\toutputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ;                          // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))\n\t\toutputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                            // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n\t\toutputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ;                         // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n\t\tif (C <= 7) { return; }\n\t\toutputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ;                              // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi))\n\t\toutputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                             // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n\t\toutputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                          // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi))\n\t\toutputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ;                           // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\toutputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                              // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))\n\t\toutputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                              // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n\t\toutputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ;                              // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n\t\toutputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ;                              // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi))\n\t\toutputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ;                              // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n\t\toutputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ;                         // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi))\n\t\toutputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                              // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))\n\t\toutputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ;                         // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))\n\t\toutputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                          // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi))\n\t\toutputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ;                               // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n\t\toutputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ;                              // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi))\n\t};\n\n\twrite_sh();\n\n\tif (calc_grad_inputs) {\n\t\tscalar_t *dx = dy_dx + b * D * C2;\n\t\tscalar_t *dy = dx + C2;\n\t\tscalar_t *dz = dy + C2;\n\n\t\tauto write_sh_dx = [&]() {\n\t\t\tdx[0] = 0.0f ;                             // 0\n\t\t\tif (C <= 1) { return; }\n\t\t\tdx[1] = 0.0f ;                             // 0\n\t\t\tdx[2] = 0.0f ;                             // 0\n\t\t\tdx[3] = -0.48860251190291992f ;                          // -sqrt(3)/(2*sqrt(pi))\n\t\t\tif (C <= 2) { return; }\n\t\t\tdx[4] = 1.0925484305920792f*y ;                          // sqrt(15)*y/(2*sqrt(pi))\n\t\t\tdx[5] = 0.0f ;                             // 0\n\t\t\tdx[6] = 0.0f ;                             // 0\n\t\t\tdx[7] = -1.0925484305920792f*z ;                         // -sqrt(15)*z/(2*sqrt(pi))\n\t\t\tdx[8] = 1.0925484305920792f*x ;                          // sqrt(15)*x/(2*sqrt(pi))\n\t\t\tif (C <= 3) { return; }\n\t\t\tdx[9] = -3.5402615395598609f*xy ;                                // -3*sqrt(70)*xy/(4*sqrt(pi))\n\t\t\tdx[10] = 2.8906114426405538f*yz ;                                // sqrt(105)*yz/(2*sqrt(pi))\n\t\t\tdx[11] = 0.0f ;                            // 0\n\t\t\tdx[12] = 0.0f ;                            // 0\n\t\t\tdx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ;                          // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))\n\t\t\tdx[14] = 2.8906114426405538f*xz ;                                // sqrt(105)*xz/(2*sqrt(pi))\n\t\t\tdx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ;                               // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tif (C <= 4) { return; }\n\t\t\tdx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ;                           // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi))\n\t\t\tdx[17] = -10.620784618679583f*xy*z ;                             // -9*sqrt(70)*xy*z/(4*sqrt(pi))\n\t\t\tdx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi))\n\t\t\tdx[19] = 0.0f ;                            // 0\n\t\t\tdx[20] = 0.0f ;                            // 0\n\t\t\tdx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ;                         // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))\n\t\t\tdx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))\n\t\t\tdx[23] = 5.3103923093397913f*z*(-x2 + y2) ;                              // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ;                           // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))\n\t\t\tif (C <= 5) { return; }\n\t\t\tdx[25] = 13.127641136803401f*xy*(-x2 + y2) ;                             // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ;                          // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi))\n\t\t\tdx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ;                         // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi))\n\t\t\tdx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi))\n\t\t\tdx[29] = 0.0f ;                            // 0\n\t\t\tdx[30] = 0.0f ;                            // 0\n\t\t\tdx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ;                          // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n\t\t\tdx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))\n\t\t\tdx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ;                         // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi))\n\t\t\tdx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))\n\t\t\tdx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ;                               // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n\t\t\tif (C <= 6) { return; }\n\t\t\tdx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                             // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))\n\t\t\tdx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ;                           // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\t\tdx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ;                              // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi))\n\t\t\tdx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n\t\t\tdx[41] = 0.0f ;                            // 0\n\t\t\tdx[42] = 0.0f ;                            // 0\n\t\t\tdx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                              // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n\t\t\tdx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n\t\t\tdx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                              // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))\n\t\t\tdx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\t\tdx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ;                           // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n\t\t\tdx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                             // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))\n\t\t\tif (C <= 7) { return; }\n\t\t\tdx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ;                         // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi))\n\t\t\tdx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ;                            // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi))\n\t\t\tdx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                             // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))\n\t\t\tdx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\t\tdx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ;                            // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi))\n\t\t\tdx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n\t\t\tdx[55] = 0.0f ;                            // 0\n\t\t\tdx[56] = 0.0f ;                            // 0\n\t\t\tdx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ;                         // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n\t\t\tdx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n\t\t\tdx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ;                              // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi))\n\t\t\tdx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\t\tdx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ;                              // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi))\n\t\t\tdx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                            // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))\n\t\t\tdx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ;                         // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))\n\t\t};\n\n\t\tauto write_sh_dy = [&]() {\n\t\t\tdy[0] = 0.0f ;                             // 0\n\t\t\tif (C <= 1) { return; }\n\t\t\tdy[1] = -0.48860251190291992f ;                          // -sqrt(3)/(2*sqrt(pi))\n\t\t\tdy[2] = 0.0f ;                             // 0\n\t\t\tdy[3] = 0.0f ;                             // 0\n\t\t\tif (C <= 2) { return; }\n\t\t\tdy[4] = 1.0925484305920792f*x ;                          // sqrt(15)*x/(2*sqrt(pi))\n\t\t\tdy[5] = -1.0925484305920792f*z ;                         // -sqrt(15)*z/(2*sqrt(pi))\n\t\t\tdy[6] = 0.0f ;                             // 0\n\t\t\tdy[7] = 0.0f ;                             // 0\n\t\t\tdy[8] = -1.0925484305920792f*y ;                         // -sqrt(15)*y/(2*sqrt(pi))\n\t\t\tif (C <= 3) { return; }\n\t\t\tdy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ;                                // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdy[10] = 2.8906114426405538f*xz ;                                // sqrt(105)*xz/(2*sqrt(pi))\n\t\t\tdy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ;                          // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi))\n\t\t\tdy[12] = 0.0f ;                            // 0\n\t\t\tdy[13] = 0.0f ;                            // 0\n\t\t\tdy[14] = -2.8906114426405538f*yz ;                               // -sqrt(105)*yz/(2*sqrt(pi))\n\t\t\tdy[15] = 3.5402615395598609f*xy ;                                // 3*sqrt(70)*xy/(4*sqrt(pi))\n\t\t\tif (C <= 4) { return; }\n\t\t\tdy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ;                           // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi))\n\t\t\tdy[17] = 5.3103923093397913f*z*(-x2 + y2) ;                              // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi))\n\t\t\tdy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ;                         // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi))\n\t\t\tdy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ;                         // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi))\n\t\t\tdy[20] = 0.0f ;                            // 0\n\t\t\tdy[21] = 0.0f ;                            // 0\n\t\t\tdy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ;                         // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi))\n\t\t\tdy[23] = 10.620784618679583f*xy*z ;                              // 9*sqrt(70)*xy*z/(4*sqrt(pi))\n\t\t\tdy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ;                          // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi))\n\t\t\tif (C <= 5) { return; }\n\t\t\tdy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ;                               // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n\t\t\tdy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi))\n\t\t\tdy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ;                         // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi))\n\t\t\tdy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ;                         // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi))\n\t\t\tdy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ;                          // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi))\n\t\t\tdy[30] = 0.0f ;                            // 0\n\t\t\tdy[31] = 0.0f ;                            // 0\n\t\t\tdy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ;                         // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi))\n\t\t\tdy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ;                         // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi))\n\t\t\tdy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ;                         // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi))\n\t\t\tdy[35] = 13.127641136803401f*xy*(x2 - y2) ;                              // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi))\n\t\t\tif (C <= 6) { return; }\n\t\t\tdy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                             // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))\n\t\t\tdy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ;                           // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi))\n\t\t\tdy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                           // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\t\tdy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                              // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi))\n\t\t\tdy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ;                             // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi))\n\t\t\tdy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ;                              // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi))\n\t\t\tdy[42] = 0.0f ;                            // 0\n\t\t\tdy[43] = 0.0f ;                            // 0\n\t\t\tdy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                              // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi))\n\t\t\tdy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ;                              // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi))\n\t\t\tdy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi))\n\t\t\tdy[47] = 47.332383244635047f*xy*z*(x2 - y2) ;                            // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi))\n\t\t\tdy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                              // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\t\tif (C <= 7) { return; }\n\t\t\tdy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ;                         // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi))\n\t\t\tdy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ;                            // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi))\n\t\t\tdy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ;                                // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi))\n\t\t\tdy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                          // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\t\tdy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ;                             // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi))\n\t\t\tdy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ;                         // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi))\n\t\t\tdy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ;                         // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi))\n\t\t\tdy[56] = 0.0f ;                            // 0\n\t\t\tdy[57] = 0.0f ;                            // 0\n\t\t\tdy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                          // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))\n\t\t\tdy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ;                           // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))\n\t\t\tdy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                         // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi))\n\t\t\tdy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                              // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi))\n\t\t\tdy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                             // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\t\tdy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                                // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n\t\t};\n\n\t\tauto write_sh_dz = [&]() {\n\t\t\tdz[0] = 0.0f ;                             // 0\n\t\t\tif (C <= 1) { return; }\n\t\t\tdz[1] = 0.0f ;                             // 0\n\t\t\tdz[2] = 0.48860251190291992f ;                           // sqrt(3)/(2*sqrt(pi))\n\t\t\tdz[3] = 0.0f ;                             // 0\n\t\t\tif (C <= 2) { return; }\n\t\t\tdz[4] = 0.0f ;                             // 0\n\t\t\tdz[5] = -1.0925484305920792f*y ;                         // -sqrt(15)*y/(2*sqrt(pi))\n\t\t\tdz[6] = 1.8923493915151202f*z ;                          // 3*sqrt(5)*z/(2*sqrt(pi))\n\t\t\tdz[7] = -1.0925484305920792f*x ;                         // -sqrt(15)*x/(2*sqrt(pi))\n\t\t\tdz[8] = 0.0f ;                             // 0\n\t\t\tif (C <= 3) { return; }\n\t\t\tdz[9] = 0.0f ;                             // 0\n\t\t\tdz[10] = 2.8906114426405538f*xy ;                                // sqrt(105)*xy/(2*sqrt(pi))\n\t\t\tdz[11] = -4.5704579946446566f*yz ;                               // -5*sqrt(42)*yz/(4*sqrt(pi))\n\t\t\tdz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ;                            // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi))\n\t\t\tdz[13] = -4.5704579946446566f*xz ;                               // -5*sqrt(42)*xz/(4*sqrt(pi))\n\t\t\tdz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ;                                // sqrt(105)*(x2 - y2)/(4*sqrt(pi))\n\t\t\tdz[15] = 0.0f ;                            // 0\n\t\t\tif (C <= 4) { return; }\n\t\t\tdz[16] = 0.0f ;                            // 0\n\t\t\tdz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ;                          // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi))\n\t\t\tdz[18] = 13.246445740605839f*xy*z ;                              // 21*sqrt(5)*xy*z/(2*sqrt(pi))\n\t\t\tdz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ;                          // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi))\n\t\t\tdz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ;                          // (105*z**3 - 45*z)/(4*sqrt(pi))\n\t\t\tdz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ;                          // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi))\n\t\t\tdz[22] = 6.6232228703029197f*z*(x2 - y2) ;                               // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi))\n\t\t\tdz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ;                          // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi))\n\t\t\tdz[24] = 0.0f ;                            // 0\n\t\t\tif (C <= 5) { return; }\n\t\t\tdz[25] = 0.0f ;                            // 0\n\t\t\tdz[26] = 8.3026492595241645f*xy*(x2 - y2) ;                              // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi))\n\t\t\tdz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ;                         // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi))\n\t\t\tdz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ;                         // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi))\n\t\t\tdz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ;                         // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi))\n\t\t\tdz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ;                           // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi))\n\t\t\tdz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ;                         // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi))\n\t\t\tdz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ;                          // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi))\n\t\t\tdz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ;                         // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi))\n\t\t\tdz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ;                            // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n\t\t\tdz[35] = 0.0f ;                            // 0\n\t\t\tif (C <= 6) { return; }\n\t\t\tdz[36] = 0.0f ;                            // 0\n\t\t\tdz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                              // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\t\tdz[38] = 44.401711264127719f*xy*z*(x2 - y2) ;                            // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi))\n\t\t\tdz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi))\n\t\t\tdz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ;                              // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi))\n\t\t\tdz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                               // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))\n\t\t\tdz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ;                              // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi))\n\t\t\tdz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ;                               // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi))\n\t\t\tdz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ;                               // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi))\n\t\t\tdz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ;                          // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi))\n\t\t\tdz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ;                           // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi))\n\t\t\tdz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                              // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n\t\t\tdz[48] = 0.0f ;                            // 0\n\t\t\tif (C <= 7) { return; }\n\t\t\tdz[49] = 0.0f ;                            // 0\n\t\t\tdz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ;                                // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi))\n\t\t\tdz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ;                             // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi))\n\t\t\tdz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ;                              // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi))\n\t\t\tdz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ;                         // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi))\n\t\t\tdz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ;                            // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi))\n\t\t\tdz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                           // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))\n\t\t\tdz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ;                           // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi))\n\t\t\tdz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ;                           // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi))\n\t\t\tdz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ;                         // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi))\n\t\t\tdz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ;                         // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi))\n\t\t\tdz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ;                            // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi))\n\t\t\tdz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ;                             // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi))\n\t\t\tdz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ;                            // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi))\n\t\t\tdz[63] = 0.0f ;                            // 0\n\t\t};\n\t\twrite_sh_dx();\n\t\twrite_sh_dy();\n\t\twrite_sh_dz();\n\t}\n}\n\n\ntemplate <typename scalar_t>\n__global__ void kernel_sh_backward(\n    const scalar_t * __restrict__ grad,\n\tconst scalar_t * __restrict__ inputs,\n    uint32_t B, uint32_t D, uint32_t C,\n    const scalar_t * __restrict__ dy_dx,\n    scalar_t * grad_inputs\n) {\n\tconst uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;\n\tconst uint32_t b = t / D;\n\tif (b >= B) return;\n\n\tconst uint32_t d = t - b * D;\n\tconst uint32_t C2 = C * C;\n\n\t// locate\n\tgrad += b * C2;\n\tdy_dx += b * D * C2 + d * C2;\n\n\tfor (int ch = 0; ch < C2; ch++) {\n\t\tgrad_inputs[t] += grad[ch] * dy_dx[ch];\n\t\t//printf(\"t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\\n\", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]);\n\t}\n\n}\n\n// inputs: [B, D], float, in [0, 1]\n// outputs: [B, L * C], float\ntemplate <typename scalar_t>\nvoid sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, scalar_t *dy_dx) {\n\tstatic constexpr uint32_t N_THREADS = 256;\n\tkernel_sh<scalar_t><<<div_round_up(B, N_THREADS), N_THREADS>>>(inputs, outputs, B, D, C, calc_grad_inputs, dy_dx);\n}\n\n\ntemplate <typename scalar_t>\nvoid sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) {\n\tstatic constexpr uint32_t N_THREADS = 256;\n\tkernel_sh_backward<scalar_t><<<div_round_up(B * D, N_THREADS), N_THREADS>>>(grad, inputs, B, D, C, dy_dx, grad_inputs);\n}\n\n\n\nvoid sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, at::Tensor dy_dx) {\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(outputs);\n    CHECK_CUDA(dy_dx);\n    \n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(outputs);\n    CHECK_CONTIGUOUS(dy_dx);\n\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(outputs);\n    CHECK_IS_FLOATING(dy_dx);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    inputs.scalar_type(), \"sh_encode_forward_cuda\", ([&] {\n\t\tsh_encode_forward_cuda<scalar_t>(inputs.data_ptr<scalar_t>(), outputs.data_ptr<scalar_t>(), B, D, C, calc_grad_inputs, dy_dx.data_ptr<scalar_t>());\n    }));\t\n}\n\nvoid sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) {    \n    CHECK_CUDA(grad);\n    CHECK_CUDA(inputs);\n    CHECK_CUDA(dy_dx);\n    CHECK_CUDA(grad_inputs);\n    \n    CHECK_CONTIGUOUS(grad);\n    CHECK_CONTIGUOUS(inputs);\n    CHECK_CONTIGUOUS(dy_dx);\n    CHECK_CONTIGUOUS(grad_inputs);\n\n    CHECK_IS_FLOATING(grad);\n    CHECK_IS_FLOATING(inputs);\n    CHECK_IS_FLOATING(dy_dx);\n    CHECK_IS_FLOATING(grad_inputs);\n\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n    grad.scalar_type(), \"sh_encode_backward_cuda\", ([&] {\n    \tsh_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<scalar_t>(), B, D, C, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>());\n    }));\t\n}"
  },
  {
    "path": "shencoder/src/shencoder.h",
    "content": "# pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [-1, 1]\n// outputs: [B, F], float\n\n// encode_forward(inputs, outputs, B, input_dim, degree, calc_grad_inputs, dy_dx)\nvoid sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, at::Tensor dy_dx);\n\n// sh_encode_backward(grad, inputs, B, input_dim, degree, ctx.calc_grad_inputs, dy_dx, grad_inputs)\nvoid sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs);"
  },
  {
    "path": "tools/activation.py",
    "content": "import torch\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n\nclass _trunc_exp(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32)  # cast to float32\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return torch.exp(x)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, g):\n        x = ctx.saved_tensors[0]\n        # return g * torch.exp(x.clamp(-15, 15))\n        return g * torch.exp(x.clamp(-12, 12))\n\n\ntrunc_exp = _trunc_exp.apply\n"
  },
  {
    "path": "tools/details.md",
    "content": "# custom datasets\n\nOur dataset format is based on the [torch-ngp](https://github.com/ashawkey/torch-ngp/tree/3b066b6cd6ccd3610cb66a56a54f5daaf12a8033), which totally supports custom dataset in form of colmap. \nThe specific steps for supporting custom dataset are as follows:\n\n- 1. take a video / many photos from different views \n- 2. put the video under a path like ./data/custom/video.mp4 or the images under ./data/custom/images/*.jpg.\n- 3. call the preprocess code: (should install ffmpeg and colmap first! refer to the  [colmap2nerf.py](https://github.com/ashawkey/torch-ngp/blob/3b066b6cd6ccd3610cb66a56a54f5daaf12a8033/scripts/colmap2nerf.py) for more options.)\n    - python colmap2nerf.py --video ./data/custom/video.mp4 --run_colmap # if use video\n    - python colmap2nerf.py --images ./data/custom/images/ --run_colmap # if use images\n- 4. it should create the transform.json, and you can train with: (you'll need to try with different scale & bound & dt_gamma to make the object correctly located in the bounding box and render fluently.)\n\n\nThen you can train a teacher and distill students with various structures according to our introduction：https://github.com/megvii-research/AAAI2023-PVD\n\n\n# some ways to reduce GPU memory\n\n- 1. Reduce the value of parameter '--num_rays' .\n- 2. Try to disable the parameter of '--preload'. When preload=True, it will load all image data into the gpu. For images with large resolution, it will occupy memory seriously (But I haven't tested this parameter, which is inherited from torch-ngp).\n- 3. If you don't have too strict requirements for image resolution, you can use the downsampled image for experiment.\n- 4. In the distillation process, due to the need to load the student and the teacher network at the same time, it will consume more memory. One solution is to separately inference the teacher network in advance and record the data required for distillation, and use these data to guide the training of the student.\n- 5. For different model(INGP/Plenoxels/NeRF/TensoRF), there are different parameters to adjust the model size. For example, you can reduce the number and resolution of hash tables in INGP, reduce the resolution of Plenoxels or tensoRF, and reduce the number of MLP parameters in NeRF, etc.\n- 6. The current code does not support multi-GPUs temporarily, but it should be easy to implement. If the above cannot solve your problem, you can try to implement DDP.\n\n"
  },
  {
    "path": "tools/encoding.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FreqEncoder(nn.Module):\n    def __init__(\n        self,\n        input_dim,\n        max_freq_log2,\n        N_freqs,\n        log_sampling=True,\n        include_input=True,\n        periodic_fns=(torch.sin, torch.cos),\n    ):\n\n        super().__init__()\n\n        self.input_dim = input_dim\n        self.include_input = include_input\n        self.periodic_fns = periodic_fns\n\n        self.output_dim = 0\n        if self.include_input:\n            self.output_dim += self.input_dim\n\n        self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns)\n\n        if log_sampling:\n            self.freq_bands = 2.0 ** torch.linspace(0.0, max_freq_log2, N_freqs)\n        else:\n            self.freq_bands = torch.linspace(2.0 ** 0.0, 2.0 ** max_freq_log2, N_freqs)\n\n        self.freq_bands = self.freq_bands.numpy().tolist()\n\n    def forward(self, input, **kwargs):\n\n        out = []\n        if self.include_input:\n            out.append(input)\n\n        for i in range(len(self.freq_bands)):\n            freq = self.freq_bands[i]\n            for p_fn in self.periodic_fns:\n                out.append(p_fn(input * freq))\n\n        out = torch.cat(out, dim=-1)\n\n        return out\n\n\ndef get_encoder(\n    encoding,\n    input_dim=3,\n    multires=6,\n    degree=4,\n    num_levels=14,\n    level_dim=2,\n    base_resolution=16,\n    log2_hashmap_size=19,\n    desired_resolution=4096,\n    align_corners=False,\n    **kwargs\n):\n\n    if encoding == \"None\":\n        return lambda x, **kwargs: x, input_dim\n\n    elif encoding == \"frequency\":\n        encoder = FreqEncoder(\n            input_dim=input_dim,\n            max_freq_log2=multires - 1,\n            N_freqs=multires,\n            log_sampling=True,\n        )\n\n    elif encoding == \"sphere_harmonics\":\n        from shencoder import SHEncoder\n\n        encoder = SHEncoder(input_dim=input_dim, degree=degree)\n\n    elif encoding == \"hashgrid\":\n        from gridencoder import GridEncoder\n\n        encoder = GridEncoder(\n            input_dim=input_dim,\n            num_levels=num_levels,\n            level_dim=level_dim,\n            base_resolution=base_resolution,\n            log2_hashmap_size=log2_hashmap_size,\n            desired_resolution=desired_resolution,\n            gridtype=\"hash\",\n            align_corners=align_corners,\n        )\n\n    elif encoding == \"tiledgrid\":\n        from gridencoder import GridEncoder\n\n        encoder = GridEncoder(\n            input_dim=input_dim,\n            num_levels=num_levels,\n            level_dim=level_dim,\n            base_resolution=base_resolution,\n            log2_hashmap_size=log2_hashmap_size,\n            desired_resolution=desired_resolution,\n            gridtype=\"tiled\",\n            align_corners=align_corners,\n        )\n\n    elif encoding == \"ash\":\n        from ashencoder import AshEncoder\n\n        encoder = AshEncoder(\n            input_dim=input_dim,\n            output_dim=16,\n            log2_hashmap_size=log2_hashmap_size,\n            resolution=desired_resolution,\n        )\n\n    else:\n        raise NotImplementedError()\n\n    return encoder, encoder.output_dim\n"
  },
  {
    "path": "tools/install_extensions.sh",
    "content": "cd raymarching\npip install .\ncd ..\n\ncd gridencoder\npip install .\ncd ..\n\ncd shencoder\npip install .\ncd .. \n"
  },
  {
    "path": "tools/requirements.txt",
    "content": "torch-ema\nninja\ntrimesh\nopencv-python\ntensorboardX\ntorch\nnumpy \npandas\ntqdm\nmatplotlib\nPyMCubes\nrich\npysdf\ndearpygui\npackaging\nscipy\nlpips\nimageio\n"
  },
  {
    "path": "tools/中文介绍.md",
    "content": "## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation\n(**Accepted by AAAI 2023**)\n\n# 目前PVD的后续工作PVD-AL已经开源，它比PVD更强大，建议使用PVD-AL: [代码地址](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL). \n\n## [项目视频](http://sk-fun.fun/PVD/) | [论文](https://arxiv.org/abs/2211.15977) | [数据集](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing) | [预训练权重](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing) | [英文介绍](https://github.com/megvii-research/AAAI2023-PVD/blob/main/README.md) |\n\n## 论文简单介绍\n\n- 目前NeRF系列模型层出不穷，令人“眼花缭乱”，有基于隐式MLP的NeRF，基于纯显式张量结构的Plenoxels, 基于混合结构的TensoRF(低秩张量+MLP) 和 INGP(Hash+MLP)，以及其他各种变体。\n\n- 不同结构的模型，其实特点是不一样。完全隐式的纯MLP结构，其高层语义特征可以拿来做很多事情，比如[光照/天气改变](https://nerf-w.github.io/)，[艺术设计](https://pfnet-research.github.io/distilled-feature-fields/)，而基于纯张量显式结构的模型，其空间结构清晰，容易进行[剪切/组合/放大/缩小/抹去](https://github.com/ashawkey/CCNeRF)等操作。TensoRF和INGP则介于两者之间，其优势更多在于训练快和重建质量高；此外硬件设施(如手机终端)对不同结构的支持度是完全不同的，选择什么样的结构用于下游任务需要一定的设计经验。\n\n- 为了减轻设计者的选择痛苦以及进行不同结构间的特性迁移，我们开展了本文的研究。目前不同结构间是否存在转化的可能性尚未被研究，我们认为首次尝试是有意义的。\n\n- 我们的目标是希望将某个架构的特性转移到其它不同架构上。比如INGP快速收敛的特点能快速得到一个模型，进而可用蒸馏方式训练一个NeRF，也起到了加速效果，并且某些数据集上还能起到涨点的效果。此外还可以将显示结构的可编辑性转移到别的非显示结构上，比如对Plenoxels的张量结构进行场景组合，场景切分等操作，然后将其蒸馏到其他模型，使其它模型也具有渲染出编辑场景的效果。实验证明显式结构的空间编辑能力可以成功且高质量的迁移到其他结构上：[我们的示例](http://sk-fun.fun/PVD/)\n\n- 为何能进行蒸馏，能够为窥探这些模型的内部原理提供一些insight。比如模型间的中间feature可对齐意味着不同结构间的模型，实质上可以映射到相近的空间。\n\n\n\n## 安装\n建议使用 [Anaconda](https://www.anaconda.com/) 进行安装，避免污染本机环境. 执行以下命令:\n\n*Step1*: 创建名为 'pvd' 的conda 环境\n```\nconda create --name pvd python=3.7\nconda activate pvd\npip install -r ./tools/requirements.txt\n```\n*Step2*: 安装C++/cuda扩展. (借鉴自 [torch-ngp](https://github.com/ashawkey/torch-ngp))\n```\nbash ./tools/install_extension.sh\n```\n\n## 数据集 & 预训练模型\nSynthetic-NeRF/LLFF/Tanks&Temples： [google云盘](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing), [baidu云盘](https://pan.baidu.com/s/1ky_TWrbUZG_MpHTBhncAKA?pwd=4h2h).\n\n预训练模型： [google云盘](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing), [baidu云盘](https://pan.baidu.com/s/1LGLXwLGusX60GpAywLwosg?pwd=34k8).\n\n不下载与训练模型，直接按照下面的方法训练一个teacher，也很快.\n\n## 训练teacher\n```\n# train a hash-based(INGP) teacher\npython main_just_train_tea.py ./data/nerf_synthetic/chair --model_type hash --data_type synthetic  --workspace ./log/train_teacher/hash_chair\n\n# train a sparse-tensor-based(TensoRF VM-decomposion) teacher\npython main_just_train_tea.py ./data/nerf_synthetic/chair --model_type vm --data_type synthetic  --workspace ./log/train_teacher/vm_chair\n\n# train a MLP-based(NeRF) teacher\npython main_just_train_tea.py ./data/nerf_synthetic/chair --model_type mlp --data_type synthetic  --workspace ./log/train_teacher/mlp_chair\n\n# train a tensors-based(Plenoxels) teacher\npython main_just_train_tea.py ./data/nerf_synthetic/chair --model_type tensors --data_type synthetic  --workspace ./log/train_teacher/tensors_chair\n\n```\n\n## 蒸馏模型\n```\n# teacher: hash(INGP),  student: vm(tensoRF)\npython3 main_distill_mutual.py  ./data/nerf_synthetic/chair \\\n                    --data_type synthetic \\\n                    --teacher_type hash \\\n                    --ckpt_teacher ./log/train_teacher/hash_chair/checkpoints/XXX.pth \\\n                    --model_type vm \\\n                    --workspace ./log/distill_student/hash2vm/chair\n                    \n# teacher: MLP(NeRF),  student: tensors(Plenoxels)\npython3 main_distill_mutual.py  ./data/nerf_synthetic/chair \\\n                    --data_type synthetic \\\n                    --teacher_type mlp \\\n                    --ckpt_teacher ./log/train_teacher/mlp_chair/checkpoints/XXX.pth \\\n                    --model_type tensors \\\n                    --workspace ./log/distill_student/mlp2tensors/chair\n                   \n```\n\n## 测试\n\n```\n# evaluate a hash teacher\npython main_distill_mutual.py ./data/nerf_synthetic/chair  --teacher_type hash --ckpt_teacher PATH/TO/CKPT.pth --test_teacher --data_type synthetic --workspace ./log/eval_teacher/hash_chair\n\n# evaluate a mlp student\npython main_distill_mutual.py ./data/nerf_synthetic/chair --model_type mlp --ckpt PATH/TO/CKPT.pth --test --data_type synthetic --workspace ./log/eval_student/mlp_chair\n```\n\n## 更多数据集上的使用以及更多运行命令参照如下\n[more running description](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/details.md)\n\n## Citation\n\n如果你觉得有用，可考虑引用我们的文章：\n```\n@article{pvd2023,\n  author    = {Fang, Shuangkang and Xu, Weixin and Wang, Heng and Yang, Yi and Wang, Yufeng and Zhou, Shuchang},\n  title     = {One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation},\n  journal   = {AAAI},\n  year      = {2023}\n}\n```\n\n### 致谢\n感谢[ingp](https://github.com/NVlabs/instant-ngp),  [torch-ngp](https://github.com/ashawkey/torch-ngp), [TensoRF](https://github.com/apchenstu/TensoRF), [Plenoxels](https://github.com/sxyu/svox2), [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch)  的漂亮框架!\n\n此外可以参考[Arch-Net](https://github.com/megvii-research/Arch-Net) 获取更多关于渐进蒸馏的思想\n"
  }
]