Full Code of megvii-research/AAAI2023-PVD for AI

main e5d0ab174f24 cached
40 files
415.1 KB
116.2k tokens
210 symbols
1 requests
Download .txt
Showing preview only (434K chars total). Download the full file or copy to clipboard to get everything.
Repository: megvii-research/AAAI2023-PVD
Branch: main
Commit: e5d0ab174f24
Files: 40
Total size: 415.1 KB

Directory structure:
gitextract_eepeq680/

├── LICENSE
├── README.md
├── distill_mutual/
│   ├── network.py
│   ├── provider.py
│   ├── renderer.py
│   └── utils.py
├── gridencoder/
│   ├── __init__.py
│   ├── backend.py
│   ├── grid.py
│   ├── setup.py
│   └── src/
│       ├── bindings.cpp
│       ├── gridencoder.cu
│       └── gridencoder.h
├── just_train_tea/
│   ├── network.py
│   ├── provider.py
│   ├── renderer.py
│   └── utils.py
├── main_distill_mutual.py
├── main_just_train_tea.py
├── raymarching/
│   ├── __init__.py
│   ├── backend.py
│   ├── raymarching.py
│   ├── setup.py
│   └── src/
│       ├── bindings.cpp
│       ├── pcg32.h
│       ├── raymarching.cu
│       └── raymarching.h
├── shencoder/
│   ├── __init__.py
│   ├── backend.py
│   ├── setup.py
│   ├── sphere_harmonics.py
│   └── src/
│       ├── bindings.cpp
│       ├── shencoder.cu
│       └── shencoder.h
└── tools/
    ├── activation.py
    ├── details.md
    ├── encoding.py
    ├── install_extensions.sh
    ├── requirements.txt
    └── 中文介绍.md

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
Copyright 2022 Megvii Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


================================================
FILE: README.md
================================================
## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation (AAAI Oral)


# :partying_face: ***New*** :partying_face: Code for more powerful PVD-AL (the follow-up work of PVD) is now provided [here](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL). 
 *(We strongly recommend using [PVD-AL](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL) (the follow-up work of pvd) with better performance).*


## [Project Page](http://sk-fun.fun/PVD/) | [Paper](https://arxiv.org/abs/2211.15977) | [Datasets](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing) | [Ckpts](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing) | [Chinese tutorial](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/%E4%B8%AD%E6%96%87%E4%BB%8B%E7%BB%8D.md) | [zhihu](https://zhuanlan.zhihu.com/p/605121286)|

## Introduction
In this paper, we propose Progressive Volume Distillation (PVD), a systematic distillation method that allows any-to-any conversions between different neural architectures, including MLP(NeRF), sparse(Plenoxels) or low-rank tensors(TensoRF), hash tables(INGP).

## Installation
We recommend using [Anaconda](https://www.anaconda.com/) to setup the environment. Run the following commands:

*Step1*: Create a conda environment named 'pvd'
```
conda create --name pvd python=3.7
conda activate pvd
pip install -r ./tools/requirements.txt
```
*Step2*: Install extension modules. (Draw from the great project [torch-ngp](https://github.com/ashawkey/torch-ngp) that we mainly rely on.)
```
bash ./tools/install_extensions.sh
```

## Datastes & Pretrained-teacher models
You can download Synthetic-NeRF/LLFF/Tanks&Temples datasets from [google](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing), or from [baidu](https://pan.baidu.com/s/1ky_TWrbUZG_MpHTBhncAKA?pwd=4h2h).

And download pretrained-teacher-models from [google](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing), or from [baidu](https://pan.baidu.com/s/1LGLXwLGusX60GpAywLwosg?pwd=34k8).

You can also train a teacher model according to the follow guidance.

## Train a teacher
```
# train a hash-based(INGP) teacher
python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type hash --data_type synthetic  --workspace ./log/train_teacher/hash_chair

# train a sparse-tensor-based(TensoRF VM-decomposion) teacher
python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type vm --data_type synthetic  --workspace ./log/train_teacher/vm_chair

# train a MLP-based(NeRF) teacher
python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type mlp --data_type synthetic  --workspace ./log/train_teacher/mlp_chair

# train a tensors-based(Plenoxels) teacher
python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type tensors --data_type synthetic  --workspace ./log/train_teacher/tensors_chair

```

## Distill a student
```
# teacher: hash(INGP),  student: vm(tensoRF)
python3 main_distill_mutual.py  ./data/nerf_synthetic/chair \
                    --data_type synthetic \
                    --teacher_type hash \
                    --ckpt_teacher ./log/train_teacher/hash_chair/checkpoints/XXX.pth \
                    --model_type vm \
                    --workspace ./log/distill_student/hash2vm/chair
                    
# teacher: MLP(NeRF),  student: tensors(Plenoxels)
python3 main_distill_mutual.py  ./data/nerf_synthetic/chair \
                    --data_type synthetic \
                    --teacher_type mlp \
                    --ckpt_teacher ./log/train_teacher/mlp_chair/checkpoints/XXX.pth \
                    --model_type tensors \
                    --workspace ./log/distill_student/mlp2tensors/chair
                   
```

## Evaluation

```
# evaluate a hash teacher
python main_distill_mutual.py ./data/nerf_synthetic/chair  --teacher_type hash --ckpt_teacher PATH/TO/CKPT.pth --test_teacher --data_type synthetic --workspace ./log/eval_teacher/hash_chair

# evaluate a mlp student
python main_distill_mutual.py ./data/nerf_synthetic/chair --model_type mlp --ckpt PATH/TO/CKPT.pth --test --data_type synthetic --workspace ./log/eval_student/mlp_chair
```

## More detailed parameter description and running commonds
Please refer to [more running description](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/details.md) for details of training different types of datasets, parameter adjustment, key settings, etc.

## Citation

If you find our code or paper useful, please consider citing
```
@article{fang2022one,
  title={One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation},
  author={Fang, Shuangkang and Xu, Weixin and Wang, Heng and Yang, Yi and Wang, Yufeng and Zhou, Shuchang},
  journal={arXiv preprint arXiv:2211.15977},
  year={2022}
}
```

### Acknowledgement
We would like to thank [ingp](https://github.com/NVlabs/instant-ngp),  [torch-ngp](https://github.com/ashawkey/torch-ngp), [TensoRF](https://github.com/apchenstu/TensoRF), [Plenoxels](https://github.com/sxyu/svox2), [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch)  for their great frameworks!

Also check out [Arch-Net](https://github.com/megvii-research/Arch-Net) for more on general progressive distillation.


================================================
FILE: distill_mutual/network.py
================================================
import torch
from time import time
import torch.nn as nn
import torch.nn.functional as F

from tools.encoding import get_encoder
from tools.activation import trunc_exp
from .renderer import NeRFRenderer
import raymarching


class NeRFNetwork(NeRFRenderer):
    def __init__(
        self,
        encoding="hashgrid",
        encoding_dir="sphere_harmonics",
        encoding_bg="hashgrid",
        num_layers=2,
        hidden_dim=64,
        geo_feat_dim=15,
        num_layers_color=3,
        hidden_dim_color=64,
        num_layers_bg=2,
        hidden_dim_bg=64,
        bound=1,
        model_type="hash",
        args=None,
        is_teacher=False,
        **kwargs,
    ):
        super().__init__(bound, **kwargs)
        # sigma network
        assert model_type in ["hash", "mlp", "vm", "tensors"]
        self.is_teacher = is_teacher
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.geo_feat_dim = geo_feat_dim
        self.args = args
        self.opt = args
        self.model_type = model_type

        self.plenoxel_degree = args.plenoxel_degree
        self.plenoxel_res = eval(args.plenoxel_res)

        assert len(self.plenoxel_res) == 3

        self.encoder, self.in_dim = get_encoder(
            encoding,
            desired_resolution=2048 * bound,
            num_levels=14,
        )

        if "hash" != self.model_type:
            self.encoder = None

        if self.model_type == "mlp":
            self.encoder_nerf_pe, self.in_dim_nerf = get_encoder(
                encoding="frequency", multires=self.args.PE
            )
            self.skips = self.args.skip
            self.nerf_layer_num = self.args.nerf_layer_num
            W = self.args.nerf_layer_wide
            self.nerf_mlp = [nn.Linear(self.in_dim_nerf, W)]
            for i in range(self.nerf_layer_num - 2):
                if i != self.skips:
                    self.nerf_mlp.append(nn.Linear(W, W))
                else:
                    self.nerf_mlp.append(nn.Linear(W + self.in_dim_nerf, W))
            self.nerf_mlp.append(nn.Linear(W, self.in_dim))
            self.nerf_mlp = nn.ModuleList(self.nerf_mlp)

        elif self.model_type == "vm":
            self.sigma_rank = [16] * 3
            self.color_rank = [48] * 3
            self.color_feat_dim = 15  # geo_feat_dim
            self.mat_ids = [[0, 1], [0, 2], [1, 2]]
            self.vec_ids = [2, 1, 0]
            self.resolution = [self.opt.resolution0] * 3
            # mat: paralist[1,16,res0,res0] repeat 3   vec: paralist[1,16,res0,1] repeat 3; repeat3 because decompose 3D grid [H, W, D] to three 2D mat [H, W], [H,D], [W, D] or decompose to three 1D vec [H], [W], [D]
            self.sigma_mat, self.sigma_vec = self.init_one_vm(
                self.sigma_rank, self.resolution
            )
            # mat: paralist[1,48,res0,res0] repeat 3   vec: paralist[1,48,res0,1] repeat 3
            self.color_mat, self.color_vec = self.init_one_vm(
                self.color_rank, self.resolution
            )
            # Linear(in_features=144, out_features=27)
            self.basis_mat = nn.Linear(
                sum(self.color_rank), self.color_feat_dim, bias=False
            )
        elif self.model_type == "tensors":
            self.init_plenoxel_volume(
                s=0.02,
                fea_dim=self.plenoxel_degree ** 2 * 3 + 1,
                volume=self.plenoxel_res,
            )

        elif self.model_type == "hash":
            pass
        else:
            raise ValueError(f"error model_type:{self.model_type}")

        if self.model_type != "vm" and self.model_type != "tensors":
            sigma_net = []
            for l in range(num_layers):
                if l == 0:
                    in_dim = self.in_dim
                else:
                    in_dim = hidden_dim

                if l == num_layers - 1:
                    out_dim = (
                        1 + self.geo_feat_dim
                    )  # 1 sigma + 15 SH features for color
                else:
                    out_dim = hidden_dim

                sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))

            self.sigma_net = nn.ModuleList(sigma_net)

        # color network
        self.num_layers_color = num_layers_color
        self.hidden_dim_color = hidden_dim_color
        # self.encoder_dir, self.in_dim_dir = get_encoder(encoding=encoding_dir)
        if self.model_type == "tensors":
            self.encoder_dir, self.in_dim_dir = get_encoder(
                encoding="sphere_harmonics",
                degree=self.plenoxel_degree,
            )

        else:
            self.encoder_dir, self.in_dim_dir = get_encoder(
                encoding=encoding_dir, input_dim=3, multires=2
            )

        if self.model_type != "tensors":
            color_net = []
            for l in range(num_layers_color):
                if l == 0:
                    in_dim = self.in_dim_dir + self.geo_feat_dim
                else:
                    in_dim = hidden_dim

                if l == num_layers_color - 1:
                    out_dim = 3  # 3 rgb
                else:
                    out_dim = hidden_dim

                color_net.append(nn.Linear(in_dim, out_dim, bias=False))

            self.color_net = nn.ModuleList(color_net)

        # background network
        if self.bg_radius > 0:
            self.num_layers_bg = num_layers_bg
            self.hidden_dim_bg = hidden_dim_bg
            self.encoder_bg, self.in_dim_bg = get_encoder(
                encoding_bg,
                input_dim=2,
                num_levels=4,
                log2_hashmap_size=19,
                desired_resolution=2048,
            )  # much smaller hashgrid

            bg_net = []
            for l in range(num_layers_bg):
                if l == 0:
                    in_dim = self.in_dim_bg + self.in_dim_dir
                else:
                    in_dim = hidden_dim_bg

                if l == num_layers_bg - 1:
                    out_dim = 3  # 3 rgb
                else:
                    out_dim = hidden_dim_bg

                bg_net.append(nn.Linear(in_dim, out_dim, bias=False))

            self.bg_net = nn.ModuleList(bg_net)
        else:
            self.bg_net = None

    def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128, 128]):
        tensor = []
        tensor.append(
            torch.nn.Parameter(
                s * torch.randn((1, fea_dim, volume[0], volume[1], volume[2]))
            )
        )
        self.tensor_volume = torch.nn.ParameterList(tensor).cuda()

    def init_one_vm(self, n_component, resolution, scale=0.1):
        # self.mat_ids = [[0, 1], [0, 2], [1, 2]]  self.vec_ids = [2, 1, 0]
        mat, vec = [], []

        for i in range(len(self.vec_ids)):
            vec_id = self.vec_ids[i]
            mat_id_0, mat_id_1 = self.mat_ids[i]
            mat.append(
                nn.Parameter(
                    scale
                    * torch.randn(
                        (1, n_component[i], resolution[mat_id_1], resolution[mat_id_0])
                    )
                )
            )  # [1, R, H, W]
            vec.append(
                nn.Parameter(
                    scale * torch.randn((1, n_component[i], resolution[vec_id], 1))
                )
            )  # [1, R, D, 1] (fake 2d to use grid_sample)

        return nn.ParameterList(mat), nn.ParameterList(vec)

    def get_sigma_feat(self, x):
        # x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode)
        # self.mat_ids = [[0, 1], [0, 2], [1, 2]]  self.vec_ids = [2, 1, 0]
        N = x.shape[0]

        # plane + line basis
        mat_coord = (
            torch.stack(
                (
                    x[..., self.mat_ids[0]],
                    x[..., self.mat_ids[1]],
                    x[..., self.mat_ids[2]],
                )
            )
            .detach()
            .view(3, -1, 1, 2)
        )  # [3, N, 1, 2]
        vec_coord = torch.stack(
            (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])
        )
        vec_coord = (
            torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)
            .detach()
            .view(3, -1, 1, 2)
        )  # [3, N, 1, 2], fake 2d coord

        sigma_feat = torch.zeros(
            [
                N,
            ],
            device=x.device,
        )

        for i in range(len(self.sigma_mat)):
            mat_feat = F.grid_sample(
                self.sigma_mat[i], mat_coord[[i]], align_corners=True
            ).view(
                -1, N
            )  # [1, R, N, 1] --> [R, N]
            vec_feat = F.grid_sample(
                self.sigma_vec[i], vec_coord[[i]], align_corners=True
            ).view(
                -1, N
            )  # [R, N]
            sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0)

        return sigma_feat

    def get_color_feat(self, x):
        # x: [N, 3], in [-1, 1]
        N = x.shape[0]

        # plane + line basis
        mat_coord = (
            torch.stack(
                (
                    x[..., self.mat_ids[0]],
                    x[..., self.mat_ids[1]],
                    x[..., self.mat_ids[2]],
                )
            )
            .detach()
            .view(3, -1, 1, 2)
        )  # [3, N, 1, 2]
        vec_coord = torch.stack(
            (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])
        )
        vec_coord = (
            torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)
            .detach()
            .view(3, -1, 1, 2)
        )  # [3, N, 1, 2], fake 2d coord

        mat_feat, vec_feat = [], []
        for i in range(len(self.color_mat)):
            mat_feat.append(
                F.grid_sample(
                    self.color_mat[i], mat_coord[[i]], align_corners=True
                ).view(-1, N)
            )  # [1, R, N, 1] --> [R, N]
            vec_feat.append(
                F.grid_sample(
                    self.color_vec[i], vec_coord[[i]], align_corners=True
                ).view(-1, N)
            )  # [R, N]

        mat_feat = torch.cat(mat_feat, dim=0)  # [3 * R, N]
        vec_feat = torch.cat(vec_feat, dim=0)  # [3 * R, N]

        color_feat = self.basis_mat(
            (mat_feat * vec_feat).T
        )  # [N, 3R] --> [N, color_feat_dim]

        return color_feat

    def compute_plenoxel_fea(self, x):
        composed = self.tensor_volume[0]
        if self.args.enable_edit_plenoxel and self.is_teacher:
            composed[
                :, 0, :, 160:, :128
            ] = -100  # This will erase the bucket in the lego scene for resolution 256
        composed = (
            F.grid_sample(composed, x.view(1, 1, -1, 1, 3), align_corners=True)
            .view(-1, x.shape[0])
            .permute(1, 0)
        )
        return composed  # [N, fea_dim]

    def forward_nerf_mlp(self, x):
        x = self.encoder_nerf_pe(x)
        in_pts = x
        for i in range(len(self.nerf_mlp)):
            x = self.nerf_mlp[i](x)
            if i != len(self.nerf_mlp) - 1:
                x = F.relu(x, inplace=True)
            if i == self.skips:
                x = torch.cat([in_pts, x], -1)
        return x

    def forward(self, x, d):
        # x: [N, 3], in [-bound, bound]  d: [N, 3], nomalized in [-1, 1]
        # sigma
        if self.model_type == "hash":
            x = self.encoder(
                x, bound=self.bound
            )  # out_x[N, 28=num_levels * fea_per_level]
        elif self.model_type == "mlp":
            x = self.forward_nerf_mlp(x)  # 28
        elif self.model_type == "vm":
            x = (
                2
                * (x - self.aabb_train[:3])
                / (self.aabb_train[3:] - self.aabb_train[:3])
                - 1
            )  # x:[N, 3]
            sigma_feat = self.get_sigma_feat(x)  # sigma_feat:[N]
            color_feat = self.get_color_feat(x)  # color_feat:[N, 15]
            if self.opt.enable_edit_plenoxel:
                sigma_feat = torch.clamp(sigma_feat, -100, self.args.sigma_clip_max)
            else:
                sigma_feat = torch.clamp(
                    sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
                )
            color_feat = torch.clamp(
                color_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
            )
            self.feature_sigma_color = torch.cat(
                [sigma_feat.unsqueeze(-1), color_feat], dim=-1
            )
            if (
                self.training
                and self.args.global_step < self.args.stage_iters["stage1"]
            ):
                return None, None
            self.sigma_l = sigma_feat
            sigma = trunc_exp(sigma_feat)  # sigma:[N]
            enc_d = self.encoder_dir(d)  # enc_d:[N, 16]
            h = torch.cat([enc_d, color_feat], dim=-1)  # h:[N, 16+15]
            for l in range(self.num_layers_color):
                h = self.color_net[l](h)
                if l != self.num_layers_color - 1:
                    h = F.relu(h, inplace=True)

            color = torch.sigmoid(h)
            self.color_l = color

            return sigma, color
        elif self.model_type == "tensors":
            x = (
                2
                * (x - self.aabb_train[:3])
                / (self.aabb_train[3:] - self.aabb_train[:3])
                - 1
            )  # x:[N, 3]
            x = self.compute_plenoxel_fea(x)
            h = x
            if self.opt.enable_edit_plenoxel:
                sigma = torch.clamp(h[..., 0], -100, self.args.sigma_clip_max)
            else:
                sigma = torch.clamp(
                    h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max
                )
            self.sigma_l = sigma
            sigma = trunc_exp(sigma)
            self.sigma = sigma
            sh = h[..., 1:].view(
                -1, 3, self.plenoxel_degree ** 2
            )  # [N, 3, 9]   ## .permute(1, 0, 2)  # [B, 27]-->[9, B, 3]
            enc_d = self.encoder_dir(d).unsqueeze(1)  # [N, 9]-->[N,1,9]
            color = (sh * enc_d).sum(-1)  # [N, 3]
            color = torch.sigmoid(color)
            self.feature_sigma_color = None
            self.color_l = color
            return sigma, color
        else:
            raise ValueError(f"not illegal model_type:{self.model_type}")

        h = x
        for l in range(self.num_layers):
            h = self.sigma_net[l](h)
            if l != self.num_layers - 1:
                h = F.relu(h, inplace=True)
        h[..., 0] = torch.clamp(
            h[..., 0].clone(), self.args.sigma_clip_min, self.args.sigma_clip_max
        )
        self.feature_sigma_color = h
        if self.training and self.args.global_step < self.args.stage_iters["stage1"]:
            return None, None
        self.sigma_l = h[..., 0]
        sigma = trunc_exp(h[..., 0])  # sigma: [n]
        geo_feat = h[..., 1:]  # geo_feat: [n, 15]

        d = self.encoder_dir(d)  # d: [n, 16]
        h = torch.cat([d, geo_feat], dim=-1)  # h: [n, 15+16]
        for l in range(self.num_layers_color):
            h = self.color_net[l](h)
            if l != self.num_layers_color - 1:
                h = F.relu(h, inplace=True)

        color = torch.sigmoid(h)
        self.color_l = color
        return sigma, color

    def density(self, x):
        # x: [N, 3], in [-bound, bound]
        if self.model_type == "hash":
            x = self.encoder(
                x, bound=self.bound
            )  # out_x[N, 32=num_levels * fea_per_level]
        elif self.model_type == "mlp":
            x = self.forward_nerf_mlp(x)
        elif self.model_type == "vm":
            x = (
                2
                * (x - self.aabb_train[:3])
                / (self.aabb_train[3:] - self.aabb_train[:3])
                - 1
            )
            sigma_feat = self.get_sigma_feat(x)
            sigma_feat = torch.clamp(
                sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
            )
            sigma = trunc_exp(sigma_feat)
            return {"sigma": sigma}
        elif self.model_type == "tensors":
            x = (
                2
                * (x - self.aabb_train[:3])
                / (self.aabb_train[3:] - self.aabb_train[:3])
                - 1
            )  # x:[N, 3]
            x = self.compute_plenoxel_fea(x)
            h = x
            # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
            sigma = trunc_exp(
                torch.clamp(
                    h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max
                )
            )
            sigma = trunc_exp(h[..., 0])
            return {"sigma": sigma}

        else:
            raise ValueError(f"not illegal model_type:{self.model_type}")

        h = x
        for l in range(self.num_layers):
            h = self.sigma_net[l](h)
            if l != self.num_layers - 1:
                h = F.relu(h, inplace=True)

        h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
        sigma = trunc_exp(h[..., 0])
        geo_feat = h[..., 1:]

        return {
            "sigma": sigma,
            "geo_feat": geo_feat,
        }

    def background(self, x, d):
        assert 1 == 2
        # x: [N, 2], in [-1, 1]

        h = self.encoder_bg(x)  # [N, C]
        d = self.encoder_dir(d)

        h = torch.cat([d, h], dim=-1)
        for l in range(self.num_layers_bg):
            h = self.bg_net[l](h)
            if l != self.num_layers_bg - 1:
                h = F.relu(h, inplace=True)

        # sigmoid activation for rgb
        rgbs = torch.sigmoid(h)

        return rgbs

    # allow masked inference
    def color(self, x, d, mask=None, geo_feat=None, **kwargs):
        assert 1 == 2
        # x: [N, 3] in [-bound, bound]
        # mask: [N,], bool, indicates where we actually needs to compute rgb.

        if mask is not None:
            rgbs = torch.zeros(
                mask.shape[0], 3, dtype=x.dtype, device=x.device
            )  # [N, 3]
            # in case of empty mask
            if not mask.any():
                return rgbs
            x = x[mask]
            d = d[mask]
            geo_feat = geo_feat[mask]

        d = self.encoder_dir(d)
        h = torch.cat([d, geo_feat], dim=-1)
        for l in range(self.num_layers_color):
            h = self.color_net[l](h)
            if l != self.num_layers_color - 1:
                h = F.relu(h, inplace=True)

        # sigmoid activation for rgb
        h = torch.sigmoid(h)

        if mask is not None:
            rgbs[mask] = h.to(rgbs.dtype)  # fp16 --> fp32
        else:
            rgbs = h

        return rgbs

    # L1 penalty for loss
    def density_loss(self):
        loss = 0
        for i in range(len(self.sigma_mat)):
            loss = (
                loss
                + torch.mean(torch.abs(self.sigma_mat[i]))
                + torch.mean(torch.abs(self.sigma_vec[i]))
            )
        return loss

    # upsample utils
    @torch.no_grad()
    def upsample_params(self, mat, vec, resolution):

        for i in range(len(self.vec_ids)):
            vec_id = self.vec_ids[i]
            mat_id_0, mat_id_1 = self.mat_ids[i]
            mat[i] = nn.Parameter(
                F.interpolate(
                    mat[i].data,
                    size=(resolution[mat_id_1], resolution[mat_id_0]),
                    mode="bilinear",
                    align_corners=True,
                )
            )
            vec[i] = nn.Parameter(
                F.interpolate(
                    vec[i].data,
                    size=(resolution[vec_id], 1),
                    mode="bilinear",
                    align_corners=True,
                )
            )

    @torch.no_grad()
    def upsample_model(self, resolution):
        self.upsample_params(self.sigma_mat, self.sigma_vec, resolution)
        self.upsample_params(self.color_mat, self.color_vec, resolution)
        self.resolution = resolution

    @torch.no_grad()
    def shrink_model(self):
        # shrink aabb_train and the model so it only represents the space inside aabb_train.

        half_grid_size = self.bound / self.grid_size
        thresh = min(self.density_thresh, self.mean_density)

        # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?)
        valid_grid = self.density_grid[self.cascade - 1] > thresh  # [N]
        valid_pos = raymarching.morton3D_invert(
            torch.nonzero(valid_grid)
        )  # [Nz] --> [Nz, 3], in [0, H - 1]
        # plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf...
        valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (
            self.bound - half_grid_size
        )  # [Nz, 3], in [-b+hgs, b-hgs]
        min_pos = valid_pos.amin(0) - half_grid_size  # [3]
        max_pos = valid_pos.amax(0) + half_grid_size  # [3]

        # shrink model
        reso = torch.LongTensor(self.resolution).to(self.aabb_train.device)
        units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso
        tl = (min_pos - self.aabb_train[:3]) / units
        br = (max_pos - self.aabb_train[:3]) / units
        tl = torch.round(tl).long().clamp(min=0)
        br = torch.minimum(torch.round(br).long(), reso)

        for i in range(len(self.vec_ids)):
            vec_id = self.vec_ids[i]
            mat_id_0, mat_id_1 = self.mat_ids[i]

            self.sigma_vec[i] = nn.Parameter(
                self.sigma_vec[i].data[..., tl[vec_id] : br[vec_id], :]
            )
            self.color_vec[i] = nn.Parameter(
                self.color_vec[i].data[..., tl[vec_id] : br[vec_id], :]
            )

            self.sigma_mat[i] = nn.Parameter(
                self.sigma_mat[i].data[
                    ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]
                ]
            )
            self.color_mat[i] = nn.Parameter(
                self.color_mat[i].data[
                    ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]
                ]
            )

        self.aabb_train = torch.cat([min_pos, max_pos], dim=0)  # [6]

        print(
            f"[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}"
        )
        print(f"[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}")

    # optimizer utils
    def get_params(self, lr, lr2=1e-3):
        if self.model_type == "hash":
            params = [
                {"params": self.encoder.parameters(), "lr": lr},
                {"params": self.sigma_net.parameters(), "lr": lr},
                {"params": self.encoder_dir.parameters(), "lr": lr},
                {"params": self.color_net.parameters(), "lr": lr},
            ]
        elif self.model_type == "mlp":
            params = [
                {"params": self.sigma_net.parameters(), "lr": lr},
                {"params": self.encoder_dir.parameters(), "lr": lr},
                {"params": self.color_net.parameters(), "lr": lr},
                {"params": self.nerf_mlp.parameters(), "lr": lr},
            ]
        elif self.model_type == "vm":
            params = [
                {"params": self.color_net.parameters(), "lr": lr2},
                {"params": self.sigma_mat, "lr": lr},
                {"params": self.sigma_vec, "lr": lr},
                {"params": self.color_mat, "lr": lr},
                {"params": self.color_vec, "lr": lr},
                {"params": self.basis_mat.parameters(), "lr": lr2},
            ]
        elif self.model_type == "tensors":
            params = [
                {"params": self.tensor_volume.parameters(), "lr": lr},
                {"params": self.encoder_dir.parameters(), "lr": lr},
            ]

        else:
            raise ValueError(f"not illegal model_type:{self.model_type}")

        if self.bg_radius > 0:
            params.append({"params": self.encoder_bg.parameters(), "lr": lr})
            params.append({"params": self.bg_net.parameters(), "lr": lr})

        return params


