Showing preview only (434K chars total). Download the full file or copy to clipboard to get everything.
Repository: megvii-research/AAAI2023-PVD
Branch: main
Commit: e5d0ab174f24
Files: 40
Total size: 415.1 KB
Directory structure:
gitextract_eepeq680/
├── LICENSE
├── README.md
├── distill_mutual/
│ ├── network.py
│ ├── provider.py
│ ├── renderer.py
│ └── utils.py
├── gridencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── grid.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── gridencoder.cu
│ └── gridencoder.h
├── just_train_tea/
│ ├── network.py
│ ├── provider.py
│ ├── renderer.py
│ └── utils.py
├── main_distill_mutual.py
├── main_just_train_tea.py
├── raymarching/
│ ├── __init__.py
│ ├── backend.py
│ ├── raymarching.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── pcg32.h
│ ├── raymarching.cu
│ └── raymarching.h
├── shencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── setup.py
│ ├── sphere_harmonics.py
│ └── src/
│ ├── bindings.cpp
│ ├── shencoder.cu
│ └── shencoder.h
└── tools/
├── activation.py
├── details.md
├── encoding.py
├── install_extensions.sh
├── requirements.txt
└── 中文介绍.md
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Copyright 2022 Megvii Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
================================================
FILE: README.md
================================================
## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation (AAAI Oral)
# :partying_face: ***New*** :partying_face: Code for more powerful PVD-AL (the follow-up work of PVD) is now provided [here](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL).
*(We strongly recommend using [PVD-AL](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL) (the follow-up work of pvd) with better performance).*
## [Project Page](http://sk-fun.fun/PVD/) | [Paper](https://arxiv.org/abs/2211.15977) | [Datasets](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing) | [Ckpts](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing) | [Chinese tutorial](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/%E4%B8%AD%E6%96%87%E4%BB%8B%E7%BB%8D.md) | [zhihu](https://zhuanlan.zhihu.com/p/605121286)|
## Introduction
In this paper, we propose Progressive Volume Distillation (PVD), a systematic distillation method that allows any-to-any conversions between different neural architectures, including MLP(NeRF), sparse(Plenoxels) or low-rank tensors(TensoRF), hash tables(INGP).
## Installation
We recommend using [Anaconda](https://www.anaconda.com/) to setup the environment. Run the following commands:
*Step1*: Create a conda environment named 'pvd'
```
conda create --name pvd python=3.7
conda activate pvd
pip install -r ./tools/requirements.txt
```
*Step2*: Install extension modules. (Draw from the great project [torch-ngp](https://github.com/ashawkey/torch-ngp) that we mainly rely on.)
```
bash ./tools/install_extensions.sh
```
## Datastes & Pretrained-teacher models
You can download Synthetic-NeRF/LLFF/Tanks&Temples datasets from [google](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing), or from [baidu](https://pan.baidu.com/s/1ky_TWrbUZG_MpHTBhncAKA?pwd=4h2h).
And download pretrained-teacher-models from [google](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing), or from [baidu](https://pan.baidu.com/s/1LGLXwLGusX60GpAywLwosg?pwd=34k8).
You can also train a teacher model according to the follow guidance.
## Train a teacher
```
# train a hash-based(INGP) teacher
python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type hash --data_type synthetic --workspace ./log/train_teacher/hash_chair
# train a sparse-tensor-based(TensoRF VM-decomposion) teacher
python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type vm --data_type synthetic --workspace ./log/train_teacher/vm_chair
# train a MLP-based(NeRF) teacher
python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type mlp --data_type synthetic --workspace ./log/train_teacher/mlp_chair
# train a tensors-based(Plenoxels) teacher
python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type tensors --data_type synthetic --workspace ./log/train_teacher/tensors_chair
```
## Distill a student
```
# teacher: hash(INGP), student: vm(tensoRF)
python3 main_distill_mutual.py ./data/nerf_synthetic/chair \
--data_type synthetic \
--teacher_type hash \
--ckpt_teacher ./log/train_teacher/hash_chair/checkpoints/XXX.pth \
--model_type vm \
--workspace ./log/distill_student/hash2vm/chair
# teacher: MLP(NeRF), student: tensors(Plenoxels)
python3 main_distill_mutual.py ./data/nerf_synthetic/chair \
--data_type synthetic \
--teacher_type mlp \
--ckpt_teacher ./log/train_teacher/mlp_chair/checkpoints/XXX.pth \
--model_type tensors \
--workspace ./log/distill_student/mlp2tensors/chair
```
## Evaluation
```
# evaluate a hash teacher
python main_distill_mutual.py ./data/nerf_synthetic/chair --teacher_type hash --ckpt_teacher PATH/TO/CKPT.pth --test_teacher --data_type synthetic --workspace ./log/eval_teacher/hash_chair
# evaluate a mlp student
python main_distill_mutual.py ./data/nerf_synthetic/chair --model_type mlp --ckpt PATH/TO/CKPT.pth --test --data_type synthetic --workspace ./log/eval_student/mlp_chair
```
## More detailed parameter description and running commonds
Please refer to [more running description](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/details.md) for details of training different types of datasets, parameter adjustment, key settings, etc.
## Citation
If you find our code or paper useful, please consider citing
```
@article{fang2022one,
title={One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation},
author={Fang, Shuangkang and Xu, Weixin and Wang, Heng and Yang, Yi and Wang, Yufeng and Zhou, Shuchang},
journal={arXiv preprint arXiv:2211.15977},
year={2022}
}
```
### Acknowledgement
We would like to thank [ingp](https://github.com/NVlabs/instant-ngp), [torch-ngp](https://github.com/ashawkey/torch-ngp), [TensoRF](https://github.com/apchenstu/TensoRF), [Plenoxels](https://github.com/sxyu/svox2), [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch) for their great frameworks!
Also check out [Arch-Net](https://github.com/megvii-research/Arch-Net) for more on general progressive distillation.
================================================
FILE: distill_mutual/network.py
================================================
import torch
from time import time
import torch.nn as nn
import torch.nn.functional as F
from tools.encoding import get_encoder
from tools.activation import trunc_exp
from .renderer import NeRFRenderer
import raymarching
class NeRFNetwork(NeRFRenderer):
def __init__(
self,
encoding="hashgrid",
encoding_dir="sphere_harmonics",
encoding_bg="hashgrid",
num_layers=2,
hidden_dim=64,
geo_feat_dim=15,
num_layers_color=3,
hidden_dim_color=64,
num_layers_bg=2,
hidden_dim_bg=64,
bound=1,
model_type="hash",
args=None,
is_teacher=False,
**kwargs,
):
super().__init__(bound, **kwargs)
# sigma network
assert model_type in ["hash", "mlp", "vm", "tensors"]
self.is_teacher = is_teacher
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.geo_feat_dim = geo_feat_dim
self.args = args
self.opt = args
self.model_type = model_type
self.plenoxel_degree = args.plenoxel_degree
self.plenoxel_res = eval(args.plenoxel_res)
assert len(self.plenoxel_res) == 3
self.encoder, self.in_dim = get_encoder(
encoding,
desired_resolution=2048 * bound,
num_levels=14,
)
if "hash" != self.model_type:
self.encoder = None
if self.model_type == "mlp":
self.encoder_nerf_pe, self.in_dim_nerf = get_encoder(
encoding="frequency", multires=self.args.PE
)
self.skips = self.args.skip
self.nerf_layer_num = self.args.nerf_layer_num
W = self.args.nerf_layer_wide
self.nerf_mlp = [nn.Linear(self.in_dim_nerf, W)]
for i in range(self.nerf_layer_num - 2):
if i != self.skips:
self.nerf_mlp.append(nn.Linear(W, W))
else:
self.nerf_mlp.append(nn.Linear(W + self.in_dim_nerf, W))
self.nerf_mlp.append(nn.Linear(W, self.in_dim))
self.nerf_mlp = nn.ModuleList(self.nerf_mlp)
elif self.model_type == "vm":
self.sigma_rank = [16] * 3
self.color_rank = [48] * 3
self.color_feat_dim = 15 # geo_feat_dim
self.mat_ids = [[0, 1], [0, 2], [1, 2]]
self.vec_ids = [2, 1, 0]
self.resolution = [self.opt.resolution0] * 3
# mat: paralist[1,16,res0,res0] repeat 3 vec: paralist[1,16,res0,1] repeat 3; repeat3 because decompose 3D grid [H, W, D] to three 2D mat [H, W], [H,D], [W, D] or decompose to three 1D vec [H], [W], [D]
self.sigma_mat, self.sigma_vec = self.init_one_vm(
self.sigma_rank, self.resolution
)
# mat: paralist[1,48,res0,res0] repeat 3 vec: paralist[1,48,res0,1] repeat 3
self.color_mat, self.color_vec = self.init_one_vm(
self.color_rank, self.resolution
)
# Linear(in_features=144, out_features=27)
self.basis_mat = nn.Linear(
sum(self.color_rank), self.color_feat_dim, bias=False
)
elif self.model_type == "tensors":
self.init_plenoxel_volume(
s=0.02,
fea_dim=self.plenoxel_degree ** 2 * 3 + 1,
volume=self.plenoxel_res,
)
elif self.model_type == "hash":
pass
else:
raise ValueError(f"error model_type:{self.model_type}")
if self.model_type != "vm" and self.model_type != "tensors":
sigma_net = []
for l in range(num_layers):
if l == 0:
in_dim = self.in_dim
else:
in_dim = hidden_dim
if l == num_layers - 1:
out_dim = (
1 + self.geo_feat_dim
) # 1 sigma + 15 SH features for color
else:
out_dim = hidden_dim
sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))
self.sigma_net = nn.ModuleList(sigma_net)
# color network
self.num_layers_color = num_layers_color
self.hidden_dim_color = hidden_dim_color
# self.encoder_dir, self.in_dim_dir = get_encoder(encoding=encoding_dir)
if self.model_type == "tensors":
self.encoder_dir, self.in_dim_dir = get_encoder(
encoding="sphere_harmonics",
degree=self.plenoxel_degree,
)
else:
self.encoder_dir, self.in_dim_dir = get_encoder(
encoding=encoding_dir, input_dim=3, multires=2
)
if self.model_type != "tensors":
color_net = []
for l in range(num_layers_color):
if l == 0:
in_dim = self.in_dim_dir + self.geo_feat_dim
else:
in_dim = hidden_dim
if l == num_layers_color - 1:
out_dim = 3 # 3 rgb
else:
out_dim = hidden_dim
color_net.append(nn.Linear(in_dim, out_dim, bias=False))
self.color_net = nn.ModuleList(color_net)
# background network
if self.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
self.encoder_bg, self.in_dim_bg = get_encoder(
encoding_bg,
input_dim=2,
num_levels=4,
log2_hashmap_size=19,
desired_resolution=2048,
) # much smaller hashgrid
bg_net = []
for l in range(num_layers_bg):
if l == 0:
in_dim = self.in_dim_bg + self.in_dim_dir
else:
in_dim = hidden_dim_bg
if l == num_layers_bg - 1:
out_dim = 3 # 3 rgb
else:
out_dim = hidden_dim_bg
bg_net.append(nn.Linear(in_dim, out_dim, bias=False))
self.bg_net = nn.ModuleList(bg_net)
else:
self.bg_net = None
def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128, 128]):
tensor = []
tensor.append(
torch.nn.Parameter(
s * torch.randn((1, fea_dim, volume[0], volume[1], volume[2]))
)
)
self.tensor_volume = torch.nn.ParameterList(tensor).cuda()
def init_one_vm(self, n_component, resolution, scale=0.1):
# self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0]
mat, vec = [], []
for i in range(len(self.vec_ids)):
vec_id = self.vec_ids[i]
mat_id_0, mat_id_1 = self.mat_ids[i]
mat.append(
nn.Parameter(
scale
* torch.randn(
(1, n_component[i], resolution[mat_id_1], resolution[mat_id_0])
)
)
) # [1, R, H, W]
vec.append(
nn.Parameter(
scale * torch.randn((1, n_component[i], resolution[vec_id], 1))
)
) # [1, R, D, 1] (fake 2d to use grid_sample)
return nn.ParameterList(mat), nn.ParameterList(vec)
def get_sigma_feat(self, x):
# x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode)
# self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0]
N = x.shape[0]
# plane + line basis
mat_coord = (
torch.stack(
(
x[..., self.mat_ids[0]],
x[..., self.mat_ids[1]],
x[..., self.mat_ids[2]],
)
)
.detach()
.view(3, -1, 1, 2)
) # [3, N, 1, 2]
vec_coord = torch.stack(
(x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])
)
vec_coord = (
torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)
.detach()
.view(3, -1, 1, 2)
) # [3, N, 1, 2], fake 2d coord
sigma_feat = torch.zeros(
[
N,
],
device=x.device,
)
for i in range(len(self.sigma_mat)):
mat_feat = F.grid_sample(
self.sigma_mat[i], mat_coord[[i]], align_corners=True
).view(
-1, N
) # [1, R, N, 1] --> [R, N]
vec_feat = F.grid_sample(
self.sigma_vec[i], vec_coord[[i]], align_corners=True
).view(
-1, N
) # [R, N]
sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0)
return sigma_feat
def get_color_feat(self, x):
# x: [N, 3], in [-1, 1]
N = x.shape[0]
# plane + line basis
mat_coord = (
torch.stack(
(
x[..., self.mat_ids[0]],
x[..., self.mat_ids[1]],
x[..., self.mat_ids[2]],
)
)
.detach()
.view(3, -1, 1, 2)
) # [3, N, 1, 2]
vec_coord = torch.stack(
(x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])
)
vec_coord = (
torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)
.detach()
.view(3, -1, 1, 2)
) # [3, N, 1, 2], fake 2d coord
mat_feat, vec_feat = [], []
for i in range(len(self.color_mat)):
mat_feat.append(
F.grid_sample(
self.color_mat[i], mat_coord[[i]], align_corners=True
).view(-1, N)
) # [1, R, N, 1] --> [R, N]
vec_feat.append(
F.grid_sample(
self.color_vec[i], vec_coord[[i]], align_corners=True
).view(-1, N)
) # [R, N]
mat_feat = torch.cat(mat_feat, dim=0) # [3 * R, N]
vec_feat = torch.cat(vec_feat, dim=0) # [3 * R, N]
color_feat = self.basis_mat(
(mat_feat * vec_feat).T
) # [N, 3R] --> [N, color_feat_dim]
return color_feat
def compute_plenoxel_fea(self, x):
composed = self.tensor_volume[0]
if self.args.enable_edit_plenoxel and self.is_teacher:
composed[
:, 0, :, 160:, :128
] = -100 # This will erase the bucket in the lego scene for resolution 256
composed = (
F.grid_sample(composed, x.view(1, 1, -1, 1, 3), align_corners=True)
.view(-1, x.shape[0])
.permute(1, 0)
)
return composed # [N, fea_dim]
def forward_nerf_mlp(self, x):
x = self.encoder_nerf_pe(x)
in_pts = x
for i in range(len(self.nerf_mlp)):
x = self.nerf_mlp[i](x)
if i != len(self.nerf_mlp) - 1:
x = F.relu(x, inplace=True)
if i == self.skips:
x = torch.cat([in_pts, x], -1)
return x
def forward(self, x, d):
# x: [N, 3], in [-bound, bound] d: [N, 3], nomalized in [-1, 1]
# sigma
if self.model_type == "hash":
x = self.encoder(
x, bound=self.bound
) # out_x[N, 28=num_levels * fea_per_level]
elif self.model_type == "mlp":
x = self.forward_nerf_mlp(x) # 28
elif self.model_type == "vm":
x = (
2
* (x - self.aabb_train[:3])
/ (self.aabb_train[3:] - self.aabb_train[:3])
- 1
) # x:[N, 3]
sigma_feat = self.get_sigma_feat(x) # sigma_feat:[N]
color_feat = self.get_color_feat(x) # color_feat:[N, 15]
if self.opt.enable_edit_plenoxel:
sigma_feat = torch.clamp(sigma_feat, -100, self.args.sigma_clip_max)
else:
sigma_feat = torch.clamp(
sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
)
color_feat = torch.clamp(
color_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
)
self.feature_sigma_color = torch.cat(
[sigma_feat.unsqueeze(-1), color_feat], dim=-1
)
if (
self.training
and self.args.global_step < self.args.stage_iters["stage1"]
):
return None, None
self.sigma_l = sigma_feat
sigma = trunc_exp(sigma_feat) # sigma:[N]
enc_d = self.encoder_dir(d) # enc_d:[N, 16]
h = torch.cat([enc_d, color_feat], dim=-1) # h:[N, 16+15]
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)
color = torch.sigmoid(h)
self.color_l = color
return sigma, color
elif self.model_type == "tensors":
x = (
2
* (x - self.aabb_train[:3])
/ (self.aabb_train[3:] - self.aabb_train[:3])
- 1
) # x:[N, 3]
x = self.compute_plenoxel_fea(x)
h = x
if self.opt.enable_edit_plenoxel:
sigma = torch.clamp(h[..., 0], -100, self.args.sigma_clip_max)
else:
sigma = torch.clamp(
h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max
)
self.sigma_l = sigma
sigma = trunc_exp(sigma)
self.sigma = sigma
sh = h[..., 1:].view(
-1, 3, self.plenoxel_degree ** 2
) # [N, 3, 9] ## .permute(1, 0, 2) # [B, 27]-->[9, B, 3]
enc_d = self.encoder_dir(d).unsqueeze(1) # [N, 9]-->[N,1,9]
color = (sh * enc_d).sum(-1) # [N, 3]
color = torch.sigmoid(color)
self.feature_sigma_color = None
self.color_l = color
return sigma, color
else:
raise ValueError(f"not illegal model_type:{self.model_type}")
h = x
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)
h[..., 0] = torch.clamp(
h[..., 0].clone(), self.args.sigma_clip_min, self.args.sigma_clip_max
)
self.feature_sigma_color = h
if self.training and self.args.global_step < self.args.stage_iters["stage1"]:
return None, None
self.sigma_l = h[..., 0]
sigma = trunc_exp(h[..., 0]) # sigma: [n]
geo_feat = h[..., 1:] # geo_feat: [n, 15]
d = self.encoder_dir(d) # d: [n, 16]
h = torch.cat([d, geo_feat], dim=-1) # h: [n, 15+16]
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)
color = torch.sigmoid(h)
self.color_l = color
return sigma, color
def density(self, x):
# x: [N, 3], in [-bound, bound]
if self.model_type == "hash":
x = self.encoder(
x, bound=self.bound
) # out_x[N, 32=num_levels * fea_per_level]
elif self.model_type == "mlp":
x = self.forward_nerf_mlp(x)
elif self.model_type == "vm":
x = (
2
* (x - self.aabb_train[:3])
/ (self.aabb_train[3:] - self.aabb_train[:3])
- 1
)
sigma_feat = self.get_sigma_feat(x)
sigma_feat = torch.clamp(
sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
)
sigma = trunc_exp(sigma_feat)
return {"sigma": sigma}
elif self.model_type == "tensors":
x = (
2
* (x - self.aabb_train[:3])
/ (self.aabb_train[3:] - self.aabb_train[:3])
- 1
) # x:[N, 3]
x = self.compute_plenoxel_fea(x)
h = x
# h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
sigma = trunc_exp(
torch.clamp(
h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max
)
)
sigma = trunc_exp(h[..., 0])
return {"sigma": sigma}
else:
raise ValueError(f"not illegal model_type:{self.model_type}")
h = x
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)
h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
sigma = trunc_exp(h[..., 0])
geo_feat = h[..., 1:]
return {
"sigma": sigma,
"geo_feat": geo_feat,
}
def background(self, x, d):
assert 1 == 2
# x: [N, 2], in [-1, 1]
h = self.encoder_bg(x) # [N, C]
d = self.encoder_dir(d)
h = torch.cat([d, h], dim=-1)
for l in range(self.num_layers_bg):
h = self.bg_net[l](h)
if l != self.num_layers_bg - 1:
h = F.relu(h, inplace=True)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# allow masked inference
def color(self, x, d, mask=None, geo_feat=None, **kwargs):
assert 1 == 2
# x: [N, 3] in [-bound, bound]
# mask: [N,], bool, indicates where we actually needs to compute rgb.
if mask is not None:
rgbs = torch.zeros(
mask.shape[0], 3, dtype=x.dtype, device=x.device
) # [N, 3]
# in case of empty mask
if not mask.any():
return rgbs
x = x[mask]
d = d[mask]
geo_feat = geo_feat[mask]
d = self.encoder_dir(d)
h = torch.cat([d, geo_feat], dim=-1)
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)
# sigmoid activation for rgb
h = torch.sigmoid(h)
if mask is not None:
rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
else:
rgbs = h
return rgbs
# L1 penalty for loss
def density_loss(self):
loss = 0
for i in range(len(self.sigma_mat)):
loss = (
loss
+ torch.mean(torch.abs(self.sigma_mat[i]))
+ torch.mean(torch.abs(self.sigma_vec[i]))
)
return loss
# upsample utils
@torch.no_grad()
def upsample_params(self, mat, vec, resolution):
for i in range(len(self.vec_ids)):
vec_id = self.vec_ids[i]
mat_id_0, mat_id_1 = self.mat_ids[i]
mat[i] = nn.Parameter(
F.interpolate(
mat[i].data,
size=(resolution[mat_id_1], resolution[mat_id_0]),
mode="bilinear",
align_corners=True,
)
)
vec[i] = nn.Parameter(
F.interpolate(
vec[i].data,
size=(resolution[vec_id], 1),
mode="bilinear",
align_corners=True,
)
)
@torch.no_grad()
def upsample_model(self, resolution):
self.upsample_params(self.sigma_mat, self.sigma_vec, resolution)
self.upsample_params(self.color_mat, self.color_vec, resolution)
self.resolution = resolution
@torch.no_grad()
def shrink_model(self):
# shrink aabb_train and the model so it only represents the space inside aabb_train.
half_grid_size = self.bound / self.grid_size
thresh = min(self.density_thresh, self.mean_density)
# get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?)
valid_grid = self.density_grid[self.cascade - 1] > thresh # [N]
valid_pos = raymarching.morton3D_invert(
torch.nonzero(valid_grid)
) # [Nz] --> [Nz, 3], in [0, H - 1]
# plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf...
valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (
self.bound - half_grid_size
) # [Nz, 3], in [-b+hgs, b-hgs]
min_pos = valid_pos.amin(0) - half_grid_size # [3]
max_pos = valid_pos.amax(0) + half_grid_size # [3]
# shrink model
reso = torch.LongTensor(self.resolution).to(self.aabb_train.device)
units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso
tl = (min_pos - self.aabb_train[:3]) / units
br = (max_pos - self.aabb_train[:3]) / units
tl = torch.round(tl).long().clamp(min=0)
br = torch.minimum(torch.round(br).long(), reso)
for i in range(len(self.vec_ids)):
vec_id = self.vec_ids[i]
mat_id_0, mat_id_1 = self.mat_ids[i]
self.sigma_vec[i] = nn.Parameter(
self.sigma_vec[i].data[..., tl[vec_id] : br[vec_id], :]
)
self.color_vec[i] = nn.Parameter(
self.color_vec[i].data[..., tl[vec_id] : br[vec_id], :]
)
self.sigma_mat[i] = nn.Parameter(
self.sigma_mat[i].data[
..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]
]
)
self.color_mat[i] = nn.Parameter(
self.color_mat[i].data[
..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]
]
)
self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6]
print(
f"[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}"
)
print(f"[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}")
# optimizer utils
def get_params(self, lr, lr2=1e-3):
if self.model_type == "hash":
params = [
{"params": self.encoder.parameters(), "lr": lr},
{"params": self.sigma_net.parameters(), "lr": lr},
{"params": self.encoder_dir.parameters(), "lr": lr},
{"params": self.color_net.parameters(), "lr": lr},
]
elif self.model_type == "mlp":
params = [
{"params": self.sigma_net.parameters(), "lr": lr},
{"params": self.encoder_dir.parameters(), "lr": lr},
{"params": self.color_net.parameters(), "lr": lr},
{"params": self.nerf_mlp.parameters(), "lr": lr},
]
elif self.model_type == "vm":
params = [
{"params": self.color_net.parameters(), "lr": lr2},
{"params": self.sigma_mat, "lr": lr},
{"params": self.sigma_vec, "lr": lr},
{"params": self.color_mat, "lr": lr},
{"params": self.color_vec, "lr": lr},
{"params": self.basis_mat.parameters(), "lr": lr2},
]
elif self.model_type == "tensors":
params = [
{"params": self.tensor_volume.parameters(), "lr": lr},
{"params": self.encoder_dir.parameters(), "lr": lr},
]
else:
raise ValueError(f"not illegal model_type:{self.model_type}")
if self.bg_radius > 0:
params.append({"params": self.encoder_bg.parameters(), "lr": lr})
params.append({"params": self.bg_net.parameters(), "lr": lr})
return params
================================================
FILE: distill_mutual/provider.py
================================================
import os
import cv2
import glob
import json
import tqdm
import numpy as np
from scipy.spatial.transform import Slerp, Rotation
import trimesh
import torch
from torch.utils.data import DataLoader
from .utils import get_rays, srgb_to_linear
# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
def nerf_matrix_to_ngp(pose, scale=0.33):
# for the fox dataset, 0.33 scales camera radius to ~ 2
new_pose = np.array(
[
[pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],
[pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],
[pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],
[0, 0, 0, 1],
],
dtype=np.float32,
)
return new_pose
def rand_poses(
size,
device,
radius=1,
theta_range=[np.pi / 3, 2 * np.pi / 3],
phi_range=[0, 2 * np.pi],
):
"""generate random poses from an orbit camera
Args:
size: batch size of generated poses.
device: where to allocate the output.
radius: camera radius
theta_range: [min, max], should be in [0, \pi]
phi_range: [min, max], should be in [0, 2\pi]
Return:
poses: [size, 4, 4]
"""
def normalize(vectors):
return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
thetas = (
torch.rand(size, device=device) * (theta_range[1] - theta_range[0])
+ theta_range[0]
)
phis = (
torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
)
centers = torch.stack(
[
radius * torch.sin(thetas) * torch.sin(phis),
radius * torch.cos(thetas),
radius * torch.sin(thetas) * torch.cos(phis),
],
dim=-1,
) # [B, 3]
# lookat
forward_vector = -normalize(centers)
up_vector = (
torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
) # confused at the coordinate system...
right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))
up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))
poses = (
torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
return poses
def normalize(vectors):
return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
interval_nums = torch.tensor(
[i * 1 / (size - 1) for i in range(size)], dtype=torch.float32, device=device
)
thetas = interval_nums * (theta_range[1] - theta_range[0]) + theta_range[0]
phis = interval_nums * (phi_range[1] - phi_range[0]) + phi_range[0]
centers = torch.stack(
[
radius * torch.sin(thetas) * torch.sin(phis),
radius * torch.cos(thetas),
radius * torch.sin(thetas) * torch.cos(phis),
],
dim=-1,
) # [B, 3]
# lookat
forward_vector = -normalize(centers)
up_vector = (
torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
) # confused at the coordinate system...
right_vector = normalize(
torch.cross(forward_vector, up_vector, dim=-1)
) # cross product
up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))
poses = (
torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
return poses
class NeRFDataset:
def __init__(self, opt, device, type="train", downscale=1, n_test=10):
super().__init__()
self.opt = opt
self.args = opt
self.device = device
self.type = type # train, val, test
self.downscale = downscale
self.root_path = opt.path
self.mode = opt.mode # only support blender
self.preload = opt.preload # preload data into GPU
self.scale = (
opt.scale
) # camera radius scale to make sure camera are inside the bounding box.
self.bound = (
opt.bound
) # bounding box half length, also used as the radius to random sample poses.
self.fp16 = opt.fp16 # if preload, load into fp16.
self.training = self.type in ["train", "all", "trainval"]
self.num_rays = self.opt.num_rays if self.training else -1
if self.mode == "blender":
if type == "all":
transform_paths = glob.glob(os.path.join(self.root_path, "*.json"))
transform = None
for transform_path in transform_paths:
with open(transform_path, "r") as f:
tmp_transform = json.load(f)
if transform is None:
transform = tmp_transform
else:
transform["frames"].extend(tmp_transform["frames"])
# load train and val split
elif type == "trainval":
with open(
os.path.join(self.root_path, f"transforms_train.json"), "r"
) as f:
transform = json.load(f)
with open(
os.path.join(self.root_path, f"transforms_val.json"), "r"
) as f:
transform_val = json.load(f)
transform["frames"].extend(transform_val["frames"])
# only load one specified split
else:
with open(
os.path.join(self.root_path, f"transforms_{type}.json"), "r"
) as f:
transform = json.load(f)
else:
raise NotImplementedError(f"unknown dataset mode: {self.mode}")
# load image size
if "h" in transform and "w" in transform:
self.H = int(transform["h"]) // downscale
self.W = int(transform["w"]) // downscale
else:
# we have to actually read an image to get H and W later.
self.H = self.W = None
# read images
frames = transform["frames"]
if True:
self.poses = []
self.images = []
for f in tqdm.tqdm(frames, desc=f"Loading {type} data:"):
f_path = os.path.join(self.root_path, f["file_path"])
if (
self.mode == "blender"
and f_path[-4:].lower() != ".png"
and f_path[-4:].lower() != ".jpg"
):
f_path += ".png" # so silly...
if not os.path.exists(f_path):
continue
pose = np.array(f["transform_matrix"], dtype=np.float32) # [4, 4]
pose = nerf_matrix_to_ngp(pose, scale=self.scale)
image = cv2.imread(
f_path, cv2.IMREAD_UNCHANGED
) # [H, W, 3] o [H, W, 4]
if self.H is None or self.W is None:
self.H = image.shape[0] // downscale
self.W = image.shape[1] // downscale
# add support for the alpha channel as a mask.
if image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
if image.shape[0] != self.H or image.shape[1] != self.W:
image = cv2.resize(
image, (self.W, self.H), interpolation=cv2.INTER_AREA
)
image = image.astype(np.float32) / 255 # [H, W, 3/4]
self.poses.append(pose)
self.images.append(image)
self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4]
if self.images is not None:
self.images = torch.from_numpy(
np.stack(self.images, axis=0)
) # [N, H, W, C]
self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()
if self.training and self.opt.error_map:
self.error_map = torch.ones(
[self.images.shape[0], 128 * 128], dtype=torch.float
) # [B, 128 * 128], flattened for easy indexing, fixed resolution...
else:
self.error_map = None
if self.preload:
self.poses = self.poses.to(self.device)
if self.images is not None:
if self.fp16 and self.opt.color_space != "linear":
dtype = torch.half
else:
dtype = torch.float
self.images = self.images.to(dtype).to(self.device)
if self.error_map is not None:
self.error_map = self.error_map.to(self.device)
# load intrinsics
if "fl_x" in transform or "fl_y" in transform:
fl_x = (
transform["fl_x"] if "fl_x" in transform else transform["fl_y"]
) / downscale
fl_y = (
transform["fl_y"] if "fl_y" in transform else transform["fl_x"]
) / downscale
elif "camera_angle_x" in transform or "camera_angle_y" in transform:
# blender, assert in radians. already downscaled since we use H/W
fl_x = (
self.W / (2 * np.tan(transform["camera_angle_x"] / 2))
if "camera_angle_x" in transform
else None
)
fl_y = (
self.H / (2 * np.tan(transform["camera_angle_y"] / 2))
if "camera_angle_y" in transform
else None
)
if fl_x is None:
fl_x = fl_y
if fl_y is None:
fl_y = fl_x
else:
raise RuntimeError(
"Failed to load focal length, please check the transforms.json!"
)
cx = (transform["cx"] / downscale) if "cx" in transform else (self.H / 2)
cy = (transform["cy"] / downscale) if "cy" in transform else (self.W / 2)
self.intrinsics = np.array([fl_x, fl_y, cx, cy])
def collate(self, index):
B = len(index) # a list of length 1
poses = self.poses[index].to(self.device) # [B, 4, 4]
error_map = None if self.error_map is None else self.error_map[index]
rays = get_rays(
poses, self.intrinsics, self.H, self.W, self.num_rays, error_map
)
results = {
"H": self.H,
"W": self.W,
"rays_o": rays["rays_o"],
"rays_d": rays["rays_d"],
}
if self.images is not None:
images = self.images[index].to(self.device) # [B, H, W, 3/4]
if self.training:
C = images.shape[-1]
images = torch.gather(
images.view(B, -1, C), 1, torch.stack(C * [rays["inds"]], -1)
) # [B, N, 3/4]
results["images"] = images
# need inds to update error_map
if error_map is not None:
results["index"] = index
results["inds_coarse"] = rays["inds_coarse"]
return results
def dataloader(self):
size = len(self.poses)
loader = DataLoader(
list(range(size)),
batch_size=1,
collate_fn=self.collate,
shuffle=self.training,
num_workers=0,
)
loader._data = self
return loader
================================================
FILE: distill_mutual/renderer.py
================================================
import math
import trimesh
import numpy as np
from time import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import raymarching
from .utils import custom_meshgrid
from IPython import embed
def sample_pdf(bins, weights, n_samples, det=False):
# This implementation is from NeRF
# bins: [B, T], old_z_vals
# weights: [B, T - 1], bin weights.
# return: [B, n_samples], new_z_vals
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
# Take uniform samples
if det:
u = torch.linspace(
0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples
).to(weights.device)
u = u.expand(list(cdf.shape[:-1]) + [n_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (B, n_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = cdf_g[..., 1] - cdf_g[..., 0]
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def plot_pointcloud(pc, color=None):
# pc: [N, 3]
# color: [N, 3/4]
print("[visualize points]", pc.shape, pc.dtype, pc.min(0), pc.max(0))
pc = trimesh.PointCloud(pc, color)
# axis
axes = trimesh.creation.axis(axis_length=4)
# sphere
sphere = trimesh.creation.icosphere(radius=1)
trimesh.Scene([pc, axes, sphere]).show()
class NeRFRenderer(nn.Module):
def __init__(
self,
bound=1,
cuda_ray=False,
density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
min_near=0.2,
density_thresh=0.01,
bg_radius=-1,
grid_size=128,
):
super().__init__()
print("\n---------------", grid_size, "--------------\n")
self.bound = bound
self.cascade = 1 + math.ceil(math.log2(bound))
self.grid_size = grid_size
self.density_scale = density_scale
self.min_near = min_near
self.density_thresh = density_thresh
self.bg_radius = bg_radius # radius of the background sphere.
# prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax)
# NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing.
aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound])
aabb_infer = aabb_train.clone()
self.register_buffer("aabb_train", aabb_train)
self.register_buffer("aabb_infer", aabb_infer)
# extra state for cuda raymarching
self.cuda_ray = cuda_ray
if cuda_ray:
# density grid
density_grid = torch.zeros(
[self.cascade, self.grid_size ** 3]
) # [CAS, H * H * H]
density_bitfield = torch.zeros(
self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8
) # [CAS * H * H * H // 8]
self.register_buffer("density_grid", density_grid)
self.register_buffer("density_bitfield", density_bitfield)
self.mean_density = 0
self.iter_density = 0
# step counter
step_counter = torch.zeros(
16, 2, dtype=torch.int32
) # 16 is hardcoded for averaging...
self.register_buffer("step_counter", step_counter)
self.mean_count = 0
self.local_step = 0
def forward(self, x, d):
raise NotImplementedError()
# separated density and color query (can accelerate non-cuda-ray mode.)
def density(self, x):
raise NotImplementedError()
def color(self, x, d, mask=None, **kwargs):
raise NotImplementedError()
def reset_extra_state(self):
if not self.cuda_ray:
return
# density grid
self.density_grid.zero_()
self.mean_density = 0
self.iter_density = 0
# step counter
self.step_counter.zero_()
self.mean_count = 0
self.local_step = 0
def run(
self,
rays_o,
rays_d,
num_steps=128,
upsample_steps=128,
bg_color=None,
perturb=False,
**kwargs
):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# bg_color: [3] in range [0, 1]
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
# choose aabb
aabb = self.aabb_train if self.training else self.aabb_infer
# sample steps
nears, fars = raymarching.near_far_from_aabb(
rays_o, rays_d, aabb, self.min_near
)
nears.unsqueeze_(-1)
fars.unsqueeze_(-1)
# print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}')
z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(
0
) # [1, T]
z_vals = z_vals.expand((N, num_steps)) # [N, T]
z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars]
# perturb z_vals
sample_dist = (fars - nears) / num_steps
if perturb:
z_vals = (
z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist
)
# z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs.
# generate xyzs
xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(
-1
) # [N, 1, 3] * [N, T, 1] -> [N, T, 3]
xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip.
# plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
# query SDF and RGB
density_outputs = self.density(xyzs.reshape(-1, 3))
# sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T]
for k, v in density_outputs.items():
density_outputs[k] = v.view(N, num_steps, -1)
# upsample z_vals (nerf-like)
if upsample_steps > 0:
with torch.no_grad():
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1]
deltas = torch.cat(
[deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1
)
alphas = 1 - torch.exp(
-deltas * self.density_scale * density_outputs["sigma"].squeeze(-1)
) # [N, T]
alphas_shifted = torch.cat(
[torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1
) # [N, T+1]
weights = (
alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1]
) # [N, T]
# sample new z_vals
z_vals_mid = z_vals[..., :-1] + 0.5 * deltas[..., :-1] # [N, T-1]
new_z_vals = sample_pdf(
z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training
).detach() # [N, t]
new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(
-2
) * new_z_vals.unsqueeze(
-1
) # [N, 1, 3] * [N, t, 1] -> [N, t, 3]
new_xyzs = torch.min(
torch.max(new_xyzs, aabb[:3]), aabb[3:]
) # a manual clip.
# only forward new points to save computation
new_density_outputs = self.density(new_xyzs.reshape(-1, 3))
# new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t]
for k, v in new_density_outputs.items():
new_density_outputs[k] = v.view(N, upsample_steps, -1)
# re-order
z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t]
z_vals, z_index = torch.sort(z_vals, dim=1)
xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3]
xyzs = torch.gather(
xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)
)
for k in density_outputs:
tmp_output = torch.cat(
[density_outputs[k], new_density_outputs[k]], dim=1
)
density_outputs[k] = torch.gather(
tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)
)
deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1]
deltas = torch.cat(
[deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1
)
alphas = 1 - torch.exp(
-deltas * self.density_scale * density_outputs["sigma"].squeeze(-1)
) # [N, T+t]
alphas_shifted = torch.cat(
[torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1
) # [N, T+t+1]
weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t]
dirs = rays_d.view(-1, 1, 3).expand_as(xyzs)
for k, v in density_outputs.items():
density_outputs[k] = v.view(-1, v.shape[-1])
mask = weights > 1e-4 # hard coded
rgbs = self.color(
xyzs.reshape(-1, 3),
dirs.reshape(-1, 3),
mask=mask.reshape(-1),
**density_outputs
)
rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3]
# print(xyzs.shape, 'valid_rgb:', mask.sum().item())
# calculate weight_sum (mask)
weights_sum = weights.sum(dim=-1) # [N]
# calculate depth
ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1)
depth = torch.sum(weights * ori_z_vals, dim=-1)
# calculate color
image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1]
# mix background color
if self.bg_radius > 0:
# use the bg model to calculate bg_color
polar = raymarching.polar_from_ray(
rays_o, rays_d, self.bg_radius
) # [N, 2] in [-1, 1]
bg_color = self.background(polar, rays_d.reshape(-1, 3)) # [N, 3]
elif bg_color is None:
bg_color = 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
# tmp: reg loss in mip-nerf 360
# z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1)
# mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T]
# loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum()
return {
"depth": depth,
"image": image,
}
def run_cuda(
self,
rays_o,
rays_d,
dt_gamma=0,
bg_color=None,
perturb=False,
force_all_rays=False,
max_steps=1024,
inherited_params=[],
**kwargs
):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: image: [B, N, 3], depth: [B, N]
prefix = rays_o.shape[:-1]
rays_o = rays_o.contiguous().view(-1, 3)
rays_d = rays_d.contiguous().view(-1, 3)
N = rays_o.shape[0] # N = B * N, in fact
device = rays_o.device
# pre-calculate near far
nears, fars = raymarching.near_far_from_aabb(
rays_o,
rays_d,
self.aabb_train if self.training else self.aabb_infer,
self.min_near,
)
# mix background color
if self.bg_radius > 0:
# use the bg model to calculate bg_color
polar = raymarching.polar_from_ray(
rays_o, rays_d, self.bg_radius
) # [N, 2] in [-1, 1]
bg_color = self.background(polar, rays_d) # [N, 3]
elif bg_color is None:
bg_color = 1
if self.training: # different with testing
# setup counter
time1 = time()
counter = self.step_counter[self.local_step % 16]
counter.zero_() # set to 0
self.local_step += 1
if (
self.args.render_stu_first
): # if stu first, then using stu to calculate xyzs, and tea will inherite the xyzs
"""
About xyzs, dirs, deltas, rays:
xyzs, dirs are all spatial points sampled by rays_o and rays_d;
rays: xyzs[rays[i, 1]:rays[i,1]+rays[i, 2]] --> points belonging to rays[i, 0]
deltas: shape is [point_nums, 2]. deltas means all generated points' deltas. (first for RGB, second for Depth)
"""
if not self.is_teacher:
xyzs, dirs, deltas, rays = raymarching.march_rays_train(
rays_o,
rays_d,
self.bound,
self.density_bitfield,
self.cascade,
self.grid_size,
nears,
fars,
counter,
self.mean_count,
perturb,
128,
force_all_rays,
dt_gamma,
max_steps,
)
inherited_params = [xyzs, dirs, deltas, rays]
else:
xyzs, dirs, deltas, rays = inherited_params
else:
if self.is_teacher:
xyzs, dirs, deltas, rays = raymarching.march_rays_train(
rays_o,
rays_d,
self.bound,
self.density_bitfield,
self.cascade,
self.grid_size,
nears,
fars,
counter,
self.mean_count,
perturb,
128,
force_all_rays,
dt_gamma,
max_steps,
)
inherited_params = [xyzs, dirs, deltas, rays]
else:
xyzs, dirs, deltas, rays = inherited_params
# plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy())
sigmas, rgbs = self(xyzs, dirs)
if self.args.global_step < self.args.stage_iters["stage1"]:
return {
"stage1": self.args.global_step,
"depth": None,
"image": None,
"inherited_params": inherited_params,
"sigmas": sigmas,
"rays": rays,
}
elif self.args.global_step < self.args.stage_iters["stage2"]:
return {
"stage2": self.args.global_step,
"depth": None,
"image": None,
"inherited_params": inherited_params,
"sigmas": sigmas,
"rays": rays,
}
sigmas = self.density_scale * sigmas
weights_sum, depth, image = raymarching.composite_rays_train(
sigmas, rgbs, deltas, rays
)
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
depth = torch.clamp(depth - nears, min=0) / (fars - nears + 1e-6)
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
else:
# allocate outputs
# if use autocast, must init as half so it won't be autocasted and lose reference.
# dtype = torch.half if torch.is_autocast_enabled() else torch.float32
# output should always be float32! only network inference uses half.
dtype = torch.float32
weights_sum = torch.zeros(N, dtype=dtype, device=device)
depth = torch.zeros(N, dtype=dtype, device=device)
image = torch.zeros(N, 3, dtype=dtype, device=device)
n_alive = N
alive_counter = torch.zeros([1], dtype=torch.int32, device=device)
rays_alive = torch.zeros(
2, n_alive, dtype=torch.int32, device=device
) # 2 is used to loop old/new
rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device)
step = 0
i = 0
while step < max_steps:
# count alive rays
if step == 0:
# init rays at first step.
torch.arange(n_alive, out=rays_alive[0])
rays_t[0] = nears
else:
alive_counter.zero_()
raymarching.compact_rays(
n_alive,
rays_alive[i % 2],
rays_alive[(i + 1) % 2],
rays_t[i % 2],
rays_t[(i + 1) % 2],
alive_counter,
)
n_alive = alive_counter.item() # must invoke D2H copy here
# exit loop
if n_alive <= 0:
break
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)
xyzs, dirs, deltas = raymarching.march_rays(
n_alive,
n_step,
rays_alive[i % 2],
rays_t[i % 2],
rays_o,
rays_d,
self.bound,
self.density_bitfield,
self.cascade,
self.grid_size,
nears,
fars,
128,
perturb,
dt_gamma,
max_steps,
)
sigmas, rgbs = self(xyzs, dirs)
# density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
# sigmas = density_outputs['sigma']
# rgbs = self.color(xyzs, dirs, **density_outputs)
sigmas = self.density_scale * sigmas
raymarching.composite_rays(
n_alive,
n_step,
rays_alive[i % 2],
rays_t[i % 2],
sigmas,
rgbs,
deltas,
weights_sum,
depth,
image,
)
# print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')
step += n_step
i += 1
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
# print('\n--- render time:--- {:6f} {:.6f}'.format(time2-time1, time()-time2))
if self.training:
return {
"depth": depth,
"image": image,
"inherited_params": inherited_params,
"sigmas": sigmas,
"rays": rays,
}
else:
return {
"depth": depth,
"image": image,
"inherited_params": inherited_params,
}
@torch.no_grad()
def mark_untrained_grid(self, poses, intrinsic, S=64):
# poses: [B, 4, 4]
# intrinsic: [3, 3]
if not self.cuda_ray:
return
if isinstance(poses, np.ndarray):
poses = torch.from_numpy(poses)
B = poses.shape[0]
fx, fy, cx, cy = intrinsic
X = torch.arange(
self.grid_size, dtype=torch.int32, device=self.density_grid.device
).split(S)
Y = torch.arange(
self.grid_size, dtype=torch.int32, device=self.density_grid.device
).split(S)
Z = torch.arange(
self.grid_size, dtype=torch.int32, device=self.density_grid.device
).split(S)
count = torch.zeros_like(self.density_grid)
poses = poses.to(count.device)
# 5-level loop, forgive me...
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat(
[xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
dim=-1,
) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
world_xyzs = (
2 * coords.float() / (self.grid_size - 1) - 1
).unsqueeze(
0
) # [1, N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_world_xyzs = world_xyzs * (bound - half_grid_size)
# split batch to avoid OOM
head = 0
while head < B:
tail = min(head + S, B)
# world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.)
cam_xyzs = cas_world_xyzs - poses[
head:tail, :3, 3
].unsqueeze(1)
cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3]
# query if point is covered by any camera
mask_z = cam_xyzs[:, :, 2] > 0 # [S, N]
mask_x = (
torch.abs(cam_xyzs[:, :, 0])
< cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2
)
mask_y = (
torch.abs(cam_xyzs[:, :, 1])
< cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2
)
mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N]
# update count
count[cas, indices] += mask
head += S
# mark untrained grid as -1
self.density_grid[count == 0] = -1
# print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}')
@torch.no_grad()
def update_extra_state(self, decay=0.95, S=128):
# call before each epoch to update extra states.
if not self.cuda_ray:
return
# update density grid
tmp_grid = -torch.ones_like(self.density_grid)
# full update.
if self.iter_density < 16:
# if True:
X = torch.arange(
self.grid_size, dtype=torch.int32, device=self.density_grid.device
).split(S)
Y = torch.arange(
self.grid_size, dtype=torch.int32, device=self.density_grid.device
).split(S)
Z = torch.arange(
self.grid_size, dtype=torch.int32, device=self.density_grid.device
).split(S)
for xs in X:
for ys in Y:
for zs in Z:
# construct points
xx, yy, zz = custom_meshgrid(xs, ys, zs)
coords = torch.cat(
[xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
dim=-1,
) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
xyzs = (
2 * coords.float() / (self.grid_size - 1) - 1
) # [N, 3] in [-1, 1]
# cascading
for cas in range(self.cascade):
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_xyzs = xyzs * (bound - half_grid_size)
# add noise in [-hgs, hgs]
cas_xyzs += (
torch.rand_like(cas_xyzs) * 2 - 1
) * half_grid_size
# query density
sigmas = (
self.density(cas_xyzs)["sigma"].reshape(-1).detach()
)
sigmas *= self.density_scale
# assign
tmp_grid[cas, indices] = sigmas
# partial update (half the computation)
# TODO: why no need of maxpool ?
else:
N = self.grid_size ** 3 // 4 # H * H * H / 4
for cas in range(self.cascade):
# random sample some positions
coords = torch.randint(
0, self.grid_size, (N, 3), device=self.density_grid.device
) # [N, 3], in [0, 128)
indices = raymarching.morton3D(coords).long() # [N]
# random sample occupied positions
occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze(
-1
) # [Nz]
rand_mask = torch.randint(
0,
occ_indices.shape[0],
[N],
dtype=torch.long,
device=self.density_grid.device,
)
occ_indices = occ_indices[
rand_mask
] # [Nz] --> [N], allow for duplication
occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3]
# concat
indices = torch.cat([indices, occ_indices], dim=0)
coords = torch.cat([coords, occ_coords], dim=0)
# same below
xyzs = (
2 * coords.float() / (self.grid_size - 1) - 1
) # [N, 3] in [-1, 1]
bound = min(2 ** cas, self.bound)
half_grid_size = bound / self.grid_size
# scale to current cascade's resolution
cas_xyzs = xyzs * (bound - half_grid_size)
# add noise in [-hgs, hgs]
cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size
# query density
sigmas = self.density(cas_xyzs)["sigma"].reshape(-1).detach()
sigmas *= self.density_scale
# assign
tmp_grid[cas, indices] = sigmas
## max-pool on tmp_grid for less aggressive culling [No significant improvement...]
# invalid_mask = tmp_grid < 0
# tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1)
# tmp_grid[invalid_mask] = -1
# ema update
valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0)
self.density_grid[valid_mask] = torch.maximum(
self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]
)
self.mean_density = torch.mean(
self.density_grid.clamp(min=0)
).item() # -1 non-training regions are viewed as 0 density.
self.iter_density += 1
# convert to bitfield
density_thresh = min(self.mean_density, self.density_thresh)
self.density_bitfield = raymarching.packbits(
self.density_grid, density_thresh, self.density_bitfield
)
### update step counter
total_step = min(16, self.local_step)
if total_step > 0:
self.mean_count = int(
self.step_counter[:total_step, 0].sum().item() / total_step
)
self.local_step = 0
# print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}')
def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: pred_rgb: [B, N, 3]
if self.cuda_ray:
_run = self.run_cuda
else:
_run = self.run
B, N = rays_o.shape[:2]
device = rays_o.device
# never stage when cuda_ray
if staged and not self.cuda_ray:
depth = torch.empty((B, N), device=device)
image = torch.empty((B, N, 3), device=device)
for b in range(B):
head = 0
while head < N:
tail = min(head + max_ray_batch, N)
results_ = _run(
rays_o[b : b + 1, head:tail],
rays_d[b : b + 1, head:tail],
**kwargs
)
depth[b : b + 1, head:tail] = results_["depth"]
image[b : b + 1, head:tail] = results_["image"]
head += max_ray_batch
results = {}
results["depth"] = depth
results["image"] = image
else:
results = _run(rays_o, rays_d, **kwargs)
return results
================================================
FILE: distill_mutual/utils.py
================================================
import os
import copy
import lpips
import glob
import tqdm
import math
import random
import warnings
import tensorboardX
import numpy as np
import pandas as pd
import imageio
import time
from datetime import datetime
import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader
import trimesh
import mcubes
from rich.console import Console
from torch_ema import ExponentialMovingAverage
from IPython import embed
import sys
from packaging import version as pver
device = torch.device("cuda")
TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision
def update_loss_rate(cur_lrate, scale=0.99):
return cur_lrate * scale
def get_softmax_map_mean(a, b):
return (F.softmax(a) - F.softmax(b)).abs().mean()
def get_kl(inputs, targets):
return F.kl_div(F.log_softmax(inputs), F.softmax(targets), reduction="sum")
def nerf_matrix_to_ngp(pose, scale=0.8):
# for the fox dataset, 0.33 scales camera radius to ~ 2
new_pose = np.array(
[
[pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],
[pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],
[pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],
[0, 0, 0, 1],
],
dtype=np.float32,
)
return new_pose
def pose_spherical(theta, phi, radius):
# for synthetic. it generates sphere random poses
trans_t = lambda t: np.array(
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]]
).astype(np.float32)
rot_phi = lambda phi: np.array(
[
[1, 0, 0, 0],
[0, np.cos(phi), -np.sin(phi), 0],
[0, np.sin(phi), np.cos(phi), 0],
[0, 0, 0, 1],
]
).astype(np.float32)
rot_theta = lambda th: np.array(
[
[np.cos(th), 0, -np.sin(th), 0],
[0, 1, 0, 0],
[np.sin(th), 0, np.cos(th), 0],
[0, 0, 0, 1],
]
).astype(np.float32)
c2w = trans_t(radius)
c2w = rot_phi(phi / 180.0 * np.pi) @ c2w
c2w = rot_theta(theta / 180.0 * np.pi) @ c2w
c2w = (
np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).astype(
np.float32
)
@ c2w
)
return c2w
def get_rand_poses(data_type="synthetic", original_loader=None):
"""
Random sampling. Random origins and directions.
"""
from scipy.spatial.transform import Slerp, Rotation
assert data_type in {"synthetic", "llff", "tank"}
def get_single_syn_pose(ph, rand_radius=False):
theta1 = -180
theta2 = 180
phi1 = -ph
phi2 = 5 - ph if (5 - ph) <= 0 else 0
theta = theta1 + np.random.rand() * (theta2 - theta1)
phi = phi1 + np.random.rand() * (phi2 - phi1)
if rand_radius:
radius = np.random.uniform(3, 4)
else:
radius = 4
return pose_spherical(theta, phi, radius)
def get_syn_poses():
random_poses = np.array([get_single_syn_pose(8) for _ in range(1)])
for a in range(0, 80):
rp = np.array(
[get_single_syn_pose(a) for _ in range(int(((90 - a) // 15) ** 1 + 1))]
)
random_poses = np.concatenate([random_poses, rp], axis=0)
for i in range(len(random_poses)):
random_poses[i] = nerf_matrix_to_ngp(random_poses[i])
print(f"\nlen(train data): {len(random_poses)}\n")
random_poses = torch.from_numpy(random_poses).cuda()
return random_poses
def get_tank_poses():
random_poses = np.array([get_single_syn_pose(8) for _ in range(1)])
for a in range(5, 20):
rp = np.array(
[
get_single_syn_pose(a, True)
for _ in range(int(((90 - a) // 15) ** 1 + 1))
]
)
random_poses = np.concatenate([random_poses, rp], axis=0)
for i in range(len(random_poses)):
random_poses[i] = nerf_matrix_to_ngp(random_poses[i])
print(f"\nlen(train data): {len(random_poses)}\n")
random_poses = torch.from_numpy(random_poses).cuda()
return random_poses
def rand_poses_from_cam_centers(centers):
def normalize(vectors):
return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
size = len(centers)
forward_vector = -normalize(centers)
up_vector = (
torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
) # confused at the coordinate system...
right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))
up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))
poses = (
torch.eye(4, dtype=torch.float, device=device)
.unsqueeze(0)
.repeat(size, 1, 1)
)
poses[:, :3, :3] = torch.stack(
(right_vector, up_vector, forward_vector), dim=-1
)
poses[:, :3, 3] = centers
return poses
def get_llff_poses_rand():
def get_rand_cam_centers_from_bbox(poses, gen_num=30):
# use poses to estimate the bbox of the camera
trasitions = poses[:, :3, 3]
bbox_max = trasitions.max(axis=0) + 1e-6
bbox_min = trasitions.min(axis=0) - 1e-6
rand_xs = np.random.uniform(low=bbox_min[0], high=bbox_max[0], size=gen_num)
rand_ys = np.random.uniform(low=bbox_min[1], high=bbox_max[1], size=gen_num)
rand_zs = np.random.uniform(low=bbox_min[2], high=bbox_max[2], size=gen_num)
centers = np.stack([rand_xs, rand_ys, rand_zs], axis=1)
return centers.astype(np.float32)
centers = get_rand_cam_centers_from_bbox(original_loader)
random_poses = rand_poses_from_cam_centers(torch.from_numpy(centers).cuda())
random_poses[:, 0, 0] = -random_poses[:, 0, 0]
return random_poses
if data_type == "synthetic":
random_poses = get_syn_poses()
elif data_type == "llff":
random_poses = get_llff_poses_rand()
elif data_type == "tank":
random_poses = get_tank_poses()
else:
raise ValueError("illegal")
return random_poses
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse("1.10"):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing="ij")
@torch.jit.script
def linear_to_srgb(x):
return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055)
@torch.jit.script
def srgb_to_linear(x):
return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
def compute_ssim(
img0,
img1,
max_val,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
return_map=False,
):
"""Computes SSIM from two images.
This function was modeled after tf.image.ssim, and should produce comparable
output.
Args:
img0: torch.tensor. An image of size [..., width, height, num_channels].
img1: torch.tensor. An image of size [..., width, height, num_channels].
max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
filter_size: int >= 1. Window size.
filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
k1: float > 0. One of the SSIM dampening parameters.
k2: float > 0. One of the SSIM dampening parameters.
return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned
Returns:
Each image's mean SSIM, or a tensor of individual values if `return_map`.
"""
device = img0.device
img0 = img0.type(torch.float32)
img1 = img1.type(torch.float32)
ori_shape = img0.size()
width, height, num_channels = ori_shape[-3:]
img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2)
img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2)
batch_size = img0.shape[0]
# Construct a 1D Gaussian blur filter.
hw = filter_size // 2
shift = (2 * hw - filter_size + 1) / 2
f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2
filt = torch.exp(-0.5 * f_i)
filt /= torch.sum(filt)
# Blur in x and y (faster than the 2D convolution).
# z is a tensor of size [B, H, W, C]
filt_fn1 = lambda z: F.conv2d(
z,
filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1),
padding=[hw, 0],
groups=num_channels,
)
filt_fn2 = lambda z: F.conv2d(
z,
filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1),
padding=[0, hw],
groups=num_channels,
)
# Vmap the blurs to the tensor size, and then compose them.
filt_fn = lambda z: filt_fn1(filt_fn2(z))
mu0 = filt_fn(img0)
mu1 = filt_fn(img1)
mu00 = mu0 * mu0
mu11 = mu1 * mu1
mu01 = mu0 * mu1
sigma00 = filt_fn(img0 ** 2) - mu00
sigma11 = filt_fn(img1 ** 2) - mu11
sigma01 = filt_fn(img0 * img1) - mu01
# Clip the variances and covariances to valid values.
# Variance must be non-negative:
sigma00 = torch.clamp(sigma00, min=0.0)
sigma11 = torch.clamp(sigma11, min=0.0)
sigma01 = torch.sign(sigma01) * torch.min(
torch.sqrt(sigma00 * sigma11), torch.abs(sigma01)
)
c1 = (k1 * max_val) ** 2
c2 = (k2 * max_val) ** 2
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
ssim_map = numer / denom
ssim = torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1)
return ssim_map if return_map else ssim
def init_lpips(net_name, device):
assert net_name in ["alex", "vgg"]
import lpips
print(f"init_lpips: lpips_{net_name}")
return lpips.LPIPS(net=net_name, version="0.1").eval().cuda()
lpips_fns = {
"alex": lpips.LPIPS(net="alex", version="0.1").eval().cuda(),
"vgg": lpips.LPIPS(net="vgg", version="0.1").eval().cuda(),
}
def rgb_lpips(gt, im, net_name):
assert net_name in ["alex", "vgg"]
gt = gt.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda()
im = im.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda()
return lpips_fns[net_name](gt, im, normalize=True).item()
@torch.cuda.amp.autocast(enabled=False)
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
"""get rays
Args:
poses: [B, 4, 4], cam2world
intrinsics: [4]
H, W, N: int
error_map: [B, 128 * 128], sample probability based on training error
Returns:
rays_o, rays_d: [B, N, 3]
inds: [B, N]
"""
device = poses.device
B = poses.shape[0]
fx, fy, cx, cy = intrinsics
i, j = custom_meshgrid(
torch.linspace(0, W - 1, W, device=device),
torch.linspace(0, H - 1, H, device=device),
)
i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5
j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5
results = {}
if N > 0:
N = min(N, H * W)
if error_map is None:
inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate
inds = inds.expand([B, N])
else:
# weighted sample on a low-reso grid
inds_coarse = torch.multinomial(
error_map.to(device), N, replacement=False
) # [B, N], but in [0, 128*128)
# map to the original resolution with random perturb.
inds_x, inds_y = (
inds_coarse // 128,
inds_coarse % 128,
) # `//` will throw a warning in torch 1.10... anyway.
sx, sy = H / 128, W / 128
inds_x = (
(inds_x * sx + torch.rand(B, N, device=device) * sx)
.long()
.clamp(max=H - 1)
)
inds_y = (
(inds_y * sy + torch.rand(B, N, device=device) * sy)
.long()
.clamp(max=W - 1)
)
inds = inds_x * W + inds_y
results["inds_coarse"] = inds_coarse # need this when updating error_map
i = torch.gather(i, -1, inds)
j = torch.gather(j, -1, inds)
results["inds"] = inds
else:
inds = torch.arange(H * W, device=device).expand([B, H * W])
zs = torch.ones_like(i)
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
directions = torch.stack((xs, ys, zs), dim=-1)
directions = directions / torch.norm(directions, dim=-1, keepdim=True)
rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3)
rays_o = poses[..., :3, 3] # [B, 3]
rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3]
results["rays_o"] = rays_o
results["rays_d"] = rays_d
return results
def seed_everything(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
def torch_vis_2d(x, renormalize=False):
# x: [3, H, W] or [1, H, W] or [H, W]
import matplotlib.pyplot as plt
import numpy as np
import torch
if isinstance(x, torch.Tensor):
if len(x.shape) == 3:
x = x.permute(1, 2, 0).squeeze()
x = x.detach().cpu().numpy()
print(f"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}")
x = x.astype(np.float32)
# renormalize
if renormalize:
x = (x - x.min(axis=0, keepdims=True)) / (
x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8
)
plt.imshow(x)
plt.show()
def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)
Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)
Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)
u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
with torch.no_grad():
for xi, xs in enumerate(X):
for yi, ys in enumerate(Y):
for zi, zs in enumerate(Z):
xx, yy, zz = custom_meshgrid(xs, ys, zs)
pts = torch.cat(
[xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
dim=-1,
) # [S, 3]
val = (
query_func(pts)
.reshape(len(xs), len(ys), len(zs))
.detach()
.cpu()
.numpy()
) # [S, 1] --> [x, y, z]
u[
xi * S : xi * S + len(xs),
yi * S : yi * S + len(ys),
zi * S : zi * S + len(zs),
] = val
return u
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
# print('threshold: {}'.format(threshold))
u = extract_fields(bound_min, bound_max, resolution, query_func)
# print(u.shape, u.max(), u.min(), np.percentile(u, 50))
vertices, triangles = mcubes.marching_cubes(u, threshold)
b_max_np = bound_max.detach().cpu().numpy()
b_min_np = bound_min.detach().cpu().numpy()
vertices = (
vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :]
+ b_min_np[None, :]
)
return vertices, triangles
class PSNRMeter:
def __init__(self):
self.V = 0
self.N = 0
self.psnr_list = []
def clear(self):
self.V = 0
self.N = 0
self.psnr_list = []
def prepare_inputs(self, *inputs):
outputs = []
for i, inp in enumerate(inputs):
if torch.is_tensor(inp):
inp = inp.detach().cpu().numpy()
outputs.append(inp)
return outputs
def update(self, preds, truths):
preds, truths = self.prepare_inputs(
preds, truths
) # [B, N, 3] or [B, H, W, 3], range[0, 1]
psnr = -10 * np.log10(np.mean((preds - truths) ** 2))
self.psnr_list.append(psnr)
self.V += psnr
self.N += 1
assert self.N == len(self.psnr_list)
def measure(self):
return self.V / self.N
def write(self, writer, global_step, prefix=""):
writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step)
def report(self):
return f"PSNR = {self.measure():.6f}"
class Trainer(object):
def __init__(
self,
name, # name of this experiment
opt, # extra conf
model_tea, # network
model_stu,
criterion=None, # loss function, if None, assume inline implementation in train_step
optimizer=None, # optimizer
ema_decay=None, # if use EMA, set the decay
lr_scheduler=None, # scheduler
metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
local_rank=0, # which GPU am I
world_size=1, # total num of GPUs
device=None, # device to use, usually setting to None is OK. (auto choose device)
mute=False, # whether to mute all print
fp16=False, # amp optimize level
eval_interval=10e10, # eval once every $ epoch
max_keep_ckpt=2, # max num of saved ckpts in disk
workspace="workspace", # workspace to save logs & ckpts
best_mode="min", # the smaller/larger result, the better
use_loss_as_metric=True, # use loss as the first metric
report_metric_at_train=False, # also report metrics at training
use_checkpoint="latest", # which ckpt to use at init time
use_tensorboardX=True, # whether to use tensorboard for logging
scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
):
self.optimizer_fn = optimizer
self.lr_scheduler_fn = lr_scheduler
self.name = name
self.opt = opt
self.args = opt
self.mute = mute
self.metrics = metrics
self.local_rank = local_rank
self.world_size = world_size
self.workspace = workspace
self.ema_decay = ema_decay
self.fp16 = fp16
self.best_mode = best_mode
self.use_loss_as_metric = use_loss_as_metric
self.report_metric_at_train = report_metric_at_train
self.max_keep_ckpt = max_keep_ckpt
self.eval_interval = eval_interval
self.use_checkpoint = use_checkpoint
self.use_tensorboardX = use_tensorboardX
self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
self.scheduler_update_every_step = scheduler_update_every_step
self.device = (
device
if device is not None
else torch.device(
f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
)
)
self.console = Console()
self.model_tea = model_tea.to(device)
self.model_stu = model_stu.to(device)
if isinstance(criterion, nn.Module):
criterion.to(self.device)
self.criterion = criterion
if optimizer is None:
self.optimizer = optim.AdamW(
self.model_stu.parameters(), lr=0.001, weight_decay=5e-4
) # naive adam
else:
self.optimizer = optimizer(self.model_stu)
if lr_scheduler is None:
self.lr_scheduler = optim.lr_scheduler.LambdaLR(
self.optimizer, lr_lambda=lambda epoch: 1
) # fake scheduler
else:
self.ls = lr_scheduler
self.lr_scheduler = lr_scheduler(self.optimizer)
if ema_decay is not None and ema_decay > 0:
self.ema = ExponentialMovingAverage(
self.model_stu.parameters(), decay=ema_decay
)
else:
self.ema = None
self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
# variable init
self.epoch = 1
self.global_step = 0
self.local_step = 0
self.stats = {
"loss": [],
"valid_loss": [],
"results": [], # metrics[0], or valid_loss
"checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt
"best_result": None,
}
# auto fix
if len(metrics) == 0 or self.use_loss_as_metric:
self.best_mode = "min"
# workspace prepare
self.log_ptr = None
if self.workspace is not None:
os.makedirs(self.workspace, exist_ok=True)
self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
self.log_ptr = open(self.log_path, "a+")
self.ckpt_path = os.path.join(self.workspace, "checkpoints")
self.best_path = f"{self.ckpt_path}/{self.name}.pth"
os.makedirs(self.ckpt_path, exist_ok=True)
self.log(self.opt)
self.log(
f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}'
)
self.log(
f"[INFO] #parameters: {sum([p.numel() for p in model_stu.parameters() if p.requires_grad])}"
)
if (
self.workspace is not None
): # only load state_dict for teacher and share backbone for student
self.log(f"[INFO] Loading teacher ckpt from {self.opt.ckpt_teacher} ...")
self.load_teacher_checkpoint()
self.log(self.model_tea)
self.load_student_checkpoint()
self.log(self.model_stu)
# self.model_tea.reset_extra_state()
# self.model_stu.reset_extra_state()
"""
if opt.rand_pose >= 0: # =0 means only using CLIP loss, >0 means a hybrid mode.
from nerf.clip_utils import CLIPLoss
self.clip_loss = CLIPLoss(self.device)
self.clip_loss.prepare_text([self.opt.clip_text]) # only support one text prompt now...
"""
def __del__(self):
if self.log_ptr:
self.log_ptr.close()
def log(self, *args, **kwargs):
if self.local_rank == 0:
if not self.mute:
# print(*args)
self.console.print(*args, **kwargs)
if self.log_ptr:
print(*args, file=self.log_ptr)
self.log_ptr.flush() # write immediately to file
def train(self, train_loader, valid_loader, max_epochs):
self.hard_rays_pool = [torch.tensor([]).cuda(), torch.tensor([]).cuda()]
self.is_hard_rays_pool_full = False
if self.use_tensorboardX and self.local_rank == 0:
self.writer = tensorboardX.SummaryWriter(
os.path.join(self.workspace, "run", self.name)
)
for p in self.model_tea.parameters():
p.requires_grad = False
self.model_tea.eval()
# get a ref to error_map
self.error_map = train_loader._data.error_map
if (
not self.args.use_real_data_for_train
): # using random poses to calculate max_epochs.
random_poses = get_rand_poses(
data_type=self.args.data_type,
original_loader=copy.deepcopy(
train_loader._data.poses.detach().cpu().numpy()
),
)
self.opt.iters = int(
(self.opt.iters // len(random_poses)) * len(random_poses)
)
max_epochs = np.ceil(self.opt.iters / len(random_poses)).astype(np.int32)
scheduler = lambda optimizer: optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=self.opt.iters * 1, eta_min=7e-5
) # update scheduler according to new opt.iters
self.lr_scheduler = scheduler(self.optimizer)
self.total_epoch = max_epochs
self.log(f"\n----------------total epoch:{max_epochs} -----------\n")
self.real_train_poses = copy.deepcopy(train_loader._data.poses)
for epoch in range(self.epoch, max_epochs + 1):
self.epoch = epoch
if not self.args.use_real_data_for_train:
print(f"\n generate new random poses at epoch{self.epoch}")
random_poses = get_rand_poses(
data_type=self.args.data_type,
original_loader=self.real_train_poses.detach().cpu().numpy(),
)
train_loader._data.poses = copy.deepcopy(random_poses)
train_loader._data.images = train_loader._data.images[:1].expand(
len(random_poses), -1, -1, -1
)
train_loader = train_loader._data.dataloader()
self.train_one_epoch(train_loader)
print("\n", self.workspace, "\n")
if (
self.workspace is not None
and self.local_rank == 0
and self.epoch > max_epochs - 1
):
self.save_checkpoint(full=False, best=False)
if self.epoch % self.eval_interval == 0:
self.evaluate_one_epoch(valid_loader)
self.save_checkpoint(full=False, best=True) # # 为了节省存储,暂时不存储pth
if self.use_tensorboardX and self.local_rank == 0:
self.writer.close()
def train_one_epoch(self, loader):
# self.log(
# f"tttttttttt> Start Training Epoch {self.epoch}/{self.total_epoch}, len(train_data):{len(loader)} lr={self.optimizer.param_groups[0]['lr']:.6f} ..."
# )
total_loss = 0
total_loss_rgb = 0
total_loss_fea_sc = 0
total_loss_sigma = 0
total_loss_color = 0
psnr_tool = PSNRMeter()
psnr_tool.clear()
self.pose_psnr = [] # [(pose1, psnr1), (pose2,psnr2)...]
if self.local_rank == 0 and self.report_metric_at_train:
for metric in self.metrics:
metric.clear()
self.model_stu.train()
self.model_tea.train()
if self.world_size > 1:
loader.sampler.set_epoch(self.epoch)
if self.local_rank == 0:
pbar = tqdm.tqdm(
total=len(loader) * loader.batch_size,
bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
)
self.local_step = 0
for data in loader:
# update grid every 16 steps. It shoule be run in just train a teacher, but not when distillting a student
if (
self.model_tea.cuda_ray
and self.global_step % self.opt.update_extra_interval == 0
):
with torch.cuda.amp.autocast(enabled=self.fp16):
if self.opt.update_stu_extra:
self.model_stu.update_extra_state()
else:
pass
self.local_step += 1
self.global_step += 1
self.args.global_step = self.global_step
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.fp16):
(
preds,
truths,
loss,
loss_rgb,
loss_fea_sc,
loss_color,
loss_sigma,
) = self.train_step(data)
if preds is not None:
psnr_tool.update(preds, truths)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
if self.scheduler_update_every_step:
self.lr_scheduler.step()
loss_val = loss.item()
total_loss += loss_val
total_loss_rgb += loss_rgb
total_loss_sigma += loss_sigma
total_loss_color += loss_color
total_loss_fea_sc += loss_fea_sc
if self.local_rank == 0:
if self.report_metric_at_train:
for metric in self.metrics:
metric.update(preds, truths)
if self.use_tensorboardX:
self.writer.add_scalar("train/loss", loss_val, self.global_step)
self.writer.add_scalar("train/loss_rgb", loss_rgb, self.global_step)
self.writer.add_scalar(
"train/loss_fea_sc", loss_fea_sc, self.global_step
)
self.writer.add_scalar(
"train/loss_coloc", loss_color, self.global_step
)
self.writer.add_scalar(
"train/loss_sigma", loss_sigma, self.global_step
)
self.writer.add_scalar(
"train/lr",
self.optimizer.param_groups[0]["lr"],
self.global_step,
)
if self.scheduler_update_every_step: # run this
cur_lr = self.optimizer.param_groups[0]["lr"]
if self.global_step < self.args.stage_iters["stage1"]:
pbar.set_description(
f"loss={total_loss/self.local_step:.5f}, fea_sc={total_loss_fea_sc/self.local_step:.5f}, lr={cur_lr:.5f}"
)
elif self.global_step < self.args.stage_iters["stage2"]:
pbar.set_description(
f"loss={total_loss/self.local_step:.5f}, fea_sc={total_loss_fea_sc/self.local_step:.5f}, sigma={total_loss_sigma/self.local_step:.5f}, color={total_loss_color/self.local_step:.5f}, lr={cur_lr:.6f}"
)
else:
pbar.set_description(
f"loss={total_loss/self.local_step:.5f}, rgb={total_loss_rgb/self.local_step:.5f}, lr={cur_lr:.5f}"
)
else:
pbar.set_description(
f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})"
)
pbar.update(loader.batch_size)
if (
self.opt.model_type == "vm"
and self.global_step in self.opt.upsample_model_steps
):
# shrink
if (
self.model_stu.cuda_ray
): # and self.global_step == self.opt.upsample_model_steps[0]:
self.model_stu.shrink_model()
# adaptive voxel size from aabb_train
n_vox = self.upsample_resolutions.pop(0) ** 3 # n_voxels
aabb = self.model_stu.aabb_train.cpu().numpy()
vox_size = np.cbrt(np.prod(aabb[3:] - aabb[:3]) / n_vox)
reso = ((aabb[3:] - aabb[:3]) / vox_size).astype(np.int32).tolist()
self.log(
f"[INFO] upsample model at step {self.global_step} from {self.model_stu.resolution} to {reso}"
)
from IPython import embed
embed()
self.model_stu.upsample_model(reso)
# reset optimizer since params changed.
self.optimizer = self.optimizer_fn(self.model_stu)
self.lr_scheduler = self.lr_scheduler_fn(self.optimizer)
if self.ema is not None:
self.ema.update()
average_loss = total_loss / self.local_step
self.stats["loss"].append(average_loss)
if self.local_rank == 0:
pbar.close()
if self.report_metric_at_train:
for metric in self.metrics:
self.log(metric.report(), style="red")
if self.use_tensorboardX:
metric.write(self.writer, self.epoch, prefix="train")
metric.clear()
if not self.scheduler_update_every_step:
if isinstance(
self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
):
self.lr_scheduler.step(average_loss)
else:
self.lr_scheduler.step()
psnr_tool.psnr_list.sort()
if self.global_step < self.args.stage_iters["stage1"]:
self.log(
f"tttttttttt> Train stage1 Epoch:{self.epoch}. loss_fea:{total_loss_fea_sc/self.local_step:.6f}"
)
elif self.global_step < self.args.stage_iters["stage2"]:
self.log(
f"tttttttttt> Train stage2 Epoch:{self.epoch}. loss_fea_sc:{total_loss_fea_sc/self.local_step:.3f} loss_sigma:{total_loss_sigma/self.local_step:.3f} loss_color:{total_loss_color/self.local_step:.3f}"
)
else:
self.log(
f"tttttttttt> Train stage3 Epoch:{self.epoch}. loss_rgb:{total_loss_rgb/self.global_step:.3f} loss_fea_sc:{total_loss_fea_sc/self.local_step:.3f} loss_sigma:{total_loss_sigma/self.local_step:.3f} loss_color:{total_loss_color/self.local_step:.3f}"
)
self.log(
f"tttttttttt> Train PSNR Epoch {self.epoch}. psnr_min:{psnr_tool.psnr_list[0]:.3f} psnr_max:{psnr_tool.psnr_list[-1]:.3f} psnr_mean:{np.mean(psnr_tool.psnr_list):.3f}"
)
def get_loss(self, pred, gt):
if self.opt.loss_type == "L2":
loss = torch.mean((gt - pred) ** 2)
elif self.opt.loss_type == "normL2":
loss = torch.norm(pred - gt)
elif self.opt.loss_type == "normL1":
loss = torch.norm(pred - gt, p=1)
elif self.opt.loss_type == "smoothL1":
loss = torch.nn.functional.smooth_l1_loss(pred, gt, beta=0.05)
else:
raise ValueError("error loss_type")
return loss
def train_step(self, data):
rays_o = data["rays_o"] # [B, N, 3]
rays_d = data["rays_d"] # [B, N, 3] [1, N=rays_num=4096, 3]
loss = 0.0
# if there is no gt image, we train with CLIP loss.
if "images" not in data:
assert 1 == 2
B, N = rays_o.shape[:2]
H, W = data["H"], data["W"]
# currently fix white bg, MUST force all rays!
outputs = self.model.render(
rays_o,
rays_d,
staged=False,
bg_color=None,
perturb=True,
force_all_rays=True,
**vars(self.opt),
)
pred_rgb = (
outputs["image"].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous()
)
loss = self.clip_loss(pred_rgb)
return pred_rgb, None, loss
images = data["images"] # [B, N, 3/4]
B, N, C = images.shape
# if self.opt.color_space == 'linear':
# images[..., :3] = srgb_to_linear(images[..., :3])
if (
C == 3 or self.model_stu.bg_radius > 0
): # C=4 in synthetic dataset. C=3 for real dataset
bg_color = 1
# train with random background color if not using a bg model and has alpha channel.
else:
bg_color = torch.rand(
[B, rays_o.size(1), 3], dtype=images.dtype, device=images.device
)
if self.opt.render_stu_first:
outputs_stu = self.model_stu.render(
rays_o,
rays_d,
staged=False,
bg_color=bg_color,
perturb=True,
force_all_rays=False,
**vars(self.opt),
)
pred_rgb_stu = outputs_stu["image"]
with torch.no_grad():
outputs_tea = self.model_tea.render(
rays_o,
rays_d,
staged=False,
bg_color=bg_color,
perturb=True,
force_all_rays=False,
inherited_params=outputs_stu["inherited_params"],
**vars(self.opt),
)
pred_rgb_tea = outputs_tea["image"]
else:
with torch.no_grad():
outputs_tea = self.model_tea.render(
rays_o,
rays_d,
staged=False,
bg_color=bg_color,
perturb=True,
force_all_rays=False,
**vars(self.opt),
)
pred_rgb_tea = outputs_tea["image"]
outputs_stu = self.model_stu.render(
rays_o,
rays_d,
staged=False,
bg_color=bg_color,
perturb=True,
force_all_rays=False,
inherited_params=outputs_tea["inherited_params"],
**vars(self.opt),
)
pred_rgb_stu = outputs_stu["image"]
gt_rgb = pred_rgb_tea
self.opt.loss_rate_fea_sc = update_loss_rate(self.opt.loss_rate_fea_sc, 0.995)
if (
"stage1" in outputs_stu
and self.opt.loss_rate_fea_sc > 0.0
and self.model_stu.feature_sigma_color is not None
and self.model_tea.feature_sigma_color is not None
):
assert (
self.model_stu.feature_sigma_color.shape
== self.model_tea.feature_sigma_color.shape
)
loss_fea_sc = self.get_loss(
self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color
)
loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc
return None, None, loss, 0, loss_fea_sc.detach().item(), 0, 0
if "stage2" in outputs_stu:
if self.opt.loss_rate_color > 0.0:
assert self.model_stu.color_l.shape == self.model_tea.color_l.shape
loss_color = self.get_loss(
self.model_stu.color_l, self.model_tea.color_l
)
loss = loss + self.opt.loss_rate_color * loss_color
else:
assert self.model_stu.color_l.shape == self.model_tea.color_l.shape
loss_color = self.get_loss(
self.model_stu.color_l, self.model_tea.color_l
)
if self.opt.loss_rate_sigma > 0.0:
assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape
loss_sigma = self.get_loss(
self.model_stu.sigma_l, self.model_tea.sigma_l
)
loss = loss + self.opt.loss_rate_sigma * loss_sigma
else:
assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape
loss_sigma = self.get_loss(
self.model_stu.sigma_l, self.model_tea.sigma_l
)
if (
self.opt.loss_rate_fea_sc > 0.0
and self.model_stu.feature_sigma_color is not None
and self.model_tea.feature_sigma_color is not None
):
assert (
self.model_stu.feature_sigma_color.shape
== self.model_tea.feature_sigma_color.shape
)
loss_fea_sc = self.get_loss(
self.model_stu.feature_sigma_color,
self.model_tea.feature_sigma_color,
)
loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc
else:
loss_fea_sc = torch.tensor(0.0)
return (
None,
None,
loss,
0,
loss_fea_sc.detach().item(),
loss_color.detach().item(),
loss_sigma.detach().item(),
)
if self.opt.loss_type == "normL2":
loss_rgb = torch.norm(pred_rgb_tea - pred_rgb_stu)
elif self.opt.loss_type == "normL1":
loss_rgb = torch.norm(pred_rgb_tea - pred_rgb_stu, p=1)
elif self.opt.loss_type == "L2":
loss_rgb = self.criterion(pred_rgb_tea, pred_rgb_stu).mean(
-1
) # [B, N, 3] --> [B, N]
if len(loss_rgb.shape) == 3: # [K, B, N]
loss_rgb = loss_rgb.mean(0)
if self.error_map is not None:
index = data["index"] # [B]
inds = data["inds_coarse"] # [B, N]
error_map = self.error_map[index] # [B, H * W]
error = loss_rgb.detach().to(
error_map.device
) # [B, N], already in [0, 1]
ema_error = 0.1 * error_map.gather(1, inds) + 0.9 * error # ema update
error_map.scatter_(1, inds, ema_error)
self.error_map[index] = error_map # put back
loss_rgb = loss_rgb.mean()
else:
raise ValueError("error loss_type")
loss = loss + loss_rgb * self.opt.loss_rate_rgb
if self.opt.l1_reg_weight > 0.0 and self.opt.model_type == "vm":
loss = loss + self.model_stu.density_loss() * self.opt.l1_reg_weight
if (
self.opt.loss_rate_fea_sc > 0.0
and self.model_stu.feature_sigma_color is not None
and self.model_tea.feature_sigma_color is not None
):
assert (
self.model_stu.feature_sigma_color.shape
== self.model_tea.feature_sigma_color.shape
)
loss_fea_sc = self.get_loss(
self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color
)
loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc
elif (
self.model_stu.feature_sigma_color is None
or self.model_tea.feature_sigma_color is None
):
loss_fea_sc = torch.tensor(0.0)
else:
assert (
self.model_stu.feature_sigma_color.shape
== self.model_tea.feature_sigma_color.shape
)
loss_fea_sc = self.get_loss(
self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color
)
if self.opt.loss_rate_color > 0.0:
assert self.model_stu.color_l.shape == self.model_tea.color_l.shape
loss_color = self.get_loss(self.model_stu.color_l, self.model_tea.color_l)
loss = loss + self.opt.loss_rate_color * loss_color
else:
assert self.model_stu.color_l.shape == self.model_tea.color_l.shape
loss_color = self.get_loss(self.model_stu.color_l, self.model_tea.color_l)
if self.opt.loss_rate_sigma > 0.0:
assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape
loss_sigma = self.get_loss(self.model_stu.sigma_l, self.model_tea.sigma_l)
loss = loss + self.opt.loss_rate_sigma * loss_sigma
else:
assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape
loss_sigma = self.get_loss(self.model_stu.sigma_l, self.model_tea.sigma_l)
loss_rgb_show = self.criterion(
pred_rgb_tea.detach(), pred_rgb_stu.detach()
).mean() # [B, N, 3] --> [B, N]
return (
pred_rgb_stu,
gt_rgb,
loss,
loss_rgb_show.detach().item(),
loss_fea_sc.detach().item(),
loss_color.detach().item(),
loss_sigma.detach().item(),
)
### ------------------------------
def evaluate(self, loader, name=None):
self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
self.evaluate_one_epoch(loader, name)
self.use_tensorboardX = use_tensorboardX
def evaluate_one_epoch(self, loader, name=None):
if name is None:
name = f"{self.name}_ep{self.epoch:04d}"
total_loss = 0
if self.local_rank == 0:
for metric in self.metrics:
metric.clear()
if self.opt.test_teacher:
self.model_stu = self.model_tea
self.model_stu.eval()
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
if self.local_rank == 0:
pbar = tqdm.tqdm(
total=len(loader) * loader.batch_size,
bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
)
with torch.no_grad():
self.local_step = 0
self.ssim = 0.0
self.lpips_vgg = 0.0
self.lpips_alex = 0.0
# update grid
if self.model_stu.cuda_ray:
with torch.cuda.amp.autocast(enabled=self.fp16):
if self.opt.update_stu_extra:
self.model_stu.update_extra_state()
else:
pass
frames = []
frames_depth = []
for data in loader:
self.local_step += 1
with torch.cuda.amp.autocast(enabled=self.fp16):
preds, preds_depth, truths, loss = self.eval_step(data)
# all_gather/reduce the statistics (NCCL only support all_*)
if self.world_size > 1:
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss = loss / self.world_size
preds_list = [
torch.zeros_like(preds).to(self.device)
for _ in range(self.world_size)
] # [[B, ...], [B, ...], ...]
dist.all_gather(preds_list, preds)
preds = torch.cat(preds_list, dim=0)
preds_depth_list = [
torch.zeros_like(preds_depth).to(self.device)
for _ in range(self.world_size)
] # [[B, ...], [B, ...], ...]
dist.all_gather(preds_depth_list, preds_depth)
preds_depth = torch.cat(preds_depth_list, dim=0)
truths_list = [
torch.zeros_like(truths).to(self.device)
for _ in range(self.world_size)
] # [[B, ...], [B, ...], ...]
dist.all_gather(truths_list, truths)
truths = torch.cat(truths_list, dim=0)
loss_val = loss.item()
total_loss += loss_val
if self.local_rank == 0:
for metric in self.metrics:
metric.update(preds, truths)
self.lpips_alex += rgb_lpips(truths, preds, "alex")
self.lpips_vgg += rgb_lpips(truths, preds, "vgg")
self.ssim += compute_ssim(
preds,
truths,
max_val=max(preds.max().item(), truths.max().item()),
).item()
# save image
save_path = os.path.join(
self.workspace,
loader._data.type,
f"{name}_{self.local_step:04d}.png",
)
save_path_depth = os.path.join(
self.workspace,
loader._data.type,
f"{name}_{self.local_step:04d}_depth.png",
)
# save_path_gt = os.path.join(self.workspace, loader._data.type, f'{name}_{self.local_step:04d}_gt.png')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
if self.opt.color_space == "linear":
preds = linear_to_srgb(preds)
pred = preds[0].detach().cpu().numpy()
truth = truths[0].detach().cpu().numpy()
pred_depth = preds_depth[0].detach().cpu().numpy()
cv2.imwrite(
save_path,
cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR),
)
cv2.imwrite(save_path_depth, (pred_depth * 255).astype(np.uint8))
frames.append((pred * 255).astype(np.uint8))
frames_depth.append((pred_depth * 255).astype(np.uint8))
pbar.set_description(
f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})"
)
pbar.update(loader.batch_size)
print(
f"\n----video num(frames): {len(frames)} depth video num:{len(frames_depth)} ----\n"
)
imageio.mimwrite(
os.path.join(os.path.dirname(save_path), "video.mp4"),
frames,
fps=int(30 * 0.7),
macro_block_size=8,
)
imageio.mimwrite(
os.path.join(os.path.dirname(save_path), "video_depth.mp4"),
frames_depth,
fps=int(30 * 0.7),
macro_block_size=8,
)
psnr_tool = self.metrics[0]
psnr_tool.psnr_list.sort()
self.log(
f"\neeeeeeeee> {loader._data.type} PSRN Report: Epoch{self.epoch}. psnr_mean:{np.mean(psnr_tool.psnr_list):.2f}"
)
average_loss = total_loss / self.local_step
self.stats["valid_loss"].append(average_loss)
if self.local_rank == 0:
pbar.close()
if not self.use_loss_as_metric and len(self.metrics) > 0:
result = self.metrics[0].measure()
self.stats["results"].append(
result if self.best_mode == "min" else -result
) # if max mode, use -result
else:
self.stats["results"].append(
average_loss
) # if no metric, choose best by min loss
for metric in self.metrics:
# self.log(metric.report(), style="blue")
psnr = metric.report().split("=")[-1].strip()[:5]
self.psnr = float(psnr)
if self.use_tensorboardX and loader._data.type == 'val':
metric.write(self.writer, self.epoch, prefix="evaluate")
metric.clear()
self.ssim /= self.local_step
self.lpips_alex /= self.local_step
self.lpips_vgg /= self.local_step
if self.ema is not None:
self.ema.restore()
self.log(
f"eeeeeeeeee> {loader._data.type} Metric Report: Epoch{self.epoch}. psnr:{psnr} ssim:{self.ssim:.2f} alex:{self.lpips_alex:.2f} vgg:{self.lpips_vgg:.2f}"
)
def eval_step(self, data):
rays_o = data["rays_o"] # [B, N, 3]
rays_d = data["rays_d"] # [B, N, 3]
images = data["images"] # [B, H, W, 3/4]
B, H, W, C = images.shape
if self.opt.color_space == "linear":
images[..., :3] = srgb_to_linear(images[..., :3])
# eval with fixed background color
bg_color = 1
if C == 4:
gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (
1 - images[..., 3:]
)
else:
gt_rgb = images
outputs = self.model_stu.render(
rays_o,
rays_d,
staged=True,
bg_color=bg_color,
perturb=False,
**vars(self.opt),
)
pred_rgb = outputs["image"].reshape(B, H, W, 3)
pred_depth = outputs["depth"].reshape(B, H, W)
loss = self.criterion(pred_rgb, gt_rgb).mean()
return pred_rgb, pred_depth, gt_rgb, loss
def save_checkpoint(self, name=None, full=False, best=False, remove_old=True):
full = False
if name is None:
name = f"{self.name}_ep{self.epoch:04d}"
if self.opt.model_type == "vm":
state = {
"epoch": self.epoch,
"global_step": self.global_step,
"stats": self.stats,
"resolution": self.model_stu.resolution,
}
else:
state = {
"epoch": self.epoch,
"global_step": self.global_step,
"stats": self.stats,
}
if self.model_stu.cuda_ray:
state["mean_count"] = self.model_stu.mean_count
state["mean_density"] = self.model_stu.mean_density
if full:
state["optimizer"] = self.optimizer.state_dict()
state["lr_scheduler"] = self.lr_scheduler.state_dict()
state["scaler"] = self.scaler.state_dict()
if self.ema is not None:
state["ema"] = self.ema.state_dict()
if not best:
state["model"] = self.model_stu.state_dict()
file_path = f"{self.ckpt_path}/{name}.pth"
if remove_old:
self.stats["checkpoints"].append(file_path)
if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
old_ckpt = self.stats["checkpoints"].pop(0)
if os.path.exists(old_ckpt):
os.remove(old_ckpt)
torch.save(state, file_path)
else:
if len(self.stats["results"]) > 0:
if (
self.stats["best_result"] is None
or self.stats["results"][-1] < self.stats["best_result"]
):
self.log(
f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}"
)
self.stats["best_result"] = self.stats["results"][-1]
# save ema results
if self.ema is not None:
self.ema.store()
self.ema.copy_to()
state["model"] = self.model_stu.state_dict()
if self.ema is not None:
self.ema.restore()
torch.save(state, self.best_path)
else:
self.log(
f"[WARN] no evaluated results found, skip saving best checkpoint."
)
def load_teacher_checkpoint(self):
checkpoint_dict = torch.load(self.opt.ckpt_teacher, map_location=self.device)
missing_keys, unexpected_keys = self.model_tea.load_state_dict(
checkpoint_dict["model"], strict=False
)
self.log("[INFO] loaded teacher model.")
if len(missing_keys) > 0:
self.log(f"[WARN] missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
self.log(f"[WARN] unexpected keys: {unexpected_keys}")
if self.ema is not None and "ema" in checkpoint_dict:
self.ema.load_state_dict(checkpoint_dict["ema"])
if self.model_tea.cuda_ray:
if "mean_count" in checkpoint_dict:
self.model_tea.mean_count = checkpoint_dict["mean_count"]
if "mean_density" in checkpoint_dict:
self.model_tea.mean_density = checkpoint_dict["mean_density"]
"""
self.stats = checkpoint_dict['stats']
self.epoch = checkpoint_dict['epoch']
self.global_step = checkpoint_dict['global_step']
self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
if self.optimizer and 'optimizer' in checkpoint_dict:
try:
self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
self.log("[INFO] loaded optimizer.")
except:
self.log("[WARN] Failed to load optimizer.")
if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
try:
self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
self.log("[INFO] loaded scheduler.")
except:
self.log("[WARN] Failed to load scheduler.")
if self.scaler and 'scaler' in checkpoint_dict:
try:
self.scaler.load_state_dict(checkpoint_dict['scaler'])
self.log("[INFO] loaded scaler.")
except:
self.log("[WARN] Failed to load scaler.")
if self.model_tea.cuda_ray:
if 'mean_count' in checkpoint_dict:
self.model_tea.mean_count = checkpoint_dict['mean_count']
if 'mean_density' in checkpoint_dict:
self.model_tea.mean_density = checkpoint_dict['mean_density']
"""
def load_student_checkpoint(self):
if self.opt.ckpt_student:
checkpoint_dict = torch.load(
self.opt.ckpt_student, map_location=self.device
)
else:
checkpoint_dict = torch.load(
self.opt.ckpt_teacher, map_location=self.device
)
if self.opt.model_type == "vm" and "resolution" in checkpoint_dict:
self.model_stu.upsample_model(checkpoint_dict["resolution"])
missing_keys, unexpected_keys = self.model_stu.load_state_dict(
checkpoint_dict["model"], strict=False
)
self.log("[INFO] loaded student model.")
if len(missing_keys) > 0:
self.log(f"[WARN] missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
self.log(f"[WARN] unexpected keys: {unexpected_keys}")
if self.model_stu.cuda_ray:
if "mean_count" in checkpoint_dict:
self.model_stu.mean_count = checkpoint_dict["mean_count"]
if "mean_density" in checkpoint_dict:
self.model_stu.mean_density = checkpoint_dict["mean_density"]
if self.ema is not None and "ema" in checkpoint_dict:
self.ema.load_state_dict(checkpoint_dict["ema"])
"""
self.stats = checkpoint_dict['stats']
self.epoch = checkpoint_dict['epoch']
self.global_step = checkpoint_dict['global_step']
self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
if self.optimizer and 'optimizer' in checkpoint_dict:
try:
self.optimizer.load_state_dict(checkpoint_dict['optimizer'])
self.log("[INFO] loaded optimizer.")
except:
self.log("[WARN] Failed to load optimizer.")
if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict:
try:
self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler'])
self.log("[INFO] loaded scheduler.")
except:
self.log("[WARN] Failed to load scheduler.")
if self.scaler and 'scaler' in checkpoint_dict:
try:
self.scaler.load_state_dict(checkpoint_dict['scaler'])
self.log("[INFO] loaded scaler.")
except:
self.log("[WARN] Failed to load scaler.")
"""
def load_checkpoint(self, checkpoint=None, model_only=False):
if checkpoint is None:
checkpoint_list = sorted(glob.glob(f"{self.ckpt_path}/{self.name}_ep*.pth"))
if checkpoint_list:
checkpoint = checkpoint_list[-1]
self.log(f"[INFO] Latest checkpoint is {checkpoint}")
else:
self.log("[WARN] No checkpoint found, model randomly initialized.")
return
checkpoint_dict = torch.load(checkpoint, map_location=self.device)
if "model" not in checkpoint_dict:
self.model.load_state_dict(checkpoint_dict)
self.log("[INFO] loaded model.")
return
missing_keys, unexpected_keys = self.model.load_state_dict(
checkpoint_dict["model"], strict=False
)
self.log("[INFO] loaded model.")
if len(missing_keys) > 0:
self.log(f"[WARN] missing keys: {missing_keys}")
if len(unexpected_keys) > 0:
self.log(f"[WARN] unexpected keys: {unexpected_keys}")
if self.ema is not None and "ema" in checkpoint_dict:
self.ema.load_state_dict(checkpoint_dict["ema"])
if self.model.cuda_ray:
if "mean_count" in checkpoint_dict:
self.model.mean_count = checkpoint_dict["mean_count"]
if "mean_density" in checkpoint_dict:
self.model.mean_density = checkpoint_dict["mean_density"]
if model_only:
return
self.stats = checkpoint_dict["stats"]
self.epoch = checkpoint_dict["epoch"]
self.global_step = checkpoint_dict["global_step"]
self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")
if self.optimizer and "optimizer" in checkpoint_dict:
try:
self.optimizer.load_state_dict(checkpoint_dict["optimizer"])
self.log("[INFO] loaded optimizer.")
except:
self.log("[WARN] Failed to load optimizer.")
if self.lr_scheduler and "lr_scheduler" in checkpoint_dict:
try:
self.lr_scheduler.load_state_dict(checkpoint_dict["lr_scheduler"])
self.log("[INFO] loaded scheduler.")
except:
self.log("[WARN] Failed to load scheduler.")
if self.scaler and "scaler" in checkpoint_dict:
try:
self.scaler.load_state_dict(checkpoint_dict["scaler"])
self.log("[INFO] loaded scaler.")
except:
self.log("[WARN] Failed to load scaler.")
def test(self, loader, save_path=None, name=None):
assert 1 == 2
if save_path is None:
save_path = os.path.join(self.workspace, "results")
if name is None:
name = f"{self.name}_ep{self.epoch:04d}"
os.makedirs(save_path, exist_ok=True)
self.log(f"==> Start Test, save results to {save_path}")
pbar = tqdm.tqdm(
total=len(loader) * loader.batch_size,
bar_format="{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
)
self.model_stu.eval()
with torch.no_grad():
# update grid
if self.model_stu.cuda_ray:
with torch.cuda.amp.autocast(enabled=self.fp16):
self.model_stu.update_extra_state()
for i, data in enumerate(loader):
with torch.cuda.amp.autocast(enabled=self.fp16):
preds, preds_depth = self.test_step(data)
path = os.path.join(save_path, f"{name}_{i:04d}.png")
path_depth = os.path.join(save_path, f"{name}_{i:04d}_depth.png")
# self.log(f"[INFO] saving test image to {path}")
if self.opt.color_space == "linear":
preds = linear_to_srgb(preds)
pred = preds[0].detach().cpu().numpy()
pred_depth = preds_depth[0].detach().cpu().numpy()
cv2.imwrite(
path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
)
cv2.imwrite(path_depth, (pred_depth * 255).astype(np.uint8))
pbar.update(loader.batch_size)
self.log(f"==> Finished Test.")
# moved out bg_color and perturb for more flexible control...
def test_step(self, data, bg_color=None, perturb=False):
rays_o = data["rays_o"] # [B, N, 3]
rays_d = data["rays_d"] # [B, N, 3]
H, W = data["H"], data["W"]
if bg_color is not None:
bg_color = bg_color.to(self.device)
outputs = self.model_stu.render(
rays_o,
rays_d,
staged=True,
bg_color=bg_color,
perturb=perturb,
**vars(self.opt),
)
pred_rgb = outputs["image"].reshape(-1, H, W, 3)
pred_depth = outputs["depth"].reshape(-1, H, W)
return pred_rgb, pred_depth
================================================
FILE: gridencoder/__init__.py
================================================
from .grid import GridEncoder
================================================
FILE: gridencoder/backend.py
================================================
import os
from torch.utils.cpp_extension import load
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
"-O3",
"-std=c++14",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]
if os.name == "posix":
c_flags = ["-O3", "-std=c++14"]
elif os.name == "nt":
c_flags = ["/O2", "/std:c++17"]
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(
glob.glob(
r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64"
% edition
),
reverse=True,
)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError(
"Could not locate a supported Microsoft Visual C++ installation"
)
os.environ["PATH"] += ";" + cl_path
_backend = load(
name="_grid_encoder",
extra_cflags=c_flags,
extra_cuda_cflags=nvcc_flags,
sources=[
os.path.join(_src_path, "src", f)
for f in [
"gridencoder.cu",
"bindings.cpp",
]
],
)
__all__ = ["_backend"]
================================================
FILE: gridencoder/grid.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import _gridencoder as _backend
except ImportError:
from .backend import _backend
_gridtype_to_id = {
"hash": 0,
"tiled": 1,
}
class _grid_encode(Function):
@staticmethod
@custom_fwd
def forward(
ctx,
inputs,
embeddings,
offsets,
per_level_scale,
base_resolution,
calc_grad_inputs=False,
gridtype=0,
align_corners=False,
):
# inputs: [B, D], float in [0, 1]
# embeddings: [sO, C], float
# offsets: [L + 1], int
# RETURN: [B, F], float
inputs = inputs.contiguous()
B, D = inputs.shape # batch size, coord dim
L = offsets.shape[0] - 1 # level
C = embeddings.shape[1] # embedding dim for each level
S = np.log2(
per_level_scale
) # resolution multiplier at each level, apply log2 for later CUDA exp2f
H = base_resolution # base resolution
# manually handle autocast (only use half precision embeddings, inputs must be float for enough precision)
# if C % 2 != 0, force float, since half for atomicAdd is very slow.
if torch.is_autocast_enabled() and C % 2 == 0:
embeddings = embeddings.to(torch.half)
# L first, optimize cache for cuda kernel, but needs an extra permute later
outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype)
if calc_grad_inputs:
dy_dx = torch.empty(
B, L * D * C, device=inputs.device, dtype=embeddings.dtype
)
else:
dy_dx = torch.empty(
1, device=inputs.device, dtype=embeddings.dtype
) # placeholder... TODO: a better way?
_backend.grid_encode_forward(
inputs,
embeddings,
offsets,
outputs,
B,
D,
C,
L,
S,
H,
calc_grad_inputs,
dy_dx,
gridtype,
align_corners,
)
# permute back to [B, L * C]
outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
ctx.dims = [B, D, C, L, S, H, gridtype]
ctx.calc_grad_inputs = calc_grad_inputs
ctx.align_corners = align_corners
return outputs
@staticmethod
# @once_differentiable
@custom_bwd
def backward(ctx, grad):
inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
B, D, C, L, S, H, gridtype = ctx.dims
calc_grad_inputs = ctx.calc_grad_inputs
align_corners = ctx.align_corners
# grad: [B, L * C] --> [L, B, C]
grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
grad_embeddings = torch.zeros_like(embeddings)
if calc_grad_inputs:
grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype)
else:
grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype)
_backend.grid_encode_backward(
grad,
inputs,
embeddings,
offsets,
grad_embeddings,
B,
D,
C,
L,
S,
H,
calc_grad_inputs,
dy_dx,
grad_inputs,
gridtype,
align_corners,
)
if calc_grad_inputs:
grad_inputs = grad_inputs.to(inputs.dtype)
return grad_inputs, grad_embeddings, None, None, None, None, None, None
else:
return None, grad_embeddings, None, None, None, None, None, None
grid_encode = _grid_encode.apply
class GridEncoder(nn.Module):
def __init__(
self,
input_dim=3,
num_levels=16,
level_dim=2,
per_level_scale=2,
base_resolution=16,
log2_hashmap_size=19,
desired_resolution=None,
gridtype="hash",
align_corners=False,
):
super().__init__()
# the finest resolution desired at the last level, if provided, overridee per_level_scale
if desired_resolution is not None:
per_level_scale = np.exp2(
np.log2(desired_resolution / base_resolution) / (num_levels - 1)
)
self.input_dim = input_dim # coord dims, 2 or 3
self.num_levels = num_levels # num levels, each level multiply resolution by 2
self.level_dim = level_dim # encode channels per level
self.per_level_scale = (
per_level_scale # multiply resolution by this scale at each level.
)
self.log2_hashmap_size = log2_hashmap_size
self.base_resolution = base_resolution
self.output_dim = num_levels * level_dim
self.gridtype = gridtype
self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash"
self.align_corners = align_corners
# allocate parameters
offsets = []
offset = 0
self.max_params = 2 ** log2_hashmap_size
for i in range(num_levels):
resolution = int(np.ceil(base_resolution * per_level_scale ** i))
params_in_level = min(
self.max_params,
(resolution if align_corners else resolution + 1) ** input_dim,
) # limit max number
params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible
offsets.append(offset)
offset += params_in_level
offsets.append(offset)
offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
self.register_buffer("offsets", offsets)
self.n_params = offsets[-1] * level_dim
# parameters
self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
self.reset_parameters()
def reset_parameters(self):
std = 1e-4
self.embeddings.data.uniform_(-std, std)
def __repr__(self):
return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}"
def forward(self, inputs, bound=1):
# inputs: [..., input_dim], normalized real world positions in [-bound, bound]
# return: [..., num_levels * level_dim]
inputs = (inputs + bound) / (2 * bound) # map to [0, 1]
# print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item())
prefix_shape = list(inputs.shape[:-1])
inputs = inputs.view(-1, self.input_dim)
outputs = grid_encode(
inputs,
self.embeddings,
self.offsets,
self.per_level_scale,
self.base_resolution,
inputs.requires_grad,
self.gridtype_id,
self.align_corners,
)
outputs = outputs.view(prefix_shape + [self.output_dim])
# print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item())
return outputs
================================================
FILE: gridencoder/setup.py
================================================
import os
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
_src_path = os.path.dirname(os.path.abspath(__file__))
nvcc_flags = [
"-O3",
"-std=c++14",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]
if os.name == "posix":
c_flags = ["-O3", "-std=c++14"]
elif os.name == "nt":
c_flags = ["/O2", "/std:c++17"]
# find cl.exe
def find_cl_path():
import glob
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
paths = sorted(
glob.glob(
r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64"
% edition
),
reverse=True,
)
if paths:
return paths[0]
# If cl.exe is not on path, try to find it.
if os.system("where cl.exe >nul 2>nul") != 0:
cl_path = find_cl_path()
if cl_path is None:
raise RuntimeError(
"Could not locate a supported Microsoft Visual C++ installation"
)
os.environ["PATH"] += ";" + cl_path
setup(
name="gridencoder", # package name, import this to use python API
ext_modules=[
CUDAExtension(
name="_gridencoder", # extension name, import this to use CUDA API
sources=[
os.path.join(_src_path, "src", f)
for f in [
"gridencoder.cu",
"bindings.cpp",
]
],
extra_compile_args={
"cxx": c_flags,
"nvcc": nvcc_flags,
},
),
],
cmdclass={
"build_ext": BuildExtension,
},
)
================================================
FILE: gridencoder/src/bindings.cpp
================================================
#include <torch/extension.h>
#include "gridencoder.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)");
m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)");
}
================================================
FILE: gridencoder/src/gridencoder.cu
================================================
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/torch.h>
#include <algorithm>
#include <stdexcept>
#include <stdint.h>
#include <cstdio>
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF...
static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) {
// requires CUDA >= 10 and ARCH >= 70
// this is very slow compared to float or __half2, and never used.
//return atomicAdd(reinterpret_cast<__half*>(address), val);
}
template <typename T>
static inline __host__ __device__ T div_round_up(T val, T divisor) {
return (val + divisor - 1) / divisor;
}
template <uint32_t D>
__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) {
static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions.");
// While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence
// and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional
// coordinates.
constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 };
uint32_t result = 0;
#pragma unroll
for (uint32_t i = 0; i < D; ++i) {
result ^= pos_grid[i] * primes[i];
}
return result;
}
template <uint32_t D, uint32_t C>
__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) {
uint32_t stride = 1;
uint32_t index = 0;
#pragma unroll
for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) {
index += pos_grid[d] * stride;
stride *= align_corners ? resolution: (resolution + 1);
}
// NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97.
// gridtype: 0 == hash, 1 == tiled
if (gridtype == 0 && stride > hashmap_size) {
index = fast_hash<D>(pos_grid);
}
return (index % hashmap_size) * C + ch;
}
template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_grid(
const float * __restrict__ inputs,
const scalar_t * __restrict__ grid,
const int * __restrict__ offsets,
scalar_t * __restrict__ outputs,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
const bool calc_grad_inputs,
scalar_t * __restrict__ dy_dx,
const uint32_t gridtype,
const bool align_corners
) {
const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B) return;
const uint32_t level = blockIdx.y;
// locate
grid += (uint32_t)offsets[level] * C;
inputs += b * D;
outputs += level * B * C + b * C;
// check input range (should be in [0, 1])
bool flag_oob = false;
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
flag_oob = true;
}
}
// if input out of bound, just set output to 0
if (flag_oob) {
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
outputs[ch] = 0;
}
if (calc_grad_inputs) {
dy_dx += b * D * L * C + level * D * C; // B L D C
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
dy_dx[d * C + ch] = 0;
}
}
}
return;
}
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const float scale = exp2f(level * S) * H - 1.0f;
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
// calculate coordinate
float pos[D];
uint32_t pos_grid[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
pos_grid[d] = floorf(pos[d]);
pos[d] -= (float)pos_grid[d];
}
//printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]);
// interpolate
scalar_t results[C] = {0}; // temp results in register
#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
float w = 1;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if ((idx & (1 << d)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = pos_grid[d] + 1;
}
}
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
// writing to register (fast)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
results[ch] += w * grid[index + ch];
}
//printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]);
}
// writing to global memory (slow)
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
outputs[ch] = results[ch];
}
// prepare dy_dx for calc_grad_inputs
// differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9
if (calc_grad_inputs) {
dy_dx += b * D * L * C + level * D * C; // B L D C
#pragma unroll
for (uint32_t gd = 0; gd < D; gd++) {
scalar_t results_grad[C] = {0};
#pragma unroll
for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) {
float w = scale;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t nd = 0; nd < D - 1; nd++) {
const uint32_t d = (nd >= gd) ? (nd + 1) : nd;
if ((idx & (1 << nd)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = pos_grid[d] + 1;
}
}
pos_grid_local[gd] = pos_grid[gd];
uint32_t index_left = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
pos_grid_local[gd] = pos_grid[gd] + 1;
uint32_t index_right = get_grid_index<D, C>(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local);
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]);
}
}
#pragma unroll
for (uint32_t ch = 0; ch < C; ch++) {
dy_dx[gd * C + ch] = results_grad[ch];
}
}
}
}
template <typename scalar_t, uint32_t D, uint32_t C, uint32_t N_C>
__global__ void kernel_grid_backward(
const scalar_t * __restrict__ grad,
const float * __restrict__ inputs,
const scalar_t * __restrict__ grid,
const int * __restrict__ offsets,
scalar_t * __restrict__ grad_grid,
const uint32_t B, const uint32_t L, const float S, const uint32_t H,
const uint32_t gridtype,
const bool align_corners
) {
const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C;
if (b >= B) return;
const uint32_t level = blockIdx.y;
const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C;
// locate
grad_grid += offsets[level] * C;
inputs += b * D;
grad += level * B * C + b * C + ch; // L, B, C
const uint32_t hashmap_size = offsets[level + 1] - offsets[level];
const float scale = exp2f(level * S) * H - 1.0f;
const uint32_t resolution = (uint32_t)ceil(scale) + 1;
// check input range (should be in [0, 1])
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if (inputs[d] < 0 || inputs[d] > 1) {
return; // grad is init as 0, so we simply return.
}
}
// calculate coordinate
float pos[D];
uint32_t pos_grid[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f);
pos_grid[d] = floorf(pos[d]);
pos[d] -= (float)pos_grid[d];
}
scalar_t grad_cur[N_C] = {0}; // fetch to register
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
grad_cur[c] = grad[c];
}
// interpolate
#pragma unroll
for (uint32_t idx = 0; idx < (1 << D); idx++) {
float w = 1;
uint32_t pos_grid_local[D];
#pragma unroll
for (uint32_t d = 0; d < D; d++) {
if ((idx & (1 << d)) == 0) {
w *= 1 - pos[d];
pos_grid_local[d] = pos_grid[d];
} else {
w *= pos[d];
pos_grid_local[d] = pos_grid[d] + 1;
}
}
uint32_t index = get_grid_index<D, C>(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local);
// atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0
// TODO: use float which is better than __half, if N_C % 2 != 0
if (std::is_same<scalar_t, at::Half>::value && N_C % 2 == 0) {
#pragma unroll
for (uint32_t c = 0; c < N_C; c += 2) {
// process two __half at once (by interpreting as a __half2)
__half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])};
atomicAdd((__half2*)&grad_grid[index + c], v);
}
// float, or __half when N_C % 2 != 0 (which means C == 1)
} else {
#pragma unroll
for (uint32_t c = 0; c < N_C; c++) {
atomicAdd(&grad_grid[index + c], w * grad_cur[c]);
}
}
}
}
template <typename scalar_t, uint32_t D, uint32_t C>
__global__ void kernel_input_backward(
const scalar_t * __restrict__ grad,
const scalar_t * __restrict__ dy_dx,
scalar_t * __restrict__ grad_inputs,
uint32_t B, uint32_t L
) {
const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (t >= B * D) return;
const uint32_t b = t / D;
const uint32_t d = t - b * D;
dy_dx += b * L * D * C;
scalar_t result = 0;
# pragma unroll
for (int l = 0; l < L; l++) {
# pragma unroll
for (int ch = 0; ch < C; ch++) {
result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch];
}
}
grad_inputs[t] = result;
}
template <typename scalar_t, uint32_t D>
void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
static constexpr uint32_t N_THREAD = 512;
const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 };
switch (C) {
case 1: kernel_grid<scalar_t, D, 1><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
case 2: kernel_grid<scalar_t, D, 2><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
case 4: kernel_grid<scalar_t, D, 4><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
case 8: kernel_grid<scalar_t, D, 8><<<blocks_hashgrid, N_THREAD>>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
}
}
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.)
// H: base resolution
// dy_dx: [B, L * D * C]
template <typename scalar_t>
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
}
}
template <typename scalar_t, uint32_t D>
void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
static constexpr uint32_t N_THREAD = 256;
const uint32_t N_C = std::min(2u, C); // n_features_per_thread
const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 };
switch (C) {
case 1:
kernel_grid_backward<scalar_t, D, 1, 1><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 1><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 2:
kernel_grid_backward<scalar_t, D, 2, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 2><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 4:
kernel_grid_backward<scalar_t, D, 4, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 4><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
case 8:
kernel_grid_backward<scalar_t, D, 8, 2><<<blocks_hashgrid, N_THREAD>>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners);
if (calc_grad_inputs) kernel_input_backward<scalar_t, D, 8><<<div_round_up(B * D, N_THREAD), N_THREAD>>>(grad, dy_dx, grad_inputs, B, L);
break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
}
}
// grad: [L, B, C], float
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// grad_embeddings: [sO, C]
// H: base resolution
template <typename scalar_t>
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
}
}
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) {
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(outputs);
CHECK_CUDA(dy_dx);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(outputs);
CHECK_CONTIGUOUS(dy_dx);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOATING(outputs);
CHECK_IS_FLOATING(dy_dx);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
embeddings.scalar_type(), "grid_encode_forward", ([&] {
grid_encode_forward_cuda<scalar_t>(inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), outputs.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), gridtype, align_corners);
}));
}
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) {
CHECK_CUDA(grad);
CHECK_CUDA(inputs);
CHECK_CUDA(embeddings);
CHECK_CUDA(offsets);
CHECK_CUDA(grad_embeddings);
CHECK_CUDA(dy_dx);
CHECK_CUDA(grad_inputs);
CHECK_CONTIGUOUS(grad);
CHECK_CONTIGUOUS(inputs);
CHECK_CONTIGUOUS(embeddings);
CHECK_CONTIGUOUS(offsets);
CHECK_CONTIGUOUS(grad_embeddings);
CHECK_CONTIGUOUS(dy_dx);
CHECK_CONTIGUOUS(grad_inputs);
CHECK_IS_FLOATING(grad);
CHECK_IS_FLOATING(inputs);
CHECK_IS_FLOATING(embeddings);
CHECK_IS_INT(offsets);
CHECK_IS_FLOATING(grad_embeddings);
CHECK_IS_FLOATING(dy_dx);
CHECK_IS_FLOATING(grad_inputs);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "grid_encode_backward", ([&] {
grid_encode_backward_cuda<scalar_t>(grad.data_ptr<scalar_t>(), inputs.data_ptr<float>(), embeddings.data_ptr<scalar_t>(), offsets.data_ptr<int>(), grad_embeddings.data_ptr<scalar_t>(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr<scalar_t>(), grad_inputs.data_ptr<scalar_t>(), gridtype, align_corners);
}));
}
================================================
FILE: gridencoder/src/gridencoder.h
================================================
#ifndef _HASH_ENCODE_H
#define _HASH_ENCODE_H
#include <stdint.h>
#include <torch/torch.h>
// inputs: [B, D], float, in [0, 1]
// embeddings: [sO, C], float
// offsets: [L + 1], uint32_t
// outputs: [B, L * C], float
// H: base resolution
void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners);
void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners);
#endif
================================================
FILE: just_train_tea/network.py
================================================
import torch
from time import time
import torch.nn as nn
import torch.nn.functional as F
from tools.encoding import get_encoder
from tools.activation import trunc_exp
from .renderer import NeRFRenderer
import raymarching
class NeRFNetwork(NeRFRenderer):
def __init__(
self,
encoding="hashgrid",
encoding_dir="sphere_harmonics",
encoding_bg="hashgrid",
num_layers=2,
hidden_dim=64,
geo_feat_dim=15,
num_layers_color=3,
hidden_dim_color=64,
num_layers_bg=2,
hidden_dim_bg=64,
bound=1,
model_type="hash",
args=None,
is_teacher=False,
**kwargs,
):
super().__init__(bound, **kwargs)
# sigma network
assert model_type in ["hash", "mlp", "vm", "tensors"]
self.is_teacher = is_teacher
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.geo_feat_dim = geo_feat_dim
self.args = args
self.opt = args
self.model_type = model_type
self.plenoxel_degree = args.plenoxel_degree
self.plenoxel_res = eval(args.plenoxel_res)
assert len(self.plenoxel_res) == 3
self.encoder, self.in_dim = get_encoder(
encoding,
desired_resolution=2048 * bound,
num_levels=14,
)
if "hash" != self.model_type:
self.encoder = None
if self.model_type == "mlp":
self.encoder_nerf_pe, self.in_dim_nerf = get_encoder(
encoding="frequency", multires=self.args.PE
)
self.skips = self.args.skip
self.nerf_layer_num = self.args.nerf_layer_num
W = self.args.nerf_layer_wide
self.nerf_mlp = [nn.Linear(self.in_dim_nerf, W)]
for i in range(self.nerf_layer_num - 2):
if i != self.skips:
self.nerf_mlp.append(nn.Linear(W, W))
else:
self.nerf_mlp.append(nn.Linear(W + self.in_dim_nerf, W))
self.nerf_mlp.append(nn.Linear(W, self.in_dim))
self.nerf_mlp = nn.ModuleList(self.nerf_mlp)
elif self.model_type == "vm":
self.sigma_rank = [16] * 3
self.color_rank = [48] * 3
self.color_feat_dim = 15 # geo_feat_dim
self.mat_ids = [[0, 1], [0, 2], [1, 2]]
self.vec_ids = [2, 1, 0]
self.resolution = [self.opt.resolution0] * 3
# mat: paralist[1,16,res0,res0] repeat 3 vec: paralist[1,16,res0,1] repeat 3; repeat3 because decompose 3D grid [H, W, D] to three 2D mat [H, W], [H,D], [W, D] or decompose to three 1D vec [H], [W], [D]
self.sigma_mat, self.sigma_vec = self.init_one_vm(
self.sigma_rank, self.resolution
)
# mat: paralist[1,48,res0,res0] repeat 3 vec: paralist[1,48,res0,1] repeat 3
self.color_mat, self.color_vec = self.init_one_vm(
self.color_rank, self.resolution
)
# Linear(in_features=144, out_features=27)
self.basis_mat = nn.Linear(
sum(self.color_rank), self.color_feat_dim, bias=False
)
elif self.model_type == "tensors":
self.init_plenoxel_volume(
s=0.02,
fea_dim=self.plenoxel_degree ** 2 * 3 + 1,
volume=self.plenoxel_res,
)
elif self.model_type == "hash":
pass
else:
raise ValueError(f"error model_type:{self.model_type}")
if self.model_type != "vm" and self.model_type != "tensors":
sigma_net = []
for l in range(num_layers):
if l == 0:
in_dim = self.in_dim
else:
in_dim = hidden_dim
if l == num_layers - 1:
out_dim = (
1 + self.geo_feat_dim
) # 1 sigma + 15 SH features for color
else:
out_dim = hidden_dim
sigma_net.append(nn.Linear(in_dim, out_dim, bias=False))
self.sigma_net = nn.ModuleList(sigma_net)
# color network
self.num_layers_color = num_layers_color
self.hidden_dim_color = hidden_dim_color
# self.encoder_dir, self.in_dim_dir = get_encoder(encoding=encoding_dir)
if self.model_type == "tensors":
self.encoder_dir, self.in_dim_dir = get_encoder(
encoding="sphere_harmonics",
degree=self.plenoxel_degree,
)
else:
self.encoder_dir, self.in_dim_dir = get_encoder(
encoding=encoding_dir, input_dim=3, multires=2
)
if self.model_type != "tensors":
color_net = []
for l in range(num_layers_color):
if l == 0:
in_dim = self.in_dim_dir + self.geo_feat_dim
else:
in_dim = hidden_dim
if l == num_layers_color - 1:
out_dim = 3 # 3 rgb
else:
out_dim = hidden_dim
color_net.append(nn.Linear(in_dim, out_dim, bias=False))
self.color_net = nn.ModuleList(color_net)
# background network
if self.bg_radius > 0:
self.num_layers_bg = num_layers_bg
self.hidden_dim_bg = hidden_dim_bg
self.encoder_bg, self.in_dim_bg = get_encoder(
encoding_bg,
input_dim=2,
num_levels=4,
log2_hashmap_size=19,
desired_resolution=2048,
) # much smaller hashgrid
bg_net = []
for l in range(num_layers_bg):
if l == 0:
in_dim = self.in_dim_bg + self.in_dim_dir
else:
in_dim = hidden_dim_bg
if l == num_layers_bg - 1:
out_dim = 3 # 3 rgb
else:
out_dim = hidden_dim_bg
bg_net.append(nn.Linear(in_dim, out_dim, bias=False))
self.bg_net = nn.ModuleList(bg_net)
else:
self.bg_net = None
def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128, 128]):
tensor = []
tensor.append(
torch.nn.Parameter(
s * torch.randn((1, fea_dim, volume[0], volume[1], volume[2]))
)
)
self.tensor_volume = torch.nn.ParameterList(tensor).cuda()
def init_one_vm(self, n_component, resolution, scale=0.1):
# self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0]
mat, vec = [], []
for i in range(len(self.vec_ids)):
vec_id = self.vec_ids[i]
mat_id_0, mat_id_1 = self.mat_ids[i]
mat.append(
nn.Parameter(
scale
* torch.randn(
(1, n_component[i], resolution[mat_id_1], resolution[mat_id_0])
)
)
) # [1, R, H, W]
vec.append(
nn.Parameter(
scale * torch.randn((1, n_component[i], resolution[vec_id], 1))
)
) # [1, R, D, 1] (fake 2d to use grid_sample)
return nn.ParameterList(mat), nn.ParameterList(vec)
def get_sigma_feat(self, x):
# x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode)
# self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0]
N = x.shape[0]
# plane + line basis
mat_coord = (
torch.stack(
(
x[..., self.mat_ids[0]],
x[..., self.mat_ids[1]],
x[..., self.mat_ids[2]],
)
)
.detach()
.view(3, -1, 1, 2)
) # [3, N, 1, 2]
vec_coord = torch.stack(
(x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])
)
vec_coord = (
torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)
.detach()
.view(3, -1, 1, 2)
) # [3, N, 1, 2], fake 2d coord
sigma_feat = torch.zeros(
[
N,
],
device=x.device,
)
for i in range(len(self.sigma_mat)):
mat_feat = F.grid_sample(
self.sigma_mat[i], mat_coord[[i]], align_corners=True
).view(
-1, N
) # [1, R, N, 1] --> [R, N]
vec_feat = F.grid_sample(
self.sigma_vec[i], vec_coord[[i]], align_corners=True
).view(
-1, N
) # [R, N]
sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0)
return sigma_feat
def get_color_feat(self, x):
# x: [N, 3], in [-1, 1]
N = x.shape[0]
# plane + line basis
mat_coord = (
torch.stack(
(
x[..., self.mat_ids[0]],
x[..., self.mat_ids[1]],
x[..., self.mat_ids[2]],
)
)
.detach()
.view(3, -1, 1, 2)
) # [3, N, 1, 2]
vec_coord = torch.stack(
(x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]])
)
vec_coord = (
torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1)
.detach()
.view(3, -1, 1, 2)
) # [3, N, 1, 2], fake 2d coord
mat_feat, vec_feat = [], []
for i in range(len(self.color_mat)):
mat_feat.append(
F.grid_sample(
self.color_mat[i], mat_coord[[i]], align_corners=True
).view(-1, N)
) # [1, R, N, 1] --> [R, N]
vec_feat.append(
F.grid_sample(
self.color_vec[i], vec_coord[[i]], align_corners=True
).view(-1, N)
) # [R, N]
mat_feat = torch.cat(mat_feat, dim=0) # [3 * R, N]
vec_feat = torch.cat(vec_feat, dim=0) # [3 * R, N]
color_feat = self.basis_mat(
(mat_feat * vec_feat).T
) # [N, 3R] --> [N, color_feat_dim]
return color_feat
def compute_plenoxel_fea(self, x):
composed = self.tensor_volume[0]
composed = (
F.grid_sample(composed, x.view(1, 1, -1, 1, 3), align_corners=True)
.view(-1, x.shape[0])
.permute(1, 0)
)
return composed # [N, fea_dim]
def forward_nerf_mlp(self, x):
x = self.encoder_nerf_pe(x)
in_pts = x
for i in range(len(self.nerf_mlp)):
x = self.nerf_mlp[i](x)
if i != len(self.nerf_mlp) - 1:
x = F.relu(x, inplace=True)
if i == self.skips:
x = torch.cat([in_pts, x], -1)
return x
def forward(self, x, d):
# x: [N, 3], in [-bound, bound] d: [N, 3], nomalized in [-1, 1]
# sigma
if self.model_type == "hash":
x = self.encoder(
x, bound=self.bound
) # out_x[N, 28=num_levels * fea_per_level]
elif self.model_type == "mlp":
x = self.forward_nerf_mlp(x) # 28
elif self.model_type == "vm":
x = (
2
* (x - self.aabb_train[:3])
/ (self.aabb_train[3:] - self.aabb_train[:3])
- 1
) # x:[N, 3]
sigma_feat = self.get_sigma_feat(x) # sigma_feat:[N]
color_feat = self.get_color_feat(x) # color_feat:[N, 15]
sigma_feat = torch.clamp(
sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
)
# color_feat = torch.clamp(color_feat, self.args.sigma_clip_min, self.args.sigma_clip_max)
self.feature_sigma_color = torch.cat(
[sigma_feat.unsqueeze(-1), color_feat], dim=-1
)
self.sigma_l = sigma_feat
sigma = trunc_exp(sigma_feat) # sigma:[N]
enc_d = self.encoder_dir(d) # enc_d:[N, 16]
h = torch.cat([enc_d, color_feat], dim=-1) # h:[N, 16+15]
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)
color = torch.sigmoid(h)
self.color_l = color
return sigma, color
elif self.model_type == "tensors":
x = (
2
* (x - self.aabb_train[:3])
/ (self.aabb_train[3:] - self.aabb_train[:3])
- 1
) # x:[N, 3]
x = self.compute_plenoxel_fea(x)
h = x
sigma = torch.clamp(
h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max
)
self.sigma_l = sigma
sigma = trunc_exp(sigma)
self.sigma = sigma
sh = h[..., 1:].view(
-1, 3, self.plenoxel_degree ** 2
) # [N, 3, 9] ## .permute(1, 0, 2) # [B, 27]-->[9, B, 3]
enc_d = self.encoder_dir(d).unsqueeze(1) # [N, 9]-->[N,1,9]
color = (sh * enc_d).sum(-1) # [N, 3]
color = torch.sigmoid(color)
self.feature_sigma_color = None
self.color_l = color
return sigma, color
else:
raise ValueError(f"not illegal model_type:{self.model_type}")
h = x
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)
h[..., 0] = torch.clamp(
h[..., 0].clone(), self.args.sigma_clip_min, self.args.sigma_clip_max
)
# h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
self.feature_sigma_color = h
self.sigma_l = h[..., 0]
sigma = trunc_exp(h[..., 0]) # sigma: [n]
geo_feat = h[..., 1:] # geo_feat: [n, 15]
d = self.encoder_dir(d) # d: [n, 16]
h = torch.cat([d, geo_feat], dim=-1) # h: [n, 15+16]
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)
color = torch.sigmoid(h)
self.color_l = color
return sigma, color
def density(self, x):
# x: [N, 3], in [-bound, bound]
if self.model_type == "hash":
x = self.encoder(
x, bound=self.bound
) # out_x[N, 32=num_levels * fea_per_level]
elif self.model_type == "mlp":
x = self.forward_nerf_mlp(x)
elif self.model_type == "vm":
x = (
2
* (x - self.aabb_train[:3])
/ (self.aabb_train[3:] - self.aabb_train[:3])
- 1
)
sigma_feat = self.get_sigma_feat(x)
sigma_feat = torch.clamp(
sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max
)
sigma = trunc_exp(sigma_feat)
return {"sigma": sigma}
elif self.model_type == "tensors":
x = (
2
* (x - self.aabb_train[:3])
/ (self.aabb_train[3:] - self.aabb_train[:3])
- 1
) # x:[N, 3]
x = self.compute_plenoxel_fea(x)
h = x
# h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
sigma = trunc_exp(
torch.clamp(
h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max
)
)
sigma = trunc_exp(h[..., 0])
return {"sigma": sigma}
else:
raise ValueError(f"not illegal model_type:{self.model_type}")
h = x
for l in range(self.num_layers):
h = self.sigma_net[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)
h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max)
sigma = trunc_exp(h[..., 0])
geo_feat = h[..., 1:]
return {
"sigma": sigma,
"geo_feat": geo_feat,
}
def background(self, x, d):
assert 1 == 2
# x: [N, 2], in [-1, 1]
h = self.encoder_bg(x) # [N, C]
d = self.encoder_dir(d)
h = torch.cat([d, h], dim=-1)
for l in range(self.num_layers_bg):
h = self.bg_net[l](h)
if l != self.num_layers_bg - 1:
h = F.relu(h, inplace=True)
# sigmoid activation for rgb
rgbs = torch.sigmoid(h)
return rgbs
# allow masked inference
def color(self, x, d, mask=None, geo_feat=None, **kwargs):
assert 1 == 2
# x: [N, 3] in [-bound, bound]
# mask: [N,], bool, indicates where we actually needs to compute rgb.
if mask is not None:
rgbs = torch.zeros(
mask.shape[0], 3, dtype=x.dtype, device=x.device
) # [N, 3]
# in case of empty mask
if not mask.any():
return rgbs
x = x[mask]
d = d[mask]
geo_feat = geo_feat[mask]
d = self.encoder_dir(d)
h = torch.cat([d, geo_feat], dim=-1)
for l in range(self.num_layers_color):
h = self.color_net[l](h)
if l != self.num_layers_color - 1:
h = F.relu(h, inplace=True)
# sigmoid activation for rgb
h = torch.sigmoid(h)
if mask is not None:
rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32
else:
rgbs = h
return rgbs
# L1 penalty for loss
def density_loss(self):
loss = 0
for i in range(len(self.sigma_mat)):
loss = (
loss
+ torch.mean(torch.abs(self.sigma_mat[i]))
+ torch.mean(torch.abs(self.sigma_vec[i]))
)
return loss
# upsample utils
@torch.no_grad()
def upsample_params(self, mat, vec, resolution):
for i in range(len(self.vec_ids)):
vec_id = self.vec_ids[i]
mat_id_0, mat_id_1 = self.mat_ids[i]
mat[i] = nn.Parameter(
F.interpolate(
mat[i].data,
size=(resolution[mat_id_1], resolution[mat_id_0]),
mode="bilinear",
align_corners=True,
)
)
vec[i] = nn.Parameter(
F.interpolate(
vec[i].data,
size=(resolution[vec_id], 1),
mode="bilinear",
align_corners=True,
)
)
@torch.no_grad()
def upsample_model(self, resolution):
self.upsample_params(self.sigma_mat, self.sigma_vec, resolution)
self.upsample_params(self.color_mat, self.color_vec, resolution)
self.resolution = resolution
@torch.no_grad()
def shrink_model(self):
# shrink aabb_train and the model so it only represents the space inside aabb_train.
half_grid_size = self.bound / self.grid_size
thresh = min(self.density_thresh, self.mean_density)
valid_grid = self.density_grid[self.cascade - 1] > thresh # [N]
valid_pos = raymarching.morton3D_invert(
torch.nonzero(valid_grid)
) # [Nz] --> [Nz, 3], in [0, H - 1]
# plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf...
valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * (
self.bound - half_grid_size
) # [Nz, 3], in [-b+hgs, b-hgs]
min_pos = valid_pos.amin(0) - half_grid_size # [3]
max_pos = valid_pos.amax(0) + half_grid_size # [3]
# shrink model
reso = torch.LongTensor(self.resolution).to(self.aabb_train.device)
units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso
tl = (min_pos - self.aabb_train[:3]) / units
br = (max_pos - self.aabb_train[:3]) / units
tl = torch.round(tl).long().clamp(min=0)
br = torch.minimum(torch.round(br).long(), reso)
for i in range(len(self.vec_ids)):
vec_id = self.vec_ids[i]
mat_id_0, mat_id_1 = self.mat_ids[i]
self.sigma_vec[i] = nn.Parameter(
self.sigma_vec[i].data[..., tl[vec_id] : br[vec_id], :]
)
self.color_vec[i] = nn.Parameter(
self.color_vec[i].data[..., tl[vec_id] : br[vec_id], :]
)
self.sigma_mat[i] = nn.Parameter(
self.sigma_mat[i].data[
..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]
]
)
self.color_mat[i] = nn.Parameter(
self.color_mat[i].data[
..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0]
]
)
self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6]
print(
f"[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}"
)
print(f"[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}")
# optimizer utils
def get_params(self, lr, lr2=1e-3):
if self.model_type == "hash":
params = [
{"params": self.encoder.parameters(), "lr": lr},
{"params": self.sigma_net.parameters(), "lr": lr},
{"params": self.encoder_dir.parameters(), "lr": lr},
{"params": self.color_net.parameters(), "lr": lr},
]
elif self.model_type == "mlp":
params = [
{"params": self.sigma_net.parameters(), "lr": lr},
{"params": self.encoder_dir.parameters(), "lr": lr},
{"params": self.color_net.parameters(), "lr": lr},
{"params": self.nerf_mlp.parameters(), "lr": lr},
]
elif self.model_type == "vm":
params = [
{"params": self.color_net.parameters(), "lr": lr2},
{"params": self.sigma_mat, "lr": lr},
{"params": self.sigma_vec, "lr": lr},
{"params": self.color_mat, "lr": lr},
{"params": self.color_vec, "lr": lr},
{"params": self.basis_mat.parameters(), "lr": lr2},
]
elif self.model_type == "tensors":
params = [
{"params": self.tensor_volume.parameters(), "lr": lr},
{"params": self.encoder_dir.parameters(), "lr": lr},
]
else:
raise ValueError(f"not illegal model_type:{self.model_type}")
if self.bg_radius > 0:
params.append({"params": self.encoder_bg.parameters(), "lr": lr})
params.append({"params": self.bg_net.parameters(), "lr": lr})
return params
================================================
FILE: just_train_tea/provider.py
================================================
import os
import cv2
import glob
import json
import tqdm
import numpy as np
from scipy.spatial.transform import Slerp, Rotation
import trimesh
import torch
from torch.utils.data import DataLoader
from .utils import get_rays, srgb_to_linear
# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
def nerf_matrix_to_ngp(pose, scale=0.33):
# for the fox dataset, 0.33 scales camera radius to ~ 2
new_pose = np.array(
[
[pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale],
[pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale],
[pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale],
[0, 0, 0, 1],
],
dtype=np.float32,
)
return new_pose
def rand_poses(
size,
device,
radius=1,
theta_range=[np.pi / 3, 2 * np.pi / 3],
phi_range=[0, 2 * np.pi],
):
"""generate random poses from an orbit camera
Args:
size: batch size of generated poses.
device: where to allocate the output.
radius: camera radius
theta_range: [min, max], should be in [0, \pi]
phi_range: [min, max], should be in [0, 2\pi]
Return:
poses: [size, 4, 4]
"""
def normalize(vectors):
return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
thetas = (
torch.rand(size, device=device) * (theta_range[1] - theta_range[0])
+ theta_range[0]
)
phis = (
torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0]
)
centers = torch.stack(
[
radius * torch.sin(thetas) * torch.sin(phis),
radius * torch.cos(thetas),
radius * torch.sin(thetas) * torch.cos(phis),
],
dim=-1,
) # [B, 3]
# lookat
forward_vector = -normalize(centers)
up_vector = (
torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
) # confused at the coordinate system...
right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1))
up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))
poses = (
torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
return poses
def normalize(vectors):
return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
interval_nums = torch.tensor(
[i * 1 / (size - 1) for i in range(size)], dtype=torch.float32, device=device
)
thetas = interval_nums * (theta_range[1] - theta_range[0]) + theta_range[0]
phis = interval_nums * (phi_range[1] - phi_range[0]) + phi_range[0]
centers = torch.stack(
[
radius * torch.sin(thetas) * torch.sin(phis),
radius * torch.cos(thetas),
radius * torch.sin(thetas) * torch.cos(phis),
],
dim=-1,
) # [B, 3]
# lookat
forward_vector = -normalize(centers)
up_vector = (
torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1)
) # confused at the coordinate system...
right_vector = normalize(
torch.cross(forward_vector, up_vector, dim=-1)
) # cross product
up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1))
poses = (
torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
)
poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
poses[:, :3, 3] = centers
return poses
class NeRFDataset:
def __init__(self, opt, device, type="train", downscale=1, n_test=10):
super().__init__()
self.opt = opt
self.args = opt
self.device = device
self.type = type # train, val, test
self.downscale = downscale
self.root_path = opt.path
self.mode = opt.mode # only support blender
self.preload = opt.preload # preload data into GPU
self.scale = (
opt.scale
) # camera radius scale to make sure camera are inside the bounding box.
self.bound = (
opt.bound
) # bounding box half length, also used as the radius to random sample poses.
self.fp16 = opt.fp16 # if preload, load into fp16.
self.training = self.type in ["train", "all", "trainval"]
self.num_rays = self.opt.num_rays if self.training else -1
if self.mode == "blender":
if type == "all":
transform_paths = glob.glob(os.path.join(self.root_path, "*.json"))
transform = None
for transform_path in transform_paths:
with open(transform_path, "r") as f:
tmp_transform = json.load(f)
if transform is None:
transform = tmp_transform
else:
transform["frames"].extend(tmp_transform["frames"])
# load train and val split
elif type == "trainval":
with open(
os.path.join(self.root_path, f"t
gitextract_eepeq680/
├── LICENSE
├── README.md
├── distill_mutual/
│ ├── network.py
│ ├── provider.py
│ ├── renderer.py
│ └── utils.py
├── gridencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── grid.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── gridencoder.cu
│ └── gridencoder.h
├── just_train_tea/
│ ├── network.py
│ ├── provider.py
│ ├── renderer.py
│ └── utils.py
├── main_distill_mutual.py
├── main_just_train_tea.py
├── raymarching/
│ ├── __init__.py
│ ├── backend.py
│ ├── raymarching.py
│ ├── setup.py
│ └── src/
│ ├── bindings.cpp
│ ├── pcg32.h
│ ├── raymarching.cu
│ └── raymarching.h
├── shencoder/
│ ├── __init__.py
│ ├── backend.py
│ ├── setup.py
│ ├── sphere_harmonics.py
│ └── src/
│ ├── bindings.cpp
│ ├── shencoder.cu
│ └── shencoder.h
└── tools/
├── activation.py
├── details.md
├── encoding.py
├── install_extensions.sh
├── requirements.txt
└── 中文介绍.md
SYMBOL INDEX (210 symbols across 24 files)
FILE: distill_mutual/network.py
class NeRFNetwork (line 12) | class NeRFNetwork(NeRFRenderer):
method __init__ (line 13) | def __init__(
method init_plenoxel_volume (line 184) | def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128...
method init_one_vm (line 193) | def init_one_vm(self, n_component, resolution, scale=0.1):
method get_sigma_feat (line 216) | def get_sigma_feat(self, x):
method get_color_feat (line 264) | def get_color_feat(self, x):
method compute_plenoxel_fea (line 311) | def compute_plenoxel_fea(self, x):
method forward_nerf_mlp (line 324) | def forward_nerf_mlp(self, x):
method forward (line 335) | def forward(self, x, d):
method density (line 439) | def density(self, x):
method background (line 496) | def background(self, x, d):
method color (line 515) | def color(self, x, d, mask=None, geo_feat=None, **kwargs):
method density_loss (line 549) | def density_loss(self):
method upsample_params (line 561) | def upsample_params(self, mat, vec, resolution):
method upsample_model (line 584) | def upsample_model(self, resolution):
method shrink_model (line 590) | def shrink_model(self):
method get_params (line 646) | def get_params(self, lr, lr2=1e-3):
FILE: distill_mutual/provider.py
function nerf_matrix_to_ngp (line 18) | def nerf_matrix_to_ngp(pose, scale=0.33):
function rand_poses (line 32) | def rand_poses(
class NeRFDataset (line 123) | class NeRFDataset:
method __init__ (line 124) | def __init__(self, opt, device, type="train", downscale=1, n_test=10):
method collate (line 284) | def collate(self, index):
method dataloader (line 316) | def dataloader(self):
FILE: distill_mutual/renderer.py
function sample_pdf (line 15) | def sample_pdf(bins, weights, n_samples, det=False):
function plot_pointcloud (line 54) | def plot_pointcloud(pc, color=None):
class NeRFRenderer (line 66) | class NeRFRenderer(nn.Module):
method __init__ (line 67) | def __init__(
method forward (line 117) | def forward(self, x, d):
method density (line 121) | def density(self, x):
method color (line 124) | def color(self, x, d, mask=None, **kwargs):
method reset_extra_state (line 127) | def reset_extra_state(self):
method run (line 139) | def run(
method run_cuda (line 319) | def run_cuda(
method mark_untrained_grid (line 562) | def mark_untrained_grid(self, poses, intrinsic, S=64):
method update_extra_state (line 648) | def update_extra_state(self, decay=0.95, S=128):
method render (line 777) | def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **k...
FILE: distill_mutual/utils.py
function update_loss_rate (line 41) | def update_loss_rate(cur_lrate, scale=0.99):
function get_softmax_map_mean (line 45) | def get_softmax_map_mean(a, b):
function get_kl (line 49) | def get_kl(inputs, targets):
function nerf_matrix_to_ngp (line 53) | def nerf_matrix_to_ngp(pose, scale=0.8):
function pose_spherical (line 67) | def pose_spherical(theta, phi, radius):
function get_rand_poses (line 100) | def get_rand_poses(data_type="synthetic", original_loader=None):
function custom_meshgrid (line 201) | def custom_meshgrid(*args):
function linear_to_srgb (line 210) | def linear_to_srgb(x):
function srgb_to_linear (line 215) | def srgb_to_linear(x):
function compute_ssim (line 219) | def compute_ssim(
function init_lpips (line 303) | def init_lpips(net_name, device):
function rgb_lpips (line 317) | def rgb_lpips(gt, im, net_name):
function get_rays (line 325) | def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
function seed_everything (line 407) | def seed_everything(seed):
function torch_vis_2d (line 417) | def torch_vis_2d(x, renormalize=False):
function extract_fields (line 442) | def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
function extract_geometry (line 473) | def extract_geometry(bound_min, bound_max, resolution, threshold, query_...
class PSNRMeter (line 491) | class PSNRMeter:
method __init__ (line 492) | def __init__(self):
method clear (line 497) | def clear(self):
method prepare_inputs (line 502) | def prepare_inputs(self, *inputs):
method update (line 511) | def update(self, preds, truths):
method measure (line 522) | def measure(self):
method write (line 525) | def write(self, writer, global_step, prefix=""):
method report (line 528) | def report(self):
class Trainer (line 532) | class Trainer(object):
method __init__ (line 533) | def __init__(
method __del__ (line 672) | def __del__(self):
method log (line 676) | def log(self, *args, **kwargs):
method train (line 685) | def train(self, train_loader, valid_loader, max_epochs):
method train_one_epoch (line 753) | def train_one_epoch(self, loader):
method get_loss (line 941) | def get_loss(self, pred, gt):
method train_step (line 954) | def train_step(self, data):
method evaluate (line 1193) | def evaluate(self, loader, name=None):
method evaluate_one_epoch (line 1198) | def evaluate_one_epoch(self, loader, name=None):
method eval_step (line 1370) | def eval_step(self, data):
method save_checkpoint (line 1405) | def save_checkpoint(self, name=None, full=False, best=False, remove_ol...
method load_teacher_checkpoint (line 1477) | def load_teacher_checkpoint(self):
method load_student_checkpoint (line 1531) | def load_student_checkpoint(self):
method load_checkpoint (line 1589) | def load_checkpoint(self, checkpoint=None, model_only=False):
method test (line 1653) | def test(self, loader, save_path=None, name=None):
method test_step (line 1703) | def test_step(self, data, bg_color=None, perturb=False):
FILE: gridencoder/backend.py
function find_cl_path (line 20) | def find_cl_path():
FILE: gridencoder/grid.py
class _grid_encode (line 20) | class _grid_encode(Function):
method forward (line 23) | def forward(
method backward (line 96) | def backward(ctx, grad):
class GridEncoder (line 142) | class GridEncoder(nn.Module):
method __init__ (line 143) | def __init__(
method reset_parameters (line 200) | def reset_parameters(self):
method __repr__ (line 204) | def __repr__(self):
method forward (line 207) | def forward(self, inputs, bound=1):
FILE: gridencoder/setup.py
function find_cl_path (line 21) | def find_cl_path():
FILE: gridencoder/src/bindings.cpp
function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: just_train_tea/network.py
class NeRFNetwork (line 12) | class NeRFNetwork(NeRFRenderer):
method __init__ (line 13) | def __init__(
method init_plenoxel_volume (line 184) | def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128...
method init_one_vm (line 193) | def init_one_vm(self, n_component, resolution, scale=0.1):
method get_sigma_feat (line 216) | def get_sigma_feat(self, x):
method get_color_feat (line 264) | def get_color_feat(self, x):
method compute_plenoxel_fea (line 311) | def compute_plenoxel_fea(self, x):
method forward_nerf_mlp (line 320) | def forward_nerf_mlp(self, x):
method forward (line 331) | def forward(self, x, d):
method density (line 422) | def density(self, x):
method background (line 479) | def background(self, x, d):
method color (line 498) | def color(self, x, d, mask=None, geo_feat=None, **kwargs):
method density_loss (line 532) | def density_loss(self):
method upsample_params (line 544) | def upsample_params(self, mat, vec, resolution):
method upsample_model (line 567) | def upsample_model(self, resolution):
method shrink_model (line 573) | def shrink_model(self):
method get_params (line 628) | def get_params(self, lr, lr2=1e-3):
FILE: just_train_tea/provider.py
function nerf_matrix_to_ngp (line 18) | def nerf_matrix_to_ngp(pose, scale=0.33):
function rand_poses (line 32) | def rand_poses(
class NeRFDataset (line 123) | class NeRFDataset:
method __init__ (line 124) | def __init__(self, opt, device, type="train", downscale=1, n_test=10):
method collate (line 284) | def collate(self, index):
method dataloader (line 316) | def dataloader(self):
FILE: just_train_tea/renderer.py
function sample_pdf (line 14) | def sample_pdf(bins, weights, n_samples, det=False):
function plot_pointcloud (line 53) | def plot_pointcloud(pc, color=None):
class NeRFRenderer (line 65) | class NeRFRenderer(nn.Module):
method __init__ (line 66) | def __init__(
method forward (line 116) | def forward(self, x, d):
method density (line 120) | def density(self, x):
method color (line 123) | def color(self, x, d, mask=None, **kwargs):
method reset_extra_state (line 126) | def reset_extra_state(self):
method run (line 138) | def run(
method run_cuda (line 319) | def run_cuda(
method mark_untrained_grid (line 555) | def mark_untrained_grid(self, poses, intrinsic, S=64):
method update_extra_state (line 641) | def update_extra_state(self, decay=0.95, S=128):
method render (line 770) | def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **k...
FILE: just_train_tea/utils.py
function custom_meshgrid (line 36) | def custom_meshgrid(*args):
function linear_to_srgb (line 45) | def linear_to_srgb(x):
function srgb_to_linear (line 50) | def srgb_to_linear(x):
function compute_ssim (line 54) | def compute_ssim(
function init_lpips (line 138) | def init_lpips(net_name, device):
function rgb_lpips (line 152) | def rgb_lpips(gt, im, net_name):
function get_rays (line 160) | def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
function seed_everything (line 242) | def seed_everything(seed):
function torch_vis_2d (line 252) | def torch_vis_2d(x, renormalize=False):
function extract_fields (line 277) | def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
function extract_geometry (line 308) | def extract_geometry(bound_min, bound_max, resolution, threshold, query_...
class PSNRMeter (line 326) | class PSNRMeter:
method __init__ (line 327) | def __init__(self):
method clear (line 331) | def clear(self):
method prepare_inputs (line 335) | def prepare_inputs(self, *inputs):
method update (line 344) | def update(self, preds, truths):
method measure (line 355) | def measure(self):
method write (line 358) | def write(self, writer, global_step, prefix=""):
method report (line 361) | def report(self):
class Trainer (line 365) | class Trainer(object):
method __init__ (line 366) | def __init__(
method __del__ (line 487) | def __del__(self):
method log (line 491) | def log(self, *args, **kwargs):
method train (line 500) | def train(self, train_loader, valid_loader, max_epochs):
method train_one_epoch (line 543) | def train_one_epoch(self, loader):
method get_loss (line 733) | def get_loss(self, pred, gt):
method train_step (line 746) | def train_step(self, data):
method evaluate (line 848) | def evaluate(self, loader, name=None):
method evaluate_one_epoch (line 853) | def evaluate_one_epoch(self, loader, name=None):
method eval_step (line 1028) | def eval_step(self, data):
method save_checkpoint (line 1063) | def save_checkpoint(self, name=None, full=False, best=False, remove_ol...
method load_teacher_checkpoint (line 1135) | def load_teacher_checkpoint(self):
method load_student_checkpoint (line 1158) | def load_student_checkpoint(self):
method test (line 1187) | def test(self, loader, save_path=None, name=None):
method test_step (line 1237) | def test_step(self, data, bg_color=None, perturb=False):
FILE: main_distill_mutual.py
function save_codes_env (line 15) | def save_codes_env(workspace):
function load_from_txt (line 24) | def load_from_txt(opt, except_space=""):
FILE: raymarching/backend.py
function find_cl_path (line 20) | def find_cl_path():
FILE: raymarching/raymarching.py
class _near_far_from_aabb (line 20) | class _near_far_from_aabb(Function):
method forward (line 23) | def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
class _polar_from_ray (line 56) | class _polar_from_ray(Function):
method forward (line 59) | def forward(ctx, rays_o, rays_d, radius):
class _morton3D (line 90) | class _morton3D(Function):
method forward (line 92) | def forward(ctx, coords):
class _morton3D_invert (line 116) | class _morton3D_invert(Function):
method forward (line 118) | def forward(ctx, indices):
class _packbits (line 141) | class _packbits(Function):
method forward (line 144) | def forward(ctx, grid, thresh, bitfield=None):
class _march_rays_train (line 176) | class _march_rays_train(Function):
method forward (line 179) | def forward(
class _composite_rays_train (line 292) | class _composite_rays_train(Function):
method forward (line 295) | def forward(ctx, sigmas, rgbs, deltas, rays):
method backward (line 329) | def backward(ctx, grad_weights_sum, grad_depth, grad_image):
class _march_rays (line 367) | class _march_rays(Function):
method forward (line 370) | def forward(
class _composite_rays (line 457) | class _composite_rays(Function):
method forward (line 460) | def forward(
class _compact_rays (line 505) | class _compact_rays(Function):
method forward (line 508) | def forward(
FILE: raymarching/setup.py
function find_cl_path (line 21) | def find_cl_path():
FILE: raymarching/src/bindings.cpp
function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: raymarching/src/pcg32.h
type pcg32 (line 44) | struct pcg32 {
function next_uint (line 66) | uint32_t next_uint() {
function next_uint (line 75) | uint32_t next_uint(uint32_t bound) {
function next_float (line 107) | float next_float() {
function next_double (line 125) | double next_double() {
function operator (line 198) | bool operator==(const pcg32 &other) const { return state == other.state ...
function operator (line 201) | bool operator!=(const pcg32 &other) const { return state != other.state ...
FILE: shencoder/backend.py
function find_cl_path (line 20) | def find_cl_path():
FILE: shencoder/setup.py
function find_cl_path (line 21) | def find_cl_path():
FILE: shencoder/sphere_harmonics.py
class _sh_encoder (line 15) | class _sh_encoder(Function):
method forward (line 18) | def forward(ctx, inputs, degree, calc_grad_inputs=False):
method backward (line 48) | def backward(ctx, grad):
class SHEncoder (line 67) | class SHEncoder(nn.Module):
method __init__ (line 68) | def __init__(self, input_dim=3, degree=4):
method __repr__ (line 80) | def __repr__(self):
method forward (line 83) | def forward(self, inputs, size=1):
FILE: shencoder/src/bindings.cpp
function PYBIND11_MODULE (line 5) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: tools/activation.py
class _trunc_exp (line 6) | class _trunc_exp(Function):
method forward (line 9) | def forward(ctx, x):
method backward (line 15) | def backward(ctx, g):
FILE: tools/encoding.py
class FreqEncoder (line 6) | class FreqEncoder(nn.Module):
method __init__ (line 7) | def __init__(
method forward (line 36) | def forward(self, input, **kwargs):
function get_encoder (line 52) | def get_encoder(
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (444K chars).
[
{
"path": "LICENSE",
"chars": 1051,
"preview": "Copyright 2022 Megvii Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this softwar"
},
{
"path": "README.md",
"chars": 5416,
"preview": "## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation (AAAI "
},
{
"path": "distill_mutual/network.py",
"chars": 24261,
"preview": "import torch\nfrom time import time\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tools.encoding import get"
},
{
"path": "distill_mutual/provider.py",
"chars": 11693,
"preview": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport numpy as np\nfrom scipy.spatial.transform import Slerp, R"
},
{
"path": "distill_mutual/renderer.py",
"chars": 31012,
"preview": "import math\nimport trimesh\nimport numpy as np\nfrom time import time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn."
},
{
"path": "distill_mutual/utils.py",
"chars": 64507,
"preview": "import os\nimport copy\nimport lpips\nimport glob\nimport tqdm\nimport math\nimport random\nimport warnings\nimport tensorboardX"
},
{
"path": "gridencoder/__init__.py",
"chars": 30,
"preview": "from .grid import GridEncoder\n"
},
{
"path": "gridencoder/backend.py",
"chars": 1462,
"preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
},
{
"path": "gridencoder/grid.py",
"chars": 7467,
"preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function "
},
{
"path": "gridencoder/setup.py",
"chars": 1837,
"preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
},
{
"path": "gridencoder/src/bindings.cpp",
"chars": 275,
"preview": "#include <torch/extension.h>\n\n#include \"gridencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n m.def(\"grid_encod"
},
{
"path": "gridencoder/src/gridencoder.cu",
"chars": 19378,
"preview": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/t"
},
{
"path": "gridencoder/src/gridencoder.h",
"chars": 968,
"preview": "#ifndef _HASH_ENCODE_H\n#define _HASH_ENCODE_H\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, i"
},
{
"path": "just_train_tea/network.py",
"chars": 23398,
"preview": "import torch\nfrom time import time\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom tools.encoding import get"
},
{
"path": "just_train_tea/provider.py",
"chars": 11693,
"preview": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport numpy as np\nfrom scipy.spatial.transform import Slerp, R"
},
{
"path": "just_train_tea/renderer.py",
"chars": 30911,
"preview": "import math\nimport trimesh\nimport numpy as np\nfrom time import time\n\nimport torch\nimport torch.nn as nn\nimport torch.nn."
},
{
"path": "just_train_tea/utils.py",
"chars": 45185,
"preview": "import os\nimport lpips\nimport glob\nimport tqdm\nimport math\nimport random\nimport warnings\nimport tensorboardX\n\nimport num"
},
{
"path": "main_distill_mutual.py",
"chars": 14169,
"preview": "import torch\nimport os\nimport argparse\n\nfrom distill_mutual.network import NeRFNetwork\nfrom functools import partial\nfro"
},
{
"path": "main_just_train_tea.py",
"chars": 12406,
"preview": "import torch\nimport os\nimport argparse\n\nfrom just_train_tea.network import NeRFNetwork\n\nfrom functools import partial\nfr"
},
{
"path": "raymarching/__init__.py",
"chars": 27,
"preview": "from .raymarching import *\n"
},
{
"path": "raymarching/backend.py",
"chars": 1461,
"preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
},
{
"path": "raymarching/raymarching.py",
"chars": 16276,
"preview": "import numpy as np\nimport time\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.cuda.a"
},
{
"path": "raymarching/setup.py",
"chars": 2273,
"preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
},
{
"path": "raymarching/src/bindings.cpp",
"chars": 974,
"preview": "#include <torch/extension.h>\n\n#include \"raymarching.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n // utils\n m.de"
},
{
"path": "raymarching/src/pcg32.h",
"chars": 6904,
"preview": "/*\n * Tiny self-contained version of the PCG Random Number Generation for C++\n * put together from pieces of the much la"
},
{
"path": "raymarching/src/raymarching.cu",
"chars": 32094,
"preview": "#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext.h>\n#include <torch/t"
},
{
"path": "raymarching/src/raymarching.h",
"chars": 2198,
"preview": "#pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n\nvoid near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_"
},
{
"path": "shencoder/__init__.py",
"chars": 40,
"preview": "from .sphere_harmonics import SHEncoder\n"
},
{
"path": "shencoder/backend.py",
"chars": 1458,
"preview": "import os\nfrom torch.utils.cpp_extension import load\n\n_src_path = os.path.dirname(os.path.abspath(__file__))\n\nnvcc_flags"
},
{
"path": "shencoder/setup.py",
"chars": 1831,
"preview": "import os\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n_src_path = "
},
{
"path": "shencoder/sphere_harmonics.py",
"chars": 2909,
"preview": "import numpy as np\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function "
},
{
"path": "shencoder/src/bindings.cpp",
"chars": 261,
"preview": "#include <torch/extension.h>\n\n#include \"shencoder.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n m.def(\"sh_encode_fo"
},
{
"path": "shencoder/src/shencoder.cu",
"chars": 37210,
"preview": "#include <stdint.h>\n\n#include <cuda.h>\n#include <cuda_fp16.h>\n#include <cuda_runtime.h>\n\n#include <ATen/cuda/CUDAContext"
},
{
"path": "shencoder/src/shencoder.h",
"chars": 628,
"preview": "# pragma once\n\n#include <stdint.h>\n#include <torch/torch.h>\n\n// inputs: [B, D], float, in [-1, 1]\n// outputs: [B, F], fl"
},
{
"path": "tools/activation.py",
"chars": 516,
"preview": "import torch\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd\n\n\nclass _trunc_exp(Fu"
},
{
"path": "tools/details.md",
"chars": 2437,
"preview": "# custom datasets\n\nOur dataset format is based on the [torch-ngp](https://github.com/ashawkey/torch-ngp/tree/3b066b6cd6c"
},
{
"path": "tools/encoding.py",
"chars": 3145,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FreqEncoder(nn.Module):\n def __init__(\n "
},
{
"path": "tools/install_extensions.sh",
"chars": 106,
"preview": "cd raymarching\npip install .\ncd ..\n\ncd gridencoder\npip install .\ncd ..\n\ncd shencoder\npip install .\ncd .. \n"
},
{
"path": "tools/requirements.txt",
"chars": 147,
"preview": "torch-ema\nninja\ntrimesh\nopencv-python\ntensorboardX\ntorch\nnumpy \npandas\ntqdm\nmatplotlib\nPyMCubes\nrich\npysdf\ndearpygui\npac"
},
{
"path": "tools/中文介绍.md",
"chars": 5090,
"preview": "## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation\n(**Acc"
}
]
About this extraction
This page contains the full source code of the megvii-research/AAAI2023-PVD GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (415.1 KB), approximately 116.2k tokens, and a symbol index with 210 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.