Showing preview only (217K chars total). Download the full file or copy to clipboard to get everything.
Repository: Kunhao-Liu/StyleRF
Branch: main
Commit: 43611c3ec7ba
Files: 36
Total size: 205.9 KB
Directory structure:
gitextract_j81cq_2h/
├── README.md
├── configs/
│ ├── llff.txt
│ ├── llff_feature.txt
│ ├── llff_style.txt
│ ├── nerf_synthetic.txt
│ ├── nerf_synthetic_feature.txt
│ └── nerf_synthetic_style.txt
├── dataLoader/
│ ├── __init__.py
│ ├── blender.py
│ ├── colmap2nerf.py
│ ├── llff.py
│ ├── nsvf.py
│ ├── ray_utils.py
│ ├── styleLoader.py
│ ├── tankstemple.py
│ └── your_own_data.py
├── extra/
│ ├── auto_run_paramsets.py
│ └── compute_metrics.py
├── models/
│ ├── VGG.py
│ ├── __init__.py
│ ├── sh.py
│ ├── styleModules.py
│ ├── tensoRF.py
│ └── tensorBase.py
├── opt.py
├── renderer.py
├── scripts/
│ ├── test.sh
│ ├── test_feature.sh
│ ├── test_style.sh
│ ├── train.sh
│ ├── train_feature.sh
│ └── train_style.sh
├── train.py
├── train_feature.py
├── train_style.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# [*CVPR 2023*] StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields
## [Project page](https://kunhao-liu.github.io/StyleRF/) | [Paper](https://arxiv.org/abs/2303.10598)
This repository contains a pytorch implementation for the paper: [StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields](https://arxiv.org/abs/2303.10598). StyleRF is an innovative 3D style transfer technique that achieves superior 3D stylization quality with precise geometry reconstruction and it can generalize to various new styles in a zero-shot manner.

---
## Installation
> Tested on Ubuntu 20.04 + Pytorch 1.12.1
Install environment:
```
conda create -n StyleRF python=3.9
conda activate StyleRF
pip install torch torchvision
pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg kornia lpips tensorboard
```
## Datasets
Please put the datasets in `./data`. You can put the datasets elsewhere if you modify the corresponding paths in the configs.
### 3D scene datasets
* [nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
* [llff](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
### Style image dataset
* [WikiArt](https://www.kaggle.com/datasets/ipythonx/wikiart-gangogh-creating-art-gan)
## Quick Start
We provide some trained checkpoints in: [StyleRF checkpoints](https://drive.google.com/drive/folders/1nF9-6lTIhktG5JjNvnmdYOo1LTvtK7Dw?usp=share_link)
Then modify the following attributes in `scripts/test_style.sh`:
* `--config`: choose `configs/llff_style.txt` or `configs/nerf_synthetic_style.txt` according to which type of dataset is being used
* `--datadir`: dataset's path
* `--ckpt`: checkpoint's path
* `--style_img`: reference style image's path
To generate stylized novel views:
```
bash scripts/test_style.sh [GPU ID]
```
The rendered stylized images can then be found in the directory under the checkpoint's path.
## Training
> Current settings in `configs` are tested on one NVIDIA RTX A5000 Graphics Card with 24G memory. To reduce memory consumption, you can set `batch_size`, `chunk_size` or `patch_size` to a smaller number.
We follow the following 3 steps of training:
### 1. Train original TensoRF
This step is for reconstructing the density field, which contains more precise geometry details compared to mesh-based methods. You can skip this step by directly downloading pre-trained checkpoints provided by [TensoRF checkpoints](https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm).
The configs are stored in `configs/llff.txt` and `configs/nerf_synthetic.txt`. For the details of the settings, please also refer to [TensoRF](https://github.com/apchenstu/TensoRF). The checkpoints are stored in `./log` by default.
You can train the original TensoRF by:
```
bash script/train.sh [GPU ID]
```
### 2. Feature grid training stage
This step is for reconstructing the 3D gird containing the VGG features.
The configs are stored in `configs/llff_feature.txt` and `configs/nerf_synthetic_feature.txt`, in which `ckpt` specifies the checkpoints trained in the **first** step. The checkpoints are stored in `./log_feature` by default.
Then run:
```
bash script/train_feature.sh [GPU ID]
```
### 3. Stylization training stage
This step is for training the style transfer modules.
The configs are stored in `configs/llff_style.txt` and `configs/nerf_synthetic_style.txt`, in which `ckpt` specifies the checkpoints trained in the **second** step. The checkpoints are stored in `./log_style` by default.
Then run:
```
bash script/train_style.sh [GPU ID]
```
---
## Training on 360 Unbounded Scenes
The code for training StyleRF on the Tanks&Temples dataset is available on the `360` branch. To access it, run `git checkout 360`.
## Acknowledgments
This repo is heavily based on the [TensoRF](https://github.com/apchenstu/TensoRF). Thank them for sharing their amazing work!
## Citation
If you find our code or paper helps, please consider citing:
```
@inproceedings{liu2023stylerf,
title={StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields},
author={Liu, Kunhao and Zhan, Fangneng and Chen, Yiwen and Zhang, Jiahui and Yu, Yingchen and El Saddik, Abdulmotaleb and Lu, Shijian and Xing, Eric P},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={8338--8348},
year={2023}
}
```
================================================
FILE: configs/llff.txt
================================================
dataset_name = llff
datadir = ./data/nerf_llff_data/trex
expname = trex
basedir = ./log
downsample_train = 4.0
ndc_ray = 1
n_iters = 25000
batch_size = 4096
N_voxel_init = 2097156 # 128**3
N_voxel_final = 262144000 # 640**3
upsamp_list = [2000,3000,4000,5500]
update_AlphaMask_list = [2500]
N_vis = -1 # vis all testing images
vis_every = 10000
render_test = 1
render_path = 1
n_lamb_sigma = [16,4,4]
n_lamb_sh = [48,12,12]
shadingMode = MLP_Fea
fea2denseAct = relu
view_pe = 0
fea_pe = 0
TV_weight_density = 1.0
TV_weight_app = 1.0
================================================
FILE: configs/llff_feature.txt
================================================
dataset_name = llff
datadir = ./data/nerf_llff_data/trex
ckpt = ./log/trex/trex.th
expname = trex
basedir = ./log_feature
TV_weight_feature = 80
downsample_train = 4.0
ndc_ray = 1
n_iters = 25000
patch_size = 256
batch_size = 4096
chunk_size = 4096
N_voxel_init = 2097156 # 128**3
N_voxel_final = 262144000 # 640**3
upsamp_list = [2000,3000,4000,5500]
update_AlphaMask_list = [2500]
n_lamb_sigma = [16,4,4]
n_lamb_sh = [48,12,12]
fea2denseAct = relu
================================================
FILE: configs/llff_style.txt
================================================
dataset_name = llff
datadir = ./data/nerf_llff_data/trex
ckpt = ./log_feature/trex/trex.th
expname = trex
basedir = ./log_style
nSamples = 300
patch_size = 256
chunk_size = 2048
content_weight = 1
style_weight = 20
featuremap_tv_weight = 0
image_tv_weight = 0
rm_weight_mask_thre = 0.001
downsample_train = 4.0
ndc_ray = 1
n_iters = 25000
n_lamb_sigma = [16,4,4]
n_lamb_sh = [48,12,12]
N_voxel_init = 2097156 # 128**3
N_voxel_final = 262144000 # 640**3
fea2denseAct = relu
================================================
FILE: configs/nerf_synthetic.txt
================================================
dataset_name = blender
datadir = ./data/nerf_synthetic/lego
expname = lego
basedir = ./log
n_iters = 30000
batch_size = 4096
N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3
upsamp_list = [2000,3000,4000,5500,7000]
update_AlphaMask_list = [2000,4000]
N_vis = 5
vis_every = 10000
render_test = 1
n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]
model_name = TensorVMSplit
shadingMode = MLP_Fea
fea2denseAct = softplus
view_pe = 2
fea_pe = 2
L1_weight_inital = 8e-5
L1_weight_rest = 4e-5
rm_weight_mask_thre = 1e-4
================================================
FILE: configs/nerf_synthetic_feature.txt
================================================
dataset_name = blender
datadir = ./data/nerf_synthetic/lego
ckpt = ./log/lego/lego.th
expname = lego
basedir = ./log_feature
TV_weight_feature = 10
n_iters = 25000
patch_size = 256
batch_size = 4096
chunk_size = 4096
N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3
upsamp_list = [2000,3000,4000,5500,7000]
update_AlphaMask_list = [2000,4000]
rm_weight_mask_thre = 0.01
n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]
fea2denseAct = softplus
================================================
FILE: configs/nerf_synthetic_style.txt
================================================
dataset_name = blender
datadir = ./data/nerf_synthetic/lego
ckpt = ./log_feature/lego/lego.th
expname = lego
basedir = ./log_style
patch_size = 256
chunk_size = 2048
content_weight = 1
style_weight = 20
rm_weight_mask_thre = 0.01
n_iters = 25000
n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]
N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3
fea2denseAct = softplus
================================================
FILE: dataLoader/__init__.py
================================================
from .llff import LLFFDataset
from .blender import BlenderDataset
from .nsvf import NSVF
from .tankstemple import TanksTempleDataset
from .your_own_data import YourOwnDataset
dataset_dict = {'blender': BlenderDataset,
'llff':LLFFDataset,
'tankstemple':TanksTempleDataset,
'nsvf':NSVF,
'own_data':YourOwnDataset}
================================================
FILE: dataLoader/blender.py
================================================
import torch,cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
class BlenderDataset(Dataset):
def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
self.N_vis = N_vis
self.root_dir = datadir
self.split = split
self.is_stack = is_stack
self.img_wh = (int(800/downsample),int(800/downsample))
self.define_transforms()
self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
self.white_bg = True
self.near_far = [2.0,6.0]
self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
self.downsample=downsample
def read_depth(self, filename):
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
return depth
def read_meta(self):
with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
self.meta = json.load(f)
w, h = self.img_wh
self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float()
self.image_paths = []
self.poses = []
self.all_rays = []
self.all_rgbs = []
self.all_masks = []
self.all_depth = []
self.downsample=1.0
img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
frame = self.meta['frames'][i]
pose = np.array(frame['transform_matrix']) @ self.blender2opencv
c2w = torch.FloatTensor(pose)
self.poses += [c2w]
image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
self.image_paths += [image_path]
img = Image.open(image_path)
if self.downsample!=1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
self.all_masks.append(img[:, -1:].reshape(h,w,1)) # (h, w, 1) A
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
self.all_masks = torch.stack(self.all_masks) # (n_frames, h, w, 1)
self.poses = torch.stack(self.poses)
all_rays = self.all_rays
all_rgbs = self.all_rgbs
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
if self.is_stack:
self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames]),h,w,6)
avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)
self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)
self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
@torch.no_grad()
def prepare_feature_data(self, encoder, chunk=8):
'''
Prepare feature maps as training data.
'''
assert self.is_stack, 'Dataset should contain original stacked taining data!'
print('====> prepare_feature_data ...')
frames_num, h, w, _ = self.all_rgbs_stack.size()
features = []
for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)):
rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()
features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1
# resize to the size of rgb map so that rays can match
features_chunk = T.functional.resize(features_chunk, size=(h,w),
interpolation=T.InterpolationMode.BILINEAR)
features.append(features_chunk.detach().cpu().requires_grad_(False))
self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)
self.all_features = self.all_features_stack.reshape(-1, 256)
print('prepare_feature_data Done!')
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
def world2ndc(self,points,lindisp=None):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
return len(self.all_rgbs)
def __getitem__(self, idx):
if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
else: # create data for each image separately
img = self.all_rgbs[idx]
rays = self.all_rays[idx]
mask = self.all_masks[idx] # for quantity evaluation
sample = {'rays': rays,
'rgbs': img,
'mask': mask}
return sample
================================================
FILE: dataLoader/colmap2nerf.py
================================================
#!/usr/bin/env python3
# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import argparse
import os
from pathlib import Path, PurePosixPath
import numpy as np
import json
import sys
import math
import cv2
import os
import shutil
def parse_args():
parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place")
parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also")
parser.add_argument("--video_fps", default=2)
parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video")
parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder")
parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images")
parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename")
parser.add_argument("--images", default="images", help="input path to the images")
parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
parser.add_argument("--out", default="transforms.json", help="output path")
args = parser.parse_args()
return args
def do_system(arg):
print(f"==== running: {arg}")
err = os.system(arg)
if err:
print("FATAL: command failed")
sys.exit(err)
def run_ffmpeg(args):
if not os.path.isabs(args.images):
args.images = os.path.join(os.path.dirname(args.video_in), args.images)
images = args.images
video = args.video_in
fps = float(args.video_fps) or 1.0
print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.")
if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
sys.exit(1)
try:
shutil.rmtree(images)
except:
pass
do_system(f"mkdir {images}")
time_slice_value = ""
time_slice = args.time_slice
if time_slice:
start, end = time_slice.split(",")
time_slice_value = f",select='between(t\,{start}\,{end})'"
do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg")
def run_colmap(args):
db=args.colmap_db
images=args.images
db_noext=str(Path(db).with_suffix(""))
if args.text=="text":
args.text=db_noext+"_text"
text=args.text
sparse=db_noext+"_sparse"
print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}")
if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
sys.exit(1)
if os.path.exists(db):
os.remove(db)
do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}")
do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}")
try:
shutil.rmtree(sparse)
except:
pass
do_system(f"mkdir {sparse}")
do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}")
do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1")
try:
shutil.rmtree(text)
except:
pass
do_system(f"mkdir {text}")
do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT")
def variance_of_laplacian(image):
return cv2.Laplacian(image, cv2.CV_64F).var()
def sharpness(imagePath):
image = cv2.imread(imagePath)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
fm = variance_of_laplacian(gray)
return fm
def qvec2rotmat(qvec):
return np.array([
[
1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
], [
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
], [
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
]
])
def rotmat(a, b):
a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
v = np.cross(a, b)
c = np.dot(a, b)
s = np.linalg.norm(v)
kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))
def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
da = da / np.linalg.norm(da)
db = db / np.linalg.norm(db)
c = np.cross(da, db)
denom = np.linalg.norm(c)**2
t = ob - oa
ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
if ta > 0:
ta = 0
if tb > 0:
tb = 0
return (oa+ta*da+ob+tb*db) * 0.5, denom
if __name__ == "__main__":
args = parse_args()
if args.video_in != "":
run_ffmpeg(args)
if args.run_colmap:
run_colmap(args)
AABB_SCALE = int(args.aabb_scale)
SKIP_EARLY = int(args.skip_early)
IMAGE_FOLDER = args.images
TEXT_FOLDER = args.text
OUT_PATH = args.out
print(f"outputting to {OUT_PATH}...")
with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f:
angle_x = math.pi / 2
for line in f:
# 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691
# 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224
# 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443
if line[0] == "#":
continue
els = line.split(" ")
w = float(els[2])
h = float(els[3])
fl_x = float(els[4])
fl_y = float(els[4])
k1 = 0
k2 = 0
p1 = 0
p2 = 0
cx = w / 2
cy = h / 2
if els[1] == "SIMPLE_PINHOLE":
cx = float(els[5])
cy = float(els[6])
elif els[1] == "PINHOLE":
fl_y = float(els[5])
cx = float(els[6])
cy = float(els[7])
elif els[1] == "SIMPLE_RADIAL":
cx = float(els[5])
cy = float(els[6])
k1 = float(els[7])
elif els[1] == "RADIAL":
cx = float(els[5])
cy = float(els[6])
k1 = float(els[7])
k2 = float(els[8])
elif els[1] == "OPENCV":
fl_y = float(els[5])
cx = float(els[6])
cy = float(els[7])
k1 = float(els[8])
k2 = float(els[9])
p1 = float(els[10])
p2 = float(els[11])
else:
print("unknown camera model ", els[1])
# fl = 0.5 * w / tan(0.5 * angle_x);
angle_x = math.atan(w / (fl_x * 2)) * 2
angle_y = math.atan(h / (fl_y * 2)) * 2
fovx = angle_x * 180 / math.pi
fovy = angle_y * 180 / math.pi
print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ")
with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f:
i = 0
bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])
out = {
"camera_angle_x": angle_x,
"camera_angle_y": angle_y,
"fl_x": fl_x,
"fl_y": fl_y,
"k1": k1,
"k2": k2,
"p1": p1,
"p2": p2,
"cx": cx,
"cy": cy,
"w": w,
"h": h,
"aabb_scale": AABB_SCALE,
"frames": [],
}
up = np.zeros(3)
for line in f:
line = line.strip()
if line[0] == "#":
continue
i = i + 1
if i < SKIP_EARLY*2:
continue
if i % 2 == 1:
elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)
#name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))
# why is this requireing a relitive path while using ^
image_rel = os.path.relpath(IMAGE_FOLDER)
name = str(f"./{image_rel}/{'_'.join(elems[9:])}")
b=sharpness(name)
print(name, "sharpness=",b)
image_id = int(elems[0])
qvec = np.array(tuple(map(float, elems[1:5])))
tvec = np.array(tuple(map(float, elems[5:8])))
R = qvec2rotmat(-qvec)
t = tvec.reshape([3,1])
m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
c2w = np.linalg.inv(m)
c2w[0:3,2] *= -1 # flip the y and z axis
c2w[0:3,1] *= -1
c2w = c2w[[1,0,2,3],:] # swap y and z
c2w[2,:] *= -1 # flip whole world upside down
up += c2w[0:3,1]
frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
out["frames"].append(frame)
nframes = len(out["frames"])
up = up / np.linalg.norm(up)
print("up vector was", up)
R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
R = np.pad(R,[0,1])
R[-1, -1] = 1
for f in out["frames"]:
f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis
# find a central point they are all looking at
print("computing center of attention...")
totw = 0.0
totp = np.array([0.0, 0.0, 0.0])
for f in out["frames"]:
mf = f["transform_matrix"][0:3,:]
for g in out["frames"]:
mg = g["transform_matrix"][0:3,:]
p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
if w > 0.01:
totp += p*w
totw += w
totp /= totw
print(totp) # the cameras are looking at totp
for f in out["frames"]:
f["transform_matrix"][0:3,3] -= totp
avglen = 0.
for f in out["frames"]:
avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
avglen /= nframes
print("avg camera distance from origin", avglen)
for f in out["frames"]:
f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"
for f in out["frames"]:
f["transform_matrix"] = f["transform_matrix"].tolist()
print(nframes,"frames")
print(f"writing {OUT_PATH}")
with open(OUT_PATH, "w") as outfile:
json.dump(out, outfile, indent=2)
================================================
FILE: dataLoader/llff.py
================================================
import torch
from torch.utils.data import Dataset
import glob
import numpy as np
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
def normalize(v):
"""Normalize a vector."""
return v / np.linalg.norm(v)
def average_poses(poses):
"""
Calculate the average pose, which is then used to center all poses
using @center_poses. Its computation is as follows:
1. Compute the center: the average of pose centers.
2. Compute the z axis: the normalized average z axis.
3. Compute axis y': the average y axis.
4. Compute x' = y' cross product z, then normalize it as the x axis.
5. Compute the y axis: z cross product x.
Note that at step 3, we cannot directly use y' as y axis since it's
not necessarily orthogonal to z axis. We need to pass from x to y.
Inputs:
poses: (N_images, 3, 4)
Outputs:
pose_avg: (3, 4) the average pose
"""
# 1. Compute the center
center = poses[..., 3].mean(0) # (3)
# 2. Compute the z axis
z = normalize(poses[..., 2].mean(0)) # (3)
# 3. Compute axis y' (no need to normalize as it's not the final output)
y_ = poses[..., 1].mean(0) # (3)
# 4. Compute the x axis
x = normalize(np.cross(z, y_)) # (3)
# 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
y = np.cross(x, z) # (3)
pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
return pose_avg
def center_poses(poses, blender2opencv):
"""
Center the poses so that we can use NDC.
See https://github.com/bmild/nerf/issues/34
Inputs:
poses: (N_images, 3, 4)
Outputs:
poses_centered: (N_images, 3, 4) the centered poses
pose_avg: (3, 4) the average pose
"""
poses = poses @ blender2opencv
pose_avg = average_poses(poses) # (3, 4)
pose_avg_homo = np.eye(4)
pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
pose_avg_homo = pose_avg_homo
# by simply adding 0, 0, 0, 1 as the last row
last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
poses_homo = \
np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
# poses_centered = poses_centered @ blender2opencv
poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
return poses_centered, pose_avg_homo
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.eye(4)
m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
return m
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
render_poses = []
rads = np.array(list(rads) + [1.])
for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
render_poses.append(viewmatrix(z, up, c))
return render_poses
def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
# center pose
c2w = average_poses(c2ws_all)
# Get average pose
up = normalize(c2ws_all[:, :3, 1].sum(0))
# Find a reasonable "focus depth" for this dataset
dt = 0.75
close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
# Get radii for spiral path
zdelta = near_fars.min() * .2
tt = c2ws_all[:, :3, 3]
rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
return np.stack(render_poses)
def get_interpolation_path(c2ws_all, steps=30):
# flower
# idx0 = 1
# idx1 = 10
# trex
# idx0 = 8
# idx1 = 53
# horns
idx0 = 18
idx1 = 47
v = np.linspace(0,1,num=steps)
c2w0 = c2ws_all[idx0]
c2w1 = c2ws_all[idx1]
c2w_ = []
for i in range(steps):
c2w_.append(c2w0*v[i] + c2w1*(1-v[i]))
return np.stack(c2w_)
class LLFFDataset(Dataset):
def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8):
self.root_dir = datadir
self.split = split
self.hold_every = hold_every
self.is_stack = is_stack
self.downsample = downsample
self.define_transforms()
self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.white_bg = False
# self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
self.near_far = [0.0, 1.0]
self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])
# self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
def read_meta(self):
poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17)
self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
# load full resolution image then resize
if self.split in ['train', 'test']:
assert len(poses_bounds) == len(self.image_paths), \
'Mismatch between number of images and number of poses! Please rerun COLMAP!'
poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
self.near_fars = poses_bounds[:, -2:] # (N_images, 2)
hwf = poses[:, :, -1]
# Step 1: rescale focal length according to training resolution
H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images
self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]
# Step 2: correct poses
# Original poses has rotation in form "down right back", change to "right up back"
# See https://github.com/bmild/nerf/issues/34
poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
# (N_images, 3, 4) exclude H, W, focal
self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)
# Step 3: correct scale so that the nearest depth is at a little more than 1.0
# See https://github.com/bmild/nerf/issues/34
near_original = self.near_fars.min()
scale_factor = near_original * 0.75 # 0.75 is the default parameter
# the nearest depth is at 1/0.75=1.33
self.near_fars /= scale_factor
self.poses[..., 3] /= scale_factor
# build rendering path
N_views, N_rots = 120, 2
tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
up = normalize(self.poses[:, :3, 1].sum(0))
rads = np.percentile(np.abs(tt), 90, 0)
self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
# self.render_path = get_interpolation_path(self.poses)
# distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
# val_idx = np.argmin(distances_from_center) # choose val image as the closest to
# center image
# ray directions for all pixels, same for all images (same H, W, focal)
W, H = self.img_wh
self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3)
average_pose = average_poses(self.poses)
dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)]
img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))
# use first N_images-1 to train, the LAST is val
self.all_rays = []
self.all_rgbs = []
for i in img_list:
image_path = self.image_paths[i]
c2w = torch.FloatTensor(self.poses[i])
img = Image.open(image_path).convert('RGB')
if self.downsample != 1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (3, h, w)
img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
self.all_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
# viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
all_rays = self.all_rays
all_rgbs = self.all_rgbs
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
if self.is_stack:
self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames]),h,w,6)
avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)
self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)
self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
@torch.no_grad()
def prepare_feature_data(self, encoder, chunk=8):
'''
Prepare feature maps as training data.
'''
assert self.is_stack, 'Dataset should contain original stacked taining data!'
print('====> prepare_feature_data ...')
frames_num, h, w, _ = self.all_rgbs_stack.size()
features = []
for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)):
rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()
features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1
# resize to the size of rgb map so that rays can match
features_chunk = T.functional.resize(features_chunk, size=(h,w),
interpolation=T.InterpolationMode.BILINEAR)
features.append(features_chunk.detach().cpu().requires_grad_(False))
self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)
self.all_features = self.all_features_stack.reshape(-1, 256)
print('prepare_feature_data Done!')
def define_transforms(self):
self.transform = T.ToTensor()
def __len__(self):
return len(self.all_rgbs)
def __getitem__(self, idx):
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
return sample
================================================
FILE: dataLoader/nsvf.py
================================================
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()
rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,np.cos(phi),-np.sin(phi),0],
[0,np.sin(phi), np.cos(phi),0],
[0,0,0,1]]).float()
rot_theta = lambda th : torch.Tensor([
[np.cos(th),0,-np.sin(th),0],
[0,1,0,0],
[np.sin(th),0, np.cos(th),0],
[0,0,0,1]]).float()
def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w
class NSVF(Dataset):
"""NSVF Generic Dataset."""
def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
self.root_dir = datadir
self.split = split
self.is_stack = is_stack
self.downsample = downsample
self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
self.define_transforms()
self.white_bg = True
self.near_far = [0.5,6.0]
self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
def bbox2corners(self):
corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
for i in range(3):
corners[i,[0,1],i] = corners[i,[1,0],i]
return corners.view(-1,3)
def read_meta(self):
with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
focal = float(f.readline().split()[0])
self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)
pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
if self.split == 'train':
pose_files = [x for x in pose_files if x.startswith('0_')]
img_files = [x for x in img_files if x.startswith('0_')]
elif self.split == 'val':
pose_files = [x for x in pose_files if x.startswith('1_')]
img_files = [x for x in img_files if x.startswith('1_')]
elif self.split == 'test':
test_pose_files = [x for x in pose_files if x.startswith('2_')]
test_img_files = [x for x in img_files if x.startswith('2_')]
if len(test_pose_files) == 0:
test_pose_files = [x for x in pose_files if x.startswith('1_')]
test_img_files = [x for x in img_files if x.startswith('1_')]
pose_files = test_pose_files
img_files = test_img_files
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
self.poses = []
self.all_rays = []
self.all_rgbs = []
assert len(img_files) == len(pose_files)
for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
image_path = os.path.join(self.root_dir, 'rgb', img_fname)
img = Image.open(image_path)
if self.downsample!=1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
if img.shape[-1]==4:
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]
c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
c2w = torch.FloatTensor(c2w)
self.poses.append(c2w) # C2W
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
# w2c = torch.inverse(c2w)
#
self.poses = torch.stack(self.poses)
if 'train' == self.split:
if self.is_stack:
self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
def world2ndc(self, points):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
if self.split == 'train':
return len(self.all_rays)
return len(self.all_rgbs)
def __getitem__(self, idx):
if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
else: # create data for each image separately
img = self.all_rgbs[idx]
rays = self.all_rays[idx]
sample = {'rays': rays,
'rgbs': img}
return sample
================================================
FILE: dataLoader/ray_utils.py
================================================
import torch, re
import numpy as np
from torch import searchsorted
from kornia import create_meshgrid
# from utils import index_point_feature
def depth2dist(z_vals, cos_angle):
# z_vals: [N_ray N_sample]
device = z_vals.device
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples]
dists = dists * cos_angle.unsqueeze(-1)
return dists
def ndc2dist(ndc_pts, cos_angle):
dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)
dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples]
return dists
def get_ray_directions(H, W, focal, center=None):
"""
Get ray directions for all pixels in camera coordinate.
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
Inputs:
H, W, focal: image height, width and focal length
Outputs:
directions: (H, W, 3), the direction of the rays in camera coordinate
"""
grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5
i, j = grid.unbind(-1)
# the direction here is without +0.5 pixel centering as calibration is not so accurate
# see https://github.com/bmild/nerf/issues/24
cent = center if center is not None else [W / 2, H / 2]
directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
return directions
def get_ray_directions_blender(H, W, focal, center=None):
"""
Get ray directions for all pixels in camera coordinate.
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
Inputs:
H, W, focal: image height, width and focal length
Outputs:
directions: (H, W, 3), the direction of the rays in camera coordinate
"""
grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5
i, j = grid.unbind(-1)
# the direction here is without +0.5 pixel centering as calibration is not so accurate
# see https://github.com/bmild/nerf/issues/24
cent = center if center is not None else [W / 2, H / 2]
directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],
-1) # (H, W, 3)
return directions
def get_rays(directions, c2w):
"""
Get ray origin and normalized directions in world coordinate for all pixels in one image.
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
Inputs:
directions: (H, W, 3) precomputed ray directions in camera coordinate
c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
Outputs:
rays_o: (H*W, 3), the origin of the rays in world coordinate
rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
"""
# Rotate ray directions from camera coordinate to the world coordinate
rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
# rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
# The origin of all rays is the camera origin in world coordinate
rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)
rays_d = rays_d.view(-1, 3)
rays_o = rays_o.view(-1, 3)
return rays_o, rays_d
def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = -(near + rays_o[..., 2]) / rays_d[..., 2]
rays_o = rays_o + t[..., None] * rays_d
# Projection
o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
o2 = 1. + 2. * near / rays_o[..., 2]
d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
d2 = -2. * near / rays_o[..., 2]
rays_o = torch.stack([o0, o1, o2], -1)
rays_d = torch.stack([d0, d1, d2], -1)
return rays_o, rays_d
def ndc_rays(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = (near - rays_o[..., 2]) / rays_d[..., 2]
rays_o = rays_o + t[..., None] * rays_d
# Projection
o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
o2 = 1. - 2. * near / rays_o[..., 2]
d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
d2 = 2. * near / rays_o[..., 2]
rays_o = torch.stack([o0, o1, o2], -1)
rays_d = torch.stack([d0, d1, d2], -1)
return rays_o, rays_d
# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
device = weights.device
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:
u = torch.linspace(0., 1., steps=N_samples, device=device)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# Invert CDF
u = u.contiguous()
inds = searchsorted(cdf.detach(), u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def dda(rays_o, rays_d, bbox_3D):
inv_ray_d = 1.0 / (rays_d + 1e-6)
t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3
t_max = (bbox_3D[1:] - rays_o) * inv_ray_d
t = torch.stack((t_min, t_max)) # 2 N_rays 3
t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]
t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]
return t_min, t_max
def ray_marcher(rays,
N_samples=64,
lindisp=False,
perturb=0,
bbox_3D=None):
"""
sample points along the rays
Inputs:
rays: ()
Returns:
"""
# Decompose the inputs
N_rays = rays.shape[0]
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
if bbox_3D is not None:
# cal aabb boundles
near, far = dda(rays_o, rays_d, bbox_3D)
# Sample depth points
z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples)
if not lindisp: # use linear sampling in depth space
z_vals = near * (1 - z_steps) + far * z_steps
else: # use linear sampling in disparity space
z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)
z_vals = z_vals.expand(N_rays, N_samples)
if perturb > 0: # perturb sampling depths (z_vals)
z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points
# get intervals between samples
upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)
perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
z_vals = lower + (upper - lower) * perturb_rand
xyz_coarse_sampled = rays_o.unsqueeze(1) + \
rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)
return xyz_coarse_sampled, rays_o, rays_d, z_vals
def read_pfm(filename):
file = open(filename, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().decode('utf-8').rstrip()
if header == 'PF':
color = True
elif header == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
file.close()
return data, scale
def ndc_bbox(all_rays):
near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]
near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]
far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]
far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]
print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')
return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))
import torchvision
normalize_vgg = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
def denormalize_vgg(img):
im = img.clone()
im[:, 0, :, :] *= 0.229
im[:, 1, :, :] *= 0.224
im[:, 2, :, :] *= 0.225
im[:, 0, :, :] += 0.485
im[:, 1, :, :] += 0.456
im[:, 2, :, :] += 0.406
return im
================================================
FILE: dataLoader/styleLoader.py
================================================
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as T
def getDataLoader(dataset_path, batch_size, sampler, image_side_length=256, num_workers=2):
transform = T.Compose([
T.Resize(size=(image_side_length*2, image_side_length*2)),
T.RandomCrop(image_side_length),
T.ToTensor(),
])
train_dataset = datasets.ImageFolder(dataset_path, transform=transform)
dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler(len(train_dataset)), num_workers=num_workers)
return dataloader
================================================
FILE: dataLoader/tankstemple.py
================================================
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
import random
from .ray_utils import *
def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
if axis == 'z':
return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]
elif axis == 'y':
return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]
else:
return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]
def cross(x, y, axis=0):
T = torch if isinstance(x, torch.Tensor) else np
return T.cross(x, y, axis)
def normalize(x, axis=-1, order=2):
if isinstance(x, torch.Tensor):
l2 = x.norm(p=order, dim=axis, keepdim=True)
return x / (l2 + 1e-8), l2
else:
l2 = np.linalg.norm(x, order, axis)
l2 = np.expand_dims(l2, axis)
l2[l2 == 0] = 1
return x / l2,
def cat(x, axis=1):
if isinstance(x[0], torch.Tensor):
return torch.cat(x, dim=axis)
return np.concatenate(x, axis=axis)
def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
"""
This function takes a vector 'camera_position' which specifies the location
of the camera in world coordinates and two vectors `at` and `up` which
indicate the position of the object and the up directions of the world
coordinate system respectively. The object is assumed to be centered at
the origin.
The output is a rotation matrix representing the transformation
from world coordinates -> view coordinates.
Input:
camera_position: 3
at: 1 x 3 or N x 3 (0, 0, 0) in default
up: 1 x 3 or N x 3 (0, 1, 0) in default
"""
if at is None:
at = torch.zeros_like(camera_position)
else:
at = torch.tensor(at).type_as(camera_position)
if up is None:
up = torch.zeros_like(camera_position)
up[2] = -1
else:
up = torch.tensor(up).type_as(camera_position)
z_axis = normalize(at - camera_position)[0]
x_axis = normalize(cross(up, z_axis))[0]
y_axis = normalize(cross(z_axis, x_axis))[0]
R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
return R
def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
c2ws = []
for t in range(frames):
c2w = torch.eye(4)
cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))
cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)
c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot
c2ws.append(c2w)
return torch.stack(c2ws)
class TanksTempleDataset(Dataset):
"""NSVF Generic Dataset."""
def __init__(self, datadir, split='train', downsample=4.0, wh=[1920,1080], is_stack=False):
self.root_dir = datadir
self.split = split
self.is_stack = is_stack
self.downsample = downsample
self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
self.define_transforms()
self.white_bg = True
self.near_far = [0.01,6.0]
self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
def bbox2corners(self):
corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
for i in range(3):
corners[i,[0,1],i] = corners[i,[1,0],i]
return corners.view(-1,3)
def read_meta(self):
self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt"))
self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1)
pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
if self.split == 'train':
pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('0_') and idx%3==0]
img_files = [x for idx,x in enumerate(img_files) if x.startswith('0_') and idx%3==0]
elif self.split == 'test':
pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('2_') and idx%3==0]
img_files = [x for idx,x in enumerate(img_files) if x.startswith('2_') and idx%3==0]
if len(test_pose_files) == 0:
test_pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('1_') and idx%3==0]
test_img_files = [x for idx,x in enumerate(img_files) if x.startswith('1_') and idx%3==0]
pose_files = test_pose_files
img_files = test_img_files
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
w, h = self.img_wh
self.poses = []
self.all_rays = []
self.all_rgbs = []
self.all_masks = []
assert len(img_files) == len(pose_files)
for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
image_path = os.path.join(self.root_dir, 'rgb', img_fname)
img = Image.open(image_path)
if self.downsample!=1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (3, h, w)
img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 3) RGBA
mask = torch.where(
img.sum(-1, keepdim=True) == 3.,
1.,
0.
)
self.all_masks.append(mask.reshape(h,w,1)) # (h, w, 1) A
if img.shape[-1]==4:
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs.append(img)
c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans
c2w = torch.FloatTensor(c2w)
self.poses.append(c2w) # C2W
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
self.poses = torch.stack(self.poses)
center = torch.mean(self.scene_bbox, dim=0)
radius = torch.norm(self.scene_bbox[1]-center)*1.2
up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
self.render_path = gen_path(pos_gen, up=up,frames=100)
self.render_path[:, :3, 3] += center
all_rays = self.all_rays
all_rgbs = self.all_rgbs
self.all_masks = torch.stack(self.all_masks) # (n_frames, h, w, 1)
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
if self.is_stack:
self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames]),h,w,6)
avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)
self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)
self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
@torch.no_grad()
def prepare_feature_data(self, encoder, chunk=4):
'''
Prepare feature maps as training data.
'''
assert self.is_stack, 'Dataset should contain original stacked taining data!'
print('====> prepare_feature_data ...')
frames_num, h, w, _ = self.all_rgbs_stack.size()
features = []
for chunk_idx in tqdm(range(frames_num // chunk + int(frames_num % chunk > 0))):
rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()
features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1
# resize to the size of rgb map so that rays can match
features_chunk = T.functional.resize(features_chunk, size=(h,w),
interpolation=T.InterpolationMode.BILINEAR)
features.append(features_chunk.detach().cpu().requires_grad_(False))
self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)
self.all_features = self.all_features_stack.reshape(-1, 256)
print('prepare_feature_data Done!')
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
def world2ndc(self, points):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
if self.split == 'train':
return len(self.all_rays)
return len(self.all_rgbs)
def __getitem__(self, idx):
if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
else: # create data for each image separately
img = self.all_rgbs[idx]
rays = self.all_rays[idx]
sample = {'rays': rays,
'rgbs': img}
return sample
================================================
FILE: dataLoader/your_own_data.py
================================================
import torch,cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
class YourOwnDataset(Dataset):
def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
self.N_vis = N_vis
self.root_dir = datadir
self.split = split
self.is_stack = is_stack
self.downsample = downsample
self.define_transforms()
self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
self.white_bg = True
self.near_far = [0.1,100.0]
self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
self.downsample=downsample
def read_depth(self, filename):
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
return depth
def read_meta(self):
with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
self.meta = json.load(f)
w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
self.img_wh = [w,h]
self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
self.cx, self.cy = self.meta['cx'],self.meta['cy']
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()
self.image_paths = []
self.poses = []
self.all_rays = []
self.all_rgbs = []
self.all_masks = []
self.all_depth = []
img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
frame = self.meta['frames'][i]
pose = np.array(frame['transform_matrix']) @ self.blender2opencv
c2w = torch.FloatTensor(pose)
self.poses += [c2w]
image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
self.image_paths += [image_path]
img = Image.open(image_path)
if self.downsample!=1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA
if img.shape[-1]==4:
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
self.poses = torch.stack(self.poses)
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
# self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
def world2ndc(self,points,lindisp=None):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
return len(self.all_rgbs)
def __getitem__(self, idx):
if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
else: # create data for each image separately
img = self.all_rgbs[idx]
rays = self.all_rays[idx]
mask = self.all_masks[idx] # for quantity evaluation
sample = {'rays': rays,
'rgbs': img}
return sample
================================================
FILE: extra/auto_run_paramsets.py
================================================
import os
import threading, queue
import numpy as np
import time
def getFolderLocker(logFolder):
while True:
try:
os.makedirs(logFolder+"/lockFolder")
break
except:
time.sleep(0.01)
def releaseFolderLocker(logFolder):
os.removedirs(logFolder+"/lockFolder")
def getStopFolder(logFolder):
return os.path.isdir(logFolder+"/stopFolder")
def get_param_str(key, val):
if key == 'data_name':
return f'--datadir {datafolder}/{val} '
else:
return f'--{key} {val} '
def get_param_list(param_dict):
param_keys = list(param_dict.keys())
param_modes = len(param_keys)
param_nums = [len(param_dict[key]) for key in param_keys]
param_ids = np.zeros(param_nums+[param_modes], dtype=int)
for i in range(param_modes):
broad_tuple = np.ones(param_modes, dtype=int).tolist()
broad_tuple[i] = param_nums[i]
broad_tuple = tuple(broad_tuple)
print(broad_tuple)
param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple)
param_ids = param_ids.reshape(-1, param_modes)
# print(param_ids)
print(len(param_ids))
params = []
expnames = []
for i in range(param_ids.shape[0]):
one = ""
name = ""
param_id = param_ids[i]
for j in range(param_modes):
key = param_keys[j]
val = param_dict[key][param_id[j]]
if type(key) is tuple:
assert len(key) == len(val)
for k in range(len(key)):
one += get_param_str(key[k], val[k])
name += f'{val[k]},'
name=name[:-1]+'-'
else:
one += get_param_str(key, val)
name += f'{val}-'
params.append(one)
name=name.replace(' ','')
print(name)
expnames.append(name[:-1])
# print(params)
return params, expnames
if __name__ == '__main__':
# nerf
expFolder = "nerf/"
# parameters to iterate, use tuple to couple multiple parameters
datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/'
param_dict = {
'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'],
'data_dim_color': [13, 27, 54]
}
# n_iters = 30000
# for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
# cmd = f'CUDA_VISIBLE_DEVICES={cuda} python train.py ' \
# f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\
# f'--expname {data_name} --batch_size {batch_size} ' \
# f'--n_iters {n_iters} ' \
# f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\
# f'--N_vis {5} ' \
# f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \
# f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \
# f'--shadingMode MLP_Fea --fea2denseAct softplus --view_pe {2} --fea_pe {2} ' \
# f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \
# f'--render_test 1 '
# print(cmd)
# os.system(cmd)
# nsvf
# expFolder = "nsvf_0227/"
# datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/'
# param_dict = {
# 'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
# 'shadingMode': ['SH'],
# ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")],
# ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)],
# ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)],
# ('n_iters','N_voxel_final'): [(30000,300**3)],
# ('dataset_name','N_vis','render_test') : [("nsvf",5,1)],
# ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")]
#
# }
# tankstemple
# expFolder = "tankstemple_0304/"
# datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/'
# param_dict = {
# 'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'],
# 'shadingMode': ['MLP_Fea'],
# ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")],
# ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)],
# ('TV_weight_density','TV_weight_app'):[(0.1,0.01)],
# # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)],
# ('n_iters','N_voxel_final'): [(15000,300**3)],
# ('dataset_name','N_vis') : [("tankstemple",5)],
# ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")]
# }
# llff
# expFolder = "real_iconic/"
# datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/'
# List = os.listdir(datafolder)
# param_dict = {
# 'data_name': List,
# ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)],
# ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")],
# ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
# ('n_iters','N_voxel_final'): [(25000,640**3)],
# ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)],
# ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
# }
# expFolder = "llff/"
# datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data'
# param_dict = {
# 'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'
# ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")],
# ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)],
# ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
# ('n_iters','N_voxel_final'): [(25000,640**3)],
# ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)],
# ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
# }
#setting available gpus
gpus_que = queue.Queue(3)
for i in [1,2,3]:
gpus_que.put(i)
os.makedirs(f"log/{expFolder}", exist_ok=True)
def run_program(gpu, expname, param):
cmd = f'CUDA_VISIBLE_DEVICES={gpu} python train.py ' \
f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \
f'{param}' \
f'> "log/{expFolder}{expname}/{expname}.txt"'
print(cmd)
os.system(cmd)
gpus_que.put(gpu)
params, expnames = get_param_list(param_dict)
logFolder=f"log/{expFolder}"
os.makedirs(logFolder, exist_ok=True)
ths = []
for i in range(len(params)):
if getStopFolder(logFolder):
break
targetFolder = f"log/{expFolder}{expnames[i]}"
gpu = gpus_que.get()
getFolderLocker(logFolder)
if os.path.isdir(targetFolder):
releaseFolderLocker(logFolder)
gpus_que.put(gpu)
continue
else:
os.makedirs(targetFolder, exist_ok=True)
print("making",targetFolder, "running",expnames[i], params[i])
releaseFolderLocker(logFolder)
t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True)
t.start()
ths.append(t)
for th in ths:
th.join()
================================================
FILE: extra/compute_metrics.py
================================================
import os, math
import numpy as np
import scipy.signal
from typing import List, Optional
from PIL import Image
import os
import torch
import configargparse
__LPIPS__ = {}
def init_lpips(net_name, device):
assert net_name in ['alex', 'vgg']
import lpips
print(f'init_lpips: lpips_{net_name}')
return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
def rgb_lpips(np_gt, np_im, net_name, device):
if net_name not in __LPIPS__:
__LPIPS__[net_name] = init_lpips(net_name, device)
gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
return __LPIPS__[net_name](gt, im, normalize=True).item()
def findItem(items, target):
for one in items:
if one[:len(target)]==target:
return one
return None
''' Evaluation metrics (ssim, lpips)
'''
def rgb_ssim(img0, img1, max_val,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
return_map=False):
# Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
assert len(img0.shape) == 3
assert img0.shape[-1] == 3
assert img0.shape == img1.shape
# Construct a 1D Gaussian blur filter.
hw = filter_size // 2
shift = (2 * hw - filter_size + 1) / 2
f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
filt = np.exp(-0.5 * f_i)
filt /= np.sum(filt)
# Blur in x and y (faster than the 2D convolution).
def convolve2d(z, f):
return scipy.signal.convolve2d(z, f, mode='valid')
filt_fn = lambda z: np.stack([
convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
for i in range(z.shape[-1])], -1)
mu0 = filt_fn(img0)
mu1 = filt_fn(img1)
mu00 = mu0 * mu0
mu11 = mu1 * mu1
mu01 = mu0 * mu1
sigma00 = filt_fn(img0**2) - mu00
sigma11 = filt_fn(img1**2) - mu11
sigma01 = filt_fn(img0 * img1) - mu01
# Clip the variances and covariances to valid values.
# Variance must be non-negative:
sigma00 = np.maximum(0., sigma00)
sigma11 = np.maximum(0., sigma11)
sigma01 = np.sign(sigma01) * np.minimum(
np.sqrt(sigma00 * sigma11), np.abs(sigma01))
c1 = (k1 * max_val)**2
c2 = (k2 * max_val)**2
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
ssim_map = numer / denom
ssim = np.mean(ssim_map)
return ssim_map if return_map else ssim
if __name__ == '__main__':
parser = configargparse.ArgumentParser()
parser.add_argument("--exp", type=str, help="folder of exps")
parser.add_argument("--paramStr", type=str, help="str of params")
args = parser.parse_args()
# datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']#
# gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic"
# expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp
# datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']#
# gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/"
# expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp
paramStr = args.paramStr
fileNum = 200
expitems = os.listdir(expFolder)
finalFolder = f'{expFolder}/finals/{paramStr}'
outFile = f'{finalFolder}/{paramStr}_metrics.txt'
os.makedirs(finalFolder, exist_ok=True)
expitems.sort(reverse=True)
with open(outFile, 'w') as f:
all_psnr = []
all_ssim = []
all_alex = []
all_vgg = []
for dataname in datanames:
gtstr = gtFolder+"/"+dataname+"/test/r_%d.png"
expname = findItem(expitems, f'{paramStr}-{dataname}')
print("expname: ", expname)
if expname is None:
print("no ",dataname, "exists")
continue
resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png"
metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt'
video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4'
exist_metric=False
if os.path.isfile(metric_file):
metrics = np.loadtxt(metric_file)
print(metrics, metrics.tolist())
if metrics.size == 4:
psnr, ssim, l_a, l_v = metrics.tolist()
exist_metric = True
os.system(f"cp {video_file} {finalFolder}/")
if not exist_metric:
psnrs = []
ssims = []
l_alex = []
l_vgg = []
for i in range(fileNum):
gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0
gtmask = gt[...,[3]]
gt = gt[...,:3]
gt = gt*gtmask + (1-gtmask)
img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3] / 255.0
# print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max())
psnr = -10. * np.log10(np.mean(np.square(img - gt)))
ssim = rgb_ssim(img, gt, 1)
lpips_alex = rgb_lpips(gt, img, 'alex','cuda')
lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda')
print(i, psnr, ssim, lpips_alex, lpips_vgg)
psnrs.append(psnr)
ssims.append(ssim)
l_alex.append(lpips_alex)
l_vgg.append(lpips_vgg)
psnr = np.mean(np.array(psnrs))
ssim = np.mean(np.array(ssims))
l_a = np.mean(np.array(l_alex))
l_v = np.mean(np.array(l_vgg))
rS=f'{dataname} : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}'
print(rS)
f.write(rS+"\n")
all_psnr.append(psnr)
all_ssim.append(ssim)
all_alex.append(l_a)
all_vgg.append(l_v)
psnr = np.mean(np.array(all_psnr))
ssim = np.mean(np.array(all_ssim))
l_a = np.mean(np.array(all_alex))
l_v = np.mean(np.array(all_vgg))
rS=f'mean : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}'
print(rS)
f.write(rS+"\n")
================================================
FILE: models/VGG.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
import torchvision.models as models
# pytorch pretrained vgg
class Encoder(nn.Module):
def __init__(self):
super().__init__()
#pretrained vgg19
vgg19 = models.vgg19(weights='DEFAULT').features
self.relu1_1 = vgg19[:2]
self.relu2_1 = vgg19[2:7]
self.relu3_1 = vgg19[7:12]
self.relu4_1 = vgg19[12:21]
#fix parameters
self.requires_grad_(False)
def forward(self, x):
_output = namedtuple('output', ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1'])
relu1_1 = self.relu1_1(x)
relu2_1 = self.relu2_1(relu1_1)
relu3_1 = self.relu3_1(relu2_1)
relu4_1 = self.relu4_1(relu3_1)
output = _output(relu1_1, relu2_1, relu3_1, relu4_1)
return output
class Decoder(nn.Module):
"""
starting from relu 4_1
"""
def __init__(self, ckpt_path=None):
super().__init__()
self.layers = nn.Sequential(
# nn.Conv2d(512, 256, 3, padding=1, padding_mode='reflect'),
# nn.ReLU(),
# nn.Upsample(scale_factor=2, mode='nearest'), # relu4-1
nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu3-4
nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu3-3
nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu3-2
nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),# relu3-1
nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu2-2
nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),# relu2-1
nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu1-2
nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'),
)
if ckpt_path is not None:
self.load_state_dict(torch.load(ckpt_path))
def forward(self, x):
return self.layers(x)
### high-res unet feature map decoder
class DownBlock(nn.Module):
def __init__(self, in_dim, out_dim, down='conv'):
super(DownBlock, self).__init__()
if down == 'conv':
self.down_conv = nn.Sequential(
nn.Conv2d(in_dim, out_dim, 3, 2, 1),
nn.LeakyReLU(),
nn.Conv2d(out_dim, out_dim, 3, 1, 1),
nn.LeakyReLU(),
)
elif down == 'mean':
self.down_conv = nn.AvgPool2d(2)
else:
raise NotImplementedError(
'[ERROR] invalid downsampling operator: {:s}'.format(down)
)
def forward(self, x):
x = self.down_conv(x)
return x
class UpBlock(nn.Module):
def __init__(self, in_dim, out_dim, skip_dim=None, up='nearest'):
super(UpBlock, self).__init__()
if up == 'conv':
self.up_conv = nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, 3, 2, 1, 1),
nn.ReLU(),
)
else:
assert up in ('bilinear', 'nearest'), \
'[ERROR] invalid upsampling mode: {:s}'.format(up)
self.up_conv = nn.Sequential(
nn.Upsample(scale_factor=2, mode=up),
nn.Conv2d(in_dim, out_dim, 3, 1, 1),
nn.ReLU(),
)
in_dim = out_dim
if skip_dim is not None:
in_dim += skip_dim
self.conv = nn.Sequential(
nn.Conv2d(in_dim, out_dim, 3, 1, 1),
nn.ReLU(),
)
def _pad(self, x, y):
dh = y.size(-2) - x.size(-2)
dw = y.size(-1) - x.size(-1)
if dh == 0 and dw == 0:
return x
if dh < 0:
x = x[..., :dh, :]
if dw < 0:
x = x[..., :, :dw]
if dh > 0 or dw > 0:
x = F.pad(
x,
pad=(dw // 2, dw - dw // 2, dh // 2, dh - dh // 2),
mode='reflect'
)
return x
def forward(self, x, skip=None):
x = self.up_conv(x)
if skip is not None:
x = torch.cat([self._pad(x, skip), skip], 1)
x = self.conv(x)
return x
class UNetDecoder(nn.Module):
def __init__(self, in_dim=256):
super(UNetDecoder, self).__init__()
self.down_layers = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.up_layers = nn.ModuleList()
in_dim = in_dim
self.n_levels = 2
self.up = 1
for i in range(self.n_levels):
self.down_layers.append(
DownBlock(
in_dim, in_dim,
)
)
out_dim = in_dim // 2 ** (self.n_levels - i)
self.skip_convs.append(nn.Conv2d(in_dim, out_dim, 1))
self.up_layers.append(
UpBlock(
out_dim * 2, out_dim, out_dim,
)
)
out_dim = in_dim // 2 ** self.n_levels
self.out_conv = nn.Sequential(
nn.Conv2d(out_dim, out_dim, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(out_dim, 3, 1, 1),
)
def forward(self, feats):
skips = []
for i in range(self.n_levels):
skips.append(self.skip_convs[i](feats))
feats = self.down_layers[i](feats)
for i in range(self.n_levels - 1, -1, -1):
feats = self.up_layers[i](feats, skips[i])
rgb = self.out_conv(feats)
return rgb
### high-res feature map decoder
class PlainDecoder(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu3-4
nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu3-3
nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu3-2
nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu2-2
nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'),
nn.ReLU(), # relu1-2
nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'),
)
def forward(self, x):
return self.layers(x)
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/sh.py
================================================
import torch
################## sh function ##################
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
1.0925484305920792,
-1.0925484305920792,
0.31539156525252005,
-1.0925484305920792,
0.5462742152960396
]
C3 = [
-0.5900435899266435,
2.890611442640554,
-0.4570457994644658,
0.3731763325901154,
-0.4570457994644658,
1.445305721320277,
-0.5900435899266435
]
C4 = [
2.5033429417967046,
-1.7701307697799304,
0.9461746957575601,
-0.6690465435572892,
0.10578554691520431,
-0.6690465435572892,
0.47308734787878004,
-1.7701307697799304,
0.6258357354491761,
]
def eval_sh(deg, sh, dirs):
"""
Evaluate spherical harmonics at unit directions
using hardcoded SH polynomials.
Works with torch/np/jnp.
... Can be 0 or more batch dimensions.
:param deg: int SH max degree. Currently, 0-4 supported
:param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)
:param dirs: torch.Tensor unit directions (..., 3)
:return: (..., C)
"""
assert deg <= 4 and deg >= 0
assert (deg + 1) ** 2 == sh.shape[-1]
C = sh.shape[-2]
result = C0 * sh[..., 0]
if deg > 0:
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
result = (result -
C1 * y * sh[..., 1] +
C1 * z * sh[..., 2] -
C1 * x * sh[..., 3])
if deg > 1:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result = (result +
C2[0] * xy * sh[..., 4] +
C2[1] * yz * sh[..., 5] +
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
C2[3] * xz * sh[..., 7] +
C2[4] * (xx - yy) * sh[..., 8])
if deg > 2:
result = (result +
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
C3[1] * xy * z * sh[..., 10] +
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
C3[5] * z * (xx - yy) * sh[..., 14] +
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
if deg > 3:
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
return result
def eval_sh_bases(deg, dirs):
"""
Evaluate spherical harmonics bases at unit directions,
without taking linear combination.
At each point, the final result may the be
obtained through simple multiplication.
:param deg: int SH max degree. Currently, 0-4 supported
:param dirs: torch.Tensor (..., 3) unit directions
:return: torch.Tensor (..., (deg+1) ** 2)
"""
assert deg <= 4 and deg >= 0
result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)
result[..., 0] = C0
if deg > 0:
x, y, z = dirs.unbind(-1)
result[..., 1] = -C1 * y;
result[..., 2] = C1 * z;
result[..., 3] = -C1 * x;
if deg > 1:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result[..., 4] = C2[0] * xy;
result[..., 5] = C2[1] * yz;
result[..., 6] = C2[2] * (2.0 * zz - xx - yy);
result[..., 7] = C2[3] * xz;
result[..., 8] = C2[4] * (xx - yy);
if deg > 2:
result[..., 9] = C3[0] * y * (3 * xx - yy);
result[..., 10] = C3[1] * xy * z;
result[..., 11] = C3[2] * y * (4 * zz - xx - yy);
result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);
result[..., 13] = C3[4] * x * (4 * zz - xx - yy);
result[..., 14] = C3[5] * z * (xx - yy);
result[..., 15] = C3[6] * x * (xx - 3 * yy);
if deg > 3:
result[..., 16] = C4[0] * xy * (xx - yy);
result[..., 17] = C4[1] * yz * (3 * xx - yy);
result[..., 18] = C4[2] * xy * (7 * zz - 1);
result[..., 19] = C4[3] * yz * (7 * zz - 3);
result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);
result[..., 21] = C4[5] * xz * (7 * zz - 3);
result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);
result[..., 23] = C4[7] * xz * (xx - 3 * yy);
result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));
return result
================================================
FILE: models/styleModules.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
def calc_mean_std(x, eps=1e-8):
"""
calculating channel-wise instance mean and standard variance
x: shape of (N,C,*)
"""
mean = torch.mean(x.flatten(2), dim=-1, keepdim=True) # size of (N, C, 1)
std = torch.std(x.flatten(2), dim=-1, keepdim=True) + eps # size of (N, C, 1)
return mean, std
def cal_adain_style_loss(x, y):
"""
style loss in one layer
Args:
x, y: feature maps of size [N, C, H, W]
"""
x_mean, x_std = calc_mean_std(x)
y_mean, y_std = calc_mean_std(y)
return nn.functional.mse_loss(x_mean, y_mean) \
+ nn.functional.mse_loss(x_std, y_std)
def cal_mse_content_loss(x, y):
return nn.functional.mse_loss(x, y)
class LearnableIN(nn.Module):
'''
Input: (N, C, L) or (C, L)
'''
def __init__(self, dim=256):
super().__init__()
self.IN = torch.nn.InstanceNorm1d(dim, momentum=1e-4, track_running_stats=True)
def forward(self, x):
if x.size()[-1] <= 1:
return x
return self.IN(x)
class SimpleLinearStylizer(nn.Module):
def __init__(self, input_dim=256, embed_dim=32, n_layers=3 ) -> None:
super().__init__()
self.input_dim = input_dim
self.embed_dim = embed_dim
self.IN = LearnableIN(input_dim)
self.q_embed = nn.Conv1d(input_dim, embed_dim, 1)
self.k_embed = nn.Conv1d(input_dim, embed_dim, 1)
self.v_embed = nn.Conv1d(input_dim, embed_dim, 1)
self.unzipper = nn.Conv1d(embed_dim, input_dim, 1, bias=0)
s_net = []
for i in range(n_layers - 1):
out_dim = max(embed_dim, input_dim // 2)
s_net.append(
nn.Sequential(
nn.Conv1d(input_dim, out_dim, 1),
nn.ReLU(inplace=True),
)
)
input_dim = out_dim
s_net.append(nn.Conv1d(input_dim, embed_dim, 1))
self.s_net = nn.Sequential(*s_net)
self.s_fc = nn.Linear(embed_dim ** 2, embed_dim ** 2)
def _vectorized_covariance(self, x):
cov = torch.bmm(x, x.transpose(2, 1)) / x.size(-1)
cov = cov.flatten(1)
return cov
def get_content_matrix(self, c):
'''
Args:
c: content feature [N,input_dim,S]
Return:
mat: [N,S,embed_dim,embed_dim]
'''
normalized_c = self.IN(c)
# normalized_c = torch.nn.functional.instance_norm(c)
q_embed = self.q_embed(normalized_c)
k_embed = self.k_embed(normalized_c)
c_cov = q_embed.transpose(1,2).unsqueeze(3) * k_embed.transpose(1,2).unsqueeze(2) # [N,S,embed_dim,embed_dim]
attn = torch.softmax(c_cov, -1) # [N,S,embed_dim,embed_dim]
return attn, normalized_c
def get_style_mean_std_matrix(self, s):
'''
Args:
s: style feature [N,input_dim,S]
Return:
mat: [N,embed_dim,embed_dim]
'''
s_mean = s.mean(-1, keepdim=True)
s_std = s.std(-1, keepdim=True)
s = s - s_mean
s_embed = self.s_net(s)
s_cov = self._vectorized_covariance(s_embed)
s_mat = self.s_fc(s_cov)
s_mat = s_mat.reshape(-1, self.embed_dim, self.embed_dim)
return s_mean, s_std, s_mat
def transform_content_3D(self, c):
'''
Args:
c: content feature [N,input_dim,S]
Return:
transformed_c: [N,embed_dim,S]
'''
attn, normalized_c = self.get_content_matrix(c) # [N,S,embed_dim,embed_dim]
c = self.v_embed(normalized_c) # [N,embed_dim,S]
c = c.transpose(1,2).unsqueeze(3) # [N,S,embed_dim,1]
c = torch.matmul(attn, c).squeeze(3) # [N,S,embed_dim]
return c.transpose(1,2)
def transfer_style_2D(self, s_mean_std_mat, c, acc_map):
'''
Agrs:
c: content feature map after volume rendering [N,embed_dim,S]
s_mat: style matrix [N,embed_dim,embed_dim]
acc_map: [S]
s_mean = [N,input_dim,1]
s_std = [N,input_dim,1]
'''
s_mean, s_std, s_mat = s_mean_std_mat
cs = torch.bmm(s_mat, c) # [N,embed_dim,S]
cs = self.unzipper(cs) # [N,input_dim,S]
cs = cs * s_std + s_mean * acc_map[None,None,...]
return cs
class AdaAttN(nn.Module):
""" Attention-weighted AdaIN (Liu et al., ICCV 21) """
def __init__(self, qk_dim, v_dim):
"""
Args:
qk_dim (int): query and key size.
v_dim (int): value size.
"""
super(AdaAttN, self).__init__()
self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)
self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)
self.s_embed = nn.Conv1d(v_dim, v_dim, 1)
def forward(self, q, k):
"""
Args:
q (float tensor, (bs, qk, *)): query (content) features.
k (float tensor, (bs, qk, *)): key (style) features.
c (float tensor, (bs, v, *)): content value features.
s (float tensor, (bs, v, *)): style value features.
Returns:
cs (float tensor, (bs, v, *)): stylized content features.
"""
c, s = q, k
shape = c.shape
q, k = q.flatten(2), k.flatten(2)
c, s = c.flatten(2), s.flatten(2)
# QKV attention with projected content and style features
q = self.q_embed(F.instance_norm(q)).transpose(2, 1) # (bs, n, qk)
k = self.k_embed(F.instance_norm(k)) # (bs, qk, m)
s = self.s_embed(s).transpose(2, 1) # (bs, m, v)
attn = F.softmax(torch.bmm(q, k), -1) # (bs, n, m)
# attention-weighted channel-wise statistics
mean = torch.bmm(attn, s) # (bs, n, v)
var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2) # (bs, n, v)
mean = mean.transpose(2, 1) # (bs, v, n)
std = torch.sqrt(var).transpose(2, 1) # (bs, v, n)
cs = F.instance_norm(c) * std + mean # (bs, v, n)
cs = cs.reshape(shape)
return cs
class AdaAttN_new_IN(nn.Module):
""" Attention-weighted AdaIN (Liu et al., ICCV 21) """
def __init__(self, qk_dim, v_dim):
"""
Args:
qk_dim (int): query and key size.
v_dim (int): value size.
"""
super(AdaAttN_new_IN, self).__init__()
self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)
self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)
self.s_embed = nn.Conv1d(v_dim, v_dim, 1)
self.IN = LearnableIN(qk_dim)
def forward(self, q, k):
"""
Args:
q (float tensor, (bs, qk, *)): query (content) features.
k (float tensor, (bs, qk, *)): key (style) features.
c (float tensor, (bs, v, *)): content value features.
s (float tensor, (bs, v, *)): style value features.
Returns:
cs (float tensor, (bs, v, *)): stylized content features.
"""
c, s = q, k
shape = c.shape
q, k = q.flatten(2), k.flatten(2)
c, s = c.flatten(2), s.flatten(2)
# QKV attention with projected content and style features
q = self.q_embed(self.IN(q)).transpose(2, 1) # (bs, n, qk)
k = self.k_embed(F.instance_norm(k)) # (bs, qk, m)
s = self.s_embed(s).transpose(2, 1) # (bs, m, v)
attn = F.softmax(torch.bmm(q, k), -1) # (bs, n, m)
# attention-weighted channel-wise statistics
mean = torch.bmm(attn, s) # (bs, n, v)
var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2) # (bs, n, v)
mean = mean.transpose(2, 1) # (bs, v, n)
std = torch.sqrt(var).transpose(2, 1) # (bs, v, n)
cs = self.IN(c) * std + mean # (bs, v, n)
cs = cs.reshape(shape)
return cs
class AdaAttN_woin(nn.Module):
""" Attention-weighted AdaIN (Liu et al., ICCV 21) """
def __init__(self, qk_dim, v_dim):
"""
Args:
qk_dim (int): query and key size.
v_dim (int): value size.
"""
super().__init__()
self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)
self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)
self.s_embed = nn.Conv1d(v_dim, v_dim, 1)
def forward(self, q, k):
"""
Args:
q (float tensor, (bs, qk, *)): query (content) features.
k (float tensor, (bs, qk, *)): key (style) features.
c (float tensor, (bs, v, *)): content value features.
s (float tensor, (bs, v, *)): style value features.
Returns:
cs (float tensor, (bs, v, *)): stylized content features.
"""
c, s = q, k
shape = c.shape
q, k = q.flatten(2), k.flatten(2)
c, s = c.flatten(2), s.flatten(2)
# QKV attention with projected content and style features
q = self.q_embed(q).transpose(2, 1) # (bs, n, qk)
k = self.k_embed(k) # (bs, qk, m)
s = self.s_embed(s).transpose(2, 1) # (bs, m, v)
attn = F.softmax(torch.bmm(q, k), -1) # (bs, n, m)
# attention-weighted channel-wise statistics
mean = torch.bmm(attn, s) # (bs, n, v)
var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2) # (bs, n, v)
mean = mean.transpose(2, 1) # (bs, v, n)
std = torch.sqrt(var).transpose(2, 1) # (bs, v, n)
cs = c * std + mean # (bs, v, n)
cs = cs.reshape(shape)
return cs
================================================
FILE: models/tensoRF.py
================================================
from .tensorBase import *
from .VGG import Encoder, Decoder, UNetDecoder, PlainDecoder
from .styleModules import LearnableIN, SimpleLinearStylizer
class TensorVMSplit(TensorBase):
def __init__(self, aabb, gridSize, device, **kargs):
super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs)
def change_to_feature_mod(self, feature_n_comp, device):
self.density_line.requires_grad_(False)
self.density_plane.requires_grad_(False)
self.app_line = None
self.app_plane = None
self.basis_mat = None
self.renderModule = None
# Both encoder and decoder do not require grad when initialized
self.encoder = Encoder().to(device)
self.decoder = PlainDecoder().to(device)
# We need to finetune decoder when training a feature grid
self.decoder.requires_grad_(True)
self.feature_n_comp = feature_n_comp
self.init_feature_svd(device)
def change_to_style_mod(self, device='cuda'):
assert self.feature_line is not None, 'Have to be trained in feature mod first!'
self.feature_line.requires_grad_(False)
self.feature_plane.requires_grad_(False)
self.feature_basis_mat.requires_grad_(False)
self.decoder.requires_grad_(True)
self.IN = LearnableIN().to(device)
# self.stylizer = NearestFeatureTransform()
# self.stylizer = LearnableIN(1,256,device)
# self.stylizer = AdaAttN(256, 256).to(device)
# self.stylizer = AdaAttN_woin(256, 256).to(device)
self.stylizer = SimpleLinearStylizer(256).to(device)
def init_svd_volume(self, res, device):
self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device)
self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device)
self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device)
def init_feature_svd(self, device):
self.feature_plane, self.feature_line = self.init_one_svd(self.feature_n_comp, self.gridSize, 0.1, device)
self.feature_basis_mat = torch.nn.Linear(sum(self.feature_n_comp), 256, bias=False).to(device)
def init_one_svd(self, n_component, gridSize, scale, device):
plane_coef, line_coef = [], []
for i in range(len(self.vecMode)):
vec_id = self.vecMode[i]
mat_id_0, mat_id_1 = self.matMode[i]
plane_coef.append(torch.nn.Parameter(
scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0]))))
line_coef.append(
torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1))))
return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device)
def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, {'params': self.density_plane, 'lr': lr_init_spatialxyz},
{'params': self.app_line, 'lr': lr_init_spatialxyz}, {'params': self.app_plane, 'lr': lr_init_spatialxyz},
{'params': self.basis_mat.parameters(), 'lr':lr_init_network}]
if isinstance(self.renderModule, torch.nn.Module):
grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
return grad_vars
def get_optparam_groups_feature_mod(self, lr_init_spatialxyz, lr_init_network):
grad_vars = [{'params': self.feature_line, 'lr': lr_init_spatialxyz},
{'params': self.feature_plane, 'lr': lr_init_spatialxyz},
{'params': self.feature_basis_mat.parameters(), 'lr':lr_init_network},
{'params': self.decoder.parameters(), 'lr':lr_init_network}]
return grad_vars
def get_optparam_groups_style_mod(self, lr_init_network, lr_finetune):
grad_vars = [
{'params': self.stylizer.parameters(), 'lr': lr_init_network},
{'params': self.decoder.parameters(), 'lr': lr_finetune},
]
return grad_vars
def vectorDiffs(self, vector_comps):
total = 0
for idx in range(len(vector_comps)):
n_comp, n_size = vector_comps[idx].shape[1:-1]
dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2))
non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]
total = total + torch.mean(torch.abs(non_diagonal))
return total
def vector_comp_diffs(self):
return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line)
def density_L1(self):
total = 0
for idx in range(len(self.density_plane)):
total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[idx]))# + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx]))
return total
def TV_loss_density(self, reg):
total = 0
for idx in range(len(self.density_plane)):
total = total + reg(self.density_plane[idx]) * 1e-2 #+ reg(self.density_line[idx]) * 1e-3
return total
def TV_loss_app(self, reg):
total = 0
for idx in range(len(self.app_plane)):
total = total + reg(self.app_plane[idx]) * 1e-2 #+ reg(self.app_line[idx]) * 1e-3
return total
def TV_loss_feature(self, reg):
total = 0
for idx in range(len(self.feature_plane)):
total = total + reg(self.feature_plane[idx]) + reg(self.feature_line[idx])
return total
def compute_densityfeature(self, xyz_sampled):
# plane + line basis
coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device)
for idx_plane in range(len(self.density_plane)):
plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]],
align_corners=True).view(-1, *xyz_sampled.shape[:1])
line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]],
align_corners=True).view(-1, *xyz_sampled.shape[:1])
sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0)
return sigma_feature
def compute_appfeature(self, xyz_sampled):
# plane + line basis
coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
plane_coef_point,line_coef_point = [],[]
for idx_plane in range(len(self.app_plane)):
plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]],
align_corners=True).view(-1, *xyz_sampled.shape[:1]))
line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]],
align_corners=True).view(-1, *xyz_sampled.shape[:1]))
plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)
return self.basis_mat((plane_coef_point * line_coef_point).T)
def compute_feature(self, xyz_sampled):
# plane + line basis
coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
plane_coef_point,line_coef_point = [],[]
for idx_plane in range(len(self.feature_plane)):
plane_coef_point.append(F.grid_sample(self.feature_plane[idx_plane], coordinate_plane[[idx_plane]],
align_corners=True).view(-1, *xyz_sampled.shape[:1]))
line_coef_point.append(F.grid_sample(self.feature_line[idx_plane], coordinate_line[[idx_plane]],
align_corners=True).view(-1, *xyz_sampled.shape[:1]))
plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)
return self.feature_basis_mat((plane_coef_point * line_coef_point).T)
def render_feature_map(self, rays_chunk, s_mean_std_mat=None, is_train=False, ndc_ray=False, N_samples=-1):
# sample points
viewdirs = rays_chunk[:, 3:6]
if ndc_ray:
xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
dists = dists * rays_norm
else:
xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
if self.alphaMask is not None:
alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
alpha_mask = alphas > 0
ray_invalid = ~ray_valid
ray_invalid[ray_valid] |= (~alpha_mask)
ray_valid = ~ray_invalid
sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
if s_mean_std_mat is not None:
features = torch.zeros((*xyz_sampled.shape[:2], self.stylizer.embed_dim), device=xyz_sampled.device)
else:
features = torch.zeros((*xyz_sampled.shape[:2], 256), device=xyz_sampled.device)
if ray_valid.any():
xyz_sampled = self.normalize_coord(xyz_sampled)
sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
validsigma = self.feature2density(sigma_feature)
sigma[ray_valid] = validsigma
alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)
app_mask = weight > self.rayMarch_weight_thres
if app_mask.any():
valid_features = self.compute_feature(xyz_sampled[app_mask]) # [n_valid_points~40k if not specify nSamples, C=256]
# transform content on 3d
if s_mean_std_mat is not None:
valid_features = self.stylizer.transform_content_3D(valid_features.transpose(0,1)[None,...])
valid_features = valid_features.squeeze(0).transpose(0,1)
features[app_mask] = valid_features
feature_map = torch.sum(weight[..., None] * features, -2)
acc_map = torch.sum(weight, -1)
# style transfer on 2d
if s_mean_std_mat is not None:
feature_map = self.stylizer.transfer_style_2D(s_mean_std_mat, feature_map.transpose(0,1)[None,...], acc_map)
feature_map = feature_map.squeeze().transpose(0,1)
return feature_map, acc_map
def render_depth_map(self, rays_chunk, is_train=False, ndc_ray=False, N_samples=-1):
# sample points
viewdirs = rays_chunk[:, 3:6]
if ndc_ray:
xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
dists = dists * rays_norm
else:
xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
if self.alphaMask is not None:
alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
alpha_mask = alphas > 0
ray_invalid = ~ray_valid
ray_invalid[ray_valid] |= (~alpha_mask)
ray_valid = ~ray_invalid
sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
if ray_valid.any():
xyz_sampled = self.normalize_coord(xyz_sampled)
sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
validsigma = self.feature2density(sigma_feature)
sigma[ray_valid] = validsigma
alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)
acc_map = torch.sum(weight, -1)
depth_map = torch.sum(weight * z_vals, -1)
depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]
return depth_map # [n_rays]
@torch.no_grad()
def up_sampling_VM(self, plane_coef, line_coef, res_target):
for i in range(len(self.vecMode)):
vec_id = self.vecMode[i]
mat_id_0, mat_id_1 = self.matMode[i]
plane_coef[i] = torch.nn.Parameter(
F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
align_corners=True))
line_coef[i] = torch.nn.Parameter(
F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
return plane_coef, line_coef
@torch.no_grad()
def upsample_volume_grid(self, res_target):
self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)
self.update_stepSize(res_target)
print(f'upsamping to {res_target}')
@torch.no_grad()
def shrink(self, new_aabb):
print("====> shrinking ...")
xyz_min, xyz_max = new_aabb
t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
# print(new_aabb, self.aabb)
# print(t_l, b_r,self.alphaMask.alpha_volume.shape)
t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
b_r = torch.stack([b_r, self.gridSize]).amin(0)
for i in range(len(self.vecMode)):
mode0 = self.vecMode[i]
self.density_line[i] = torch.nn.Parameter(
self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]
)
self.app_line[i] = torch.nn.Parameter(
self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]
)
mode0, mode1 = self.matMode[i]
self.density_plane[i] = torch.nn.Parameter(
self.density_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
)
self.app_plane[i] = torch.nn.Parameter(
self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
)
if not torch.all(self.alphaMask.gridSize == self.gridSize):
t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
correct_aabb = torch.zeros_like(new_aabb)
correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
new_aabb = correct_aabb
newSize = b_r - t_l
self.aabb = new_aabb
self.update_stepSize((newSize[0], newSize[1], newSize[2]))
================================================
FILE: models/tensorBase.py
================================================
import torch
import torch.nn
import torch.nn.functional as F
from .sh import eval_sh_bases
import numpy as np
import time
def positional_encoding(positions, freqs):
freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,)
pts = (positions[..., None] * freq_bands).reshape(
positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF)
pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
return pts
def raw2alpha(sigma, dist):
# sigma, dist [N_rays, N_samples]
alpha = 1. - torch.exp(-sigma*dist)
T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)
weights = alpha * T[:, :-1] # [N_rays, N_samples]
return alpha, weights, T[:,-1:]
def SHRender(xyz_sampled, viewdirs, features):
sh_mult = eval_sh_bases(2, viewdirs)[:, None]
rgb_sh = features.view(-1, 3, sh_mult.shape[-1])
rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5)
return rgb
def RGBRender(xyz_sampled, viewdirs, features):
rgb = features
return rgb
class AlphaGridMask(torch.nn.Module):
def __init__(self, device, aabb, alpha_volume):
super(AlphaGridMask, self).__init__()
self.device = device
self.aabb=aabb.to(self.device)
self.aabbSize = self.aabb[1] - self.aabb[0]
self.invgridSize = 1.0/self.aabbSize * 2
self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:])
self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device)
def sample_alpha(self, xyz_sampled):
xyz_sampled = self.normalize_coord(xyz_sampled)
alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1)
return alpha_vals
def normalize_coord(self, xyz_sampled):
return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1
class MLPRender_Fea(torch.nn.Module):
def __init__(self,inChanel, viewpe=6, feape=6, featureC=128):
super(MLPRender_Fea, self).__init__()
self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel
self.viewpe = viewpe
self.feape = feape
layer1 = torch.nn.Linear(self.in_mlpC, featureC)
layer2 = torch.nn.Linear(featureC, featureC)
layer3 = torch.nn.Linear(featureC,3)
self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
torch.nn.init.constant_(self.mlp[-1].bias, 0)
def forward(self, pts, viewdirs, features):
indata = [features, viewdirs]
if self.feape > 0:
indata += [positional_encoding(features, self.feape)]
if self.viewpe > 0:
indata += [positional_encoding(viewdirs, self.viewpe)]
mlp_in = torch.cat(indata, dim=-1)
rgb = self.mlp(mlp_in)
rgb = torch.sigmoid(rgb)
return rgb
class MLPRender_PE(torch.nn.Module):
def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128):
super(MLPRender_PE, self).__init__()
self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel #
self.viewpe = viewpe
self.pospe = pospe
layer1 = torch.nn.Linear(self.in_mlpC, featureC)
layer2 = torch.nn.Linear(featureC, featureC)
layer3 = torch.nn.Linear(featureC,3)
self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
torch.nn.init.constant_(self.mlp[-1].bias, 0)
def forward(self, pts, viewdirs, features):
indata = [features, viewdirs]
if self.pospe > 0:
indata += [positional_encoding(pts, self.pospe)]
if self.viewpe > 0:
indata += [positional_encoding(viewdirs, self.viewpe)]
mlp_in = torch.cat(indata, dim=-1)
rgb = self.mlp(mlp_in)
rgb = torch.sigmoid(rgb)
return rgb
class MLPRender(torch.nn.Module):
def __init__(self,inChanel, viewpe=6, featureC=128):
super(MLPRender, self).__init__()
self.in_mlpC = (3+2*viewpe*3) + inChanel
self.viewpe = viewpe
layer1 = torch.nn.Linear(self.in_mlpC, featureC)
layer2 = torch.nn.Linear(featureC, featureC)
layer3 = torch.nn.Linear(featureC,3)
self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
torch.nn.init.constant_(self.mlp[-1].bias, 0)
def forward(self, pts, viewdirs, features):
indata = [features, viewdirs]
if self.viewpe > 0:
indata += [positional_encoding(viewdirs, self.viewpe)]
mlp_in = torch.cat(indata, dim=-1)
rgb = self.mlp(mlp_in)
rgb = torch.sigmoid(rgb)
return rgb
class TensorBase(torch.nn.Module):
def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27,
shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0],
density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001,
pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0,
fea2denseAct = 'softplus'):
super(TensorBase, self).__init__()
self.density_n_comp = density_n_comp
self.app_n_comp = appearance_n_comp
self.app_dim = app_dim
self.aabb = aabb
self.alphaMask = alphaMask
self.device=device
self.density_shift = density_shift
self.alphaMask_thres = alphaMask_thres
self.distance_scale = distance_scale
self.rayMarch_weight_thres = rayMarch_weight_thres
self.fea2denseAct = fea2denseAct
self.near_far = near_far
self.step_ratio = step_ratio
self.update_stepSize(gridSize)
self.matMode = [[0,1], [0,2], [1,2]]
self.vecMode = [2, 1, 0]
self.comp_w = [1,1,1]
self.init_svd_volume(gridSize[0], device)
self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC
self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device)
def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device):
if shadingMode == 'MLP_PE':
self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device)
elif shadingMode == 'MLP_Fea':
self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device)
elif shadingMode == 'MLP':
self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device)
elif shadingMode == 'SH':
self.renderModule = SHRender
elif shadingMode == 'RGB':
assert self.app_dim == 3
self.renderModule = RGBRender
else:
print("Unrecognized shading module")
exit()
print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe)
def update_stepSize(self, gridSize):
print("aabb", self.aabb.view(-1))
print("grid size", gridSize)
self.aabbSize = self.aabb[1] - self.aabb[0]
self.invaabbSize = 2.0/self.aabbSize
self.gridSize= torch.LongTensor(gridSize).to(self.device)
self.units=self.aabbSize / (self.gridSize-1)
self.stepSize=torch.mean(self.units)*self.step_ratio
self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))
self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1
print("sampling step size: ", self.stepSize)
print("sampling number: ", self.nSamples)
def init_svd_volume(self, res, device):
pass
def compute_features(self, xyz_sampled):
pass
def compute_densityfeature(self, xyz_sampled):
pass
def compute_appfeature(self, xyz_sampled):
pass
def normalize_coord(self, xyz_sampled):
return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1
def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001):
pass
def get_kwargs(self):
return {
'aabb': self.aabb,
'gridSize':self.gridSize.tolist(),
'density_n_comp': self.density_n_comp,
'appearance_n_comp': self.app_n_comp,
'app_dim': self.app_dim,
'density_shift': self.density_shift,
'alphaMask_thres': self.alphaMask_thres,
'distance_scale': self.distance_scale,
'rayMarch_weight_thres': self.rayMarch_weight_thres,
'fea2denseAct': self.fea2denseAct,
'near_far': self.near_far,
'step_ratio': self.step_ratio,
'shadingMode': self.shadingMode,
'pos_pe': self.pos_pe,
'view_pe': self.view_pe,
'fea_pe': self.fea_pe,
'featureC': self.featureC
}
def save(self, path):
kwargs = self.get_kwargs()
ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()}
if self.alphaMask is not None:
alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
ckpt.update({'alphaMask.shape':alpha_volume.shape})
ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))})
ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})
torch.save(ckpt, path)
def load(self, ckpt):
if 'alphaMask.aabb' in ckpt.keys():
length = np.prod(ckpt['alphaMask.shape'])
alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))
self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device))
self.load_state_dict(ckpt['state_dict'])
def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
N_samples = N_samples if N_samples > 0 else self.nSamples
near, far = self.near_far
interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
if is_train:
interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)
rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1)
return rays_pts, interpx, ~mask_outbbox
def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
N_samples = N_samples if N_samples>0 else self.nSamples
stepsize = self.stepSize
near, far = self.near_far
vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d)
rate_a = (self.aabb[1] - rays_o) / vec
rate_b = (self.aabb[0] - rays_o) / vec
t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)
rng = torch.arange(N_samples)[None].float()
if is_train:
rng = rng.repeat(rays_d.shape[-2],1)
rng += torch.rand_like(rng[:,[0]])
step = stepsize * rng.to(rays_o.device)
interpx = (t_min[...,None] + step)
rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None]
mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1)
return rays_pts, interpx, ~mask_outbbox
def shrink(self, new_aabb, voxel_size):
pass
@torch.no_grad()
def getDenseAlpha(self,gridSize=None):
gridSize = self.gridSize if gridSize is None else gridSize
samples = torch.stack(torch.meshgrid(
torch.linspace(0, 1, gridSize[0]),
torch.linspace(0, 1, gridSize[1]),
torch.linspace(0, 1, gridSize[2]),
), -1).to(self.device)
dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples
# dense_xyz = dense_xyz
# print(self.stepSize, self.distance_scale*self.aabbDiag)
alpha = torch.zeros_like(dense_xyz[...,0])
for i in range(gridSize[0]):
alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2]))
return alpha, dense_xyz
@torch.no_grad()
def updateAlphaMask(self, gridSize=(200,200,200)):
alpha, dense_xyz = self.getDenseAlpha(gridSize)
dense_xyz = dense_xyz.transpose(0,2).contiguous()
alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None]
total_voxels = gridSize[0] * gridSize[1] * gridSize[2]
ks = 3
alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])
alpha[alpha>=self.alphaMask_thres] = 1
alpha[alpha<self.alphaMask_thres] = 0
self.alphaMask = AlphaGridMask(self.device, self.aabb, alpha)
valid_xyz = dense_xyz[alpha>0.5]
xyz_min = valid_xyz.amin(0)
xyz_max = valid_xyz.amax(0)
new_aabb = torch.stack((xyz_min, xyz_max))
total = torch.sum(alpha)
print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100))
return new_aabb
@torch.no_grad()
def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False):
print('========> filtering rays ...')
tt = time.time()
N = torch.tensor(all_rays.shape[:-1]).prod()
mask_filtered = []
idx_chunks = torch.split(torch.arange(N), chunk)
for idx_chunk in idx_chunks:
rays_chunk = all_rays[idx_chunk].to(self.device)
rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
if bbox_only:
vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
rate_a = (self.aabb[1] - rays_o) / vec
rate_b = (self.aabb[0] - rays_o) / vec
t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far)
t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far)
mask_inbbox = t_max > t_min
else:
xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False)
mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)
mask_filtered.append(mask_inbbox.cpu())
mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1])
print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}')
return all_rays[mask_filtered], all_rgbs[mask_filtered]
def feature2density(self, density_features):
if self.fea2denseAct == "softplus":
return F.softplus(density_features+self.density_shift)
elif self.fea2denseAct == "relu":
return F.relu(density_features)
def compute_alpha(self, xyz_locs, length=1):
if self.alphaMask is not None:
alphas = self.alphaMask.sample_alpha(xyz_locs)
alpha_mask = alphas > 0
else:
alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool)
sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)
if alpha_mask.any():
xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask])
sigma_feature = self.compute_densityfeature(xyz_sampled)
validsigma = self.feature2density(sigma_feature)
sigma[alpha_mask] = validsigma
alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1])
return alpha
def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):
# sample points
viewdirs = rays_chunk[:, 3:6]
if ndc_ray:
xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
dists = dists * rays_norm
viewdirs = viewdirs / rays_norm
else:
xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
if self.alphaMask is not None:
alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
alpha_mask = alphas > 0
ray_invalid = ~ray_valid
ray_invalid[ray_valid] |= (~alpha_mask)
ray_valid = ~ray_invalid
sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
if ray_valid.any():
xyz_sampled = self.normalize_coord(xyz_sampled)
sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
validsigma = self.feature2density(sigma_feature)
sigma[ray_valid] = validsigma
alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)
app_mask = weight > self.rayMarch_weight_thres
if app_mask.any():
app_features = self.compute_appfeature(xyz_sampled[app_mask])
valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
rgb[app_mask] = valid_rgbs
acc_map = torch.sum(weight, -1)
rgb_map = torch.sum(weight[..., None] * rgb, -2)
if white_bg or (is_train and torch.rand((1,))<0.5):
rgb_map = rgb_map + (1. - acc_map[..., None])
rgb_map = rgb_map.clamp(0,1)
with torch.no_grad():
depth_map = torch.sum(weight * z_vals, -1)
depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]
return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight
================================================
FILE: opt.py
================================================
import configargparse
def config_parser(cmd=None):
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True,
help='config file path')
parser.add_argument("--expname", type=str,
help='experiment name')
parser.add_argument("--basedir", type=str, default='./log',
help='where to store ckpts and logs')
parser.add_argument("--add_timestamp", type=int, default=0,
help='add timestamp to dir')
parser.add_argument("--datadir", type=str, default='./data/llff/fern',
help='input data directory')
parser.add_argument("--wikiartdir", type=str, default='./data/WikiArt',
help='input data directory')
parser.add_argument("--progress_refresh_rate", type=int, default=10,
help='how many iterations to show psnrs or iters')
parser.add_argument('--with_depth', action='store_true')
parser.add_argument('--downsample_train', type=float, default=1.0)
parser.add_argument('--downsample_test', type=float, default=1.0)
parser.add_argument('--model_name', type=str, default='TensorVMSplit',
choices=['TensorVMSplit', 'TensorCP'])
# loader options
parser.add_argument("--batch_size", type=int, default=4096)
parser.add_argument("--n_iters", type=int, default=30000)
parser.add_argument('--dataset_name', type=str, default='blender',
choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'])
# training options
parser.add_argument("--patch_size", type=int, default=256,
help='patch_size for training')
parser.add_argument("--chunk_size", type=int, default=4096,
help='chunk_size for training')
# learning rate
parser.add_argument("--lr_init", type=float, default=0.02,
help='learning rate')
parser.add_argument("--lr_basis", type=float, default=1e-4,
help='learning rate')
parser.add_argument("--lr_finetune", type=float, default=1e-5,
help='learning rate')
parser.add_argument("--lr_decay_iters", type=int, default=-1,
help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters')
parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1,
help='the target decay ratio; after decay_iters inital lr decays to lr*ratio')
parser.add_argument("--lr_upsample_reset", type=int, default=1,
help='reset lr to inital after upsampling')
# loss
parser.add_argument("--L1_weight_inital", type=float, default=0.0,
help='loss weight')
parser.add_argument("--L1_weight_rest", type=float, default=0,
help='loss weight')
parser.add_argument("--Ortho_weight", type=float, default=0.0,
help='loss weight')
parser.add_argument("--TV_weight_density", type=float, default=0.0,
help='loss weight')
parser.add_argument("--TV_weight_app", type=float, default=0.0,
help='loss weight')
parser.add_argument("--TV_weight_feature", type=float, default=0.0,
help='loss weight')
parser.add_argument("--style_weight", type=float, default=0,
help='loss weight')
parser.add_argument("--content_weight", type=float, default=0,
help='loss weight')
parser.add_argument("--image_tv_weight", type=float, default=0,
help='loss weight')
parser.add_argument("--featuremap_tv_weight", type=float, default=0,
help='loss weight')
# model
# volume options
parser.add_argument("--n_lamb_sigma", type=int, action="append")
parser.add_argument("--n_lamb_sh", type=int, action="append")
parser.add_argument("--data_dim_color", type=int, default=27)
parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001,
help='mask points in ray marching')
parser.add_argument("--alpha_mask_thre", type=float, default=0.0001,
help='threshold for creating alpha mask volume')
parser.add_argument("--distance_scale", type=float, default=25,
help='scaling sampling distance for computation')
parser.add_argument("--density_shift", type=float, default=-10,
help='shift density in softplus; making density = 0 when feature == 0')
# network decoder
parser.add_argument("--shadingMode", type=str, default="MLP_Fea",
help='which shading mode to use')
parser.add_argument("--pos_pe", type=int, default=6,
help='number of pe for pos')
parser.add_argument("--view_pe", type=int, default=6,
help='number of pe for view')
parser.add_argument("--fea_pe", type=int, default=6,
help='number of pe for features')
parser.add_argument("--featureC", type=int, default=128,
help='hidden feature channel in MLP')
# test option
parser.add_argument("--ckpt", type=str, default=None,
help='specific weights npy file to reload for coarse network')
parser.add_argument("--render_only", type=int, default=0)
parser.add_argument("--render_test", type=int, default=0)
parser.add_argument("--render_train", type=int, default=0)
parser.add_argument("--render_path", type=int, default=0)
parser.add_argument("--export_mesh", type=int, default=0)
parser.add_argument("--style_img", type=str, required=False)
# rendering options
parser.add_argument('--lindisp', default=False, action="store_true",
help='use disparity depth sampling')
parser.add_argument("--perturb", type=float, default=1.,
help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--accumulate_decay", type=float, default=0.998)
parser.add_argument("--fea2denseAct", type=str, default='softplus')
parser.add_argument('--ndc_ray', type=int, default=0)
parser.add_argument('--nSamples', type=int, default=1e6,
help='sample point each ray, pass 1e6 if automatic adjust')
parser.add_argument('--step_ratio',type=float,default=0.5)
## blender flags
parser.add_argument("--white_bkgd", action='store_true',
help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument('--N_voxel_init',
type=int,
default=100**3)
parser.add_argument('--N_voxel_final',
type=int,
default=300**3)
parser.add_argument("--upsamp_list", type=int, action="append")
parser.add_argument("--update_AlphaMask_list", type=int, action="append")
parser.add_argument('--idx_view',
type=int,
default=0)
# logging/saving options
parser.add_argument("--N_vis", type=int, default=5,
help='N images to vis')
parser.add_argument("--vis_every", type=int, default=10000,
help='frequency of visualize the image')
if cmd is not None:
return parser.parse_args(cmd)
else:
return parser.parse_args()
================================================
FILE: renderer.py
================================================
import torch,os,imageio,sys
from tqdm.auto import tqdm
from dataLoader.ray_utils import get_rays
from models.tensoRF import raw2alpha, TensorVMSplit, AlphaGridMask
from utils import *
from dataLoader.ray_utils import ndc_rays_blender, denormalize_vgg, normalize_vgg
def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False,
render_feature=False, style_img=None, device='cuda'):
rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], []
features, accs = [], []
s_mean_std_mat = None
if style_img is not None:
with torch.no_grad():
style_feature = tensorf.encoder(normalize_vgg(style_img))
s_mean_std_mat = tensorf.stylizer.get_style_mean_std_matrix(style_feature.relu3_1.flatten(2))
N_rays_all = rays.shape[0]
for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
if render_feature:
feature_map, acc_map = tensorf.render_feature_map(rays_chunk, s_mean_std_mat=s_mean_std_mat, is_train=is_train, ndc_ray=ndc_ray, N_samples=N_samples)
features.append(feature_map)
accs.append(acc_map)
else:
rgb_map, depth_map = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples)
rgbs.append(rgb_map)
depth_maps.append(depth_map)
if render_feature:
if style_img is not None:
return torch.cat(features), torch.cat(accs), style_feature
return torch.cat(features), torch.cat(accs)
return torch.cat(rgbs), None, torch.cat(depth_maps), None, None
def OctreeRender_trilinear_fast_depth(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, is_train=False, device='cuda'):
depth_maps = []
N_rays_all = rays.shape[0]
for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
depth_map = tensorf.render_depth_map(rays_chunk, is_train=is_train, ndc_ray=ndc_ray, N_samples=N_samples)
depth_maps.append(depth_map)
return torch.cat(depth_maps)
@torch.no_grad()
def evaluation_feature(test_dataset, tensorf, args, renderer, chunk_size=2048, savePath=None, N_vis=10, prtx='', N_samples=-1,
white_bg=False, ndc_ray=False, compute_extra_metrics=False, style_img=None, device='cuda'):
'''
To see if the decoded feature map is similar to gt rgb map
'''
PSNRs, rgb_maps, vis_feature_maps = [], [], []
ssims,l_alex,l_vgg=[],[],[]
os.makedirs(savePath, exist_ok=True)
os.makedirs(savePath + '/feature', exist_ok=True)
W, H = test_dataset.img_wh
try:
tqdm._instances.clear()
except Exception:
pass
near_far = test_dataset.near_far
img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays_stack.shape[0] // N_vis,1)
idxs = list(range(0, test_dataset.all_rays_stack.shape[0], img_eval_interval))
for idx, samples in tqdm(enumerate(test_dataset.all_rays_stack[0::img_eval_interval]), file=sys.stdout):
rays = samples.view(-1,samples.shape[-1])
if style_img is None:
feature_map, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray,
white_bg = white_bg, render_feature=True, device=device)
else:
feature_map, _, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray,
white_bg = white_bg, render_feature=True, style_img=style_img, device=device)
feature_map = feature_map.reshape(H, W, 256)[None,...].permute(0,3,1,2)
recon_rgb = denormalize_vgg(tensorf.decoder(feature_map))
recon_rgb = recon_rgb.permute(0,2,3,1).clamp(0,1)
vis_feature_map = torch.sigmoid(feature_map[:, [1,2,3], :, :].permute(0,2,3,1))
if test_dataset.white_bg:
mask = test_dataset.all_masks[idx:idx+1].to(device)
recon_rgb = mask*recon_rgb + (1.-mask)
vis_feature_map = mask*vis_feature_map + (1.-mask)
recon_rgb = recon_rgb.reshape(H, W, 3).cpu()
vis_feature_map = vis_feature_map.squeeze().cpu()
if len(test_dataset.all_rgbs_stack):
gt_rgb = test_dataset.all_rgbs_stack[idxs[idx]].view(H, W, 3)
loss = torch.mean((recon_rgb - gt_rgb) ** 2)
PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))
if compute_extra_metrics:
ssim = rgb_ssim(recon_rgb, gt_rgb, 1)
l_a = rgb_lpips(gt_rgb.numpy(), recon_rgb.numpy(), 'alex', tensorf.device)
l_v = rgb_lpips(gt_rgb.numpy(), recon_rgb.numpy(), 'vgg', tensorf.device)
ssims.append(ssim)
l_alex.append(l_a)
l_vgg.append(l_v)
recon_rgb = (recon_rgb.numpy() * 255).astype('uint8')
vis_feature_map = (vis_feature_map.numpy() * 255).astype('uint8')
gt_rgb = (gt_rgb.numpy() * 255).astype('uint8')
if savePath is not None:
# rgb_map = np.concatenate((recon_rgb, gt_rgb), axis=1)
rgb_maps.append(recon_rgb)
vis_feature_maps.append(vis_feature_map)
imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', recon_rgb)
imageio.imwrite(f'{savePath}/feature/feature_{idx:03d}.png', vis_feature_map)
imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
imageio.mimwrite(f'{savePath}/feature/feature_video.mp4', np.stack(vis_feature_maps), fps=30, quality=8)
if PSNRs:
psnr = np.mean(np.asarray(PSNRs))
if compute_extra_metrics:
ssim = np.mean(np.asarray(ssims))
l_a = np.mean(np.asarray(l_alex))
l_v = np.mean(np.asarray(l_vgg))
np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
else:
np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
return PSNRs
@torch.no_grad()
def evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, chunk_size=2048, savePath=None, N_vis=5, prtx='', N_samples=-1,
white_bg=False, ndc_ray=False, compute_extra_metrics=False, style_img=None, device='cuda'):
PSNRs, rgb_maps, vis_feature_maps, depth_maps = [], [], [], []
ssims,l_alex,l_vgg=[],[],[]
os.makedirs(savePath, exist_ok=True)
os.makedirs(savePath + '/feature', exist_ok=True)
W, H = test_dataset.img_wh
try:
tqdm._instances.clear()
except Exception:
pass
near_far = test_dataset.near_far
for idx, c2w in tqdm(enumerate(c2ws)):
c2w = torch.FloatTensor(c2w)
rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3)
if ndc_ray:
rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
rays = torch.cat([rays_o, rays_d], 1).reshape(H, W, 6).permute(2,0,1) # (6,H,W)
rays = rays.permute(1,2,0).reshape(-1,6) # (H * W, 6)
if style_img is None:
feature_map, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray,
white_bg = white_bg, render_feature=True, device=device)
else:
feature_map, _, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray,
white_bg = white_bg, render_feature=True, style_img=style_img, device=device)
feature_map = feature_map.reshape(H, W, 256)[None,...].permute(0,3,1,2)
recon_rgb = denormalize_vgg(tensorf.decoder(feature_map))
recon_rgb = recon_rgb.permute(0,2,3,1).clamp(0,1)
recon_rgb = recon_rgb.reshape(H, W, 3).cpu()
recon_rgb = (recon_rgb.numpy() * 255).astype('uint8')
rgb_maps.append(recon_rgb)
vis_feature_map = torch.sigmoid(feature_map[:, [1,2,3], :, :].permute(0,2,3,1))
vis_feature_map = vis_feature_map.squeeze().cpu()
vis_feature_map = (vis_feature_map.numpy() * 255).astype('uint8')
vis_feature_maps.append(vis_feature_map)
if savePath is not None:
imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', recon_rgb)
imageio.imwrite(f'{savePath}/feature/feature_{idx:03d}.png', vis_feature_map)
imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
imageio.mimwrite(f'{savePath}/feature/feature_video.mp4', np.stack(vis_feature_maps), fps=30, quality=8)
if PSNRs:
psnr = np.mean(np.asarray(PSNRs))
if compute_extra_metrics:
ssim = np.mean(np.asarray(ssims))
l_a = np.mean(np.asarray(l_alex))
l_v = np.mean(np.asarray(l_vgg))
np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
else:
np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
return PSNRs
@torch.no_grad()
def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
PSNRs, rgb_maps, depth_maps = [], [], []
ssims,l_alex,l_vgg=[],[],[]
os.makedirs(savePath, exist_ok=True)
os.makedirs(savePath+"/rgbd", exist_ok=True)
try:
tqdm._instances.clear()
except Exception:
pass
near_far = test_dataset.near_far
img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays_stack.shape[0] // N_vis,1)
idxs = list(range(0, test_dataset.all_rays_stack.shape[0], img_eval_interval))
for idx, samples in tqdm(enumerate(test_dataset.all_rays_stack[0::img_eval_interval]), file=sys.stdout):
W, H = test_dataset.img_wh
rays = samples.view(-1,samples.shape[-1])
rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples,
ndc_ray=ndc_ray, white_bg = white_bg, device=device)
rgb_map = rgb_map.clamp(0.0, 1.0)
rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
if len(test_dataset.all_rgbs_stack):
gt_rgb = test_dataset.all_rgbs_stack[idxs[idx]].view(H, W, 3)
loss = torch.mean((rgb_map - gt_rgb) ** 2)
PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))
if compute_extra_metrics:
ssim = rgb_ssim(rgb_map, gt_rgb, 1)
l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device)
l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device)
ssims.append(ssim)
l_alex.append(l_a)
l_vgg.append(l_v)
rgb_map = (rgb_map.numpy() * 255).astype('uint8')
# rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
rgb_maps.append(rgb_map)
depth_maps.append(depth_map)
if savePath is not None:
imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
# rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', depth_map)
imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10)
imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10)
if PSNRs:
psnr = np.mean(np.asarray(PSNRs))
if compute_extra_metrics:
ssim = np.mean(np.asarray(ssims))
l_a = np.mean(np.asarray(l_alex))
l_v = np.mean(np.asarray(l_vgg))
np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
else:
np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
return PSNRs
@torch.no_grad()
def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
PSNRs, rgb_maps, depth_maps = [], [], []
ssims,l_alex,l_vgg=[],[],[]
os.makedirs(savePath, exist_ok=True)
os.makedirs(savePath+"/rgbd", exist_ok=True)
try:
tqdm._instances.clear()
except Exception:
pass
near_far = test_dataset.near_far
for idx, c2w in tqdm(enumerate(c2ws)):
W, H = test_dataset.img_wh
c2w = torch.FloatTensor(c2w)
rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3)
if ndc_ray:
rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples,
ndc_ray=ndc_ray, white_bg = white_bg, device=device)
rgb_map = rgb_map.clamp(0.0, 1.0)
rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
rgb_map = (rgb_map.numpy() * 255).astype('uint8')
# rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
rgb_maps.append(rgb_map)
depth_maps.append(depth_map)
if savePath is not None:
imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8)
if PSNRs:
psnr = np.mean(np.asarray(PSNRs))
if compute_extra_metrics:
ssim = np.mean(np.asarray(ssims))
l_a = np.mean(np.asarray(l_alex))
l_v = np.mean(np.asarray(l_vgg))
np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
else:
np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
return PSNRs
================================================
FILE: scripts/test.sh
================================================
CUDA_VISIBLE_DEVICES=$1 python train.py \
--config configs/llff.txt \
--ckpt log/trex/trex.th \
--render_only 1
--render_test 1
================================================
FILE: scripts/test_feature.sh
================================================
expname=trex
CUDA_VISIBLE_DEVICES=$1 python train_feature.py \
--config configs/llff_feature.txt \
--datadir ./data/nerf_llff_data/trex \
--expname $expname \
--ckpt ./log_feature/$expname/$expname.th \
--render_only 1 \
--render_test 0 \
--render_path 1 \
--chunk_size 1024
================================================
FILE: scripts/test_style.sh
================================================
expname=trex
CUDA_VISIBLE_DEVICES=$1 python train_style.py \
--config configs/llff_style.txt \
--datadir ./data/nerf_llff_data/trex \
--expname $expname \
--ckpt log_style/$expname/$expname.th \
--style_img path/to/reference/style/image \
--render_only 1 \
--render_train 0 \
--render_test 0 \
--render_path 1 \
--chunk_size 1024 \
--rm_weight_mask_thre 0.0001 \
================================================
FILE: scripts/train.sh
================================================
CUDA_VISIBLE_DEVICES=$1 python train.py --config=configs/llff.txt
================================================
FILE: scripts/train_feature.sh
================================================
CUDA_VISIBLE_DEVICES=$1 python train_feature.py --config=configs/llff_feature.txt
================================================
FILE: scripts/train_style.sh
================================================
CUDA_VISIBLE_DEVICES=$1 python train_style.py --config=configs/llff_style.txt
================================================
FILE: train.py
================================================
import os
from tqdm.auto import tqdm
from opt import config_parser
import json, random
from renderer import *
from utils import *
from torch.utils.tensorboard import SummaryWriter
import datetime
from dataLoader import dataset_dict
import sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
renderer = OctreeRender_trilinear_fast
class SimpleSampler:
def __init__(self, total, batch):
self.total = total
self.batch = batch
self.curr = total
self.ids = None
def nextids(self):
self.curr+=self.batch
if self.curr + self.batch > self.total:
self.ids = torch.LongTensor(np.random.permutation(self.total))
self.curr = 0
return self.ids[self.curr:self.curr+self.batch]
@torch.no_grad()
def export_mesh(args):
ckpt = torch.load(args.ckpt, map_location=device)
kwargs = ckpt['kwargs']
kwargs.update({'device': device})
tensorf = eval(args.model_name)(**kwargs)
tensorf.load(ckpt)
alpha,_ = tensorf.getDenseAlpha()
convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)
@torch.no_grad()
def render_test(args):
# init dataset
dataset = dataset_dict[args.dataset_name]
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
white_bg = test_dataset.white_bg
ndc_ray = args.ndc_ray
if not os.path.exists(args.ckpt):
print('the ckpt path does not exists!!')
return
ckpt = torch.load(args.ckpt, map_location=device)
kwargs = ckpt['kwargs']
kwargs.update({'device': device})
tensorf = eval(args.model_name)(**kwargs)
tensorf.load(ckpt)
logfolder = os.path.dirname(args.ckpt)
if args.render_train:
os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
if args.render_test:
os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
if args.render_path:
c2ws = test_dataset.render_path
os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
def reconstruction(args):
# init dataset
dataset = dataset_dict[args.dataset_name]
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
white_bg = train_dataset.white_bg
near_far = train_dataset.near_far
ndc_ray = args.ndc_ray
# init resolution
upsamp_list = args.upsamp_list
update_AlphaMask_list = args.update_AlphaMask_list
n_lamb_sigma = args.n_lamb_sigma
n_lamb_sh = args.n_lamb_sh
if args.add_timestamp:
logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
else:
logfolder = f'{args.basedir}/{args.expname}'
# init log file
os.makedirs(logfolder, exist_ok=True)
os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
os.makedirs(f'{logfolder}/rgba', exist_ok=True)
summary_writer = SummaryWriter(logfolder)
# init parameters
# tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
aabb = train_dataset.scene_bbox.to(device)
reso_cur = N_to_reso(args.N_voxel_init, aabb)
nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
if args.ckpt is not None:
ckpt = torch.load(args.ckpt, map_location=device)
kwargs = ckpt['kwargs']
kwargs.update({'device':device})
tensorf = eval(args.model_name)(**kwargs)
tensorf.load(ckpt)
else:
tensorf = eval(args.model_name)(aabb, reso_cur, device,
density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,
shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale,
pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct)
grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
if args.lr_decay_iters > 0:
lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
else:
args.lr_decay_iters = args.n_iters
lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)
print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))
#linear in logrithmic space
if upsamp_list is not None:
N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]
torch.cuda.empty_cache()
PSNRs,PSNRs_test = [],[0]
allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
if not args.ndc_ray:
allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True)
trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
Ortho_reg_weight = args.Ortho_weight
print("initial Ortho_reg_weight", Ortho_reg_weight)
L1_reg_weight = args.L1_weight_inital
print("initial L1_reg_weight", L1_reg_weight)
TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app
tvreg = TVLoss()
print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}")
pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
for iteration in pbar:
ray_idx = trainingSampler.nextids()
rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device)
#rgb_map, alphas_map, depth_map, weights, uncertainty
rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(rays_train, tensorf, chunk=args.batch_size,
N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True)
loss = torch.mean((rgb_map - rgb_train) ** 2)
# loss
total_loss = loss
if Ortho_reg_weight > 0:
loss_reg = tensorf.vector_comp_diffs()
total_loss += Ortho_reg_weight*loss_reg
summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
if L1_reg_weight > 0:
loss_reg_L1 = tensorf.density_L1()
total_loss += L1_reg_weight*loss_reg_L1
summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
if TV_weight_density>0:
TV_weight_density *= lr_factor
loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density
total_loss = total_loss + loss_tv
summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
if TV_weight_app>0:
TV_weight_app *= lr_factor
loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app
total_loss = total_loss + loss_tv
summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
loss = loss.detach().item()
PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
summary_writer.add_scalar('train/mse', loss, global_step=iteration)
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * lr_factor
# Print the current values of the losses.
if iteration % args.progress_refresh_rate == 0:
pbar.set_description(
f'Iteration {iteration:05d}:'
+ f' train_psnr = {float(np.mean(PSNRs)):.2f}'
+ f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
+ f' mse = {loss:.6f}'
)
PSNRs = []
if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:
PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False)
summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
if update_AlphaMask_list is not None and iteration in update_AlphaMask_list:
if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution
reso_mask = reso_cur
new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))
if iteration == update_AlphaMask_list[0]:
tensorf.shrink(new_aabb)
# tensorVM.alphaMask = None
L1_reg_weight = args.L1_weight_rest
print("continuing L1_reg_weight", L1_reg_weight)
if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
# filter rays outside the bbox
allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs)
trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
if upsamp_list is not None and iteration in upsamp_list:
n_voxels = N_voxel_list.pop(0)
reso_cur = N_to_reso(n_voxels, tensorf.aabb)
nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
tensorf.upsample_volume_grid(reso_cur)
if args.lr_upsample_reset:
print("reset lr to initial")
lr_scale = 1 #0.1 ** (iteration / args.n_iters)
else:
lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
tensorf.save(f'{logfolder}/{args.expname}.th')
if args.render_train:
os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
if args.render_test:
os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
if args.render_path:
c2ws = test_dataset.render_path
# c2ws = test_dataset.poses
print('========>',c2ws.shape)
os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(20211202)
np.random.seed(20211202)
args = config_parser()
print(args)
if args.export_mesh:
export_mesh(args)
if args.render_only and (args.render_test or args.render_path):
render_test(args)
else:
reconstruction(args)
================================================
FILE: train_feature.py
================================================
import os
from unittest.mock import patch
from tqdm.auto import tqdm
from opt import config_parser
import json, random
from renderer import *
from utils import *
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
import torchvision.transforms.functional as TF
import datetime
from dataLoader import dataset_dict
import sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
renderer = OctreeRender_trilinear_fast
class SimpleSampler:
def __init__(self, total, batch):
self.total = total
self.batch = batch
self.curr = total
self.ids = None
def nextids(self):
self.curr+=self.batch
if self.curr + self.batch > self.total:
self.ids = torch.LongTensor(np.random.permutation(self.total))
self.curr = 0
return self.ids[self.curr:self.curr+self.batch]
def InfiniteSampler(n):
# i = 0
i = n - 1
order = np.random.permutation(n)
while True:
yield order[i]
i += 1
if i >= n:
np.random.seed()
order = np.random.permutation(n)
i = 0
class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
def __init__(self, num_samples):
self.num_samples = num_samples
def __iter__(self):
return iter(InfiniteSampler(self.num_samples))
def __len__(self):
return 2 ** 31
@torch.no_grad()
def render_test(args):
# init dataset
dataset = dataset_dict[args.dataset_name]
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
white_bg = test_dataset.white_bg
ndc_ray = args.ndc_ray
if not os.path.exists(args.ckpt):
print('the ckpt path does not exists!!')
return
ckpt = torch.load(args.ckpt, map_location=device)
kwargs = ckpt['kwargs']
kwargs.update({'device': device})
tensorf = eval(args.model_name)(**kwargs)
tensorf.change_to_feature_mod(args.n_lamb_sh ,device)
tensorf.load(ckpt)
tensorf.eval()
tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre
logfolder = os.path.dirname(args.ckpt)
if args.render_train:
os.makedirs(f'{logfolder}/{args.expname}/imgs_train_all', exist_ok=True)
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
evaluation_feature(train_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_train_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
if args.render_test:
os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
evaluation_feature(test_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_test_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
if args.render_path:
c2ws = test_dataset.render_path
os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
evaluation_feature_path(test_dataset,tensorf, c2ws, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_path_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
def reconstruction(args):
# init dataset
dataset = dataset_dict[args.dataset_name]
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
white_bg = train_dataset.white_bg
near_far = train_dataset.near_far
h_rays, w_rays = train_dataset.img_wh[1], train_dataset.img_wh[0]
ndc_ray = args.ndc_ray
patch_size = args.patch_size
if args.add_timestamp:
logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
else:
logfolder = f'{args.basedir}/{args.expname}'
# init log file
os.makedirs(logfolder, exist_ok=True)
summary_writer = SummaryWriter(logfolder)
# init parameters
# tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
aabb = train_dataset.scene_bbox.to(device)
# TODO: need to update reso_cur
reso_cur = N_to_reso(args.N_voxel_init, aabb)
nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
assert args.ckpt is not None, 'Have to be pre-trained to get density fielded!'
ckpt = torch.load(args.ckpt, map_location=device)
kwargs = ckpt['kwargs']
kwargs.update({'device':device})
tensorf = eval(args.model_name)(**kwargs)
tensorf.load(ckpt)
tensorf.change_to_feature_mod(args.n_lamb_sh ,device)
tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre
train_dataset.prepare_feature_data(tensorf.encoder)
grad_vars = tensorf.get_optparam_groups_feature_mod(args.lr_init, args.lr_basis)
if args.lr_decay_iters > 0:
lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
else:
args.lr_decay_iters = args.n_iters
lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)
print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))
torch.cuda.empty_cache()
PSNRs = []
allrays, allfeatures = train_dataset.all_rays, train_dataset.all_features
allrays_stack, allrgbs_stack = train_dataset.all_rays_stack, train_dataset.all_rgbs_stack
if not args.ndc_ray:
allrays, allfeatures = tensorf.filtering_rays(allrays, allfeatures, bbox_only=True)
trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
frameSampler = iter(InfiniteSamplerWrapper(allrays_stack.size(0))) # every next(sampler) returns a frame index
TV_weight_feature = args.TV_weight_feature
tvreg = TVLoss()
print(f"initial TV_weight_feature: {TV_weight_feature}")
pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
for iteration in pbar:
feature_loss, pixel_loss = 0., 0.
if iteration%2==0:
ray_idx = trainingSampler.nextids()
rays_train, features_train = allrays[ray_idx], allfeatures[ray_idx].to(device)
feature_map, _ = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg = white_bg,
ndc_ray=ndc_ray, render_feature=True, device=device, is_train=True)
feature_loss = torch.mean((feature_map - features_train) ** 2)
else:
frame_idx = next(frameSampler)
start_h = np.random.randint(0, h_rays-patch_size+1)
start_w = np.random.randint(0, w_rays-patch_size+1)
if white_bg:
# move random sampled patches into center
mid_h, mid_w = (h_rays-patch_size+1)/2, (w_rays-patch_size+1)/2
if mid_h-start_h>=1:
start_h += np.random.randint(0, mid_h-start_h)
elif mid_h-start_h<=-1:
start_h += np.random.randint(mid_h-start_h, 0)
if mid_w-start_w>=1:
start_w += np.random.randint(0, mid_w-start_w)
elif mid_w-start_w<=-1:
start_w += np.random.randint(mid_w-start_w, 0)
rays_train = allrays_stack[frame_idx, start_h:start_h+patch_size,
start_w:start_w+patch_size, :].reshape(-1, 6).to(device)
# [patch*patch, 6]
rgbs_train = allrgbs_stack[frame_idx, start_h:(start_h+patch_size),
start_w:(start_w+patch_size), :].to(device)
# [patch, patch, 3]
feature_map, _ = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg=white_bg,
ndc_ray=ndc_ray, render_feature=True, device=device, is_train=True)
feature_map = feature_map.reshape(patch_size, patch_size, 256)[None,...].permute(0,3,1,2)
recon_rgb = tensorf.decoder(feature_map)
rgbs_train = rgbs_train[None,...].permute(0,3,1,2)
img_enc = tensorf.encoder(normalize_vgg(rgbs_train))
recon_rgb_enc = tensorf.encoder(recon_rgb)
feature_loss =(F.mse_loss(recon_rgb_enc.relu4_1, img_enc.relu4_1) +
F.mse_loss(recon_rgb_enc.relu3_1, img_enc.relu3_1)) / 10
recon_rgb = denormalize_vgg(recon_rgb)
pixel_loss = torch.mean((recon_rgb - rgbs_train) ** 2)
total_loss = pixel_loss + feature_loss
# loss
# NOTE: Calculate feature TV loss rather than appearence TV loss
if TV_weight_feature>0:
TV_weight_feature *= lr_factor
loss_tv = tensorf.TV_loss_feature(tvreg)*TV_weight_feature
total_loss = total_loss + loss_tv
summary_writer.add_scalar('train/reg_tv_feature', loss_tv.detach().item(), global_step=iteration)
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if pixel_loss == 0:
feature_loss = feature_loss.detach().item()
PSNRs.append(-10.0 * np.log(feature_loss) / np.log(10.0))
summary_writer.add_scalar('train/PSNR_feature', PSNRs[-1], global_step=iteration)
summary_writer.add_scalar('train/mse_feature', feature_loss, global_step=iteration)
else:
pixel_loss = pixel_loss.detach().item()
PSNRs.append(-10.0 * np.log(pixel_loss) / np.log(10.0))
summary_writer.add_scalar('train/PSNR_pixel', PSNRs[-1], global_step=iteration)
summary_writer.add_scalar('train/mse_pixel', pixel_loss, global_step=iteration)
summary_writer.add_scalar('train/mse_recon_feature', feature_loss.detach().item(), global_step=iteration)
for param_group in optimizer.param_groups:
param_group['lr'] = param_group['lr'] * lr_factor
# Print the current values of the losses.
if iteration % args.progress_refresh_rate == 0:
pbar.set_description(
f'Iteration {iteration:05d}:'
+ f' psnr = {float(np.mean(PSNRs)):.2f}'
)
PSNRs = []
if iteration % (args.progress_refresh_rate*20) == 1:
summary_writer.add_image('output', make_grid([rgbs_train.squeeze(),
recon_rgb.clamp(0, 1).squeeze()],
nrow=2, padding=0, normalize=False),
global_step=iteration)
tensorf.save(f'{logfolder}/{args.expname}.th')
if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(20211202)
np.random.seed(20211202)
args = config_parser()
print(args)
if args.render_only:
render_test(args)
else:
reconstruction(args)
================================================
FILE: train_style.py
================================================
import os
from tqdm.auto import tqdm
from opt import config_parser
from PIL import Image, ImageFile
from pathlib import Path
from torchvision.utils import make_grid
import torchvision.transforms.functional as TF
from renderer import *
from utils import *
from torch.utils.tensorboard import SummaryWriter
import datetime
from dataLoader import dataset_dict
from dataLoader.styleLoader import getDataLoader
from models.styleModules import cal_mse_content_loss, cal_adain_style_loss
import sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
renderer = OctreeRender_trilinear_fast
depth_renderer = OctreeRender_trilinear_fast_depth
class SimpleSampler:
def __init__(self, total, batch):
self.total = total
self.batch = batch
self.curr = total
self.ids = None
def nextids(self):
self.curr+=self.batch
if self.curr + self.batch > self.total:
self.ids = torch.LongTensor(np.random.permutation(self.total))
self.curr = 0
return self.ids[self.curr:self.curr+self.batch]
def InfiniteSampler(n):
# i = 0
i = n - 1
order = np.random.permutation(n)
while True:
yield order[i]
i += 1
if i >= n:
np.random.seed()
order = np.random.permutation(n)
i = 0
class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
def __init__(self, num_samples):
self.num_samples = num_samples
def __iter__(self):
return iter(InfiniteSampler(self.num_samples))
def __len__(self):
return 2 ** 31
@torch.no_grad()
def render_test(args):
# init dataset
dataset = dataset_dict[args.dataset_name]
ndc_ray = args.ndc_ray
if not os.path.exists(args.ckpt):
print('the ckpt path does not exists!!')
return
assert args.style_img is not None, 'Must specify a style image!'
ckpt = torch.load(args.ckpt, map_location=device)
kwargs = ckpt['kwargs']
kwargs.update({'device': device})
tensorf = eval(args.model_name)(**kwargs)
tensorf.change_to_feature_mod(args.n_lamb_sh, device)
tensorf.change_to_style_mod(device)
tensorf.load(ckpt)
tensorf.eval()
tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre
logfolder = os.path.dirname(args.ckpt)
trans = T.Compose([T.Resize(size=(256,256)), T.ToTensor()])
style_img = trans(Image.open(args.style_img)).cuda()[None, ...]
style_name = Path(args.style_img).stem
if args.render_train:
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
os.makedirs(f'{logfolder}/{args.expname}/imgs_train_all/{style_name}', exist_ok=True)
evaluation_feature(train_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_train_all/{style_name}',
N_vis=-1, N_samples=-1, white_bg = train_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)
if args.render_test:
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all/{style_name}', exist_ok=True)
evaluation_feature(test_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_test_all/{style_name}',
N_vis=-1, N_samples=-1, white_bg = test_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)
if args.render_path:
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
c2ws = test_dataset.render_path
os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all/{style_name}', exist_ok=True)
evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_path_all/{style_name}',
N_vis=-1, N_samples=-1, white_bg = test_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)
def reconstruction(args):
# init dataset
dataset = dataset_dict[args.dataset_name]
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
white_bg = train_dataset.white_bg
near_far = train_dataset.near_far
h_rays, w_rays = train_dataset.img_wh[1], train_dataset.img_wh[0]
ndc_ray = args.ndc_ray
patch_size = args.patch_size # ground truth image patch size when training
Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError
ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
style_loader = getDataLoader(args.wikiartdir, batch_size=1, sampler=InfiniteSamplerWrapper,
image_side_length=256, num_workers=2)
style_iter = iter(style_loader)
if args.add_timestamp:
logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
else:
logfolder = f'{args.basedir}/{args.expname}'
# init log file
os.makedirs(logfolder, exist_ok=True)
summary_writer = SummaryWriter(logfolder)
# init parameters
# tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
aabb = train_dataset.scene_bbox.to(device)
# TODO: need to update reso_cur
reso_cur = N_to_reso(args.N_voxel_init, aabb)
nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
assert args.ckpt is not None, 'Have to be pre-trained to get density fielded!'
ckpt = torch.load(args.ckpt, map_location=device)
kwargs = ckpt['kwargs']
kwargs.update({'device':device})
tensorf = eval(args.model_name)(**kwargs)
tensorf.change_to_feature_mod(args.n_lamb_sh, device)
tensorf.load(ckpt)
tensorf.change_to_style_mod(device)
tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre
tvreg = TVLoss()
grad_vars = tensorf.get_optparam_groups_style_mod(args.lr_basis, args.lr_finetune)
if args.lr_decay_iters > 0:
lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
else:
args.lr_decay_iters = args.n_iters
lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)
print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
tensorf.train()
optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))
torch.cuda.empty_cache()
allrays_stack, allrgbs_stack = train_dataset.all_rays_stack, train_dataset.all_rgbs_stack
frameSampler = iter(InfiniteSamplerWrapper(allrays_stack.size(0))) # every next(sampler) returns a frame index
pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
for iteration in pbar:
# get style_img, this style_img has NOT been normalized according to the pretrained VGGmodel
style_img = next(style_iter)[0].to(device)
# randomly sample patch_size*patch_size patch from given frame
frame_idx = next(frameSampler)
start_h = np.random.randint(0, h_rays-patch_size+1)
start_w = np.random.randint(0, w_rays-patch_size+1)
if white_bg:
# move random sampled patches into center
mid_h, mid_w = (h_rays-patch_size+1)/2, (w_rays-patch_size+1)/2
if mid_h-start_h>=1:
start_h += np.random.randint(0, mid_h-start_h)
elif mid_h-start_h<=-1:
start_h += np.random.randint(mid_h-start_h, 0)
if mid_w-start_w>=1:
start_w += np.random.randint(0, mid_w-start_w)
elif mid_w-start_w<=-1:
start_w += np.random.randint(mid_w-start_w, 0)
rays_train = allrays_stack[frame_idx, start_h:start_h+patch_size, start_w:start_w+patch_size, :]\
.reshape(-1, 6).to(device)
# [patch*patch, 6]
rgbs_train = allrgbs_stack[frame_idx, start_h:(start_h+patch_size),
start_w:(start_w+patch_size), :].to(device)
# [patch, patch, 3]
feature_map, acc_map, style_feature = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg = white_bg,
ndc_ray=ndc_ray, render_feature=True, style_img=style_img, device=device, is_train=True)
feature_map = feature_map.reshape(patch_size, patch_size, 256)[None,...].permute(0,
gitextract_j81cq_2h/ ├── README.md ├── configs/ │ ├── llff.txt │ ├── llff_feature.txt │ ├── llff_style.txt │ ├── nerf_synthetic.txt │ ├── nerf_synthetic_feature.txt │ └── nerf_synthetic_style.txt ├── dataLoader/ │ ├── __init__.py │ ├── blender.py │ ├── colmap2nerf.py │ ├── llff.py │ ├── nsvf.py │ ├── ray_utils.py │ ├── styleLoader.py │ ├── tankstemple.py │ └── your_own_data.py ├── extra/ │ ├── auto_run_paramsets.py │ └── compute_metrics.py ├── models/ │ ├── VGG.py │ ├── __init__.py │ ├── sh.py │ ├── styleModules.py │ ├── tensoRF.py │ └── tensorBase.py ├── opt.py ├── renderer.py ├── scripts/ │ ├── test.sh │ ├── test_feature.sh │ ├── test_style.sh │ ├── train.sh │ ├── train_feature.sh │ └── train_style.sh ├── train.py ├── train_feature.py ├── train_style.py └── utils.py
SYMBOL INDEX (246 symbols across 21 files)
FILE: dataLoader/blender.py
class BlenderDataset (line 13) | class BlenderDataset(Dataset):
method __init__ (line 14) | def __init__(self, datadir, split='train', downsample=1.0, is_stack=Fa...
method read_depth (line 35) | def read_depth(self, filename):
method read_meta (line 39) | def read_meta(self):
method prepare_feature_data (line 103) | def prepare_feature_data(self, encoder, chunk=8):
method define_transforms (line 125) | def define_transforms(self):
method define_proj_mat (line 128) | def define_proj_mat(self):
method world2ndc (line 131) | def world2ndc(self,points,lindisp=None):
method __len__ (line 135) | def __len__(self):
method __getitem__ (line 138) | def __getitem__(self, idx):
FILE: dataLoader/colmap2nerf.py
function parse_args (line 23) | def parse_args():
function do_system (line 40) | def do_system(arg):
function run_ffmpeg (line 47) | def run_ffmpeg(args):
function run_colmap (line 69) | def run_colmap(args):
function variance_of_laplacian (line 99) | def variance_of_laplacian(image):
function sharpness (line 102) | def sharpness(imagePath):
function qvec2rotmat (line 108) | def qvec2rotmat(qvec):
function rotmat (line 125) | def rotmat(a, b):
function closest_point_2_lines (line 133) | def closest_point_2_lines(oa, da, ob, db): # returns point closest to bo...
FILE: dataLoader/llff.py
function normalize (line 12) | def normalize(v):
function average_poses (line 17) | def average_poses(poses):
function center_poses (line 54) | def center_poses(poses, blender2opencv):
function viewmatrix (line 81) | def viewmatrix(z, up, pos):
function render_path_spiral (line 91) | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=...
function get_spiral (line 102) | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
function get_interpolation_path (line 121) | def get_interpolation_path(c2ws_all, steps=30):
class LLFFDataset (line 147) | class LLFFDataset(Dataset):
method __init__ (line 148) | def __init__(self, datadir, split='train', downsample=4, is_stack=Fals...
method read_meta (line 168) | def read_meta(self):
method prepare_feature_data (line 257) | def prepare_feature_data(self, encoder, chunk=8):
method define_transforms (line 279) | def define_transforms(self):
method __len__ (line 282) | def __len__(self):
method __getitem__ (line 285) | def __getitem__(self, idx):
FILE: dataLoader/nsvf.py
function pose_spherical (line 29) | def pose_spherical(theta, phi, radius):
class NSVF (line 36) | class NSVF(Dataset):
method __init__ (line 38) | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800...
method bbox2corners (line 56) | def bbox2corners(self):
method read_meta (line 63) | def read_meta(self):
method define_transforms (line 132) | def define_transforms(self):
method define_proj_mat (line 135) | def define_proj_mat(self):
method world2ndc (line 138) | def world2ndc(self, points):
method __len__ (line 142) | def __len__(self):
method __getitem__ (line 147) | def __getitem__(self, idx):
FILE: dataLoader/ray_utils.py
function depth2dist (line 9) | def depth2dist(z_vals, cos_angle):
function ndc2dist (line 18) | def ndc2dist(ndc_pts, cos_angle):
function get_ray_directions (line 24) | def get_ray_directions(H, W, focal, center=None):
function get_ray_directions_blender (line 45) | def get_ray_directions_blender(H, W, focal, center=None):
function get_rays (line 66) | def get_rays(directions, c2w):
function ndc_rays_blender (line 90) | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
function ndc_rays (line 109) | def ndc_rays(H, W, focal, near, rays_o, rays_d):
function sample_pdf (line 129) | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
function dda (line 174) | def dda(rays_o, rays_d, bbox_3D):
function ray_marcher (line 184) | def ray_marcher(rays,
function read_pfm (line 231) | def read_pfm(filename):
function ndc_bbox (line 269) | def ndc_bbox(all_rays):
function denormalize_vgg (line 281) | def denormalize_vgg(img):
FILE: dataLoader/styleLoader.py
function getDataLoader (line 6) | def getDataLoader(dataset_path, batch_size, sampler, image_side_length=2...
FILE: dataLoader/tankstemple.py
function circle (line 12) | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
function cross (line 21) | def cross(x, y, axis=0):
function normalize (line 26) | def normalize(x, axis=-1, order=2):
function cat (line 38) | def cat(x, axis=1):
function look_at_rotation (line 44) | def look_at_rotation(camera_position, at=None, up=None, inverse=False, c...
function gen_path (line 77) | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
class TanksTempleDataset (line 87) | class TanksTempleDataset(Dataset):
method __init__ (line 89) | def __init__(self, datadir, split='train', downsample=4.0, wh=[1920,10...
method bbox2corners (line 108) | def bbox2corners(self):
method read_meta (line 115) | def read_meta(self):
method prepare_feature_data (line 196) | def prepare_feature_data(self, encoder, chunk=4):
method define_transforms (line 219) | def define_transforms(self):
method define_proj_mat (line 222) | def define_proj_mat(self):
method world2ndc (line 225) | def world2ndc(self, points):
method __len__ (line 229) | def __len__(self):
method __getitem__ (line 234) | def __getitem__(self, idx):
FILE: dataLoader/your_own_data.py
class YourOwnDataset (line 13) | class YourOwnDataset(Dataset):
method __init__ (line 14) | def __init__(self, datadir, split='train', downsample=1.0, is_stack=Fa...
method read_depth (line 35) | def read_depth(self, filename):
method read_meta (line 39) | def read_meta(self):
method define_transforms (line 102) | def define_transforms(self):
method define_proj_mat (line 105) | def define_proj_mat(self):
method world2ndc (line 108) | def world2ndc(self,points,lindisp=None):
method __len__ (line 112) | def __len__(self):
method __getitem__ (line 115) | def __getitem__(self, idx):
FILE: extra/auto_run_paramsets.py
function getFolderLocker (line 7) | def getFolderLocker(logFolder):
function releaseFolderLocker (line 15) | def releaseFolderLocker(logFolder):
function getStopFolder (line 18) | def getStopFolder(logFolder):
function get_param_str (line 22) | def get_param_str(key, val):
function get_param_list (line 28) | def get_param_list(param_dict):
function run_program (line 167) | def run_program(gpu, expname, param):
FILE: extra/compute_metrics.py
function init_lpips (line 11) | def init_lpips(net_name, device):
function rgb_lpips (line 17) | def rgb_lpips(np_gt, np_im, net_name, device):
function findItem (line 25) | def findItem(items, target):
function rgb_ssim (line 34) | def rgb_ssim(img0, img1, max_val,
FILE: models/VGG.py
class Encoder (line 8) | class Encoder(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 23) | def forward(self, x):
class Decoder (line 34) | class Decoder(nn.Module):
method __init__ (line 38) | def __init__(self, ckpt_path=None):
method forward (line 67) | def forward(self, x):
class DownBlock (line 74) | class DownBlock(nn.Module):
method __init__ (line 76) | def __init__(self, in_dim, out_dim, down='conv'):
method forward (line 93) | def forward(self, x):
class UpBlock (line 98) | class UpBlock(nn.Module):
method __init__ (line 100) | def __init__(self, in_dim, out_dim, skip_dim=None, up='nearest'):
method _pad (line 125) | def _pad(self, x, y):
method forward (line 142) | def forward(self, x, skip=None):
class UNetDecoder (line 150) | class UNetDecoder(nn.Module):
method __init__ (line 152) | def __init__(self, in_dim=256):
method forward (line 184) | def forward(self, feats):
class PlainDecoder (line 197) | class PlainDecoder(nn.Module):
method __init__ (line 198) | def __init__(self) -> None:
method forward (line 219) | def forward(self, x):
FILE: models/sh.py
function eval_sh (line 34) | def eval_sh(deg, sh, dirs):
function eval_sh_bases (line 87) | def eval_sh_bases(deg, dirs):
FILE: models/styleModules.py
function calc_mean_std (line 6) | def calc_mean_std(x, eps=1e-8):
function cal_adain_style_loss (line 16) | def cal_adain_style_loss(x, y):
function cal_mse_content_loss (line 29) | def cal_mse_content_loss(x, y):
class LearnableIN (line 34) | class LearnableIN(nn.Module):
method __init__ (line 38) | def __init__(self, dim=256):
method forward (line 42) | def forward(self, x):
class SimpleLinearStylizer (line 47) | class SimpleLinearStylizer(nn.Module):
method __init__ (line 48) | def __init__(self, input_dim=256, embed_dim=32, n_layers=3 ) -> None:
method _vectorized_covariance (line 76) | def _vectorized_covariance(self, x):
method get_content_matrix (line 81) | def get_content_matrix(self, c):
method get_style_mean_std_matrix (line 98) | def get_style_mean_std_matrix(self, s):
method transform_content_3D (line 117) | def transform_content_3D(self, c):
method transfer_style_2D (line 131) | def transfer_style_2D(self, s_mean_std_mat, c, acc_map):
class AdaAttN (line 151) | class AdaAttN(nn.Module):
method __init__ (line 154) | def __init__(self, qk_dim, v_dim):
method forward (line 166) | def forward(self, q, k):
class AdaAttN_new_IN (line 199) | class AdaAttN_new_IN(nn.Module):
method __init__ (line 202) | def __init__(self, qk_dim, v_dim):
method forward (line 215) | def forward(self, q, k):
class AdaAttN_woin (line 248) | class AdaAttN_woin(nn.Module):
method __init__ (line 251) | def __init__(self, qk_dim, v_dim):
method forward (line 263) | def forward(self, q, k):
FILE: models/tensoRF.py
class TensorVMSplit (line 5) | class TensorVMSplit(TensorBase):
method __init__ (line 6) | def __init__(self, aabb, gridSize, device, **kargs):
method change_to_feature_mod (line 11) | def change_to_feature_mod(self, feature_n_comp, device):
method change_to_style_mod (line 29) | def change_to_style_mod(self, device='cuda'):
method init_svd_volume (line 45) | def init_svd_volume(self, res, device):
method init_feature_svd (line 50) | def init_feature_svd(self, device):
method init_one_svd (line 54) | def init_one_svd(self, n_component, gridSize, scale, device):
method get_optparam_groups (line 68) | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_netwo...
method get_optparam_groups_feature_mod (line 76) | def get_optparam_groups_feature_mod(self, lr_init_spatialxyz, lr_init_...
method get_optparam_groups_style_mod (line 83) | def get_optparam_groups_style_mod(self, lr_init_network, lr_finetune):
method vectorDiffs (line 92) | def vectorDiffs(self, vector_comps):
method vector_comp_diffs (line 103) | def vector_comp_diffs(self):
method density_L1 (line 106) | def density_L1(self):
method TV_loss_density (line 112) | def TV_loss_density(self, reg):
method TV_loss_app (line 118) | def TV_loss_app(self, reg):
method TV_loss_feature (line 124) | def TV_loss_feature(self, reg):
method compute_densityfeature (line 132) | def compute_densityfeature(self, xyz_sampled):
method compute_appfeature (line 149) | def compute_appfeature(self, xyz_sampled):
method compute_feature (line 167) | def compute_feature(self, xyz_sampled):
method render_feature_map (line 187) | def render_feature_map(self, rays_chunk, s_mean_std_mat=None, is_train...
method render_depth_map (line 247) | def render_depth_map(self, rays_chunk, is_train=False, ndc_ray=False, ...
method up_sampling_VM (line 287) | def up_sampling_VM(self, plane_coef, line_coef, res_target):
method upsample_volume_grid (line 302) | def upsample_volume_grid(self, res_target):
method shrink (line 310) | def shrink(self, new_aabb):
FILE: models/tensorBase.py
function positional_encoding (line 9) | def positional_encoding(positions, freqs):
function raw2alpha (line 17) | def raw2alpha(sigma, dist):
function SHRender (line 27) | def SHRender(xyz_sampled, viewdirs, features):
function RGBRender (line 34) | def RGBRender(xyz_sampled, viewdirs, features):
class AlphaGridMask (line 39) | class AlphaGridMask(torch.nn.Module):
method __init__ (line 40) | def __init__(self, device, aabb, alpha_volume):
method sample_alpha (line 50) | def sample_alpha(self, xyz_sampled):
method normalize_coord (line 56) | def normalize_coord(self, xyz_sampled):
class MLPRender_Fea (line 60) | class MLPRender_Fea(torch.nn.Module):
method __init__ (line 61) | def __init__(self,inChanel, viewpe=6, feape=6, featureC=128):
method forward (line 74) | def forward(self, pts, viewdirs, features):
class MLPRender_PE (line 86) | class MLPRender_PE(torch.nn.Module):
method __init__ (line 87) | def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128):
method forward (line 100) | def forward(self, pts, viewdirs, features):
class MLPRender (line 112) | class MLPRender(torch.nn.Module):
method __init__ (line 113) | def __init__(self,inChanel, viewpe=6, featureC=128):
method forward (line 126) | def forward(self, pts, viewdirs, features):
class TensorBase (line 138) | class TensorBase(torch.nn.Module):
method __init__ (line 139) | def __init__(self, aabb, gridSize, device, density_n_comp = 8, appeara...
method init_render_func (line 175) | def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featu...
method update_stepSize (line 192) | def update_stepSize(self, gridSize):
method init_svd_volume (line 205) | def init_svd_volume(self, res, device):
method compute_features (line 208) | def compute_features(self, xyz_sampled):
method compute_densityfeature (line 211) | def compute_densityfeature(self, xyz_sampled):
method compute_appfeature (line 214) | def compute_appfeature(self, xyz_sampled):
method normalize_coord (line 217) | def normalize_coord(self, xyz_sampled):
method get_optparam_groups (line 220) | def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network ...
method get_kwargs (line 223) | def get_kwargs(self):
method save (line 247) | def save(self, path):
method load (line 257) | def load(self, ckpt):
method sample_ray_ndc (line 265) | def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
method sample_ray (line 276) | def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
method shrink (line 298) | def shrink(self, new_aabb, voxel_size):
method getDenseAlpha (line 302) | def getDenseAlpha(self,gridSize=None):
method updateAlphaMask (line 320) | def updateAlphaMask(self, gridSize=(200,200,200)):
method filtering_rays (line 346) | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=1024...
method feature2density (line 378) | def feature2density(self, density_features):
method compute_alpha (line 385) | def compute_alpha(self, xyz_locs, length=1):
method forward (line 408) | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=F...
FILE: opt.py
function config_parser (line 3) | def config_parser(cmd=None):
FILE: renderer.py
function OctreeRender_trilinear_fast (line 9) | def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1,...
function OctreeRender_trilinear_fast_depth (line 40) | def OctreeRender_trilinear_fast_depth(rays, tensorf, chunk=4096, N_sampl...
function evaluation_feature (line 53) | def evaluation_feature(test_dataset, tensorf, args, renderer, chunk_size...
function evaluation_feature_path (line 139) | def evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, chunk...
function evaluation (line 203) | def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vi...
function evaluation_path (line 269) | def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None,...
FILE: train.py
class SimpleSampler (line 24) | class SimpleSampler:
method __init__ (line 25) | def __init__(self, total, batch):
method nextids (line 31) | def nextids(self):
function export_mesh (line 40) | def export_mesh(args):
function render_test (line 53) | def render_test(args):
function reconstruction (line 89) | def reconstruction(args):
FILE: train_feature.py
class SimpleSampler (line 27) | class SimpleSampler:
method __init__ (line 28) | def __init__(self, total, batch):
method nextids (line 34) | def nextids(self):
function InfiniteSampler (line 41) | def InfiniteSampler(n):
class InfiniteSamplerWrapper (line 53) | class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
method __init__ (line 54) | def __init__(self, num_samples):
method __iter__ (line 57) | def __iter__(self):
method __len__ (line 60) | def __len__(self):
function render_test (line 64) | def render_test(args):
function reconstruction (line 104) | def reconstruction(args):
FILE: train_style.py
class SimpleSampler (line 30) | class SimpleSampler:
method __init__ (line 31) | def __init__(self, total, batch):
method nextids (line 37) | def nextids(self):
function InfiniteSampler (line 44) | def InfiniteSampler(n):
class InfiniteSamplerWrapper (line 56) | class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
method __init__ (line 57) | def __init__(self, num_samples):
method __iter__ (line 60) | def __iter__(self):
method __len__ (line 63) | def __len__(self):
function render_test (line 67) | def render_test(args):
function reconstruction (line 113) | def reconstruction(args):
FILE: utils.py
function visualize_depth_numpy (line 11) | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
function init_log (line 28) | def init_log(log, keys):
function visualize_depth (line 33) | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
function N_to_reso (line 53) | def N_to_reso(n_voxels, bbox):
function cal_n_samples (line 59) | def cal_n_samples(reso, step_ratio=0.5):
function init_lpips (line 66) | def init_lpips(net_name, device):
function rgb_lpips (line 72) | def rgb_lpips(np_gt, np_im, net_name, device):
function findItem (line 80) | def findItem(items, target):
function rgb_ssim (line 89) | def rgb_ssim(img0, img1, max_val,
class TVLoss (line 139) | class TVLoss(nn.Module):
method __init__ (line 140) | def __init__(self):
method forward (line 143) | def forward(self,x):
method _tensor_size (line 164) | def _tensor_size(self,t):
function convert_sdf_samples_to_ply (line 171) | def convert_sdf_samples_to_ply(
function get_checkerboard (line 355) | def get_checkerboard(fg, n=8):
Condensed preview — 36 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (219K chars).
[
{
"path": "README.md",
"chars": 4457,
"preview": "# [*CVPR 2023*] StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields\n## [Project page](https://kunhao-liu.gith"
},
{
"path": "configs/llff.txt",
"chars": 544,
"preview": "\ndataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nexpname = trex\nbasedir = ./log\n\ndownsample_train = 4.0\nndc_ray"
},
{
"path": "configs/llff_feature.txt",
"chars": 457,
"preview": "dataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nckpt = ./log/trex/trex.th\nexpname = trex\nbasedir = ./log_featur"
},
{
"path": "configs/llff_style.txt",
"chars": 480,
"preview": "dataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nckpt = ./log_feature/trex/trex.th\nexpname = trex\nbasedir = ./lo"
},
{
"path": "configs/nerf_synthetic.txt",
"chars": 541,
"preview": "\ndataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nexpname = lego\nbasedir = ./log\n\nn_iters = 30000\nbatch_size"
},
{
"path": "configs/nerf_synthetic_feature.txt",
"chars": 466,
"preview": "dataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nckpt = ./log/lego/lego.th\nexpname = lego\nbasedir = ./log_fea"
},
{
"path": "configs/nerf_synthetic_style.txt",
"chars": 391,
"preview": "dataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nckpt = ./log_feature/lego/lego.th\nexpname = lego\nbasedir = ."
},
{
"path": "dataLoader/__init__.py",
"chars": 375,
"preview": "from .llff import LLFFDataset\nfrom .blender import BlenderDataset\nfrom .nsvf import NSVF\nfrom .tankstemple import TanksT"
},
{
"path": "dataLoader/blender.py",
"chars": 6247,
"preview": "import torch,cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\n"
},
{
"path": "dataLoader/colmap2nerf.py",
"chars": 10564,
"preview": "#!/usr/bin/env python3\n\n# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and "
},
{
"path": "dataLoader/llff.py",
"chars": 11105,
"preview": "import torch\nfrom torch.utils.data import Dataset\nimport glob\nimport numpy as np\nimport os\nfrom PIL import Image\nfrom to"
},
{
"path": "dataLoader/nsvf.py",
"chars": 6560,
"preview": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision"
},
{
"path": "dataLoader/ray_utils.py",
"chars": 10543,
"preview": "import torch, re\nimport numpy as np\nfrom torch import searchsorted\nfrom kornia import create_meshgrid\n\n\n# from utils imp"
},
{
"path": "dataLoader/styleLoader.py",
"chars": 626,
"preview": "from torch.utils.data import DataLoader\nfrom torchvision import datasets\nimport torchvision.transforms as T\n\n\ndef getDat"
},
{
"path": "dataLoader/tankstemple.py",
"chars": 9953,
"preview": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision"
},
{
"path": "dataLoader/your_own_data.py",
"chars": 5026,
"preview": "import torch,cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\n"
},
{
"path": "extra/auto_run_paramsets.py",
"chars": 8274,
"preview": "import os\nimport threading, queue\nimport numpy as np\nimport time\n\n\ndef getFolderLocker(logFolder):\n while True:\n "
},
{
"path": "extra/compute_metrics.py",
"chars": 6532,
"preview": "import os, math\nimport numpy as np\nimport scipy.signal\nfrom typing import List, Optional\nfrom PIL import Image\nimport os"
},
{
"path": "models/VGG.py",
"chars": 6818,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom collections import namedtuple\nimport torchvision"
},
{
"path": "models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/sh.py",
"chars": 5231,
"preview": "import torch\n\n################## sh function ##################\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n"
},
{
"path": "models/styleModules.py",
"chars": 10046,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport random\n\ndef calc_mean_std(x, eps=1e-8):\n "
},
{
"path": "models/tensoRF.py",
"chars": 16669,
"preview": "from .tensorBase import *\nfrom .VGG import Encoder, Decoder, UNetDecoder, PlainDecoder\nfrom .styleModules import Learnab"
},
{
"path": "models/tensorBase.py",
"chars": 18150,
"preview": "import torch\nimport torch.nn\nimport torch.nn.functional as F\nfrom .sh import eval_sh_bases\nimport numpy as np\nimport tim"
},
{
"path": "opt.py",
"chars": 7633,
"preview": "import configargparse\n\ndef config_parser(cmd=None):\n parser = configargparse.ArgumentParser()\n parser.add_argument"
},
{
"path": "renderer.py",
"chars": 14465,
"preview": "import torch,os,imageio,sys\nfrom tqdm.auto import tqdm\nfrom dataLoader.ray_utils import get_rays\nfrom models.tensoRF imp"
},
{
"path": "scripts/test.sh",
"chars": 130,
"preview": "CUDA_VISIBLE_DEVICES=$1 python train.py \\\n--config configs/llff.txt \\\n--ckpt log/trex/trex.th \\ \n--render_only 1 \n--rend"
},
{
"path": "scripts/test_feature.sh",
"chars": 274,
"preview": "expname=trex\nCUDA_VISIBLE_DEVICES=$1 python train_feature.py \\\n--config configs/llff_feature.txt \\\n--datadir ./data/nerf"
},
{
"path": "scripts/test_style.sh",
"chars": 362,
"preview": "expname=trex\nCUDA_VISIBLE_DEVICES=$1 python train_style.py \\\n--config configs/llff_style.txt \\\n--datadir ./data/nerf_llf"
},
{
"path": "scripts/train.sh",
"chars": 65,
"preview": "CUDA_VISIBLE_DEVICES=$1 python train.py --config=configs/llff.txt"
},
{
"path": "scripts/train_feature.sh",
"chars": 81,
"preview": "CUDA_VISIBLE_DEVICES=$1 python train_feature.py --config=configs/llff_feature.txt"
},
{
"path": "scripts/train_style.sh",
"chars": 77,
"preview": "CUDA_VISIBLE_DEVICES=$1 python train_style.py --config=configs/llff_style.txt"
},
{
"path": "train.py",
"chars": 12844,
"preview": "\nimport os\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\n\n\n\nimport json, random\nfrom renderer import *\nfrom u"
},
{
"path": "train_feature.py",
"chars": 11107,
"preview": "\nimport os\nfrom unittest.mock import patch\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\n\n\n\nimport json, rand"
},
{
"path": "train_style.py",
"chars": 12355,
"preview": "\nimport os\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\nfrom PIL import Image, ImageFile\nfrom pathlib import"
},
{
"path": "utils.py",
"chars": 11395,
"preview": "import cv2,torch\nimport numpy as np\nfrom PIL import Image\nimport torchvision.transforms as T\nimport torch.nn.functional "
}
]
About this extraction
This page contains the full source code of the Kunhao-Liu/StyleRF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 36 files (205.9 KB), approximately 60.2k tokens, and a symbol index with 246 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.