================================================
FILE: distill_mutual/provider.py
================================================
import os
import cv2
import glob
import json
import tqdm
import numpy as np
from scipy.spatial.transform import Slerp, Rotation

import trimesh

import torch
from torch.utils.data import DataLoader

from .utils import get_rays, srgb_to_linear


# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
def nerf_matrix_to_ngp(pose, scale=0.33):
    # for the fox dataset, 0.33 scales camera radius to ~ 2
    new_pose = np.array(
        [
            [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],
            [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],
            [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],
            [0, 0, 0, 1],
        ],
        dtype=np.float32,
    )
    return new_pose


def rand_poses(
    size,
    device,
    radius=1,
    theta_range=[np.pi / 3, 2 * np.pi / 3],
    phi_range=[0, 2 * np.pi],
):
    """generate random poses from an orbit camera
    Args:
        size: batch size of generated poses.
        device: where to allocate the output.
        radius: camera radius
        theta_range: [min, max], should be in [0, \pi]
        phi_range: [min, max], should be in [0, 2\pi]
    Return:
        poses: [size, 4, 4]
    """

    def normalize(vectors):
        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)

    thetas = (
        torch.rand(size, device=device) * (theta_range[1] - theta_range[0])
        + theta_range[0]
    )
    phis = (
        torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
    )

    centers = torch.stack(
        [
            radius * torch.sin(thetas) * torch.sin(phis),
            radius * torch.cos(thetas),
            radius * torch.sin(thetas) * torch.cos(phis),
        ],
        dim=-1,
    )  # [B, 3]

    # lookat
    forward_vector = -normalize(centers)
    up_vector = (
        torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
    )  # confused at the coordinate system...
    right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))
    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))

    poses = (
        torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
    )
    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
    poses[:, :3, 3] = centers

    return poses

    def normalize(vectors):
        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)

    interval_nums = torch.tensor(
        [i * 1 / (size - 1) for i in range(size)], dtype=torch.float32, device=device
    )
    thetas = interval_nums * (theta_range[1] - theta_range[0]) + theta_range[0]
    phis = interval_nums * (phi_range[1] - phi_range[0]) + phi_range[0]

    centers = torch.stack(
        [
            radius * torch.sin(thetas) * torch.sin(phis),
            radius * torch.cos(thetas),
            radius * torch.sin(thetas) * torch.cos(phis),
        ],
        dim=-1,
    )  # [B, 3]

    # lookat
    forward_vector = -normalize(centers)
    up_vector = (
        torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
    )  # confused at the coordinate system...
    right_vector = normalize(
        torch.cross(forward_vector, up_vector, dim=-1)
    )  # cross product
    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))

    poses = (
        torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
    )
    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
    poses[:, :3, 3] = centers

    return poses


class NeRFDataset:
    def __init__(self, opt, device, type="train", downscale=1, n_test=10):
        super().__init__()

        self.opt = opt
        self.args = opt
        self.device = device
        self.type = type  # train, val, test
        self.downscale = downscale
        self.root_path = opt.path
        self.mode = opt.mode  # only support blender
        self.preload = opt.preload  # preload data into GPU
        self.scale = (
            opt.scale
        )  # camera radius scale to make sure camera are inside the bounding box.
        self.bound = (
            opt.bound
        )  # bounding box half length, also used as the radius to random sample poses.
        self.fp16 = opt.fp16  # if preload, load into fp16.

        self.training = self.type in ["train", "all", "trainval"]
        self.num_rays = self.opt.num_rays if self.training else -1

        if self.mode == "blender":
            if type == "all":
                transform_paths = glob.glob(os.path.join(self.root_path, "*.json"))
                transform = None
                for transform_path in transform_paths:
                    with open(transform_path, "r") as f:
                        tmp_transform = json.load(f)
                        if transform is None:
                            transform = tmp_transform
                        else:
                            transform["frames"].extend(tmp_transform["frames"])
            # load train and val split
            elif type == "trainval":
                with open(
                    os.path.join(self.root_path, f"transforms_train.json"), "r"
                ) as f:
                    transform = json.load(f)
                with open(
                    os.path.join(self.root_path, f"transforms_val.json"), "r"
                ) as f:
                    transform_val = json.load(f)
                transform["frames"].extend(transform_val["frames"])
            # only load one specified split
            else:
                with open(
                    os.path.join(self.root_path, f"transforms_{type}.json"), "r"
                ) as f:
                    transform = json.load(f)

        else:
            raise NotImplementedError(f"unknown dataset mode: {self.mode}")

        # load image size
        if "h" in transform and "w" in transform:
            self.H = int(transform["h"]) // downscale
            self.W = int(transform["w"]) // downscale
        else:
            # we have to actually read an image to get H and W later.
            self.H = self.W = None
        # read images
        frames = transform["frames"]
        if True:
            self.poses = []
            self.images = []
            for f in tqdm.tqdm(frames, desc=f"Loading {type} data:"):
                f_path = os.path.join(self.root_path, f["file_path"])
                if (
                    self.mode == "blender"
                    and f_path[-4:].lower() != ".png"
                    and f_path[-4:].lower() != ".jpg"
                ):
                    f_path += ".png"  # so silly...
                if not os.path.exists(f_path):
                    continue
                pose = np.array(f["transform_matrix"], dtype=np.float32)  # [4, 4]
                pose = nerf_matrix_to_ngp(pose, scale=self.scale)

                image = cv2.imread(
                    f_path, cv2.IMREAD_UNCHANGED
                )  # [H, W, 3] o [H, W, 4]
                if self.H is None or self.W is None:
                    self.H = image.shape[0] // downscale
                    self.W = image.shape[1] // downscale

                # add support for the alpha channel as a mask.
                if image.shape[-1] == 3:
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                else:
                    image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)

                if image.shape[0] != self.H or image.shape[1] != self.W:
                    image = cv2.resize(
                        image, (self.W, self.H), interpolation=cv2.INTER_AREA
                    )

                image = image.astype(np.float32) / 255  # [H, W, 3/4]

                self.poses.append(pose)
                self.images.append(image)
        self.poses = torch.from_numpy(np.stack(self.poses, axis=0))  # [N, 4, 4]
        if self.images is not None:
            self.images = torch.from_numpy(
                np.stack(self.images, axis=0)
            )  # [N, H, W, C]
        self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()

        if self.training and self.opt.error_map:
            self.error_map = torch.ones(
                [self.images.shape[0], 128 * 128], dtype=torch.float
            )  # [B, 128 * 128], flattened for easy indexing, fixed resolution...
        else:
            self.error_map = None

        if self.preload:
            self.poses = self.poses.to(self.device)
            if self.images is not None:
                if self.fp16 and self.opt.color_space != "linear":
                    dtype = torch.half
                else:
                    dtype = torch.float
                self.images = self.images.to(dtype).to(self.device)
            if self.error_map is not None:
                self.error_map = self.error_map.to(self.device)

        # load intrinsics
        if "fl_x" in transform or "fl_y" in transform:
            fl_x = (
                transform["fl_x"] if "fl_x" in transform else transform["fl_y"]
            ) / downscale
            fl_y = (
                transform["fl_y"] if "fl_y" in transform else transform["fl_x"]
            ) / downscale
        elif "camera_angle_x" in transform or "camera_angle_y" in transform:
            # blender, assert in radians. already downscaled since we use H/W
            fl_x = (
                self.W / (2 * np.tan(transform["camera_angle_x"] / 2))
                if "camera_angle_x" in transform
                else None
            )
            fl_y = (
                self.H / (2 * np.tan(transform["camera_angle_y"] / 2))
                if "camera_angle_y" in transform
                else None
            )
            if fl_x is None:
                fl_x = fl_y
            if fl_y is None:
                fl_y = fl_x
        else:
            raise RuntimeError(
                "Failed to load focal length, please check the transforms.json!"
            )

        cx = (transform["cx"] / downscale) if "cx" in transform else (self.H / 2)
        cy = (transform["cy"] / downscale) if "cy" in transform else (self.W / 2)

        self.intrinsics = np.array([fl_x, fl_y, cx, cy])

    def collate(self, index):

        B = len(index)  # a list of length 1
        poses = self.poses[index].to(self.device)  # [B, 4, 4]

        error_map = None if self.error_map is None else self.error_map[index]
        rays = get_rays(
            poses, self.intrinsics, self.H, self.W, self.num_rays, error_map
        )
        results = {
            "H": self.H,
            "W": self.W,
            "rays_o": rays["rays_o"],
            "rays_d": rays["rays_d"],
        }

        if self.images is not None:
            images = self.images[index].to(self.device)  # [B, H, W, 3/4]
            if self.training:
                C = images.shape[-1]
                images = torch.gather(
                    images.view(B, -1, C), 1, torch.stack(C * [rays["inds"]], -1)
                )  # [B, N, 3/4]
            results["images"] = images

        # need inds to update error_map
        if error_map is not None:
            results["index"] = index
            results["inds_coarse"] = rays["inds_coarse"]

        return results

    def dataloader(self):
        size = len(self.poses)
        loader = DataLoader(
            list(range(size)),
            batch_size=1,
            collate_fn=self.collate,
            shuffle=self.training,
            num_workers=0,
        )
        loader._data = self
        return loader


================================================
FILE: distill_mutual/renderer.py
================================================
import math
import trimesh
import numpy as np
from time import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import raymarching
from .utils import custom_meshgrid
from IPython import embed


