Repository: megvii-research/AAAI2023-PVD Branch: main Commit: e5d0ab174f24 Files: 40 Total size: 415.1 KB Directory structure: gitextract_eepeq680/ ├── LICENSE ├── README.md ├── distill_mutual/ │ ├── network.py │ ├── provider.py │ ├── renderer.py │ └── utils.py ├── gridencoder/ │ ├── __init__.py │ ├── backend.py │ ├── grid.py │ ├── setup.py │ └── src/ │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── just_train_tea/ │ ├── network.py │ ├── provider.py │ ├── renderer.py │ └── utils.py ├── main_distill_mutual.py ├── main_just_train_tea.py ├── raymarching/ │ ├── __init__.py │ ├── backend.py │ ├── raymarching.py │ ├── setup.py │ └── src/ │ ├── bindings.cpp │ ├── pcg32.h │ ├── raymarching.cu │ └── raymarching.h ├── shencoder/ │ ├── __init__.py │ ├── backend.py │ ├── setup.py │ ├── sphere_harmonics.py │ └── src/ │ ├── bindings.cpp │ ├── shencoder.cu │ └── shencoder.h └── tools/ ├── activation.py ├── details.md ├── encoding.py ├── install_extensions.sh ├── requirements.txt └── 中文介绍.md ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ Copyright 2022 Megvii Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ ## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation (AAAI Oral) # :partying_face: ***New*** :partying_face: Code for more powerful PVD-AL (the follow-up work of PVD) is now provided [here](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL). *(We strongly recommend using [PVD-AL](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL) (the follow-up work of pvd) with better performance).* ## [Project Page](http://sk-fun.fun/PVD/) | [Paper](https://arxiv.org/abs/2211.15977) | [Datasets](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing) | [Ckpts](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing) | [Chinese tutorial](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/%E4%B8%AD%E6%96%87%E4%BB%8B%E7%BB%8D.md) | [zhihu](https://zhuanlan.zhihu.com/p/605121286)| ## Introduction In this paper, we propose Progressive Volume Distillation (PVD), a systematic distillation method that allows any-to-any conversions between different neural architectures, including MLP(NeRF), sparse(Plenoxels) or low-rank tensors(TensoRF), hash tables(INGP). ## Installation We recommend using [Anaconda](https://www.anaconda.com/) to setup the environment. Run the following commands: *Step1*: Create a conda environment named 'pvd' ``` conda create --name pvd python=3.7 conda activate pvd pip install -r ./tools/requirements.txt ``` *Step2*: Install extension modules. (Draw from the great project [torch-ngp](https://github.com/ashawkey/torch-ngp) that we mainly rely on.) ``` bash ./tools/install_extensions.sh ``` ## Datastes & Pretrained-teacher models You can download Synthetic-NeRF/LLFF/Tanks&Temples datasets from [google](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing), or from [baidu](https://pan.baidu.com/s/1ky_TWrbUZG_MpHTBhncAKA?pwd=4h2h). And download pretrained-teacher-models from [google](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing), or from [baidu](https://pan.baidu.com/s/1LGLXwLGusX60GpAywLwosg?pwd=34k8). You can also train a teacher model according to the follow guidance. ## Train a teacher ``` # train a hash-based(INGP) teacher python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type hash --data_type synthetic --workspace ./log/train_teacher/hash_chair # train a sparse-tensor-based(TensoRF VM-decomposion) teacher python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type vm --data_type synthetic --workspace ./log/train_teacher/vm_chair # train a MLP-based(NeRF) teacher python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type mlp --data_type synthetic --workspace ./log/train_teacher/mlp_chair # train a tensors-based(Plenoxels) teacher python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type tensors --data_type synthetic --workspace ./log/train_teacher/tensors_chair ``` ## Distill a student ``` # teacher: hash(INGP), student: vm(tensoRF) python3 main_distill_mutual.py ./data/nerf_synthetic/chair \ --data_type synthetic \ --teacher_type hash \ --ckpt_teacher ./log/train_teacher/hash_chair/checkpoints/XXX.pth \ --model_type vm \ --workspace ./log/distill_student/hash2vm/chair # teacher: MLP(NeRF), student: tensors(Plenoxels) python3 main_distill_mutual.py ./data/nerf_synthetic/chair \ --data_type synthetic \ --teacher_type mlp \ --ckpt_teacher ./log/train_teacher/mlp_chair/checkpoints/XXX.pth \ --model_type tensors \ --workspace ./log/distill_student/mlp2tensors/chair ``` ## Evaluation ``` # evaluate a hash teacher python main_distill_mutual.py ./data/nerf_synthetic/chair --teacher_type hash --ckpt_teacher PATH/TO/CKPT.pth --test_teacher --data_type synthetic --workspace ./log/eval_teacher/hash_chair # evaluate a mlp student python main_distill_mutual.py ./data/nerf_synthetic/chair --model_type mlp --ckpt PATH/TO/CKPT.pth --test --data_type synthetic --workspace ./log/eval_student/mlp_chair ``` ## More detailed parameter description and running commonds Please refer to [more running description](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/details.md) for details of training different types of datasets, parameter adjustment, key settings, etc. ## Citation If you find our code or paper useful, please consider citing ``` @article{fang2022one, title={One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation}, author={Fang, Shuangkang and Xu, Weixin and Wang, Heng and Yang, Yi and Wang, Yufeng and Zhou, Shuchang}, journal={arXiv preprint arXiv:2211.15977}, year={2022} } ``` ### Acknowledgement We would like to thank [ingp](https://github.com/NVlabs/instant-ngp), [torch-ngp](https://github.com/ashawkey/torch-ngp), [TensoRF](https://github.com/apchenstu/TensoRF), [Plenoxels](https://github.com/sxyu/svox2), [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch) for their great frameworks! Also check out [Arch-Net](https://github.com/megvii-research/Arch-Net) for more on general progressive distillation. ================================================ FILE: distill_mutual/network.py ================================================ import torch from time import time import torch.nn as nn import torch.nn.functional as F from tools.encoding import get_encoder from tools.activation import trunc_exp from .renderer import NeRFRenderer import raymarching class NeRFNetwork(NeRFRenderer): def __init__( self, encoding="hashgrid", encoding_dir="sphere_harmonics", encoding_bg="hashgrid", num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, num_layers_bg=2, hidden_dim_bg=64, bound=1, model_type="hash", args=None, is_teacher=False, **kwargs, ): super().__init__(bound, **kwargs) # sigma network assert model_type in ["hash", "mlp", "vm", "tensors"] self.is_teacher = is_teacher self.num_layers = num_layers self.hidden_dim = hidden_dim self.geo_feat_dim = geo_feat_dim self.args = args self.opt = args self.model_type = model_type self.plenoxel_degree = args.plenoxel_degree self.plenoxel_res = eval(args.plenoxel_res) assert len(self.plenoxel_res) == 3 self.encoder, self.in_dim = get_encoder( encoding, desired_resolution=2048 * bound, num_levels=14, ) if "hash" != self.model_type: self.encoder = None if self.model_type == "mlp": self.encoder_nerf_pe, self.in_dim_nerf = get_encoder( encoding="frequency", multires=self.args.PE ) self.skips = self.args.skip self.nerf_layer_num = self.args.nerf_layer_num W = self.args.nerf_layer_wide self.nerf_mlp = [nn.Linear(self.in_dim_nerf, W)] for i in range(self.nerf_layer_num - 2): if i != self.skips: self.nerf_mlp.append(nn.Linear(W, W)) else: self.nerf_mlp.append(nn.Linear(W + self.in_dim_nerf, W)) self.nerf_mlp.append(nn.Linear(W, self.in_dim)) self.nerf_mlp = nn.ModuleList(self.nerf_mlp) elif self.model_type == "vm": self.sigma_rank = [16] * 3 self.color_rank = [48] * 3 self.color_feat_dim = 15 # geo_feat_dim self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0] self.resolution = [self.opt.resolution0] * 3 # mat: paralist[1,16,res0,res0] repeat 3 vec: paralist[1,16,res0,1] repeat 3; repeat3 because decompose 3D grid [H, W, D] to three 2D mat [H, W], [H,D], [W, D] or decompose to three 1D vec [H], [W], [D] self.sigma_mat, self.sigma_vec = self.init_one_vm( self.sigma_rank, self.resolution ) # mat: paralist[1,48,res0,res0] repeat 3 vec: paralist[1,48,res0,1] repeat 3 self.color_mat, self.color_vec = self.init_one_vm( self.color_rank, self.resolution ) # Linear(in_features=144, out_features=27) self.basis_mat = nn.Linear( sum(self.color_rank), self.color_feat_dim, bias=False ) elif self.model_type == "tensors": self.init_plenoxel_volume( s=0.02, fea_dim=self.plenoxel_degree ** 2 * 3 + 1, volume=self.plenoxel_res, ) elif self.model_type == "hash": pass else: raise ValueError(f"error model_type:{self.model_type}") if self.model_type != "vm" and self.model_type != "tensors": sigma_net = [] for l in range(num_layers): if l == 0: in_dim = self.in_dim else: in_dim = hidden_dim if l == num_layers - 1: out_dim = ( 1 + self.geo_feat_dim ) # 1 sigma + 15 SH features for color else: out_dim = hidden_dim sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.sigma_net = nn.ModuleList(sigma_net) # color network self.num_layers_color = num_layers_color self.hidden_dim_color = hidden_dim_color # self.encoder_dir, self.in_dim_dir = get_encoder(encoding=encoding_dir) if self.model_type == "tensors": self.encoder_dir, self.in_dim_dir = get_encoder( encoding="sphere_harmonics", degree=self.plenoxel_degree, ) else: self.encoder_dir, self.in_dim_dir = get_encoder( encoding=encoding_dir, input_dim=3, multires=2 ) if self.model_type != "tensors": color_net = [] for l in range(num_layers_color): if l == 0: in_dim = self.in_dim_dir + self.geo_feat_dim else: in_dim = hidden_dim if l == num_layers_color - 1: out_dim = 3 # 3 rgb else: out_dim = hidden_dim color_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.color_net = nn.ModuleList(color_net) # background network if self.bg_radius > 0: self.num_layers_bg = num_layers_bg self.hidden_dim_bg = hidden_dim_bg self.encoder_bg, self.in_dim_bg = get_encoder( encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048, ) # much smaller hashgrid bg_net = [] for l in range(num_layers_bg): if l == 0: in_dim = self.in_dim_bg + self.in_dim_dir else: in_dim = hidden_dim_bg if l == num_layers_bg - 1: out_dim = 3 # 3 rgb else: out_dim = hidden_dim_bg bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.bg_net = nn.ModuleList(bg_net) else: self.bg_net = None def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128, 128]): tensor = [] tensor.append( torch.nn.Parameter( s * torch.randn((1, fea_dim, volume[0], volume[1], volume[2])) ) ) self.tensor_volume = torch.nn.ParameterList(tensor).cuda() def init_one_vm(self, n_component, resolution, scale=0.1): # self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0] mat, vec = [], [] for i in range(len(self.vec_ids)): vec_id = self.vec_ids[i] mat_id_0, mat_id_1 = self.mat_ids[i] mat.append( nn.Parameter( scale * torch.randn( (1, n_component[i], resolution[mat_id_1], resolution[mat_id_0]) ) ) ) # [1, R, H, W] vec.append( nn.Parameter( scale * torch.randn((1, n_component[i], resolution[vec_id], 1)) ) ) # [1, R, D, 1] (fake 2d to use grid_sample) return nn.ParameterList(mat), nn.ParameterList(vec) def get_sigma_feat(self, x): # x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode) # self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0] N = x.shape[0] # plane + line basis mat_coord = ( torch.stack( ( x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]], ) ) .detach() .view(3, -1, 1, 2) ) # [3, N, 1, 2] vec_coord = torch.stack( (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]]) ) vec_coord = ( torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1) .detach() .view(3, -1, 1, 2) ) # [3, N, 1, 2], fake 2d coord sigma_feat = torch.zeros( [ N, ], device=x.device, ) for i in range(len(self.sigma_mat)): mat_feat = F.grid_sample( self.sigma_mat[i], mat_coord[[i]], align_corners=True ).view( -1, N ) # [1, R, N, 1] --> [R, N] vec_feat = F.grid_sample( self.sigma_vec[i], vec_coord[[i]], align_corners=True ).view( -1, N ) # [R, N] sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0) return sigma_feat def get_color_feat(self, x): # x: [N, 3], in [-1, 1] N = x.shape[0] # plane + line basis mat_coord = ( torch.stack( ( x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]], ) ) .detach() .view(3, -1, 1, 2) ) # [3, N, 1, 2] vec_coord = torch.stack( (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]]) ) vec_coord = ( torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1) .detach() .view(3, -1, 1, 2) ) # [3, N, 1, 2], fake 2d coord mat_feat, vec_feat = [], [] for i in range(len(self.color_mat)): mat_feat.append( F.grid_sample( self.color_mat[i], mat_coord[[i]], align_corners=True ).view(-1, N) ) # [1, R, N, 1] --> [R, N] vec_feat.append( F.grid_sample( self.color_vec[i], vec_coord[[i]], align_corners=True ).view(-1, N) ) # [R, N] mat_feat = torch.cat(mat_feat, dim=0) # [3 * R, N] vec_feat = torch.cat(vec_feat, dim=0) # [3 * R, N] color_feat = self.basis_mat( (mat_feat * vec_feat).T ) # [N, 3R] --> [N, color_feat_dim] return color_feat def compute_plenoxel_fea(self, x): composed = self.tensor_volume[0] if self.args.enable_edit_plenoxel and self.is_teacher: composed[ :, 0, :, 160:, :128 ] = -100 # This will erase the bucket in the lego scene for resolution 256 composed = ( F.grid_sample(composed, x.view(1, 1, -1, 1, 3), align_corners=True) .view(-1, x.shape[0]) .permute(1, 0) ) return composed # [N, fea_dim] def forward_nerf_mlp(self, x): x = self.encoder_nerf_pe(x) in_pts = x for i in range(len(self.nerf_mlp)): x = self.nerf_mlp[i](x) if i != len(self.nerf_mlp) - 1: x = F.relu(x, inplace=True) if i == self.skips: x = torch.cat([in_pts, x], -1) return x def forward(self, x, d): # x: [N, 3], in [-bound, bound] d: [N, 3], nomalized in [-1, 1] # sigma if self.model_type == "hash": x = self.encoder( x, bound=self.bound ) # out_x[N, 28=num_levels * fea_per_level] elif self.model_type == "mlp": x = self.forward_nerf_mlp(x) # 28 elif self.model_type == "vm": x = ( 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 ) # x:[N, 3] sigma_feat = self.get_sigma_feat(x) # sigma_feat:[N] color_feat = self.get_color_feat(x) # color_feat:[N, 15] if self.opt.enable_edit_plenoxel: sigma_feat = torch.clamp(sigma_feat, -100, self.args.sigma_clip_max) else: sigma_feat = torch.clamp( sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max ) color_feat = torch.clamp( color_feat, self.args.sigma_clip_min, self.args.sigma_clip_max ) self.feature_sigma_color = torch.cat( [sigma_feat.unsqueeze(-1), color_feat], dim=-1 ) if ( self.training and self.args.global_step < self.args.stage_iters["stage1"] ): return None, None self.sigma_l = sigma_feat sigma = trunc_exp(sigma_feat) # sigma:[N] enc_d = self.encoder_dir(d) # enc_d:[N, 16] h = torch.cat([enc_d, color_feat], dim=-1) # h:[N, 16+15] for l in range(self.num_layers_color): h = self.color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) color = torch.sigmoid(h) self.color_l = color return sigma, color elif self.model_type == "tensors": x = ( 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 ) # x:[N, 3] x = self.compute_plenoxel_fea(x) h = x if self.opt.enable_edit_plenoxel: sigma = torch.clamp(h[..., 0], -100, self.args.sigma_clip_max) else: sigma = torch.clamp( h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max ) self.sigma_l = sigma sigma = trunc_exp(sigma) self.sigma = sigma sh = h[..., 1:].view( -1, 3, self.plenoxel_degree ** 2 ) # [N, 3, 9] ## .permute(1, 0, 2) # [B, 27]-->[9, B, 3] enc_d = self.encoder_dir(d).unsqueeze(1) # [N, 9]-->[N,1,9] color = (sh * enc_d).sum(-1) # [N, 3] color = torch.sigmoid(color) self.feature_sigma_color = None self.color_l = color return sigma, color else: raise ValueError(f"not illegal model_type:{self.model_type}") h = x for l in range(self.num_layers): h = self.sigma_net[l](h) if l != self.num_layers - 1: h = F.relu(h, inplace=True) h[..., 0] = torch.clamp( h[..., 0].clone(), self.args.sigma_clip_min, self.args.sigma_clip_max ) self.feature_sigma_color = h if self.training and self.args.global_step < self.args.stage_iters["stage1"]: return None, None self.sigma_l = h[..., 0] sigma = trunc_exp(h[..., 0]) # sigma: [n] geo_feat = h[..., 1:] # geo_feat: [n, 15] d = self.encoder_dir(d) # d: [n, 16] h = torch.cat([d, geo_feat], dim=-1) # h: [n, 15+16] for l in range(self.num_layers_color): h = self.color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) color = torch.sigmoid(h) self.color_l = color return sigma, color def density(self, x): # x: [N, 3], in [-bound, bound] if self.model_type == "hash": x = self.encoder( x, bound=self.bound ) # out_x[N, 32=num_levels * fea_per_level] elif self.model_type == "mlp": x = self.forward_nerf_mlp(x) elif self.model_type == "vm": x = ( 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 ) sigma_feat = self.get_sigma_feat(x) sigma_feat = torch.clamp( sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max ) sigma = trunc_exp(sigma_feat) return {"sigma": sigma} elif self.model_type == "tensors": x = ( 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 ) # x:[N, 3] x = self.compute_plenoxel_fea(x) h = x # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max) sigma = trunc_exp( torch.clamp( h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max ) ) sigma = trunc_exp(h[..., 0]) return {"sigma": sigma} else: raise ValueError(f"not illegal model_type:{self.model_type}") h = x for l in range(self.num_layers): h = self.sigma_net[l](h) if l != self.num_layers - 1: h = F.relu(h, inplace=True) h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max) sigma = trunc_exp(h[..., 0]) geo_feat = h[..., 1:] return { "sigma": sigma, "geo_feat": geo_feat, } def background(self, x, d): assert 1 == 2 # x: [N, 2], in [-1, 1] h = self.encoder_bg(x) # [N, C] d = self.encoder_dir(d) h = torch.cat([d, h], dim=-1) for l in range(self.num_layers_bg): h = self.bg_net[l](h) if l != self.num_layers_bg - 1: h = F.relu(h, inplace=True) # sigmoid activation for rgb rgbs = torch.sigmoid(h) return rgbs # allow masked inference def color(self, x, d, mask=None, geo_feat=None, **kwargs): assert 1 == 2 # x: [N, 3] in [-bound, bound] # mask: [N,], bool, indicates where we actually needs to compute rgb. if mask is not None: rgbs = torch.zeros( mask.shape[0], 3, dtype=x.dtype, device=x.device ) # [N, 3] # in case of empty mask if not mask.any(): return rgbs x = x[mask] d = d[mask] geo_feat = geo_feat[mask] d = self.encoder_dir(d) h = torch.cat([d, geo_feat], dim=-1) for l in range(self.num_layers_color): h = self.color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) # sigmoid activation for rgb h = torch.sigmoid(h) if mask is not None: rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 else: rgbs = h return rgbs # L1 penalty for loss def density_loss(self): loss = 0 for i in range(len(self.sigma_mat)): loss = ( loss + torch.mean(torch.abs(self.sigma_mat[i])) + torch.mean(torch.abs(self.sigma_vec[i])) ) return loss # upsample utils @torch.no_grad() def upsample_params(self, mat, vec, resolution): for i in range(len(self.vec_ids)): vec_id = self.vec_ids[i] mat_id_0, mat_id_1 = self.mat_ids[i] mat[i] = nn.Parameter( F.interpolate( mat[i].data, size=(resolution[mat_id_1], resolution[mat_id_0]), mode="bilinear", align_corners=True, ) ) vec[i] = nn.Parameter( F.interpolate( vec[i].data, size=(resolution[vec_id], 1), mode="bilinear", align_corners=True, ) ) @torch.no_grad() def upsample_model(self, resolution): self.upsample_params(self.sigma_mat, self.sigma_vec, resolution) self.upsample_params(self.color_mat, self.color_vec, resolution) self.resolution = resolution @torch.no_grad() def shrink_model(self): # shrink aabb_train and the model so it only represents the space inside aabb_train. half_grid_size = self.bound / self.grid_size thresh = min(self.density_thresh, self.mean_density) # get new aabb from the coarsest density grid (TODO: from the finest that covers current aabb?) valid_grid = self.density_grid[self.cascade - 1] > thresh # [N] valid_pos = raymarching.morton3D_invert( torch.nonzero(valid_grid) ) # [Nz] --> [Nz, 3], in [0, H - 1] # plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf... valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * ( self.bound - half_grid_size ) # [Nz, 3], in [-b+hgs, b-hgs] min_pos = valid_pos.amin(0) - half_grid_size # [3] max_pos = valid_pos.amax(0) + half_grid_size # [3] # shrink model reso = torch.LongTensor(self.resolution).to(self.aabb_train.device) units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso tl = (min_pos - self.aabb_train[:3]) / units br = (max_pos - self.aabb_train[:3]) / units tl = torch.round(tl).long().clamp(min=0) br = torch.minimum(torch.round(br).long(), reso) for i in range(len(self.vec_ids)): vec_id = self.vec_ids[i] mat_id_0, mat_id_1 = self.mat_ids[i] self.sigma_vec[i] = nn.Parameter( self.sigma_vec[i].data[..., tl[vec_id] : br[vec_id], :] ) self.color_vec[i] = nn.Parameter( self.color_vec[i].data[..., tl[vec_id] : br[vec_id], :] ) self.sigma_mat[i] = nn.Parameter( self.sigma_mat[i].data[ ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0] ] ) self.color_mat[i] = nn.Parameter( self.color_mat[i].data[ ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0] ] ) self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6] print( f"[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}" ) print(f"[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}") # optimizer utils def get_params(self, lr, lr2=1e-3): if self.model_type == "hash": params = [ {"params": self.encoder.parameters(), "lr": lr}, {"params": self.sigma_net.parameters(), "lr": lr}, {"params": self.encoder_dir.parameters(), "lr": lr}, {"params": self.color_net.parameters(), "lr": lr}, ] elif self.model_type == "mlp": params = [ {"params": self.sigma_net.parameters(), "lr": lr}, {"params": self.encoder_dir.parameters(), "lr": lr}, {"params": self.color_net.parameters(), "lr": lr}, {"params": self.nerf_mlp.parameters(), "lr": lr}, ] elif self.model_type == "vm": params = [ {"params": self.color_net.parameters(), "lr": lr2}, {"params": self.sigma_mat, "lr": lr}, {"params": self.sigma_vec, "lr": lr}, {"params": self.color_mat, "lr": lr}, {"params": self.color_vec, "lr": lr}, {"params": self.basis_mat.parameters(), "lr": lr2}, ] elif self.model_type == "tensors": params = [ {"params": self.tensor_volume.parameters(), "lr": lr}, {"params": self.encoder_dir.parameters(), "lr": lr}, ] else: raise ValueError(f"not illegal model_type:{self.model_type}") if self.bg_radius > 0: params.append({"params": self.encoder_bg.parameters(), "lr": lr}) params.append({"params": self.bg_net.parameters(), "lr": lr}) return params ================================================ FILE: distill_mutual/provider.py ================================================ import os import cv2 import glob import json import tqdm import numpy as np from scipy.spatial.transform import Slerp, Rotation import trimesh import torch from torch.utils.data import DataLoader from .utils import get_rays, srgb_to_linear # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 def nerf_matrix_to_ngp(pose, scale=0.33): # for the fox dataset, 0.33 scales camera radius to ~ 2 new_pose = np.array( [ [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale], [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale], [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale], [0, 0, 0, 1], ], dtype=np.float32, ) return new_pose def rand_poses( size, device, radius=1, theta_range=[np.pi / 3, 2 * np.pi / 3], phi_range=[0, 2 * np.pi], ): """generate random poses from an orbit camera Args: size: batch size of generated poses. device: where to allocate the output. radius: camera radius theta_range: [min, max], should be in [0, \pi] phi_range: [min, max], should be in [0, 2\pi] Return: poses: [size, 4, 4] """ def normalize(vectors): return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) thetas = ( torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] ) phis = ( torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] ) centers = torch.stack( [ radius * torch.sin(thetas) * torch.sin(phis), radius * torch.cos(thetas), radius * torch.sin(thetas) * torch.cos(phis), ], dim=-1, ) # [B, 3] # lookat forward_vector = -normalize(centers) up_vector = ( torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) ) # confused at the coordinate system... right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1)) up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) poses = ( torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) ) poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) poses[:, :3, 3] = centers return poses def normalize(vectors): return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) interval_nums = torch.tensor( [i * 1 / (size - 1) for i in range(size)], dtype=torch.float32, device=device ) thetas = interval_nums * (theta_range[1] - theta_range[0]) + theta_range[0] phis = interval_nums * (phi_range[1] - phi_range[0]) + phi_range[0] centers = torch.stack( [ radius * torch.sin(thetas) * torch.sin(phis), radius * torch.cos(thetas), radius * torch.sin(thetas) * torch.cos(phis), ], dim=-1, ) # [B, 3] # lookat forward_vector = -normalize(centers) up_vector = ( torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) ) # confused at the coordinate system... right_vector = normalize( torch.cross(forward_vector, up_vector, dim=-1) ) # cross product up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) poses = ( torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) ) poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) poses[:, :3, 3] = centers return poses class NeRFDataset: def __init__(self, opt, device, type="train", downscale=1, n_test=10): super().__init__() self.opt = opt self.args = opt self.device = device self.type = type # train, val, test self.downscale = downscale self.root_path = opt.path self.mode = opt.mode # only support blender self.preload = opt.preload # preload data into GPU self.scale = ( opt.scale ) # camera radius scale to make sure camera are inside the bounding box. self.bound = ( opt.bound ) # bounding box half length, also used as the radius to random sample poses. self.fp16 = opt.fp16 # if preload, load into fp16. self.training = self.type in ["train", "all", "trainval"] self.num_rays = self.opt.num_rays if self.training else -1 if self.mode == "blender": if type == "all": transform_paths = glob.glob(os.path.join(self.root_path, "*.json")) transform = None for transform_path in transform_paths: with open(transform_path, "r") as f: tmp_transform = json.load(f) if transform is None: transform = tmp_transform else: transform["frames"].extend(tmp_transform["frames"]) # load train and val split elif type == "trainval": with open( os.path.join(self.root_path, f"transforms_train.json"), "r" ) as f: transform = json.load(f) with open( os.path.join(self.root_path, f"transforms_val.json"), "r" ) as f: transform_val = json.load(f) transform["frames"].extend(transform_val["frames"]) # only load one specified split else: with open( os.path.join(self.root_path, f"transforms_{type}.json"), "r" ) as f: transform = json.load(f) else: raise NotImplementedError(f"unknown dataset mode: {self.mode}") # load image size if "h" in transform and "w" in transform: self.H = int(transform["h"]) // downscale self.W = int(transform["w"]) // downscale else: # we have to actually read an image to get H and W later. self.H = self.W = None # read images frames = transform["frames"] if True: self.poses = [] self.images = [] for f in tqdm.tqdm(frames, desc=f"Loading {type} data:"): f_path = os.path.join(self.root_path, f["file_path"]) if ( self.mode == "blender" and f_path[-4:].lower() != ".png" and f_path[-4:].lower() != ".jpg" ): f_path += ".png" # so silly... if not os.path.exists(f_path): continue pose = np.array(f["transform_matrix"], dtype=np.float32) # [4, 4] pose = nerf_matrix_to_ngp(pose, scale=self.scale) image = cv2.imread( f_path, cv2.IMREAD_UNCHANGED ) # [H, W, 3] o [H, W, 4] if self.H is None or self.W is None: self.H = image.shape[0] // downscale self.W = image.shape[1] // downscale # add support for the alpha channel as a mask. if image.shape[-1] == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) else: image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) if image.shape[0] != self.H or image.shape[1] != self.W: image = cv2.resize( image, (self.W, self.H), interpolation=cv2.INTER_AREA ) image = image.astype(np.float32) / 255 # [H, W, 3/4] self.poses.append(pose) self.images.append(image) self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4] if self.images is not None: self.images = torch.from_numpy( np.stack(self.images, axis=0) ) # [N, H, W, C] self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() if self.training and self.opt.error_map: self.error_map = torch.ones( [self.images.shape[0], 128 * 128], dtype=torch.float ) # [B, 128 * 128], flattened for easy indexing, fixed resolution... else: self.error_map = None if self.preload: self.poses = self.poses.to(self.device) if self.images is not None: if self.fp16 and self.opt.color_space != "linear": dtype = torch.half else: dtype = torch.float self.images = self.images.to(dtype).to(self.device) if self.error_map is not None: self.error_map = self.error_map.to(self.device) # load intrinsics if "fl_x" in transform or "fl_y" in transform: fl_x = ( transform["fl_x"] if "fl_x" in transform else transform["fl_y"] ) / downscale fl_y = ( transform["fl_y"] if "fl_y" in transform else transform["fl_x"] ) / downscale elif "camera_angle_x" in transform or "camera_angle_y" in transform: # blender, assert in radians. already downscaled since we use H/W fl_x = ( self.W / (2 * np.tan(transform["camera_angle_x"] / 2)) if "camera_angle_x" in transform else None ) fl_y = ( self.H / (2 * np.tan(transform["camera_angle_y"] / 2)) if "camera_angle_y" in transform else None ) if fl_x is None: fl_x = fl_y if fl_y is None: fl_y = fl_x else: raise RuntimeError( "Failed to load focal length, please check the transforms.json!" ) cx = (transform["cx"] / downscale) if "cx" in transform else (self.H / 2) cy = (transform["cy"] / downscale) if "cy" in transform else (self.W / 2) self.intrinsics = np.array([fl_x, fl_y, cx, cy]) def collate(self, index): B = len(index) # a list of length 1 poses = self.poses[index].to(self.device) # [B, 4, 4] error_map = None if self.error_map is None else self.error_map[index] rays = get_rays( poses, self.intrinsics, self.H, self.W, self.num_rays, error_map ) results = { "H": self.H, "W": self.W, "rays_o": rays["rays_o"], "rays_d": rays["rays_d"], } if self.images is not None: images = self.images[index].to(self.device) # [B, H, W, 3/4] if self.training: C = images.shape[-1] images = torch.gather( images.view(B, -1, C), 1, torch.stack(C * [rays["inds"]], -1) ) # [B, N, 3/4] results["images"] = images # need inds to update error_map if error_map is not None: results["index"] = index results["inds_coarse"] = rays["inds_coarse"] return results def dataloader(self): size = len(self.poses) loader = DataLoader( list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0, ) loader._data = self return loader ================================================ FILE: distill_mutual/renderer.py ================================================ import math import trimesh import numpy as np from time import time import torch import torch.nn as nn import torch.nn.functional as F import raymarching from .utils import custom_meshgrid from IPython import embed def sample_pdf(bins, weights, n_samples, det=False): # This implementation is from NeRF # bins: [B, T], old_z_vals # weights: [B, T - 1], bin weights. # return: [B, n_samples], new_z_vals # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # Take uniform samples if det: u = torch.linspace( 0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples ).to(weights.device) u = u.expand(list(cdf.shape[:-1]) + [n_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = cdf_g[..., 1] - cdf_g[..., 0] denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples def plot_pointcloud(pc, color=None): # pc: [N, 3] # color: [N, 3/4] print("[visualize points]", pc.shape, pc.dtype, pc.min(0), pc.max(0)) pc = trimesh.PointCloud(pc, color) # axis axes = trimesh.creation.axis(axis_length=4) # sphere sphere = trimesh.creation.icosphere(radius=1) trimesh.Scene([pc, axes, sphere]).show() class NeRFRenderer(nn.Module): def __init__( self, bound=1, cuda_ray=False, density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. min_near=0.2, density_thresh=0.01, bg_radius=-1, grid_size=128, ): super().__init__() print("\n---------------", grid_size, "--------------\n") self.bound = bound self.cascade = 1 + math.ceil(math.log2(bound)) self.grid_size = grid_size self.density_scale = density_scale self.min_near = min_near self.density_thresh = density_thresh self.bg_radius = bg_radius # radius of the background sphere. # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) aabb_infer = aabb_train.clone() self.register_buffer("aabb_train", aabb_train) self.register_buffer("aabb_infer", aabb_infer) # extra state for cuda raymarching self.cuda_ray = cuda_ray if cuda_ray: # density grid density_grid = torch.zeros( [self.cascade, self.grid_size ** 3] ) # [CAS, H * H * H] density_bitfield = torch.zeros( self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8 ) # [CAS * H * H * H // 8] self.register_buffer("density_grid", density_grid) self.register_buffer("density_bitfield", density_bitfield) self.mean_density = 0 self.iter_density = 0 # step counter step_counter = torch.zeros( 16, 2, dtype=torch.int32 ) # 16 is hardcoded for averaging... self.register_buffer("step_counter", step_counter) self.mean_count = 0 self.local_step = 0 def forward(self, x, d): raise NotImplementedError() # separated density and color query (can accelerate non-cuda-ray mode.) def density(self, x): raise NotImplementedError() def color(self, x, d, mask=None, **kwargs): raise NotImplementedError() def reset_extra_state(self): if not self.cuda_ray: return # density grid self.density_grid.zero_() self.mean_density = 0 self.iter_density = 0 # step counter self.step_counter.zero_() self.mean_count = 0 self.local_step = 0 def run( self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs ): # rays_o, rays_d: [B, N, 3], assumes B == 1 # bg_color: [3] in range [0, 1] # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # choose aabb aabb = self.aabb_train if self.training else self.aabb_infer # sample steps nears, fars = raymarching.near_far_from_aabb( rays_o, rays_d, aabb, self.min_near ) nears.unsqueeze_(-1) fars.unsqueeze_(-1) # print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze( 0 ) # [1, T] z_vals = z_vals.expand((N, num_steps)) # [N, T] z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] # perturb z_vals sample_dist = (fars - nears) / num_steps if perturb: z_vals = ( z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist ) # z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. # generate xyzs xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze( -1 ) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) # query SDF and RGB density_outputs = self.density(xyzs.reshape(-1, 3)) # sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] for k, v in density_outputs.items(): density_outputs[k] = v.view(N, num_steps, -1) # upsample z_vals (nerf-like) if upsample_steps > 0: with torch.no_grad(): deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] deltas = torch.cat( [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1 ) alphas = 1 - torch.exp( -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1) ) # [N, T] alphas_shifted = torch.cat( [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1 ) # [N, T+1] weights = ( alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] ) # [N, T] # sample new z_vals z_vals_mid = z_vals[..., :-1] + 0.5 * deltas[..., :-1] # [N, T-1] new_z_vals = sample_pdf( z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training ).detach() # [N, t] new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze( -2 ) * new_z_vals.unsqueeze( -1 ) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] new_xyzs = torch.min( torch.max(new_xyzs, aabb[:3]), aabb[3:] ) # a manual clip. # only forward new points to save computation new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] for k, v in new_density_outputs.items(): new_density_outputs[k] = v.view(N, upsample_steps, -1) # re-order z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] z_vals, z_index = torch.sort(z_vals, dim=1) xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] xyzs = torch.gather( xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs) ) for k in density_outputs: tmp_output = torch.cat( [density_outputs[k], new_density_outputs[k]], dim=1 ) density_outputs[k] = torch.gather( tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output) ) deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] deltas = torch.cat( [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1 ) alphas = 1 - torch.exp( -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1) ) # [N, T+t] alphas_shifted = torch.cat( [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1 ) # [N, T+t+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) for k, v in density_outputs.items(): density_outputs[k] = v.view(-1, v.shape[-1]) mask = weights > 1e-4 # hard coded rgbs = self.color( xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs ) rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] # print(xyzs.shape, 'valid_rgb:', mask.sum().item()) # calculate weight_sum (mask) weights_sum = weights.sum(dim=-1) # [N] # calculate depth ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) depth = torch.sum(weights * ori_z_vals, dim=-1) # calculate color image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color polar = raymarching.polar_from_ray( rays_o, rays_d, self.bg_radius ) # [N, 2] in [-1, 1] bg_color = self.background(polar, rays_d.reshape(-1, 3)) # [N, 3] elif bg_color is None: bg_color = 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, 3) depth = depth.view(*prefix) # tmp: reg loss in mip-nerf 360 # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() return { "depth": depth, "image": image, } def run_cuda( self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, inherited_params=[], **kwargs ): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # pre-calculate near far nears, fars = raymarching.near_far_from_aabb( rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near, ) # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color polar = raymarching.polar_from_ray( rays_o, rays_d, self.bg_radius ) # [N, 2] in [-1, 1] bg_color = self.background(polar, rays_d) # [N, 3] elif bg_color is None: bg_color = 1 if self.training: # different with testing # setup counter time1 = time() counter = self.step_counter[self.local_step % 16] counter.zero_() # set to 0 self.local_step += 1 if ( self.args.render_stu_first ): # if stu first, then using stu to calculate xyzs, and tea will inherite the xyzs """ About xyzs, dirs, deltas, rays: xyzs, dirs are all spatial points sampled by rays_o and rays_d; rays: xyzs[rays[i, 1]:rays[i,1]+rays[i, 2]] --> points belonging to rays[i, 0] deltas: shape is [point_nums, 2]. deltas means all generated points' deltas. (first for RGB, second for Depth) """ if not self.is_teacher: xyzs, dirs, deltas, rays = raymarching.march_rays_train( rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps, ) inherited_params = [xyzs, dirs, deltas, rays] else: xyzs, dirs, deltas, rays = inherited_params else: if self.is_teacher: xyzs, dirs, deltas, rays = raymarching.march_rays_train( rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps, ) inherited_params = [xyzs, dirs, deltas, rays] else: xyzs, dirs, deltas, rays = inherited_params # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) sigmas, rgbs = self(xyzs, dirs) if self.args.global_step < self.args.stage_iters["stage1"]: return { "stage1": self.args.global_step, "depth": None, "image": None, "inherited_params": inherited_params, "sigmas": sigmas, "rays": rays, } elif self.args.global_step < self.args.stage_iters["stage2"]: return { "stage2": self.args.global_step, "depth": None, "image": None, "inherited_params": inherited_params, "sigmas": sigmas, "rays": rays, } sigmas = self.density_scale * sigmas weights_sum, depth, image = raymarching.composite_rays_train( sigmas, rgbs, deltas, rays ) image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears + 1e-6) image = image.view(*prefix, 3) depth = depth.view(*prefix) else: # allocate outputs # if use autocast, must init as half so it won't be autocasted and lose reference. # dtype = torch.half if torch.is_autocast_enabled() else torch.float32 # output should always be float32! only network inference uses half. dtype = torch.float32 weights_sum = torch.zeros(N, dtype=dtype, device=device) depth = torch.zeros(N, dtype=dtype, device=device) image = torch.zeros(N, 3, dtype=dtype, device=device) n_alive = N alive_counter = torch.zeros([1], dtype=torch.int32, device=device) rays_alive = torch.zeros( 2, n_alive, dtype=torch.int32, device=device ) # 2 is used to loop old/new rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device) step = 0 i = 0 while step < max_steps: # count alive rays if step == 0: # init rays at first step. torch.arange(n_alive, out=rays_alive[0]) rays_t[0] = nears else: alive_counter.zero_() raymarching.compact_rays( n_alive, rays_alive[i % 2], rays_alive[(i + 1) % 2], rays_t[i % 2], rays_t[(i + 1) % 2], alive_counter, ) n_alive = alive_counter.item() # must invoke D2H copy here # exit loop if n_alive <= 0: break # decide compact_steps n_step = max(min(N // n_alive, 8), 1) xyzs, dirs, deltas = raymarching.march_rays( n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb, dt_gamma, max_steps, ) sigmas, rgbs = self(xyzs, dirs) # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. # sigmas = density_outputs['sigma'] # rgbs = self.color(xyzs, dirs, **density_outputs) sigmas = self.density_scale * sigmas raymarching.composite_rays( n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], sigmas, rgbs, deltas, weights_sum, depth, image, ) # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') step += n_step i += 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears) image = image.view(*prefix, 3) depth = depth.view(*prefix) # print('\n--- render time:--- {:6f} {:.6f}'.format(time2-time1, time()-time2)) if self.training: return { "depth": depth, "image": image, "inherited_params": inherited_params, "sigmas": sigmas, "rays": rays, } else: return { "depth": depth, "image": image, "inherited_params": inherited_params, } @torch.no_grad() def mark_untrained_grid(self, poses, intrinsic, S=64): # poses: [B, 4, 4] # intrinsic: [3, 3] if not self.cuda_ray: return if isinstance(poses, np.ndarray): poses = torch.from_numpy(poses) B = poses.shape[0] fx, fy, cx, cy = intrinsic X = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) Y = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) Z = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) count = torch.zeros_like(self.density_grid) poses = poses.to(count.device) # 5-level loop, forgive me... for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1, ) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] world_xyzs = ( 2 * coords.float() / (self.grid_size - 1) - 1 ).unsqueeze( 0 ) # [1, N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_world_xyzs = world_xyzs * (bound - half_grid_size) # split batch to avoid OOM head = 0 while head < B: tail = min(head + S, B) # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) cam_xyzs = cas_world_xyzs - poses[ head:tail, :3, 3 ].unsqueeze(1) cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] # query if point is covered by any camera mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] mask_x = ( torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 ) mask_y = ( torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 ) mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] # update count count[cas, indices] += mask head += S # mark untrained grid as -1 self.density_grid[count == 0] = -1 # print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}') @torch.no_grad() def update_extra_state(self, decay=0.95, S=128): # call before each epoch to update extra states. if not self.cuda_ray: return # update density grid tmp_grid = -torch.ones_like(self.density_grid) # full update. if self.iter_density < 16: # if True: X = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) Y = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) Z = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1, ) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] xyzs = ( 2 * coords.float() / (self.grid_size - 1) - 1 ) # [N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += ( torch.rand_like(cas_xyzs) * 2 - 1 ) * half_grid_size # query density sigmas = ( self.density(cas_xyzs)["sigma"].reshape(-1).detach() ) sigmas *= self.density_scale # assign tmp_grid[cas, indices] = sigmas # partial update (half the computation) # TODO: why no need of maxpool ? else: N = self.grid_size ** 3 // 4 # H * H * H / 4 for cas in range(self.cascade): # random sample some positions coords = torch.randint( 0, self.grid_size, (N, 3), device=self.density_grid.device ) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] # random sample occupied positions occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze( -1 ) # [Nz] rand_mask = torch.randint( 0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_grid.device, ) occ_indices = occ_indices[ rand_mask ] # [Nz] --> [N], allow for duplication occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3] # concat indices = torch.cat([indices, occ_indices], dim=0) coords = torch.cat([coords, occ_coords], dim=0) # same below xyzs = ( 2 * coords.float() / (self.grid_size - 1) - 1 ) # [N, 3] in [-1, 1] bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size # query density sigmas = self.density(cas_xyzs)["sigma"].reshape(-1).detach() sigmas *= self.density_scale # assign tmp_grid[cas, indices] = sigmas ## max-pool on tmp_grid for less aggressive culling [No significant improvement...] # invalid_mask = tmp_grid < 0 # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1) # tmp_grid[invalid_mask] = -1 # ema update valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) self.density_grid[valid_mask] = torch.maximum( self.density_grid[valid_mask] * decay, tmp_grid[valid_mask] ) self.mean_density = torch.mean( self.density_grid.clamp(min=0) ).item() # -1 non-training regions are viewed as 0 density. self.iter_density += 1 # convert to bitfield density_thresh = min(self.mean_density, self.density_thresh) self.density_bitfield = raymarching.packbits( self.density_grid, density_thresh, self.density_bitfield ) ### update step counter total_step = min(16, self.local_step) if total_step > 0: self.mean_count = int( self.step_counter[:total_step, 0].sum().item() / total_step ) self.local_step = 0 # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: pred_rgb: [B, N, 3] if self.cuda_ray: _run = self.run_cuda else: _run = self.run B, N = rays_o.shape[:2] device = rays_o.device # never stage when cuda_ray if staged and not self.cuda_ray: depth = torch.empty((B, N), device=device) image = torch.empty((B, N, 3), device=device) for b in range(B): head = 0 while head < N: tail = min(head + max_ray_batch, N) results_ = _run( rays_o[b : b + 1, head:tail], rays_d[b : b + 1, head:tail], **kwargs ) depth[b : b + 1, head:tail] = results_["depth"] image[b : b + 1, head:tail] = results_["image"] head += max_ray_batch results = {} results["depth"] = depth results["image"] = image else: results = _run(rays_o, rays_d, **kwargs) return results ================================================ FILE: distill_mutual/utils.py ================================================ import os import copy import lpips import glob import tqdm import math import random import warnings import tensorboardX import numpy as np import pandas as pd import imageio import time from datetime import datetime import cv2 import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.distributed as dist from torch.utils.data import Dataset, DataLoader import trimesh import mcubes from rich.console import Console from torch_ema import ExponentialMovingAverage from IPython import embed import sys from packaging import version as pver device = torch.device("cuda") TINY_NUMBER = 1e-6 # float32 only has 7 decimal digits precision def update_loss_rate(cur_lrate, scale=0.99): return cur_lrate * scale def get_softmax_map_mean(a, b): return (F.softmax(a) - F.softmax(b)).abs().mean() def get_kl(inputs, targets): return F.kl_div(F.log_softmax(inputs), F.softmax(targets), reduction="sum") def nerf_matrix_to_ngp(pose, scale=0.8): # for the fox dataset, 0.33 scales camera radius to ~ 2 new_pose = np.array( [ [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale], [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale], [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale], [0, 0, 0, 1], ], dtype=np.float32, ) return new_pose def pose_spherical(theta, phi, radius): # for synthetic. it generates sphere random poses trans_t = lambda t: np.array( [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1]] ).astype(np.float32) rot_phi = lambda phi: np.array( [ [1, 0, 0, 0], [0, np.cos(phi), -np.sin(phi), 0], [0, np.sin(phi), np.cos(phi), 0], [0, 0, 0, 1], ] ).astype(np.float32) rot_theta = lambda th: np.array( [ [np.cos(th), 0, -np.sin(th), 0], [0, 1, 0, 0], [np.sin(th), 0, np.cos(th), 0], [0, 0, 0, 1], ] ).astype(np.float32) c2w = trans_t(radius) c2w = rot_phi(phi / 180.0 * np.pi) @ c2w c2w = rot_theta(theta / 180.0 * np.pi) @ c2w c2w = ( np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]).astype( np.float32 ) @ c2w ) return c2w def get_rand_poses(data_type="synthetic", original_loader=None): """ Random sampling. Random origins and directions. """ from scipy.spatial.transform import Slerp, Rotation assert data_type in {"synthetic", "llff", "tank"} def get_single_syn_pose(ph, rand_radius=False): theta1 = -180 theta2 = 180 phi1 = -ph phi2 = 5 - ph if (5 - ph) <= 0 else 0 theta = theta1 + np.random.rand() * (theta2 - theta1) phi = phi1 + np.random.rand() * (phi2 - phi1) if rand_radius: radius = np.random.uniform(3, 4) else: radius = 4 return pose_spherical(theta, phi, radius) def get_syn_poses(): random_poses = np.array([get_single_syn_pose(8) for _ in range(1)]) for a in range(0, 80): rp = np.array( [get_single_syn_pose(a) for _ in range(int(((90 - a) // 15) ** 1 + 1))] ) random_poses = np.concatenate([random_poses, rp], axis=0) for i in range(len(random_poses)): random_poses[i] = nerf_matrix_to_ngp(random_poses[i]) print(f"\nlen(train data): {len(random_poses)}\n") random_poses = torch.from_numpy(random_poses).cuda() return random_poses def get_tank_poses(): random_poses = np.array([get_single_syn_pose(8) for _ in range(1)]) for a in range(5, 20): rp = np.array( [ get_single_syn_pose(a, True) for _ in range(int(((90 - a) // 15) ** 1 + 1)) ] ) random_poses = np.concatenate([random_poses, rp], axis=0) for i in range(len(random_poses)): random_poses[i] = nerf_matrix_to_ngp(random_poses[i]) print(f"\nlen(train data): {len(random_poses)}\n") random_poses = torch.from_numpy(random_poses).cuda() return random_poses def rand_poses_from_cam_centers(centers): def normalize(vectors): return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) size = len(centers) forward_vector = -normalize(centers) up_vector = ( torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) ) # confused at the coordinate system... right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1)) up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) poses = ( torch.eye(4, dtype=torch.float, device=device) .unsqueeze(0) .repeat(size, 1, 1) ) poses[:, :3, :3] = torch.stack( (right_vector, up_vector, forward_vector), dim=-1 ) poses[:, :3, 3] = centers return poses def get_llff_poses_rand(): def get_rand_cam_centers_from_bbox(poses, gen_num=30): # use poses to estimate the bbox of the camera trasitions = poses[:, :3, 3] bbox_max = trasitions.max(axis=0) + 1e-6 bbox_min = trasitions.min(axis=0) - 1e-6 rand_xs = np.random.uniform(low=bbox_min[0], high=bbox_max[0], size=gen_num) rand_ys = np.random.uniform(low=bbox_min[1], high=bbox_max[1], size=gen_num) rand_zs = np.random.uniform(low=bbox_min[2], high=bbox_max[2], size=gen_num) centers = np.stack([rand_xs, rand_ys, rand_zs], axis=1) return centers.astype(np.float32) centers = get_rand_cam_centers_from_bbox(original_loader) random_poses = rand_poses_from_cam_centers(torch.from_numpy(centers).cuda()) random_poses[:, 0, 0] = -random_poses[:, 0, 0] return random_poses if data_type == "synthetic": random_poses = get_syn_poses() elif data_type == "llff": random_poses = get_llff_poses_rand() elif data_type == "tank": random_poses = get_tank_poses() else: raise ValueError("illegal") return random_poses def custom_meshgrid(*args): # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid if pver.parse(torch.__version__) < pver.parse("1.10"): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing="ij") @torch.jit.script def linear_to_srgb(x): return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) @torch.jit.script def srgb_to_linear(x): return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) def compute_ssim( img0, img1, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, return_map=False, ): """Computes SSIM from two images. This function was modeled after tf.image.ssim, and should produce comparable output. Args: img0: torch.tensor. An image of size [..., width, height, num_channels]. img1: torch.tensor. An image of size [..., width, height, num_channels]. max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. filter_size: int >= 1. Window size. filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. k1: float > 0. One of the SSIM dampening parameters. k2: float > 0. One of the SSIM dampening parameters. return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned Returns: Each image's mean SSIM, or a tensor of individual values if `return_map`. """ device = img0.device img0 = img0.type(torch.float32) img1 = img1.type(torch.float32) ori_shape = img0.size() width, height, num_channels = ori_shape[-3:] img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2) img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2) batch_size = img0.shape[0] # Construct a 1D Gaussian blur filter. hw = filter_size // 2 shift = (2 * hw - filter_size + 1) / 2 f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2 filt = torch.exp(-0.5 * f_i) filt /= torch.sum(filt) # Blur in x and y (faster than the 2D convolution). # z is a tensor of size [B, H, W, C] filt_fn1 = lambda z: F.conv2d( z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), padding=[hw, 0], groups=num_channels, ) filt_fn2 = lambda z: F.conv2d( z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), padding=[0, hw], groups=num_channels, ) # Vmap the blurs to the tensor size, and then compose them. filt_fn = lambda z: filt_fn1(filt_fn2(z)) mu0 = filt_fn(img0) mu1 = filt_fn(img1) mu00 = mu0 * mu0 mu11 = mu1 * mu1 mu01 = mu0 * mu1 sigma00 = filt_fn(img0 ** 2) - mu00 sigma11 = filt_fn(img1 ** 2) - mu11 sigma01 = filt_fn(img0 * img1) - mu01 # Clip the variances and covariances to valid values. # Variance must be non-negative: sigma00 = torch.clamp(sigma00, min=0.0) sigma11 = torch.clamp(sigma11, min=0.0) sigma01 = torch.sign(sigma01) * torch.min( torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) ) c1 = (k1 * max_val) ** 2 c2 = (k2 * max_val) ** 2 numer = (2 * mu01 + c1) * (2 * sigma01 + c2) denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) ssim_map = numer / denom ssim = torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1) return ssim_map if return_map else ssim def init_lpips(net_name, device): assert net_name in ["alex", "vgg"] import lpips print(f"init_lpips: lpips_{net_name}") return lpips.LPIPS(net=net_name, version="0.1").eval().cuda() lpips_fns = { "alex": lpips.LPIPS(net="alex", version="0.1").eval().cuda(), "vgg": lpips.LPIPS(net="vgg", version="0.1").eval().cuda(), } def rgb_lpips(gt, im, net_name): assert net_name in ["alex", "vgg"] gt = gt.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda() im = im.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda() return lpips_fns[net_name](gt, im, normalize=True).item() @torch.cuda.amp.autocast(enabled=False) def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): """get rays Args: poses: [B, 4, 4], cam2world intrinsics: [4] H, W, N: int error_map: [B, 128 * 128], sample probability based on training error Returns: rays_o, rays_d: [B, N, 3] inds: [B, N] """ device = poses.device B = poses.shape[0] fx, fy, cx, cy = intrinsics i, j = custom_meshgrid( torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device), ) i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 results = {} if N > 0: N = min(N, H * W) if error_map is None: inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate inds = inds.expand([B, N]) else: # weighted sample on a low-reso grid inds_coarse = torch.multinomial( error_map.to(device), N, replacement=False ) # [B, N], but in [0, 128*128) # map to the original resolution with random perturb. inds_x, inds_y = ( inds_coarse // 128, inds_coarse % 128, ) # `//` will throw a warning in torch 1.10... anyway. sx, sy = H / 128, W / 128 inds_x = ( (inds_x * sx + torch.rand(B, N, device=device) * sx) .long() .clamp(max=H - 1) ) inds_y = ( (inds_y * sy + torch.rand(B, N, device=device) * sy) .long() .clamp(max=W - 1) ) inds = inds_x * W + inds_y results["inds_coarse"] = inds_coarse # need this when updating error_map i = torch.gather(i, -1, inds) j = torch.gather(j, -1, inds) results["inds"] = inds else: inds = torch.arange(H * W, device=device).expand([B, H * W]) zs = torch.ones_like(i) xs = (i - cx) / fx * zs ys = (j - cy) / fy * zs directions = torch.stack((xs, ys, zs), dim=-1) directions = directions / torch.norm(directions, dim=-1, keepdim=True) rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) rays_o = poses[..., :3, 3] # [B, 3] rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] results["rays_o"] = rays_o results["rays_d"] = rays_d return results def seed_everything(seed): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = True def torch_vis_2d(x, renormalize=False): # x: [3, H, W] or [1, H, W] or [H, W] import matplotlib.pyplot as plt import numpy as np import torch if isinstance(x, torch.Tensor): if len(x.shape) == 3: x = x.permute(1, 2, 0).squeeze() x = x.detach().cpu().numpy() print(f"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}") x = x.astype(np.float32) # renormalize if renormalize: x = (x - x.min(axis=0, keepdims=True)) / ( x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8 ) plt.imshow(x) plt.show() def extract_fields(bound_min, bound_max, resolution, query_func, S=128): X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S) Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S) Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S) u = np.zeros([resolution, resolution, resolution], dtype=np.float32) with torch.no_grad(): for xi, xs in enumerate(X): for yi, ys in enumerate(Y): for zi, zs in enumerate(Z): xx, yy, zz = custom_meshgrid(xs, ys, zs) pts = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1, ) # [S, 3] val = ( query_func(pts) .reshape(len(xs), len(ys), len(zs)) .detach() .cpu() .numpy() ) # [S, 1] --> [x, y, z] u[ xi * S : xi * S + len(xs), yi * S : yi * S + len(ys), zi * S : zi * S + len(zs), ] = val return u def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): # print('threshold: {}'.format(threshold)) u = extract_fields(bound_min, bound_max, resolution, query_func) # print(u.shape, u.max(), u.min(), np.percentile(u, 50)) vertices, triangles = mcubes.marching_cubes(u, threshold) b_max_np = bound_max.detach().cpu().numpy() b_min_np = bound_min.detach().cpu().numpy() vertices = ( vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] ) return vertices, triangles class PSNRMeter: def __init__(self): self.V = 0 self.N = 0 self.psnr_list = [] def clear(self): self.V = 0 self.N = 0 self.psnr_list = [] def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): inp = inp.detach().cpu().numpy() outputs.append(inp) return outputs def update(self, preds, truths): preds, truths = self.prepare_inputs( preds, truths ) # [B, N, 3] or [B, H, W, 3], range[0, 1] psnr = -10 * np.log10(np.mean((preds - truths) ** 2)) self.psnr_list.append(psnr) self.V += psnr self.N += 1 assert self.N == len(self.psnr_list) def measure(self): return self.V / self.N def write(self, writer, global_step, prefix=""): writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step) def report(self): return f"PSNR = {self.measure():.6f}" class Trainer(object): def __init__( self, name, # name of this experiment opt, # extra conf model_tea, # network model_stu, criterion=None, # loss function, if None, assume inline implementation in train_step optimizer=None, # optimizer ema_decay=None, # if use EMA, set the decay lr_scheduler=None, # scheduler metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. local_rank=0, # which GPU am I world_size=1, # total num of GPUs device=None, # device to use, usually setting to None is OK. (auto choose device) mute=False, # whether to mute all print fp16=False, # amp optimize level eval_interval=10e10, # eval once every $ epoch max_keep_ckpt=2, # max num of saved ckpts in disk workspace="workspace", # workspace to save logs & ckpts best_mode="min", # the smaller/larger result, the better use_loss_as_metric=True, # use loss as the first metric report_metric_at_train=False, # also report metrics at training use_checkpoint="latest", # which ckpt to use at init time use_tensorboardX=True, # whether to use tensorboard for logging scheduler_update_every_step=False, # whether to call scheduler.step() after every train step ): self.optimizer_fn = optimizer self.lr_scheduler_fn = lr_scheduler self.name = name self.opt = opt self.args = opt self.mute = mute self.metrics = metrics self.local_rank = local_rank self.world_size = world_size self.workspace = workspace self.ema_decay = ema_decay self.fp16 = fp16 self.best_mode = best_mode self.use_loss_as_metric = use_loss_as_metric self.report_metric_at_train = report_metric_at_train self.max_keep_ckpt = max_keep_ckpt self.eval_interval = eval_interval self.use_checkpoint = use_checkpoint self.use_tensorboardX = use_tensorboardX self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") self.scheduler_update_every_step = scheduler_update_every_step self.device = ( device if device is not None else torch.device( f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" ) ) self.console = Console() self.model_tea = model_tea.to(device) self.model_stu = model_stu.to(device) if isinstance(criterion, nn.Module): criterion.to(self.device) self.criterion = criterion if optimizer is None: self.optimizer = optim.AdamW( self.model_stu.parameters(), lr=0.001, weight_decay=5e-4 ) # naive adam else: self.optimizer = optimizer(self.model_stu) if lr_scheduler is None: self.lr_scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=lambda epoch: 1 ) # fake scheduler else: self.ls = lr_scheduler self.lr_scheduler = lr_scheduler(self.optimizer) if ema_decay is not None and ema_decay > 0: self.ema = ExponentialMovingAverage( self.model_stu.parameters(), decay=ema_decay ) else: self.ema = None self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) # variable init self.epoch = 1 self.global_step = 0 self.local_step = 0 self.stats = { "loss": [], "valid_loss": [], "results": [], # metrics[0], or valid_loss "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt "best_result": None, } # auto fix if len(metrics) == 0 or self.use_loss_as_metric: self.best_mode = "min" # workspace prepare self.log_ptr = None if self.workspace is not None: os.makedirs(self.workspace, exist_ok=True) self.log_path = os.path.join(workspace, f"log_{self.name}.txt") self.log_ptr = open(self.log_path, "a+") self.ckpt_path = os.path.join(self.workspace, "checkpoints") self.best_path = f"{self.ckpt_path}/{self.name}.pth" os.makedirs(self.ckpt_path, exist_ok=True) self.log(self.opt) self.log( f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}' ) self.log( f"[INFO] #parameters: {sum([p.numel() for p in model_stu.parameters() if p.requires_grad])}" ) if ( self.workspace is not None ): # only load state_dict for teacher and share backbone for student self.log(f"[INFO] Loading teacher ckpt from {self.opt.ckpt_teacher} ...") self.load_teacher_checkpoint() self.log(self.model_tea) self.load_student_checkpoint() self.log(self.model_stu) # self.model_tea.reset_extra_state() # self.model_stu.reset_extra_state() """ if opt.rand_pose >= 0: # =0 means only using CLIP loss, >0 means a hybrid mode. from nerf.clip_utils import CLIPLoss self.clip_loss = CLIPLoss(self.device) self.clip_loss.prepare_text([self.opt.clip_text]) # only support one text prompt now... """ def __del__(self): if self.log_ptr: self.log_ptr.close() def log(self, *args, **kwargs): if self.local_rank == 0: if not self.mute: # print(*args) self.console.print(*args, **kwargs) if self.log_ptr: print(*args, file=self.log_ptr) self.log_ptr.flush() # write immediately to file def train(self, train_loader, valid_loader, max_epochs): self.hard_rays_pool = [torch.tensor([]).cuda(), torch.tensor([]).cuda()] self.is_hard_rays_pool_full = False if self.use_tensorboardX and self.local_rank == 0: self.writer = tensorboardX.SummaryWriter( os.path.join(self.workspace, "run", self.name) ) for p in self.model_tea.parameters(): p.requires_grad = False self.model_tea.eval() # get a ref to error_map self.error_map = train_loader._data.error_map if ( not self.args.use_real_data_for_train ): # using random poses to calculate max_epochs. random_poses = get_rand_poses( data_type=self.args.data_type, original_loader=copy.deepcopy( train_loader._data.poses.detach().cpu().numpy() ), ) self.opt.iters = int( (self.opt.iters // len(random_poses)) * len(random_poses) ) max_epochs = np.ceil(self.opt.iters / len(random_poses)).astype(np.int32) scheduler = lambda optimizer: optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.opt.iters * 1, eta_min=7e-5 ) # update scheduler according to new opt.iters self.lr_scheduler = scheduler(self.optimizer) self.total_epoch = max_epochs self.log(f"\n----------------total epoch:{max_epochs} -----------\n") self.real_train_poses = copy.deepcopy(train_loader._data.poses) for epoch in range(self.epoch, max_epochs + 1): self.epoch = epoch if not self.args.use_real_data_for_train: print(f"\n generate new random poses at epoch{self.epoch}") random_poses = get_rand_poses( data_type=self.args.data_type, original_loader=self.real_train_poses.detach().cpu().numpy(), ) train_loader._data.poses = copy.deepcopy(random_poses) train_loader._data.images = train_loader._data.images[:1].expand( len(random_poses), -1, -1, -1 ) train_loader = train_loader._data.dataloader() self.train_one_epoch(train_loader) print("\n", self.workspace, "\n") if ( self.workspace is not None and self.local_rank == 0 and self.epoch > max_epochs - 1 ): self.save_checkpoint(full=False, best=False) if self.epoch % self.eval_interval == 0: self.evaluate_one_epoch(valid_loader) self.save_checkpoint(full=False, best=True) # # 为了节省存储,暂时不存储pth if self.use_tensorboardX and self.local_rank == 0: self.writer.close() def train_one_epoch(self, loader): # self.log( # f"tttttttttt> Start Training Epoch {self.epoch}/{self.total_epoch}, len(train_data):{len(loader)} lr={self.optimizer.param_groups[0]['lr']:.6f} ..." # ) total_loss = 0 total_loss_rgb = 0 total_loss_fea_sc = 0 total_loss_sigma = 0 total_loss_color = 0 psnr_tool = PSNRMeter() psnr_tool.clear() self.pose_psnr = [] # [(pose1, psnr1), (pose2,psnr2)...] if self.local_rank == 0 and self.report_metric_at_train: for metric in self.metrics: metric.clear() self.model_stu.train() self.model_tea.train() if self.world_size > 1: loader.sampler.set_epoch(self.epoch) if self.local_rank == 0: pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) self.local_step = 0 for data in loader: # update grid every 16 steps. It shoule be run in just train a teacher, but not when distillting a student if ( self.model_tea.cuda_ray and self.global_step % self.opt.update_extra_interval == 0 ): with torch.cuda.amp.autocast(enabled=self.fp16): if self.opt.update_stu_extra: self.model_stu.update_extra_state() else: pass self.local_step += 1 self.global_step += 1 self.args.global_step = self.global_step self.optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=self.fp16): ( preds, truths, loss, loss_rgb, loss_fea_sc, loss_color, loss_sigma, ) = self.train_step(data) if preds is not None: psnr_tool.update(preds, truths) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_update_every_step: self.lr_scheduler.step() loss_val = loss.item() total_loss += loss_val total_loss_rgb += loss_rgb total_loss_sigma += loss_sigma total_loss_color += loss_color total_loss_fea_sc += loss_fea_sc if self.local_rank == 0: if self.report_metric_at_train: for metric in self.metrics: metric.update(preds, truths) if self.use_tensorboardX: self.writer.add_scalar("train/loss", loss_val, self.global_step) self.writer.add_scalar("train/loss_rgb", loss_rgb, self.global_step) self.writer.add_scalar( "train/loss_fea_sc", loss_fea_sc, self.global_step ) self.writer.add_scalar( "train/loss_coloc", loss_color, self.global_step ) self.writer.add_scalar( "train/loss_sigma", loss_sigma, self.global_step ) self.writer.add_scalar( "train/lr", self.optimizer.param_groups[0]["lr"], self.global_step, ) if self.scheduler_update_every_step: # run this cur_lr = self.optimizer.param_groups[0]["lr"] if self.global_step < self.args.stage_iters["stage1"]: pbar.set_description( f"loss={total_loss/self.local_step:.5f}, fea_sc={total_loss_fea_sc/self.local_step:.5f}, lr={cur_lr:.5f}" ) elif self.global_step < self.args.stage_iters["stage2"]: pbar.set_description( f"loss={total_loss/self.local_step:.5f}, fea_sc={total_loss_fea_sc/self.local_step:.5f}, sigma={total_loss_sigma/self.local_step:.5f}, color={total_loss_color/self.local_step:.5f}, lr={cur_lr:.6f}" ) else: pbar.set_description( f"loss={total_loss/self.local_step:.5f}, rgb={total_loss_rgb/self.local_step:.5f}, lr={cur_lr:.5f}" ) else: pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})" ) pbar.update(loader.batch_size) if ( self.opt.model_type == "vm" and self.global_step in self.opt.upsample_model_steps ): # shrink if ( self.model_stu.cuda_ray ): # and self.global_step == self.opt.upsample_model_steps[0]: self.model_stu.shrink_model() # adaptive voxel size from aabb_train n_vox = self.upsample_resolutions.pop(0) ** 3 # n_voxels aabb = self.model_stu.aabb_train.cpu().numpy() vox_size = np.cbrt(np.prod(aabb[3:] - aabb[:3]) / n_vox) reso = ((aabb[3:] - aabb[:3]) / vox_size).astype(np.int32).tolist() self.log( f"[INFO] upsample model at step {self.global_step} from {self.model_stu.resolution} to {reso}" ) from IPython import embed embed() self.model_stu.upsample_model(reso) # reset optimizer since params changed. self.optimizer = self.optimizer_fn(self.model_stu) self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) if self.ema is not None: self.ema.update() average_loss = total_loss / self.local_step self.stats["loss"].append(average_loss) if self.local_rank == 0: pbar.close() if self.report_metric_at_train: for metric in self.metrics: self.log(metric.report(), style="red") if self.use_tensorboardX: metric.write(self.writer, self.epoch, prefix="train") metric.clear() if not self.scheduler_update_every_step: if isinstance( self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau ): self.lr_scheduler.step(average_loss) else: self.lr_scheduler.step() psnr_tool.psnr_list.sort() if self.global_step < self.args.stage_iters["stage1"]: self.log( f"tttttttttt> Train stage1 Epoch:{self.epoch}. loss_fea:{total_loss_fea_sc/self.local_step:.6f}" ) elif self.global_step < self.args.stage_iters["stage2"]: self.log( f"tttttttttt> Train stage2 Epoch:{self.epoch}. loss_fea_sc:{total_loss_fea_sc/self.local_step:.3f} loss_sigma:{total_loss_sigma/self.local_step:.3f} loss_color:{total_loss_color/self.local_step:.3f}" ) else: self.log( f"tttttttttt> Train stage3 Epoch:{self.epoch}. loss_rgb:{total_loss_rgb/self.global_step:.3f} loss_fea_sc:{total_loss_fea_sc/self.local_step:.3f} loss_sigma:{total_loss_sigma/self.local_step:.3f} loss_color:{total_loss_color/self.local_step:.3f}" ) self.log( f"tttttttttt> Train PSNR Epoch {self.epoch}. psnr_min:{psnr_tool.psnr_list[0]:.3f} psnr_max:{psnr_tool.psnr_list[-1]:.3f} psnr_mean:{np.mean(psnr_tool.psnr_list):.3f}" ) def get_loss(self, pred, gt): if self.opt.loss_type == "L2": loss = torch.mean((gt - pred) ** 2) elif self.opt.loss_type == "normL2": loss = torch.norm(pred - gt) elif self.opt.loss_type == "normL1": loss = torch.norm(pred - gt, p=1) elif self.opt.loss_type == "smoothL1": loss = torch.nn.functional.smooth_l1_loss(pred, gt, beta=0.05) else: raise ValueError("error loss_type") return loss def train_step(self, data): rays_o = data["rays_o"] # [B, N, 3] rays_d = data["rays_d"] # [B, N, 3] [1, N=rays_num=4096, 3] loss = 0.0 # if there is no gt image, we train with CLIP loss. if "images" not in data: assert 1 == 2 B, N = rays_o.shape[:2] H, W = data["H"], data["W"] # currently fix white bg, MUST force all rays! outputs = self.model.render( rays_o, rays_d, staged=False, bg_color=None, perturb=True, force_all_rays=True, **vars(self.opt), ) pred_rgb = ( outputs["image"].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() ) loss = self.clip_loss(pred_rgb) return pred_rgb, None, loss images = data["images"] # [B, N, 3/4] B, N, C = images.shape # if self.opt.color_space == 'linear': # images[..., :3] = srgb_to_linear(images[..., :3]) if ( C == 3 or self.model_stu.bg_radius > 0 ): # C=4 in synthetic dataset. C=3 for real dataset bg_color = 1 # train with random background color if not using a bg model and has alpha channel. else: bg_color = torch.rand( [B, rays_o.size(1), 3], dtype=images.dtype, device=images.device ) if self.opt.render_stu_first: outputs_stu = self.model_stu.render( rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **vars(self.opt), ) pred_rgb_stu = outputs_stu["image"] with torch.no_grad(): outputs_tea = self.model_tea.render( rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, inherited_params=outputs_stu["inherited_params"], **vars(self.opt), ) pred_rgb_tea = outputs_tea["image"] else: with torch.no_grad(): outputs_tea = self.model_tea.render( rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **vars(self.opt), ) pred_rgb_tea = outputs_tea["image"] outputs_stu = self.model_stu.render( rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, inherited_params=outputs_tea["inherited_params"], **vars(self.opt), ) pred_rgb_stu = outputs_stu["image"] gt_rgb = pred_rgb_tea self.opt.loss_rate_fea_sc = update_loss_rate(self.opt.loss_rate_fea_sc, 0.995) if ( "stage1" in outputs_stu and self.opt.loss_rate_fea_sc > 0.0 and self.model_stu.feature_sigma_color is not None and self.model_tea.feature_sigma_color is not None ): assert ( self.model_stu.feature_sigma_color.shape == self.model_tea.feature_sigma_color.shape ) loss_fea_sc = self.get_loss( self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color ) loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc return None, None, loss, 0, loss_fea_sc.detach().item(), 0, 0 if "stage2" in outputs_stu: if self.opt.loss_rate_color > 0.0: assert self.model_stu.color_l.shape == self.model_tea.color_l.shape loss_color = self.get_loss( self.model_stu.color_l, self.model_tea.color_l ) loss = loss + self.opt.loss_rate_color * loss_color else: assert self.model_stu.color_l.shape == self.model_tea.color_l.shape loss_color = self.get_loss( self.model_stu.color_l, self.model_tea.color_l ) if self.opt.loss_rate_sigma > 0.0: assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape loss_sigma = self.get_loss( self.model_stu.sigma_l, self.model_tea.sigma_l ) loss = loss + self.opt.loss_rate_sigma * loss_sigma else: assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape loss_sigma = self.get_loss( self.model_stu.sigma_l, self.model_tea.sigma_l ) if ( self.opt.loss_rate_fea_sc > 0.0 and self.model_stu.feature_sigma_color is not None and self.model_tea.feature_sigma_color is not None ): assert ( self.model_stu.feature_sigma_color.shape == self.model_tea.feature_sigma_color.shape ) loss_fea_sc = self.get_loss( self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color, ) loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc else: loss_fea_sc = torch.tensor(0.0) return ( None, None, loss, 0, loss_fea_sc.detach().item(), loss_color.detach().item(), loss_sigma.detach().item(), ) if self.opt.loss_type == "normL2": loss_rgb = torch.norm(pred_rgb_tea - pred_rgb_stu) elif self.opt.loss_type == "normL1": loss_rgb = torch.norm(pred_rgb_tea - pred_rgb_stu, p=1) elif self.opt.loss_type == "L2": loss_rgb = self.criterion(pred_rgb_tea, pred_rgb_stu).mean( -1 ) # [B, N, 3] --> [B, N] if len(loss_rgb.shape) == 3: # [K, B, N] loss_rgb = loss_rgb.mean(0) if self.error_map is not None: index = data["index"] # [B] inds = data["inds_coarse"] # [B, N] error_map = self.error_map[index] # [B, H * W] error = loss_rgb.detach().to( error_map.device ) # [B, N], already in [0, 1] ema_error = 0.1 * error_map.gather(1, inds) + 0.9 * error # ema update error_map.scatter_(1, inds, ema_error) self.error_map[index] = error_map # put back loss_rgb = loss_rgb.mean() else: raise ValueError("error loss_type") loss = loss + loss_rgb * self.opt.loss_rate_rgb if self.opt.l1_reg_weight > 0.0 and self.opt.model_type == "vm": loss = loss + self.model_stu.density_loss() * self.opt.l1_reg_weight if ( self.opt.loss_rate_fea_sc > 0.0 and self.model_stu.feature_sigma_color is not None and self.model_tea.feature_sigma_color is not None ): assert ( self.model_stu.feature_sigma_color.shape == self.model_tea.feature_sigma_color.shape ) loss_fea_sc = self.get_loss( self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color ) loss = loss + self.opt.loss_rate_fea_sc * loss_fea_sc elif ( self.model_stu.feature_sigma_color is None or self.model_tea.feature_sigma_color is None ): loss_fea_sc = torch.tensor(0.0) else: assert ( self.model_stu.feature_sigma_color.shape == self.model_tea.feature_sigma_color.shape ) loss_fea_sc = self.get_loss( self.model_stu.feature_sigma_color, self.model_tea.feature_sigma_color ) if self.opt.loss_rate_color > 0.0: assert self.model_stu.color_l.shape == self.model_tea.color_l.shape loss_color = self.get_loss(self.model_stu.color_l, self.model_tea.color_l) loss = loss + self.opt.loss_rate_color * loss_color else: assert self.model_stu.color_l.shape == self.model_tea.color_l.shape loss_color = self.get_loss(self.model_stu.color_l, self.model_tea.color_l) if self.opt.loss_rate_sigma > 0.0: assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape loss_sigma = self.get_loss(self.model_stu.sigma_l, self.model_tea.sigma_l) loss = loss + self.opt.loss_rate_sigma * loss_sigma else: assert self.model_stu.sigma_l.shape == self.model_tea.sigma_l.shape loss_sigma = self.get_loss(self.model_stu.sigma_l, self.model_tea.sigma_l) loss_rgb_show = self.criterion( pred_rgb_tea.detach(), pred_rgb_stu.detach() ).mean() # [B, N, 3] --> [B, N] return ( pred_rgb_stu, gt_rgb, loss, loss_rgb_show.detach().item(), loss_fea_sc.detach().item(), loss_color.detach().item(), loss_sigma.detach().item(), ) ### ------------------------------ def evaluate(self, loader, name=None): self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX self.evaluate_one_epoch(loader, name) self.use_tensorboardX = use_tensorboardX def evaluate_one_epoch(self, loader, name=None): if name is None: name = f"{self.name}_ep{self.epoch:04d}" total_loss = 0 if self.local_rank == 0: for metric in self.metrics: metric.clear() if self.opt.test_teacher: self.model_stu = self.model_tea self.model_stu.eval() if self.ema is not None: self.ema.store() self.ema.copy_to() if self.local_rank == 0: pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) with torch.no_grad(): self.local_step = 0 self.ssim = 0.0 self.lpips_vgg = 0.0 self.lpips_alex = 0.0 # update grid if self.model_stu.cuda_ray: with torch.cuda.amp.autocast(enabled=self.fp16): if self.opt.update_stu_extra: self.model_stu.update_extra_state() else: pass frames = [] frames_depth = [] for data in loader: self.local_step += 1 with torch.cuda.amp.autocast(enabled=self.fp16): preds, preds_depth, truths, loss = self.eval_step(data) # all_gather/reduce the statistics (NCCL only support all_*) if self.world_size > 1: dist.all_reduce(loss, op=dist.ReduceOp.SUM) loss = loss / self.world_size preds_list = [ torch.zeros_like(preds).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(preds_list, preds) preds = torch.cat(preds_list, dim=0) preds_depth_list = [ torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(preds_depth_list, preds_depth) preds_depth = torch.cat(preds_depth_list, dim=0) truths_list = [ torch.zeros_like(truths).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(truths_list, truths) truths = torch.cat(truths_list, dim=0) loss_val = loss.item() total_loss += loss_val if self.local_rank == 0: for metric in self.metrics: metric.update(preds, truths) self.lpips_alex += rgb_lpips(truths, preds, "alex") self.lpips_vgg += rgb_lpips(truths, preds, "vgg") self.ssim += compute_ssim( preds, truths, max_val=max(preds.max().item(), truths.max().item()), ).item() # save image save_path = os.path.join( self.workspace, loader._data.type, f"{name}_{self.local_step:04d}.png", ) save_path_depth = os.path.join( self.workspace, loader._data.type, f"{name}_{self.local_step:04d}_depth.png", ) # save_path_gt = os.path.join(self.workspace, loader._data.type, f'{name}_{self.local_step:04d}_gt.png') os.makedirs(os.path.dirname(save_path), exist_ok=True) if self.opt.color_space == "linear": preds = linear_to_srgb(preds) pred = preds[0].detach().cpu().numpy() truth = truths[0].detach().cpu().numpy() pred_depth = preds_depth[0].detach().cpu().numpy() cv2.imwrite( save_path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR), ) cv2.imwrite(save_path_depth, (pred_depth * 255).astype(np.uint8)) frames.append((pred * 255).astype(np.uint8)) frames_depth.append((pred_depth * 255).astype(np.uint8)) pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})" ) pbar.update(loader.batch_size) print( f"\n----video num(frames): {len(frames)} depth video num:{len(frames_depth)} ----\n" ) imageio.mimwrite( os.path.join(os.path.dirname(save_path), "video.mp4"), frames, fps=int(30 * 0.7), macro_block_size=8, ) imageio.mimwrite( os.path.join(os.path.dirname(save_path), "video_depth.mp4"), frames_depth, fps=int(30 * 0.7), macro_block_size=8, ) psnr_tool = self.metrics[0] psnr_tool.psnr_list.sort() self.log( f"\neeeeeeeee> {loader._data.type} PSRN Report: Epoch{self.epoch}. psnr_mean:{np.mean(psnr_tool.psnr_list):.2f}" ) average_loss = total_loss / self.local_step self.stats["valid_loss"].append(average_loss) if self.local_rank == 0: pbar.close() if not self.use_loss_as_metric and len(self.metrics) > 0: result = self.metrics[0].measure() self.stats["results"].append( result if self.best_mode == "min" else -result ) # if max mode, use -result else: self.stats["results"].append( average_loss ) # if no metric, choose best by min loss for metric in self.metrics: # self.log(metric.report(), style="blue") psnr = metric.report().split("=")[-1].strip()[:5] self.psnr = float(psnr) if self.use_tensorboardX and loader._data.type == 'val': metric.write(self.writer, self.epoch, prefix="evaluate") metric.clear() self.ssim /= self.local_step self.lpips_alex /= self.local_step self.lpips_vgg /= self.local_step if self.ema is not None: self.ema.restore() self.log( f"eeeeeeeeee> {loader._data.type} Metric Report: Epoch{self.epoch}. psnr:{psnr} ssim:{self.ssim:.2f} alex:{self.lpips_alex:.2f} vgg:{self.lpips_vgg:.2f}" ) def eval_step(self, data): rays_o = data["rays_o"] # [B, N, 3] rays_d = data["rays_d"] # [B, N, 3] images = data["images"] # [B, H, W, 3/4] B, H, W, C = images.shape if self.opt.color_space == "linear": images[..., :3] = srgb_to_linear(images[..., :3]) # eval with fixed background color bg_color = 1 if C == 4: gt_rgb = images[..., :3] * images[..., 3:] + bg_color * ( 1 - images[..., 3:] ) else: gt_rgb = images outputs = self.model_stu.render( rays_o, rays_d, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt), ) pred_rgb = outputs["image"].reshape(B, H, W, 3) pred_depth = outputs["depth"].reshape(B, H, W) loss = self.criterion(pred_rgb, gt_rgb).mean() return pred_rgb, pred_depth, gt_rgb, loss def save_checkpoint(self, name=None, full=False, best=False, remove_old=True): full = False if name is None: name = f"{self.name}_ep{self.epoch:04d}" if self.opt.model_type == "vm": state = { "epoch": self.epoch, "global_step": self.global_step, "stats": self.stats, "resolution": self.model_stu.resolution, } else: state = { "epoch": self.epoch, "global_step": self.global_step, "stats": self.stats, } if self.model_stu.cuda_ray: state["mean_count"] = self.model_stu.mean_count state["mean_density"] = self.model_stu.mean_density if full: state["optimizer"] = self.optimizer.state_dict() state["lr_scheduler"] = self.lr_scheduler.state_dict() state["scaler"] = self.scaler.state_dict() if self.ema is not None: state["ema"] = self.ema.state_dict() if not best: state["model"] = self.model_stu.state_dict() file_path = f"{self.ckpt_path}/{name}.pth" if remove_old: self.stats["checkpoints"].append(file_path) if len(self.stats["checkpoints"]) > self.max_keep_ckpt: old_ckpt = self.stats["checkpoints"].pop(0) if os.path.exists(old_ckpt): os.remove(old_ckpt) torch.save(state, file_path) else: if len(self.stats["results"]) > 0: if ( self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"] ): self.log( f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}" ) self.stats["best_result"] = self.stats["results"][-1] # save ema results if self.ema is not None: self.ema.store() self.ema.copy_to() state["model"] = self.model_stu.state_dict() if self.ema is not None: self.ema.restore() torch.save(state, self.best_path) else: self.log( f"[WARN] no evaluated results found, skip saving best checkpoint." ) def load_teacher_checkpoint(self): checkpoint_dict = torch.load(self.opt.ckpt_teacher, map_location=self.device) missing_keys, unexpected_keys = self.model_tea.load_state_dict( checkpoint_dict["model"], strict=False ) self.log("[INFO] loaded teacher model.") if len(missing_keys) > 0: self.log(f"[WARN] missing keys: {missing_keys}") if len(unexpected_keys) > 0: self.log(f"[WARN] unexpected keys: {unexpected_keys}") if self.ema is not None and "ema" in checkpoint_dict: self.ema.load_state_dict(checkpoint_dict["ema"]) if self.model_tea.cuda_ray: if "mean_count" in checkpoint_dict: self.model_tea.mean_count = checkpoint_dict["mean_count"] if "mean_density" in checkpoint_dict: self.model_tea.mean_density = checkpoint_dict["mean_density"] """ self.stats = checkpoint_dict['stats'] self.epoch = checkpoint_dict['epoch'] self.global_step = checkpoint_dict['global_step'] self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") if self.optimizer and 'optimizer' in checkpoint_dict: try: self.optimizer.load_state_dict(checkpoint_dict['optimizer']) self.log("[INFO] loaded optimizer.") except: self.log("[WARN] Failed to load optimizer.") if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: try: self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) self.log("[INFO] loaded scheduler.") except: self.log("[WARN] Failed to load scheduler.") if self.scaler and 'scaler' in checkpoint_dict: try: self.scaler.load_state_dict(checkpoint_dict['scaler']) self.log("[INFO] loaded scaler.") except: self.log("[WARN] Failed to load scaler.") if self.model_tea.cuda_ray: if 'mean_count' in checkpoint_dict: self.model_tea.mean_count = checkpoint_dict['mean_count'] if 'mean_density' in checkpoint_dict: self.model_tea.mean_density = checkpoint_dict['mean_density'] """ def load_student_checkpoint(self): if self.opt.ckpt_student: checkpoint_dict = torch.load( self.opt.ckpt_student, map_location=self.device ) else: checkpoint_dict = torch.load( self.opt.ckpt_teacher, map_location=self.device ) if self.opt.model_type == "vm" and "resolution" in checkpoint_dict: self.model_stu.upsample_model(checkpoint_dict["resolution"]) missing_keys, unexpected_keys = self.model_stu.load_state_dict( checkpoint_dict["model"], strict=False ) self.log("[INFO] loaded student model.") if len(missing_keys) > 0: self.log(f"[WARN] missing keys: {missing_keys}") if len(unexpected_keys) > 0: self.log(f"[WARN] unexpected keys: {unexpected_keys}") if self.model_stu.cuda_ray: if "mean_count" in checkpoint_dict: self.model_stu.mean_count = checkpoint_dict["mean_count"] if "mean_density" in checkpoint_dict: self.model_stu.mean_density = checkpoint_dict["mean_density"] if self.ema is not None and "ema" in checkpoint_dict: self.ema.load_state_dict(checkpoint_dict["ema"]) """ self.stats = checkpoint_dict['stats'] self.epoch = checkpoint_dict['epoch'] self.global_step = checkpoint_dict['global_step'] self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") if self.optimizer and 'optimizer' in checkpoint_dict: try: self.optimizer.load_state_dict(checkpoint_dict['optimizer']) self.log("[INFO] loaded optimizer.") except: self.log("[WARN] Failed to load optimizer.") if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: try: self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) self.log("[INFO] loaded scheduler.") except: self.log("[WARN] Failed to load scheduler.") if self.scaler and 'scaler' in checkpoint_dict: try: self.scaler.load_state_dict(checkpoint_dict['scaler']) self.log("[INFO] loaded scaler.") except: self.log("[WARN] Failed to load scaler.") """ def load_checkpoint(self, checkpoint=None, model_only=False): if checkpoint is None: checkpoint_list = sorted(glob.glob(f"{self.ckpt_path}/{self.name}_ep*.pth")) if checkpoint_list: checkpoint = checkpoint_list[-1] self.log(f"[INFO] Latest checkpoint is {checkpoint}") else: self.log("[WARN] No checkpoint found, model randomly initialized.") return checkpoint_dict = torch.load(checkpoint, map_location=self.device) if "model" not in checkpoint_dict: self.model.load_state_dict(checkpoint_dict) self.log("[INFO] loaded model.") return missing_keys, unexpected_keys = self.model.load_state_dict( checkpoint_dict["model"], strict=False ) self.log("[INFO] loaded model.") if len(missing_keys) > 0: self.log(f"[WARN] missing keys: {missing_keys}") if len(unexpected_keys) > 0: self.log(f"[WARN] unexpected keys: {unexpected_keys}") if self.ema is not None and "ema" in checkpoint_dict: self.ema.load_state_dict(checkpoint_dict["ema"]) if self.model.cuda_ray: if "mean_count" in checkpoint_dict: self.model.mean_count = checkpoint_dict["mean_count"] if "mean_density" in checkpoint_dict: self.model.mean_density = checkpoint_dict["mean_density"] if model_only: return self.stats = checkpoint_dict["stats"] self.epoch = checkpoint_dict["epoch"] self.global_step = checkpoint_dict["global_step"] self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}") if self.optimizer and "optimizer" in checkpoint_dict: try: self.optimizer.load_state_dict(checkpoint_dict["optimizer"]) self.log("[INFO] loaded optimizer.") except: self.log("[WARN] Failed to load optimizer.") if self.lr_scheduler and "lr_scheduler" in checkpoint_dict: try: self.lr_scheduler.load_state_dict(checkpoint_dict["lr_scheduler"]) self.log("[INFO] loaded scheduler.") except: self.log("[WARN] Failed to load scheduler.") if self.scaler and "scaler" in checkpoint_dict: try: self.scaler.load_state_dict(checkpoint_dict["scaler"]) self.log("[INFO] loaded scaler.") except: self.log("[WARN] Failed to load scaler.") def test(self, loader, save_path=None, name=None): assert 1 == 2 if save_path is None: save_path = os.path.join(self.workspace, "results") if name is None: name = f"{self.name}_ep{self.epoch:04d}" os.makedirs(save_path, exist_ok=True) self.log(f"==> Start Test, save results to {save_path}") pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) self.model_stu.eval() with torch.no_grad(): # update grid if self.model_stu.cuda_ray: with torch.cuda.amp.autocast(enabled=self.fp16): self.model_stu.update_extra_state() for i, data in enumerate(loader): with torch.cuda.amp.autocast(enabled=self.fp16): preds, preds_depth = self.test_step(data) path = os.path.join(save_path, f"{name}_{i:04d}.png") path_depth = os.path.join(save_path, f"{name}_{i:04d}_depth.png") # self.log(f"[INFO] saving test image to {path}") if self.opt.color_space == "linear": preds = linear_to_srgb(preds) pred = preds[0].detach().cpu().numpy() pred_depth = preds_depth[0].detach().cpu().numpy() cv2.imwrite( path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR) ) cv2.imwrite(path_depth, (pred_depth * 255).astype(np.uint8)) pbar.update(loader.batch_size) self.log(f"==> Finished Test.") # moved out bg_color and perturb for more flexible control... def test_step(self, data, bg_color=None, perturb=False): rays_o = data["rays_o"] # [B, N, 3] rays_d = data["rays_d"] # [B, N, 3] H, W = data["H"], data["W"] if bg_color is not None: bg_color = bg_color.to(self.device) outputs = self.model_stu.render( rays_o, rays_d, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt), ) pred_rgb = outputs["image"].reshape(-1, H, W, 3) pred_depth = outputs["depth"].reshape(-1, H, W) return pred_rgb, pred_depth ================================================ FILE: gridencoder/__init__.py ================================================ from .grid import GridEncoder ================================================ FILE: gridencoder/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path _backend = load( name="_grid_encoder", extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, sources=[ os.path.join(_src_path, "src", f) for f in [ "gridencoder.cu", "bindings.cpp", ] ], ) __all__ = ["_backend"] ================================================ FILE: gridencoder/grid.py ================================================ import numpy as np import torch import torch.nn as nn from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.cuda.amp import custom_bwd, custom_fwd try: import _gridencoder as _backend except ImportError: from .backend import _backend _gridtype_to_id = { "hash": 0, "tiled": 1, } class _grid_encode(Function): @staticmethod @custom_fwd def forward( ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, ): # inputs: [B, D], float in [0, 1] # embeddings: [sO, C], float # offsets: [L + 1], int # RETURN: [B, F], float inputs = inputs.contiguous() B, D = inputs.shape # batch size, coord dim L = offsets.shape[0] - 1 # level C = embeddings.shape[1] # embedding dim for each level S = np.log2( per_level_scale ) # resolution multiplier at each level, apply log2 for later CUDA exp2f H = base_resolution # base resolution # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) # if C % 2 != 0, force float, since half for atomicAdd is very slow. if torch.is_autocast_enabled() and C % 2 == 0: embeddings = embeddings.to(torch.half) # L first, optimize cache for cuda kernel, but needs an extra permute later outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) if calc_grad_inputs: dy_dx = torch.empty( B, L * D * C, device=inputs.device, dtype=embeddings.dtype ) else: dy_dx = torch.empty( 1, device=inputs.device, dtype=embeddings.dtype ) # placeholder... TODO: a better way? _backend.grid_encode_forward( inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners, ) # permute back to [B, L * C] outputs = outputs.permute(1, 0, 2).reshape(B, L * C) ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) ctx.dims = [B, D, C, L, S, H, gridtype] ctx.calc_grad_inputs = calc_grad_inputs ctx.align_corners = align_corners return outputs @staticmethod # @once_differentiable @custom_bwd def backward(ctx, grad): inputs, embeddings, offsets, dy_dx = ctx.saved_tensors B, D, C, L, S, H, gridtype = ctx.dims calc_grad_inputs = ctx.calc_grad_inputs align_corners = ctx.align_corners # grad: [B, L * C] --> [L, B, C] grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() grad_embeddings = torch.zeros_like(embeddings) if calc_grad_inputs: grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) else: grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype) _backend.grid_encode_backward( grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners, ) if calc_grad_inputs: grad_inputs = grad_inputs.to(inputs.dtype) return grad_inputs, grad_embeddings, None, None, None, None, None, None else: return None, grad_embeddings, None, None, None, None, None, None grid_encode = _grid_encode.apply class GridEncoder(nn.Module): def __init__( self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None, gridtype="hash", align_corners=False, ): super().__init__() # the finest resolution desired at the last level, if provided, overridee per_level_scale if desired_resolution is not None: per_level_scale = np.exp2( np.log2(desired_resolution / base_resolution) / (num_levels - 1) ) self.input_dim = input_dim # coord dims, 2 or 3 self.num_levels = num_levels # num levels, each level multiply resolution by 2 self.level_dim = level_dim # encode channels per level self.per_level_scale = ( per_level_scale # multiply resolution by this scale at each level. ) self.log2_hashmap_size = log2_hashmap_size self.base_resolution = base_resolution self.output_dim = num_levels * level_dim self.gridtype = gridtype self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" self.align_corners = align_corners # allocate parameters offsets = [] offset = 0 self.max_params = 2 ** log2_hashmap_size for i in range(num_levels): resolution = int(np.ceil(base_resolution * per_level_scale ** i)) params_in_level = min( self.max_params, (resolution if align_corners else resolution + 1) ** input_dim, ) # limit max number params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible offsets.append(offset) offset += params_in_level offsets.append(offset) offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) self.register_buffer("offsets", offsets) self.n_params = offsets[-1] * level_dim # parameters self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) self.reset_parameters() def reset_parameters(self): std = 1e-4 self.embeddings.data.uniform_(-std, std) def __repr__(self): return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}" def forward(self, inputs, bound=1): # inputs: [..., input_dim], normalized real world positions in [-bound, bound] # return: [..., num_levels * level_dim] inputs = (inputs + bound) / (2 * bound) # map to [0, 1] # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) prefix_shape = list(inputs.shape[:-1]) inputs = inputs.view(-1, self.input_dim) outputs = grid_encode( inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, ) outputs = outputs.view(prefix_shape + [self.output_dim]) # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) return outputs ================================================ FILE: gridencoder/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path setup( name="gridencoder", # package name, import this to use python API ext_modules=[ CUDAExtension( name="_gridencoder", # extension name, import this to use CUDA API sources=[ os.path.join(_src_path, "src", f) for f in [ "gridencoder.cu", "bindings.cpp", ] ], extra_compile_args={ "cxx": c_flags, "nvcc": nvcc_flags, }, ), ], cmdclass={ "build_ext": BuildExtension, }, ) ================================================ FILE: gridencoder/src/bindings.cpp ================================================ #include #include "gridencoder.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); } ================================================ FILE: gridencoder/src/gridencoder.cu ================================================ #include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") // just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { // requires CUDA >= 10 and ARCH >= 70 // this is very slow compared to float or __half2, and never used. //return atomicAdd(reinterpret_cast<__half*>(address), val); } template static inline __host__ __device__ T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } template __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions."); // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional // coordinates. constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 }; uint32_t result = 0; #pragma unroll for (uint32_t i = 0; i < D; ++i) { result ^= pos_grid[i] * primes[i]; } return result; } template __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { uint32_t stride = 1; uint32_t index = 0; #pragma unroll for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { index += pos_grid[d] * stride; stride *= align_corners ? resolution: (resolution + 1); } // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. // gridtype: 0 == hash, 1 == tiled if (gridtype == 0 && stride > hashmap_size) { index = fast_hash(pos_grid); } return (index % hashmap_size) * C + ch; } template __global__ void kernel_grid( const float * __restrict__ inputs, const scalar_t * __restrict__ grid, const int * __restrict__ offsets, scalar_t * __restrict__ outputs, const uint32_t B, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t * __restrict__ dy_dx, const uint32_t gridtype, const bool align_corners ) { const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; if (b >= B) return; const uint32_t level = blockIdx.y; // locate grid += (uint32_t)offsets[level] * C; inputs += b * D; outputs += level * B * C + b * C; // check input range (should be in [0, 1]) bool flag_oob = false; #pragma unroll for (uint32_t d = 0; d < D; d++) { if (inputs[d] < 0 || inputs[d] > 1) { flag_oob = true; } } // if input out of bound, just set output to 0 if (flag_oob) { #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { outputs[ch] = 0; } if (calc_grad_inputs) { dy_dx += b * D * L * C + level * D * C; // B L D C #pragma unroll for (uint32_t d = 0; d < D; d++) { #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { dy_dx[d * C + ch] = 0; } } } return; } const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; const float scale = exp2f(level * S) * H - 1.0f; const uint32_t resolution = (uint32_t)ceil(scale) + 1; // calculate coordinate float pos[D]; uint32_t pos_grid[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); pos_grid[d] = floorf(pos[d]); pos[d] -= (float)pos_grid[d]; } //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); // interpolate scalar_t results[C] = {0}; // temp results in register #pragma unroll for (uint32_t idx = 0; idx < (1 << D); idx++) { float w = 1; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { if ((idx & (1 << d)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); // writing to register (fast) #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { results[ch] += w * grid[index + ch]; } //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); } // writing to global memory (slow) #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { outputs[ch] = results[ch]; } // prepare dy_dx for calc_grad_inputs // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 if (calc_grad_inputs) { dy_dx += b * D * L * C + level * D * C; // B L D C #pragma unroll for (uint32_t gd = 0; gd < D; gd++) { scalar_t results_grad[C] = {0}; #pragma unroll for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { float w = scale; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t nd = 0; nd < D - 1; nd++) { const uint32_t d = (nd >= gd) ? (nd + 1) : nd; if ((idx & (1 << nd)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } pos_grid_local[gd] = pos_grid[gd]; uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); pos_grid_local[gd] = pos_grid[gd] + 1; uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]); } } #pragma unroll for (uint32_t ch = 0; ch < C; ch++) { dy_dx[gd * C + ch] = results_grad[ch]; } } } } template __global__ void kernel_grid_backward( const scalar_t * __restrict__ grad, const float * __restrict__ inputs, const scalar_t * __restrict__ grid, const int * __restrict__ offsets, scalar_t * __restrict__ grad_grid, const uint32_t B, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners ) { const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; if (b >= B) return; const uint32_t level = blockIdx.y; const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; // locate grad_grid += offsets[level] * C; inputs += b * D; grad += level * B * C + b * C + ch; // L, B, C const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; const float scale = exp2f(level * S) * H - 1.0f; const uint32_t resolution = (uint32_t)ceil(scale) + 1; // check input range (should be in [0, 1]) #pragma unroll for (uint32_t d = 0; d < D; d++) { if (inputs[d] < 0 || inputs[d] > 1) { return; // grad is init as 0, so we simply return. } } // calculate coordinate float pos[D]; uint32_t pos_grid[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); pos_grid[d] = floorf(pos[d]); pos[d] -= (float)pos_grid[d]; } scalar_t grad_cur[N_C] = {0}; // fetch to register #pragma unroll for (uint32_t c = 0; c < N_C; c++) { grad_cur[c] = grad[c]; } // interpolate #pragma unroll for (uint32_t idx = 0; idx < (1 << D); idx++) { float w = 1; uint32_t pos_grid_local[D]; #pragma unroll for (uint32_t d = 0; d < D; d++) { if ((idx & (1 << d)) == 0) { w *= 1 - pos[d]; pos_grid_local[d] = pos_grid[d]; } else { w *= pos[d]; pos_grid_local[d] = pos_grid[d] + 1; } } uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 // TODO: use float which is better than __half, if N_C % 2 != 0 if (std::is_same::value && N_C % 2 == 0) { #pragma unroll for (uint32_t c = 0; c < N_C; c += 2) { // process two __half at once (by interpreting as a __half2) __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; atomicAdd((__half2*)&grad_grid[index + c], v); } // float, or __half when N_C % 2 != 0 (which means C == 1) } else { #pragma unroll for (uint32_t c = 0; c < N_C; c++) { atomicAdd(&grad_grid[index + c], w * grad_cur[c]); } } } } template __global__ void kernel_input_backward( const scalar_t * __restrict__ grad, const scalar_t * __restrict__ dy_dx, scalar_t * __restrict__ grad_inputs, uint32_t B, uint32_t L ) { const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; if (t >= B * D) return; const uint32_t b = t / D; const uint32_t d = t - b * D; dy_dx += b * L * D * C; scalar_t result = 0; # pragma unroll for (int l = 0; l < L; l++) { # pragma unroll for (int ch = 0; ch < C; ch++) { result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; } } grad_inputs[t] = result; } template void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { static constexpr uint32_t N_THREAD = 512; const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; switch (C) { case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } // inputs: [B, D], float, in [0, 1] // embeddings: [sO, C], float // offsets: [L + 1], uint32_t // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) // H: base resolution // dy_dx: [B, L * D * C] template void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { switch (D) { case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } template void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { static constexpr uint32_t N_THREAD = 256; const uint32_t N_C = std::min(2u, C); // n_features_per_thread const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; switch (C) { case 1: kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); break; case 2: kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); break; case 4: kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); break; case 8: kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } // grad: [L, B, C], float // inputs: [B, D], float, in [0, 1] // embeddings: [sO, C], float // offsets: [L + 1], uint32_t // grad_embeddings: [sO, C] // H: base resolution template void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { switch (D) { case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break; case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break; default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; } } void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) { CHECK_CUDA(inputs); CHECK_CUDA(embeddings); CHECK_CUDA(offsets); CHECK_CUDA(outputs); CHECK_CUDA(dy_dx); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(embeddings); CHECK_CONTIGUOUS(offsets); CHECK_CONTIGUOUS(outputs); CHECK_CONTIGUOUS(dy_dx); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(embeddings); CHECK_IS_INT(offsets); CHECK_IS_FLOATING(outputs); CHECK_IS_FLOATING(dy_dx); AT_DISPATCH_FLOATING_TYPES_AND_HALF( embeddings.scalar_type(), "grid_encode_forward", ([&] { grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), gridtype, align_corners); })); } void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) { CHECK_CUDA(grad); CHECK_CUDA(inputs); CHECK_CUDA(embeddings); CHECK_CUDA(offsets); CHECK_CUDA(grad_embeddings); CHECK_CUDA(dy_dx); CHECK_CUDA(grad_inputs); CHECK_CONTIGUOUS(grad); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(embeddings); CHECK_CONTIGUOUS(offsets); CHECK_CONTIGUOUS(grad_embeddings); CHECK_CONTIGUOUS(dy_dx); CHECK_CONTIGUOUS(grad_inputs); CHECK_IS_FLOATING(grad); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(embeddings); CHECK_IS_INT(offsets); CHECK_IS_FLOATING(grad_embeddings); CHECK_IS_FLOATING(dy_dx); CHECK_IS_FLOATING(grad_inputs); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "grid_encode_backward", ([&] { grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), grad_inputs.data_ptr(), gridtype, align_corners); })); } ================================================ FILE: gridencoder/src/gridencoder.h ================================================ #ifndef _HASH_ENCODE_H #define _HASH_ENCODE_H #include #include // inputs: [B, D], float, in [0, 1] // embeddings: [sO, C], float // offsets: [L + 1], uint32_t // outputs: [B, L * C], float // H: base resolution void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners); void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners); #endif ================================================ FILE: just_train_tea/network.py ================================================ import torch from time import time import torch.nn as nn import torch.nn.functional as F from tools.encoding import get_encoder from tools.activation import trunc_exp from .renderer import NeRFRenderer import raymarching class NeRFNetwork(NeRFRenderer): def __init__( self, encoding="hashgrid", encoding_dir="sphere_harmonics", encoding_bg="hashgrid", num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, num_layers_bg=2, hidden_dim_bg=64, bound=1, model_type="hash", args=None, is_teacher=False, **kwargs, ): super().__init__(bound, **kwargs) # sigma network assert model_type in ["hash", "mlp", "vm", "tensors"] self.is_teacher = is_teacher self.num_layers = num_layers self.hidden_dim = hidden_dim self.geo_feat_dim = geo_feat_dim self.args = args self.opt = args self.model_type = model_type self.plenoxel_degree = args.plenoxel_degree self.plenoxel_res = eval(args.plenoxel_res) assert len(self.plenoxel_res) == 3 self.encoder, self.in_dim = get_encoder( encoding, desired_resolution=2048 * bound, num_levels=14, ) if "hash" != self.model_type: self.encoder = None if self.model_type == "mlp": self.encoder_nerf_pe, self.in_dim_nerf = get_encoder( encoding="frequency", multires=self.args.PE ) self.skips = self.args.skip self.nerf_layer_num = self.args.nerf_layer_num W = self.args.nerf_layer_wide self.nerf_mlp = [nn.Linear(self.in_dim_nerf, W)] for i in range(self.nerf_layer_num - 2): if i != self.skips: self.nerf_mlp.append(nn.Linear(W, W)) else: self.nerf_mlp.append(nn.Linear(W + self.in_dim_nerf, W)) self.nerf_mlp.append(nn.Linear(W, self.in_dim)) self.nerf_mlp = nn.ModuleList(self.nerf_mlp) elif self.model_type == "vm": self.sigma_rank = [16] * 3 self.color_rank = [48] * 3 self.color_feat_dim = 15 # geo_feat_dim self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0] self.resolution = [self.opt.resolution0] * 3 # mat: paralist[1,16,res0,res0] repeat 3 vec: paralist[1,16,res0,1] repeat 3; repeat3 because decompose 3D grid [H, W, D] to three 2D mat [H, W], [H,D], [W, D] or decompose to three 1D vec [H], [W], [D] self.sigma_mat, self.sigma_vec = self.init_one_vm( self.sigma_rank, self.resolution ) # mat: paralist[1,48,res0,res0] repeat 3 vec: paralist[1,48,res0,1] repeat 3 self.color_mat, self.color_vec = self.init_one_vm( self.color_rank, self.resolution ) # Linear(in_features=144, out_features=27) self.basis_mat = nn.Linear( sum(self.color_rank), self.color_feat_dim, bias=False ) elif self.model_type == "tensors": self.init_plenoxel_volume( s=0.02, fea_dim=self.plenoxel_degree ** 2 * 3 + 1, volume=self.plenoxel_res, ) elif self.model_type == "hash": pass else: raise ValueError(f"error model_type:{self.model_type}") if self.model_type != "vm" and self.model_type != "tensors": sigma_net = [] for l in range(num_layers): if l == 0: in_dim = self.in_dim else: in_dim = hidden_dim if l == num_layers - 1: out_dim = ( 1 + self.geo_feat_dim ) # 1 sigma + 15 SH features for color else: out_dim = hidden_dim sigma_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.sigma_net = nn.ModuleList(sigma_net) # color network self.num_layers_color = num_layers_color self.hidden_dim_color = hidden_dim_color # self.encoder_dir, self.in_dim_dir = get_encoder(encoding=encoding_dir) if self.model_type == "tensors": self.encoder_dir, self.in_dim_dir = get_encoder( encoding="sphere_harmonics", degree=self.plenoxel_degree, ) else: self.encoder_dir, self.in_dim_dir = get_encoder( encoding=encoding_dir, input_dim=3, multires=2 ) if self.model_type != "tensors": color_net = [] for l in range(num_layers_color): if l == 0: in_dim = self.in_dim_dir + self.geo_feat_dim else: in_dim = hidden_dim if l == num_layers_color - 1: out_dim = 3 # 3 rgb else: out_dim = hidden_dim color_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.color_net = nn.ModuleList(color_net) # background network if self.bg_radius > 0: self.num_layers_bg = num_layers_bg self.hidden_dim_bg = hidden_dim_bg self.encoder_bg, self.in_dim_bg = get_encoder( encoding_bg, input_dim=2, num_levels=4, log2_hashmap_size=19, desired_resolution=2048, ) # much smaller hashgrid bg_net = [] for l in range(num_layers_bg): if l == 0: in_dim = self.in_dim_bg + self.in_dim_dir else: in_dim = hidden_dim_bg if l == num_layers_bg - 1: out_dim = 3 # 3 rgb else: out_dim = hidden_dim_bg bg_net.append(nn.Linear(in_dim, out_dim, bias=False)) self.bg_net = nn.ModuleList(bg_net) else: self.bg_net = None def init_plenoxel_volume(self, s=0.1, fea_dim=27 + 1, volume=[128, 128, 128]): tensor = [] tensor.append( torch.nn.Parameter( s * torch.randn((1, fea_dim, volume[0], volume[1], volume[2])) ) ) self.tensor_volume = torch.nn.ParameterList(tensor).cuda() def init_one_vm(self, n_component, resolution, scale=0.1): # self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0] mat, vec = [], [] for i in range(len(self.vec_ids)): vec_id = self.vec_ids[i] mat_id_0, mat_id_1 = self.mat_ids[i] mat.append( nn.Parameter( scale * torch.randn( (1, n_component[i], resolution[mat_id_1], resolution[mat_id_0]) ) ) ) # [1, R, H, W] vec.append( nn.Parameter( scale * torch.randn((1, n_component[i], resolution[vec_id], 1)) ) ) # [1, R, D, 1] (fake 2d to use grid_sample) return nn.ParameterList(mat), nn.ParameterList(vec) def get_sigma_feat(self, x): # x: [N, 3], in [-1, 1] (outliers will be treated as zero due to grid_sample padding mode) # self.mat_ids = [[0, 1], [0, 2], [1, 2]] self.vec_ids = [2, 1, 0] N = x.shape[0] # plane + line basis mat_coord = ( torch.stack( ( x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]], ) ) .detach() .view(3, -1, 1, 2) ) # [3, N, 1, 2] vec_coord = torch.stack( (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]]) ) vec_coord = ( torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1) .detach() .view(3, -1, 1, 2) ) # [3, N, 1, 2], fake 2d coord sigma_feat = torch.zeros( [ N, ], device=x.device, ) for i in range(len(self.sigma_mat)): mat_feat = F.grid_sample( self.sigma_mat[i], mat_coord[[i]], align_corners=True ).view( -1, N ) # [1, R, N, 1] --> [R, N] vec_feat = F.grid_sample( self.sigma_vec[i], vec_coord[[i]], align_corners=True ).view( -1, N ) # [R, N] sigma_feat = sigma_feat + torch.sum(mat_feat * vec_feat, dim=0) return sigma_feat def get_color_feat(self, x): # x: [N, 3], in [-1, 1] N = x.shape[0] # plane + line basis mat_coord = ( torch.stack( ( x[..., self.mat_ids[0]], x[..., self.mat_ids[1]], x[..., self.mat_ids[2]], ) ) .detach() .view(3, -1, 1, 2) ) # [3, N, 1, 2] vec_coord = torch.stack( (x[..., self.vec_ids[0]], x[..., self.vec_ids[1]], x[..., self.vec_ids[2]]) ) vec_coord = ( torch.stack((torch.zeros_like(vec_coord), vec_coord), dim=-1) .detach() .view(3, -1, 1, 2) ) # [3, N, 1, 2], fake 2d coord mat_feat, vec_feat = [], [] for i in range(len(self.color_mat)): mat_feat.append( F.grid_sample( self.color_mat[i], mat_coord[[i]], align_corners=True ).view(-1, N) ) # [1, R, N, 1] --> [R, N] vec_feat.append( F.grid_sample( self.color_vec[i], vec_coord[[i]], align_corners=True ).view(-1, N) ) # [R, N] mat_feat = torch.cat(mat_feat, dim=0) # [3 * R, N] vec_feat = torch.cat(vec_feat, dim=0) # [3 * R, N] color_feat = self.basis_mat( (mat_feat * vec_feat).T ) # [N, 3R] --> [N, color_feat_dim] return color_feat def compute_plenoxel_fea(self, x): composed = self.tensor_volume[0] composed = ( F.grid_sample(composed, x.view(1, 1, -1, 1, 3), align_corners=True) .view(-1, x.shape[0]) .permute(1, 0) ) return composed # [N, fea_dim] def forward_nerf_mlp(self, x): x = self.encoder_nerf_pe(x) in_pts = x for i in range(len(self.nerf_mlp)): x = self.nerf_mlp[i](x) if i != len(self.nerf_mlp) - 1: x = F.relu(x, inplace=True) if i == self.skips: x = torch.cat([in_pts, x], -1) return x def forward(self, x, d): # x: [N, 3], in [-bound, bound] d: [N, 3], nomalized in [-1, 1] # sigma if self.model_type == "hash": x = self.encoder( x, bound=self.bound ) # out_x[N, 28=num_levels * fea_per_level] elif self.model_type == "mlp": x = self.forward_nerf_mlp(x) # 28 elif self.model_type == "vm": x = ( 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 ) # x:[N, 3] sigma_feat = self.get_sigma_feat(x) # sigma_feat:[N] color_feat = self.get_color_feat(x) # color_feat:[N, 15] sigma_feat = torch.clamp( sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max ) # color_feat = torch.clamp(color_feat, self.args.sigma_clip_min, self.args.sigma_clip_max) self.feature_sigma_color = torch.cat( [sigma_feat.unsqueeze(-1), color_feat], dim=-1 ) self.sigma_l = sigma_feat sigma = trunc_exp(sigma_feat) # sigma:[N] enc_d = self.encoder_dir(d) # enc_d:[N, 16] h = torch.cat([enc_d, color_feat], dim=-1) # h:[N, 16+15] for l in range(self.num_layers_color): h = self.color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) color = torch.sigmoid(h) self.color_l = color return sigma, color elif self.model_type == "tensors": x = ( 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 ) # x:[N, 3] x = self.compute_plenoxel_fea(x) h = x sigma = torch.clamp( h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max ) self.sigma_l = sigma sigma = trunc_exp(sigma) self.sigma = sigma sh = h[..., 1:].view( -1, 3, self.plenoxel_degree ** 2 ) # [N, 3, 9] ## .permute(1, 0, 2) # [B, 27]-->[9, B, 3] enc_d = self.encoder_dir(d).unsqueeze(1) # [N, 9]-->[N,1,9] color = (sh * enc_d).sum(-1) # [N, 3] color = torch.sigmoid(color) self.feature_sigma_color = None self.color_l = color return sigma, color else: raise ValueError(f"not illegal model_type:{self.model_type}") h = x for l in range(self.num_layers): h = self.sigma_net[l](h) if l != self.num_layers - 1: h = F.relu(h, inplace=True) h[..., 0] = torch.clamp( h[..., 0].clone(), self.args.sigma_clip_min, self.args.sigma_clip_max ) # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max) self.feature_sigma_color = h self.sigma_l = h[..., 0] sigma = trunc_exp(h[..., 0]) # sigma: [n] geo_feat = h[..., 1:] # geo_feat: [n, 15] d = self.encoder_dir(d) # d: [n, 16] h = torch.cat([d, geo_feat], dim=-1) # h: [n, 15+16] for l in range(self.num_layers_color): h = self.color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) color = torch.sigmoid(h) self.color_l = color return sigma, color def density(self, x): # x: [N, 3], in [-bound, bound] if self.model_type == "hash": x = self.encoder( x, bound=self.bound ) # out_x[N, 32=num_levels * fea_per_level] elif self.model_type == "mlp": x = self.forward_nerf_mlp(x) elif self.model_type == "vm": x = ( 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 ) sigma_feat = self.get_sigma_feat(x) sigma_feat = torch.clamp( sigma_feat, self.args.sigma_clip_min, self.args.sigma_clip_max ) sigma = trunc_exp(sigma_feat) return {"sigma": sigma} elif self.model_type == "tensors": x = ( 2 * (x - self.aabb_train[:3]) / (self.aabb_train[3:] - self.aabb_train[:3]) - 1 ) # x:[N, 3] x = self.compute_plenoxel_fea(x) h = x # h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max) sigma = trunc_exp( torch.clamp( h[..., 0], self.args.sigma_clip_min, self.args.sigma_clip_max ) ) sigma = trunc_exp(h[..., 0]) return {"sigma": sigma} else: raise ValueError(f"not illegal model_type:{self.model_type}") h = x for l in range(self.num_layers): h = self.sigma_net[l](h) if l != self.num_layers - 1: h = F.relu(h, inplace=True) h = torch.clamp(h, self.args.sigma_clip_min, self.args.sigma_clip_max) sigma = trunc_exp(h[..., 0]) geo_feat = h[..., 1:] return { "sigma": sigma, "geo_feat": geo_feat, } def background(self, x, d): assert 1 == 2 # x: [N, 2], in [-1, 1] h = self.encoder_bg(x) # [N, C] d = self.encoder_dir(d) h = torch.cat([d, h], dim=-1) for l in range(self.num_layers_bg): h = self.bg_net[l](h) if l != self.num_layers_bg - 1: h = F.relu(h, inplace=True) # sigmoid activation for rgb rgbs = torch.sigmoid(h) return rgbs # allow masked inference def color(self, x, d, mask=None, geo_feat=None, **kwargs): assert 1 == 2 # x: [N, 3] in [-bound, bound] # mask: [N,], bool, indicates where we actually needs to compute rgb. if mask is not None: rgbs = torch.zeros( mask.shape[0], 3, dtype=x.dtype, device=x.device ) # [N, 3] # in case of empty mask if not mask.any(): return rgbs x = x[mask] d = d[mask] geo_feat = geo_feat[mask] d = self.encoder_dir(d) h = torch.cat([d, geo_feat], dim=-1) for l in range(self.num_layers_color): h = self.color_net[l](h) if l != self.num_layers_color - 1: h = F.relu(h, inplace=True) # sigmoid activation for rgb h = torch.sigmoid(h) if mask is not None: rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 else: rgbs = h return rgbs # L1 penalty for loss def density_loss(self): loss = 0 for i in range(len(self.sigma_mat)): loss = ( loss + torch.mean(torch.abs(self.sigma_mat[i])) + torch.mean(torch.abs(self.sigma_vec[i])) ) return loss # upsample utils @torch.no_grad() def upsample_params(self, mat, vec, resolution): for i in range(len(self.vec_ids)): vec_id = self.vec_ids[i] mat_id_0, mat_id_1 = self.mat_ids[i] mat[i] = nn.Parameter( F.interpolate( mat[i].data, size=(resolution[mat_id_1], resolution[mat_id_0]), mode="bilinear", align_corners=True, ) ) vec[i] = nn.Parameter( F.interpolate( vec[i].data, size=(resolution[vec_id], 1), mode="bilinear", align_corners=True, ) ) @torch.no_grad() def upsample_model(self, resolution): self.upsample_params(self.sigma_mat, self.sigma_vec, resolution) self.upsample_params(self.color_mat, self.color_vec, resolution) self.resolution = resolution @torch.no_grad() def shrink_model(self): # shrink aabb_train and the model so it only represents the space inside aabb_train. half_grid_size = self.bound / self.grid_size thresh = min(self.density_thresh, self.mean_density) valid_grid = self.density_grid[self.cascade - 1] > thresh # [N] valid_pos = raymarching.morton3D_invert( torch.nonzero(valid_grid) ) # [Nz] --> [Nz, 3], in [0, H - 1] # plot_pointcloud(valid_pos.detach().cpu().numpy()) # lots of noisy outliers in hashnerf... valid_pos = (2 * valid_pos / (self.grid_size - 1) - 1) * ( self.bound - half_grid_size ) # [Nz, 3], in [-b+hgs, b-hgs] min_pos = valid_pos.amin(0) - half_grid_size # [3] max_pos = valid_pos.amax(0) + half_grid_size # [3] # shrink model reso = torch.LongTensor(self.resolution).to(self.aabb_train.device) units = (self.aabb_train[3:] - self.aabb_train[:3]) / reso tl = (min_pos - self.aabb_train[:3]) / units br = (max_pos - self.aabb_train[:3]) / units tl = torch.round(tl).long().clamp(min=0) br = torch.minimum(torch.round(br).long(), reso) for i in range(len(self.vec_ids)): vec_id = self.vec_ids[i] mat_id_0, mat_id_1 = self.mat_ids[i] self.sigma_vec[i] = nn.Parameter( self.sigma_vec[i].data[..., tl[vec_id] : br[vec_id], :] ) self.color_vec[i] = nn.Parameter( self.color_vec[i].data[..., tl[vec_id] : br[vec_id], :] ) self.sigma_mat[i] = nn.Parameter( self.sigma_mat[i].data[ ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0] ] ) self.color_mat[i] = nn.Parameter( self.color_mat[i].data[ ..., tl[mat_id_1] : br[mat_id_1], tl[mat_id_0] : br[mat_id_0] ] ) self.aabb_train = torch.cat([min_pos, max_pos], dim=0) # [6] print( f"[INFO] shrink slice: {tl.cpu().numpy().tolist()} - {br.cpu().numpy().tolist()}" ) print(f"[INFO] new aabb: {self.aabb_train.cpu().numpy().tolist()}") # optimizer utils def get_params(self, lr, lr2=1e-3): if self.model_type == "hash": params = [ {"params": self.encoder.parameters(), "lr": lr}, {"params": self.sigma_net.parameters(), "lr": lr}, {"params": self.encoder_dir.parameters(), "lr": lr}, {"params": self.color_net.parameters(), "lr": lr}, ] elif self.model_type == "mlp": params = [ {"params": self.sigma_net.parameters(), "lr": lr}, {"params": self.encoder_dir.parameters(), "lr": lr}, {"params": self.color_net.parameters(), "lr": lr}, {"params": self.nerf_mlp.parameters(), "lr": lr}, ] elif self.model_type == "vm": params = [ {"params": self.color_net.parameters(), "lr": lr2}, {"params": self.sigma_mat, "lr": lr}, {"params": self.sigma_vec, "lr": lr}, {"params": self.color_mat, "lr": lr}, {"params": self.color_vec, "lr": lr}, {"params": self.basis_mat.parameters(), "lr": lr2}, ] elif self.model_type == "tensors": params = [ {"params": self.tensor_volume.parameters(), "lr": lr}, {"params": self.encoder_dir.parameters(), "lr": lr}, ] else: raise ValueError(f"not illegal model_type:{self.model_type}") if self.bg_radius > 0: params.append({"params": self.encoder_bg.parameters(), "lr": lr}) params.append({"params": self.bg_net.parameters(), "lr": lr}) return params ================================================ FILE: just_train_tea/provider.py ================================================ import os import cv2 import glob import json import tqdm import numpy as np from scipy.spatial.transform import Slerp, Rotation import trimesh import torch from torch.utils.data import DataLoader from .utils import get_rays, srgb_to_linear # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 def nerf_matrix_to_ngp(pose, scale=0.33): # for the fox dataset, 0.33 scales camera radius to ~ 2 new_pose = np.array( [ [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale], [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale], [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale], [0, 0, 0, 1], ], dtype=np.float32, ) return new_pose def rand_poses( size, device, radius=1, theta_range=[np.pi / 3, 2 * np.pi / 3], phi_range=[0, 2 * np.pi], ): """generate random poses from an orbit camera Args: size: batch size of generated poses. device: where to allocate the output. radius: camera radius theta_range: [min, max], should be in [0, \pi] phi_range: [min, max], should be in [0, 2\pi] Return: poses: [size, 4, 4] """ def normalize(vectors): return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) thetas = ( torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] ) phis = ( torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] ) centers = torch.stack( [ radius * torch.sin(thetas) * torch.sin(phis), radius * torch.cos(thetas), radius * torch.sin(thetas) * torch.cos(phis), ], dim=-1, ) # [B, 3] # lookat forward_vector = -normalize(centers) up_vector = ( torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) ) # confused at the coordinate system... right_vector = normalize(torch.cross(forward_vector, up_vector, dim=-1)) up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) poses = ( torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) ) poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) poses[:, :3, 3] = centers return poses def normalize(vectors): return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) interval_nums = torch.tensor( [i * 1 / (size - 1) for i in range(size)], dtype=torch.float32, device=device ) thetas = interval_nums * (theta_range[1] - theta_range[0]) + theta_range[0] phis = interval_nums * (phi_range[1] - phi_range[0]) + phi_range[0] centers = torch.stack( [ radius * torch.sin(thetas) * torch.sin(phis), radius * torch.cos(thetas), radius * torch.sin(thetas) * torch.cos(phis), ], dim=-1, ) # [B, 3] # lookat forward_vector = -normalize(centers) up_vector = ( torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0).repeat(size, 1) ) # confused at the coordinate system... right_vector = normalize( torch.cross(forward_vector, up_vector, dim=-1) ) # cross product up_vector = normalize(torch.cross(right_vector, forward_vector, dim=-1)) poses = ( torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) ) poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) poses[:, :3, 3] = centers return poses class NeRFDataset: def __init__(self, opt, device, type="train", downscale=1, n_test=10): super().__init__() self.opt = opt self.args = opt self.device = device self.type = type # train, val, test self.downscale = downscale self.root_path = opt.path self.mode = opt.mode # only support blender self.preload = opt.preload # preload data into GPU self.scale = ( opt.scale ) # camera radius scale to make sure camera are inside the bounding box. self.bound = ( opt.bound ) # bounding box half length, also used as the radius to random sample poses. self.fp16 = opt.fp16 # if preload, load into fp16. self.training = self.type in ["train", "all", "trainval"] self.num_rays = self.opt.num_rays if self.training else -1 if self.mode == "blender": if type == "all": transform_paths = glob.glob(os.path.join(self.root_path, "*.json")) transform = None for transform_path in transform_paths: with open(transform_path, "r") as f: tmp_transform = json.load(f) if transform is None: transform = tmp_transform else: transform["frames"].extend(tmp_transform["frames"]) # load train and val split elif type == "trainval": with open( os.path.join(self.root_path, f"transforms_train.json"), "r" ) as f: transform = json.load(f) with open( os.path.join(self.root_path, f"transforms_val.json"), "r" ) as f: transform_val = json.load(f) transform["frames"].extend(transform_val["frames"]) # only load one specified split else: with open( os.path.join(self.root_path, f"transforms_{type}.json"), "r" ) as f: transform = json.load(f) else: raise NotImplementedError(f"unknown dataset mode: {self.mode}") # load image size if "h" in transform and "w" in transform: self.H = int(transform["h"]) // downscale self.W = int(transform["w"]) // downscale else: # we have to actually read an image to get H and W later. self.H = self.W = None # read images frames = transform["frames"] if True: self.poses = [] self.images = [] for f in tqdm.tqdm(frames, desc=f"Loading {type} data:"): f_path = os.path.join(self.root_path, f["file_path"]) if ( self.mode == "blender" and f_path[-4:].lower() != ".png" and f_path[-4:].lower() != ".jpg" ): f_path += ".png" # so silly... if not os.path.exists(f_path): continue pose = np.array(f["transform_matrix"], dtype=np.float32) # [4, 4] pose = nerf_matrix_to_ngp(pose, scale=self.scale) image = cv2.imread( f_path, cv2.IMREAD_UNCHANGED ) # [H, W, 3] o [H, W, 4] if self.H is None or self.W is None: self.H = image.shape[0] // downscale self.W = image.shape[1] // downscale # add support for the alpha channel as a mask. if image.shape[-1] == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) else: image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) if image.shape[0] != self.H or image.shape[1] != self.W: image = cv2.resize( image, (self.W, self.H), interpolation=cv2.INTER_AREA ) image = image.astype(np.float32) / 255 # [H, W, 3/4] self.poses.append(pose) self.images.append(image) self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4] if self.images is not None: self.images = torch.from_numpy( np.stack(self.images, axis=0) ) # [N, H, W, C] self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() if self.training and self.opt.error_map: self.error_map = torch.ones( [self.images.shape[0], 128 * 128], dtype=torch.float ) # [B, 128 * 128], flattened for easy indexing, fixed resolution... else: self.error_map = None if self.preload: self.poses = self.poses.to(self.device) if self.images is not None: if self.fp16 and self.opt.color_space != "linear": dtype = torch.half else: dtype = torch.float self.images = self.images.to(dtype).to(self.device) if self.error_map is not None: self.error_map = self.error_map.to(self.device) # load intrinsics if "fl_x" in transform or "fl_y" in transform: fl_x = ( transform["fl_x"] if "fl_x" in transform else transform["fl_y"] ) / downscale fl_y = ( transform["fl_y"] if "fl_y" in transform else transform["fl_x"] ) / downscale elif "camera_angle_x" in transform or "camera_angle_y" in transform: # blender, assert in radians. already downscaled since we use H/W fl_x = ( self.W / (2 * np.tan(transform["camera_angle_x"] / 2)) if "camera_angle_x" in transform else None ) fl_y = ( self.H / (2 * np.tan(transform["camera_angle_y"] / 2)) if "camera_angle_y" in transform else None ) if fl_x is None: fl_x = fl_y if fl_y is None: fl_y = fl_x else: raise RuntimeError( "Failed to load focal length, please check the transforms.json!" ) cx = (transform["cx"] / downscale) if "cx" in transform else (self.H / 2) cy = (transform["cy"] / downscale) if "cy" in transform else (self.W / 2) self.intrinsics = np.array([fl_x, fl_y, cx, cy]) def collate(self, index): B = len(index) # a list of length 1 poses = self.poses[index].to(self.device) # [B, 4, 4] error_map = None if self.error_map is None else self.error_map[index] rays = get_rays( poses, self.intrinsics, self.H, self.W, self.num_rays, error_map ) results = { "H": self.H, "W": self.W, "rays_o": rays["rays_o"], "rays_d": rays["rays_d"], } if self.images is not None: images = self.images[index].to(self.device) # [B, H, W, 3/4] if self.training: C = images.shape[-1] images = torch.gather( images.view(B, -1, C), 1, torch.stack(C * [rays["inds"]], -1) ) # [B, N, 3/4] results["images"] = images # need inds to update error_map if error_map is not None: results["index"] = index results["inds_coarse"] = rays["inds_coarse"] return results def dataloader(self): size = len(self.poses) loader = DataLoader( list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0, ) loader._data = self return loader ================================================ FILE: just_train_tea/renderer.py ================================================ import math import trimesh import numpy as np from time import time import torch import torch.nn as nn import torch.nn.functional as F import raymarching from .utils import custom_meshgrid def sample_pdf(bins, weights, n_samples, det=False): # This implementation is from NeRF # bins: [B, T], old_z_vals # weights: [B, T - 1], bin weights. # return: [B, n_samples], new_z_vals # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # Take uniform samples if det: u = torch.linspace( 0.0 + 0.5 / n_samples, 1.0 - 0.5 / n_samples, steps=n_samples ).to(weights.device) u = u.expand(list(cdf.shape[:-1]) + [n_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(weights.device) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (B, n_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = cdf_g[..., 1] - cdf_g[..., 0] denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples def plot_pointcloud(pc, color=None): # pc: [N, 3] # color: [N, 3/4] print("[visualize points]", pc.shape, pc.dtype, pc.min(0), pc.max(0)) pc = trimesh.PointCloud(pc, color) # axis axes = trimesh.creation.axis(axis_length=4) # sphere sphere = trimesh.creation.icosphere(radius=1) trimesh.Scene([pc, axes, sphere]).show() class NeRFRenderer(nn.Module): def __init__( self, bound=1, cuda_ray=False, density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance. min_near=0.2, density_thresh=0.01, bg_radius=-1, grid_size=128, ): super().__init__() print("\n---------------", grid_size, "--------------\n") self.bound = bound self.cascade = 1 + math.ceil(math.log2(bound)) self.grid_size = grid_size self.density_scale = density_scale self.min_near = min_near self.density_thresh = density_thresh self.bg_radius = bg_radius # radius of the background sphere. # prepare aabb with a 6D tensor (xmin, ymin, zmin, xmax, ymax, zmax) # NOTE: aabb (can be rectangular) is only used to generate points, we still rely on bound (always cubic) to calculate density grid and hashing. aabb_train = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) aabb_infer = aabb_train.clone() self.register_buffer("aabb_train", aabb_train) self.register_buffer("aabb_infer", aabb_infer) # extra state for cuda raymarching self.cuda_ray = cuda_ray if cuda_ray: # density grid density_grid = torch.zeros( [self.cascade, self.grid_size ** 3] ) # [CAS, H * H * H] density_bitfield = torch.zeros( self.cascade * self.grid_size ** 3 // 8, dtype=torch.uint8 ) # [CAS * H * H * H // 8] self.register_buffer("density_grid", density_grid) self.register_buffer("density_bitfield", density_bitfield) self.mean_density = 0 self.iter_density = 0 # step counter step_counter = torch.zeros( 16, 2, dtype=torch.int32 ) # 16 is hardcoded for averaging... self.register_buffer("step_counter", step_counter) self.mean_count = 0 self.local_step = 0 def forward(self, x, d): raise NotImplementedError() # separated density and color query (can accelerate non-cuda-ray mode.) def density(self, x): raise NotImplementedError() def color(self, x, d, mask=None, **kwargs): raise NotImplementedError() def reset_extra_state(self): if not self.cuda_ray: return # density grid self.density_grid.zero_() self.mean_density = 0 self.iter_density = 0 # step counter self.step_counter.zero_() self.mean_count = 0 self.local_step = 0 def run( self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None, perturb=False, **kwargs ): # rays_o, rays_d: [B, N, 3], assumes B == 1 # bg_color: [3] in range [0, 1] # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # choose aabb aabb = self.aabb_train if self.training else self.aabb_infer # sample steps nears, fars = raymarching.near_far_from_aabb( rays_o, rays_d, aabb, self.min_near ) nears.unsqueeze_(-1) fars.unsqueeze_(-1) # print(f'nears = {nears.min().item()} ~ {nears.max().item()}, fars = {fars.min().item()} ~ {fars.max().item()}') z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze( 0 ) # [1, T] z_vals = z_vals.expand((N, num_steps)) # [N, T] z_vals = nears + (fars - nears) * z_vals # [N, T], in [nears, fars] # perturb z_vals sample_dist = (fars - nears) / num_steps if perturb: z_vals = ( z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist ) # z_vals = z_vals.clamp(nears, fars) # avoid out of bounds xyzs. # generate xyzs xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze( -1 ) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] xyzs = torch.min(torch.max(xyzs, aabb[:3]), aabb[3:]) # a manual clip. # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) # query SDF and RGB density_outputs = self.density(xyzs.reshape(-1, 3)) # sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] for k, v in density_outputs.items(): density_outputs[k] = v.view(N, num_steps, -1) # upsample z_vals (nerf-like) if upsample_steps > 0: with torch.no_grad(): deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] deltas = torch.cat( [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1 ) alphas = 1 - torch.exp( -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1) ) # [N, T] alphas_shifted = torch.cat( [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1 ) # [N, T+1] weights = ( alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] ) # [N, T] # sample new z_vals z_vals_mid = z_vals[..., :-1] + 0.5 * deltas[..., :-1] # [N, T-1] new_z_vals = sample_pdf( z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training ).detach() # [N, t] new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze( -2 ) * new_z_vals.unsqueeze( -1 ) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] new_xyzs = torch.min( torch.max(new_xyzs, aabb[:3]), aabb[3:] ) # a manual clip. # only forward new points to save computation new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) # new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] for k, v in new_density_outputs.items(): new_density_outputs[k] = v.view(N, upsample_steps, -1) # re-order z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] z_vals, z_index = torch.sort(z_vals, dim=1) xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] xyzs = torch.gather( xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs) ) for k in density_outputs: tmp_output = torch.cat( [density_outputs[k], new_density_outputs[k]], dim=1 ) density_outputs[k] = torch.gather( tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output) ) deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] deltas = torch.cat( [deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1 ) alphas = 1 - torch.exp( -deltas * self.density_scale * density_outputs["sigma"].squeeze(-1) ) # [N, T+t] alphas_shifted = torch.cat( [torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1 ) # [N, T+t+1] weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) for k, v in density_outputs.items(): density_outputs[k] = v.view(-1, v.shape[-1]) mask = weights > 1e-4 # hard coded rgbs = self.color( xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs ) rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] # print(xyzs.shape, 'valid_rgb:', mask.sum().item()) # calculate weight_sum (mask) weights_sum = weights.sum(dim=-1) # [N] # calculate depth ori_z_vals = ((z_vals - nears) / (fars - nears)).clamp(0, 1) depth = torch.sum(weights * ori_z_vals, dim=-1) # calculate color image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color polar = raymarching.polar_from_ray( rays_o, rays_d, self.bg_radius ) # [N, 2] in [-1, 1] bg_color = self.background(polar, rays_d.reshape(-1, 3)) # [N, 3] elif bg_color is None: bg_color = 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color image = image.view(*prefix, 3) depth = depth.view(*prefix) # tmp: reg loss in mip-nerf 360 # z_vals_shifted = torch.cat([z_vals[..., 1:], sample_dist * torch.ones_like(z_vals[..., :1])], dim=-1) # mid_zs = (z_vals + z_vals_shifted) / 2 # [N, T] # loss_dist = (torch.abs(mid_zs.unsqueeze(1) - mid_zs.unsqueeze(2)) * (weights.unsqueeze(1) * weights.unsqueeze(2))).sum() + 1/3 * ((z_vals_shifted - z_vals_shifted) * (weights ** 2)).sum() return { "depth": depth, "image": image, } def run_cuda( self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, inherited_params=[], **kwargs ): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: image: [B, N, 3], depth: [B, N] prefix = rays_o.shape[:-1] rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # pre-calculate near far nears, fars = raymarching.near_far_from_aabb( rays_o, rays_d, self.aabb_train if self.training else self.aabb_infer, self.min_near, ) # mix background color if self.bg_radius > 0: # use the bg model to calculate bg_color polar = raymarching.polar_from_ray( rays_o, rays_d, self.bg_radius ) # [N, 2] in [-1, 1] bg_color = self.background(polar, rays_d) # [N, 3] elif bg_color is None: bg_color = 1 if self.training: # XXX gurantee the inference # setup counter time1 = time() counter = self.step_counter[self.local_step % 16] counter.zero_() # set to 0 self.local_step += 1 if ( self.args.render_stu_first ): # if stu first, then using stu to calculate xyzs, and tea will inherite the xyzs if not self.is_teacher: xyzs, dirs, deltas, rays = raymarching.march_rays_train( rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps, ) inherited_params = [xyzs, dirs, deltas, rays] else: xyzs, dirs, deltas, rays = inherited_params else: if self.is_teacher: xyzs, dirs, deltas, rays = raymarching.march_rays_train( rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, counter, self.mean_count, perturb, 128, force_all_rays, dt_gamma, max_steps, ) inherited_params = [xyzs, dirs, deltas, rays] else: xyzs, dirs, deltas, rays = inherited_params # print('\n', self.model_type, self.mean_count, self.mean_density) # from IPython import embed; embed() # plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) sigmas, rgbs = self(xyzs, dirs) # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. # sigmas = density_outputs['sigma'] # rgbs = self.color(xyzs, dirs, **density_outputs) sigmas = self.density_scale * sigmas # print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') time2 = time() # special case for CCNeRF's residual learning if len(sigmas.shape) == 2: K = sigmas.shape[0] depths = [] images = [] for k in range(K): weights_sum, depth, image = raymarching.composite_rays_train( sigmas[k], rgbs[k], deltas, rays ) image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears) images.append(image.view(*prefix, 3)) depths.append(depth.view(*prefix)) depth = torch.stack(depths, axis=0) # [K, B, N] image = torch.stack(images, axis=0) # [K, B, N, 3] else: weights_sum, depth, image = raymarching.composite_rays_train( sigmas, rgbs, deltas, rays ) image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears) image = image.view(*prefix, 3) depth = depth.view(*prefix) else: # allocate outputs # if use autocast, must init as half so it won't be autocasted and lose reference. # dtype = torch.half if torch.is_autocast_enabled() else torch.float32 # output should always be float32! only network inference uses half. dtype = torch.float32 weights_sum = torch.zeros(N, dtype=dtype, device=device) depth = torch.zeros(N, dtype=dtype, device=device) image = torch.zeros(N, 3, dtype=dtype, device=device) n_alive = N alive_counter = torch.zeros([1], dtype=torch.int32, device=device) rays_alive = torch.zeros( 2, n_alive, dtype=torch.int32, device=device ) # 2 is used to loop old/new rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device) step = 0 i = 0 while step < max_steps: # count alive rays if step == 0: # init rays at first step. torch.arange(n_alive, out=rays_alive[0]) rays_t[0] = nears else: alive_counter.zero_() raymarching.compact_rays( n_alive, rays_alive[i % 2], rays_alive[(i + 1) % 2], rays_t[i % 2], rays_t[(i + 1) % 2], alive_counter, ) n_alive = alive_counter.item() # must invoke D2H copy here # exit loop if n_alive <= 0: break # decide compact_steps n_step = max(min(N // n_alive, 8), 1) xyzs, dirs, deltas = raymarching.march_rays( n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb, dt_gamma, max_steps, ) sigmas, rgbs = self(xyzs, dirs) # density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. # sigmas = density_outputs['sigma'] # rgbs = self.color(xyzs, dirs, **density_outputs) sigmas = self.density_scale * sigmas raymarching.composite_rays( n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], sigmas, rgbs, deltas, weights_sum, depth, image, ) # print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}') step += n_step i += 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color depth = torch.clamp(depth - nears, min=0) / (fars - nears) image = image.view(*prefix, 3) depth = depth.view(*prefix) # print('\n--- render time:--- {:6f} {:.6f}'.format(time2-time1, time()-time2)) return { "depth": depth, "image": image, "inherited_params": inherited_params, } @torch.no_grad() def mark_untrained_grid(self, poses, intrinsic, S=64): # poses: [B, 4, 4] # intrinsic: [3, 3] if not self.cuda_ray: return if isinstance(poses, np.ndarray): poses = torch.from_numpy(poses) B = poses.shape[0] fx, fy, cx, cy = intrinsic X = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) Y = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) Z = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) count = torch.zeros_like(self.density_grid) poses = poses.to(count.device) # 5-level loop, forgive me... for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1, ) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] world_xyzs = ( 2 * coords.float() / (self.grid_size - 1) - 1 ).unsqueeze( 0 ) # [1, N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_world_xyzs = world_xyzs * (bound - half_grid_size) # split batch to avoid OOM head = 0 while head < B: tail = min(head + S, B) # world2cam transform (poses is c2w, so we need to transpose it. Another transpose is needed for batched matmul, so the final form is without transpose.) cam_xyzs = cas_world_xyzs - poses[ head:tail, :3, 3 ].unsqueeze(1) cam_xyzs = cam_xyzs @ poses[head:tail, :3, :3] # [S, N, 3] # query if point is covered by any camera mask_z = cam_xyzs[:, :, 2] > 0 # [S, N] mask_x = ( torch.abs(cam_xyzs[:, :, 0]) < cx / fx * cam_xyzs[:, :, 2] + half_grid_size * 2 ) mask_y = ( torch.abs(cam_xyzs[:, :, 1]) < cy / fy * cam_xyzs[:, :, 2] + half_grid_size * 2 ) mask = (mask_z & mask_x & mask_y).sum(0).reshape(-1) # [N] # update count count[cas, indices] += mask head += S # mark untrained grid as -1 self.density_grid[count == 0] = -1 # print(f'[mark untrained grid] {(count == 0).sum()} from {resolution ** 3 * self.cascade}') @torch.no_grad() def update_extra_state(self, decay=0.95, S=128): # call before each epoch to update extra states. if not self.cuda_ray: return # update density grid tmp_grid = -torch.ones_like(self.density_grid) # full update. if self.iter_density < 16: # if True: X = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) Y = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) Z = torch.arange( self.grid_size, dtype=torch.int32, device=self.density_grid.device ).split(S) for xs in X: for ys in Y: for zs in Z: # construct points xx, yy, zz = custom_meshgrid(xs, ys, zs) coords = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1, ) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] xyzs = ( 2 * coords.float() / (self.grid_size - 1) - 1 ) # [N, 3] in [-1, 1] # cascading for cas in range(self.cascade): bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += ( torch.rand_like(cas_xyzs) * 2 - 1 ) * half_grid_size # query density sigmas = ( self.density(cas_xyzs)["sigma"].reshape(-1).detach() ) sigmas *= self.density_scale # assign tmp_grid[cas, indices] = sigmas # partial update (half the computation) # TODO: why no need of maxpool ? else: N = self.grid_size ** 3 // 4 # H * H * H / 4 for cas in range(self.cascade): # random sample some positions coords = torch.randint( 0, self.grid_size, (N, 3), device=self.density_grid.device ) # [N, 3], in [0, 128) indices = raymarching.morton3D(coords).long() # [N] # random sample occupied positions occ_indices = torch.nonzero(self.density_grid[cas] > 0).squeeze( -1 ) # [Nz] rand_mask = torch.randint( 0, occ_indices.shape[0], [N], dtype=torch.long, device=self.density_grid.device, ) occ_indices = occ_indices[ rand_mask ] # [Nz] --> [N], allow for duplication occ_coords = raymarching.morton3D_invert(occ_indices) # [N, 3] # concat indices = torch.cat([indices, occ_indices], dim=0) coords = torch.cat([coords, occ_coords], dim=0) # same below xyzs = ( 2 * coords.float() / (self.grid_size - 1) - 1 ) # [N, 3] in [-1, 1] bound = min(2 ** cas, self.bound) half_grid_size = bound / self.grid_size # scale to current cascade's resolution cas_xyzs = xyzs * (bound - half_grid_size) # add noise in [-hgs, hgs] cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_size # query density sigmas = self.density(cas_xyzs)["sigma"].reshape(-1).detach() sigmas *= self.density_scale # assign tmp_grid[cas, indices] = sigmas ## max-pool on tmp_grid for less aggressive culling [No significant improvement...] # invalid_mask = tmp_grid < 0 # tmp_grid = F.max_pool3d(tmp_grid.view(self.cascade, 1, self.grid_size, self.grid_size, self.grid_size), kernel_size=3, stride=1, padding=1).view(self.cascade, -1) # tmp_grid[invalid_mask] = -1 # ema update valid_mask = (self.density_grid >= 0) & (tmp_grid >= 0) self.density_grid[valid_mask] = torch.maximum( self.density_grid[valid_mask] * decay, tmp_grid[valid_mask] ) self.mean_density = torch.mean( self.density_grid.clamp(min=0) ).item() # -1 non-training regions are viewed as 0 density. self.iter_density += 1 # convert to bitfield density_thresh = min(self.mean_density, self.density_thresh) self.density_bitfield = raymarching.packbits( self.density_grid, density_thresh, self.density_bitfield ) ### update step counter total_step = min(16, self.local_step) if total_step > 0: self.mean_count = int( self.step_counter[:total_step, 0].sum().item() / total_step ) self.local_step = 0 # print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3 * self.cascade):.3f} | [step counter] mean={self.mean_count}') def render(self, rays_o, rays_d, staged=False, max_ray_batch=4096, **kwargs): # rays_o, rays_d: [B, N, 3], assumes B == 1 # return: pred_rgb: [B, N, 3] if self.cuda_ray: _run = self.run_cuda else: _run = self.run B, N = rays_o.shape[:2] device = rays_o.device # never stage when cuda_ray if staged and not self.cuda_ray: depth = torch.empty((B, N), device=device) image = torch.empty((B, N, 3), device=device) for b in range(B): head = 0 while head < N: tail = min(head + max_ray_batch, N) results_ = _run( rays_o[b : b + 1, head:tail], rays_d[b : b + 1, head:tail], **kwargs ) depth[b : b + 1, head:tail] = results_["depth"] image[b : b + 1, head:tail] = results_["image"] head += max_ray_batch results = {} results["depth"] = depth results["image"] = image else: results = _run(rays_o, rays_d, **kwargs) return results ================================================ FILE: just_train_tea/utils.py ================================================ import os import lpips import glob import tqdm import math import random import warnings import tensorboardX import numpy as np import pandas as pd import time from datetime import datetime import cv2 import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.distributed as dist from torch.utils.data import Dataset, DataLoader import trimesh import mcubes from rich.console import Console from torch_ema import ExponentialMovingAverage from packaging import version as pver device = torch.device("cuda") def custom_meshgrid(*args): # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid if pver.parse(torch.__version__) < pver.parse("1.10"): return torch.meshgrid(*args) else: return torch.meshgrid(*args, indexing="ij") @torch.jit.script def linear_to_srgb(x): return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x ** 0.41666 - 0.055) @torch.jit.script def srgb_to_linear(x): return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4) def compute_ssim( img0, img1, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, return_map=False, ): """Computes SSIM from two images. This function was modeled after tf.image.ssim, and should produce comparable output. Args: img0: torch.tensor. An image of size [..., width, height, num_channels]. img1: torch.tensor. An image of size [..., width, height, num_channels]. max_val: float > 0. The maximum magnitude that `img0` or `img1` can have. filter_size: int >= 1. Window size. filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering. k1: float > 0. One of the SSIM dampening parameters. k2: float > 0. One of the SSIM dampening parameters. return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned Returns: Each image's mean SSIM, or a tensor of individual values if `return_map`. """ device = img0.device img0 = img0.type(torch.float32) img1 = img1.type(torch.float32) ori_shape = img0.size() width, height, num_channels = ori_shape[-3:] img0 = img0.view(-1, width, height, num_channels).permute(0, 3, 1, 2) img1 = img1.view(-1, width, height, num_channels).permute(0, 3, 1, 2) batch_size = img0.shape[0] # Construct a 1D Gaussian blur filter. hw = filter_size // 2 shift = (2 * hw - filter_size + 1) / 2 f_i = ((torch.arange(filter_size, device=device) - hw + shift) / filter_sigma) ** 2 filt = torch.exp(-0.5 * f_i) filt /= torch.sum(filt) # Blur in x and y (faster than the 2D convolution). # z is a tensor of size [B, H, W, C] filt_fn1 = lambda z: F.conv2d( z, filt.view(1, 1, -1, 1).repeat(num_channels, 1, 1, 1), padding=[hw, 0], groups=num_channels, ) filt_fn2 = lambda z: F.conv2d( z, filt.view(1, 1, 1, -1).repeat(num_channels, 1, 1, 1), padding=[0, hw], groups=num_channels, ) # Vmap the blurs to the tensor size, and then compose them. filt_fn = lambda z: filt_fn1(filt_fn2(z)) mu0 = filt_fn(img0) mu1 = filt_fn(img1) mu00 = mu0 * mu0 mu11 = mu1 * mu1 mu01 = mu0 * mu1 sigma00 = filt_fn(img0 ** 2) - mu00 sigma11 = filt_fn(img1 ** 2) - mu11 sigma01 = filt_fn(img0 * img1) - mu01 # Clip the variances and covariances to valid values. # Variance must be non-negative: sigma00 = torch.clamp(sigma00, min=0.0) sigma11 = torch.clamp(sigma11, min=0.0) sigma01 = torch.sign(sigma01) * torch.min( torch.sqrt(sigma00 * sigma11), torch.abs(sigma01) ) c1 = (k1 * max_val) ** 2 c2 = (k2 * max_val) ** 2 numer = (2 * mu01 + c1) * (2 * sigma01 + c2) denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) ssim_map = numer / denom ssim = torch.mean(ssim_map.reshape([-1, num_channels * width * height]), dim=-1) return ssim_map if return_map else ssim def init_lpips(net_name, device): assert net_name in ["alex", "vgg"] import lpips print(f"init_lpips: lpips_{net_name}") return lpips.LPIPS(net=net_name, version="0.1").eval().cuda() lpips_fns = { "alex": lpips.LPIPS(net="alex", version="0.1").eval().cuda(), "vgg": lpips.LPIPS(net="vgg", version="0.1").eval().cuda(), } def rgb_lpips(gt, im, net_name): assert net_name in ["alex", "vgg"] gt = gt.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda() im = im.type(torch.float32).permute([0, 3, 1, 2]).contiguous().cuda() return lpips_fns[net_name](gt, im, normalize=True).item() @torch.cuda.amp.autocast(enabled=False) def get_rays(poses, intrinsics, H, W, N=-1, error_map=None): """get rays Args: poses: [B, 4, 4], cam2world intrinsics: [4] H, W, N: int error_map: [B, 128 * 128], sample probability based on training error Returns: rays_o, rays_d: [B, N, 3] inds: [B, N] """ device = poses.device B = poses.shape[0] fx, fy, cx, cy = intrinsics i, j = custom_meshgrid( torch.linspace(0, W - 1, W, device=device), torch.linspace(0, H - 1, H, device=device), ) i = i.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 j = j.t().reshape([1, H * W]).expand([B, H * W]) + 0.5 results = {} if N > 0: N = min(N, H * W) if error_map is None: inds = torch.randint(0, H * W, size=[N], device=device) # may duplicate inds = inds.expand([B, N]) else: # weighted sample on a low-reso grid inds_coarse = torch.multinomial( error_map.to(device), N, replacement=False ) # [B, N], but in [0, 128*128) # map to the original resolution with random perturb. inds_x, inds_y = ( inds_coarse // 128, inds_coarse % 128, ) # `//` will throw a warning in torch 1.10... anyway. sx, sy = H / 128, W / 128 inds_x = ( (inds_x * sx + torch.rand(B, N, device=device) * sx) .long() .clamp(max=H - 1) ) inds_y = ( (inds_y * sy + torch.rand(B, N, device=device) * sy) .long() .clamp(max=W - 1) ) inds = inds_x * W + inds_y results["inds_coarse"] = inds_coarse # need this when updating error_map i = torch.gather(i, -1, inds) j = torch.gather(j, -1, inds) results["inds"] = inds else: inds = torch.arange(H * W, device=device).expand([B, H * W]) zs = torch.ones_like(i) xs = (i - cx) / fx * zs ys = (j - cy) / fy * zs directions = torch.stack((xs, ys, zs), dim=-1) directions = directions / torch.norm(directions, dim=-1, keepdim=True) rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) # (B, N, 3) rays_o = poses[..., :3, 3] # [B, 3] rays_o = rays_o[..., None, :].expand_as(rays_d) # [B, N, 3] results["rays_o"] = rays_o results["rays_d"] = rays_d return results def seed_everything(seed): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = True def torch_vis_2d(x, renormalize=False): # x: [3, H, W] or [1, H, W] or [H, W] import matplotlib.pyplot as plt import numpy as np import torch if isinstance(x, torch.Tensor): if len(x.shape) == 3: x = x.permute(1, 2, 0).squeeze() x = x.detach().cpu().numpy() print(f"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}") x = x.astype(np.float32) # renormalize if renormalize: x = (x - x.min(axis=0, keepdims=True)) / ( x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8 ) plt.imshow(x) plt.show() def extract_fields(bound_min, bound_max, resolution, query_func, S=128): X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S) Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S) Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S) u = np.zeros([resolution, resolution, resolution], dtype=np.float32) with torch.no_grad(): for xi, xs in enumerate(X): for yi, ys in enumerate(Y): for zi, zs in enumerate(Z): xx, yy, zz = custom_meshgrid(xs, ys, zs) pts = torch.cat( [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1, ) # [S, 3] val = ( query_func(pts) .reshape(len(xs), len(ys), len(zs)) .detach() .cpu() .numpy() ) # [S, 1] --> [x, y, z] u[ xi * S : xi * S + len(xs), yi * S : yi * S + len(ys), zi * S : zi * S + len(zs), ] = val return u def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): # print('threshold: {}'.format(threshold)) u = extract_fields(bound_min, bound_max, resolution, query_func) # print(u.shape, u.max(), u.min(), np.percentile(u, 50)) vertices, triangles = mcubes.marching_cubes(u, threshold) b_max_np = bound_max.detach().cpu().numpy() b_min_np = bound_min.detach().cpu().numpy() vertices = ( vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] ) return vertices, triangles class PSNRMeter: def __init__(self): self.V = 0 self.N = 0 def clear(self): self.V = 0 self.N = 0 def prepare_inputs(self, *inputs): outputs = [] for i, inp in enumerate(inputs): if torch.is_tensor(inp): inp = inp.detach().cpu().numpy() outputs.append(inp) return outputs def update(self, preds, truths): preds, truths = self.prepare_inputs( preds, truths ) # [B, N, 3] or [B, H, W, 3], range[0, 1] # simplified since max_pixel_value is 1 here. psnr = -10 * np.log10(np.mean((preds - truths) ** 2)) self.V += psnr self.N += 1 def measure(self): return self.V / self.N def write(self, writer, global_step, prefix=""): writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step) def report(self): return f"PSNR = {self.measure():.6f}" class Trainer(object): def __init__( self, name, # name of this experiment opt, # extra conf model_tea, # network model_stu, criterion=None, # loss function, if None, assume inline implementation in train_step optimizer=None, # optimizer ema_decay=None, # if use EMA, set the decay lr_scheduler=None, # scheduler metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. local_rank=0, # which GPU am I world_size=1, # total num of GPUs device=None, # device to use, usually setting to None is OK. (auto choose device) mute=False, # whether to mute all print fp16=False, # amp optimize level eval_interval=10e10, # eval once every $ epoch max_keep_ckpt=2, # max num of saved ckpts in disk workspace="workspace", # workspace to save logs & ckpts best_mode="min", # the smaller/larger result, the better use_loss_as_metric=True, # use loss as the first metric report_metric_at_train=False, # also report metrics at training use_checkpoint="latest", # which ckpt to use at init time use_tensorboardX=True, # whether to use tensorboard for logging scheduler_update_every_step=False, # whether to call scheduler.step() after every train step ): self.optimizer_fn = optimizer self.lr_scheduler_fn = lr_scheduler self.name = name self.opt = opt self.mute = mute self.metrics = metrics self.local_rank = local_rank self.world_size = world_size self.workspace = workspace self.ema_decay = ema_decay self.fp16 = fp16 self.best_mode = best_mode self.use_loss_as_metric = use_loss_as_metric self.report_metric_at_train = report_metric_at_train self.max_keep_ckpt = max_keep_ckpt self.eval_interval = eval_interval self.use_checkpoint = use_checkpoint self.use_tensorboardX = use_tensorboardX self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") self.scheduler_update_every_step = scheduler_update_every_step self.device = ( device if device is not None else torch.device( f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" ) ) self.console = Console() self.model_tea = model_tea.to(device) self.model_stu = model_stu.to(device) if isinstance(criterion, nn.Module): criterion.to(self.device) self.criterion = criterion if optimizer is None: self.optimizer = optim.AdamW( self.model_stu.parameters(), lr=0.001, weight_decay=5e-4 ) # naive adam else: self.optimizer = optimizer(self.model_stu) if lr_scheduler is None: self.lr_scheduler = optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=lambda epoch: 1 ) # fake scheduler else: self.ls = lr_scheduler self.lr_scheduler = lr_scheduler(self.optimizer) if ema_decay is not None and ema_decay > 0: # ema_decay=0.95 self.ema = ExponentialMovingAverage( self.model_stu.parameters(), decay=ema_decay ) else: self.ema = None self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) # variable init self.epoch = 1 self.global_step = 0 self.local_step = 0 self.stats = { "loss": [], "valid_loss": [], "results": [], # metrics[0], or valid_loss "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt "best_result": None, } # auto fix if len(metrics) == 0 or self.use_loss_as_metric: self.best_mode = "min" # workspace prepare self.log_ptr = None if self.workspace is not None: os.makedirs(self.workspace, exist_ok=True) self.log_path = os.path.join(workspace, f"log_{self.name}.txt") self.log_ptr = open(self.log_path, "a+") self.ckpt_path = os.path.join(self.workspace, "checkpoints") self.best_path = f"{self.ckpt_path}/{self.name}.pth" os.makedirs(self.ckpt_path, exist_ok=True) self.log(self.opt) self.log( f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}' ) self.log( f"[INFO] #parameters: {sum([p.numel() for p in model_stu.parameters() if p.requires_grad])}" ) def __del__(self): if self.log_ptr: self.log_ptr.close() def log(self, *args, **kwargs): if self.local_rank == 0: if not self.mute: # print(*args) self.console.print(*args, **kwargs) if self.log_ptr: print(*args, file=self.log_ptr) self.log_ptr.flush() # write immediately to file def train(self, train_loader, valid_loader, max_epochs): if self.use_tensorboardX and self.local_rank == 0: self.writer = tensorboardX.SummaryWriter( os.path.join(self.workspace, "run", self.name) ) # mark untrained region (i.e., not covered by any camera from the training dataset) if self.model_tea.cuda_ray: self.model_tea.mark_untrained_grid( train_loader._data.poses, train_loader._data.intrinsics ) self.model_stu.mark_untrained_grid( train_loader._data.poses, train_loader._data.intrinsics ) for p in self.model_tea.parameters(): p.requires_grad = False self.model_tea.eval() # get a ref to error_map self.error_map = train_loader._data.error_map for epoch in range(self.epoch, max_epochs + 1): self.epoch = epoch self.train_one_epoch(train_loader) print("\n", self.workspace, "\n") if ( self.workspace is not None and self.local_rank == 0 and self.epoch > max_epochs - 2 ): self.save_checkpoint( full=False, best=False ) # FIXME save should include teacher and student if self.epoch % self.eval_interval == 0: self.evaluate_one_epoch(valid_loader) self.save_checkpoint(full=False, best=True) if self.use_tensorboardX and self.local_rank == 0: self.writer.close() def train_one_epoch(self, loader): self.log( f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ..." ) total_loss = 0 total_loss_rgb = 0 total_loss_fea = 0 if self.local_rank == 0 and self.report_metric_at_train: for metric in self.metrics: metric.clear() self.model_stu.train() self.model_tea.train() # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs # ref: https://pytorch.org/docs/stable/data.html if self.world_size > 1: loader.sampler.set_epoch(self.epoch) if self.local_rank == 0: pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) self.local_step = 0 for data in loader: # update grid every 16 steps if ( self.model_tea.cuda_ray and self.global_step % self.opt.update_extra_interval == 0 ): with torch.cuda.amp.autocast(enabled=self.fp16): if self.opt.update_stu_extra: self.model_stu.update_extra_state() else: pass self.local_step += 1 self.global_step += 1 self.optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=self.fp16): # XXX self.train_step if self.opt.just_train_a_model: loss, preds, truths = self.train_step(data) else: ( preds, truths, loss, loss_rgb, loss_fea, loss_fea_sc, loss_color, loss_sigma, ) = self.train_step(data) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() if self.scheduler_update_every_step: self.lr_scheduler.step() loss_val = loss.item() total_loss += loss_val if self.opt.just_train_a_model: if self.report_metric_at_train: for metric in self.metrics: metric.update(preds, truths) if self.use_tensorboardX: self.writer.add_scalar("train/loss", loss_val, self.global_step) self.writer.add_scalar( "train/lr", self.optimizer.param_groups[0]["lr"], self.global_step, ) if self.scheduler_update_every_step: pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}" ) else: pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})" ) pbar.update(loader.batch_size) else: total_loss_rgb += loss_rgb total_loss_fea += loss_fea if self.local_rank == 0: if self.report_metric_at_train: for metric in self.metrics: metric.update(preds, truths) if self.use_tensorboardX: self.writer.add_scalar("train/loss", loss_val, self.global_step) self.writer.add_scalar( "train/loss_rgb", loss_rgb, self.global_step ) self.writer.add_scalar( "train/loss_fea", loss_fea, self.global_step ) self.writer.add_scalar( "train/loss_fea_sc", loss_fea_sc, self.global_step ) self.writer.add_scalar( "train/loss_coloc", loss_color, self.global_step ) self.writer.add_scalar( "train/loss_sigma", loss_sigma, self.global_step ) self.writer.add_scalar( "train/lr", self.optimizer.param_groups[0]["lr"], self.global_step, ) if self.scheduler_update_every_step: # run this # pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") cur_lr = self.optimizer.param_groups[0]["lr"] pbar.set_description( f"loss={total_loss/self.local_step:.5f}, loss_rgb={total_loss_rgb/self.local_step:.5f}, loss_fea={total_loss_fea/self.local_step:.5f} lr={cur_lr:.5f}" ) else: pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})" ) pbar.update(loader.batch_size) # only for vm FIXME upsample_resolutions should be setted first in main if ( self.opt.model_type == "vm" and self.global_step in self.opt.upsample_model_steps ): # shrink if ( self.model_stu.cuda_ray ): # and self.global_step == self.opt.upsample_model_steps[0]: self.model_stu.shrink_model() # adaptive voxel size from aabb_train n_vox = self.upsample_resolutions.pop(0) ** 3 # n_voxels aabb = self.model_stu.aabb_train.cpu().numpy() vox_size = np.cbrt(np.prod(aabb[3:] - aabb[:3]) / n_vox) reso = ((aabb[3:] - aabb[:3]) / vox_size).astype(np.int32).tolist() self.log( f"[INFO] upsample model at step {self.global_step} from {self.model_stu.resolution} to {reso}" ) self.model_stu.upsample_model(reso) # reset optimizer since params changed. self.optimizer = self.optimizer_fn(self.model_stu) self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) if self.ema is not None: self.ema.update() average_loss = total_loss / self.local_step self.stats["loss"].append(average_loss) if self.local_rank == 0: pbar.close() if self.report_metric_at_train: for metric in self.metrics: self.log(metric.report(), style="red") if self.use_tensorboardX: metric.write(self.writer, self.epoch, prefix="train") metric.clear() if not self.scheduler_update_every_step: if isinstance( self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau ): self.lr_scheduler.step(average_loss) else: self.lr_scheduler.step() self.log(f"==> Finished Epoch {self.epoch}.") ### ------------------------------ def get_loss(self, pred, gt): if self.opt.loss_type == "L2": loss = torch.mean((gt - pred) ** 2) elif self.opt.loss_type == "normL2": loss = torch.norm(pred - gt) elif self.opt.loss_type == "normL1": loss = torch.norm(pred - gt, p=1) elif self.opt.loss_type == "smoothL1": loss = torch.nn.functional.smooth_l1_loss(pred, gt, beta=0.05) else: raise ValueError("error loss_type") return loss def train_step(self, data): rays_o = data["rays_o"] # [B, N, 3] rays_d = data["rays_d"] # [B, N, 3] # if there is no gt image, we train with CLIP loss. if "images" not in data: assert 1 == 2 B, N = rays_o.shape[:2] H, W = data["H"], data["W"] # currently fix white bg, MUST force all rays! outputs = self.model.render( rays_o, rays_d, staged=False, bg_color=None, perturb=True, force_all_rays=True, **vars(self.opt), ) pred_rgb = ( outputs["image"].reshape(B, H, W, 3).permute(0, 3, 1, 2).contiguous() ) loss = self.clip_loss(pred_rgb) return pred_rgb, None, loss images = data["images"] # [B, N, 3/4] B, N, C = images.shape if self.opt.color_space == "linear": images[..., :3] = srgb_to_linear(images[..., :3]) if C == 3 or self.model_stu.bg_radius > 0: bg_color = 1 # train with random background color if not using a bg model and has alpha channel. else: bg_color = torch.rand_like(images[..., :3]) # [N, 3], pixel-wise random. if C == 4: gt_rgb = images[..., :3] * images[..., 3:] + bg_color * ( 1 - images[..., 3:] ) else: gt_rgb = images # outputs = self.model.render(rays_o,rays_d,staged=False,bg_color=bg_color,perturb=True,force_all_rays=False,**vars(self.opt)) if self.opt.render_stu_first: outputs_stu = self.model_stu.render( rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **vars(self.opt), ) pred_rgb_stu = outputs_stu["image"] if not self.opt.just_train_a_model: with torch.no_grad(): outputs_tea = self.model_tea.render( rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, inherited_params=outputs_stu["inherited_params"], **vars(self.opt), ) pred_rgb_tea = outputs_tea["image"] else: with torch.no_grad(): outputs_tea = self.model_tea.render( rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **vars(self.opt), ) pred_rgb_tea = outputs_tea["image"] outputs_stu = self.model_stu.render( rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, inherited_params=outputs_tea["inherited_params"], **vars(self.opt), ) pred_rgb_stu = outputs_stu["image"] loss = 0.0 if self.opt.just_train_a_model: pred_rgb = pred_rgb_stu loss = loss + self.criterion(pred_rgb, gt_rgb).mean() if self.opt.model_type == "vm": loss = loss + self.model_stu.density_loss() * self.opt.l1_reg_weight return loss, pred_rgb, gt_rgb def evaluate(self, loader, name=None): self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX self.evaluate_one_epoch(loader, name) self.use_tensorboardX = use_tensorboardX def evaluate_one_epoch(self, loader, name=None): self.log(f"++> Evaluate at epoch {self.epoch} ...") if name is None: name = f"{self.name}_ep{self.epoch:04d}" total_loss = 0 if self.local_rank == 0: for metric in self.metrics: metric.clear() if self.opt.test_teacher: self.model_stu = self.model_tea self.model_stu.eval() if self.ema is not None: self.ema.store() self.ema.copy_to() if self.local_rank == 0: pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) with torch.no_grad(): self.local_step = 0 self.ssim = 0.0 self.lpips_vgg = 0.0 self.lpips_alex = 0.0 # update grid if self.model_stu.cuda_ray: with torch.cuda.amp.autocast(enabled=self.fp16): if self.opt.update_stu_extra: self.model_stu.update_extra_state() else: pass for data in loader: self.local_step += 1 with torch.cuda.amp.autocast(enabled=self.fp16): preds, preds_depth, truths, loss = self.eval_step(data) # all_gather/reduce the statistics (NCCL only support all_*) if self.world_size > 1: dist.all_reduce(loss, op=dist.ReduceOp.SUM) loss = loss / self.world_size preds_list = [ torch.zeros_like(preds).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(preds_list, preds) preds = torch.cat(preds_list, dim=0) preds_depth_list = [ torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(preds_depth_list, preds_depth) preds_depth = torch.cat(preds_depth_list, dim=0) truths_list = [ torch.zeros_like(truths).to(self.device) for _ in range(self.world_size) ] # [[B, ...], [B, ...], ...] dist.all_gather(truths_list, truths) truths = torch.cat(truths_list, dim=0) loss_val = loss.item() total_loss += loss_val # only rank = 0 will perform evaluation. if self.local_rank == 0: for metric in self.metrics: metric.update(preds, truths) self.lpips_alex += rgb_lpips(truths, preds, "alex") self.lpips_vgg += rgb_lpips(truths, preds, "vgg") self.ssim += compute_ssim( preds, truths, max_val=max(preds.max().item(), truths.max().item()), ).item() # save image save_path = os.path.join( self.workspace, "validation", f"{name}_{self.local_step:04d}.png", ) save_path_depth = os.path.join( self.workspace, "validation", f"{name}_{self.local_step:04d}_depth.png", ) save_path_gt = os.path.join( self.workspace, "validation", f"{name}_{self.local_step:04d}_gt.png", ) # self.log(f"==> Saving validation image to {save_path}") os.makedirs(os.path.dirname(save_path), exist_ok=True) if self.opt.color_space == "linear": preds = linear_to_srgb(preds) pred = preds[0].detach().cpu().numpy() pred_depth = preds_depth[0].detach().cpu().numpy() if self.local_step < 15: cv2.imwrite( save_path, cv2.cvtColor( (pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR ), ) cv2.imwrite( save_path_depth, (pred_depth * 255).astype(np.uint8) ) cv2.imwrite( save_path_gt, cv2.cvtColor( (truths[0].detach().cpu().numpy() * 255).astype( np.uint8 ), cv2.COLOR_RGB2BGR, ), ) # cv2.imwrite(save_path_gt, cv2.cvtColor((linear_to_srgb(truths[0].detach().cpu().numpy()) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) pbar.set_description( f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})" ) pbar.update(loader.batch_size) average_loss = total_loss / self.local_step self.stats["valid_loss"].append(average_loss) if self.local_rank == 0: pbar.close() if not self.use_loss_as_metric and len(self.metrics) > 0: result = self.metrics[0].measure() self.stats["results"].append( result if self.best_mode == "min" else -result ) # if max mode, use -result else: self.stats["results"].append( average_loss ) # if no metric, choose best by min loss for metric in self.metrics: self.log(metric.report(), style="blue") psnr = metric.report().split("=")[-1].strip()[:5] self.psnr = float(psnr) if self.use_tensorboardX: metric.write(self.writer, self.epoch, prefix="evaluate") metric.clear() self.ssim /= self.local_step self.lpips_alex /= self.local_step self.lpips_vgg /= self.local_step if self.ema is not None: self.ema.restore() # from IPython import embed; embed() print( f"\n psnr:{psnr} ssim:{self.ssim} alex:{self.lpips_alex} vgg:{self.lpips_vgg} \n" ) # cmd = f'mv {self.workspace} {self.workspace}-pnsr{psnr}' # print(cmd) # os.system(cmd) self.log(f"++> Evaluate epoch {self.epoch} Finished.") def eval_step(self, data): rays_o = data["rays_o"] # [B, N, 3] rays_d = data["rays_d"] # [B, N, 3] images = data["images"] # [B, H, W, 3/4] B, H, W, C = images.shape if self.opt.color_space == "linear": images[..., :3] = srgb_to_linear(images[..., :3]) # eval with fixed background color bg_color = 1 if C == 4: gt_rgb = images[..., :3] * images[..., 3:] + bg_color * ( 1 - images[..., 3:] ) else: gt_rgb = images outputs = self.model_stu.render( rays_o, rays_d, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt), ) pred_rgb = outputs["image"].reshape(B, H, W, 3) pred_depth = outputs["depth"].reshape(B, H, W) loss = self.criterion(pred_rgb, gt_rgb).mean() return pred_rgb, pred_depth, gt_rgb, loss def save_checkpoint(self, name=None, full=False, best=False, remove_old=True): full = False if name is None: name = f"{self.name}_ep{self.epoch:04d}" if self.opt.model_type == "vm": state = { "epoch": self.epoch, "global_step": self.global_step, "stats": self.stats, "resolution": self.model_stu.resolution, } else: state = { "epoch": self.epoch, "global_step": self.global_step, "stats": self.stats, } if self.model_stu.cuda_ray: state["mean_count"] = self.model_stu.mean_count state["mean_density"] = self.model_stu.mean_density if full: state["optimizer"] = self.optimizer.state_dict() state["lr_scheduler"] = self.lr_scheduler.state_dict() state["scaler"] = self.scaler.state_dict() if self.ema is not None: state["ema"] = self.ema.state_dict() if not best: state["model"] = self.model_stu.state_dict() file_path = f"{self.ckpt_path}/{name}.pth" if remove_old: self.stats["checkpoints"].append(file_path) if len(self.stats["checkpoints"]) > self.max_keep_ckpt: old_ckpt = self.stats["checkpoints"].pop(0) if os.path.exists(old_ckpt): os.remove(old_ckpt) torch.save(state, file_path) else: if len(self.stats["results"]) > 0: if ( self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"] ): self.log( f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}" ) self.stats["best_result"] = self.stats["results"][-1] # save ema results if self.ema is not None: self.ema.store() self.ema.copy_to() state["model"] = self.model_stu.state_dict() if self.ema is not None: self.ema.restore() torch.save(state, self.best_path) else: self.log( f"[WARN] no evaluated results found, skip saving best checkpoint." ) def load_teacher_checkpoint(self): checkpoint_dict = torch.load(self.opt.ckpt_teacher, map_location=self.device) missing_keys, unexpected_keys = self.model_tea.load_state_dict( checkpoint_dict["model"], strict=False ) self.log("[INFO] loaded teacher model.") if len(missing_keys) > 0: self.log(f"[WARN] missing keys: {missing_keys}") if len(unexpected_keys) > 0: self.log(f"[WARN] unexpected keys: {unexpected_keys}") if self.ema is not None and "ema" in checkpoint_dict: self.ema.load_state_dict(checkpoint_dict["ema"]) # if self.ema is not None and 'ema' in checkpoint_dict: # self.ema.load_state_dict(checkpoint_dict['ema']) if self.model_tea.cuda_ray: if "mean_count" in checkpoint_dict: self.model_tea.mean_count = checkpoint_dict["mean_count"] if "mean_density" in checkpoint_dict: self.model_tea.mean_density = checkpoint_dict["mean_density"] def load_student_checkpoint(self): if self.opt.ckpt_student: checkpoint_dict = torch.load( self.opt.ckpt_student, map_location=self.device ) else: checkpoint_dict = torch.load( self.opt.ckpt_teacher, map_location=self.device ) if self.opt.model_type == "vm" and "resolution" in checkpoint_dict: self.model_stu.upsample_model(checkpoint_dict["resolution"]) missing_keys, unexpected_keys = self.model_stu.load_state_dict( checkpoint_dict["model"], strict=False ) self.log("[INFO] loaded student model.") if len(missing_keys) > 0: self.log(f"[WARN] missing keys: {missing_keys}") if len(unexpected_keys) > 0: self.log(f"[WARN] unexpected keys: {unexpected_keys}") if self.model_stu.cuda_ray: if "mean_count" in checkpoint_dict: self.model_stu.mean_count = checkpoint_dict["mean_count"] if "mean_density" in checkpoint_dict: self.model_stu.mean_density = checkpoint_dict["mean_density"] if self.ema is not None and "ema" in checkpoint_dict: self.ema.load_state_dict(checkpoint_dict["ema"]) def test(self, loader, save_path=None, name=None): assert 1 == 2 if save_path is None: save_path = os.path.join(self.workspace, "results") if name is None: name = f"{self.name}_ep{self.epoch:04d}" os.makedirs(save_path, exist_ok=True) self.log(f"==> Start Test, save results to {save_path}") pbar = tqdm.tqdm( total=len(loader) * loader.batch_size, bar_format="{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]", ) self.model_stu.eval() with torch.no_grad(): # update grid if self.model_stu.cuda_ray: with torch.cuda.amp.autocast(enabled=self.fp16): self.model_stu.update_extra_state() for i, data in enumerate(loader): with torch.cuda.amp.autocast(enabled=self.fp16): preds, preds_depth = self.test_step(data) path = os.path.join(save_path, f"{name}_{i:04d}.png") path_depth = os.path.join(save_path, f"{name}_{i:04d}_depth.png") # self.log(f"[INFO] saving test image to {path}") if self.opt.color_space == "linear": preds = linear_to_srgb(preds) pred = preds[0].detach().cpu().numpy() pred_depth = preds_depth[0].detach().cpu().numpy() cv2.imwrite( path, cv2.cvtColor((pred * 255).astype(np.uint8), cv2.COLOR_RGB2BGR) ) cv2.imwrite(path_depth, (pred_depth * 255).astype(np.uint8)) pbar.update(loader.batch_size) self.log(f"==> Finished Test.") # moved out bg_color and perturb for more flexible control... def test_step(self, data, bg_color=None, perturb=False): rays_o = data["rays_o"] # [B, N, 3] rays_d = data["rays_d"] # [B, N, 3] H, W = data["H"], data["W"] if bg_color is not None: bg_color = bg_color.to(self.device) outputs = self.model_stu.render( rays_o, rays_d, staged=True, bg_color=bg_color, perturb=perturb, **vars(self.opt), ) pred_rgb = outputs["image"].reshape(-1, H, W, 3) pred_depth = outputs["depth"].reshape(-1, H, W) return pred_rgb, pred_depth ================================================ FILE: main_distill_mutual.py ================================================ import torch import os import argparse from distill_mutual.network import NeRFNetwork from functools import partial from time import time from distill_mutual.provider import NeRFDataset from distill_mutual.utils import * from IPython import embed device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def save_codes_env(workspace): path = os.path.join(workspace, "codes_env") os.makedirs(path, exist_ok=True) os.system(f"cp *.py {path}") os.system(f"cp -r raymarching {path}") os.system(f"cp -r distill_mutual {path}") os.system(f"cp -r nerf {path}") def load_from_txt(opt, except_space=""): # except_space = {'workspace', 'teacher_type', 'model_type', 'test', 'test_teacher', 'use_spiral_pose', 'ckpt_teacher'} except_space = {"workspace"} with open( os.path.join(opt.ckpt_teacher.split("checkpoints")[0], "args.txt"), "r" ) as f: # change this path to your own params settings load_args = f.readlines() for i in range(1, len(load_args)): if "(" in load_args[i]: k, v = eval(load_args[i]) else: continue if k in opt and k not in except_space and v != opt.__dict__[k]: print(k, v, opt.__dict__[k]) opt.__dict__[k] = v if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("path", type=str) parser.add_argument( "-O", action="store_true", help="equals --fp16 --cuda_ray --preload" ) parser.add_argument("--test", action="store_true", help="test mode") parser.add_argument("--workspace", type=str, default="workspace") parser.add_argument("--seed", type=int, default=0) # training options parser.add_argument("--iters", type=int, default=30000, help="training iters") parser.add_argument("--lr", type=float, default=1e-2, help="initial learning rate") parser.add_argument("--ckpt", type=str, default="latest") parser.add_argument( "--num_rays", type=int, default=4096, help="num rays sampled per image for each training step", ) parser.add_argument( "--cuda_ray", action="store_true", help="use CUDA raymarching instead of pytorch", ) parser.add_argument( "--max_steps", type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)", ) parser.add_argument( "--num_steps", type=int, default=512, help="num steps sampled per ray (only valid when NOT using --cuda_ray)", ) parser.add_argument( "--upsample_steps", type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)", ) parser.add_argument( "--update_extra_interval", type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)", ) parser.add_argument( "--max_ray_batch", type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)", ) parser.add_argument( "--fp16", action="store_true", help="use amp mixed precision training" ) parser.add_argument( "--mode", type=str, default="blender", help="dataset mode, supports (colmap, blender)", ) parser.add_argument( "--color_space", type=str, default="srgb", help="Color space, supports (linear, srgb)", ) parser.add_argument( "--preload", action="store_true", help="preload all data into GPU, accelerate training but use more GPU memory", ) parser.add_argument( "--bound", type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.", ) parser.add_argument( "--scale", type=float, default=0.8, help="scale camera location into box[-bound, bound]^3", ) parser.add_argument( "--dt_gamma", type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)", ) parser.add_argument( "--min_near", type=float, default=0.2, help="minimum near distance for camera" ) parser.add_argument( "--density_thresh", type=float, default=10, help="threshold for density grid to be occupied", ) parser.add_argument( "--bg_radius", type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)", ) # experimental parser.add_argument( "--error_map", action="store_true", help="use error map to sample rays" ) parser.add_argument( "--clip_text", type=str, default="", help="text input for CLIP guidance" ) parser.add_argument( "--loss_type", type=str, default="normL2", choices=["normL2", "L2", "normL1", "L1"], ) parser.add_argument( "--distill_mode", type=str, default="no_fix_mlp", choices=["fix_mlp", "no_fix_mlp"], help="fix mlp for hash", ) parser.add_argument("--loss_rate_rgb", type=float, default=1.0) parser.add_argument("--loss_rate_fea_sc", type=float, default=0.002) parser.add_argument("--loss_rate_color", type=float, default=0.002) parser.add_argument("--loss_rate_sigma", type=float, default=0.002) parser.add_argument("--l1_reg_weight", type=float, default=1e-4) parser.add_argument("--ckpt_teacher", type=str, default="") parser.add_argument("--ckpt_student", type=str, default="") parser.add_argument("--sigma_clip_min", type=float, default=-2) parser.add_argument("--sigma_clip_max", type=float, default=7) parser.add_argument("--render_stu_first", action="store_true", default=False) parser.add_argument("--use_diagonal_matrix", action="store_true", default=False) parser.add_argument("--test_teacher", action="store_true", default=False) parser.add_argument("--test_metric", action="store_true", default=False) parser.add_argument( "--test_type_trainval", action="store_true", default=False ) # XXX parser.add_argument("--PE", type=int, default=10) parser.add_argument("--nerf_layer_num", type=int, default=8) parser.add_argument("--nerf_layer_wide", type=int, default=256) parser.add_argument("--skip", type=int, default=3) parser.add_argument("--residual", type=int, default=3) parser.add_argument("--resolution0", type=int, default=300) parser.add_argument("--resolution1", type=int, default=300) parser.add_argument( "--upsample_model_steps", type=int, action="append", default=[1e10] ) parser.add_argument("--teacher_type", default="hash", type=str) parser.add_argument("--model_type", default="hash", type=str) parser.add_argument( "--data_type", default="synthetic", type=str, choices=["synthetic", "llff", "tank"], ) parser.add_argument("--update_stu_extra", action="store_true", default=False) parser.add_argument("--ema_decay", type=float, default=-1.0) parser.add_argument("--grid_size", type=int, default=128) parser.add_argument("--plenoxel_degree", type=int, default=3) parser.add_argument("--plenoxel_res", type=str, default="[128,128,128]") parser.add_argument("--load_args", action="store_true", default=False) parser.add_argument("--eval_interval_epoch", default=1e5, type=int, help="") parser.add_argument( "--use_real_data_for_train", action="store_true", default=False, ) parser.add_argument("--enable_embed", action="store_true") parser.add_argument("--enable_edit_plenoxel", action="store_true") parser.add_argument( "--stage_iters", type=str, default="{'stage1':2000, 'stage2':5000}" ) opt = parser.parse_args() opt.stage_iters = eval(opt.stage_iters) opt.O = True # always use -O opt.render_stu_first = True if opt.model_type == "mlp": opt.lr *= 0.1 if ( "tensors" == opt.model_type or "tensors" == opt.teacher_type ): # plenoxel have no features opt.stage_iters["stage1"] = -1 save_codes_env(opt.workspace) if opt.load_args: load_from_txt(opt) if opt.O: opt.fp16 = True opt.cuda_ray = True opt.preload = True assert opt.model_type in ["hash", "mlp", "vm", "tensors"] assert opt.teacher_type in ["hash", "mlp", "vm", "tensors"] print(opt) seed_everything(opt.seed) model_tea = NeRFNetwork( encoding="hashgrid", bound=opt.bound, cuda_ray=opt.cuda_ray, density_scale=1, min_near=opt.min_near, density_thresh=opt.density_thresh, bg_radius=opt.bg_radius, model_type=opt.teacher_type, args=opt, grid_size=opt.grid_size, is_teacher=True, ) model_stu = NeRFNetwork( encoding="hashgrid", bound=opt.bound, cuda_ray=opt.cuda_ray, density_scale=1, min_near=opt.min_near, density_thresh=opt.density_thresh, bg_radius=opt.bg_radius, model_type=opt.model_type, args=opt, grid_size=opt.grid_size, ) print("\nteacher:", model_tea) print(f"\n{opt.model_type}", model_stu) criterion = torch.nn.MSELoss(reduction="none") # ------------------------------------ test-test-test-test-test ---------------------------------------------- if opt.test or opt.test_teacher or opt.test_type_trainval: trainer = Trainer( f"{opt.teacher_type}2{opt.model_type}", opt, model_tea, model_stu, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt, ema_decay=opt.ema_decay, ) if opt.test_type_trainval: test_loader = NeRFDataset(opt, device=device, type="trainval").dataloader() else: test_loader = NeRFDataset(opt, device=device, type="test").dataloader() if opt.mode == "blender": trainer.evaluate(test_loader) else: trainer.test(test_loader) # ------------------------------------ train-train-train-train ---------------------------------------------- else: for p in model_tea.parameters(): p.requires_grad = False if opt.distill_mode == "fix_mlp": for n, p in model_stu.named_parameters(): if "sigma_net" in n or "color_net" in n: p.requires_grad = False idx = 1 if opt.model_type == "vm" else 3 optimizer = lambda model_stu: torch.optim.AdamW( model_stu.get_params(opt.lr)[idx:], betas=(0.9, 0.99), eps=1e-15, amsgrad=False, ) else: optimizer = lambda model_stu: torch.optim.AdamW( model_stu.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15, amsgrad=False, ) # fake train loader. The real random data for distillating will be generated in utils.py train_loader = NeRFDataset(opt, device=device, type="train").dataloader() opt.iters = opt.iters + opt.iters % len( train_loader ) # will be updated in utils according to the number of random data max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) scheduler = lambda optimizer: optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=opt.iters * 1, eta_min=5e-5 ) trainer = Trainer( f"{opt.teacher_type}2{opt.model_type}", opt, model_tea, model_stu, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=opt.ema_decay, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt, eval_interval=opt.eval_interval_epoch, ) upsample_resolutions = ( ( np.round( np.exp( np.linspace( np.log(opt.resolution0), np.log(opt.resolution1), len(opt.upsample_model_steps) + 1, ) ) ) ) .astype(np.int32) .tolist()[1:] ) trainer.upsample_resolutions = upsample_resolutions argstxt = sorted(opt.__dict__.items()) with open(os.path.join(opt.workspace, "args.txt"), "w") as f: for t in argstxt: f.write(str(t) + "\n") start_time = time.time() valid_loader = NeRFDataset( opt, device=device, type="val", downscale=1 ).dataloader() test_loader = NeRFDataset(opt, device=device, type="test").dataloader() trainer.train(train_loader, valid_loader, max_epoch) end_time = time.time() train_time = end_time - start_time print(f"\nusing_time : {train_time:.2f}s\n") # run test data test_loader = NeRFDataset(opt, device=device, type="test").dataloader() print(opt.workspace) trainer.evaluate(test_loader) with open(os.path.join(trainer.workspace, "args.txt"), "a+") as f: txt = f"\npsnr: {trainer.psnr:.2f} \nssim: {trainer.ssim:.3f} \nalex: {trainer.lpips_alex:.3f}\nvgg:{trainer.lpips_vgg:.3f}" f.write(txt) cmd = f"mv {trainer.workspace} {trainer.workspace}-pnsr{trainer.psnr}" os.system(cmd) ================================================ FILE: main_just_train_tea.py ================================================ import torch import os import argparse from just_train_tea.network import NeRFNetwork from functools import partial from just_train_tea.provider import NeRFDataset from just_train_tea.utils import * from time import time if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("path", type=str) parser.add_argument( "-O", action="store_true", help="equals --fp16 --cuda_ray --preload" ) parser.add_argument("--test", action="store_true", help="test mode") parser.add_argument("--workspace", type=str, default="workspace") parser.add_argument("--seed", type=int, default=0) ### training options parser.add_argument("--iters", type=int, default=40000, help="training iters") parser.add_argument("--lr", type=float, default=1e-2, help="initial learning rate") parser.add_argument("--ckpt", type=str, default="latest") parser.add_argument( "--num_rays", type=int, default=8192, help="num rays sampled per image for each training step", ) parser.add_argument( "--cuda_ray", action="store_true", help="use CUDA raymarching instead of pytorch", ) parser.add_argument( "--max_steps", type=int, default=1024, help="max num steps sampled per ray (only valid when using --cuda_ray)", ) parser.add_argument( "--num_steps", type=int, default=512, help="num steps sampled per ray (only valid when NOT using --cuda_ray)", ) parser.add_argument( "--upsample_steps", type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)", ) parser.add_argument( "--update_extra_interval", type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)", ) parser.add_argument( "--max_ray_batch", type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)", ) parser.add_argument( "--fp16", action="store_true", help="use amp mixed precision training" ) parser.add_argument("--ff", action="store_true", help="use fully-fused MLP") parser.add_argument("--tcnn", action="store_true", help="use TCNN backend") parser.add_argument( "--mode", type=str, default="blender", help="dataset mode, supports (colmap, blender)", ) parser.add_argument( "--color_space", type=str, default="srgb", help="Color space, supports (linear, srgb)", ) parser.add_argument( "--preload", action="store_true", help="preload all data into GPU, accelerate training but use more GPU memory", ) # (the default value is for the fox dataset) parser.add_argument( "--bound", type=float, default=1, help="assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.", ) parser.add_argument( "--scale", type=float, default=0.8, help="scale camera location into box[-bound, bound]^3", ) parser.add_argument( "--dt_gamma", type=float, default=0, help="dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)", ) parser.add_argument( "--min_near", type=float, default=0.2, help="minimum near distance for camera" ) parser.add_argument( "--density_thresh", type=float, default=10, help="threshold for density grid to be occupied", ) parser.add_argument( "--bg_radius", type=float, default=-1, help="if positive, use a background model at sphere(bg_radius)", ) ### GUI options parser.add_argument("--gui", action="store_true", help="start a GUI") parser.add_argument("--W", type=int, default=1920, help="GUI width") parser.add_argument("--H", type=int, default=1080, help="GUI height") parser.add_argument( "--radius", type=float, default=5, help="default GUI camera radius from center" ) parser.add_argument( "--fovy", type=float, default=50, help="default GUI camera fovy" ) parser.add_argument( "--max_spp", type=int, default=64, help="GUI rendering max sample per pixel" ) ### experimental parser.add_argument( "--error_map", action="store_true", help="use error map to sample rays" ) parser.add_argument( "--clip_text", type=str, default="", help="text input for CLIP guidance" ) parser.add_argument( "--rand_pose", type=int, default=-1, help="<0 uses no rand pose, =0 only uses rand pose, >0 sample one rand pose every $ known poses", ) parser.add_argument( "--distill_mode", type=str, default="no_fix_mlp", choices=["fix_mlp", "no_fix_mlp"], ) parser.add_argument("--loss_rate_rgb", type=float, default=1.0) parser.add_argument("--loss_rate_fea", type=float, default=0.1) parser.add_argument("--loss_rate_fea_sc", type=float, default=0.1) parser.add_argument("--loss_rate_color", type=float, default=0.0) parser.add_argument("--loss_rate_sigma", type=float, default=0) parser.add_argument( "--L1_tensorAB_reg", type=float, default=1e-3, help="reg for tensor_ab" ) parser.add_argument("--l1_reg_weight", type=float, default=1e-4) parser.add_argument("--ckpt_teacher", type=str, default="") parser.add_argument("--ckpt_student", type=str, default="") parser.add_argument("--sigma_clip_min", type=float, default=-2) parser.add_argument("--sigma_clip_max", type=float, default=7) parser.add_argument("--use_sigma_clip", action="store_true") parser.add_argument("--render_stu_first", action="store_true", default=False) parser.add_argument("--nerf_pe", action="store_true", default=False) parser.add_argument("--use_real_gt", action="store_true", default=False) parser.add_argument("--use_diagonal_matrix", action="store_true", default=False) parser.add_argument( "--loss_rate_real_gt", type=float, default=0, help="range in [0, 1]" ) parser.add_argument("--test_teacher", action="store_true", default=False) parser.add_argument("--test_metric", action="store_true", default=False) parser.add_argument("--resolution0", type=int, default=300) parser.add_argument("--resolution1", type=int, default=300) parser.add_argument( "--upsample_model_steps", type=int, action="append", default=[1e10] ) parser.add_argument( "--loss_type", type=str, default="L2", choices=["normL2", "L2", "normL1", "L1"] ) parser.add_argument("--PE", type=int, default=10) parser.add_argument("--nerf_layer_num", type=int, default=8) parser.add_argument("--nerf_layer_wide", type=int, default=256) parser.add_argument("--skip", type=int, default=3) parser.add_argument("--residual", type=int, default=3) parser.add_argument("--model_type", default="hash", type=str) parser.add_argument("--teacher_type", default="hash", type=str) parser.add_argument("--use_upsample_vm", action="store_true", default=False) parser.add_argument("--update_stu_extra", action="store_true", default=False) parser.add_argument("--ema_decay", type=float, default=-1) parser.add_argument("--grid_size", type=int, default=128) parser.add_argument("--plenoxel_degree", type=int, default=3) parser.add_argument("--plenoxel_res", type=str, default="[128,128,128]") parser.add_argument("--just_train_a_model", action="store_true", default=False) parser.add_argument("--data_type", type=str, default="") opt = parser.parse_args() opt.just_train_a_model = True opt.update_stu_extra = True opt.render_stu_first = True opt.O = True if opt.model_type == "mlp": opt.lr *= 0.1 if opt.O: opt.fp16 = True opt.cuda_ray = True opt.preload = True assert opt.model_type in ["hash", "mlp", "vm", "tensors"] print(opt) seed_everything(opt.seed) model_tea = NeRFNetwork( encoding="hashgrid", bound=opt.bound, cuda_ray=opt.cuda_ray, density_scale=1, min_near=opt.min_near, density_thresh=opt.density_thresh, bg_radius=opt.bg_radius, model_type=opt.teacher_type, args=opt, grid_size=opt.grid_size, is_teacher=True, ) model_stu = NeRFNetwork( encoding="hashgrid", bound=opt.bound, cuda_ray=opt.cuda_ray, density_scale=1, min_near=opt.min_near, density_thresh=opt.density_thresh, bg_radius=opt.bg_radius, model_type=opt.model_type, args=opt, grid_size=opt.grid_size, ) criterion = torch.nn.MSELoss(reduction="none") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if opt.test or opt.test_teacher or opt.test_metric: trainer = Trainer( opt.model_type, opt, model_tea, model_stu, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt, ema_decay=opt.ema_decay, ) test_loader = NeRFDataset(opt, device=device, type="test").dataloader() trainer.evaluate(test_loader) else: for p in model_tea.parameters(): p.requires_grad = False optimizer = lambda model_stu: torch.optim.AdamW( model_stu.get_params(opt.lr, opt.lr * 0.1), betas=(0.9, 0.99), eps=1e-15, amsgrad=False, ) train_loader = NeRFDataset(opt, device=device, type="train").dataloader() valid_loader = NeRFDataset(opt, device=device, type="val").dataloader() test_loader = NeRFDataset(opt, device=device, type="test").dataloader() if opt.just_train_a_model: scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR( optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1) ) else: scheduler = lambda optimizer: optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=opt.iters * 1 ) print(scheduler) trainer = Trainer( opt.model_type, opt, model_tea, model_stu, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=opt.ema_decay, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt, eval_interval=500000000, ) upsample_resolutions = ( ( np.round( np.exp( np.linspace( np.log(opt.resolution0), np.log(opt.resolution1), len(opt.upsample_model_steps) + 1, ) ) ) ) .astype(np.int32) .tolist()[1:] ) trainer.upsample_resolutions = upsample_resolutions argstxt = sorted(opt.__dict__.items()) with open(os.path.join(opt.workspace, "args.txt"), "w") as f: for t in argstxt: f.write(str(t) + "\n") start_time = time() max_epoch = np.ceil(opt.iters / len(train_loader)).astype(np.int32) trainer.train(train_loader, valid_loader, max_epoch) print(opt.workspace) trainer.evaluate(test_loader) with open(os.path.join(trainer.workspace, "args.txt"), "a+") as f: txt = f"\npsnr: {trainer.psnr:.2f} \nssim: {trainer.ssim:.3f} \nalex: {trainer.lpips_alex:.3f}\nvgg:{trainer.lpips_vgg:.3f}" f.write(txt) cmd = f"mv {trainer.workspace} {trainer.workspace}-pnsr{trainer.psnr}" print(f"\n{cmd}\n") os.system(cmd) ================================================ FILE: raymarching/__init__.py ================================================ from .raymarching import * ================================================ FILE: raymarching/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path _backend = load( name="_raymarching", extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, sources=[ os.path.join(_src_path, "src", f) for f in [ "raymarching.cu", "bindings.cpp", ] ], ) __all__ = ["_backend"] ================================================ FILE: raymarching/raymarching.py ================================================ import numpy as np import time import torch import torch.nn as nn from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd try: import _raymarching as _backend except ImportError: from .backend import _backend # ---------------------------------------- # utils # ---------------------------------------- class _near_far_from_aabb(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): """near_far_from_aabb, CUDA implementation Calculate rays' intersection time (near and far) with aabb Args: rays_o: float, [N, 3] rays_d: float, [N, 3] aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) min_near: float, scalar Returns: nears: float, [N] fars: float, [N] """ if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # num rays nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) _backend.near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) return nears, fars near_far_from_aabb = _near_far_from_aabb.apply class _polar_from_ray(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, rays_o, rays_d, radius): """polar_from_ray, CUDA implementation get polar coordinate on the background sphere from rays. Assume rays_o are inside the Sphere(radius). Args: rays_o: [N, 3] rays_d: [N, 3] radius: scalar, float Return: coords: [N, 2], in [-1, 1], theta and phi on a sphere. """ if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) N = rays_o.shape[0] # num rays coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) _backend.polar_from_ray(rays_o, rays_d, radius, N, coords) return coords polar_from_ray = _polar_from_ray.apply class _morton3D(Function): @staticmethod def forward(ctx, coords): """morton3D, CUDA implementation Args: coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) TODO: check if the coord range is valid! (current 128 is safe) Returns: indices: [N], int32, in [0, 128^3) """ if not coords.is_cuda: coords = coords.cuda() N = coords.shape[0] indices = torch.empty(N, dtype=torch.int32, device=coords.device) _backend.morton3D(coords.int(), N, indices) return indices morton3D = _morton3D.apply class _morton3D_invert(Function): @staticmethod def forward(ctx, indices): """morton3D_invert, CUDA implementation Args: indices: [N], int32, in [0, 128^3) Returns: coords: [N, 3], int32, in [0, 128) """ if not indices.is_cuda: indices = indices.cuda() N = indices.shape[0] coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) _backend.morton3D_invert(indices.int(), N, coords) return coords morton3D_invert = _morton3D_invert.apply class _packbits(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, grid, thresh, bitfield=None): """packbits, CUDA implementation Pack up the density grid into a bit field to accelerate ray marching. Args: grid: float, [C, H * H * H], assume H % 2 == 0 thresh: float, threshold Returns: bitfield: uint8, [C, H * H * H / 8] """ if not grid.is_cuda: grid = grid.cuda() grid = grid.contiguous() C = grid.shape[0] H3 = grid.shape[1] N = C * H3 // 8 if bitfield is None: bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) _backend.packbits(grid, N, thresh, bitfield) return bitfield packbits = _packbits.apply # ---------------------------------------- # train functions # ---------------------------------------- class _march_rays_train(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward( ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False, dt_gamma=0, max_steps=1024, ): """march rays to generate points (forward only) Args: rays_o/d: float, [N, 3] bound: float, scalar density_bitfield: uint8: [CHHH // 8] C: int H: int nears/fars: float, [N] step_counter: int32, (2), used to count the actual number of generated points. mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) perturb: bool align: int, pad output so its size is dividable by align, set to -1 to disable. force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) max_steps: int, max number of sampled points along each ray, also affect min_stepsize. Returns: xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) dirs: float, [M, 3], all generated points' view dirs. deltas: float, [M, 2], all generated points' deltas. (first for RGB, second for Depth) rays: int32, [N, 3], all rays' (index, point_offset, point_count), e.g., xyzs[rays[i, 1]:rays[i, 2]] --> points belonging to rays[i, 0] """ if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) density_bitfield = density_bitfield.contiguous() N = rays_o.shape[0] # num rays M = N * max_steps # init max points number in total # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) # It estimate the max points number to enable faster training, but will lead to random ignored rays if underestimated. if not force_all_rays and mean_count > 0: if align > 0: mean_count += align - mean_count % align M = mean_count xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) rays = torch.empty( N, 3, dtype=torch.int32, device=rays_o.device ) # id, offset, num_steps if step_counter is None: step_counter = torch.zeros( 2, dtype=torch.int32, device=rays_o.device ) # point counter, ray counter _backend.march_rays_train( rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, perturb, ) # m is the actually used points number # print(step_counter, M) # only used at the first (few) epochs. if force_all_rays or mean_count <= 0: m = step_counter[0].item() # D2H copy if align > 0: m += align - m % align xyzs = xyzs[:m] dirs = dirs[:m] deltas = deltas[:m] torch.cuda.empty_cache() return xyzs, dirs, deltas, rays march_rays_train = _march_rays_train.apply class _composite_rays_train(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, sigmas, rgbs, deltas, rays): """composite rays' rgbs, according to the ray marching formula. Args: rgbs: float, [M, 3] sigmas: float, [M,] deltas: float, [M, 2] rays: int32, [N, 3] Returns: weights_sum: float, [N,], the alpha channel depth: float, [N, ], the Depth image: float, [N, 3], the RGB channel (after multiplying alpha!) """ sigmas = sigmas.contiguous() rgbs = rgbs.contiguous() M = sigmas.shape[0] N = rays.shape[0] weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) _backend.composite_rays_train_forward( sigmas, rgbs, deltas, rays, M, N, weights_sum, depth, image ) ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image) ctx.dims = [M, N] return weights_sum, depth, image @staticmethod @custom_bwd def backward(ctx, grad_weights_sum, grad_depth, grad_image): # NOTE: grad_depth is not used now! It won't be propagated to sigmas. grad_weights_sum = grad_weights_sum.contiguous() grad_image = grad_image.contiguous() sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors M, N = ctx.dims grad_sigmas = torch.zeros_like(sigmas) grad_rgbs = torch.zeros_like(rgbs) _backend.composite_rays_train_backward( grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, grad_sigmas, grad_rgbs, ) return grad_sigmas, grad_rgbs, None, None composite_rays_train = _composite_rays_train.apply # ---------------------------------------- # infer functions # ---------------------------------------- class _march_rays(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward( ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, align=-1, perturb=False, dt_gamma=0, max_steps=1024, ): """march rays to generate points (forward only, for inference) Args: n_alive: int, number of alive rays n_step: int, how many steps we march rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) rays_t: float, [N], the alive rays' time, we only use the first n_alive. rays_o/d: float, [N, 3] bound: float, scalar density_bitfield: uint8: [CHHH // 8] C: int H: int nears/fars: float, [N] align: int, pad output so its size is dividable by align, set to -1 to disable. perturb: bool/int, int > 0 is used as the random seed. dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) max_steps: int, max number of sampled points along each ray, also affect min_stepsize. Returns: xyzs: float, [n_alive * n_step, 3], all generated points' coords dirs: float, [n_alive * n_step, 3], all generated points' view dirs. deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). """ if not rays_o.is_cuda: rays_o = rays_o.cuda() if not rays_d.is_cuda: rays_d = rays_d.cuda() rays_o = rays_o.contiguous().view(-1, 3) rays_d = rays_d.contiguous().view(-1, 3) M = n_alive * n_step if align > 0: M += align - (M % align) xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) deltas = torch.zeros( M, 2, dtype=rays_o.dtype, device=rays_o.device ) # 2 vals, one for rgb, one for depth _backend.march_rays( n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, perturb, ) return xyzs, dirs, deltas march_rays = _march_rays.apply class _composite_rays(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float def forward( ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, ): """composite rays' rgbs, according to the ray marching formula. (for inference) Args: n_alive: int, number of alive rays n_step: int, how many steps we march rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) rays_t: float, [N], the alive rays' time, we only use the first n_alive. sigmas: float, [n_alive * n_step,] rgbs: float, [n_alive * n_step, 3] deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth). In-place Outputs: weights_sum: float, [N,], the alpha channel depth: float, [N,], the depth value image: float, [N, 3], the RGB channel (after multiplying alpha!) """ _backend.composite_rays( n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, ) return tuple() composite_rays = _composite_rays.apply class _compact_rays(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward( ctx, n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter ): """compact rays, remove dead rays and reallocate alive rays, to accelerate next ray marching. Args: n_alive: int, number of alive rays rays_alive_old: int, [N] rays_t_old: float, [N], dead rays are marked by rays_t < 0 alive_counter: int, [1], used to count remained alive rays. In-place Outputs: rays_alive: int, [N] rays_t: float, [N] """ _backend.compact_rays( n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter ) return tuple() compact_rays = _compact_rays.apply ================================================ FILE: raymarching/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path """ Usage: python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory) python setup.py install # build extensions and install (copy) to PATH. pip install . # ditto but better (e.g., dependency & metadata handling) python setup.py develop # build extensions and install (symbolic) to PATH. pip install -e . # ditto but better (e.g., dependency & metadata handling) """ setup( name="raymarching", # package name, import this to use python API ext_modules=[ CUDAExtension( name="_raymarching", # extension name, import this to use CUDA API sources=[ os.path.join(_src_path, "src", f) for f in [ "raymarching.cu", "bindings.cpp", ] ], extra_compile_args={ "cxx": c_flags, "nvcc": nvcc_flags, }, ), ], cmdclass={ "build_ext": BuildExtension, }, ) ================================================ FILE: raymarching/src/bindings.cpp ================================================ #include #include "raymarching.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // utils m.def("packbits", &packbits, "packbits (CUDA)"); m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)"); m.def("polar_from_ray", &polar_from_ray, "polar_from_ray (CUDA)"); m.def("morton3D", &morton3D, "morton3D (CUDA)"); m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)"); // train m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)"); m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)"); m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)"); // infer m.def("march_rays", &march_rays, "march rays (CUDA)"); m.def("composite_rays", &composite_rays, "composite rays (CUDA)"); m.def("compact_rays", &compact_rays, "compact rays (CUDA)"); } ================================================ FILE: raymarching/src/pcg32.h ================================================ /* * Tiny self-contained version of the PCG Random Number Generation for C++ * put together from pieces of the much larger C/C++ codebase. * Wenzel Jakob, February 2015 * * The PCG random number generator was developed by Melissa O'Neill * * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For additional information about the PCG random number generation scheme, * including its license and other licensing options, visit * * http://www.pcg-random.org * * Note: This code was modified to work with CUDA by the tiny-cuda-nn authors. */ #pragma once #define PCG32_DEFAULT_STATE 0x853c49e6748fea9bULL #define PCG32_DEFAULT_STREAM 0xda3e39cb94b95bdbULL #define PCG32_MULT 0x5851f42d4c957f2dULL #include #include #include #include #include #include /// PCG32 Pseudorandom number generator struct pcg32 { /// Initialize the pseudorandom number generator with default seed __host__ __device__ pcg32() : state(PCG32_DEFAULT_STATE), inc(PCG32_DEFAULT_STREAM) {} /// Initialize the pseudorandom number generator with the \ref seed() function __host__ __device__ pcg32(uint64_t initstate, uint64_t initseq = 1u) { seed(initstate, initseq); } /** * \brief Seed the pseudorandom number generator * * Specified in two parts: a state initializer and a sequence selection * constant (a.k.a. stream id) */ __host__ __device__ void seed(uint64_t initstate, uint64_t initseq = 1) { state = 0U; inc = (initseq << 1u) | 1u; next_uint(); state += initstate; next_uint(); } /// Generate a uniformly distributed unsigned 32-bit random number __host__ __device__ uint32_t next_uint() { uint64_t oldstate = state; state = oldstate * PCG32_MULT + inc; uint32_t xorshifted = (uint32_t) (((oldstate >> 18u) ^ oldstate) >> 27u); uint32_t rot = (uint32_t) (oldstate >> 59u); return (xorshifted >> rot) | (xorshifted << ((~rot + 1u) & 31)); } /// Generate a uniformly distributed number, r, where 0 <= r < bound __host__ __device__ uint32_t next_uint(uint32_t bound) { // To avoid bias, we need to make the range of the RNG a multiple of // bound, which we do by dropping output less than a threshold. // A naive scheme to calculate the threshold would be to do // // uint32_t threshold = 0x100000000ull % bound; // // but 64-bit div/mod is slower than 32-bit div/mod (especially on // 32-bit platforms). In essence, we do // // uint32_t threshold = (0x100000000ull-bound) % bound; // // because this version will calculate the same modulus, but the LHS // value is less than 2^32. uint32_t threshold = (~bound+1u) % bound; // Uniformity guarantees that this loop will terminate. In practice, it // should usually terminate quickly; on average (assuming all bounds are // equally likely), 82.25% of the time, we can expect it to require just // one iteration. In the worst case, someone passes a bound of 2^31 + 1 // (i.e., 2147483649), which invalidates almost 50% of the range. In // practice, bounds are typically small and only a tiny amount of the range // is eliminated. for (;;) { uint32_t r = next_uint(); if (r >= threshold) return r % bound; } } /// Generate a single precision floating point value on the interval [0, 1) __host__ __device__ float next_float() { /* Trick from MTGP: generate an uniformly distributed single precision number in [1,2) and subtract 1. */ union { uint32_t u; float f; } x; x.u = (next_uint() >> 9) | 0x3f800000u; return x.f - 1.0f; } /** * \brief Generate a double precision floating point value on the interval [0, 1) * * \remark Since the underlying random number generator produces 32 bit output, * only the first 32 mantissa bits will be filled (however, the resolution is still * finer than in \ref next_float(), which only uses 23 mantissa bits) */ __host__ __device__ double next_double() { /* Trick from MTGP: generate an uniformly distributed double precision number in [1,2) and subtract 1. */ union { uint64_t u; double d; } x; x.u = ((uint64_t) next_uint() << 20) | 0x3ff0000000000000ULL; return x.d - 1.0; } /** * \brief Multi-step advance function (jump-ahead, jump-back) * * The method used here is based on Brown, "Random Number Generation * with Arbitrary Stride", Transactions of the American Nuclear * Society (Nov. 1994). The algorithm is very similar to fast * exponentiation. * * The default value of 2^32 ensures that the PRNG is advanced * sufficiently far that there is (likely) no overlap with * previously drawn random numbers, even if small advancements. * are made inbetween. */ __host__ __device__ void advance(int64_t delta_ = (1ll<<32)) { uint64_t cur_mult = PCG32_MULT, cur_plus = inc, acc_mult = 1u, acc_plus = 0u; /* Even though delta is an unsigned integer, we can pass a signed integer to go backwards, it just goes "the long way round". */ uint64_t delta = (uint64_t) delta_; while (delta > 0) { if (delta & 1) { acc_mult *= cur_mult; acc_plus = acc_plus * cur_mult + cur_plus; } cur_plus = (cur_mult + 1) * cur_plus; cur_mult *= cur_mult; delta /= 2; } state = acc_mult * state + acc_plus; } /// Compute the distance between two PCG32 pseudorandom number generators __host__ __device__ int64_t operator-(const pcg32 &other) const { assert(inc == other.inc); uint64_t cur_mult = PCG32_MULT, cur_plus = inc, cur_state = other.state, the_bit = 1u, distance = 0u; while (state != cur_state) { if ((state & the_bit) != (cur_state & the_bit)) { cur_state = cur_state * cur_mult + cur_plus; distance |= the_bit; } assert((state & the_bit) == (cur_state & the_bit)); the_bit <<= 1; cur_plus = (cur_mult + 1ULL) * cur_plus; cur_mult *= cur_mult; } return (int64_t) distance; } /// Equality operator __host__ __device__ bool operator==(const pcg32 &other) const { return state == other.state && inc == other.inc; } /// Inequality operator __host__ __device__ bool operator!=(const pcg32 &other) const { return state != other.state || inc != other.inc; } uint64_t state; // RNG state. All values are possible. uint64_t inc; // Controls which RNG sequence (stream) is selected. Must *always* be odd. }; ================================================ FILE: raymarching/src/raymarching.cu ================================================ #include #include #include #include #include #include #include #include #include #include "pcg32.h" #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; } inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; } inline constexpr __device__ float PI() { return 3.141592653589793f; } inline constexpr __device__ float RPI() { return 0.3183098861837907f; } template inline __host__ __device__ T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } inline __host__ __device__ float signf(const float x) { return copysignf(1.0, x); } inline __host__ __device__ float clamp(const float x, const float min, const float max) { return fminf(max, fmaxf(min, x)); } inline __host__ __device__ void swapf(float& a, float& b) { float c = a; a = b; b = c; } inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) { const float mx = fmaxf(fabsf(x), fmaxf(fabs(y), fabs(z))); int exponent; frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ... return fminf(max_cascade - 1, fmaxf(0, exponent)); } inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) { const float mx = dt * H * 0.5; int exponent; frexpf(mx, &exponent); return fminf(max_cascade - 1, fmaxf(0, exponent)); } inline __host__ __device__ uint32_t __expand_bits(uint32_t v) { v = (v * 0x00010001u) & 0xFF0000FFu; v = (v * 0x00000101u) & 0x0F00F00Fu; v = (v * 0x00000011u) & 0xC30C30C3u; v = (v * 0x00000005u) & 0x49249249u; return v; } inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z) { uint32_t xx = __expand_bits(x); uint32_t yy = __expand_bits(y); uint32_t zz = __expand_bits(z); return xx | (yy << 1) | (zz << 2); } inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x) { x = x & 0x49249249; x = (x | (x >> 2)) & 0xc30c30c3; x = (x | (x >> 4)) & 0x0f00f00f; x = (x | (x >> 8)) & 0xff0000ff; x = (x | (x >> 16)) & 0x0000ffff; return x; } //////////////////////////////////////////////////// ///////////// utils ///////////// //////////////////////////////////////////////////// // rays_o/d: [N, 3] // nears/fars: [N] // scalar_t should always be float in use. template __global__ void kernel_near_far_from_aabb( const scalar_t * __restrict__ rays_o, const scalar_t * __restrict__ rays_d, const scalar_t * __restrict__ aabb, const uint32_t N, const float min_near, scalar_t * nears, scalar_t * fars ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; // get near far (assume cube scene) float near = (aabb[0] - ox) * rdx; float far = (aabb[3] - ox) * rdx; if (near > far) swapf(near, far); float near_y = (aabb[1] - oy) * rdy; float far_y = (aabb[4] - oy) * rdy; if (near_y > far_y) swapf(near_y, far_y); if (near > far_y || near_y > far) { nears[n] = fars[n] = std::numeric_limits::max(); return; } if (near_y > near) near = near_y; if (far_y < far) far = far_y; float near_z = (aabb[2] - oz) * rdz; float far_z = (aabb[5] - oz) * rdz; if (near_z > far_z) swapf(near_z, far_z); if (near > far_z || near_z > far) { nears[n] = fars[n] = std::numeric_limits::max(); return; } if (near_z > near) near = near_z; if (far_z < far) far = far_z; if (near < min_near) near = min_near; nears[n] = near; fars[n] = far; } void near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "near_far_from_aabb", ([&] { kernel_near_far_from_aabb<<>>(rays_o.data_ptr(), rays_d.data_ptr(), aabb.data_ptr(), N, min_near, nears.data_ptr(), fars.data_ptr()); })); } // rays_o/d: [N, 3] // radius: float // coords: [N, 2] template __global__ void kernel_polar_from_ray( const scalar_t * __restrict__ rays_o, const scalar_t * __restrict__ rays_d, const float radius, const uint32_t N, scalar_t * coords ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; coords += n * 2; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; // solve t from || o + td || = radius const float A = dx * dx + dy * dy + dz * dz; const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2 const float C = ox * ox + oy * oy + oz * oz - radius * radius; const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive) // solve theta, phi (assume y is the up axis) const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz; const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI) const float phi = atan2(z, x); // [-PI, PI) // normalize to [-1, 1] coords[0] = 2 * theta * RPI() - 1; coords[1] = phi * RPI(); } void polar_from_ray(at::Tensor rays_o, at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "polar_from_ray", ([&] { kernel_polar_from_ray<<>>(rays_o.data_ptr(), rays_d.data_ptr(), radius, N, coords.data_ptr()); })); } // coords: int32, [N, 3] // indices: int32, [N] __global__ void kernel_morton3D( const int * __restrict__ coords, const uint32_t N, int * indices ) { // parallel const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate coords += n * 3; indices[n] = __morton3D(coords[0], coords[1], coords[2]); } void morton3D(at::Tensor coords, const uint32_t N, at::Tensor indices) { static constexpr uint32_t N_THREAD = 128; kernel_morton3D<<>>(coords.data_ptr(), N, indices.data_ptr()); } // indices: int32, [N] // coords: int32, [N, 3] __global__ void kernel_morton3D_invert( const int * __restrict__ indices, const uint32_t N, int * coords ) { // parallel const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate coords += n * 3; const int ind = indices[n]; coords[0] = __morton3D_invert(ind >> 0); coords[1] = __morton3D_invert(ind >> 1); coords[2] = __morton3D_invert(ind >> 2); } void morton3D_invert(at::Tensor indices, const uint32_t N, at::Tensor coords) { static constexpr uint32_t N_THREAD = 128; kernel_morton3D_invert<<>>(indices.data_ptr(), N, coords.data_ptr()); } // grid: float, [C, H, H, H] // N: int, C * H * H * H / 8 // density_thresh: float // bitfield: uint8, [N] template __global__ void kernel_packbits( const scalar_t * __restrict__ grid, const uint32_t N, const float density_thresh, uint8_t * bitfield ) { // parallel per byte const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate grid += n * 8; uint8_t bits = 0; #pragma unroll for (uint8_t i = 0; i < 8; i++) { bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0; } bitfield[n] = bits; } void packbits(at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( grid.scalar_type(), "packbits", ([&] { kernel_packbits<<>>(grid.data_ptr(), N, density_thresh, bitfield.data_ptr()); })); } //////////////////////////////////////////////////// ///////////// training ///////////// //////////////////////////////////////////////////// // rays_o/d: [N, 3] // grid: [CHHH / 8] // xyzs, dirs, deltas: [M, 3], [M, 3], [M, 2] // dirs: [M, 3] // rays: [N, 3], idx, offset, num_steps template __global__ void kernel_march_rays_train( const scalar_t * __restrict__ rays_o, const scalar_t * __restrict__ rays_d, const uint8_t * __restrict__ grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, const scalar_t* __restrict__ nears, const scalar_t* __restrict__ fars, scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, int * rays, int * counter, const uint32_t perturb, pcg32 rng ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate rays_o += n * 3; rays_d += n * 3; // ray marching const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; const float rH = 1 / (float)H; const float near = nears[n]; const float far = fars[n]; const float dt_min = 2 * SQRT3() / max_steps; const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; float t0 = near; if (perturb) { rng.advance(n); t0 += dt_min * rng.next_float(); } // first pass: estimation of num_steps float t = t0; uint32_t num_steps = 0; //if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far); while (t < far && num_steps < max_steps) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf((float)(1 << level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const uint32_t index = level * H * H * H + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output //if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, num_steps); if (occ) { num_steps++; t += dt; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); // second pass: really locate and write points & dirs uint32_t point_index = atomicAdd(counter, num_steps); uint32_t ray_index = atomicAdd(counter + 1, 1); //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); // write rays rays[ray_index * 3] = n; rays[ray_index * 3 + 1] = point_index; rays[ray_index * 3 + 2] = num_steps; if (num_steps == 0) return; if (point_index + num_steps >= M) return; xyzs += point_index * 3; dirs += point_index * 3; deltas += point_index * 2; t = t0; uint32_t step = 0; float last_t = t; while (t < far && step < num_steps) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf((float)(1 << level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); // query grid const uint32_t index = level * H * H * H + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output if (occ) { // write step xyzs[0] = x; xyzs[1] = y; xyzs[2] = z; dirs[0] = dx; dirs[1] = dy; dirs[2] = dz; t += dt; deltas[0] = dt; deltas[1] = t - last_t; // used to calc depth last_t = t; xyzs += 3; dirs += 3; deltas += 2; step++; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } } void march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb) { static constexpr uint32_t N_THREAD = 128; pcg32 rng = pcg32{(uint64_t)42}; // hard coded random seed AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "march_rays_train", ([&] { kernel_march_rays_train<<>>(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), bound, dt_gamma, max_steps, N, C, H, M, nears.data_ptr(), fars.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), perturb, rng); })); } // sigmas: [M] // rgbs: [M, 3] // deltas: [M, 2] // rays: [N, 3], idx, offset, num_steps // weights_sum: [N], final pixel alpha // depth: [N,] // image: [N, 3] template __global__ void kernel_composite_rays_train_forward( const scalar_t * __restrict__ sigmas, const scalar_t * __restrict__ rgbs, const scalar_t * __restrict__ deltas, const int * __restrict__ rays, const uint32_t M, const uint32_t N, scalar_t * weights_sum, scalar_t * depth, scalar_t * image ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate uint32_t index = rays[n * 3]; uint32_t offset = rays[n * 3 + 1]; uint32_t num_steps = rays[n * 3 + 2]; // empty ray, or ray that exceed max step count. if (num_steps == 0 || offset + num_steps >= M) { weights_sum[index] = 0; depth[index] = 0; image[index * 3] = 0; image[index * 3 + 1] = 0; image[index * 3 + 2] = 0; return; } sigmas += offset; rgbs += offset * 3; deltas += offset * 2; // accumulate uint32_t step = 0; scalar_t T = 1.0f; scalar_t r = 0, g = 0, b = 0, ws = 0, t = 0, d = 0; while (step < num_steps) { const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); const scalar_t weight = alpha * T; // minimal remained transmittence // NOTE: uncomment it won't affect instant-ngp, but totally breaks TensoRF... //if (weight < 1e-4f) break; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; t += deltas[1]; // real delta d += weight * t; ws += weight; T *= 1.0f - alpha; //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); // locate sigmas++; rgbs += 3; deltas += 2; step++; } //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); // write weights_sum[index] = ws; // weights_sum depth[index] = d; image[index * 3] = r; image[index * 3 + 1] = g; image[index * 3 + 2] = b; } void composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( sigmas.scalar_type(), "composite_rays_train_forward", ([&] { kernel_composite_rays_train_forward<<>>(sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), M, N, weights_sum.data_ptr(), depth.data_ptr(), image.data_ptr()); })); } // grad_weights_sum: [N,] // grad: [N, 3] // sigmas: [M] // rgbs: [M, 3] // deltas: [M, 2] // rays: [N, 3], idx, offset, num_steps // weights_sum: [N,], weights_sum here // image: [N, 3] // grad_sigmas: [M] // grad_rgbs: [M, 3] template __global__ void kernel_composite_rays_train_backward( const scalar_t * __restrict__ grad_weights_sum, const scalar_t * __restrict__ grad_image, const scalar_t * __restrict__ sigmas, const scalar_t * __restrict__ rgbs, const scalar_t * __restrict__ deltas, const int * __restrict__ rays, const scalar_t * __restrict__ weights_sum, const scalar_t * __restrict__ image, const uint32_t M, const uint32_t N, scalar_t * grad_sigmas, scalar_t * grad_rgbs ) { // parallel per ray const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; // locate uint32_t index = rays[n * 3]; uint32_t offset = rays[n * 3 + 1]; uint32_t num_steps = rays[n * 3 + 2]; if (num_steps == 0 || offset + num_steps >= M) return; grad_weights_sum += index; grad_image += index * 3; weights_sum += index; image += index * 3; sigmas += offset; rgbs += offset * 3; deltas += offset * 2; grad_sigmas += offset; grad_rgbs += offset * 3; // accumulate uint32_t step = 0; scalar_t T = 1.0f; const scalar_t r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0]; scalar_t r = 0, g = 0, b = 0, ws = 0; while (step < num_steps) { const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); const scalar_t weight = alpha * T; //if (weight < 1e-4f) break; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; ws += weight; T *= 1.0f - alpha; // write grad_rgbs grad_rgbs[0] = grad_image[0] * weight; grad_rgbs[1] = grad_image[1] * weight; grad_rgbs[2] = grad_image[2] * weight; // write grad_sigmas grad_sigmas[0] = deltas[0] * ( grad_image[0] * (T * rgbs[0] - (r_final - r)) + grad_image[1] * (T * rgbs[1] - (g_final - g)) + grad_image[2] * (T * rgbs[2] - (b_final - b)) + grad_weights_sum[0] * (T - (ws_final - ws)) ); //printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r); // locate sigmas++; rgbs += 3; deltas += 2; grad_sigmas++; grad_rgbs += 3; step++; } } void composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad_image.scalar_type(), "composite_rays_train_backward", ([&] { kernel_composite_rays_train_backward<<>>(grad_weights_sum.data_ptr(), grad_image.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), weights_sum.data_ptr(), image.data_ptr(), M, N, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); })); } //////////////////////////////////////////////////// ///////////// infernce ///////////// //////////////////////////////////////////////////// template __global__ void kernel_march_rays( const uint32_t n_alive, const uint32_t n_step, const int* __restrict__ rays_alive, const scalar_t* __restrict__ rays_t, const scalar_t* __restrict__ rays_o, const scalar_t* __restrict__ rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const uint8_t * __restrict__ grid, const scalar_t* __restrict__ nears, const scalar_t* __restrict__ fars, scalar_t* xyzs, scalar_t* dirs, scalar_t* deltas, const uint32_t perturb, pcg32 rng ) { const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= n_alive) return; const int index = rays_alive[n]; // ray id float t = rays_t[n]; // current ray's t // locate rays_o += index * 3; rays_d += index * 3; xyzs += n * n_step * 3; dirs += n * n_step * 3; deltas += n * n_step * 2; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; const float rH = 1 / (float)H; const float near = nears[index], far = fars[index]; const float dt_min = 2 * SQRT3() / max_steps; const float dt_max = 2 * SQRT3() * (1 << (C - 1)) / H; // march for n_step steps, record points uint32_t step = 0; // introduce some randomness (pass in spp as perturb here) if (perturb) { rng.advance(n); t += dt_min * rng.next_float(); } float last_t = t; while (t < far && step < n_step) { // current point const float x = clamp(ox + t * dx, -bound, bound); const float y = clamp(oy + t * dy, -bound, bound); const float z = clamp(oz + t * dz, -bound, bound); const float dt = clamp(t * dt_gamma, dt_min, dt_max); // get mip level const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1] const float mip_bound = fminf((float)(1 << level), bound); const float mip_rbound = 1 / mip_bound; // convert to nearest grid position const int nx = clamp(0.5 * (x * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int ny = clamp(0.5 * (y * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const int nz = clamp(0.5 * (z * mip_rbound + 1) * H, 0.0f, (float)(H - 1)); const uint32_t index = level * H * H * H + __morton3D(nx, ny, nz); const bool occ = grid[index / 8] & (1 << (index % 8)); // if occpuied, advance a small step, and write to output if (occ) { // write step xyzs[0] = x; xyzs[1] = y; xyzs[2] = z; dirs[0] = dx; dirs[1] = dy; dirs[2] = dz; // calc dt t += dt; deltas[0] = dt; deltas[1] = t - last_t; // used to calc depth last_t = t; // step xyzs += 3; dirs += 3; deltas += 2; step++; // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - x) * rdx; const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - y) * rdy; const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - z) * rdz; const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz))); // step until next voxel do { t += clamp(t * dt_gamma, dt_min, dt_max); } while (t < tt); } } } void march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, at::Tensor grid, at::Tensor near, at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb) { static constexpr uint32_t N_THREAD = 128; pcg32 rng = pcg32{(uint64_t)perturb}; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "march_rays", ([&] { kernel_march_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), rays_o.data_ptr(), rays_d.data_ptr(), bound, dt_gamma, max_steps, C, H, grid.data_ptr(), near.data_ptr(), far.data_ptr(), xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), perturb, rng); })); } template __global__ void kernel_composite_rays( const uint32_t n_alive, const uint32_t n_step, const int* __restrict__ rays_alive, scalar_t* rays_t, const scalar_t* __restrict__ sigmas, const scalar_t* __restrict__ rgbs, const scalar_t* __restrict__ deltas, scalar_t* weights_sum, scalar_t* depth, scalar_t* image ) { const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= n_alive) return; const int index = rays_alive[n]; // ray id scalar_t t = rays_t[n]; // current ray's t // locate sigmas += n * n_step; rgbs += n * n_step * 3; deltas += n * n_step * 2; weights_sum += index; depth += index; image += index * 3; scalar_t weight_sum = weights_sum[0]; scalar_t d = depth[0]; scalar_t r = image[0]; scalar_t g = image[1]; scalar_t b = image[2]; // accumulate uint32_t step = 0; while (step < n_step) { // ray is terminated if delta == 0 if (deltas[0] == 0) break; const scalar_t alpha = 1.0f - __expf(- sigmas[0] * deltas[0]); /* T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j) w_i = alpha_i * T_i --> T_i = 1 - \sum_{j=0}^{i-1} w_j */ const scalar_t T = 1 - weight_sum; const scalar_t weight = alpha * T; weight_sum += weight; t += deltas[1]; // real delta d += weight * t; r += weight * rgbs[0]; g += weight * rgbs[1]; b += weight * rgbs[2]; //printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d); // ray is terminated if T is too small // NOTE: can significantly accelerate inference! if (T < 1e-4) break; // locate sigmas++; rgbs += 3; deltas += 2; step++; } //printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d); // rays_t = -1 means ray is terminated early. if (step < n_step) { rays_t[n] = -1; } else { rays_t[n] = t; } weights_sum[0] = weight_sum; // this is the thing I needed! depth[0] = d; image[0] = r; image[1] = g; image[2] = b; } void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( image.scalar_type(), "composite_rays", ([&] { kernel_composite_rays<<>>(n_alive, n_step, rays_alive.data_ptr(), rays_t.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), weights.data_ptr(), depth.data_ptr(), image.data_ptr()); })); } template __global__ void kernel_compact_rays( const uint32_t n_alive, int* rays_alive, const int* __restrict__ rays_alive_old, scalar_t* rays_t, const scalar_t* __restrict__ rays_t_old, int* alive_counter ) { const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= n_alive) return; // rays_t_old[n] < 0 means ray died in last composite kernel. if (rays_t_old[n] >= 0) { const int index = atomicAdd(alive_counter, 1); rays_alive[index] = rays_alive_old[n]; rays_t[index] = rays_t_old[n]; } } void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter) { static constexpr uint32_t N_THREAD = 128; AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_t.scalar_type(), "compact_rays", ([&] { kernel_compact_rays<<>>(n_alive, rays_alive.data_ptr(), rays_alive_old.data_ptr(), rays_t.data_ptr(), rays_t_old.data_ptr(), alive_counter.data_ptr()); })); } ================================================ FILE: raymarching/src/raymarching.h ================================================ #pragma once #include #include void near_far_from_aabb(at::Tensor rays_o, at::Tensor rays_d, at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars); void polar_from_ray(at::Tensor rays_o, at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords); void morton3D(at::Tensor coords, const uint32_t N, at::Tensor indices); void morton3D_invert(at::Tensor indices, const uint32_t N, at::Tensor coords); void packbits(at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield); void march_rays_train(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const uint32_t M, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const uint32_t perturb); void composite_rays_train_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const uint32_t M, const uint32_t N, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); void composite_rays_train_backward(at::Tensor grad_weights_sum, at::Tensor grad_image, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor weights_sum, at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs); void march_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor rays_o, at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, at::Tensor grid, at::Tensor nears, at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb); void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image); void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, at::Tensor rays_alive_old, at::Tensor rays_t, at::Tensor rays_t_old, at::Tensor alive_counter); ================================================ FILE: shencoder/__init__.py ================================================ from .sphere_harmonics import SHEncoder ================================================ FILE: shencoder/backend.py ================================================ import os from torch.utils.cpp_extension import load _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path _backend = load( name="_sh_encoder", extra_cflags=c_flags, extra_cuda_cflags=nvcc_flags, sources=[ os.path.join(_src_path, "src", f) for f in [ "shencoder.cu", "bindings.cpp", ] ], ) __all__ = ["_backend"] ================================================ FILE: shencoder/setup.py ================================================ import os from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension _src_path = os.path.dirname(os.path.abspath(__file__)) nvcc_flags = [ "-O3", "-std=c++14", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] if os.name == "posix": c_flags = ["-O3", "-std=c++14"] elif os.name == "nt": c_flags = ["/O2", "/std:c++17"] # find cl.exe def find_cl_path(): import glob for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: paths = sorted( glob.glob( r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition ), reverse=True, ) if paths: return paths[0] # If cl.exe is not on path, try to find it. if os.system("where cl.exe >nul 2>nul") != 0: cl_path = find_cl_path() if cl_path is None: raise RuntimeError( "Could not locate a supported Microsoft Visual C++ installation" ) os.environ["PATH"] += ";" + cl_path setup( name="shencoder", # package name, import this to use python API ext_modules=[ CUDAExtension( name="_shencoder", # extension name, import this to use CUDA API sources=[ os.path.join(_src_path, "src", f) for f in [ "shencoder.cu", "bindings.cpp", ] ], extra_compile_args={ "cxx": c_flags, "nvcc": nvcc_flags, }, ), ], cmdclass={ "build_ext": BuildExtension, }, ) ================================================ FILE: shencoder/sphere_harmonics.py ================================================ import numpy as np import torch import torch.nn as nn from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.cuda.amp import custom_bwd, custom_fwd try: import _shencoder as _backend except ImportError: from .backend import _backend class _sh_encoder(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision def forward(ctx, inputs, degree, calc_grad_inputs=False): # inputs: [B, input_dim], float in [-1, 1] # RETURN: [B, F], float inputs = inputs.contiguous() B, input_dim = inputs.shape # batch size, coord dim output_dim = degree ** 2 outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) if calc_grad_inputs: dy_dx = torch.empty( B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device ) else: dy_dx = torch.empty(1, dtype=inputs.dtype, device=inputs.device) _backend.sh_encode_forward( inputs, outputs, B, input_dim, degree, calc_grad_inputs, dy_dx ) ctx.save_for_backward(inputs, dy_dx) ctx.dims = [B, input_dim, degree] ctx.calc_grad_inputs = calc_grad_inputs return outputs @staticmethod # @once_differentiable @custom_bwd def backward(ctx, grad): # grad: [B, C * C] if ctx.calc_grad_inputs: grad = grad.contiguous() inputs, dy_dx = ctx.saved_tensors B, input_dim, degree = ctx.dims grad_inputs = torch.zeros_like(inputs) _backend.sh_encode_backward( grad, inputs, B, input_dim, degree, dy_dx, grad_inputs ) return grad_inputs, None, None else: return None, None, None sh_encode = _sh_encoder.apply class SHEncoder(nn.Module): def __init__(self, input_dim=3, degree=4): super().__init__() self.input_dim = input_dim # coord dims, must be 3 self.degree = degree # 0 ~ 4 self.output_dim = degree ** 2 assert self.input_dim == 3, "SH encoder only support input dim == 3" assert ( self.degree > 0 and self.degree <= 8 ), "SH encoder only supports degree in [1, 8]" def __repr__(self): return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" def forward(self, inputs, size=1): # inputs: [..., input_dim], normalized real world positions in [-size, size] # return: [..., degree^2] inputs = inputs / size # [-1, 1] prefix_shape = list(inputs.shape[:-1]) inputs = inputs.reshape(-1, self.input_dim) outputs = sh_encode(inputs, self.degree, inputs.requires_grad) outputs = outputs.reshape(prefix_shape + [self.output_dim]) return outputs ================================================ FILE: shencoder/src/bindings.cpp ================================================ #include #include "shencoder.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sh_encode_forward", &sh_encode_forward, "SH encode forward (CUDA)"); m.def("sh_encode_backward", &sh_encode_backward, "SH encode backward (CUDA)"); } ================================================ FILE: shencoder/src/shencoder.cu ================================================ #include #include #include #include #include #include #include #include #include #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") template __host__ __device__ T div_round_up(T val, T divisor) { return (val + divisor - 1) / divisor; } template __global__ void kernel_sh( const scalar_t * __restrict__ inputs, scalar_t * outputs, uint32_t B, uint32_t D, uint32_t C, const bool calc_grad_inputs, scalar_t * dy_dx ) { const uint32_t b = threadIdx.x + blockIdx.x * blockDim.x; if (b >= B) return; const uint32_t C2 = C * C; // locate inputs += b * D; outputs += b * C2; scalar_t x = inputs[0], y = inputs[1], z = inputs[2]; scalar_t xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z; scalar_t x4=x2*x2, y4=y2*y2, z4=z2*z2; scalar_t x6=x4*x2, y6=y4*y2, z6=z4*z2; auto write_sh = [&]() { outputs[0] = 0.28209479177387814f ; // 1/(2*sqrt(pi)) if (C <= 1) { return; } outputs[1] = -0.48860251190291987f*y ; // -sqrt(3)*y/(2*sqrt(pi)) outputs[2] = 0.48860251190291987f*z ; // sqrt(3)*z/(2*sqrt(pi)) outputs[3] = -0.48860251190291987f*x ; // -sqrt(3)*x/(2*sqrt(pi)) if (C <= 2) { return; } outputs[4] = 1.0925484305920792f*xy ; // sqrt(15)*xy/(2*sqrt(pi)) outputs[5] = -1.0925484305920792f*yz ; // -sqrt(15)*yz/(2*sqrt(pi)) outputs[6] = 0.94617469575755997f*z2 - 0.31539156525251999f ; // sqrt(5)*(3*z2 - 1)/(4*sqrt(pi)) outputs[7] = -1.0925484305920792f*xz ; // -sqrt(15)*xz/(2*sqrt(pi)) outputs[8] = 0.54627421529603959f*x2 - 0.54627421529603959f*y2 ; // sqrt(15)*(x2 - y2)/(4*sqrt(pi)) if (C <= 3) { return; } outputs[9] = 0.59004358992664352f*y*(-3.0f*x2 + y2) ; // sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) outputs[10] = 2.8906114426405538f*xy*z ; // sqrt(105)*xy*z/(2*sqrt(pi)) outputs[11] = 0.45704579946446572f*y*(1.0f - 5.0f*z2) ; // sqrt(42)*y*(1 - 5*z2)/(8*sqrt(pi)) outputs[12] = 0.3731763325901154f*z*(5.0f*z2 - 3.0f) ; // sqrt(7)*z*(5*z2 - 3)/(4*sqrt(pi)) outputs[13] = 0.45704579946446572f*x*(1.0f - 5.0f*z2) ; // sqrt(42)*x*(1 - 5*z2)/(8*sqrt(pi)) outputs[14] = 1.4453057213202769f*z*(x2 - y2) ; // sqrt(105)*z*(x2 - y2)/(4*sqrt(pi)) outputs[15] = 0.59004358992664352f*x*(-x2 + 3.0f*y2) ; // sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) if (C <= 4) { return; } outputs[16] = 2.5033429417967046f*xy*(x2 - y2) ; // 3*sqrt(35)*xy*(x2 - y2)/(4*sqrt(pi)) outputs[17] = 1.7701307697799304f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(70)*yz*(-3*x2 + y2)/(8*sqrt(pi)) outputs[18] = 0.94617469575756008f*xy*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*xy*(7*z2 - 1)/(4*sqrt(pi)) outputs[19] = 0.66904654355728921f*yz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*yz*(3 - 7*z2)/(8*sqrt(pi)) outputs[20] = -3.1735664074561294f*z2 + 3.7024941420321507f*z4 + 0.31735664074561293f ; // 3*(-30*z2 + 35*z4 + 3)/(16*sqrt(pi)) outputs[21] = 0.66904654355728921f*xz*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*xz*(3 - 7*z2)/(8*sqrt(pi)) outputs[22] = 0.47308734787878004f*(x2 - y2)*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*(x2 - y2)*(7*z2 - 1)/(8*sqrt(pi)) outputs[23] = 1.7701307697799304f*xz*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*xz*(-x2 + 3*y2)/(8*sqrt(pi)) outputs[24] = -3.7550144126950569f*x2*y2 + 0.62583573544917614f*x4 + 0.62583573544917614f*y4 ; // 3*sqrt(35)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) if (C <= 5) { return; } outputs[25] = 0.65638205684017015f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(154)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) outputs[26] = 8.3026492595241645f*xy*z*(x2 - y2) ; // 3*sqrt(385)*xy*z*(x2 - y2)/(4*sqrt(pi)) outputs[27] = -0.48923829943525038f*y*(3.0f*x2 - y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*y*(3*x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) outputs[28] = 4.7935367849733241f*xy*z*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xy*z*(3*z2 - 1)/(4*sqrt(pi)) outputs[29] = 0.45294665119569694f*y*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*y*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) outputs[30] = 0.1169503224534236f*z*(-70.0f*z2 + 63.0f*z4 + 15.0f) ; // sqrt(11)*z*(-70*z2 + 63*z4 + 15)/(16*sqrt(pi)) outputs[31] = 0.45294665119569694f*x*(14.0f*z2 - 21.0f*z4 - 1.0f) ; // sqrt(165)*x*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) outputs[32] = 2.3967683924866621f*z*(x2 - y2)*(3.0f*z2 - 1.0f) ; // sqrt(1155)*z*(x2 - y2)*(3*z2 - 1)/(8*sqrt(pi)) outputs[33] = -0.48923829943525038f*x*(x2 - 3.0f*y2)*(9.0f*z2 - 1.0f) ; // -sqrt(770)*x*(x2 - 3*y2)*(9*z2 - 1)/(32*sqrt(pi)) outputs[34] = 2.0756623148810411f*z*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) outputs[35] = 0.65638205684017015f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(154)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) if (C <= 6) { return; } outputs[36] = 1.3663682103838286f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // sqrt(6006)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) outputs[37] = 2.3666191622317521f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) outputs[38] = 2.0182596029148963f*xy*(x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*xy*(x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) outputs[39] = -0.92120525951492349f*yz*(3.0f*x2 - y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*yz*(3*x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) outputs[40] = 0.92120525951492349f*xy*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*xy*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) outputs[41] = 0.58262136251873131f*yz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*yz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) outputs[42] = 6.6747662381009842f*z2 - 20.024298714302954f*z4 + 14.684485723822165f*z6 - 0.31784601133814211f ; // sqrt(13)*(105*z2 - 315*z4 + 231*z6 - 5)/(32*sqrt(pi)) outputs[43] = 0.58262136251873131f*xz*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*xz*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) outputs[44] = 0.46060262975746175f*(x2 - y2)*(11.0f*z2*(3.0f*z2 - 1.0f) - 7.0f*z2 + 1.0f) ; // sqrt(2730)*(x2 - y2)*(11*z2*(3*z2 - 1) - 7*z2 + 1)/(64*sqrt(pi)) outputs[45] = -0.92120525951492349f*xz*(x2 - 3.0f*y2)*(11.0f*z2 - 3.0f) ; // -sqrt(2730)*xz*(x2 - 3*y2)*(11*z2 - 3)/(32*sqrt(pi)) outputs[46] = 0.50456490072872406f*(11.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(91)*(11*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) outputs[47] = 2.3666191622317521f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) outputs[48] = 10.247761577878714f*x2*y4 - 10.247761577878714f*x4*y2 + 0.6831841051919143f*x6 - 0.6831841051919143f*y6 ; // sqrt(6006)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) if (C <= 7) { return; } outputs[49] = 0.70716273252459627f*y*(-21.0f*x2*y4 + 35.0f*x4*y2 - 7.0f*x6 + y6) ; // 3*sqrt(715)*y*(-21*x2*y4 + 35*x4*y2 - 7*x6 + y6)/(64*sqrt(pi)) outputs[50] = 5.2919213236038001f*xy*z*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*z*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) outputs[51] = -0.51891557872026028f*y*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // -3*sqrt(385)*y*(13*z2 - 1)*(-10*x2*y2 + 5*x4 + y4)/(64*sqrt(pi)) outputs[52] = 4.1513246297620823f*xy*z*(x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xy*z*(x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) outputs[53] = -0.15645893386229404f*y*(3.0f*x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*y*(3*x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) outputs[54] = 0.44253269244498261f*xy*z*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xy*z*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) outputs[55] = 0.090331607582517306f*y*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*y*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) outputs[56] = 0.068284276912004949f*z*(315.0f*z2 - 693.0f*z4 + 429.0f*z6 - 35.0f) ; // sqrt(15)*z*(315*z2 - 693*z4 + 429*z6 - 35)/(32*sqrt(pi)) outputs[57] = 0.090331607582517306f*x*(-135.0f*z2 + 495.0f*z4 - 429.0f*z6 + 5.0f) ; // sqrt(105)*x*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) outputs[58] = 0.07375544874083044f*z*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*z*(x2 - y2)*(143*z2*(3*z2 - 1) - 187*z2 + 45)/(64*sqrt(pi)) outputs[59] = -0.15645893386229404f*x*(x2 - 3.0f*y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -3*sqrt(35)*x*(x2 - 3*y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) outputs[60] = 1.0378311574405206f*z*(13.0f*z2 - 3.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 3*sqrt(385)*z*(13*z2 - 3)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) outputs[61] = -0.51891557872026028f*x*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // -3*sqrt(385)*x*(13*z2 - 1)*(-10*x2*y2 + x4 + 5*y4)/(64*sqrt(pi)) outputs[62] = 2.6459606618019f*z*(15.0f*x2*y4 - 15.0f*x4*y2 + x6 - y6) ; // 3*sqrt(10010)*z*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) outputs[63] = 0.70716273252459627f*x*(-35.0f*x2*y4 + 21.0f*x4*y2 - x6 + 7.0f*y6) ; // 3*sqrt(715)*x*(-35*x2*y4 + 21*x4*y2 - x6 + 7*y6)/(64*sqrt(pi)) }; write_sh(); if (calc_grad_inputs) { scalar_t *dx = dy_dx + b * D * C2; scalar_t *dy = dx + C2; scalar_t *dz = dy + C2; auto write_sh_dx = [&]() { dx[0] = 0.0f ; // 0 if (C <= 1) { return; } dx[1] = 0.0f ; // 0 dx[2] = 0.0f ; // 0 dx[3] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) if (C <= 2) { return; } dx[4] = 1.0925484305920792f*y ; // sqrt(15)*y/(2*sqrt(pi)) dx[5] = 0.0f ; // 0 dx[6] = 0.0f ; // 0 dx[7] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) dx[8] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) if (C <= 3) { return; } dx[9] = -3.5402615395598609f*xy ; // -3*sqrt(70)*xy/(4*sqrt(pi)) dx[10] = 2.8906114426405538f*yz ; // sqrt(105)*yz/(2*sqrt(pi)) dx[11] = 0.0f ; // 0 dx[12] = 0.0f ; // 0 dx[13] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) dx[14] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) dx[15] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) if (C <= 4) { return; } dx[16] = 2.5033429417967046f*y*(3.0f*x2 - y2) ; // 3*sqrt(35)*y*(3*x2 - y2)/(4*sqrt(pi)) dx[17] = -10.620784618679583f*xy*z ; // -9*sqrt(70)*xy*z/(4*sqrt(pi)) dx[18] = 0.94617469575756008f*y*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*y*(7*z2 - 1)/(4*sqrt(pi)) dx[19] = 0.0f ; // 0 dx[20] = 0.0f ; // 0 dx[21] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) dx[22] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) dx[23] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) dx[24] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) if (C <= 5) { return; } dx[25] = 13.127641136803401f*xy*(-x2 + y2) ; // 15*sqrt(154)*xy*(-x2 + y2)/(8*sqrt(pi)) dx[26] = 8.3026492595241645f*yz*(3.0f*x2 - y2) ; // 3*sqrt(385)*yz*(3*x2 - y2)/(4*sqrt(pi)) dx[27] = 2.9354297966115022f*xy*(1.0f - 9.0f*z2) ; // 3*sqrt(770)*xy*(1 - 9*z2)/(16*sqrt(pi)) dx[28] = 4.7935367849733241f*yz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*yz*(3*z2 - 1)/(4*sqrt(pi)) dx[29] = 0.0f ; // 0 dx[30] = 0.0f ; // 0 dx[31] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) dx[32] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) dx[33] = -13.209434084751759f*x2*z2 + 1.4677148983057511f*x2 + 13.209434084751759f*y2*z2 - 1.4677148983057511f*y2 ; // 3*sqrt(770)*(-9*x2*z2 + x2 + 9*y2*z2 - y2)/(32*sqrt(pi)) dx[34] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) dx[35] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) if (C <= 6) { return; } dx[36] = 4.0991046311514854f*y*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 3*sqrt(6006)*y*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) dx[37] = 47.332383244635047f*xy*z*(-x2 + y2) ; // 15*sqrt(2002)*xy*z*(-x2 + y2)/(8*sqrt(pi)) dx[38] = 2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) dx[39] = 5.5272315570895412f*xy*z*(3.0f - 11.0f*z2) ; // 3*sqrt(2730)*xy*z*(3 - 11*z2)/(16*sqrt(pi)) dx[40] = 0.92120525951492349f*y*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*y*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) dx[41] = 0.0f ; // 0 dx[42] = 0.0f ; // 0 dx[43] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) dx[44] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) dx[45] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) dx[46] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) dx[47] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) dx[48] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) if (C <= 7) { return; } dx[49] = 9.9002782553443485f*xy*(10.0f*x2*y2 - 3.0f*x4 - 3.0f*y4) ; // 21*sqrt(715)*xy*(10*x2*y2 - 3*x4 - 3*y4)/(32*sqrt(pi)) dx[50] = 15.875763970811402f*yz*(-10.0f*x2*y2 + 5.0f*x4 + y4) ; // 9*sqrt(10010)*yz*(-10*x2*y2 + 5*x4 + y4)/(32*sqrt(pi)) dx[51] = -10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // -15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) dx[52] = 4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) dx[53] = 0.93875360317376422f*xy*(66.0f*z2 - 143.0f*z4 - 3.0f) ; // 9*sqrt(35)*xy*(66*z2 - 143*z4 - 3)/(32*sqrt(pi)) dx[54] = 0.44253269244498261f*yz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*yz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) dx[55] = 0.0f ; // 0 dx[56] = 0.0f ; // 0 dx[57] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) dx[58] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) dx[59] = 30.97886890473422f*x2*z2 - 67.120882626924143f*x2*z4 - 1.4081304047606462f*x2 - 30.97886890473422f*y2*z2 + 67.120882626924143f*y2*z4 + 1.4081304047606462f*y2 ; // 9*sqrt(35)*(66*x2*z2 - 143*x2*z4 - 3*x2 - 66*y2*z2 + 143*y2*z4 + 3*y2)/(64*sqrt(pi)) dx[60] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) dx[61] = -0.51891557872026028f*(13.0f*z2 - 1.0f)*(-10.0f*x2*y2 + 4.0f*x2*(x2 - 5.0f*y2) + x4 + 5.0f*y4) ; // -3*sqrt(385)*(13*z2 - 1)*(-10*x2*y2 + 4*x2*(x2 - 5*y2) + x4 + 5*y4)/(64*sqrt(pi)) dx[62] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) dx[63] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) }; auto write_sh_dy = [&]() { dy[0] = 0.0f ; // 0 if (C <= 1) { return; } dy[1] = -0.48860251190291992f ; // -sqrt(3)/(2*sqrt(pi)) dy[2] = 0.0f ; // 0 dy[3] = 0.0f ; // 0 if (C <= 2) { return; } dy[4] = 1.0925484305920792f*x ; // sqrt(15)*x/(2*sqrt(pi)) dy[5] = -1.0925484305920792f*z ; // -sqrt(15)*z/(2*sqrt(pi)) dy[6] = 0.0f ; // 0 dy[7] = 0.0f ; // 0 dy[8] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) if (C <= 3) { return; } dy[9] = -1.7701307697799304f*x2 + 1.7701307697799304f*y2 ; // 3*sqrt(70)*(-x2 + y2)/(8*sqrt(pi)) dy[10] = 2.8906114426405538f*xz ; // sqrt(105)*xz/(2*sqrt(pi)) dy[11] = 0.45704579946446572f - 2.2852289973223288f*z2 ; // sqrt(42)*(1 - 5*z2)/(8*sqrt(pi)) dy[12] = 0.0f ; // 0 dy[13] = 0.0f ; // 0 dy[14] = -2.8906114426405538f*yz ; // -sqrt(105)*yz/(2*sqrt(pi)) dy[15] = 3.5402615395598609f*xy ; // 3*sqrt(70)*xy/(4*sqrt(pi)) if (C <= 4) { return; } dy[16] = 2.5033429417967046f*x*(x2 - 3.0f*y2) ; // 3*sqrt(35)*x*(x2 - 3*y2)/(4*sqrt(pi)) dy[17] = 5.3103923093397913f*z*(-x2 + y2) ; // 9*sqrt(70)*z*(-x2 + y2)/(8*sqrt(pi)) dy[18] = 0.94617469575756008f*x*(7.0f*z2 - 1.0f) ; // 3*sqrt(5)*x*(7*z2 - 1)/(4*sqrt(pi)) dy[19] = 0.66904654355728921f*z*(3.0f - 7.0f*z2) ; // 3*sqrt(10)*z*(3 - 7*z2)/(8*sqrt(pi)) dy[20] = 0.0f ; // 0 dy[21] = 0.0f ; // 0 dy[22] = 0.94617469575756008f*y*(1.0f - 7.0f*z2) ; // 3*sqrt(5)*y*(1 - 7*z2)/(4*sqrt(pi)) dy[23] = 10.620784618679583f*xy*z ; // 9*sqrt(70)*xy*z/(4*sqrt(pi)) dy[24] = 2.5033429417967046f*y*(-3.0f*x2 + y2) ; // 3*sqrt(35)*y*(-3*x2 + y2)/(4*sqrt(pi)) if (C <= 5) { return; } dy[25] = 19.6914617052051f*x2*y2 - 3.2819102842008503f*x4 - 3.2819102842008503f*y4 ; // 15*sqrt(154)*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) dy[26] = 8.3026492595241645f*xz*(x2 - 3.0f*y2) ; // 3*sqrt(385)*xz*(x2 - 3*y2)/(4*sqrt(pi)) dy[27] = -1.4677148983057511f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // -3*sqrt(770)*(x2 - y2)*(9*z2 - 1)/(32*sqrt(pi)) dy[28] = 4.7935367849733241f*xz*(3.0f*z2 - 1.0f) ; // sqrt(1155)*xz*(3*z2 - 1)/(4*sqrt(pi)) dy[29] = 6.3412531167397574f*z2 - 9.5118796751096362f*z4 - 0.45294665119569694f ; // sqrt(165)*(14*z2 - 21*z4 - 1)/(16*sqrt(pi)) dy[30] = 0.0f ; // 0 dy[31] = 0.0f ; // 0 dy[32] = 4.7935367849733241f*yz*(1.0f - 3.0f*z2) ; // sqrt(1155)*yz*(1 - 3*z2)/(4*sqrt(pi)) dy[33] = 2.9354297966115022f*xy*(9.0f*z2 - 1.0f) ; // 3*sqrt(770)*xy*(9*z2 - 1)/(16*sqrt(pi)) dy[34] = 8.3026492595241645f*yz*(-3.0f*x2 + y2) ; // 3*sqrt(385)*yz*(-3*x2 + y2)/(4*sqrt(pi)) dy[35] = 13.127641136803401f*xy*(x2 - y2) ; // 15*sqrt(154)*xy*(x2 - y2)/(8*sqrt(pi)) if (C <= 6) { return; } dy[36] = 4.0991046311514854f*x*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 3*sqrt(6006)*x*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) dy[37] = 11.833095811158762f*z*(6.0f*x2*y2 - x4 - y4) ; // 15*sqrt(2002)*z*(6*x2*y2 - x4 - y4)/(32*sqrt(pi)) dy[38] = 2.0182596029148963f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // 3*sqrt(91)*x*(x2 - 3*y2)*(11*z2 - 1)/(8*sqrt(pi)) dy[39] = -2.7636157785447706f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // -3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(32*sqrt(pi)) dy[40] = 0.92120525951492349f*x*(-18.0f*z2 + 33.0f*z4 + 1.0f) ; // sqrt(2730)*x*(-18*z2 + 33*z4 + 1)/(32*sqrt(pi)) dy[41] = 0.58262136251873131f*z*(30.0f*z2 - 33.0f*z4 - 5.0f) ; // sqrt(273)*z*(30*z2 - 33*z4 - 5)/(16*sqrt(pi)) dy[42] = 0.0f ; // 0 dy[43] = 0.0f ; // 0 dy[44] = 0.92120525951492349f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // sqrt(2730)*y*(18*z2 - 33*z4 - 1)/(32*sqrt(pi)) dy[45] = 5.5272315570895412f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(16*sqrt(pi)) dy[46] = -2.0182596029148963f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(91)*y*(3*x2 - y2)*(11*z2 - 1)/(8*sqrt(pi)) dy[47] = 47.332383244635047f*xy*z*(x2 - y2) ; // 15*sqrt(2002)*xy*z*(x2 - y2)/(8*sqrt(pi)) dy[48] = 4.0991046311514854f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(6006)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) if (C <= 7) { return; } dy[49] = -74.252086915082614f*x2*y4 + 74.252086915082614f*x4*y2 - 4.9501391276721742f*x6 + 4.9501391276721742f*y6 ; // 21*sqrt(715)*(-15*x2*y4 + 15*x4*y2 - x6 + y6)/(64*sqrt(pi)) dy[50] = 15.875763970811402f*xz*(-10.0f*x2*y2 + x4 + 5.0f*y4) ; // 9*sqrt(10010)*xz*(-10*x2*y2 + x4 + 5*y4)/(32*sqrt(pi)) dy[51] = 0.51891557872026028f*(13.0f*z2 - 1.0f)*(10.0f*x2*y2 - 5.0f*x4 + 4.0f*y2*(5.0f*x2 - y2) - y4) ; // 3*sqrt(385)*(13*z2 - 1)*(10*x2*y2 - 5*x4 + 4*y2*(5*x2 - y2) - y4)/(64*sqrt(pi)) dy[52] = 4.1513246297620823f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // 3*sqrt(385)*xz*(x2 - 3*y2)*(13*z2 - 3)/(8*sqrt(pi)) dy[53] = -0.46937680158688211f*(x2 - y2)*(13.0f*z2*(11.0f*z2 - 3.0f) - 27.0f*z2 + 3.0f) ; // -9*sqrt(35)*(x2 - y2)*(13*z2*(11*z2 - 3) - 27*z2 + 3)/(64*sqrt(pi)) dy[54] = 0.44253269244498261f*xz*(-110.0f*z2 + 143.0f*z4 + 15.0f) ; // 3*sqrt(70)*xz*(-110*z2 + 143*z4 + 15)/(32*sqrt(pi)) dy[55] = -12.194767023639836f*z2 + 44.714145753346067f*z4 - 38.752259652899923f*z6 + 0.45165803791258652f ; // sqrt(105)*(-135*z2 + 495*z4 - 429*z6 + 5)/(64*sqrt(pi)) dy[56] = 0.0f ; // 0 dy[57] = 0.0f ; // 0 dy[58] = 0.44253269244498261f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 3*sqrt(70)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) dy[59] = 0.93875360317376422f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 9*sqrt(35)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) dy[60] = -4.1513246297620823f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -3*sqrt(385)*yz*(3*x2 - y2)*(13*z2 - 3)/(8*sqrt(pi)) dy[61] = 10.378311574405206f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 15*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(16*sqrt(pi)) dy[62] = 15.875763970811402f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 9*sqrt(10010)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) dy[63] = 9.9002782553443485f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 21*sqrt(715)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) }; auto write_sh_dz = [&]() { dz[0] = 0.0f ; // 0 if (C <= 1) { return; } dz[1] = 0.0f ; // 0 dz[2] = 0.48860251190291992f ; // sqrt(3)/(2*sqrt(pi)) dz[3] = 0.0f ; // 0 if (C <= 2) { return; } dz[4] = 0.0f ; // 0 dz[5] = -1.0925484305920792f*y ; // -sqrt(15)*y/(2*sqrt(pi)) dz[6] = 1.8923493915151202f*z ; // 3*sqrt(5)*z/(2*sqrt(pi)) dz[7] = -1.0925484305920792f*x ; // -sqrt(15)*x/(2*sqrt(pi)) dz[8] = 0.0f ; // 0 if (C <= 3) { return; } dz[9] = 0.0f ; // 0 dz[10] = 2.8906114426405538f*xy ; // sqrt(105)*xy/(2*sqrt(pi)) dz[11] = -4.5704579946446566f*yz ; // -5*sqrt(42)*yz/(4*sqrt(pi)) dz[12] = 5.597644988851731f*z2 - 1.1195289977703462f ; // 3*sqrt(7)*(5*z2 - 1)/(4*sqrt(pi)) dz[13] = -4.5704579946446566f*xz ; // -5*sqrt(42)*xz/(4*sqrt(pi)) dz[14] = 1.4453057213202769f*x2 - 1.4453057213202769f*y2 ; // sqrt(105)*(x2 - y2)/(4*sqrt(pi)) dz[15] = 0.0f ; // 0 if (C <= 4) { return; } dz[16] = 0.0f ; // 0 dz[17] = 1.7701307697799304f*y*(-3.0f*x2 + y2) ; // 3*sqrt(70)*y*(-3*x2 + y2)/(8*sqrt(pi)) dz[18] = 13.246445740605839f*xy*z ; // 21*sqrt(5)*xy*z/(2*sqrt(pi)) dz[19] = 2.0071396306718676f*y*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*y*(1 - 7*z2)/(8*sqrt(pi)) dz[20] = 14.809976568128603f*pow(z, 3) - 6.3471328149122579f*z ; // (105*z**3 - 45*z)/(4*sqrt(pi)) dz[21] = 2.0071396306718676f*x*(1.0f - 7.0f*z2) ; // 9*sqrt(10)*x*(1 - 7*z2)/(8*sqrt(pi)) dz[22] = 6.6232228703029197f*z*(x2 - y2) ; // 21*sqrt(5)*z*(x2 - y2)/(4*sqrt(pi)) dz[23] = 1.7701307697799304f*x*(-x2 + 3.0f*y2) ; // 3*sqrt(70)*x*(-x2 + 3*y2)/(8*sqrt(pi)) dz[24] = 0.0f ; // 0 if (C <= 5) { return; } dz[25] = 0.0f ; // 0 dz[26] = 8.3026492595241645f*xy*(x2 - y2) ; // 3*sqrt(385)*xy*(x2 - y2)/(4*sqrt(pi)) dz[27] = 8.8062893898345074f*yz*(-3.0f*x2 + y2) ; // 9*sqrt(770)*yz*(-3*x2 + y2)/(16*sqrt(pi)) dz[28] = 4.7935367849733241f*xy*(9.0f*z2 - 1.0f) ; // sqrt(1155)*xy*(9*z2 - 1)/(4*sqrt(pi)) dz[29] = 12.682506233479513f*yz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*yz*(1 - 3*z2)/(4*sqrt(pi)) dz[30] = -24.559567715218954f*z2 + 36.839351572828434f*z4 + 1.754254836801354f ; // 15*sqrt(11)*(-14*z2 + 21*z4 + 1)/(16*sqrt(pi)) dz[31] = 12.682506233479513f*xz*(1.0f - 3.0f*z2) ; // 7*sqrt(165)*xz*(1 - 3*z2)/(4*sqrt(pi)) dz[32] = 2.3967683924866621f*(x2 - y2)*(9.0f*z2 - 1.0f) ; // sqrt(1155)*(x2 - y2)*(9*z2 - 1)/(8*sqrt(pi)) dz[33] = 8.8062893898345074f*xz*(-x2 + 3.0f*y2) ; // 9*sqrt(770)*xz*(-x2 + 3*y2)/(16*sqrt(pi)) dz[34] = -12.453973889286246f*x2*y2 + 2.0756623148810411f*x4 + 2.0756623148810411f*y4 ; // 3*sqrt(385)*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) dz[35] = 0.0f ; // 0 if (C <= 6) { return; } dz[36] = 0.0f ; // 0 dz[37] = 2.3666191622317521f*y*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 3*sqrt(2002)*y*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) dz[38] = 44.401711264127719f*xy*z*(x2 - y2) ; // 33*sqrt(91)*xy*z*(x2 - y2)/(4*sqrt(pi)) dz[39] = -2.7636157785447706f*y*(3.0f*x2 - y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*y*(3*x2 - y2)*(11*z2 - 1)/(32*sqrt(pi)) dz[40] = 11.054463114179082f*xy*z*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*xy*z*(11*z2 - 3)/(8*sqrt(pi)) dz[41] = 2.9131068125936568f*y*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*y*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) dz[42] = 2.6699064952403937f*z*(-30.0f*z2 + 33.0f*z4 + 5.0f) ; // 21*sqrt(13)*z*(-30*z2 + 33*z4 + 5)/(16*sqrt(pi)) dz[43] = 2.9131068125936568f*x*(18.0f*z2 - 33.0f*z4 - 1.0f) ; // 5*sqrt(273)*x*(18*z2 - 33*z4 - 1)/(16*sqrt(pi)) dz[44] = 5.5272315570895412f*z*(x2 - y2)*(11.0f*z2 - 3.0f) ; // 3*sqrt(2730)*z*(x2 - y2)*(11*z2 - 3)/(16*sqrt(pi)) dz[45] = -2.7636157785447706f*x*(x2 - 3.0f*y2)*(11.0f*z2 - 1.0f) ; // -3*sqrt(2730)*x*(x2 - 3*y2)*(11*z2 - 1)/(32*sqrt(pi)) dz[46] = 11.10042781603193f*z*(-6.0f*x2*y2 + x4 + y4) ; // 33*sqrt(91)*z*(-6*x2*y2 + x4 + y4)/(16*sqrt(pi)) dz[47] = 2.3666191622317521f*x*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 3*sqrt(2002)*x*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) dz[48] = 0.0f ; // 0 if (C <= 7) { return; } dz[49] = 0.0f ; // 0 dz[50] = 5.2919213236038001f*xy*(-10.0f*x2*y2 + 3.0f*x4 + 3.0f*y4) ; // 3*sqrt(10010)*xy*(-10*x2*y2 + 3*x4 + 3*y4)/(32*sqrt(pi)) dz[51] = 13.491805046726766f*yz*(10.0f*x2*y2 - 5.0f*x4 - y4) ; // 39*sqrt(385)*yz*(10*x2*y2 - 5*x4 - y4)/(32*sqrt(pi)) dz[52] = 12.453973889286248f*xy*(x2 - y2)*(13.0f*z2 - 1.0f) ; // 9*sqrt(385)*xy*(x2 - y2)*(13*z2 - 1)/(8*sqrt(pi)) dz[53] = -6.8841930899409371f*yz*(3.0f*x2 - y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*yz*(3*x2 - y2)*(13*z2 - 3)/(16*sqrt(pi)) dz[54] = 2.2126634622249131f*xy*(-66.0f*z2 + 143.0f*z4 + 3.0f) ; // 15*sqrt(70)*xy*(-66*z2 + 143*z4 + 3)/(32*sqrt(pi)) dz[55] = 1.6259689364853116f*yz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*yz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) dz[56] = 64.528641681844675f*z2 - 236.60501950009714f*z4 + 205.05768356675085f*z6 - 2.3899496919201733f ; // 7*sqrt(15)*(135*z2 - 495*z4 + 429*z6 - 5)/(32*sqrt(pi)) dz[57] = 1.6259689364853116f*xz*(110.0f*z2 - 143.0f*z4 - 15.0f) ; // 9*sqrt(105)*xz*(110*z2 - 143*z4 - 15)/(32*sqrt(pi)) dz[58] = 0.07375544874083044f*(x2 - y2)*(143.0f*z2*(3.0f*z2 - 1.0f) + 132.0f*z2*(13.0f*z2 - 5.0f) - 187.0f*z2 + 45.0f) ; // sqrt(70)*(x2 - y2)*(143*z2*(3*z2 - 1) + 132*z2*(13*z2 - 5) - 187*z2 + 45)/(64*sqrt(pi)) dz[59] = -6.8841930899409371f*xz*(x2 - 3.0f*y2)*(13.0f*z2 - 3.0f) ; // -33*sqrt(35)*xz*(x2 - 3*y2)*(13*z2 - 3)/(16*sqrt(pi)) dz[60] = 3.1134934723215619f*(13.0f*z2 - 1.0f)*(-6.0f*x2*y2 + x4 + y4) ; // 9*sqrt(385)*(13*z2 - 1)*(-6*x2*y2 + x4 + y4)/(32*sqrt(pi)) dz[61] = 13.491805046726766f*xz*(10.0f*x2*y2 - x4 - 5.0f*y4) ; // 39*sqrt(385)*xz*(10*x2*y2 - x4 - 5*y4)/(32*sqrt(pi)) dz[62] = 39.6894099270285f*x2*y4 - 39.6894099270285f*x4*y2 + 2.6459606618019f*x6 - 2.6459606618019f*y6 ; // 3*sqrt(10010)*(15*x2*y4 - 15*x4*y2 + x6 - y6)/(64*sqrt(pi)) dz[63] = 0.0f ; // 0 }; write_sh_dx(); write_sh_dy(); write_sh_dz(); } } template __global__ void kernel_sh_backward( const scalar_t * __restrict__ grad, const scalar_t * __restrict__ inputs, uint32_t B, uint32_t D, uint32_t C, const scalar_t * __restrict__ dy_dx, scalar_t * grad_inputs ) { const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; const uint32_t b = t / D; if (b >= B) return; const uint32_t d = t - b * D; const uint32_t C2 = C * C; // locate grad += b * C2; dy_dx += b * D * C2 + d * C2; for (int ch = 0; ch < C2; ch++) { grad_inputs[t] += grad[ch] * dy_dx[ch]; //printf("t=%d, b=%d, d=%d, ch=%d, grad=%f (+= %f * %f)\n", t, b, d, ch, grad_inputs[t], grad[ch], dy_dx[ch]); } } // inputs: [B, D], float, in [0, 1] // outputs: [B, L * C], float template void sh_encode_forward_cuda(const scalar_t *inputs, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, scalar_t *dy_dx) { static constexpr uint32_t N_THREADS = 256; kernel_sh<<>>(inputs, outputs, B, D, C, calc_grad_inputs, dy_dx); } template void sh_encode_backward_cuda(const scalar_t *grad, const scalar_t *inputs, const uint32_t B, const uint32_t D, const uint32_t C, scalar_t *dy_dx, scalar_t *grad_inputs) { static constexpr uint32_t N_THREADS = 256; kernel_sh_backward<<>>(grad, inputs, B, D, C, dy_dx, grad_inputs); } void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, at::Tensor dy_dx) { CHECK_CUDA(inputs); CHECK_CUDA(outputs); CHECK_CUDA(dy_dx); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(outputs); CHECK_CONTIGUOUS(dy_dx); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(outputs); CHECK_IS_FLOATING(dy_dx); AT_DISPATCH_FLOATING_TYPES_AND_HALF( inputs.scalar_type(), "sh_encode_forward_cuda", ([&] { sh_encode_forward_cuda(inputs.data_ptr(), outputs.data_ptr(), B, D, C, calc_grad_inputs, dy_dx.data_ptr()); })); } void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs) { CHECK_CUDA(grad); CHECK_CUDA(inputs); CHECK_CUDA(dy_dx); CHECK_CUDA(grad_inputs); CHECK_CONTIGUOUS(grad); CHECK_CONTIGUOUS(inputs); CHECK_CONTIGUOUS(dy_dx); CHECK_CONTIGUOUS(grad_inputs); CHECK_IS_FLOATING(grad); CHECK_IS_FLOATING(inputs); CHECK_IS_FLOATING(dy_dx); CHECK_IS_FLOATING(grad_inputs); AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "sh_encode_backward_cuda", ([&] { sh_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), B, D, C, dy_dx.data_ptr(), grad_inputs.data_ptr()); })); } ================================================ FILE: shencoder/src/shencoder.h ================================================ # pragma once #include #include // inputs: [B, D], float, in [-1, 1] // outputs: [B, F], float // encode_forward(inputs, outputs, B, input_dim, degree, calc_grad_inputs, dy_dx) void sh_encode_forward(at::Tensor inputs, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const bool calc_grad_inputs, at::Tensor dy_dx); // sh_encode_backward(grad, inputs, B, input_dim, degree, ctx.calc_grad_inputs, dy_dx, grad_inputs) void sh_encode_backward(at::Tensor grad, at::Tensor inputs, const uint32_t B, const uint32_t D, const uint32_t C, at::Tensor dy_dx, at::Tensor grad_inputs); ================================================ FILE: tools/activation.py ================================================ import torch from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd class _trunc_exp(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # cast to float32 def forward(ctx, x): ctx.save_for_backward(x) return torch.exp(x) @staticmethod @custom_bwd def backward(ctx, g): x = ctx.saved_tensors[0] # return g * torch.exp(x.clamp(-15, 15)) return g * torch.exp(x.clamp(-12, 12)) trunc_exp = _trunc_exp.apply ================================================ FILE: tools/details.md ================================================ # custom datasets Our dataset format is based on the [torch-ngp](https://github.com/ashawkey/torch-ngp/tree/3b066b6cd6ccd3610cb66a56a54f5daaf12a8033), which totally supports custom dataset in form of colmap. The specific steps for supporting custom dataset are as follows: - 1. take a video / many photos from different views - 2. put the video under a path like ./data/custom/video.mp4 or the images under ./data/custom/images/*.jpg. - 3. call the preprocess code: (should install ffmpeg and colmap first! refer to the [colmap2nerf.py](https://github.com/ashawkey/torch-ngp/blob/3b066b6cd6ccd3610cb66a56a54f5daaf12a8033/scripts/colmap2nerf.py) for more options.) - python colmap2nerf.py --video ./data/custom/video.mp4 --run_colmap # if use video - python colmap2nerf.py --images ./data/custom/images/ --run_colmap # if use images - 4. it should create the transform.json, and you can train with: (you'll need to try with different scale & bound & dt_gamma to make the object correctly located in the bounding box and render fluently.) Then you can train a teacher and distill students with various structures according to our introduction:https://github.com/megvii-research/AAAI2023-PVD # some ways to reduce GPU memory - 1. Reduce the value of parameter '--num_rays' . - 2. Try to disable the parameter of '--preload'. When preload=True, it will load all image data into the gpu. For images with large resolution, it will occupy memory seriously (But I haven't tested this parameter, which is inherited from torch-ngp). - 3. If you don't have too strict requirements for image resolution, you can use the downsampled image for experiment. - 4. In the distillation process, due to the need to load the student and the teacher network at the same time, it will consume more memory. One solution is to separately inference the teacher network in advance and record the data required for distillation, and use these data to guide the training of the student. - 5. For different model(INGP/Plenoxels/NeRF/TensoRF), there are different parameters to adjust the model size. For example, you can reduce the number and resolution of hash tables in INGP, reduce the resolution of Plenoxels or tensoRF, and reduce the number of MLP parameters in NeRF, etc. - 6. The current code does not support multi-GPUs temporarily, but it should be easy to implement. If the above cannot solve your problem, you can try to implement DDP. ================================================ FILE: tools/encoding.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class FreqEncoder(nn.Module): def __init__( self, input_dim, max_freq_log2, N_freqs, log_sampling=True, include_input=True, periodic_fns=(torch.sin, torch.cos), ): super().__init__() self.input_dim = input_dim self.include_input = include_input self.periodic_fns = periodic_fns self.output_dim = 0 if self.include_input: self.output_dim += self.input_dim self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) if log_sampling: self.freq_bands = 2.0 ** torch.linspace(0.0, max_freq_log2, N_freqs) else: self.freq_bands = torch.linspace(2.0 ** 0.0, 2.0 ** max_freq_log2, N_freqs) self.freq_bands = self.freq_bands.numpy().tolist() def forward(self, input, **kwargs): out = [] if self.include_input: out.append(input) for i in range(len(self.freq_bands)): freq = self.freq_bands[i] for p_fn in self.periodic_fns: out.append(p_fn(input * freq)) out = torch.cat(out, dim=-1) return out def get_encoder( encoding, input_dim=3, multires=6, degree=4, num_levels=14, level_dim=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=4096, align_corners=False, **kwargs ): if encoding == "None": return lambda x, **kwargs: x, input_dim elif encoding == "frequency": encoder = FreqEncoder( input_dim=input_dim, max_freq_log2=multires - 1, N_freqs=multires, log_sampling=True, ) elif encoding == "sphere_harmonics": from shencoder import SHEncoder encoder = SHEncoder(input_dim=input_dim, degree=degree) elif encoding == "hashgrid": from gridencoder import GridEncoder encoder = GridEncoder( input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype="hash", align_corners=align_corners, ) elif encoding == "tiledgrid": from gridencoder import GridEncoder encoder = GridEncoder( input_dim=input_dim, num_levels=num_levels, level_dim=level_dim, base_resolution=base_resolution, log2_hashmap_size=log2_hashmap_size, desired_resolution=desired_resolution, gridtype="tiled", align_corners=align_corners, ) elif encoding == "ash": from ashencoder import AshEncoder encoder = AshEncoder( input_dim=input_dim, output_dim=16, log2_hashmap_size=log2_hashmap_size, resolution=desired_resolution, ) else: raise NotImplementedError() return encoder, encoder.output_dim ================================================ FILE: tools/install_extensions.sh ================================================ cd raymarching pip install . cd .. cd gridencoder pip install . cd .. cd shencoder pip install . cd .. ================================================ FILE: tools/requirements.txt ================================================ torch-ema ninja trimesh opencv-python tensorboardX torch numpy pandas tqdm matplotlib PyMCubes rich pysdf dearpygui packaging scipy lpips imageio ================================================ FILE: tools/中文介绍.md ================================================ ## One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation (**Accepted by AAAI 2023**) # 目前PVD的后续工作PVD-AL已经开源,它比PVD更强大,建议使用PVD-AL: [代码地址](https://github.com/megvii-research/AAAI2023-PVD/tree/PVD-AL). ## [项目视频](http://sk-fun.fun/PVD/) | [论文](https://arxiv.org/abs/2211.15977) | [数据集](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing) | [预训练权重](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing) | [英文介绍](https://github.com/megvii-research/AAAI2023-PVD/blob/main/README.md) | ## 论文简单介绍 - 目前NeRF系列模型层出不穷,令人“眼花缭乱”,有基于隐式MLP的NeRF,基于纯显式张量结构的Plenoxels, 基于混合结构的TensoRF(低秩张量+MLP) 和 INGP(Hash+MLP),以及其他各种变体。 - 不同结构的模型,其实特点是不一样。完全隐式的纯MLP结构,其高层语义特征可以拿来做很多事情,比如[光照/天气改变](https://nerf-w.github.io/),[艺术设计](https://pfnet-research.github.io/distilled-feature-fields/),而基于纯张量显式结构的模型,其空间结构清晰,容易进行[剪切/组合/放大/缩小/抹去](https://github.com/ashawkey/CCNeRF)等操作。TensoRF和INGP则介于两者之间,其优势更多在于训练快和重建质量高;此外硬件设施(如手机终端)对不同结构的支持度是完全不同的,选择什么样的结构用于下游任务需要一定的设计经验。 - 为了减轻设计者的选择痛苦以及进行不同结构间的特性迁移,我们开展了本文的研究。目前不同结构间是否存在转化的可能性尚未被研究,我们认为首次尝试是有意义的。 - 我们的目标是希望将某个架构的特性转移到其它不同架构上。比如INGP快速收敛的特点能快速得到一个模型,进而可用蒸馏方式训练一个NeRF,也起到了加速效果,并且某些数据集上还能起到涨点的效果。此外还可以将显示结构的可编辑性转移到别的非显示结构上,比如对Plenoxels的张量结构进行场景组合,场景切分等操作,然后将其蒸馏到其他模型,使其它模型也具有渲染出编辑场景的效果。实验证明显式结构的空间编辑能力可以成功且高质量的迁移到其他结构上:[我们的示例](http://sk-fun.fun/PVD/) - 为何能进行蒸馏,能够为窥探这些模型的内部原理提供一些insight。比如模型间的中间feature可对齐意味着不同结构间的模型,实质上可以映射到相近的空间。 ## 安装 建议使用 [Anaconda](https://www.anaconda.com/) 进行安装,避免污染本机环境. 执行以下命令: *Step1*: 创建名为 'pvd' 的conda 环境 ``` conda create --name pvd python=3.7 conda activate pvd pip install -r ./tools/requirements.txt ``` *Step2*: 安装C++/cuda扩展. (借鉴自 [torch-ngp](https://github.com/ashawkey/torch-ngp)) ``` bash ./tools/install_extension.sh ``` ## 数据集 & 预训练模型 Synthetic-NeRF/LLFF/Tanks&Temples: [google云盘](https://drive.google.com/drive/folders/1U06KAEsW53PolLI3U8hWUhzzIH74QGaP?usp=sharing), [baidu云盘](https://pan.baidu.com/s/1ky_TWrbUZG_MpHTBhncAKA?pwd=4h2h). 预训练模型: [google云盘](https://drive.google.com/drive/folders/1GGJf-FTmpCJjmEn-AF_S9-HrLRkFe5Ud?usp=sharing), [baidu云盘](https://pan.baidu.com/s/1LGLXwLGusX60GpAywLwosg?pwd=34k8). 不下载与训练模型,直接按照下面的方法训练一个teacher,也很快. ## 训练teacher ``` # train a hash-based(INGP) teacher python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type hash --data_type synthetic --workspace ./log/train_teacher/hash_chair # train a sparse-tensor-based(TensoRF VM-decomposion) teacher python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type vm --data_type synthetic --workspace ./log/train_teacher/vm_chair # train a MLP-based(NeRF) teacher python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type mlp --data_type synthetic --workspace ./log/train_teacher/mlp_chair # train a tensors-based(Plenoxels) teacher python main_just_train_tea.py ./data/nerf_synthetic/chair --model_type tensors --data_type synthetic --workspace ./log/train_teacher/tensors_chair ``` ## 蒸馏模型 ``` # teacher: hash(INGP), student: vm(tensoRF) python3 main_distill_mutual.py ./data/nerf_synthetic/chair \ --data_type synthetic \ --teacher_type hash \ --ckpt_teacher ./log/train_teacher/hash_chair/checkpoints/XXX.pth \ --model_type vm \ --workspace ./log/distill_student/hash2vm/chair # teacher: MLP(NeRF), student: tensors(Plenoxels) python3 main_distill_mutual.py ./data/nerf_synthetic/chair \ --data_type synthetic \ --teacher_type mlp \ --ckpt_teacher ./log/train_teacher/mlp_chair/checkpoints/XXX.pth \ --model_type tensors \ --workspace ./log/distill_student/mlp2tensors/chair ``` ## 测试 ``` # evaluate a hash teacher python main_distill_mutual.py ./data/nerf_synthetic/chair --teacher_type hash --ckpt_teacher PATH/TO/CKPT.pth --test_teacher --data_type synthetic --workspace ./log/eval_teacher/hash_chair # evaluate a mlp student python main_distill_mutual.py ./data/nerf_synthetic/chair --model_type mlp --ckpt PATH/TO/CKPT.pth --test --data_type synthetic --workspace ./log/eval_student/mlp_chair ``` ## 更多数据集上的使用以及更多运行命令参照如下 [more running description](https://github.com/megvii-research/AAAI2023-PVD/blob/main/tools/details.md) ## Citation 如果你觉得有用,可考虑引用我们的文章: ``` @article{pvd2023, author = {Fang, Shuangkang and Xu, Weixin and Wang, Heng and Yang, Yi and Wang, Yufeng and Zhou, Shuchang}, title = {One is All: Bridging the Gap Between Neural Radiance Fields Architectures with Progressive Volume Distillation}, journal = {AAAI}, year = {2023} } ``` ### 致谢 感谢[ingp](https://github.com/NVlabs/instant-ngp), [torch-ngp](https://github.com/ashawkey/torch-ngp), [TensoRF](https://github.com/apchenstu/TensoRF), [Plenoxels](https://github.com/sxyu/svox2), [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch) 的漂亮框架! 此外可以参考[Arch-Net](https://github.com/megvii-research/Arch-Net) 获取更多关于渐进蒸馏的思想