def sample_pdf(bins, weights, n_samples, det=False):
    # This implementation is from NeRF
    # bins: [B, T], old_z_vals
    # weights: [B, T - 1], bin weights.
    # return: [B, n_samples], new_z_vals

    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
    # Take uniform samples
    if det:
        u = torch.linspace(
            0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples
        ).to(weights.device)
        u = u.expand(list(cdf.shape[:-1]) + [n_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)

    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (B, n_samples, 2)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = cdf_g[..., 1] - cdf_g[..., 0]
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples


def plot_pointcloud(pc, color=None):
    # pc: [N, 3]
    # color: [N, 3/4]
    print("[visualize points]", pc.shape, pc.dtype, pc.min(0), pc.max(0))
    pc = trimesh.PointCloud(pc, color)
    # axis
    axes = trimesh.creation.axis(axis_length=4)
    # sphere
    sphere = trimesh.creation.icosphere(radius=1)
    trimesh.Scene([pc, axes, sphere]).show()


class NeRFRenderer(nn.Module):
    def __init__(
        self,
        bound=1,
        cuda_ray=False,
        density_scale=1,  # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
        min_near=0.2,
        density_thresh=0.01,
        bg_radius=-1,
        grid_size=128,
    ):
        super().__init__()

        print("\n---------------", grid_size, "--------------\n")
        self.bound = bound
        self.cascade = 1 + math.ceil(math.log2(bound))
        self.grid_size = grid_size
        self.density_scale = density_scale
        self.min_near = min_near
        self.density_thresh = density_thresh
        self.bg_radius = bg_radius  # radius of the background sphere.

        # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
        # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
        aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])
        aabb_infer = aabb_train.clone()
        self.register_buffer("aabb_train", aabb_train)
        self.register_buffer("aabb_infer", aabb_infer)

        # extra state for cuda raymarching
        self.cuda_ray = cuda_ray
        if cuda_ray:
            # density grid
            density_grid = torch.zeros(
                [self.cascade, self.grid_size ** 3]
            )  # [CAS, H * H * H]
            density_bitfield = torch.zeros(
                self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8
            )  # [CAS * H * H * H // 8]
            self.register_buffer("density_grid", density_grid)
            self.register_buffer("density_bitfield", density_bitfield)
            self.mean_density = 0
            self.iter_density = 0
            # step counter
            step_counter = torch.zeros(
                16, 2, dtype=torch.int32
            )  # 16 is hardcoded for averaging...
            self.register_buffer("step_counter", step_counter)
            self.mean_count = 0
            self.local_step = 0

    def forward(self, x, d):
        raise NotImplementedError()

    # separated density and color query (can accelerate non-cuda-ray mode.)
    def density(self, x):
        raise NotImplementedError()

    def color(self, x, d, mask=None, **kwargs):
        raise NotImplementedError()

    def reset_extra_state(self):
        if not self.cuda_ray:
            return
        # density grid
        self.density_grid.zero_()
        self.mean_density = 0
        self.iter_density = 0
        # step counter
        self.step_counter.zero_()
        self.mean_count = 0
        self.local_step = 0

    def run(
        self,
        rays_o,
        rays_d,
        num_steps=128,
        upsample_steps=128,
        bg_color=None,
        perturb=False,
        **kwargs
    ):
        # rays_o, rays_d: [B, N, 3], assumes B == 1
        # bg_color: [3] in range [0, 1]
        # return: image: [B, N, 3], depth: [B, N]
        prefix = rays_o.shape[:-1]
        rays_o = rays_o.contiguous().view(-1, 3)
        rays_d = rays_d.contiguous().view(-1, 3)

        N = rays_o.shape[0]  # N = B * N, in fact
        device = rays_o.device

        # choose aabb
        aabb = self.aabb_train if self.training else self.aabb_infer

        # sample steps
        nears, fars = raymarching.near_far_from_aabb(
            rays_o, rays_d, aabb, self.min_near
        )
        nears.unsqueeze_(-1)
        fars.unsqueeze_(-1)

        # print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')

        z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(
            0
        )  # [1, T]
        z_vals = z_vals.expand((N, num_steps))  # [N, T]
        z_vals = nears + (fars - nears) * z_vals  # [N, T], in [nears, fars]

        # perturb z_vals
        sample_dist = (fars - nears) / num_steps
        if perturb:
            z_vals = (
                z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
            )
            # z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.

        # generate xyzs
        xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(
            -1
        )  # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
        xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:])  # a manual clip.

        # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())

        # query SDF and RGB
        density_outputs = self.density(xyzs.reshape(-1, 3))

        # sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
        for k, v in density_outputs.items():
            density_outputs[k] = v.view(N, num_steps, -1)

        # upsample z_vals (nerf-like)
        if upsample_steps > 0:
            with torch.no_grad():

                deltas = z_vals[..., 1:] - z_vals[..., :-1]  # [N, T-1]
                deltas = torch.cat(
                    [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1
                )

                alphas = 1 - torch.exp(
                    -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1)
                )  # [N, T]
                alphas_shifted = torch.cat(
                    [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1
                )  # [N, T+1]
                weights = (
                    alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]
                )  # [N, T]

                # sample new z_vals
                z_vals_mid = z_vals[..., :-1] + 0.5 * deltas[..., :-1]  # [N, T-1]
                new_z_vals = sample_pdf(
                    z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training
                ).detach()  # [N, t]

                new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(
                    -2
                ) * new_z_vals.unsqueeze(
                    -1
                )  # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
                new_xyzs = torch.min(
                    torch.max(new_xyzs, aabb[:3]), aabb[3:]
                )  # a manual clip.

            # only forward new points to save computation
            new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
            # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
            for k, v in new_density_outputs.items():
                new_density_outputs[k] = v.view(N, upsample_steps, -1)

            # re-order
            z_vals = torch.cat([z_vals, new_z_vals], dim=1)  # [N, T+t]
            z_vals, z_index = torch.sort(z_vals, dim=1)

            xyzs = torch.cat([xyzs, new_xyzs], dim=1)  # [N, T+t, 3]
            xyzs = torch.gather(
                xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)
            )

            for k in density_outputs:
                tmp_output = torch.cat(
                    [density_outputs[k], new_density_outputs[k]], dim=1
                )
                density_outputs[k] = torch.gather(
                    tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)
                )

        deltas = z_vals[..., 1:] - z_vals[..., :-1]  # [N, T+t-1]
        deltas = torch.cat(
            [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1
        )
        alphas = 1 - torch.exp(
            -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1)
        )  # [N, T+t]
        alphas_shifted = torch.cat(
            [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1
        )  # [N, T+t+1]
        weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]  # [N, T+t]

        dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
        for k, v in density_outputs.items():
            density_outputs[k] = v.view(-1, v.shape[-1])

        mask = weights > 1e-4  # hard coded
        rgbs = self.color(
            xyzs.reshape(-1, 3),
            dirs.reshape(-1, 3),
            mask=mask.reshape(-1),
            **density_outputs
        )
        rgbs = rgbs.view(N, -1, 3)  # [N, T+t, 3]

        # print(xyzs.shape, 'valid_rgb:', mask.sum().item())

        # calculate weight_sum (mask)
        weights_sum = weights.sum(dim=-1)  # [N]

        # calculate depth
        ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
        depth = torch.sum(weights * ori_z_vals, dim=-1)

        # calculate color
        image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2)  # [N, 3], in [0, 1]

        # mix background color
        if self.bg_radius > 0:
            # use the bg model to calculate bg_color
            polar = raymarching.polar_from_ray(
                rays_o, rays_d, self.bg_radius
            )  # [N, 2] in [-1, 1]
            bg_color = self.background(polar, rays_d.reshape(-1, 3))  # [N, 3]
        elif bg_color is None:
            bg_color = 1

        image = image + (1 - weights_sum).unsqueeze(-1) * bg_color

        image = image.view(*prefix, 3)
        depth = depth.view(*prefix)

        # tmp: reg loss in mip-nerf 360
        # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)
        # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]
        # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()

        return {
            "depth": depth,
            "image": image,
        }

    def run_cuda(
        self,
        rays_o,
        rays_d,
        dt_gamma=0,
        bg_color=None,
        perturb=False,
        force_all_rays=False,
        max_steps=1024,
        inherited_params=[],
        **kwargs
    ):
        # rays_o, rays_d: [B, N, 3], assumes B == 1
        # return: image: [B, N, 3], depth: [B, N]

        prefix = rays_o.shape[:-1]
        rays_o = rays_o.contiguous().view(-1, 3)
        rays_d = rays_d.contiguous().view(-1, 3)

        N = rays_o.shape[0]  # N = B * N, in fact
        device = rays_o.device

        # pre-calculate near far
        nears, fars = raymarching.near_far_from_aabb(
            rays_o,
            rays_d,
            self.aabb_train if self.training else self.aabb_infer,
            self.min_near,
        )

        # mix background color
        if self.bg_radius > 0:
            # use the bg model to calculate bg_color
            polar = raymarching.polar_from_ray(
                rays_o, rays_d, self.bg_radius
            )  # [N, 2] in [-1, 1]
            bg_color = self.background(polar, rays_d)  # [N, 3]
        elif bg_color is None:
            bg_color = 1

        if self.training:  # different with testing
            # setup counter
            time1 = time()
            counter = self.step_counter[self.local_step % 16]
            counter.zero_()  # set to 0
            self.local_step += 1
            if (
                self.args.render_stu_first
            ):  # if stu first, then using stu to calculate xyzs, and tea will inherite the xyzs
                """
                About xyzs, dirs, deltas, rays:
                    xyzs, dirs are all spatial points sampled by rays_o and rays_d;
                    rays: xyzs[rays[i, 1]:rays[i,1]+rays[i, 2]] --> points belonging to rays[i, 0]
                    deltas: shape is [point_nums, 2]. deltas means all generated points' deltas. (first for RGB, second for Depth)
                """
                if not self.is_teacher:
                    xyzs, dirs, deltas, rays = raymarching.march_rays_train(
                        rays_o,
                        rays_d,
                        self.bound,
                        self.density_bitfield,
                        self.cascade,
                        self.grid_size,
                        nears,
                        fars,
                        counter,
                        self.mean_count,
                        perturb,
                        128,
                        force_all_rays,
                        dt_gamma,
                        max_steps,
                    )
                    inherited_params = [xyzs, dirs, deltas, rays]
                else:
                    xyzs, dirs, deltas, rays = inherited_params
            else:
                if self.is_teacher:
                    xyzs, dirs, deltas, rays = raymarching.march_rays_train(
                        rays_o,
                        rays_d,
                        self.bound,
                        self.density_bitfield,
                        self.cascade,
                        self.grid_size,
                        nears,
                        fars,
                        counter,
                        self.mean_count,
                        perturb,
                        128,
                        force_all_rays,
                        dt_gamma,
                        max_steps,
                    )
                    inherited_params = [xyzs, dirs, deltas, rays]
                else:
                    xyzs, dirs, deltas, rays = inherited_params

            # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
            sigmas, rgbs = self(xyzs, dirs)

            if self.args.global_step < self.args.stage_iters["stage1"]:
                return {
                    "stage1": self.args.global_step,
                    "depth": None,
                    "image": None,
                    "inherited_params": inherited_params,
                    "sigmas": sigmas,
                    "rays": rays,
                }
            elif self.args.global_step < self.args.stage_iters["stage2"]:
                return {
                    "stage2": self.args.global_step,
                    "depth": None,
                    "image": None,
                    "inherited_params": inherited_params,
                    "sigmas": sigmas,
                    "rays": rays,
                }

            sigmas = self.density_scale * sigmas

            weights_sum, depth, image = raymarching.composite_rays_train(
                sigmas, rgbs, deltas, rays
            )
            image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
            depth = torch.clamp(depth - nears, min=0) / (fars - nears + 1e-6)
            image = image.view(*prefix, 3)
            depth = depth.view(*prefix)

        else:
            # allocate outputs
            # if use autocast, must init as half so it won't be autocasted and lose reference.
            # dtype = torch.half if torch.is_autocast_enabled() else torch.float32
            # output should always be float32! only network inference uses half.
            dtype = torch.float32

            weights_sum = torch.zeros(N, dtype=dtype, device=device)
            depth = torch.zeros(N, dtype=dtype, device=device)
            image = torch.zeros(N, 3, dtype=dtype, device=device)

            n_alive = N
            alive_counter = torch.zeros([1], dtype=torch.int32, device=device)

            rays_alive = torch.zeros(
                2, n_alive, dtype=torch.int32, device=device
            )  # 2 is used to loop old/new
            rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device)

            step = 0
            i = 0
            while step < max_steps:

                # count alive rays
                if step == 0:
                    # init rays at first step.
                    torch.arange(n_alive, out=rays_alive[0])
                    rays_t[0] = nears
                else:
                    alive_counter.zero_()
                    raymarching.compact_rays(
                        n_alive,
                        rays_alive[i % 2],
                        rays_alive[(i + 1) % 2],
                        rays_t[i % 2],
                        rays_t[(i + 1) % 2],
                        alive_counter,
                    )
                    n_alive = alive_counter.item()  # must invoke D2H copy here

                # exit loop
                if n_alive <= 0:
                    break

                # decide compact_steps
                n_step = max(min(N // n_alive, 8), 1)

                xyzs, dirs, deltas = raymarching.march_rays(
                    n_alive,
                    n_step,
                    rays_alive[i % 2],
                    rays_t[i % 2],
                    rays_o,
                    rays_d,
                    self.bound,
                    self.density_bitfield,
                    self.cascade,
                    self.grid_size,
                    nears,
                    fars,
                    128,
                    perturb,
                    dt_gamma,
                    max_steps,
                )

                sigmas, rgbs = self(xyzs, dirs)
                # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
                # sigmas = density_outputs['sigma']
                # rgbs = self.color(xyzs, dirs, **density_outputs)
                sigmas = self.density_scale * sigmas

                raymarching.composite_rays(
                    n_alive,
                    n_step,
                    rays_alive[i % 2],
                    rays_t[i % 2],
                    sigmas,
                    rgbs,
                    deltas,
                    weights_sum,
                    depth,
                    image,
                )

                # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')

                step += n_step
                i += 1

            image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
            depth = torch.clamp(depth - nears, min=0) / (fars - nears)
            image = image.view(*prefix, 3)
            depth = depth.view(*prefix)

        # print('\n--- render time:--- {:6f}  {:.6f}'.format(time2-time1, time()-time2))
        if self.training:
            return {
                "depth": depth,
                "image": image,
                "inherited_params": inherited_params,
                "sigmas": sigmas,
                "rays": rays,
            }
        else:
            return {
                "depth": depth,
                "image": image,
                "inherited_params": inherited_params,
            }

    @torch.no_grad()
    def mark_untrained_grid(self, poses, intrinsic, S=64):
        # poses: [B, 4, 4]
        # intrinsic: [3, 3]

        if not self.cuda_ray:
            return

        if isinstance(poses, np.ndarray):
            poses = torch.from_numpy(poses)

        B = poses.shape[0]

        fx, fy, cx, cy = intrinsic

        X = torch.arange(
            self.grid_size, dtype=torch.int32, device=self.density_grid.device
        ).split(S)
        Y = torch.arange(
            self.grid_size, dtype=torch.int32, device=self.density_grid.device
        ).split(S)
        Z = torch.arange(
            self.grid_size, dtype=torch.int32, device=self.density_grid.device
        ).split(S)

        count = torch.zeros_like(self.density_grid)
        poses = poses.to(count.device)

        # 5-level loop, forgive me...

        for xs in X:
            for ys in Y:
                for zs in Z:

                    # construct points
                    xx, yy, zz = custom_meshgrid(xs, ys, zs)
                    coords = torch.cat(
                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
                        dim=-1,
                    )  # [N, 3], in [0, 128)
                    indices = raymarching.morton3D(coords).long()  # [N]
                    world_xyzs = (
                        2 * coords.float() / (self.grid_size - 1) - 1
                    ).unsqueeze(
                        0
                    )  # [1, N, 3] in [-1, 1]

                    # cascading
                    for cas in range(self.cascade):
                        bound = min(2 ** cas, self.bound)
                        half_grid_size = bound / self.grid_size
                        # scale to current cascade's resolution
                        cas_world_xyzs = world_xyzs * (bound - half_grid_size)

                        # split batch to avoid OOM
                        head = 0
                        while head < B:
                            tail = min(head + S, B)

                            # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
                            cam_xyzs = cas_world_xyzs - poses[
                                head:tail, :3, 3
                            ].unsqueeze(1)
                            cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3]  # [S, N, 3]

                            # query if point is covered by any camera
                            mask_z = cam_xyzs[:, :, 2] > 0  # [S, N]
                            mask_x = (
                                torch.abs(cam_xyzs[:, :, 0])
                                < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
                            )
                            mask_y = (
                                torch.abs(cam_xyzs[:, :, 1])
                                < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
                            )
                            mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1)  # [N]

                            # update count
                            count[cas, indices] += mask
                            head += S

        # mark untrained grid as -1
        self.density_grid[count == 0] = -1

        # print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}')

    @torch.no_grad()
    def update_extra_state(self, decay=0.95, S=128):
        # call before each epoch to update extra states.

        if not self.cuda_ray:
            return

        # update density grid
        tmp_grid = -torch.ones_like(self.density_grid)

        # full update.
        if self.iter_density < 16:
            # if True:
            X = torch.arange(
                self.grid_size, dtype=torch.int32, device=self.density_grid.device
            ).split(S)
            Y = torch.arange(
                self.grid_size, dtype=torch.int32, device=self.density_grid.device
            ).split(S)
            Z = torch.arange(
                self.grid_size, dtype=torch.int32, device=self.density_grid.device
            ).split(S)

            for xs in X:
                for ys in Y:
                    for zs in Z:
                        # construct points
                        xx, yy, zz = custom_meshgrid(xs, ys, zs)
                        coords = torch.cat(
                            [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
                            dim=-1,
                        )  # [N, 3], in [0, 128)
                        indices = raymarching.morton3D(coords).long()  # [N]
                        xyzs = (
                            2 * coords.float() / (self.grid_size - 1) - 1
                        )  # [N, 3] in [-1, 1]

                        # cascading
                        for cas in range(self.cascade):
                            bound = min(2 ** cas, self.bound)
                            half_grid_size = bound / self.grid_size
                            # scale to current cascade's resolution
                            cas_xyzs = xyzs * (bound - half_grid_size)
                            # add noise in [-hgs, hgs]
                            cas_xyzs += (
                                torch.rand_like(cas_xyzs) * 2 - 1
                            ) * half_grid_size
                            # query density
                            sigmas = (
                                self.density(cas_xyzs)["sigma"].reshape(-1).detach()
                            )
                            sigmas *= self.density_scale
                            # assign
                            tmp_grid[cas, indices] = sigmas

        # partial update (half the computation)
        # TODO: why no need of maxpool ?
        else:
            N = self.grid_size ** 3 // 4  # H * H * H / 4
            for cas in range(self.cascade):
                # random sample some positions
                coords = torch.randint(
                    0, self.grid_size, (N, 3), device=self.density_grid.device
                )  # [N, 3], in [0, 128)
                indices = raymarching.morton3D(coords).long()  # [N]
                # random sample occupied positions
                occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(
                    -1
                )  # [Nz]
                rand_mask = torch.randint(
                    0,
                    occ_indices.shape[0],
                    [N],
                    dtype=torch.long,
                    device=self.density_grid.device,
                )
                occ_indices = occ_indices[
                    rand_mask
                ]  # [Nz] --> [N], allow for duplication
                occ_coords = raymarching.morton3D_invert(occ_indices)  # [N, 3]
                # concat
                indices = torch.cat([indices, occ_indices], dim=0)
                coords = torch.cat([coords, occ_coords], dim=0)
                # same below
                xyzs = (
                    2 * coords.float() / (self.grid_size - 1) - 1
                )  # [N, 3] in [-1, 1]
                bound = min(2 ** cas, self.bound)
                half_grid_size = bound / self.grid_size
                # scale to current cascade's resolution
                cas_xyzs = xyzs * (bound - half_grid_size)
                # add noise in [-hgs, hgs]
                cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
                # query density
                sigmas = self.density(cas_xyzs)["sigma"].reshape(-1).detach()
                sigmas *= self.density_scale
                # assign
                tmp_grid[cas, indices] = sigmas

        ## max-pool on tmp_grid for less aggressive culling [No significant improvement...]
        # invalid_mask = tmp_grid < 0
        # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1)
        # tmp_grid[invalid_mask] = -1

        # ema update
        valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
        self.density_grid[valid_mask] = torch.maximum(
            self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]
        )
        self.mean_density = torch.mean(
            self.density_grid.clamp(min=0)
        ).item()  # -1 non-training regions are viewed as 0 density.
        self.iter_density += 1

        # convert to bitfield
        density_thresh = min(self.mean_density, self.density_thresh)
        self.density_bitfield = raymarching.packbits(
            self.density_grid, density_thresh, self.density_bitfield
        )

        ### update step counter
        total_step = min(16, self.local_step)
        if total_step > 0:
            self.mean_count = int(
                self.step_counter[:total_step, 0].sum().item() / total_step
            )
        self.local_step = 0

        # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')

    def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
        # rays_o, rays_d: [B, N, 3], assumes B == 1
        # return: pred_rgb: [B, N, 3]

        if self.cuda_ray:
            _run = self.run_cuda
        else:
            _run = self.run

        B, N = rays_o.shape[:2]
        device = rays_o.device

        # never stage when cuda_ray
        if staged and not self.cuda_ray:
            depth = torch.empty((B, N), device=device)
            image = torch.empty((B, N, 3), device=device)

            for b in range(B):
                head = 0
                while head < N:
                    tail = min(head + max_ray_batch, N)
                    results_ = _run(
                        rays_o[b : b + 1, head:tail],
                        rays_d[b : b + 1, head:tail],
                        **kwargs
                    )
                    depth[b : b + 1, head:tail] = results_["depth"]
                    image[b : b + 1, head:tail] = results_["image"]
                    head += max_ray_batch

            results = {}
            results["depth"] = depth
            results["image"] = image

        else:
            results = _run(rays_o, rays_d, **kwargs)

        return results


================================================
FILE: distill_mutual/utils.py
================================================
import os
import copy
import lpips
import glob
import tqdm
import math
import random
import warnings
import tensorboardX

import numpy as np
import pandas as pd

import imageio

import time
from datetime import datetime

import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
import trimesh
import mcubes
from rich.console import Console
from torch_ema import ExponentialMovingAverage
from IPython import embed
import sys

from packaging import version as pver

device = torch.device("cuda")
TINY_NUMBER = 1e-6  # float32 only has 7 decimal digits precision


def update_loss_rate(cur_lrate, scale=0.99):
    return cur_lrate * scale


def get_softmax_map_mean(a, b):
    return (F.softmax(a) - F.softmax(b)).abs().mean()


def get_kl(inputs, targets):
    return F.kl_div(F.log_softmax(inputs), F.softmax(targets), reduction="sum")


def nerf_matrix_to_ngp(pose, scale=0.8):
    # for the fox dataset, 0.33 scales camera radius to ~ 2
    new_pose = np.array(
        [
            [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],
            [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],
            [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],
            [0, 0, 0, 1],
        ],
        dtype=np.float32,
    )
    return new_pose


def pose_spherical(theta, phi, radius):
    # for synthetic. it generates sphere random poses
    trans_t = lambda t: np.array(
        [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]]
    ).astype(np.float32)
    rot_phi = lambda phi: np.array(
        [
            [1, 0, 0, 0],
            [0, np.cos(phi), -np.sin(phi), 0],
            [0, np.sin(phi), np.cos(phi), 0],
            [0, 0, 0, 1],
        ]
    ).astype(np.float32)
    rot_theta = lambda th: np.array(
        [
            [np.cos(th), 0, -np.sin(th), 0],
            [0, 1, 0, 0],
            [np.sin(th), 0, np.cos(th), 0],
            [0, 0, 0, 1],
        ]
    ).astype(np.float32)
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180.0 * np.pi) @ c2w
    c2w = rot_theta(theta / 180.0 * np.pi) @ c2w
    c2w = (
        np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).astype(
            np.float32
        )
        @ c2w
    )
    return c2w


def get_rand_poses(data_type="synthetic", original_loader=None):
    """
    Random sampling. Random origins and directions.
    """
    from scipy.spatial.transform import Slerp, Rotation

    assert data_type in {"synthetic", "llff", "tank"}

    def get_single_syn_pose(ph, rand_radius=False):
        theta1 = -180
        theta2 = 180
        phi1 = -ph
        phi2 = 5 - ph if (5 - ph) <= 0 else 0
        theta = theta1 + np.random.rand() * (theta2 - theta1)
        phi = phi1 + np.random.rand() * (phi2 - phi1)
        if rand_radius:
            radius = np.random.uniform(3, 4)
        else:
            radius = 4
        return pose_spherical(theta, phi, radius)

    def get_syn_poses():
        random_poses = np.array([get_single_syn_pose(8) for _ in range(1)])
        for a in range(0, 80):
            rp = np.array(
                [get_single_syn_pose(a) for _ in range(int(((90 - a) // 15) ** 1 + 1))]
            )
            random_poses = np.concatenate([random_poses, rp], axis=0)
        for i in range(len(random_poses)):
            random_poses[i] = nerf_matrix_to_ngp(random_poses[i])
        print(f"\nlen(train data): {len(random_poses)}\n")
        random_poses = torch.from_numpy(random_poses).cuda()
        return random_poses

    def get_tank_poses():
        random_poses = np.array([get_single_syn_pose(8) for _ in range(1)])
        for a in range(5, 20):
            rp = np.array(
                [
                    get_single_syn_pose(a, True)
                    for _ in range(int(((90 - a) // 15) ** 1 + 1))
                ]
            )
            random_poses = np.concatenate([random_poses, rp], axis=0)
        for i in range(len(random_poses)):
            random_poses[i] = nerf_matrix_to_ngp(random_poses[i])
        print(f"\nlen(train data): {len(random_poses)}\n")
        random_poses = torch.from_numpy(random_poses).cuda()
        return random_poses

    def rand_poses_from_cam_centers(centers):
        def normalize(vectors):
            return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)

        size = len(centers)
        forward_vector = -normalize(centers)
        up_vector = (
            torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
        )  # confused at the coordinate system...
        right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))
        up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))

        poses = (
            torch.eye(4, dtype=torch.float, device=device)
            .unsqueeze(0)
            .repeat(size, 1, 1)
        )
        poses[:, :3, :3] = torch.stack(
            (right_vector, up_vector, forward_vector), dim=-1
        )
        poses[:, :3, 3] = centers
        return poses

    def get_llff_poses_rand():
        def get_rand_cam_centers_from_bbox(poses, gen_num=30):
            # use poses to estimate the bbox of the camera
            trasitions = poses[:, :3, 3]
            bbox_max = trasitions.max(axis=0) + 1e-6
            bbox_min = trasitions.min(axis=0) - 1e-6
            rand_xs = np.random.uniform(low=bbox_min[0], high=bbox_max[0], size=gen_num)
            rand_ys = np.random.uniform(low=bbox_min[1], high=bbox_max[1], size=gen_num)
            rand_zs = np.random.uniform(low=bbox_min[2], high=bbox_max[2], size=gen_num)
            centers = np.stack([rand_xs, rand_ys, rand_zs], axis=1)
            return centers.astype(np.float32)

        centers = get_rand_cam_centers_from_bbox(original_loader)
        random_poses = rand_poses_from_cam_centers(torch.from_numpy(centers).cuda())
        random_poses[:, 0, 0] = -random_poses[:, 0, 0]
        return random_poses

    if data_type == "synthetic":
        random_poses = get_syn_poses()
    elif data_type == "llff":
        random_poses = get_llff_poses_rand()
    elif data_type == "tank":
        random_poses = get_tank_poses()
    else:
        raise ValueError("illegal")
    return random_poses


def custom_meshgrid(*args):
    # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
    if pver.parse(torch.__version__) < pver.parse("1.10"):
        return torch.meshgrid(*args)
    else:
        return torch.meshgrid(*args, indexing="ij")


@torch.jit.script
def linear_to_srgb(x):
    return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)


@torch.jit.script
def srgb_to_linear(x):
    return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)


def compute_ssim(
    img0,
    img1,
    max_val,
    filter_size=11,
    filter_sigma=1.5,
    k1=0.01,
    k2=0.03,
    return_map=False,
):
    """Computes SSIM from two images.
    This function was modeled after tf.image.ssim, and should produce comparable
    output.
    Args:
      img0: torch.tensor. An image of size [..., width, height, num_channels].
      img1: torch.tensor. An image of size [..., width, height, num_channels].
      max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
      filter_size: int >= 1. Window size.
      filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
      k1: float > 0. One of the SSIM dampening parameters.
      k2: float > 0. One of the SSIM dampening parameters.
      return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
    Returns:
      Each image's mean SSIM, or a tensor of individual values if `return_map`.
    """
    device = img0.device
    img0 = img0.type(torch.float32)
    img1 = img1.type(torch.float32)
    ori_shape = img0.size()
    width, height, num_channels = ori_shape[-3:]
    img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2)
    img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2)
    batch_size = img0.shape[0]

    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2
    filt = torch.exp(-0.5 * f_i)
    filt /= torch.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    # z is a tensor of size [B, H, W, C]
    filt_fn1 = lambda z: F.conv2d(
        z,
        filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1),
        padding=[hw, 0],
        groups=num_channels,
    )
    filt_fn2 = lambda z: F.conv2d(
        z,
        filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1),
        padding=[0, hw],
        groups=num_channels,
    )

    # Vmap the blurs to the tensor size, and then compose them.
    filt_fn = lambda z: filt_fn1(filt_fn2(z))
    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0 ** 2) - mu00
    sigma11 = filt_fn(img1 ** 2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = torch.clamp(sigma00, min=0.0)
    sigma11 = torch.clamp(sigma11, min=0.0)
    sigma01 = torch.sign(sigma01) * torch.min(
        torch.sqrt(sigma00 * sigma11), torch.abs(sigma01)
    )

    c1 = (k1 * max_val) ** 2
    c2 = (k2 * max_val) ** 2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1)
    return ssim_map if return_map else ssim


def init_lpips(net_name, device):
    assert net_name in ["alex", "vgg"]
    import lpips

    print(f"init_lpips: lpips_{net_name}")
    return lpips.LPIPS(net=net_name, version="0.1").eval().cuda()


lpips_fns = {
    "alex": lpips.LPIPS(net="alex", version="0.1").eval().cuda(),
    "vgg": lpips.LPIPS(net="vgg", version="0.1").eval().cuda(),
}


def rgb_lpips(gt, im, net_name):
    assert net_name in ["alex", "vgg"]
    gt = gt.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda()
    im = im.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda()
    return lpips_fns[net_name](gt, im, normalize=True).item()


@torch.cuda.amp.autocast(enabled=False)
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
    """get rays
    Args:
        poses: [B, 4, 4], cam2world
        intrinsics: [4]
        H, W, N: int
        error_map: [B, 128 * 128], sample probability based on training error
    Returns:
        rays_o, rays_d: [B, N, 3]
        inds: [B, N]
    """

    device = poses.device
    B = poses.shape[0]
    fx, fy, cx, cy = intrinsics

    i, j = custom_meshgrid(
        torch.linspace(0, W - 1, W, device=device),
        torch.linspace(0, H - 1, H, device=device),
    )
    i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5
    j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5

    results = {}

    if N > 0:
        N = min(N, H * W)

        if error_map is None:
            inds = torch.randint(0, H * W, size=[N], device=device)  # may duplicate
            inds = inds.expand([B, N])
        else:

            # weighted sample on a low-reso grid
            inds_coarse = torch.multinomial(
                error_map.to(device), N, replacement=False
            )  # [B, N], but in [0, 128*128)

            # map to the original resolution with random perturb.
            inds_x, inds_y = (
                inds_coarse // 128,
                inds_coarse % 128,
            )  # `//` will throw a warning in torch 1.10... anyway.
            sx, sy = H / 128, W / 128
            inds_x = (
                (inds_x * sx + torch.rand(B, N, device=device) * sx)
                .long()
                .clamp(max=H - 1)
            )
            inds_y = (
                (inds_y * sy + torch.rand(B, N, device=device) * sy)
                .long()
                .clamp(max=W - 1)
            )
            inds = inds_x * W + inds_y

            results["inds_coarse"] = inds_coarse  # need this when updating error_map

        i = torch.gather(i, -1, inds)
        j = torch.gather(j, -1, inds)

        results["inds"] = inds

    else:
        inds = torch.arange(H * W, device=device).expand([B, H * W])

    zs = torch.ones_like(i)
    xs = (i - cx) / fx * zs
    ys = (j - cy) / fy * zs
    directions = torch.stack((xs, ys, zs), dim=-1)
    directions = directions / torch.norm(directions, dim=-1, keepdim=True)
    rays_d = directions @ poses[:, :3, :3].transpose(-1, -2)  # (B, N, 3)

    rays_o = poses[..., :3, 3]  # [B, 3]
    rays_o = rays_o[..., None, :].expand_as(rays_d)  # [B, N, 3]

    results["rays_o"] = rays_o
    results["rays_d"] = rays_d

    return results


def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True


def torch_vis_2d(x, renormalize=False):
    # x: [3, H, W] or [1, H, W] or [H, W]
    import matplotlib.pyplot as plt
    import numpy as np
    import torch

    if isinstance(x, torch.Tensor):
        if len(x.shape) == 3:
            x = x.permute(1, 2, 0).squeeze()
        x = x.detach().cpu().numpy()

    print(f"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}")

    x = x.astype(np.float32)

    # renormalize
    if renormalize:
        x = (x - x.min(axis=0, keepdims=True)) / (
            x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8
        )

    plt.imshow(x)
    plt.show()


def extract_fields(bound_min, bound_max, resolution, query_func, S=128):

    X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)
    Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)
    Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)

    u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
    with torch.no_grad():
        for xi, xs in enumerate(X):
            for yi, ys in enumerate(Y):
                for zi, zs in enumerate(Z):
                    xx, yy, zz = custom_meshgrid(xs, ys, zs)
                    pts = torch.cat(
                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
                        dim=-1,
                    )  # [S, 3]
                    val = (
                        query_func(pts)
                        .reshape(len(xs), len(ys), len(zs))
                        .detach()
                        .cpu()
                        .numpy()
                    )  # [S, 1] --> [x, y, z]
                    u[
                        xi * S : xi * S + len(xs),
                        yi * S : yi * S + len(ys),
                        zi * S : zi * S + len(zs),
                    ] = val
    return u


def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
    # print('threshold: {}'.format(threshold))
    u = extract_fields(bound_min, bound_max, resolution, query_func)

    # print(u.shape, u.max(), u.min(), np.percentile(u, 50))

    vertices, triangles = mcubes.marching_cubes(u, threshold)

    b_max_np = bound_max.detach().cpu().numpy()
    b_min_np = bound_min.detach().cpu().numpy()

    vertices = (
        vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :]
        + b_min_np[None, :]
    )
    return vertices, triangles


class PSNRMeter:
    def __init__(self):
        self.V = 0
        self.N = 0
        self.psnr_list = []

    def clear(self):
        self.V = 0
        self.N = 0
        self.psnr_list = []

    def prepare_inputs(self, *inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            if torch.is_tensor(inp):
                inp = inp.detach().cpu().numpy()
            outputs.append(inp)

        return outputs

    def update(self, preds, truths):
        preds, truths = self.prepare_inputs(
            preds, truths
        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]

        psnr = -10 * np.log10(np.mean((preds - truths) ** 2))
        self.psnr_list.append(psnr)
        self.V += psnr
        self.N += 1
        assert self.N == len(self.psnr_list)

    def measure(self):
        return self.V / self.N

    def write(self, writer, global_step, prefix=""):
        writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step)

    def report(self):
        return f"PSNR = {self.measure():.6f}"


class Trainer(object):
    def __init__(
        self,
        name,  # name of this experiment
        opt,  # extra conf
        model_tea,  # network
        model_stu,
        criterion=None,  # loss function, if None, assume inline implementation in train_step
        optimizer=None,  # optimizer
        ema_decay=None,  # if use EMA, set the decay
        lr_scheduler=None,  # scheduler
        metrics=[],  # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
        local_rank=0,  # which GPU am I
        world_size=1,  # total num of GPUs
        device=None,  # device to use, usually setting to None is OK. (auto choose device)
        mute=False,  # whether to mute all print
        fp16=False,  # amp optimize level
        eval_interval=10e10,  # eval once every $ epoch
        max_keep_ckpt=2,  # max num of saved ckpts in disk
        workspace="workspace",  # workspace to save logs & ckpts
        best_mode="min",  # the smaller/larger result, the better
        use_loss_as_metric=True,  # use loss as the first metric
        report_metric_at_train=False,  # also report metrics at training
        use_checkpoint="latest",  # which ckpt to use at init time
        use_tensorboardX=True,  # whether to use tensorboard for logging
        scheduler_update_every_step=False,  # whether to call scheduler.step() after every train step
    ):

        self.optimizer_fn = optimizer
        self.lr_scheduler_fn = lr_scheduler
        self.name = name
        self.opt = opt
        self.args = opt
        self.mute = mute
        self.metrics = metrics
        self.local_rank = local_rank
        self.world_size = world_size
        self.workspace = workspace
        self.ema_decay = ema_decay
        self.fp16 = fp16
        self.best_mode = best_mode
        self.use_loss_as_metric = use_loss_as_metric
        self.report_metric_at_train = report_metric_at_train
        self.max_keep_ckpt = max_keep_ckpt
        self.eval_interval = eval_interval
        self.use_checkpoint = use_checkpoint
        self.use_tensorboardX = use_tensorboardX
        self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
        self.scheduler_update_every_step = scheduler_update_every_step
        self.device = (
            device
            if device is not None
            else torch.device(
                f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
            )
        )
        self.console = Console()

        self.model_tea = model_tea.to(device)
        self.model_stu = model_stu.to(device)

        if isinstance(criterion, nn.Module):
            criterion.to(self.device)
        self.criterion = criterion

        if optimizer is None:
            self.optimizer = optim.AdamW(
                self.model_stu.parameters(), lr=0.001, weight_decay=5e-4
            )  # naive adam
        else:
            self.optimizer = optimizer(self.model_stu)

        if lr_scheduler is None:
            self.lr_scheduler = optim.lr_scheduler.LambdaLR(
                self.optimizer, lr_lambda=lambda epoch: 1
            )  # fake scheduler
        else:
            self.ls = lr_scheduler
            self.lr_scheduler = lr_scheduler(self.optimizer)
        if ema_decay is not None and ema_decay > 0:
            self.ema = ExponentialMovingAverage(
                self.model_stu.parameters(), decay=ema_decay
            )
        else:
            self.ema = None

        self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)

        # variable init
        self.epoch = 1
        self.global_step = 0
        self.local_step = 0
        self.stats = {
            "loss": [],
            "valid_loss": [],
            "results": [],  # metrics[0], or valid_loss
            "checkpoints": [],  # record path of saved ckpt, to automatically remove old ckpt
            "best_result": None,
        }

        # auto fix
        if len(metrics) == 0 or self.use_loss_as_metric:
            self.best_mode = "min"

        # workspace prepare
        self.log_ptr = None
        if self.workspace is not None:
            os.makedirs(self.workspace, exist_ok=True)
            self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
            self.log_ptr = open(self.log_path, "a+")

            self.ckpt_path = os.path.join(self.workspace, "checkpoints")
            self.best_path = f"{self.ckpt_path}/{self.name}.pth"
            os.makedirs(self.ckpt_path, exist_ok=True)
        self.log(self.opt)

        self.log(
            f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}'
        )
        self.log(
            f"[INFO] #parameters: {sum([p.numel() for p in model_stu.parameters() if p.requires_grad])}"
        )

        if (
            self.workspace is not None
        ):  # only load state_dict for teacher and share backbone for student
            self.log(f"[INFO] Loading teacher ckpt from {self.opt.ckpt_teacher} ...")
            self.load_teacher_checkpoint()
            self.log(self.model_tea)
            self.load_student_checkpoint()
            self.log(self.model_stu)
            # self.model_tea.reset_extra_state()
            # self.model_stu.reset_extra_state()
        """
        if opt.rand_pose >= 0: # =0 means only using CLIP loss, >0 means a hybrid mode.
            from nerf.clip_utils import CLIPLoss
            self.clip_loss = CLIPLoss(self.device)
            self.clip_loss.prepare_text([self.opt.clip_text]) # only support one text prompt now...
        """

    def __del__(self):
        if self.log_ptr:
            self.log_ptr.close()

    def log(self, *args, **kwargs):
        if self.local_rank == 0:
            if not self.mute:
                # print(*args)
                self.console.print(*args, **kwargs)
            if self.log_ptr:
                print(*args, file=self.log_ptr)
                self.log_ptr.flush()  # write immediately to file

    def train(self, train_loader, valid_loader, max_epochs):
        self.hard_rays_pool = [torch.tensor([]).cuda(), torch.tensor([]).cuda()]
        self.is_hard_rays_pool_full = False

        if self.use_tensorboardX and self.local_rank == 0:
            self.writer = tensorboardX.SummaryWriter(
                os.path.join(self.workspace, "run", self.name)
            )

        for p in self.model_tea.parameters():
            p.requires_grad = False
        self.model_tea.eval()

        # get a ref to error_map
        self.error_map = train_loader._data.error_map

        if (
            not self.args.use_real_data_for_train
        ):  # using random poses to calculate max_epochs.
            random_poses = get_rand_poses(
                data_type=self.args.data_type,
                original_loader=copy.deepcopy(
                    train_loader._data.poses.detach().cpu().numpy()
                ),
            )
            self.opt.iters = int(
                (self.opt.iters // len(random_poses)) * len(random_poses)
            )
            max_epochs = np.ceil(self.opt.iters / len(random_poses)).astype(np.int32)
            scheduler = lambda optimizer: optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.opt.iters * 1, eta_min=7e-5
            )  # update scheduler according to new opt.iters
            self.lr_scheduler = scheduler(self.optimizer)

        self.total_epoch = max_epochs
        self.log(f"\n----------------total epoch:{max_epochs} -----------\n")

        self.real_train_poses = copy.deepcopy(train_loader._data.poses)
        for epoch in range(self.epoch, max_epochs + 1):
            self.epoch = epoch
            if not self.args.use_real_data_for_train:
                print(f"\n generate new random poses at epoch{self.epoch}")
                random_poses = get_rand_poses(
                    data_type=self.args.data_type,
                    original_loader=self.real_train_poses.detach().cpu().numpy(),
                )
                train_loader._data.poses = copy.deepcopy(random_poses)
                train_loader._data.images = train_loader._data.images[:1].expand(
                    len(random_poses), -1, -1, -1
                )
                train_loader = train_loader._data.dataloader()
            self.train_one_epoch(train_loader)
            print("\n", self.workspace, "\n")

            if (
                self.workspace is not None
                and self.local_rank == 0
                and self.epoch > max_epochs - 1
            ):
                self.save_checkpoint(full=False, best=False)

            if self.epoch % self.eval_interval == 0:
                self.evaluate_one_epoch(valid_loader)
                self.save_checkpoint(full=False, best=True)  # #  为了节省存储,暂时不存储pth

        if self.use_tensorboardX and self.local_rank == 0:
            self.writer.close()

    def train_one_epoch(self, loader):
        # self.log(
        #    f"tttttttttt> Start Training Epoch {self.epoch}/{self.total_epoch}, len(train_data):{len(loader)} lr={self.optimizer.param_groups[0]['lr']:.6f} ..."
        # )

        total_loss = 0
        total_loss_rgb = 0
        total_loss_fea_sc = 0
        total_loss_sigma = 0
        total_loss_color = 0

        psnr_tool = PSNRMeter()
        psnr_tool.clear()
        self.pose_psnr = []  # [(pose1, psnr1), (pose2,psnr2)...]

        if self.local_rank == 0 and self.report_metric_at_train:
            for metric in self.metrics:
                metric.clear()

        self.model_stu.train()
        self.model_tea.train()

        if self.world_size > 1:
            loader.sampler.set_epoch(self.epoch)

        if self.local_rank == 0:
            pbar = tqdm.tqdm(
                total=len(loader) * loader.batch_size,
                bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
            )

        self.local_step = 0

        for data in loader:
            # update grid every 16 steps. It shoule be run in just train a teacher, but not when distillting a student
            if (
                self.model_tea.cuda_ray
                and self.global_step % self.opt.update_extra_interval == 0
            ):
                with torch.cuda.amp.autocast(enabled=self.fp16):
                    if self.opt.update_stu_extra:
                        self.model_stu.update_extra_state()
                    else:
                        pass

            self.local_step += 1
            self.global_step += 1
            self.args.global_step = self.global_step

            self.optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=self.fp16):
                (
                    preds,
                    truths,
                    loss,
                    loss_rgb,
                    loss_fea_sc,
                    loss_color,
                    loss_sigma,
                ) = self.train_step(data)
                if preds is not None:
                    psnr_tool.update(preds, truths)

            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

            if self.scheduler_update_every_step:
                self.lr_scheduler.step()

            loss_val = loss.item()
            total_loss += loss_val
            total_loss_rgb += loss_rgb
            total_loss_sigma += loss_sigma
            total_loss_color += loss_color
            total_loss_fea_sc += loss_fea_sc

            if self.local_rank == 0:
                if self.report_metric_at_train:
                    for metric in self.metrics:
                        metric.update(preds, truths)

                if self.use_tensorboardX:
                    self.writer.add_scalar("train/loss", loss_val, self.global_step)
                    self.writer.add_scalar("train/loss_rgb", loss_rgb, self.global_step)
                    self.writer.add_scalar(
                        "train/loss_fea_sc", loss_fea_sc, self.global_step
                    )
                    self.writer.add_scalar(
                        "train/loss_coloc", loss_color, self.global_step
                    )
                    self.writer.add_scalar(
                        "train/loss_sigma", loss_sigma, self.global_step
                    )
                    self.writer.add_scalar(
                        "train/lr",
                        self.optimizer.param_groups[0]["lr"],
                        self.global_step,
                    )

                if self.scheduler_update_every_step:  # run this
                    cur_lr = self.optimizer.param_groups[0]["lr"]
                    if self.global_step < self.args.stage_iters["stage1"]:
                        pbar.set_description(
                            f"loss={total_loss/self.local_step:.5f}, fea_sc={total_loss_fea_sc/self.local_step:.5f}, lr={cur_lr:.5f}"
                        )
                    elif self.global_step < self.args.stage_iters["stage2"]:
                        pbar.set_description(
                            f"loss={total_loss/self.local_step:.5f}, fea_sc={total_loss_fea_sc/self.local_step:.5f}, sigma={total_loss_sigma/self.local_step:.5f}, color={total_loss_color/self.local_step:.5f}, lr={cur_lr:.6f}"
                        )
                    else:
                        pbar.set_description(
                            f"loss={total_loss/self.local_step:.5f}, rgb={total_loss_rgb/self.local_step:.5f},  lr={cur_lr:.5f}"
                        )
                else:
                    pbar.set_description(
                        f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})"
                    )
                pbar.update(loader.batch_size)

            if (
                self.opt.model_type == "vm"
                and self.global_step in self.opt.upsample_model_steps
            ):
                # shrink
                if (
                    self.model_stu.cuda_ray
                ):  # and self.global_step == self.opt.upsample_model_steps[0]:
                    self.model_stu.shrink_model()

                # adaptive voxel size from aabb_train
                n_vox = self.upsample_resolutions.pop(0) ** 3  # n_voxels
                aabb = self.model_stu.aabb_train.cpu().numpy()
                vox_size = np.cbrt(np.prod(aabb[3:] - aabb[:3]) / n_vox)
                reso = ((aabb[3:] - aabb[:3]) / vox_size).astype(np.int32).tolist()
                self.log(
                    f"[INFO] upsample model at step {self.global_step} from {self.model_stu.resolution} to {reso}"
                )
                from IPython import embed

                embed()
                self.model_stu.upsample_model(reso)

                # reset optimizer since params changed.
                self.optimizer = self.optimizer_fn(self.model_stu)
                self.lr_scheduler = self.lr_scheduler_fn(self.optimizer)

        if self.ema is not None:
            self.ema.update()

        average_loss = total_loss / self.local_step
        self.stats["loss"].append(average_loss)

        if self.local_rank == 0:
            pbar.close()
            if self.report_metric_at_train:
                for metric in self.metrics:
                    self.log(metric.report(), style="red")
                    if self.use_tensorboardX:
                        metric.write(self.writer, self.epoch, prefix="train")
                    metric.clear()

        if not self.scheduler_update_every_step:
            if isinstance(
                self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
            ):
                self.lr_scheduler.step(average_loss)
            else:
                self.lr_scheduler.step()

        psnr_tool.psnr_list.sort()
        if self.global_step < self.args.stage_iters["stage1"]:
            self.log(
                f"tttttttttt> Train stage1 Epoch:{self.epoch}. loss_fea:{total_loss_fea_sc/self.local_step:.6f}"
            )
        elif self.global_step < self.args.stage_iters["stage2"]:
            self.log(
                f"tttttttttt> Train stage2 Epoch:{self.epoch}. loss_fea_sc:{total_loss_fea_sc/self.local_step:.3f} loss_sigma:{total_loss_sigma/self.local_step:.3f} loss_color:{total_loss_color/self.local_step:.3f}"
            )
        else:
            self.log(
                f"tttttttttt> Train stage3 Epoch:{self.epoch}. loss_rgb:{total_loss_rgb/self.global_step:.3f} loss_fea_sc:{total_loss_fea_sc/self.local_step:.3f} loss_sigma:{total_loss_sigma/self.local_step:.3f} loss_color:{total_loss_color/self.local_step:.3f}"
            )
            self.log(
                f"tttttttttt> Train PSNR Epoch {self.epoch}. psnr_min:{psnr_tool.psnr_list[0]:.3f} psnr_max:{psnr_tool.psnr_list[-1]:.3f} psnr_mean:{np.mean(psnr_tool.psnr_list):.3f}"
            )

    def get_loss(self, pred, gt):
        if self.opt.loss_type == "L2":
            loss = torch.mean((gt - pred) ** 2)
        elif self.opt.loss_type == "normL2":
            loss = torch.norm(pred - gt)
        elif self.opt.loss_type == "normL1":
            loss = torch.norm(pred - gt, p=1)
        elif self.opt.loss_type == "smoothL1":
            loss = torch.nn.functional.smooth_l1_loss(pred, gt, beta=0.05)
        else:
            raise ValueError("error loss_type")
        return loss

    def train_step(self, data):
        rays_o = data["rays_o"]  # [B, N, 3]
        rays_d = data["rays_d"]  # [B, N, 3]  [1, N=rays_num=4096, 3]

        loss = 0.0

        # if there is no gt image, we train with CLIP loss.
        if "images" not in data:
            assert 1 == 2
            B, N = rays_o.shape[:2]
            H, W = data["H"], data["W"]
            # currently fix white bg, MUST force all rays!
            outputs = self.model.render(
                rays_o,
                rays_d,
                staged=False,
                bg_color=None,
                perturb=True,
                force_all_rays=True,
                **vars(self.opt),
            )
            pred_rgb = (
                outputs["image"].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()
            )
            loss = self.clip_loss(pred_rgb)
            return pred_rgb, None, loss

        images = data["images"]  # [B, N, 3/4]
        B, N, C = images.shape

        # if self.opt.color_space == 'linear':
        #    images[..., :3] = srgb_to_linear(images[..., :3])

        if (
            C == 3 or self.model_stu.bg_radius > 0
        ):  #  C=4 in synthetic dataset. C=3 for real dataset
            bg_color = 1
        # train with random background color if not using a bg model and has alpha channel.
        else:
            bg_color = torch.rand(
                [B, rays_o.size(1), 3], dtype=images.dtype, device=images.device
            )

        if self.opt.render_stu_first:
            outputs_stu = self.model_stu.render(
                rays_o,
                rays_d,
                staged=False,
                bg_color=bg_color,
                perturb=True,
                force_all_rays=False,
                **vars(self.opt),
            )
            pred_rgb_stu = outputs_stu["image"]
            with torch.no_grad():
                outputs_tea = self.model_tea.render(
                    rays_o,
                    rays_d,
                    staged=False,
                    bg_color=bg_color,
                    perturb=True,
                    force_all_rays=False,
                    inherited_params=outputs_stu["inherited_params"],
                    **vars(self.opt),
                )
                pred_rgb_tea = outputs_tea["image"]
        else:
            with torch.no_grad():
                outputs_tea = self.model_tea.render(
                    rays_o,
                    rays_d,
                    staged=False,
                    bg_color=bg_color,
                    perturb=True,
                    force_all_rays=False,
                    **vars(self.opt),
                )
                pred_rgb_tea = outputs_tea["image"]
            outputs_stu = self.model_stu.render(
                rays_o,
                rays_d,
                staged=False,
                bg_color=bg_color,
                perturb=True,
                force_all_rays=False,
                inherited_params=outputs_tea["inherited_params"],
                **vars(self.opt),
            )
            pred_rgb_stu = outputs_stu["image"]
        gt_rgb = pred_rgb_tea
        self.opt.loss_rate_fea_sc = update_loss_rate(self.opt.loss_rate_fea_sc, 0.995)

        if (
            "stage1" in outputs_stu
            and self.opt.loss_rate_fea_sc > 0.0
            and self.model_stu.feature_sigma_color is not None
            and self.model_tea.feature_sigma_color is not None
        ):
            assert (
                self.model_stu.feature_sigma_color.shape
                == self.model_tea.feature_sigma_color.shape
            )
            loss_fea_sc = self.get_loss(
                self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color
            )
            loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc
            return None, None, loss, 0, loss_fea_sc.detach().item(), 0, 0
        if "stage2" in outputs_stu:
            if self.opt.loss_rate_color > 0.0:
                assert self.model_stu.color_l.shape == self.model_tea.color_l.shape
                loss_color = self.get_loss(
                    self.model_stu.color_l, self.model_tea.color_l
                )
                loss = loss + self.opt.loss_rate_color * loss_color
            else:
                assert self.model_stu.color_l.shape == self.model_tea.color_l.shape
                loss_color = self.get_loss(
                    self.model_stu.color_l, self.model_tea.color_l
                )
            if self.opt.loss_rate_sigma > 0.0:
                assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape
                loss_sigma = self.get_loss(
                    self.model_stu.sigma_l, self.model_tea.sigma_l
                )
                loss = loss + self.opt.loss_rate_sigma * loss_sigma
            else:
                assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape
                loss_sigma = self.get_loss(
                    self.model_stu.sigma_l, self.model_tea.sigma_l
                )
            if (
                self.opt.loss_rate_fea_sc > 0.0
                and self.model_stu.feature_sigma_color is not None
                and self.model_tea.feature_sigma_color is not None
            ):
                assert (
                    self.model_stu.feature_sigma_color.shape
                    == self.model_tea.feature_sigma_color.shape
                )
                loss_fea_sc = self.get_loss(
                    self.model_stu.feature_sigma_color,
                    self.model_tea.feature_sigma_color,
                )
                loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc
            else:
                loss_fea_sc = torch.tensor(0.0)
            return (
                None,
                None,
                loss,
                0,
                loss_fea_sc.detach().item(),
                loss_color.detach().item(),
                loss_sigma.detach().item(),
            )

        if self.opt.loss_type == "normL2":
            loss_rgb = torch.norm(pred_rgb_tea - pred_rgb_stu)
        elif self.opt.loss_type == "normL1":
            loss_rgb = torch.norm(pred_rgb_tea - pred_rgb_stu, p=1)
        elif self.opt.loss_type == "L2":
            loss_rgb = self.criterion(pred_rgb_tea, pred_rgb_stu).mean(
                -1
            )  # [B, N, 3] --> [B, N]
            if len(loss_rgb.shape) == 3:  # [K, B, N]
                loss_rgb = loss_rgb.mean(0)
            if self.error_map is not None:
                index = data["index"]  # [B]
                inds = data["inds_coarse"]  # [B, N]
                error_map = self.error_map[index]  # [B, H * W]
                error = loss_rgb.detach().to(
                    error_map.device
                )  # [B, N], already in [0, 1]
                ema_error = 0.1 * error_map.gather(1, inds) + 0.9 * error  # ema update
                error_map.scatter_(1, inds, ema_error)
                self.error_map[index] = error_map  # put back
            loss_rgb = loss_rgb.mean()
        else:
            raise ValueError("error loss_type")
        loss = loss + loss_rgb * self.opt.loss_rate_rgb

        if self.opt.l1_reg_weight > 0.0 and self.opt.model_type == "vm":
            loss = loss + self.model_stu.density_loss() * self.opt.l1_reg_weight
        if (
            self.opt.loss_rate_fea_sc > 0.0
            and self.model_stu.feature_sigma_color is not None
            and self.model_tea.feature_sigma_color is not None
        ):
            assert (
                self.model_stu.feature_sigma_color.shape
                == self.model_tea.feature_sigma_color.shape
            )
            loss_fea_sc = self.get_loss(
                self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color
            )
            loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc
        elif (
            self.model_stu.feature_sigma_color is None
            or self.model_tea.feature_sigma_color is None
        ):
            loss_fea_sc = torch.tensor(0.0)
        else:
            assert (
                self.model_stu.feature_sigma_color.shape
                == self.model_tea.feature_sigma_color.shape
            )
            loss_fea_sc = self.get_loss(
                self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color
            )
        if self.opt.loss_rate_color > 0.0:
            assert self.model_stu.color_l.shape == self.model_tea.color_l.shape
            loss_color = self.get_loss(self.model_stu.color_l, self.model_tea.color_l)
            loss = loss + self.opt.loss_rate_color * loss_color
        else:
            assert self.model_stu.color_l.shape == self.model_tea.color_l.shape
            loss_color = self.get_loss(self.model_stu.color_l, self.model_tea.color_l)
        if self.opt.loss_rate_sigma > 0.0:
            assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape
            loss_sigma = self.get_loss(self.model_stu.sigma_l, self.model_tea.sigma_l)
            loss = loss + self.opt.loss_rate_sigma * loss_sigma
        else:
            assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape
            loss_sigma = self.get_loss(self.model_stu.sigma_l, self.model_tea.sigma_l)

        loss_rgb_show = self.criterion(
            pred_rgb_tea.detach(), pred_rgb_stu.detach()
        ).mean()  # [B, N, 3] --> [B, N]
        return (
            pred_rgb_stu,
            gt_rgb,
            loss,
            loss_rgb_show.detach().item(),
            loss_fea_sc.detach().item(),
            loss_color.detach().item(),
            loss_sigma.detach().item(),
        )

    ### ------------------------------

    def evaluate(self, loader, name=None):
        self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
        self.evaluate_one_epoch(loader, name)
        self.use_tensorboardX = use_tensorboardX

    def evaluate_one_epoch(self, loader, name=None):
        if name is None:
            name = f"{self.name}_ep{self.epoch:04d}"

        total_loss = 0
        if self.local_rank == 0:
            for metric in self.metrics:
                metric.clear()
        if self.opt.test_teacher:
            self.model_stu = self.model_tea
        self.model_stu.eval()

        if self.ema is not None:
            self.ema.store()
            self.ema.copy_to()

        if self.local_rank == 0:
            pbar = tqdm.tqdm(
                total=len(loader) * loader.batch_size,
                bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
            )

        with torch.no_grad():
            self.local_step = 0
            self.ssim = 0.0
            self.lpips_vgg = 0.0
            self.lpips_alex = 0.0

            # update grid
            if self.model_stu.cuda_ray:
                with torch.cuda.amp.autocast(enabled=self.fp16):
                    if self.opt.update_stu_extra:
                        self.model_stu.update_extra_state()
                    else:
                        pass

            frames = []
            frames_depth = []
            for data in loader:
                self.local_step += 1

                with torch.cuda.amp.autocast(enabled=self.fp16):
                    preds, preds_depth, truths, loss = self.eval_step(data)

                # all_gather/reduce the statistics (NCCL only support all_*)
                if self.world_size > 1:
                    dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                    loss = loss / self.world_size
                    preds_list = [
                        torch.zeros_like(preds).to(self.device)
                        for _ in range(self.world_size)
                    ]  # [[B, ...], [B, ...], ...]
                    dist.all_gather(preds_list, preds)
                    preds = torch.cat(preds_list, dim=0)

                    preds_depth_list = [
                        torch.zeros_like(preds_depth).to(self.device)
                        for _ in range(self.world_size)
                    ]  # [[B, ...], [B, ...], ...]
                    dist.all_gather(preds_depth_list, preds_depth)
                    preds_depth = torch.cat(preds_depth_list, dim=0)

                    truths_list = [
                        torch.zeros_like(truths).to(self.device)
                        for _ in range(self.world_size)
                    ]  # [[B, ...], [B, ...], ...]
                    dist.all_gather(truths_list, truths)
                    truths = torch.cat(truths_list, dim=0)
                loss_val = loss.item()
                total_loss += loss_val

                if self.local_rank == 0:

                    for metric in self.metrics:
                        metric.update(preds, truths)
                    self.lpips_alex += rgb_lpips(truths, preds, "alex")
                    self.lpips_vgg += rgb_lpips(truths, preds, "vgg")
                    self.ssim += compute_ssim(
                        preds,
                        truths,
                        max_val=max(preds.max().item(), truths.max().item()),
                    ).item()

                    # save image
                    save_path = os.path.join(
                        self.workspace,
                        loader._data.type,
                        f"{name}_{self.local_step:04d}.png",
                    )
                    save_path_depth = os.path.join(
                        self.workspace,
                        loader._data.type,
                        f"{name}_{self.local_step:04d}_depth.png",
                    )
                    # save_path_gt = os.path.join(self.workspace, loader._data.type, f'{name}_{self.local_step:04d}_gt.png')

                    os.makedirs(os.path.dirname(save_path), exist_ok=True)

                    if self.opt.color_space == "linear":
                        preds = linear_to_srgb(preds)

                    pred = preds[0].detach().cpu().numpy()
                    truth = truths[0].detach().cpu().numpy()
                    pred_depth = preds_depth[0].detach().cpu().numpy()
                    cv2.imwrite(
                        save_path,
                        cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR),
                    )
                    cv2.imwrite(save_path_depth, (pred_depth * 255).astype(np.uint8))
                    frames.append((pred * 255).astype(np.uint8))
                    frames_depth.append((pred_depth * 255).astype(np.uint8))

                    pbar.set_description(
                        f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})"
                    )
                    pbar.update(loader.batch_size)

            print(
                f"\n----video num(frames): {len(frames)} depth video num:{len(frames_depth)}  ----\n"
            )
            imageio.mimwrite(
                os.path.join(os.path.dirname(save_path), "video.mp4"),
                frames,
                fps=int(30 * 0.7),
                macro_block_size=8,
            )
            imageio.mimwrite(
                os.path.join(os.path.dirname(save_path), "video_depth.mp4"),
                frames_depth,
                fps=int(30 * 0.7),
                macro_block_size=8,
            )

        psnr_tool = self.metrics[0]

        psnr_tool.psnr_list.sort()
        self.log(
            f"\neeeeeeeee> {loader._data.type} PSRN Report: Epoch{self.epoch}.  psnr_mean:{np.mean(psnr_tool.psnr_list):.2f}"
        )

        average_loss = total_loss / self.local_step
        self.stats["valid_loss"].append(average_loss)

        if self.local_rank == 0:
            pbar.close()
            if not self.use_loss_as_metric and len(self.metrics) > 0:
                result = self.metrics[0].measure()
                self.stats["results"].append(
                    result if self.best_mode == "min" else -result
                )  # if max mode, use -result
            else:
                self.stats["results"].append(
                    average_loss
                )  # if no metric, choose best by min loss

            for metric in self.metrics:
                # self.log(metric.report(), style="blue")
                psnr = metric.report().split("=")[-1].strip()[:5]
                self.psnr = float(psnr)
                if self.use_tensorboardX and loader._data.type == 'val':
                    metric.write(self.writer, self.epoch, prefix="evaluate")
                metric.clear()

        self.ssim /= self.local_step
        self.lpips_alex /= self.local_step
        self.lpips_vgg /= self.local_step
        if self.ema is not None:
            self.ema.restore()
        self.log(
            f"eeeeeeeeee> {loader._data.type} Metric Report: Epoch{self.epoch}. psnr:{psnr} ssim:{self.ssim:.2f} alex:{self.lpips_alex:.2f} vgg:{self.lpips_vgg:.2f}"
        )

    def eval_step(self, data):

        rays_o = data["rays_o"]  # [B, N, 3]
        rays_d = data["rays_d"]  # [B, N, 3]
        images = data["images"]  # [B, H, W, 3/4]
        B, H, W, C = images.shape

        if self.opt.color_space == "linear":
            images[..., :3] = srgb_to_linear(images[..., :3])

        # eval with fixed background color
        bg_color = 1
        if C == 4:
            gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (
                1 - images[..., 3:]
            )
        else:
            gt_rgb = images

        outputs = self.model_stu.render(
            rays_o,
            rays_d,
            staged=True,
            bg_color=bg_color,
            perturb=False,
            **vars(self.opt),
        )

        pred_rgb = outputs["image"].reshape(B, H, W, 3)
        pred_depth = outputs["depth"].reshape(B, H, W)

        loss = self.criterion(pred_rgb, gt_rgb).mean()

        return pred_rgb, pred_depth, gt_rgb, loss

    def save_checkpoint(self, name=None, full=False, best=False, remove_old=True):
        full = False
        if name is None:
            name = f"{self.name}_ep{self.epoch:04d}"
        if self.opt.model_type == "vm":
            state = {
                "epoch": self.epoch,
                "global_step": self.global_step,
                "stats": self.stats,
                "resolution": self.model_stu.resolution,
            }
        else:
            state = {
                "epoch": self.epoch,
                "global_step": self.global_step,
                "stats": self.stats,
            }

        if self.model_stu.cuda_ray:
            state["mean_count"] = self.model_stu.mean_count
            state["mean_density"] = self.model_stu.mean_density

        if full:
            state["optimizer"] = self.optimizer.state_dict()
            state["lr_scheduler"] = self.lr_scheduler.state_dict()
            state["scaler"] = self.scaler.state_dict()
            if self.ema is not None:
                state["ema"] = self.ema.state_dict()

        if not best:

            state["model"] = self.model_stu.state_dict()

            file_path = f"{self.ckpt_path}/{name}.pth"

            if remove_old:
                self.stats["checkpoints"].append(file_path)

                if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
                    old_ckpt = self.stats["checkpoints"].pop(0)
                    if os.path.exists(old_ckpt):
                        os.remove(old_ckpt)

            torch.save(state, file_path)

        else:
            if len(self.stats["results"]) > 0:
                if (
                    self.stats["best_result"] is None
                    or self.stats["results"][-1] < self.stats["best_result"]
                ):
                    self.log(
                        f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}"
                    )
                    self.stats["best_result"] = self.stats["results"][-1]

                    # save ema results
                    if self.ema is not None:
                        self.ema.store()
                        self.ema.copy_to()

                    state["model"] = self.model_stu.state_dict()

                    if self.ema is not None:
                        self.ema.restore()

                    torch.save(state, self.best_path)
            else:
                self.log(
                    f"[WARN] no evaluated results found, skip saving best checkpoint."
                )

    def load_teacher_checkpoint(self):
        checkpoint_dict = torch.load(self.opt.ckpt_teacher, map_location=self.device)

        missing_keys, unexpected_keys = self.model_tea.load_state_dict(
            checkpoint_dict["model"], strict=False
        )
        self.log("[INFO] loaded teacher model.")
        if len(missing_keys) > 0:
            self.log(f"[WARN] missing keys: {missing_keys}")
        if len(unexpected_keys) > 0:
            self.log(f"[WARN] unexpected keys: {unexpected_keys}")
        if self.ema is not None and "ema" in checkpoint_dict:
            self.ema.load_state_dict(checkpoint_dict["ema"])

        if self.model_tea.cuda_ray:
            if "mean_count" in checkpoint_dict:
                self.model_tea.mean_count = checkpoint_dict["mean_count"]
            if "mean_density" in checkpoint_dict:
                self.model_tea.mean_density = checkpoint_dict["mean_density"]
        """
        self.stats = checkpoint_dict['stats']
        self.epoch = checkpoint_dict['epoch']
        self.global_step = checkpoint_dict['global_step']
        self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
        
        if self.optimizer and  'optimizer' in checkpoint_dict:
            try:
                self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
                self.log("[INFO] loaded optimizer.")
            except:
                self.log("[WARN] Failed to load optimizer.")
        
        if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
            try:
                self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
                self.log("[INFO] loaded scheduler.")
            except:
                self.log("[WARN] Failed to load scheduler.")
        
        if self.scaler and 'scaler' in checkpoint_dict:
            try:
                self.scaler.load_state_dict(checkpoint_dict['scaler'])
                self.log("[INFO] loaded scaler.")
            except:
                self.log("[WARN] Failed to load scaler.")


        if self.model_tea.cuda_ray:
            if 'mean_count' in checkpoint_dict:
                self.model_tea.mean_count = checkpoint_dict['mean_count']
            if 'mean_density' in checkpoint_dict:
                self.model_tea.mean_density = checkpoint_dict['mean_density']
        """

    def load_student_checkpoint(self):
        if self.opt.ckpt_student:
            checkpoint_dict = torch.load(
                self.opt.ckpt_student, map_location=self.device
            )
        else:
            checkpoint_dict = torch.load(
                self.opt.ckpt_teacher, map_location=self.device
            )

        if self.opt.model_type == "vm" and "resolution" in checkpoint_dict:
            self.model_stu.upsample_model(checkpoint_dict["resolution"])
        missing_keys, unexpected_keys = self.model_stu.load_state_dict(
            checkpoint_dict["model"], strict=False
        )
        self.log("[INFO] loaded student model.")
        if len(missing_keys) > 0:
            self.log(f"[WARN] missing keys: {missing_keys}")
        if len(unexpected_keys) > 0:
            self.log(f"[WARN] unexpected keys: {unexpected_keys}")

        if self.model_stu.cuda_ray:
            if "mean_count" in checkpoint_dict:
                self.model_stu.mean_count = checkpoint_dict["mean_count"]
            if "mean_density" in checkpoint_dict:
                self.model_stu.mean_density = checkpoint_dict["mean_density"]

        if self.ema is not None and "ema" in checkpoint_dict:
            self.ema.load_state_dict(checkpoint_dict["ema"])

        """
        self.stats = checkpoint_dict['stats']
        self.epoch = checkpoint_dict['epoch']
        self.global_step = checkpoint_dict['global_step']
        self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
        
        if self.optimizer and  'optimizer' in checkpoint_dict:
            try:
                self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
                self.log("[INFO] loaded optimizer.")
            except:
                self.log("[WARN] Failed to load optimizer.")
        
        if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
            try:
                self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
                self.log("[INFO] loaded scheduler.")
            except:
                self.log("[WARN] Failed to load scheduler.")
        
        if self.scaler and 'scaler' in checkpoint_dict:
            try:
                self.scaler.load_state_dict(checkpoint_dict['scaler'])
                self.log("[INFO] loaded scaler.")
            except:
                self.log("[WARN] Failed to load scaler.")
        """

    def load_checkpoint(self, checkpoint=None, model_only=False):
        if checkpoint is None:
            checkpoint_list = sorted(glob.glob(f"{self.ckpt_path}/{self.name}_ep*.pth"))
            if checkpoint_list:
                checkpoint = checkpoint_list[-1]
                self.log(f"[INFO] Latest checkpoint is {checkpoint}")
            else:
                self.log("[WARN] No checkpoint found, model randomly initialized.")
                return

        checkpoint_dict = torch.load(checkpoint, map_location=self.device)

        if "model" not in checkpoint_dict:
            self.model.load_state_dict(checkpoint_dict)
            self.log("[INFO] loaded model.")
            return

        missing_keys, unexpected_keys = self.model.load_state_dict(
            checkpoint_dict["model"], strict=False
        )
        self.log("[INFO] loaded model.")
        if len(missing_keys) > 0:
            self.log(f"[WARN] missing keys: {missing_keys}")
        if len(unexpected_keys) > 0:
            self.log(f"[WARN] unexpected keys: {unexpected_keys}")

        if self.ema is not None and "ema" in checkpoint_dict:
            self.ema.load_state_dict(checkpoint_dict["ema"])

        if self.model.cuda_ray:
            if "mean_count" in checkpoint_dict:
                self.model.mean_count = checkpoint_dict["mean_count"]
            if "mean_density" in checkpoint_dict:
                self.model.mean_density = checkpoint_dict["mean_density"]

        if model_only:
            return

        self.stats = checkpoint_dict["stats"]
        self.epoch = checkpoint_dict["epoch"]
        self.global_step = checkpoint_dict["global_step"]
        self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")

        if self.optimizer and "optimizer" in checkpoint_dict:
            try:
                self.optimizer.load_state_dict(checkpoint_dict["optimizer"])
                self.log("[INFO] loaded optimizer.")
            except:
                self.log("[WARN] Failed to load optimizer.")

        if self.lr_scheduler and "lr_scheduler" in checkpoint_dict:
            try:
                self.lr_scheduler.load_state_dict(checkpoint_dict["lr_scheduler"])
                self.log("[INFO] loaded scheduler.")
            except:
                self.log("[WARN] Failed to load scheduler.")

        if self.scaler and "scaler" in checkpoint_dict:
            try:
                self.scaler.load_state_dict(checkpoint_dict["scaler"])
                self.log("[INFO] loaded scaler.")
            except:
                self.log("[WARN] Failed to load scaler.")

    def test(self, loader, save_path=None, name=None):
        assert 1 == 2
        if save_path is None:
            save_path = os.path.join(self.workspace, "results")

        if name is None:
            name = f"{self.name}_ep{self.epoch:04d}"

        os.makedirs(save_path, exist_ok=True)

        self.log(f"==> Start Test, save results to {save_path}")

        pbar = tqdm.tqdm(
            total=len(loader) * loader.batch_size,
            bar_format="{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
        )
        self.model_stu.eval()
        with torch.no_grad():

            # update grid
            if self.model_stu.cuda_ray:
                with torch.cuda.amp.autocast(enabled=self.fp16):
                    self.model_stu.update_extra_state()

            for i, data in enumerate(loader):

                with torch.cuda.amp.autocast(enabled=self.fp16):
                    preds, preds_depth = self.test_step(data)

                path = os.path.join(save_path, f"{name}_{i:04d}.png")
                path_depth = os.path.join(save_path, f"{name}_{i:04d}_depth.png")

                # self.log(f"[INFO] saving test image to {path}")

                if self.opt.color_space == "linear":
                    preds = linear_to_srgb(preds)

                pred = preds[0].detach().cpu().numpy()
                pred_depth = preds_depth[0].detach().cpu().numpy()

                cv2.imwrite(
                    path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
                )
                cv2.imwrite(path_depth, (pred_depth * 255).astype(np.uint8))

                pbar.update(loader.batch_size)

        self.log(f"==> Finished Test.")

    # moved out bg_color and perturb for more flexible control...
    def test_step(self, data, bg_color=None, perturb=False):

        rays_o = data["rays_o"]  # [B, N, 3]
        rays_d = data["rays_d"]  # [B, N, 3]
        H, W = data["H"], data["W"]

        if bg_color is not None:
            bg_color = bg_color.to(self.device)

        outputs = self.model_stu.render(
            rays_o,
            rays_d,
            staged=True,
            bg_color=bg_color,
            perturb=perturb,
            **vars(self.opt),
        )

        pred_rgb = outputs["image"].reshape(-1, H, W, 3)
        pred_depth = outputs["depth"].reshape(-1, H, W)

        return pred_rgb, pred_depth


================================================
FILE: gridencoder/__init__.py
================================================
from .grid import GridEncoder


================================================
FILE: gridencoder/backend.py
================================================
import os
from torch.utils.cpp_extension import load

_src_path = os.path.dirname(os.path.abspath(__file__))

nvcc_flags = [
    "-O3",
    "-std=c++14",
    "-U__CUDA_NO_HALF_OPERATORS__",
    "-U__CUDA_NO_HALF_CONVERSIONS__",
    "-U__CUDA_NO_HALF2_OPERATORS__",
]

if os.name == "posix":
    c_flags = ["-O3", "-std=c++14"]
elif os.name == "nt":
    c_flags = ["/O2", "/std:c++17"]

    # find cl.exe
    def find_cl_path():
        import glob

        for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
            paths = sorted(
                glob.glob(
                    r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64"
                    % edition
                ),
                reverse=True,
            )
            if paths:
                return paths[0]

    # If cl.exe is not on path, try to find it.
    if os.system("where cl.exe >nul 2>nul") != 0:
        cl_path = find_cl_path()
        if cl_path is None:
            raise RuntimeError(
                "Could not locate a supported Microsoft Visual C++ installation"
            )
        os.environ["PATH"] += ";" + cl_path

_backend = load(
    name="_grid_encoder",
    extra_cflags=c_flags,
    extra_cuda_cflags=nvcc_flags,
    sources=[
        os.path.join(_src_path, "src", f)
        for f in [
            "gridencoder.cu",
            "bindings.cpp",
        ]
    ],
)

__all__ = ["_backend"]


================================================
FILE: gridencoder/grid.py
================================================
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd

try:
    import _gridencoder as _backend
except ImportError:
    from .backend import _backend

_gridtype_to_id = {
    "hash": 0,
    "tiled": 1,
}


class _grid_encode(Function):
    @staticmethod
    @custom_fwd
    def forward(
        ctx,
        inputs,
        embeddings,
        offsets,
        per_level_scale,
        base_resolution,
        calc_grad_inputs=False,
        gridtype=0,
        align_corners=False,
    ):
        # inputs: [B, D], float in [0, 1]
        # embeddings: [sO, C], float
        # offsets: [L + 1], int
        # RETURN: [B, F], float

        inputs = inputs.contiguous()

        B, D = inputs.shape  # batch size, coord dim
        L = offsets.shape[0] - 1  # level
        C = embeddings.shape[1]  # embedding dim for each level
        S = np.log2(
            per_level_scale
        )  # resolution multiplier at each level, apply log2 for later CUDA exp2f
        H = base_resolution  # base resolution

        # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
        # if C % 2 != 0, force float, since half for atomicAdd is very slow.
        if torch.is_autocast_enabled() and C % 2 == 0:
            embeddings = embeddings.to(torch.half)

        # L first, optimize cache for cuda kernel, but needs an extra permute later
        outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)

        if calc_grad_inputs:
            dy_dx = torch.empty(
                B, L * D * C, device=inputs.device, dtype=embeddings.dtype
            )
        else:
            dy_dx = torch.empty(
                1, device=inputs.device, dtype=embeddings.dtype
            )  # placeholder... TODO: a better way?

        _backend.grid_encode_forward(
            inputs,
            embeddings,
            offsets,
            outputs,
            B,
            D,
            C,
            L,
            S,
            H,
            calc_grad_inputs,
            dy_dx,
            gridtype,
            align_corners,
        )

        # permute back to [B, L * C]
        outputs = outputs.permute(1, 0, 2).reshape(B, L * C)

        ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
        ctx.dims = [B, D, C, L, S, H, gridtype]
        ctx.calc_grad_inputs = calc_grad_inputs
        ctx.align_corners = align_corners

        return outputs

    @staticmethod
    # @once_differentiable
    @custom_bwd
    def backward(ctx, grad):

        inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
        B, D, C, L, S, H, gridtype = ctx.dims
        calc_grad_inputs = ctx.calc_grad_inputs
        align_corners = ctx.align_corners

        # grad: [B, L * C] --> [L, B, C]
        grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()

        grad_embeddings = torch.zeros_like(embeddings)

        if calc_grad_inputs:
            grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
        else:
            grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)

        _backend.grid_encode_backward(
            grad,
            inputs,
            embeddings,
            offsets,
            grad_embeddings,
            B,
            D,
            C,
            L,
            S,
            H,
            calc_grad_inputs,
            dy_dx,
            grad_inputs,
            gridtype,
            align_corners,
        )

        if calc_grad_inputs:
            grad_inputs = grad_inputs.to(inputs.dtype)
            return grad_inputs, grad_embeddings, None, None, None, None, None, None
        else:
            return None, grad_embeddings, None, None, None, None, None, None


grid_encode = _grid_encode.apply


class GridEncoder(nn.Module):
    def __init__(
        self,
        input_dim=3,
        num_levels=16,
        level_dim=2,
        per_level_scale=2,
        base_resolution=16,
        log2_hashmap_size=19,
        desired_resolution=None,
        gridtype="hash",
        align_corners=False,
    ):
        super().__init__()

        # the finest resolution desired at the last level, if provided, overridee per_level_scale
        if desired_resolution is not None:
            per_level_scale = np.exp2(
                np.log2(desired_resolution / base_resolution) / (num_levels - 1)
            )

        self.input_dim = input_dim  # coord dims, 2 or 3
        self.num_levels = num_levels  # num levels, each level multiply resolution by 2
        self.level_dim = level_dim  # encode channels per level
        self.per_level_scale = (
            per_level_scale  # multiply resolution by this scale at each level.
        )
        self.log2_hashmap_size = log2_hashmap_size
        self.base_resolution = base_resolution
        self.output_dim = num_levels * level_dim
        self.gridtype = gridtype
        self.gridtype_id = _gridtype_to_id[gridtype]  # "tiled" or "hash"
        self.align_corners = align_corners

        # allocate parameters
        offsets = []
        offset = 0
        self.max_params = 2 ** log2_hashmap_size
        for i in range(num_levels):
            resolution = int(np.ceil(base_resolution * per_level_scale ** i))
            params_in_level = min(
                self.max_params,
                (resolution if align_corners else resolution + 1) ** input_dim,
            )  # limit max number
            params_in_level = int(np.ceil(params_in_level / 8) * 8)  # make divisible
            offsets.append(offset)
            offset += params_in_level
        offsets.append(offset)
        offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
        self.register_buffer("offsets", offsets)

        self.n_params = offsets[-1] * level_dim

        # parameters
        self.embeddings = nn.Parameter(torch.empty(offset, level_dim))

        self.reset_parameters()

    def reset_parameters(self):
        std = 1e-4
        self.embeddings.data.uniform_(-std, std)

    def __repr__(self):
        return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"

    def forward(self, inputs, bound=1):
        # inputs: [..., input_dim], normalized real world positions in [-bound, bound]
        # return: [..., num_levels * level_dim]

        inputs = (inputs + bound) / (2 * bound)  # map to [0, 1]

        # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())

        prefix_shape = list(inputs.shape[:-1])
        inputs = inputs.view(-1, self.input_dim)

        outputs = grid_encode(
            inputs,
            self.embeddings,
            self.offsets,
            self.per_level_scale,
            self.base_resolution,
            inputs.requires_grad,
            self.gridtype_id,
            self.align_corners,
        )
        outputs = outputs.view(prefix_shape + [self.output_dim])

        # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())

        return outputs


================================================
FILE: gridencoder/setup.py
================================================
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

_src_path = os.path.dirname(os.path.abspath(__file__))

nvcc_flags = [
    "-O3",
    "-std=c++14",
    "-U__CUDA_NO_HALF_OPERATORS__",
    "-U__CUDA_NO_HALF_CONVERSIONS__",
    "-U__CUDA_NO_HALF2_OPERATORS__",
]

if os.name == "posix":
    c_flags = ["-O3", "-std=c++14"]
elif os.name == "nt":
    c_flags = ["/O2", "/std:c++17"]

    # find cl.exe
    def find_cl_path():
        import glob

        for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
            paths = sorted(
                glob.glob(
                    r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64"
                    % edition
                ),
                reverse=True,
            )
            if paths:
                return paths[0]

    # If cl.exe is not on path, try to find it.
    if os.system("where cl.exe >nul 2>nul") != 0:
        cl_path = find_cl_path()
        if cl_path is None:
            raise RuntimeError(
                "Could not locate a supported Microsoft Visual C++ installation"
            )
        os.environ["PATH"] += ";" + cl_path

setup(
    name="gridencoder",  # package name, import this to use python API
    ext_modules=[
        CUDAExtension(
            name="_gridencoder",  # extension name, import this to use CUDA API
            sources=[
                os.path.join(_src_path, "src", f)
                for f in [
                    "gridencoder.cu",
                    "bindings.cpp",
                ]
            ],
            extra_compile_args={
                "cxx": c_flags,
                "nvcc": nvcc_flags,
            },
        ),
    ],
    cmdclass={
        "build_ext": BuildExtension,
    },
)


================================================
FILE: gridencoder/src/bindings.cpp
================================================
#include <torch/extension.h>

#include "gridencoder.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
    m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
}

================================================
FILE: gridencoder/src/gridencoder.cu
================================================
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>

#include <algorithm>
#include <stdexcept>

#include <stdint.h>
#include <cstdio>


#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")


// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
static inline  __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
  // requires CUDA >= 10 and ARCH >= 70
  // this is very slow compared to float or __half2, and never used.
  //return atomicAdd(reinterpret_cast<__half*>(address), val);
}


template <typename T>
static inline __host__ __device__ T div_round_up(T val, T divisor) {
    return (val + divisor - 1) / divisor;
}


template <uint32_t D>
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
    static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");

    // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
    // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
    // coordinates.
    constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };

    uint32_t result = 0;
    #pragma unroll
    for (uint32_t i = 0; i < D; ++i) {
        result ^= pos_grid[i] * primes[i];
    }

    return result;
}


template <uint32_t D, uint32_t C>
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
    uint32_t stride = 1;
    uint32_t index = 0;

    #pragma unroll
    for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
        index += pos_grid[d] * stride;
        stride *= align_corners ? resolution: (resolution + 1);
    }

    // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
    // gridtype: 0 == hash, 1 == tiled
    if (gridtype == 0 && stride > hashmap_size) {
        index = fast_hash<D>(pos_grid);
    }

    return (index % hashmap_size) * C + ch;
}


template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_grid(
    const float * __restrict__ inputs, 
    const scalar_t * __restrict__ grid, 
    const int * __restrict__ offsets, 
    scalar_t * __restrict__ outputs, 
    const uint32_t B, const uint32_t L, const float S, const uint32_t H,
    const bool calc_grad_inputs, 
    scalar_t * __restrict__ dy_dx,
    const uint32_t gridtype,
    const bool align_corners
) {
    const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (b >= B) return;

    const uint32_t level = blockIdx.y;
    
    // locate
    grid += (uint32_t)offsets[level] * C;
    inputs += b * D;
    outputs += level * B * C + b * C;

    // check input range (should be in [0, 1])
    bool flag_oob = false;
    #pragma unroll
    for (uint32_t d = 0; d < D; d++) {
        if (inputs[d] < 0 || inputs[d] > 1) {
            flag_oob = true;
        }
    }
    // if input out of bound, just set output to 0
    if (flag_oob) {
        #pragma unroll
        for (uint32_t ch = 0; ch < C; ch++) {
            outputs[ch] = 0; 
        }
        if (calc_grad_inputs) {
            dy_dx += b * D * L * C + level * D * C; // B L D C
            #pragma unroll
            for (uint32_t d = 0; d < D; d++) {
                #pragma unroll
                for (uint32_t ch = 0; ch < C; ch++) {
                    dy_dx[d * C + ch] = 0; 
                }       
            }
        }
        return;
    }

    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
    const float scale = exp2f(level * S) * H - 1.0f;
    const uint32_t resolution = (uint32_t)ceil(scale) + 1;
    
    // calculate coordinate
    float pos[D];
    uint32_t pos_grid[D];

    #pragma unroll
    for (uint32_t d = 0; d < D; d++) {
        pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
        pos_grid[d] = floorf(pos[d]);
        pos[d] -= (float)pos_grid[d];
    }

    //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);

    // interpolate
    scalar_t results[C] = {0}; // temp results in register

    #pragma unroll
    for (uint32_t idx = 0; idx < (1 << D); idx++) {
        float w = 1;
        uint32_t pos_grid_local[D];

        #pragma unroll
        for (uint32_t d = 0; d < D; d++) {
            if ((idx & (1 << d)) == 0) {
                w *= 1 - pos[d];
                pos_grid_local[d] = pos_grid[d];
            } else {
                w *= pos[d];
                pos_grid_local[d] = pos_grid[d] + 1;
            }
        }

        uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);

        // writing to register (fast)
        #pragma unroll
        for (uint32_t ch = 0; ch < C; ch++) {
            results[ch] += w * grid[index + ch];
        }

        //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
    }    

    // writing to global memory (slow)
    #pragma unroll
    for (uint32_t ch = 0; ch < C; ch++) {
        outputs[ch] = results[ch]; 
    }

    // prepare dy_dx for calc_grad_inputs
    // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
    if (calc_grad_inputs) {

        dy_dx += b * D * L * C + level * D * C; // B L D C

        #pragma unroll
        for (uint32_t gd = 0; gd < D; gd++) {

            scalar_t results_grad[C] = {0};

            #pragma unroll
            for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
                float w = scale;
                uint32_t pos_grid_local[D];

                #pragma unroll
                for (uint32_t nd = 0; nd < D - 1; nd++) {
                    const uint32_t d = (nd >= gd) ? (nd + 1) : nd;

                    if ((idx & (1 << nd)) == 0) {
                        w *= 1 - pos[d];
                        pos_grid_local[d] = pos_grid[d];
                    } else {
                        w *= pos[d];
                        pos_grid_local[d] = pos_grid[d] + 1;
                    }
                }

                pos_grid_local[gd] = pos_grid[gd];
                uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
                pos_grid_local[gd] = pos_grid[gd] + 1;
                uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);

                #pragma unroll
                for (uint32_t ch = 0; ch < C; ch++) {
                    results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
                }
            }

            #pragma unroll
            for (uint32_t ch = 0; ch < C; ch++) {
                dy_dx[gd * C + ch] = results_grad[ch];
            }
        }
    }
}


template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
__global__ void kernel_grid_backward(
    const scalar_t * __restrict__ grad,
    const float * __restrict__ inputs, 
    const scalar_t * __restrict__ grid, 
    const int * __restrict__ offsets, 
    scalar_t * __restrict__ grad_grid, 
    const uint32_t B, const uint32_t L, const float S, const uint32_t H,
    const uint32_t gridtype,
    const bool align_corners
) {
    const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
    if (b >= B) return;

    const uint32_t level = blockIdx.y;
    const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;

    // locate
    grad_grid += offsets[level] * C;
    inputs += b * D;
    grad += level * B * C + b * C + ch; // L, B, C

    const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
    const float scale = exp2f(level * S) * H - 1.0f;
    const uint32_t resolution = (uint32_t)ceil(scale) + 1;

    // check input range (should be in [0, 1])
    #pragma unroll
    for (uint32_t d = 0; d < D; d++) {
        if (inputs[d] < 0 || inputs[d] > 1) {
            return; // grad is init as 0, so we simply return.
        }
    }

    // calculate coordinate
    float pos[D];
    uint32_t pos_grid[D];

    #pragma unroll
    for (uint32_t d = 0; d < D; d++) {
        pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
        pos_grid[d] = floorf(pos[d]);
        pos[d] -= (float)pos_grid[d];
    }

    scalar_t grad_cur[N_C] = {0}; // fetch to register
    #pragma unroll
    for (uint32_t c = 0; c < N_C; c++) {
        grad_cur[c] = grad[c];
    }

    // interpolate
    #pragma unroll
    for (uint32_t idx = 0; idx < (1 << D); idx++) {
        float w = 1;
        uint32_t pos_grid_local[D];

        #pragma unroll
        for (uint32_t d = 0; d < D; d++) {
            if ((idx & (1 << d)) == 0) {
                w *= 1 - pos[d];
                pos_grid_local[d] = pos_grid[d];
            } else {
                w *= pos[d];
                pos_grid_local[d] = pos_grid[d] + 1;
            }
        }

        uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);

        // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
        // TODO: use float which is better than __half, if N_C % 2 != 0
        if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
            #pragma unroll
            for (uint32_t c = 0; c < N_C; c += 2) {
                // process two __half at once (by interpreting as a __half2)
                __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
                atomicAdd((__half2*)&grad_grid[index + c], v);
            }
        // float, or __half when N_C % 2 != 0 (which means C == 1)
        } else {
            #pragma unroll
            for (uint32_t c = 0; c < N_C; c++) {
                atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
            }
        }
    }    
}


template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_input_backward(
    const scalar_t * __restrict__ grad,
    const scalar_t * __restrict__ dy_dx,  
    scalar_t * __restrict__ grad_inputs, 
    uint32_t B, uint32_t L
) {
    const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
    if (t >= B * D) return;

    const uint32_t b = t / D;
    const uint32_t d = t - b * D;

    dy_dx += b * L * D * C;

    scalar_t result = 0;
    
    # pragma unroll
    for (int l = 0; l < L; l++) {
        # pragma unroll
        for (int ch = 0; ch < C; ch++) {
            result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
        }
    }

    grad_inputs[t] = result;
}


template <typename scalar_t, uint32_t D>
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
    static constexpr uint32_t N_THREAD = 512;
    const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
    switch (C) {
        case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
        case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
        case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
        case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
        default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
    }
}

// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
// H: base resolution
// dy_dx: [B, L * D * C]
template <typename scalar_t>
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
    switch (D) {
        case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
        case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
        default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
    }
    
}

template <typename scalar_t, uint32_t D>
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
    static constexpr uint32_t N_THREAD = 256;
    const uint32_t N_C = std::min(2u, C); // n_features_per_thread
    const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
    switch (C) {
        case 1: 
            kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); 
            if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
            break;
        case 2: 
            kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
            if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
            break;
        case 4: 
            kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
            if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
            break;
        case 8: 
            kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
            if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
            break;
        default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
    }
}


// grad: [L, B, C], float
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// grad_embeddings: [sO, C]
// H: base resolution
template <typename scalar_t>
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
    switch (D) {
        case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
        case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
        default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
    }
}



void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) {
    CHECK_CUDA(inputs);
    CHECK_CUDA(embeddings);
    CHECK_CUDA(offsets);
    CHECK_CUDA(outputs);
    CHECK_CUDA(dy_dx);
    
    CHECK_CONTIGUOUS(inputs);
    CHECK_CONTIGUOUS(embeddings);
    CHECK_CONTIGUOUS(offsets);
    CHECK_CONTIGUOUS(outputs);
    CHECK_CONTIGUOUS(dy_dx);

    CHECK_IS_FLOATING(inputs);
    CHECK_IS_FLOATING(embeddings);
    CHECK_IS_INT(offsets);
    CHECK_IS_FLOATING(outputs);
    CHECK_IS_FLOATING(dy_dx);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    embeddings.scalar_type(), "grid_encode_forward", ([&] {
        grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype, align_corners);
    }));
}

void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) {
    CHECK_CUDA(grad);
    CHECK_CUDA(inputs);
    CHECK_CUDA(embeddings);
    CHECK_CUDA(offsets);
    CHECK_CUDA(grad_embeddings);
    CHECK_CUDA(dy_dx);
    CHECK_CUDA(grad_inputs);
    
    CHECK_CONTIGUOUS(grad);
    CHECK_CONTIGUOUS(inputs);
    CHECK_CONTIGUOUS(embeddings);
    CHECK_CONTIGUOUS(offsets);
    CHECK_CONTIGUOUS(grad_embeddings);
    CHECK_CONTIGUOUS(dy_dx);
    CHECK_CONTIGUOUS(grad_inputs);

    CHECK_IS_FLOATING(grad);
    CHECK_IS_FLOATING(inputs);
    CHECK_IS_FLOATING(embeddings);
    CHECK_IS_INT(offsets);
    CHECK_IS_FLOATING(grad_embeddings);
    CHECK_IS_FLOATING(dy_dx);
    CHECK_IS_FLOATING(grad_inputs);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(
    grad.scalar_type(), "grid_encode_backward", ([&] {
        grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);
    }));
    
}


================================================
FILE: gridencoder/src/gridencoder.h
================================================
#ifndef _HASH_ENCODE_H
#define _HASH_ENCODE_H

#include <stdint.h>
#include <torch/torch.h>

// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [B, L * C], float
// H: base resolution
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners);
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners);

#endif

================================================
FILE: just_train_tea/network.py
================================================
import torch
from time import time
import torch.nn as nn
import torch.nn.functional as F

from tools.encoding import get_encoder
from tools.activation import trunc_exp
from .renderer import NeRFRenderer
import raymarching


class NeRFNetwork(NeRFRenderer):
    def __init__(
        self,
        encoding="hashgrid",
        encoding_dir="sphere_harmonics",
        encoding_bg="hashgrid",
        num_layers=2,
        hidden_dim=64,
        geo_feat_dim=15,
        num_layers_color=3,
        hidden_dim_color=64,
        num_layers_bg=2,
        hidden_dim_bg=64,
        bound=1,
        model_type="hash",
        args=None,
        is_teacher=False,
        **kwargs,
    ):
        super().__init__(bound, **kwargs)
        # sigma network
        assert model_type in ["hash", "mlp", "vm", "tensors"]
        self.is_teacher = is_teacher
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.geo_feat_dim = geo_feat_dim
        self.args = args
        self.opt = args
        self.model_type = model_type

        self.plenoxel_degree = args.plenoxel_degree
        self.plenoxel_res = eval(args.plenoxel_res)

        assert len(self.plenoxel_res) == 3

        self.encoder, self.in_dim = get_encoder(
            encoding,
            desired_resolution=2048 * bound,
            num_levels=14,
        )

        if "hash" != self.model_type:
            self.encoder = None

        if self.model_type == "mlp":
            self.encoder_nerf_pe, self.in_dim_nerf = get_encoder(
                encoding="frequency", multires=self.args.PE
            )
            self.skips = self.args.skip
            self.nerf_layer_num = self.args.nerf_layer_num
            W = self.args.nerf_layer_wide
            self.nerf_mlp = [nn.Linear(self.in_dim_nerf, W)]
            for i in range(self.nerf_layer_num - 2):
                if i != self.skips:
                    self.nerf_mlp.append(nn.Linear(W, W))
                else:
                    self.nerf_mlp.append(nn.Linear(W + self.in_dim_nerf, W))
            self.nerf_mlp.append(nn.Linear(W, self.in_dim))
            self.nerf_mlp = nn.ModuleList(self.nerf_mlp)

        elif self.model_type == "vm":
            self.sigma_rank = [16] * 3
            self.color_rank = [48] * 3
            self.color_feat_dim = 15  # geo_feat_dim
            self.mat_ids = [[0, 1], [0, 2], [1, 2]]
            self.vec_ids = [2, 1, 0]
            self.resolution = [self.opt.resolution0] * 3
            # mat: paralist[1,16,res0,res0] repeat 3   vec: paralist[1,16,res0,1] repeat 3; repeat3 because decompose 3D grid [H, W, D] to three 2D mat [H, W], [H,D], [W, D] or decompose to three 1D vec [H], [W], [D]
            self.sigma_mat, self.sigma_vec = self.init_one_vm(
                self.sigma_rank, self.resolution
            )
            # mat: paralist[1,48,res0,res0] repeat 3   vec: paralist[1,48,res0,1] repeat 3
            self.color_mat, self.color_vec = self.init_one_vm(
                self.color_rank, self.resolution
            )
            # Linear(in_features=144, out_features=27)
            self.basis_mat = nn.Linear(
                sum(self.color_rank), self.color_feat_dim, bias=False
            )
        elif self.model_type == "tensors":
            self.init_plenoxel_volume(
                s=0.02,
                fea_dim=self.plenoxel_degree ** 2 * 3 + 1,
                volume=self.plenoxel_res,
            )

        elif self.model_type == "hash":
            pass
        else:
            raise ValueError(f"error model_type:{self.model_type}")

        if self.model_type != "vm" and self.model_type != "tensors":
            sigma_net = []
            for l in range(num_layers):
                if l == 0:
                    in_dim = self.in_dim
                else:
                    in_dim = hidden_dim

                if l == num_layers - 1:
                    out_dim = (
                        1 + self.geo_feat_dim
                    )  # 1 sigma + 15 SH features for color
                else:
                    out_dim = hidden_dim

                sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))

            self.sigma_net = nn.ModuleList(sigma_net)

        # color network
        self.num_layers_color = num_layers_color
        self.hidden_dim_color = hidden_dim_color
        # self.encoder_dir, self.in_dim_dir = get_encoder(encoding=encoding_dir)
        if self.model_type == "tensors":
            self.encoder_dir, self.in_dim_dir = get_encoder(
                encoding="sphere_harmonics",
                degree=self.plenoxel_degree,
            )

        else:
            self.encoder_dir, self.in_dim_dir = get_encoder(
                encoding=encoding_dir, input_dim=3, multires=2
            )

        if self.model_type != "tensors":
            color_net = []
            for l in range(num_layers_color):
                if l == 0:
                    in_dim = self.in_dim_dir + self.geo_feat_dim
                else:
                    in_dim = hidden_dim

                if l == num_layers_color - 1:
                    out_dim = 3  # 3 rgb
                else:
                    out_dim = hidden_dim

                color_net.append(nn.Linear(in_dim, out_dim, bias=False))

            self.color_net = nn.ModuleList(color_net)

        # background network
        if self.bg_radius > 0:
            self.num_layers_bg = num_layers_bg
            self.hidden_dim_bg = hidden_dim_bg
            self.encoder_bg, self.in_dim_bg = get_encoder(
                encoding_bg,
                input_dim=2,
                num_levels=4,
                log2_hashmap_size=19,
                desired_resolution=2048,
            )  # much smaller hashgrid

            bg_net = []
            for l in range(num_layers_bg):
                if l == 0:
                    in_dim = self.in_dim_bg + self.in_dim_dir
                else:
                    in_dim = hidden_dim_bg

                if l == num_layers_bg - 1:
                    out_dim = 3  # 3 rgb
                else:
                    out_dim = hidden_dim_bg

                bg_net.append(nn.Linear(in_dim, out_dim, bias=False))

            self.bg_net = nn.ModuleList(bg_net)
        else:
            self.bg_net = None

    def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128, 128]):
        tensor = []
        tensor.append(
            torch.nn.Parameter(
                s * torch.randn((1, fea_dim, volume[0], volume[1], volume[2]))
            )
        )
        self.tensor_volume = torch.nn.ParameterList(tensor).cuda()

    def init_one_vm(self, n_component, resolution, scale=0.1):
        # self.mat_ids = [[0, 1], [0, 2], [1, 2]]  self.vec_ids = [2, 1, 0]
        mat, vec = [], []

        for i in range(len(self.vec_ids)):
            vec_id = self.vec_ids[i]
            mat_id_0, mat_id_1 = self.mat_ids[i]
            mat.append(
                nn.Parameter(
                    scale
                    * torch.randn(
                        (1, n_component[i], resolution[mat_id_1], resolution[mat_id_0])
                    )
                )
            )  # [1, R, H, W]
            vec.append(
                nn.Parameter(
                    scale * torch.randn((1, n_component[i], resolution[vec_id], 1))
                )
            )  # [1, R, D, 1] (fake 2d to use grid_sample)

        return nn.ParameterList(mat), nn.ParameterList(vec)

    def get_sigma_feat(self, x):
        # x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode)
        # self.mat_ids = [[0, 1], [0, 2], [1, 2]]  self.vec_ids = [2, 1, 0]
        N = x.shape[0]

        # plane + line basis
        mat_coord = (
            torch.stack(
                (
                    x[..., self.mat_ids[0]],
                    x[..., self.mat_ids[1]],
                    x[..., self.mat_ids[2]],
                )
            )
            .detach()
            .view(3, -1, 1, 2)
        )  # [3, N, 1, 2]
        vec_coord = torch.stack(
            (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])
        )
        vec_coord = (
            torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)
            .detach()
            .view(3, -1, 1, 2)
        )  # [3, N, 1, 2], fake 2d coord

        sigma_feat = torch.zeros(
            [
                N,
            ],
            device=x.device,
        )

        for i in range(len(self.sigma_mat)):
            mat_feat = F.grid_sample(
                self.sigma_mat[i], mat_coord[[i]], align_corners=True
            ).view(
                -1, N
            )  # [1, R, N, 1] --> [R, N]
            vec_feat = F.grid_sample(
                self.sigma_vec[i], vec_coord[[i]], align_corners=True
            ).view(
                -1, N
            )  # [R, N]
            sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0)

        return sigma_feat

    def get_color_feat(self, x):
        # x: [N, 3], in [-1, 1]
        N = x.shape[0]

        # plane + line basis
        mat_coord = (
            torch.stack(
                (
                    x[..., self.mat_ids[0]],
                    x[..., self.mat_ids[1]],
                    x[..., self.mat_ids[2]],
                )
            )
            .detach()
            .view(3, -1, 1, 2)
        )  # [3, N, 1, 2]
        vec_coord = torch.stack(
            (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])
        )
        vec_coord = (
            torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)
            .detach()
            .view(3, -1, 1, 2)
        )  # [3, N, 1, 2], fake 2d coord

        mat_feat, vec_feat = [], []
        for i in range(len(self.color_mat)):
            mat_feat.append(
                F.grid_sample(
                    self.color_mat[i], mat_coord[[i]], align_corners=True
                ).view(-1, N)
            )  # [1, R, N, 1] --> [R, N]
            vec_feat.append(
                F.grid_sample(
                    self.color_vec[i], vec_coord[[i]], align_corners=True
                ).view(-1, N)
            )  # [R, N]

        mat_feat = torch.cat(mat_feat, dim=0)  # [3 * R, N]
        vec_feat = torch.cat(vec_feat, dim=0)  # [3 * R, N]

        color_feat = self.basis_mat(
            (mat_feat * vec_feat).T
        )  # [N, 3R] --> [N, color_feat_dim]

        return color_feat

    def compute_plenoxel_fea(self, x):
        composed = self.tensor_volume[0]
        composed = (
            F.grid_sample(composed, x.view(1, 1, -1, 1, 3), align_corners=True)
            .view(-1, x.shape[0])
            .permute(1, 0)
        )
        return composed  # [N, fea_dim]

    def forward_nerf_mlp(self, x):
        x = self.encoder_nerf_pe(x)
        in_pts = x
        for i in range(len(self.nerf_mlp)):
            x = self.nerf_mlp[i](x)
            if i != len(self.nerf_mlp) - 1:
                x = F.relu(x, inplace=True)
            if i == self.skips:
                x = torch.cat([in_pts, x], -1)
        return x

    def forward(self, x, d):
        # x: [N, 3], in [-bound, bound]  d: [N, 3], nomalized in [-1, 1]
        # sigma
        if self.model_type == "hash":
            x = self.encoder(
                x, bound=self.bound
            )  # out_x[N, 28=num_levels * fea_per_level]
        elif self.model_type == "mlp":
            x = self.forward_nerf_mlp(x)  # 28
        elif self.model_type == "vm":
            x = (
                2
                * (x - self.aabb_train[:3])
                / (self.aabb_train[3:] - self.aabb_train[:3])
                - 1
            )  # x:[N, 3]
            sigma_feat = self.get_sigma_feat(x)  # sigma_feat:[N]
            color_feat = self.get_color_feat(x)  # color_feat:[N, 15]
            sigma_feat = torch.clamp(
                sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
            )
            # color_feat = torch.clamp(color_feat, self.args.sigma_clip_min, self.args.sigma_clip_max)
            self.feature_sigma_color = torch.cat(
                [sigma_feat.unsqueeze(-1), color_feat], dim=-1
            )
            self.sigma_l = sigma_feat
            sigma = trunc_exp(sigma_feat)  # sigma:[N]
            enc_d = self.encoder_dir(d)  # enc_d:[N, 16]
            h = torch.cat([enc_d, color_feat], dim=-1)  # h:[N, 16+15]
            for l in range(self.num_layers_color):
                h = self.color_net[l](h)
                if l != self.num_layers_color - 1:
                    h = F.relu(h, inplace=True)

            color = torch.sigmoid(h)
            self.color_l = color

            return sigma, color
        elif self.model_type == "tensors":
            x = (
                2
                * (x - self.aabb_train[:3])
                / (self.aabb_train[3:] - self.aabb_train[:3])
                - 1
            )  # x:[N, 3]
            x = self.compute_plenoxel_fea(x)
            h = x
            sigma = torch.clamp(
                h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max
            )
            self.sigma_l = sigma
            sigma = trunc_exp(sigma)
            self.sigma = sigma
            sh = h[..., 1:].view(
                -1, 3, self.plenoxel_degree ** 2
            )  # [N, 3, 9]   ## .permute(1, 0, 2)  # [B, 27]-->[9, B, 3]
            enc_d = self.encoder_dir(d).unsqueeze(1)  # [N, 9]-->[N,1,9]
            color = (sh * enc_d).sum(-1)  # [N, 3]
            color = torch.sigmoid(color)
            self.feature_sigma_color = None
            self.color_l = color
            return sigma, color

        else:
            raise ValueError(f"not illegal model_type:{self.model_type}")

        h = x
        for l in range(self.num_layers):
            h = self.sigma_net[l](h)
            if l != self.num_layers - 1:
                h = F.relu(h, inplace=True)
        h[..., 0] = torch.clamp(
            h[..., 0].clone(), self.args.sigma_clip_min, self.args.sigma_clip_max
        )
        # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
        self.feature_sigma_color = h
        self.sigma_l = h[..., 0]
        sigma = trunc_exp(h[..., 0])  # sigma: [n]
        geo_feat = h[..., 1:]  # geo_feat: [n, 15]

        d = self.encoder_dir(d)  # d: [n, 16]
        h = torch.cat([d, geo_feat], dim=-1)  # h: [n, 15+16]
        for l in range(self.num_layers_color):
            h = self.color_net[l](h)
            if l != self.num_layers_color - 1:
                h = F.relu(h, inplace=True)

        color = torch.sigmoid(h)
        self.color_l = color
        return sigma, color

    def density(self, x):
        # x: [N, 3], in [-bound, bound]
        if self.model_type == "hash":
            x = self.encoder(
                x, bound=self.bound
            )  # out_x[N, 32=num_levels * fea_per_level]
        elif self.model_type == "mlp":
            x = self.forward_nerf_mlp(x)
        elif self.model_type == "vm":
            x = (
                2
                * (x - self.aabb_train[:3])
                / (self.aabb_train[3:] - self.aabb_train[:3])
                - 1
            )
            sigma_feat = self.get_sigma_feat(x)
            sigma_feat = torch.clamp(
                sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
            )
            sigma = trunc_exp(sigma_feat)
            return {"sigma": sigma}
        elif self.model_type == "tensors":
            x = (
                2
                * (x - self.aabb_train[:3])
                / (self.aabb_train[3:] - self.aabb_train[:3])
                - 1
            )  # x:[N, 3]
            x = self.compute_plenoxel_fea(x)
            h = x
            # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
            sigma = trunc_exp(
                torch.clamp(
                    h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max
                )
            )
            sigma = trunc_exp(h[..., 0])
            return {"sigma": sigma}

        else:
            raise ValueError(f"not illegal model_type:{self.model_type}")

        h = x
        for l in range(self.num_layers):
            h = self.sigma_net[l](h)
            if l != self.num_layers - 1:
                h = F.relu(h, inplace=True)

        h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
        sigma = trunc_exp(h[..., 0])
        geo_feat = h[..., 1:]

        return {
            "sigma": sigma,
            "geo_feat": geo_feat,
        }

    def background(self, x, d):
        assert 1 == 2
        # x: [N, 2], in [-1, 1]

        h = self.encoder_bg(x)  # [N, C]
        d = self.encoder_dir(d)

        h = torch.cat([d, h], dim=-1)
        for l in range(self.num_layers_bg):
            h = self.bg_net[l](h)
            if l != self.num_layers_bg - 1:
                h = F.relu(h, inplace=True)

        # sigmoid activation for rgb
        rgbs = torch.sigmoid(h)

        return rgbs

    # allow masked inference
    def color(self, x, d, mask=None, geo_feat=None, **kwargs):
        assert 1 == 2
        # x: [N, 3] in [-bound, bound]
        # mask: [N,], bool, indicates where we actually needs to compute rgb.

        if mask is not None:
            rgbs = torch.zeros(
                mask.shape[0], 3, dtype=x.dtype, device=x.device
            )  # [N, 3]
            # in case of empty mask
            if not mask.any():
                return rgbs
            x = x[mask]
            d = d[mask]
            geo_feat = geo_feat[mask]

        d = self.encoder_dir(d)
        h = torch.cat([d, geo_feat], dim=-1)
        for l in range(self.num_layers_color):
            h = self.color_net[l](h)
            if l != self.num_layers_color - 1:
                h = F.relu(h, inplace=True)

        # sigmoid activation for rgb
        h = torch.sigmoid(h)

        if mask is not None:
            rgbs[mask] = h.to(rgbs.dtype)  # fp16 --> fp32
        else:
            rgbs = h

        return rgbs

    # L1 penalty for loss
    def density_loss(self):
        loss = 0
        for i in range(len(self.sigma_mat)):
            loss = (
                loss
                + torch.mean(torch.abs(self.sigma_mat[i]))
                + torch.mean(torch.abs(self.sigma_vec[i]))
            )
        return loss

    # upsample utils
    @torch.no_grad()
    def upsample_params(self, mat, vec, resolution):

        for i in range(len(self.vec_ids)):
            vec_id = self.vec_ids[i]
            mat_id_0, mat_id_1 = self.mat_ids[i]
            mat[i] = nn.Parameter(
                F.interpolate(
                    mat[i].data,
                    size=(resolution[mat_id_1], resolution[mat_id_0]),
                    mode="bilinear",
                    align_corners=True,
                )
            )
            vec[i] = nn.Parameter(
                F.interpolate(
                    vec[i].data,
                    size=(resolution[vec_id], 1),
                    mode="bilinear",
                    align_corners=True,
                )
            )

    @torch.no_grad()
    def upsample_model(self, resolution):
        self.upsample_params(self.sigma_mat, self.sigma_vec, resolution)
        self.upsample_params(self.color_mat, self.color_vec, resolution)
        self.resolution = resolution

    @torch.no_grad()
    def shrink_model(self):
        # shrink aabb_train and the model so it only represents the space inside aabb_train.

        half_grid_size = self.bound / self.grid_size
        thresh = min(self.density_thresh, self.mean_density)

        valid_grid = self.density_grid[self.cascade - 1] > thresh  # [N]
        valid_pos = raymarching.morton3D_invert(
            torch.nonzero(valid_grid)
        )  # [Nz] --> [Nz, 3], in [0, H - 1]
        # plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf...
        valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (
            self.bound - half_grid_size
        )  # [Nz, 3], in [-b+hgs, b-hgs]
        min_pos = valid_pos.amin(0) - half_grid_size  # [3]
        max_pos = valid_pos.amax(0) + half_grid_size  # [3]

        # shrink model
        reso = torch.LongTensor(self.resolution).to(self.aabb_train.device)
        units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso
        tl = (min_pos - self.aabb_train[:3]) / units
        br = (max_pos - self.aabb_train[:3]) / units
        tl = torch.round(tl).long().clamp(min=0)
        br = torch.minimum(torch.round(br).long(), reso)

        for i in range(len(self.vec_ids)):
            vec_id = self.vec_ids[i]
            mat_id_0, mat_id_1 = self.mat_ids[i]

            self.sigma_vec[i] = nn.Parameter(
                self.sigma_vec[i].data[..., tl[vec_id] : br[vec_id], :]
            )
            self.color_vec[i] = nn.Parameter(
                self.color_vec[i].data[..., tl[vec_id] : br[vec_id], :]
            )

            self.sigma_mat[i] = nn.Parameter(
                self.sigma_mat[i].data[
                    ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]
                ]
            )
            self.color_mat[i] = nn.Parameter(
                self.color_mat[i].data[
                    ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]
                ]
            )

        self.aabb_train = torch.cat([min_pos, max_pos], dim=0)  # [6]

        print(
            f"[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}"
        )
        print(f"[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}")

    # optimizer utils
    def get_params(self, lr, lr2=1e-3):
        if self.model_type == "hash":
            params = [
                {"params": self.encoder.parameters(), "lr": lr},
                {"params": self.sigma_net.parameters(), "lr": lr},
                {"params": self.encoder_dir.parameters(), "lr": lr},
                {"params": self.color_net.parameters(), "lr": lr},
            ]
        elif self.model_type == "mlp":
            params = [
                {"params": self.sigma_net.parameters(), "lr": lr},
                {"params": self.encoder_dir.parameters(), "lr": lr},
                {"params": self.color_net.parameters(), "lr": lr},
                {"params": self.nerf_mlp.parameters(), "lr": lr},
            ]
        elif self.model_type == "vm":
            params = [
                {"params": self.color_net.parameters(), "lr": lr2},
                {"params": self.sigma_mat, "lr": lr},
                {"params": self.sigma_vec, "lr": lr},
                {"params": self.color_mat, "lr": lr},
                {"params": self.color_vec, "lr": lr},
                {"params": self.basis_mat.parameters(), "lr": lr2},
            ]
        elif self.model_type == "tensors":
            params = [
                {"params": self.tensor_volume.parameters(), "lr": lr},
                {"params": self.encoder_dir.parameters(), "lr": lr},
            ]

        else:
            raise ValueError(f"not illegal model_type:{self.model_type}")

        if self.bg_radius > 0:
            params.append({"params": self.encoder_bg.parameters(), "lr": lr})
            params.append({"params": self.bg_net.parameters(), "lr": lr})

        return params


================================================
FILE: just_train_tea/provider.py
================================================
import os
import cv2
import glob
import json
import tqdm
import numpy as np
from scipy.spatial.transform import Slerp, Rotation

import trimesh

import torch
from torch.utils.data import DataLoader

from .utils import get_rays, srgb_to_linear


# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
def nerf_matrix_to_ngp(pose, scale=0.33):
    # for the fox dataset, 0.33 scales camera radius to ~ 2
    new_pose = np.array(
        [
            [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],
            [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],
            [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],
            [0, 0, 0, 1],
        ],
        dtype=np.float32,
    )
    return new_pose


def rand_poses(
    size,
    device,
    radius=1,
    theta_range=[np.pi / 3, 2 * np.pi / 3],
    phi_range=[0, 2 * np.pi],
):
    """generate random poses from an orbit camera
    Args:
        size: batch size of generated poses.
        device: where to allocate the output.
        radius: camera radius
        theta_range: [min, max], should be in [0, \pi]
        phi_range: [min, max], should be in [0, 2\pi]
    Return:
        poses: [size, 4, 4]
    """

    def normalize(vectors):
        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)

    thetas = (
        torch.rand(size, device=device) * (theta_range[1] - theta_range[0])
        + theta_range[0]
    )
    phis = (
        torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
    )

    centers = torch.stack(
        [
            radius * torch.sin(thetas) * torch.sin(phis),
            radius * torch.cos(thetas),
            radius * torch.sin(thetas) * torch.cos(phis),
        ],
        dim=-1,
    )  # [B, 3]

    # lookat
    forward_vector = -normalize(centers)
    up_vector = (
        torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
    )  # confused at the coordinate system...
    right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))
    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))

    poses = (
        torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
    )
    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
    poses[:, :3, 3] = centers

    return poses

    def normalize(vectors):
        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)

    interval_nums = torch.tensor(
        [i * 1 / (size - 1) for i in range(size)], dtype=torch.float32, device=device
    )
    thetas = interval_nums * (theta_range[1] - theta_range[0]) + theta_range[0]
    phis = interval_nums * (phi_range[1] - phi_range[0]) + phi_range[0]

    centers = torch.stack(
        [
            radius * torch.sin(thetas) * torch.sin(phis),
            radius * torch.cos(thetas),
            radius * torch.sin(thetas) * torch.cos(phis),
        ],
        dim=-1,
    )  # [B, 3]

    # lookat
    forward_vector = -normalize(centers)
    up_vector = (
        torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
    )  # confused at the coordinate system...
    right_vector = normalize(
        torch.cross(forward_vector, up_vector, dim=-1)
    )  # cross product
    up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))

    poses = (
        torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
    )
    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
    poses[:, :3, 3] = centers

    return poses


class NeRFDataset:
    def __init__(self, opt, device, type="train", downscale=1, n_test=10):
        super().__init__()

        self.opt = opt
        self.args = opt
        self.device = device
        self.type = type  # train, val, test
        self.downscale = downscale
        self.root_path = opt.path
        self.mode = opt.mode  # only support blender
        self.preload = opt.preload  # preload data into GPU
        self.scale = (
            opt.scale
        )  # camera radius scale to make sure camera are inside the bounding box.
        self.bound = (
            opt.bound
        )  # bounding box half length, also used as the radius to random sample poses.
        self.fp16 = opt.fp16  # if preload, load into fp16.

        self.training = self.type in ["train", "all", "trainval"]
        self.num_rays = self.opt.num_rays if self.training else -1

        if self.mode == "blender":
            if type == "all":
                transform_paths = glob.glob(os.path.join(self.root_path, "*.json"))
                transform = None
                for transform_path in transform_paths:
                    with open(transform_path, "r") as f:
                        tmp_transform = json.load(f)
                        if transform is None:
                            transform = tmp_transform
                        else:
                            transform["frames"].extend(tmp_transform["frames"])
            # load train and val split
            elif type == "trainval":
                with open(
                    os.path.join(self.root_path, f"t
Download .txt
gitextract_eepeq680/

├── LICENSE
├── README.md
├── distill_mutual/
│   ├── network.py
│   ├── provider.py
│   ├── renderer.py
│   └── utils.py
├── gridencoder/
│   ├── __init__.py
│   ├── backend.py
│   ├── grid.py
│   ├── setup.py
│   └── src/
│       ├── bindings.cpp
│       ├── gridencoder.cu
│       └── gridencoder.h
├── just_train_tea/
│   ├── network.py
│   ├── provider.py
│   ├── renderer.py
│   └── utils.py
├── main_distill_mutual.py
├── main_just_train_tea.py
├── raymarching/
│   ├── __init__.py
│   ├── backend.py
│   ├── raymarching.py
│   ├── setup.py
│   └── src/
│       ├── bindings.cpp
│       ├── pcg32.h
│       ├── raymarching.cu
│       └── raymarching.h
├── shencoder/
│   ├── __init__.py
│   ├── backend.py
│   ├── setup.py
│   ├── sphere_harmonics.py
│   └── src/
│       ├── bindings.cpp
│       ├── shencoder.cu
│       └── shencoder.h
└── tools/
    ├── activation.py
    ├── details.md
    ├── encoding.py
    ├── install_extensions.sh
    ├── requirements.txt
    └── 中文介绍.md
Download .txt
SYMBOL INDEX (210 symbols across 24 files)

FILE: distill_mutual/network.py
  class NeRFNetwork (line 12) | class NeRFNetwork(NeRFRenderer):
    method __init__ (line 13) | def __init__(
    method init_plenoxel_volume (line 184) | def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128...
    method init_one_vm (line 193) | def init_one_vm(self, n_component, resolution, scale=0.1):
    method get_sigma_feat (line 216) | def get_sigma_feat(self, x):
    method get_color_feat (line 264) | def get_color_feat(self, x):
    method compute_plenoxel_fea (line 311) | def compute_plenoxel_fea(self, x):
    method forward_nerf_mlp (line 324) | def forward_nerf_mlp(self, x):
    method forward (line 335) | def forward(self, x, d):
    method density (line 439) | def density(self, x):
    method background (line 496) | def background(self, x, d):
    method color (line 515) | def color(self, x, d, mask=None, geo_feat=None, **kwargs):
    method density_loss (line 549) | def density_loss(self):
    method upsample_params (line 561) | def upsample_params(self, mat, vec, resolution):
    method upsample_model (line 584) | def upsample_model(self, resolution):
    method shrink_model (line 590) | def shrink_model(self):
    method get_params (line 646) | def get_params(self, lr, lr2=1e-3):

FILE: distill_mutual/provider.py
  function nerf_matrix_to_ngp (line 18) | def nerf_matrix_to_ngp(pose, scale=0.33):
  function rand_poses (line 32) | def rand_poses(
  class NeRFDataset (line 123) | class NeRFDataset:
    method __init__ (line 124) | def __init__(self, opt, device, type="train", downscale=1, n_test=10):
    method collate (line 284) | def collate(self, index):
    method dataloader (line 316) | def dataloader(self):

FILE: distill_mutual/renderer.py
  function sample_pdf (line 15) | def sample_pdf(bins, weights, n_samples, det=False):
  function plot_pointcloud (line 54) | def plot_pointcloud(pc, color=None):
  class NeRFRenderer (line 66) | class NeRFRenderer(nn.Module):
    method __init__ (line 67) | def __init__(
    method forward (line 117) | def forward(self, x, d):
    method density (line 121) | def density(self, x):
    method color (line 124) | def color(self, x, d, mask=None, **kwargs):
    method reset_extra_state (line 127) | def reset_extra_state(self):
    method run (line 139) | def run(
    method run_cuda (line 319) | def run_cuda(
    method mark_untrained_grid (line 562) | def mark_untrained_grid(self, poses, intrinsic, S=64):
    method update_extra_state (line 648) | def update_extra_state(self, decay=0.95, S=128):
    method render (line 777) | def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **k...

FILE: distill_mutual/utils.py
  function update_loss_rate (line 41) | def update_loss_rate(cur_lrate, scale=0.99):
  function get_softmax_map_mean (line 45) | def get_softmax_map_mean(a, b):
  function get_kl (line 49) | def get_kl(inputs, targets):
  function nerf_matrix_to_ngp (line 53) | def nerf_matrix_to_ngp(pose, scale=0.8):
  function pose_spherical (line 67) | def pose_spherical(theta, phi, radius):
  function get_rand_poses (line 100) | def get_rand_poses(data_type="synthetic", original_loader=None):
  function custom_meshgrid (line 201) | def custom_meshgrid(*args):
  function linear_to_srgb (line 210) | def linear_to_srgb(x):
  function srgb_to_linear (line 215) | def srgb_to_linear(x):
  function compute_ssim (line 219) | def compute_ssim(
  function init_lpips (line 303) | def init_lpips(net_name, device):
  function rgb_lpips (line 317) | def rgb_lpips(gt, im, net_name):
  function get_rays (line 325) | def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
  function seed_everything (line 407) | def seed_everything(seed):
  function torch_vis_2d (line 417) | def torch_vis_2d(x, renormalize=False):
  function extract_fields (line 442) | def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
  function extract_geometry (line 473) | def extract_geometry(bound_min, bound_max, resolution, threshold, query_...
  class PSNRMeter (line 491) | class PSNRMeter:
    method __init__ (line 492) | def __init__(self):
    method clear (line 497) | def clear(self):
    method prepare_inputs (line 502) | def prepare_inputs(self, *inputs):
    method update (line 511) | def update(self, preds, truths):
    method measure (line 522) | def measure(self):
    method write (line 525) | def write(self, writer, global_step, prefix=""):
    method report (line 528) | def report(self):
  class Trainer (line 532) | class Trainer(object):
    method __init__ (line 533) | def __init__(
    method __del__ (line 672) | def __del__(self):
    method log (line 676) | def log(self, *args, **kwargs):
    method train (line 685) | def train(self, train_loader, valid_loader, max_epochs):
    method train_one_epoch (line 753) | def train_one_epoch(self, loader):
    method get_loss (line 941) | def get_loss(self, pred, gt):
    method train_step (line 954) | def train_step(self, data):
    method evaluate (line 1193) | def evaluate(self, loader, name=None):
    method evaluate_one_epoch (line 1198) | def evaluate_one_epoch(self, loader, name=None):
    method eval_step (line 1370) | def eval_step(self, data):
    method save_checkpoint (line 1405) | def save_checkpoint(self, name=None, full=False, best=False, remove_ol...
    method load_teacher_checkpoint (line 1477) | def load_teacher_checkpoint(self):
    method load_student_checkpoint (line 1531) | def load_student_checkpoint(self):
    method load_checkpoint (line 1589) | def load_checkpoint(self, checkpoint=None, model_only=False):
    method test (line 1653) | def test(self, loader, save_path=None, name=None):
    method test_step (line 1703) | def test_step(self, data, bg_color=None, perturb=False):

FILE: gridencoder/backend.py
  function find_cl_path (line 20) | def find_cl_path():

FILE: gridencoder/grid.py
  class _grid_encode (line 20) | class _grid_encode(Function):
    method forward (line 23) | def forward(
    method backward (line 96) | def backward(ctx, grad):
  class GridEncoder (line 142) | class GridEncoder(nn.Module):
    method __init__ (line 143) | def __init__(
    method reset_parameters (line 200) | def reset_parameters(self):
    method __repr__ (line 204) | def __repr__(self):
    method forward (line 207) | def forward(self, inputs, bound=1):

FILE: gridencoder/setup.py
  function find_cl_path (line 21) | def find_cl_path():

FILE: gridencoder/src/bindings.cpp
  function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: just_train_tea/network.py
  class NeRFNetwork (line 12) | class NeRFNetwork(NeRFRenderer):
    method __init__ (line 13) | def __init__(
    method init_plenoxel_volume (line 184) | def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128...
    method init_one_vm (line 193) | def init_one_vm(self, n_component, resolution, scale=0.1):
    method get_sigma_feat (line 216) | def get_sigma_feat(self, x):
    method get_color_feat (line 264) | def get_color_feat(self, x):
    method compute_plenoxel_fea (line 311) | def compute_plenoxel_fea(self, x):
    method forward_nerf_mlp (line 320) | def forward_nerf_mlp(self, x):
    method forward (line 331) | def forward(self, x, d):
    method density (line 422) | def density(self, x):
    method background (line 479) | def background(self, x, d):
    method color (line 498) | def color(self, x, d, mask=None, geo_feat=None, **kwargs):
    method density_loss (line 532) | def density_loss(self):
    method upsample_params (line 544) | def upsample_params(self, mat, vec, resolution):
    method upsample_model (line 567) | def upsample_model(self, resolution):
    method shrink_model (line 573) | def shrink_model(self):
    method get_params (line 628) | def get_params(self, lr, lr2=1e-3):

FILE: just_train_tea/provider.py
  function nerf_matrix_to_ngp (line 18) | def nerf_matrix_to_ngp(pose, scale=0.33):
  function rand_poses (line 32) | def rand_poses(
  class NeRFDataset (line 123) | class NeRFDataset:
    method __init__ (line 124) | def __init__(self, opt, device, type="train", downscale=1, n_test=10):
    method collate (line 284) | def collate(self, index):
    method dataloader (line 316) | def dataloader(self):

FILE: just_train_tea/renderer.py
  function sample_pdf (line 14) | def sample_pdf(bins, weights, n_samples, det=False):
  function plot_pointcloud (line 53) | def plot_pointcloud(pc, color=None):
  class NeRFRenderer (line 65) | class NeRFRenderer(nn.Module):
    method __init__ (line 66) | def __init__(
    method forward (line 116) | def forward(self, x, d):
    method density (line 120) | def density(self, x):
    method color (line 123) | def color(self, x, d, mask=None, **kwargs):
    method reset_extra_state (line 126) | def reset_extra_state(self):
    method run (line 138) | def run(
    method run_cuda (line 319) | def run_cuda(
    method mark_untrained_grid (line 555) | def mark_untrained_grid(self, poses, intrinsic, S=64):
    method update_extra_state (line 641) | def update_extra_state(self, decay=0.95, S=128):
    method render (line 770) | def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **k...

FILE: just_train_tea/utils.py
  function custom_meshgrid (line 36) | def custom_meshgrid(*args):
  function linear_to_srgb (line 45) | def linear_to_srgb(x):
  function srgb_to_linear (line 50) | def srgb_to_linear(x):
  function compute_ssim (line 54) | def compute_ssim(
  function init_lpips (line 138) | def init_lpips(net_name, device):
  function rgb_lpips (line 152) | def rgb_lpips(gt, im, net_name):
  function get_rays (line 160) | def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
  function seed_everything (line 242) | def seed_everything(seed):
  function torch_vis_2d (line 252) | def torch_vis_2d(x, renormalize=False):
  function extract_fields (line 277) | def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
  function extract_geometry (line 308) | def extract_geometry(bound_min, bound_max, resolution, threshold, query_...
  class PSNRMeter (line 326) | class PSNRMeter:
    method __init__ (line 327) | def __init__(self):
    method clear (line 331) | def clear(self):
    method prepare_inputs (line 335) | def prepare_inputs(self, *inputs):
    method update (line 344) | def update(self, preds, truths):
    method measure (line 355) | def measure(self):
    method write (line 358) | def write(self, writer, global_step, prefix=""):
    method report (line 361) | def report(self):
  class Trainer (line 365) | class Trainer(object):
    method __init__ (line 366) | def __init__(
    method __del__ (line 487) | def __del__(self):
    method log (line 491) | def log(self, *args, **kwargs):
    method train (line 500) | def train(self, train_loader, valid_loader, max_epochs):
    method train_one_epoch (line 543) | def train_one_epoch(self, loader):
    method get_loss (line 733) | def get_loss(self, pred, gt):
    method train_step (line 746) | def train_step(self, data):
    method evaluate (line 848) | def evaluate(self, loader, name=None):
    method evaluate_one_epoch (line 853) | def evaluate_one_epoch(self, loader, name=None):
    method eval_step (line 1028) | def eval_step(self, data):
    method save_checkpoint (line 1063) | def save_checkpoint(self, name=None, full=False, best=False, remove_ol...
    method load_teacher_checkpoint (line 1135) | def load_teacher_checkpoint(self):
    method load_student_checkpoint (line 1158) | def load_student_checkpoint(self):
    method test (line 1187) | def test(self, loader, save_path=None, name=None):
    method test_step (line 1237) | def test_step(self, data, bg_color=None, perturb=False):

FILE: main_distill_mutual.py
  function save_codes_env (line 15) | def save_codes_env(workspace):
  function load_from_txt (line 24) | def load_from_txt(opt, except_space=""):

FILE: raymarching/backend.py
  function find_cl_path (line 20) | def find_cl_path():

FILE: raymarching/raymarching.py
  class _near_far_from_aabb (line 20) | class _near_far_from_aabb(Function):
    method forward (line 23) | def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
  class _polar_from_ray (line 56) | class _polar_from_ray(Function):
    method forward (line 59) | def forward(ctx, rays_o, rays_d, radius):
  class _morton3D (line 90) | class _morton3D(Function):
    method forward (line 92) | def forward(ctx, coords):
  class _morton3D_invert (line 116) | class _morton3D_invert(Function):
    method forward (line 118) | def forward(ctx, indices):
  class _packbits (line 141) | class _packbits(Function):
    method forward (line 144) | def forward(ctx, grid, thresh, bitfield=None):
  class _march_rays_train (line 176) | class _march_rays_train(Function):
    method forward (line 179) | def forward(
  class _composite_rays_train (line 292) | class _composite_rays_train(Function):
    method forward (line 295) | def forward(ctx, sigmas, rgbs, deltas, rays):
    method backward (line 329) | def backward(ctx, grad_weights_sum, grad_depth, grad_image):
  class _march_rays (line 367) | class _march_rays(Function):
    method forward (line 370) | def forward(
  class _composite_rays (line 457) | class _composite_rays(Function):
    method forward (line 460) | def forward(
  class _compact_rays (line 505) | class _compact_rays(Function):
    method forward (line 508) | def forward(

FILE: raymarching/setup.py
  function find_cl_path (line 21) | def find_cl_path():

FILE: raymarching/src/bindings.cpp
  function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: raymarching/src/pcg32.h
  type pcg32 (line 44) | struct pcg32 {
  function next_uint (line 66) | uint32_t next_uint() {
  function next_uint (line 75) | uint32_t next_uint(uint32_t bound) {
  function next_float (line 107) | float next_float() {
  function next_double (line 125) | double next_double() {
  function operator (line 198) | bool operator==(const pcg32 &other) const { return state == other.state ...
  function operator (line 201) | bool operator!=(const pcg32 &other) const { return state != other.state ...

FILE: shencoder/backend.py
  function find_cl_path (line 20) | def find_cl_path():

FILE: shencoder/setup.py
  function find_cl_path (line 21) | def find_cl_path():

FILE: shencoder/sphere_harmonics.py
  class _sh_encoder (line 15) | class _sh_encoder(Function):
    method forward (line 18) | def forward(ctx, inputs, degree, calc_grad_inputs=False):
    method backward (line 48) | def backward(ctx, grad):
  class SHEncoder (line 67) | class SHEncoder(nn.Module):
    method __init__ (line 68) | def __init__(self, input_dim=3, degree=4):
    method __repr__ (line 80) | def __repr__(self):
    method forward (line 83) | def forward(self, inputs, size=1):

FILE: shencoder/src/bindings.cpp
  function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: tools/activation.py
  class _trunc_exp (line 6) | class _trunc_exp(Function):
    method forward (line 9) | def forward(ctx, x):
    method backward (line 15) | def backward(ctx, g):

FILE: tools/encoding.py
  class FreqEncoder (line 6) | class FreqEncoder(nn.Module):
    method __init__ (line 7) | def __init__(
    method forward (line 36) | def forward(self, input, **kwargs):
  function get_encoder (line 52) | def get_encoder(
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (444K chars).
[
  {
    "path": "LICENSE",
    "chars": 1051,
    "preview": "Copyright 2022 Megvii Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this softwar"
  },
  {
    "path": "README.md",
    "chars": 5416,
    "preview": "## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation (AAAI "
  },
  {
    "path": "distill_mutual/network.py",
    "chars": 24261,
    "preview": "import torch\nfrom time import time\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tools.encoding import get"
  },
  {
    "path": "distill_mutual/provider.py",
    "chars": 11693,
    "preview": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport numpy as np\nfrom scipy.spatial.transform import Slerp, R"
  },
  {
    "path": "distill_mutual/renderer.py",
    "chars": 31012,
    "preview": "import math\nimport trimesh\nimport numpy as np\nfrom time import time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn."
  },
  {
    "path": "distill_mutual/utils.py",
    "chars": 64507,
    "preview": "import os\nimport copy\nimport lpips\nimport glob\nimport tqdm\nimport math\nimport random\nimport warnings\nimport tensorboardX"
  },
  {
    "path": "gridencoder/__init__.py",
    "chars": 30,
    "preview": "from .grid import GridEncoder\n"
  },
  {
    "path": "gridencoder/backend.py",
    "chars": 1462,
    "preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
  },
  {
    "path": "gridencoder/grid.py",
    "chars": 7467,
    "preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function "
  },
  {
    "path": "gridencoder/setup.py",
    "chars": 1837,
    "preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
  },
  {
    "path": "gridencoder/src/bindings.cpp",
    "chars": 275,
    "preview": "#include <torch/extension.h>\n\n#include \"gridencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"grid_encod"
  },
  {
    "path": "gridencoder/src/gridencoder.cu",
    "chars": 19378,
    "preview": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/t"
  },
  {
    "path": "gridencoder/src/gridencoder.h",
    "chars": 968,
    "preview": "#ifndef _HASH_ENCODE_H\n#define _HASH_ENCODE_H\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, i"
  },
  {
    "path": "just_train_tea/network.py",
    "chars": 23398,
    "preview": "import torch\nfrom time import time\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tools.encoding import get"
  },
  {
    "path": "just_train_tea/provider.py",
    "chars": 11693,
    "preview": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport numpy as np\nfrom scipy.spatial.transform import Slerp, R"
  },
  {
    "path": "just_train_tea/renderer.py",
    "chars": 30911,
    "preview": "import math\nimport trimesh\nimport numpy as np\nfrom time import time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn."
  },
  {
    "path": "just_train_tea/utils.py",
    "chars": 45185,
    "preview": "import os\nimport lpips\nimport glob\nimport tqdm\nimport math\nimport random\nimport warnings\nimport tensorboardX\n\nimport num"
  },
  {
    "path": "main_distill_mutual.py",
    "chars": 14169,
    "preview": "import torch\nimport os\nimport argparse\n\nfrom distill_mutual.network import NeRFNetwork\nfrom functools import partial\nfro"
  },
  {
    "path": "main_just_train_tea.py",
    "chars": 12406,
    "preview": "import torch\nimport os\nimport argparse\n\nfrom just_train_tea.network import NeRFNetwork\n\nfrom functools import partial\nfr"
  },
  {
    "path": "raymarching/__init__.py",
    "chars": 27,
    "preview": "from .raymarching import *\n"
  },
  {
    "path": "raymarching/backend.py",
    "chars": 1461,
    "preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
  },
  {
    "path": "raymarching/raymarching.py",
    "chars": 16276,
    "preview": "import numpy as np\nimport time\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.a"
  },
  {
    "path": "raymarching/setup.py",
    "chars": 2273,
    "preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
  },
  {
    "path": "raymarching/src/bindings.cpp",
    "chars": 974,
    "preview": "#include <torch/extension.h>\n\n#include \"raymarching.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    // utils\n    m.de"
  },
  {
    "path": "raymarching/src/pcg32.h",
    "chars": 6904,
    "preview": "/*\n * Tiny self-contained version of the PCG Random Number Generation for C++\n * put together from pieces of the much la"
  },
  {
    "path": "raymarching/src/raymarching.cu",
    "chars": 32094,
    "preview": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/t"
  },
  {
    "path": "raymarching/src/raymarching.h",
    "chars": 2198,
    "preview": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n\nvoid near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_"
  },
  {
    "path": "shencoder/__init__.py",
    "chars": 40,
    "preview": "from .sphere_harmonics import SHEncoder\n"
  },
  {
    "path": "shencoder/backend.py",
    "chars": 1458,
    "preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
  },
  {
    "path": "shencoder/setup.py",
    "chars": 1831,
    "preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
  },
  {
    "path": "shencoder/sphere_harmonics.py",
    "chars": 2909,
    "preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function "
  },
  {
    "path": "shencoder/src/bindings.cpp",
    "chars": 261,
    "preview": "#include <torch/extension.h>\n\n#include \"shencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"sh_encode_fo"
  },
  {
    "path": "shencoder/src/shencoder.cu",
    "chars": 37210,
    "preview": "#include <stdint.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext"
  },
  {
    "path": "shencoder/src/shencoder.h",
    "chars": 628,
    "preview": "# pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [-1, 1]\n// outputs: [B, F], fl"
  },
  {
    "path": "tools/activation.py",
    "chars": 516,
    "preview": "import torch\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n\nclass _trunc_exp(Fu"
  },
  {
    "path": "tools/details.md",
    "chars": 2437,
    "preview": "# custom datasets\n\nOur dataset format is based on the [torch-ngp](https://github.com/ashawkey/torch-ngp/tree/3b066b6cd6c"
  },
  {
    "path": "tools/encoding.py",
    "chars": 3145,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FreqEncoder(nn.Module):\n    def __init__(\n   "
  },
  {
    "path": "tools/install_extensions.sh",
    "chars": 106,
    "preview": "cd raymarching\npip install .\ncd ..\n\ncd gridencoder\npip install .\ncd ..\n\ncd shencoder\npip install .\ncd .. \n"
  },
  {
    "path": "tools/requirements.txt",
    "chars": 147,
    "preview": "torch-ema\nninja\ntrimesh\nopencv-python\ntensorboardX\ntorch\nnumpy \npandas\ntqdm\nmatplotlib\nPyMCubes\nrich\npysdf\ndearpygui\npac"
  },
  {
    "path": "tools/中文介绍.md",
    "chars": 5090,
    "preview": "## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation\n(**Acc"
  }
]

About this extraction

This page contains the full source code of the megvii-research/AAAI2023-PVD GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (415.1 KB), approximately 116.2k tokens, and a symbol index with 210 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!