Repository: Kunhao-Liu/StyleRF Branch: main Commit: 43611c3ec7ba Files: 36 Total size: 205.9 KB Directory structure: gitextract_j81cq_2h/ ├── README.md ├── configs/ │ ├── llff.txt │ ├── llff_feature.txt │ ├── llff_style.txt │ ├── nerf_synthetic.txt │ ├── nerf_synthetic_feature.txt │ └── nerf_synthetic_style.txt ├── dataLoader/ │ ├── __init__.py │ ├── blender.py │ ├── colmap2nerf.py │ ├── llff.py │ ├── nsvf.py │ ├── ray_utils.py │ ├── styleLoader.py │ ├── tankstemple.py │ └── your_own_data.py ├── extra/ │ ├── auto_run_paramsets.py │ └── compute_metrics.py ├── models/ │ ├── VGG.py │ ├── __init__.py │ ├── sh.py │ ├── styleModules.py │ ├── tensoRF.py │ └── tensorBase.py ├── opt.py ├── renderer.py ├── scripts/ │ ├── test.sh │ ├── test_feature.sh │ ├── test_style.sh │ ├── train.sh │ ├── train_feature.sh │ └── train_style.sh ├── train.py ├── train_feature.py ├── train_style.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # [*CVPR 2023*] StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields ## [Project page](https://kunhao-liu.github.io/StyleRF/) | [Paper](https://arxiv.org/abs/2303.10598) This repository contains a pytorch implementation for the paper: [StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields](https://arxiv.org/abs/2303.10598). StyleRF is an innovative 3D style transfer technique that achieves superior 3D stylization quality with precise geometry reconstruction and it can generalize to various new styles in a zero-shot manner. ![teaser](https://kunhao-liu.github.io/StyleRF/resources/teaser.png) --- ## Installation > Tested on Ubuntu 20.04 + Pytorch 1.12.1 Install environment: ``` conda create -n StyleRF python=3.9 conda activate StyleRF pip install torch torchvision pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg kornia lpips tensorboard ``` ## Datasets Please put the datasets in `./data`. You can put the datasets elsewhere if you modify the corresponding paths in the configs. ### 3D scene datasets * [nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) * [llff](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) ### Style image dataset * [WikiArt](https://www.kaggle.com/datasets/ipythonx/wikiart-gangogh-creating-art-gan) ## Quick Start We provide some trained checkpoints in: [StyleRF checkpoints](https://drive.google.com/drive/folders/1nF9-6lTIhktG5JjNvnmdYOo1LTvtK7Dw?usp=share_link) Then modify the following attributes in `scripts/test_style.sh`: * `--config`: choose `configs/llff_style.txt` or `configs/nerf_synthetic_style.txt` according to which type of dataset is being used * `--datadir`: dataset's path * `--ckpt`: checkpoint's path * `--style_img`: reference style image's path To generate stylized novel views: ``` bash scripts/test_style.sh [GPU ID] ``` The rendered stylized images can then be found in the directory under the checkpoint's path. ## Training > Current settings in `configs` are tested on one NVIDIA RTX A5000 Graphics Card with 24G memory. To reduce memory consumption, you can set `batch_size`, `chunk_size` or `patch_size` to a smaller number. We follow the following 3 steps of training: ### 1. Train original TensoRF This step is for reconstructing the density field, which contains more precise geometry details compared to mesh-based methods. You can skip this step by directly downloading pre-trained checkpoints provided by [TensoRF checkpoints](https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm). The configs are stored in `configs/llff.txt` and `configs/nerf_synthetic.txt`. For the details of the settings, please also refer to [TensoRF](https://github.com/apchenstu/TensoRF). The checkpoints are stored in `./log` by default. You can train the original TensoRF by: ``` bash script/train.sh [GPU ID] ``` ### 2. Feature grid training stage This step is for reconstructing the 3D gird containing the VGG features. The configs are stored in `configs/llff_feature.txt` and `configs/nerf_synthetic_feature.txt`, in which `ckpt` specifies the checkpoints trained in the **first** step. The checkpoints are stored in `./log_feature` by default. Then run: ``` bash script/train_feature.sh [GPU ID] ``` ### 3. Stylization training stage This step is for training the style transfer modules. The configs are stored in `configs/llff_style.txt` and `configs/nerf_synthetic_style.txt`, in which `ckpt` specifies the checkpoints trained in the **second** step. The checkpoints are stored in `./log_style` by default. Then run: ``` bash script/train_style.sh [GPU ID] ``` --- ## Training on 360 Unbounded Scenes The code for training StyleRF on the Tanks&Temples dataset is available on the `360` branch. To access it, run `git checkout 360`. ## Acknowledgments This repo is heavily based on the [TensoRF](https://github.com/apchenstu/TensoRF). Thank them for sharing their amazing work! ## Citation If you find our code or paper helps, please consider citing: ``` @inproceedings{liu2023stylerf, title={StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields}, author={Liu, Kunhao and Zhan, Fangneng and Chen, Yiwen and Zhang, Jiahui and Yu, Yingchen and El Saddik, Abdulmotaleb and Lu, Shijian and Xing, Eric P}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={8338--8348}, year={2023} } ``` ================================================ FILE: configs/llff.txt ================================================ dataset_name = llff datadir = ./data/nerf_llff_data/trex expname = trex basedir = ./log downsample_train = 4.0 ndc_ray = 1 n_iters = 25000 batch_size = 4096 N_voxel_init = 2097156 # 128**3 N_voxel_final = 262144000 # 640**3 upsamp_list = [2000,3000,4000,5500] update_AlphaMask_list = [2500] N_vis = -1 # vis all testing images vis_every = 10000 render_test = 1 render_path = 1 n_lamb_sigma = [16,4,4] n_lamb_sh = [48,12,12] shadingMode = MLP_Fea fea2denseAct = relu view_pe = 0 fea_pe = 0 TV_weight_density = 1.0 TV_weight_app = 1.0 ================================================ FILE: configs/llff_feature.txt ================================================ dataset_name = llff datadir = ./data/nerf_llff_data/trex ckpt = ./log/trex/trex.th expname = trex basedir = ./log_feature TV_weight_feature = 80 downsample_train = 4.0 ndc_ray = 1 n_iters = 25000 patch_size = 256 batch_size = 4096 chunk_size = 4096 N_voxel_init = 2097156 # 128**3 N_voxel_final = 262144000 # 640**3 upsamp_list = [2000,3000,4000,5500] update_AlphaMask_list = [2500] n_lamb_sigma = [16,4,4] n_lamb_sh = [48,12,12] fea2denseAct = relu ================================================ FILE: configs/llff_style.txt ================================================ dataset_name = llff datadir = ./data/nerf_llff_data/trex ckpt = ./log_feature/trex/trex.th expname = trex basedir = ./log_style nSamples = 300 patch_size = 256 chunk_size = 2048 content_weight = 1 style_weight = 20 featuremap_tv_weight = 0 image_tv_weight = 0 rm_weight_mask_thre = 0.001 downsample_train = 4.0 ndc_ray = 1 n_iters = 25000 n_lamb_sigma = [16,4,4] n_lamb_sh = [48,12,12] N_voxel_init = 2097156 # 128**3 N_voxel_final = 262144000 # 640**3 fea2denseAct = relu ================================================ FILE: configs/nerf_synthetic.txt ================================================ dataset_name = blender datadir = ./data/nerf_synthetic/lego expname = lego basedir = ./log n_iters = 30000 batch_size = 4096 N_voxel_init = 2097156 # 128**3 N_voxel_final = 27000000 # 300**3 upsamp_list = [2000,3000,4000,5500,7000] update_AlphaMask_list = [2000,4000] N_vis = 5 vis_every = 10000 render_test = 1 n_lamb_sigma = [16,16,16] n_lamb_sh = [48,48,48] model_name = TensorVMSplit shadingMode = MLP_Fea fea2denseAct = softplus view_pe = 2 fea_pe = 2 L1_weight_inital = 8e-5 L1_weight_rest = 4e-5 rm_weight_mask_thre = 1e-4 ================================================ FILE: configs/nerf_synthetic_feature.txt ================================================ dataset_name = blender datadir = ./data/nerf_synthetic/lego ckpt = ./log/lego/lego.th expname = lego basedir = ./log_feature TV_weight_feature = 10 n_iters = 25000 patch_size = 256 batch_size = 4096 chunk_size = 4096 N_voxel_init = 2097156 # 128**3 N_voxel_final = 27000000 # 300**3 upsamp_list = [2000,3000,4000,5500,7000] update_AlphaMask_list = [2000,4000] rm_weight_mask_thre = 0.01 n_lamb_sigma = [16,16,16] n_lamb_sh = [48,48,48] fea2denseAct = softplus ================================================ FILE: configs/nerf_synthetic_style.txt ================================================ dataset_name = blender datadir = ./data/nerf_synthetic/lego ckpt = ./log_feature/lego/lego.th expname = lego basedir = ./log_style patch_size = 256 chunk_size = 2048 content_weight = 1 style_weight = 20 rm_weight_mask_thre = 0.01 n_iters = 25000 n_lamb_sigma = [16,16,16] n_lamb_sh = [48,48,48] N_voxel_init = 2097156 # 128**3 N_voxel_final = 27000000 # 300**3 fea2denseAct = softplus ================================================ FILE: dataLoader/__init__.py ================================================ from .llff import LLFFDataset from .blender import BlenderDataset from .nsvf import NSVF from .tankstemple import TanksTempleDataset from .your_own_data import YourOwnDataset dataset_dict = {'blender': BlenderDataset, 'llff':LLFFDataset, 'tankstemple':TanksTempleDataset, 'nsvf':NSVF, 'own_data':YourOwnDataset} ================================================ FILE: dataLoader/blender.py ================================================ import torch,cv2 from torch.utils.data import Dataset import json from tqdm import tqdm import os from PIL import Image from torchvision import transforms as T from .ray_utils import * class BlenderDataset(Dataset): def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): self.N_vis = N_vis self.root_dir = datadir self.split = split self.is_stack = is_stack self.img_wh = (int(800/downsample),int(800/downsample)) self.define_transforms() self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) self.read_meta() self.define_proj_mat() self.white_bg = True self.near_far = [2.0,6.0] self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) self.downsample=downsample def read_depth(self, filename): depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) return depth def read_meta(self): with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: self.meta = json.load(f) w, h = self.img_wh self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh # ray directions for all pixels, same for all images (same H, W, focal) self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3) self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float() self.image_paths = [] self.poses = [] self.all_rays = [] self.all_rgbs = [] self.all_masks = [] self.all_depth = [] self.downsample=1.0 img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# frame = self.meta['frames'][i] pose = np.array(frame['transform_matrix']) @ self.blender2opencv c2w = torch.FloatTensor(pose) self.poses += [c2w] image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") self.image_paths += [image_path] img = Image.open(image_path) if self.downsample!=1.0: img = img.resize(self.img_wh, Image.LANCZOS) img = self.transform(img) # (4, h, w) img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA self.all_masks.append(img[:, -1:].reshape(h,w,1)) # (h, w, 1) A img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB self.all_rgbs += [img] rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) self.all_masks = torch.stack(self.all_masks) # (n_frames, h, w, 1) self.poses = torch.stack(self.poses) all_rays = self.all_rays all_rgbs = self.all_rgbs self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6) self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3) if self.is_stack: self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames]),h,w,6) avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True) self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6) self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) @torch.no_grad() def prepare_feature_data(self, encoder, chunk=8): ''' Prepare feature maps as training data. ''' assert self.is_stack, 'Dataset should contain original stacked taining data!' print('====> prepare_feature_data ...') frames_num, h, w, _ = self.all_rgbs_stack.size() features = [] for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)): rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda() features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1 # resize to the size of rgb map so that rays can match features_chunk = T.functional.resize(features_chunk, size=(h,w), interpolation=T.InterpolationMode.BILINEAR) features.append(features_chunk.detach().cpu().requires_grad_(False)) self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256) self.all_features = self.all_features_stack.reshape(-1, 256) print('prepare_feature_data Done!') def define_transforms(self): self.transform = T.ToTensor() def define_proj_mat(self): self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] def world2ndc(self,points,lindisp=None): device = points.device return (points - self.center.to(device)) / self.radius.to(device) def __len__(self): return len(self.all_rgbs) def __getitem__(self, idx): if self.split == 'train': # use data in the buffers sample = {'rays': self.all_rays[idx], 'rgbs': self.all_rgbs[idx]} else: # create data for each image separately img = self.all_rgbs[idx] rays = self.all_rays[idx] mask = self.all_masks[idx] # for quantity evaluation sample = {'rays': rays, 'rgbs': img, 'mask': mask} return sample ================================================ FILE: dataLoader/colmap2nerf.py ================================================ #!/usr/bin/env python3 # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import argparse import os from pathlib import Path, PurePosixPath import numpy as np import json import sys import math import cv2 import os import shutil def parse_args(): parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place") parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also") parser.add_argument("--video_fps", default=2) parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video") parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images") parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") parser.add_argument("--images", default="images", help="input path to the images") parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)") parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16") parser.add_argument("--skip_early", default=0, help="skip this many images from the start") parser.add_argument("--out", default="transforms.json", help="output path") args = parser.parse_args() return args def do_system(arg): print(f"==== running: {arg}") err = os.system(arg) if err: print("FATAL: command failed") sys.exit(err) def run_ffmpeg(args): if not os.path.isabs(args.images): args.images = os.path.join(os.path.dirname(args.video_in), args.images) images = args.images video = args.video_in fps = float(args.video_fps) or 1.0 print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.") if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": sys.exit(1) try: shutil.rmtree(images) except: pass do_system(f"mkdir {images}") time_slice_value = "" time_slice = args.time_slice if time_slice: start, end = time_slice.split(",") time_slice_value = f",select='between(t\,{start}\,{end})'" do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg") def run_colmap(args): db=args.colmap_db images=args.images db_noext=str(Path(db).with_suffix("")) if args.text=="text": args.text=db_noext+"_text" text=args.text sparse=db_noext+"_sparse" print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}") if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": sys.exit(1) if os.path.exists(db): os.remove(db) do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}") do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}") try: shutil.rmtree(sparse) except: pass do_system(f"mkdir {sparse}") do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}") do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1") try: shutil.rmtree(text) except: pass do_system(f"mkdir {text}") do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT") def variance_of_laplacian(image): return cv2.Laplacian(image, cv2.CV_64F).var() def sharpness(imagePath): image = cv2.imread(imagePath) gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) fm = variance_of_laplacian(gray) return fm def qvec2rotmat(qvec): return np.array([ [ 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] ], [ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] ], [ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 ] ]) def rotmat(a, b): a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) v = np.cross(a, b) c = np.dot(a, b) s = np.linalg.norm(v) kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel da = da / np.linalg.norm(da) db = db / np.linalg.norm(db) c = np.cross(da, db) denom = np.linalg.norm(c)**2 t = ob - oa ta = np.linalg.det([t, db, c]) / (denom + 1e-10) tb = np.linalg.det([t, da, c]) / (denom + 1e-10) if ta > 0: ta = 0 if tb > 0: tb = 0 return (oa+ta*da+ob+tb*db) * 0.5, denom if __name__ == "__main__": args = parse_args() if args.video_in != "": run_ffmpeg(args) if args.run_colmap: run_colmap(args) AABB_SCALE = int(args.aabb_scale) SKIP_EARLY = int(args.skip_early) IMAGE_FOLDER = args.images TEXT_FOLDER = args.text OUT_PATH = args.out print(f"outputting to {OUT_PATH}...") with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f: angle_x = math.pi / 2 for line in f: # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 if line[0] == "#": continue els = line.split(" ") w = float(els[2]) h = float(els[3]) fl_x = float(els[4]) fl_y = float(els[4]) k1 = 0 k2 = 0 p1 = 0 p2 = 0 cx = w / 2 cy = h / 2 if els[1] == "SIMPLE_PINHOLE": cx = float(els[5]) cy = float(els[6]) elif els[1] == "PINHOLE": fl_y = float(els[5]) cx = float(els[6]) cy = float(els[7]) elif els[1] == "SIMPLE_RADIAL": cx = float(els[5]) cy = float(els[6]) k1 = float(els[7]) elif els[1] == "RADIAL": cx = float(els[5]) cy = float(els[6]) k1 = float(els[7]) k2 = float(els[8]) elif els[1] == "OPENCV": fl_y = float(els[5]) cx = float(els[6]) cy = float(els[7]) k1 = float(els[8]) k2 = float(els[9]) p1 = float(els[10]) p2 = float(els[11]) else: print("unknown camera model ", els[1]) # fl = 0.5 * w / tan(0.5 * angle_x); angle_x = math.atan(w / (fl_x * 2)) * 2 angle_y = math.atan(h / (fl_y * 2)) * 2 fovx = angle_x * 180 / math.pi fovy = angle_y * 180 / math.pi print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ") with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f: i = 0 bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) out = { "camera_angle_x": angle_x, "camera_angle_y": angle_y, "fl_x": fl_x, "fl_y": fl_y, "k1": k1, "k2": k2, "p1": p1, "p2": p2, "cx": cx, "cy": cy, "w": w, "h": h, "aabb_scale": AABB_SCALE, "frames": [], } up = np.zeros(3) for line in f: line = line.strip() if line[0] == "#": continue i = i + 1 if i < SKIP_EARLY*2: continue if i % 2 == 1: elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9]))) # why is this requireing a relitive path while using ^ image_rel = os.path.relpath(IMAGE_FOLDER) name = str(f"./{image_rel}/{'_'.join(elems[9:])}") b=sharpness(name) print(name, "sharpness=",b) image_id = int(elems[0]) qvec = np.array(tuple(map(float, elems[1:5]))) tvec = np.array(tuple(map(float, elems[5:8]))) R = qvec2rotmat(-qvec) t = tvec.reshape([3,1]) m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) c2w = np.linalg.inv(m) c2w[0:3,2] *= -1 # flip the y and z axis c2w[0:3,1] *= -1 c2w = c2w[[1,0,2,3],:] # swap y and z c2w[2,:] *= -1 # flip whole world upside down up += c2w[0:3,1] frame={"file_path":name,"sharpness":b,"transform_matrix": c2w} out["frames"].append(frame) nframes = len(out["frames"]) up = up / np.linalg.norm(up) print("up vector was", up) R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1] R = np.pad(R,[0,1]) R[-1, -1] = 1 for f in out["frames"]: f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis # find a central point they are all looking at print("computing center of attention...") totw = 0.0 totp = np.array([0.0, 0.0, 0.0]) for f in out["frames"]: mf = f["transform_matrix"][0:3,:] for g in out["frames"]: mg = g["transform_matrix"][0:3,:] p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) if w > 0.01: totp += p*w totw += w totp /= totw print(totp) # the cameras are looking at totp for f in out["frames"]: f["transform_matrix"][0:3,3] -= totp avglen = 0. for f in out["frames"]: avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) avglen /= nframes print("avg camera distance from origin", avglen) for f in out["frames"]: f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" for f in out["frames"]: f["transform_matrix"] = f["transform_matrix"].tolist() print(nframes,"frames") print(f"writing {OUT_PATH}") with open(OUT_PATH, "w") as outfile: json.dump(out, outfile, indent=2) ================================================ FILE: dataLoader/llff.py ================================================ import torch from torch.utils.data import Dataset import glob import numpy as np import os from PIL import Image from torchvision import transforms as T from .ray_utils import * def normalize(v): """Normalize a vector.""" return v / np.linalg.norm(v) def average_poses(poses): """ Calculate the average pose, which is then used to center all poses using @center_poses. Its computation is as follows: 1. Compute the center: the average of pose centers. 2. Compute the z axis: the normalized average z axis. 3. Compute axis y': the average y axis. 4. Compute x' = y' cross product z, then normalize it as the x axis. 5. Compute the y axis: z cross product x. Note that at step 3, we cannot directly use y' as y axis since it's not necessarily orthogonal to z axis. We need to pass from x to y. Inputs: poses: (N_images, 3, 4) Outputs: pose_avg: (3, 4) the average pose """ # 1. Compute the center center = poses[..., 3].mean(0) # (3) # 2. Compute the z axis z = normalize(poses[..., 2].mean(0)) # (3) # 3. Compute axis y' (no need to normalize as it's not the final output) y_ = poses[..., 1].mean(0) # (3) # 4. Compute the x axis x = normalize(np.cross(z, y_)) # (3) # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) y = np.cross(x, z) # (3) pose_avg = np.stack([x, y, z, center], 1) # (3, 4) return pose_avg def center_poses(poses, blender2opencv): """ Center the poses so that we can use NDC. See https://github.com/bmild/nerf/issues/34 Inputs: poses: (N_images, 3, 4) Outputs: poses_centered: (N_images, 3, 4) the centered poses pose_avg: (3, 4) the average pose """ poses = poses @ blender2opencv pose_avg = average_poses(poses) # (3, 4) pose_avg_homo = np.eye(4) pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation pose_avg_homo = pose_avg_homo # by simply adding 0, 0, 0, 1 as the last row last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) poses_homo = \ np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) # poses_centered = poses_centered @ blender2opencv poses_centered = poses_centered[:, :3] # (N_images, 3, 4) return poses_centered, pose_avg_homo def viewmatrix(z, up, pos): vec2 = normalize(z) vec1_avg = up vec0 = normalize(np.cross(vec1_avg, vec2)) vec1 = normalize(np.cross(vec2, vec0)) m = np.eye(4) m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) return m def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): render_poses = [] rads = np.array(list(rads) + [1.]) for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]: c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) render_poses.append(viewmatrix(z, up, c)) return render_poses def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120): # center pose c2w = average_poses(c2ws_all) # Get average pose up = normalize(c2ws_all[:, :3, 1].sum(0)) # Find a reasonable "focus depth" for this dataset dt = 0.75 close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0 focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) # Get radii for spiral path zdelta = near_fars.min() * .2 tt = c2ws_all[:, :3, 3] rads = np.percentile(np.abs(tt), 90, 0) * rads_scale render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views) return np.stack(render_poses) def get_interpolation_path(c2ws_all, steps=30): # flower # idx0 = 1 # idx1 = 10 # trex # idx0 = 8 # idx1 = 53 # horns idx0 = 18 idx1 = 47 v = np.linspace(0,1,num=steps) c2w0 = c2ws_all[idx0] c2w1 = c2ws_all[idx1] c2w_ = [] for i in range(steps): c2w_.append(c2w0*v[i] + c2w1*(1-v[i])) return np.stack(c2w_) class LLFFDataset(Dataset): def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8): self.root_dir = datadir self.split = split self.hold_every = hold_every self.is_stack = is_stack self.downsample = downsample self.define_transforms() self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) self.read_meta() self.white_bg = False # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])] self.near_far = [0.0, 1.0] self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]]) # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]]) self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3) self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3) def read_meta(self): poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17) self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*'))) # load full resolution image then resize if self.split in ['train', 'test']: assert len(poses_bounds) == len(self.image_paths), \ 'Mismatch between number of images and number of poses! Please rerun COLMAP!' poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) self.near_fars = poses_bounds[:, -2:] # (N_images, 2) hwf = poses[:, :, -1] # Step 1: rescale focal length according to training resolution H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)]) self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H] # Step 2: correct poses # Original poses has rotation in form "down right back", change to "right up back" # See https://github.com/bmild/nerf/issues/34 poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) # (N_images, 3, 4) exclude H, W, focal self.poses, self.pose_avg = center_poses(poses, self.blender2opencv) # Step 3: correct scale so that the nearest depth is at a little more than 1.0 # See https://github.com/bmild/nerf/issues/34 near_original = self.near_fars.min() scale_factor = near_original * 0.75 # 0.75 is the default parameter # the nearest depth is at 1/0.75=1.33 self.near_fars /= scale_factor self.poses[..., 3] /= scale_factor # build rendering path N_views, N_rots = 120, 2 tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T up = normalize(self.poses[:, :3, 1].sum(0)) rads = np.percentile(np.abs(tt), 90, 0) self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views) # self.render_path = get_interpolation_path(self.poses) # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) # val_idx = np.argmin(distances_from_center) # choose val image as the closest to # center image # ray directions for all pixels, same for all images (same H, W, focal) W, H = self.img_wh self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3) average_pose = average_poses(self.poses) dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1) i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)] img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test)) # use first N_images-1 to train, the LAST is val self.all_rays = [] self.all_rgbs = [] for i in img_list: image_path = self.image_paths[i] c2w = torch.FloatTensor(self.poses[i]) img = Image.open(image_path).convert('RGB') if self.downsample != 1.0: img = img.resize(self.img_wh, Image.LANCZOS) img = self.transform(img) # (3, h, w) img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB self.all_rgbs += [img] rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d) # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) all_rays = self.all_rays all_rgbs = self.all_rgbs self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6) self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3) if self.is_stack: self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames]),h,w,6) avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True) self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6) self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) @torch.no_grad() def prepare_feature_data(self, encoder, chunk=8): ''' Prepare feature maps as training data. ''' assert self.is_stack, 'Dataset should contain original stacked taining data!' print('====> prepare_feature_data ...') frames_num, h, w, _ = self.all_rgbs_stack.size() features = [] for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)): rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda() features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1 # resize to the size of rgb map so that rays can match features_chunk = T.functional.resize(features_chunk, size=(h,w), interpolation=T.InterpolationMode.BILINEAR) features.append(features_chunk.detach().cpu().requires_grad_(False)) self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256) self.all_features = self.all_features_stack.reshape(-1, 256) print('prepare_feature_data Done!') def define_transforms(self): self.transform = T.ToTensor() def __len__(self): return len(self.all_rgbs) def __getitem__(self, idx): sample = {'rays': self.all_rays[idx], 'rgbs': self.all_rgbs[idx]} return sample ================================================ FILE: dataLoader/nsvf.py ================================================ import torch from torch.utils.data import Dataset from tqdm import tqdm import os from PIL import Image from torchvision import transforms as T from .ray_utils import * trans_t = lambda t : torch.Tensor([ [1,0,0,0], [0,1,0,0], [0,0,1,t], [0,0,0,1]]).float() rot_phi = lambda phi : torch.Tensor([ [1,0,0,0], [0,np.cos(phi),-np.sin(phi),0], [0,np.sin(phi), np.cos(phi),0], [0,0,0,1]]).float() rot_theta = lambda th : torch.Tensor([ [np.cos(th),0,-np.sin(th),0], [0,1,0,0], [np.sin(th),0, np.cos(th),0], [0,0,0,1]]).float() def pose_spherical(theta, phi, radius): c2w = trans_t(radius) c2w = rot_phi(phi/180.*np.pi) @ c2w c2w = rot_theta(theta/180.*np.pi) @ c2w c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w return c2w class NSVF(Dataset): """NSVF Generic Dataset.""" def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False): self.root_dir = datadir self.split = split self.is_stack = is_stack self.downsample = downsample self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) self.define_transforms() self.white_bg = True self.near_far = [0.5,6.0] self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3) self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) self.read_meta() self.define_proj_mat() self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) def bbox2corners(self): corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) for i in range(3): corners[i,[0,1],i] = corners[i,[1,0],i] return corners.view(-1,3) def read_meta(self): with open(os.path.join(self.root_dir, "intrinsics.txt")) as f: focal = float(f.readline().split()[0]) self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]]) self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1) pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) if self.split == 'train': pose_files = [x for x in pose_files if x.startswith('0_')] img_files = [x for x in img_files if x.startswith('0_')] elif self.split == 'val': pose_files = [x for x in pose_files if x.startswith('1_')] img_files = [x for x in img_files if x.startswith('1_')] elif self.split == 'test': test_pose_files = [x for x in pose_files if x.startswith('2_')] test_img_files = [x for x in img_files if x.startswith('2_')] if len(test_pose_files) == 0: test_pose_files = [x for x in pose_files if x.startswith('1_')] test_img_files = [x for x in img_files if x.startswith('1_')] pose_files = test_pose_files img_files = test_img_files # ray directions for all pixels, same for all images (same H, W, focal) self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) self.poses = [] self.all_rays = [] self.all_rgbs = [] assert len(img_files) == len(pose_files) for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): image_path = os.path.join(self.root_dir, 'rgb', img_fname) img = Image.open(image_path) if self.downsample!=1.0: img = img.resize(self.img_wh, Image.LANCZOS) img = self.transform(img) # (4, h, w) img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA if img.shape[-1]==4: img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB self.all_rgbs += [img] c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv c2w = torch.FloatTensor(c2w) self.poses.append(c2w) # C2W rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) # w2c = torch.inverse(c2w) # self.poses = torch.stack(self.poses) if 'train' == self.split: if self.is_stack: self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) else: self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) else: self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) def define_transforms(self): self.transform = T.ToTensor() def define_proj_mat(self): self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] def world2ndc(self, points): device = points.device return (points - self.center.to(device)) / self.radius.to(device) def __len__(self): if self.split == 'train': return len(self.all_rays) return len(self.all_rgbs) def __getitem__(self, idx): if self.split == 'train': # use data in the buffers sample = {'rays': self.all_rays[idx], 'rgbs': self.all_rgbs[idx]} else: # create data for each image separately img = self.all_rgbs[idx] rays = self.all_rays[idx] sample = {'rays': rays, 'rgbs': img} return sample ================================================ FILE: dataLoader/ray_utils.py ================================================ import torch, re import numpy as np from torch import searchsorted from kornia import create_meshgrid # from utils import index_point_feature def depth2dist(z_vals, cos_angle): # z_vals: [N_ray N_sample] device = z_vals.device dists = z_vals[..., 1:] - z_vals[..., :-1] dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] dists = dists * cos_angle.unsqueeze(-1) return dists def ndc2dist(ndc_pts, cos_angle): dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1) dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples] return dists def get_ray_directions(H, W, focal, center=None): """ Get ray directions for all pixels in camera coordinate. Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ ray-tracing-generating-camera-rays/standard-coordinate-systems Inputs: H, W, focal: image height, width and focal length Outputs: directions: (H, W, 3), the direction of the rays in camera coordinate """ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 i, j = grid.unbind(-1) # the direction here is without +0.5 pixel centering as calibration is not so accurate # see https://github.com/bmild/nerf/issues/24 cent = center if center is not None else [W / 2, H / 2] directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) return directions def get_ray_directions_blender(H, W, focal, center=None): """ Get ray directions for all pixels in camera coordinate. Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ ray-tracing-generating-camera-rays/standard-coordinate-systems Inputs: H, W, focal: image height, width and focal length Outputs: directions: (H, W, 3), the direction of the rays in camera coordinate """ grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5 i, j = grid.unbind(-1) # the direction here is without +0.5 pixel centering as calibration is not so accurate # see https://github.com/bmild/nerf/issues/24 cent = center if center is not None else [W / 2, H / 2] directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)], -1) # (H, W, 3) return directions def get_rays(directions, c2w): """ Get ray origin and normalized directions in world coordinate for all pixels in one image. Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ ray-tracing-generating-camera-rays/standard-coordinate-systems Inputs: directions: (H, W, 3) precomputed ray directions in camera coordinate c2w: (3, 4) transformation matrix from camera coordinate to world coordinate Outputs: rays_o: (H*W, 3), the origin of the rays in world coordinate rays_d: (H*W, 3), the normalized direction of the rays in world coordinate """ # Rotate ray directions from camera coordinate to the world coordinate rays_d = directions @ c2w[:3, :3].T # (H, W, 3) # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) # The origin of all rays is the camera origin in world coordinate rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) rays_d = rays_d.view(-1, 3) rays_o = rays_o.view(-1, 3) return rays_o, rays_d def ndc_rays_blender(H, W, focal, near, rays_o, rays_d): # Shift ray origins to near plane t = -(near + rays_o[..., 2]) / rays_d[..., 2] rays_o = rays_o + t[..., None] * rays_d # Projection o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] o2 = 1. + 2. * near / rays_o[..., 2] d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) d2 = -2. * near / rays_o[..., 2] rays_o = torch.stack([o0, o1, o2], -1) rays_d = torch.stack([d0, d1, d2], -1) return rays_o, rays_d def ndc_rays(H, W, focal, near, rays_o, rays_d): # Shift ray origins to near plane t = (near - rays_o[..., 2]) / rays_d[..., 2] rays_o = rays_o + t[..., None] * rays_d # Projection o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] o2 = 1. - 2. * near / rays_o[..., 2] d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) d2 = 2. * near / rays_o[..., 2] rays_o = torch.stack([o0, o1, o2], -1) rays_d = torch.stack([d0, d1, d2], -1) return rays_o, rays_d # Hierarchical sampling (section 5.2) def sample_pdf(bins, weights, N_samples, det=False, pytest=False): device = weights.device # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Take uniform samples if det: u = torch.linspace(0., 1., steps=N_samples, device=device) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device) # Pytest, overwrite u with numpy's fixed random numbers if pytest: np.random.seed(0) new_shape = list(cdf.shape[:-1]) + [N_samples] if det: u = np.linspace(0., 1., N_samples) u = np.broadcast_to(u, new_shape) else: u = np.random.rand(*new_shape) u = torch.Tensor(u) # Invert CDF u = u.contiguous() inds = searchsorted(cdf.detach(), u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples def dda(rays_o, rays_d, bbox_3D): inv_ray_d = 1.0 / (rays_d + 1e-6) t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3 t_max = (bbox_3D[1:] - rays_o) * inv_ray_d t = torch.stack((t_min, t_max)) # 2 N_rays 3 t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0] t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0] return t_min, t_max def ray_marcher(rays, N_samples=64, lindisp=False, perturb=0, bbox_3D=None): """ sample points along the rays Inputs: rays: () Returns: """ # Decompose the inputs N_rays = rays.shape[0] rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) if bbox_3D is not None: # cal aabb boundles near, far = dda(rays_o, rays_d, bbox_3D) # Sample depth points z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples) if not lindisp: # use linear sampling in depth space z_vals = near * (1 - z_steps) + far * z_steps else: # use linear sampling in disparity space z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) z_vals = z_vals.expand(N_rays, N_samples) if perturb > 0: # perturb sampling depths (z_vals) z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points # get intervals between samples upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1) lower = torch.cat([z_vals[:, :1], z_vals_mid], -1) perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device) z_vals = lower + (upper - lower) * perturb_rand xyz_coarse_sampled = rays_o.unsqueeze(1) + \ rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3) return xyz_coarse_sampled, rays_o, rays_d, z_vals def read_pfm(filename): file = open(filename, 'rb') color = None width = None height = None scale = None endian = None header = file.readline().decode('utf-8').rstrip() if header == 'PF': color = True elif header == 'Pf': color = False else: raise Exception('Not a PFM file.') dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) if dim_match: width, height = map(int, dim_match.groups()) else: raise Exception('Malformed PFM header.') scale = float(file.readline().rstrip()) if scale < 0: # little-endian endian = '<' scale = -scale else: endian = '>' # big-endian data = np.fromfile(file, endian + 'f') shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) file.close() return data, scale def ndc_bbox(all_rays): near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0] near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0] far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0] far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0] print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}') return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max))) import torchvision normalize_vgg = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def denormalize_vgg(img): im = img.clone() im[:, 0, :, :] *= 0.229 im[:, 1, :, :] *= 0.224 im[:, 2, :, :] *= 0.225 im[:, 0, :, :] += 0.485 im[:, 1, :, :] += 0.456 im[:, 2, :, :] += 0.406 return im ================================================ FILE: dataLoader/styleLoader.py ================================================ from torch.utils.data import DataLoader from torchvision import datasets import torchvision.transforms as T def getDataLoader(dataset_path, batch_size, sampler, image_side_length=256, num_workers=2): transform = T.Compose([ T.Resize(size=(image_side_length*2, image_side_length*2)), T.RandomCrop(image_side_length), T.ToTensor(), ]) train_dataset = datasets.ImageFolder(dataset_path, transform=transform) dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler(len(train_dataset)), num_workers=num_workers) return dataloader ================================================ FILE: dataLoader/tankstemple.py ================================================ import torch from torch.utils.data import Dataset from tqdm import tqdm import os from PIL import Image from torchvision import transforms as T import random from .ray_utils import * def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): if axis == 'z': return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h] elif axis == 'y': return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)] else: return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)] def cross(x, y, axis=0): T = torch if isinstance(x, torch.Tensor) else np return T.cross(x, y, axis) def normalize(x, axis=-1, order=2): if isinstance(x, torch.Tensor): l2 = x.norm(p=order, dim=axis, keepdim=True) return x / (l2 + 1e-8), l2 else: l2 = np.linalg.norm(x, order, axis) l2 = np.expand_dims(l2, axis) l2[l2 == 0] = 1 return x / l2, def cat(x, axis=1): if isinstance(x[0], torch.Tensor): return torch.cat(x, dim=axis) return np.concatenate(x, axis=axis) def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): """ This function takes a vector 'camera_position' which specifies the location of the camera in world coordinates and two vectors `at` and `up` which indicate the position of the object and the up directions of the world coordinate system respectively. The object is assumed to be centered at the origin. The output is a rotation matrix representing the transformation from world coordinates -> view coordinates. Input: camera_position: 3 at: 1 x 3 or N x 3 (0, 0, 0) in default up: 1 x 3 or N x 3 (0, 1, 0) in default """ if at is None: at = torch.zeros_like(camera_position) else: at = torch.tensor(at).type_as(camera_position) if up is None: up = torch.zeros_like(camera_position) up[2] = -1 else: up = torch.tensor(up).type_as(camera_position) z_axis = normalize(at - camera_position)[0] x_axis = normalize(cross(up, z_axis))[0] y_axis = normalize(cross(z_axis, x_axis))[0] R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) return R def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180): c2ws = [] for t in range(frames): c2w = torch.eye(4) cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi)) cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True) c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot c2ws.append(c2w) return torch.stack(c2ws) class TanksTempleDataset(Dataset): """NSVF Generic Dataset.""" def __init__(self, datadir, split='train', downsample=4.0, wh=[1920,1080], is_stack=False): self.root_dir = datadir self.split = split self.is_stack = is_stack self.downsample = downsample self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) self.define_transforms() self.white_bg = True self.near_far = [0.01,6.0] self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2 self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) self.read_meta() self.define_proj_mat() self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) def bbox2corners(self): corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) for i in range(3): corners[i,[0,1],i] = corners[i,[1,0],i] return corners.view(-1,3) def read_meta(self): self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt")) self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1) pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) if self.split == 'train': pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('0_') and idx%3==0] img_files = [x for idx,x in enumerate(img_files) if x.startswith('0_') and idx%3==0] elif self.split == 'test': pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('2_') and idx%3==0] img_files = [x for idx,x in enumerate(img_files) if x.startswith('2_') and idx%3==0] if len(test_pose_files) == 0: test_pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('1_') and idx%3==0] test_img_files = [x for idx,x in enumerate(img_files) if x.startswith('1_') and idx%3==0] pose_files = test_pose_files img_files = test_img_files # ray directions for all pixels, same for all images (same H, W, focal) self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) w, h = self.img_wh self.poses = [] self.all_rays = [] self.all_rgbs = [] self.all_masks = [] assert len(img_files) == len(pose_files) for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): image_path = os.path.join(self.root_dir, 'rgb', img_fname) img = Image.open(image_path) if self.downsample!=1.0: img = img.resize(self.img_wh, Image.LANCZOS) img = self.transform(img) # (3, h, w) img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 3) RGBA mask = torch.where( img.sum(-1, keepdim=True) == 3., 1., 0. ) self.all_masks.append(mask.reshape(h,w,1)) # (h, w, 1) A if img.shape[-1]==4: img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB self.all_rgbs.append(img) c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans c2w = torch.FloatTensor(c2w) self.poses.append(c2w) # C2W rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) self.poses = torch.stack(self.poses) center = torch.mean(self.scene_bbox, dim=0) radius = torch.norm(self.scene_bbox[1]-center)*1.2 up = torch.mean(self.poses[:, :3, 1], dim=0).tolist() pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y') self.render_path = gen_path(pos_gen, up=up,frames=100) self.render_path[:, :3, 3] += center all_rays = self.all_rays all_rgbs = self.all_rgbs self.all_masks = torch.stack(self.all_masks) # (n_frames, h, w, 1) self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6) self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3) if self.is_stack: self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames]),h,w,6) avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True) self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6) self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) @torch.no_grad() def prepare_feature_data(self, encoder, chunk=4): ''' Prepare feature maps as training data. ''' assert self.is_stack, 'Dataset should contain original stacked taining data!' print('====> prepare_feature_data ...') frames_num, h, w, _ = self.all_rgbs_stack.size() features = [] for chunk_idx in tqdm(range(frames_num // chunk + int(frames_num % chunk > 0))): rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda() features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1 # resize to the size of rgb map so that rays can match features_chunk = T.functional.resize(features_chunk, size=(h,w), interpolation=T.InterpolationMode.BILINEAR) features.append(features_chunk.detach().cpu().requires_grad_(False)) self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256) self.all_features = self.all_features_stack.reshape(-1, 256) print('prepare_feature_data Done!') def define_transforms(self): self.transform = T.ToTensor() def define_proj_mat(self): self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] def world2ndc(self, points): device = points.device return (points - self.center.to(device)) / self.radius.to(device) def __len__(self): if self.split == 'train': return len(self.all_rays) return len(self.all_rgbs) def __getitem__(self, idx): if self.split == 'train': # use data in the buffers sample = {'rays': self.all_rays[idx], 'rgbs': self.all_rgbs[idx]} else: # create data for each image separately img = self.all_rgbs[idx] rays = self.all_rays[idx] sample = {'rays': rays, 'rgbs': img} return sample ================================================ FILE: dataLoader/your_own_data.py ================================================ import torch,cv2 from torch.utils.data import Dataset import json from tqdm import tqdm import os from PIL import Image from torchvision import transforms as T from .ray_utils import * class YourOwnDataset(Dataset): def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): self.N_vis = N_vis self.root_dir = datadir self.split = split self.is_stack = is_stack self.downsample = downsample self.define_transforms() self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) self.read_meta() self.define_proj_mat() self.white_bg = True self.near_far = [0.1,100.0] self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) self.downsample=downsample def read_depth(self, filename): depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) return depth def read_meta(self): with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: self.meta = json.load(f) w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample) self.img_wh = [w,h] self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length self.cx, self.cy = self.meta['cx'],self.meta['cy'] # ray directions for all pixels, same for all images (same H, W, focal) self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3) self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float() self.image_paths = [] self.poses = [] self.all_rays = [] self.all_rgbs = [] self.all_masks = [] self.all_depth = [] img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# frame = self.meta['frames'][i] pose = np.array(frame['transform_matrix']) @ self.blender2opencv c2w = torch.FloatTensor(pose) self.poses += [c2w] image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") self.image_paths += [image_path] img = Image.open(image_path) if self.downsample!=1.0: img = img.resize(self.img_wh, Image.LANCZOS) img = self.transform(img) # (4, h, w) img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA if img.shape[-1]==4: img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB self.all_rgbs += [img] rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) self.poses = torch.stack(self.poses) if not self.is_stack: self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) else: self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) def define_transforms(self): self.transform = T.ToTensor() def define_proj_mat(self): self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] def world2ndc(self,points,lindisp=None): device = points.device return (points - self.center.to(device)) / self.radius.to(device) def __len__(self): return len(self.all_rgbs) def __getitem__(self, idx): if self.split == 'train': # use data in the buffers sample = {'rays': self.all_rays[idx], 'rgbs': self.all_rgbs[idx]} else: # create data for each image separately img = self.all_rgbs[idx] rays = self.all_rays[idx] mask = self.all_masks[idx] # for quantity evaluation sample = {'rays': rays, 'rgbs': img} return sample ================================================ FILE: extra/auto_run_paramsets.py ================================================ import os import threading, queue import numpy as np import time def getFolderLocker(logFolder): while True: try: os.makedirs(logFolder+"/lockFolder") break except: time.sleep(0.01) def releaseFolderLocker(logFolder): os.removedirs(logFolder+"/lockFolder") def getStopFolder(logFolder): return os.path.isdir(logFolder+"/stopFolder") def get_param_str(key, val): if key == 'data_name': return f'--datadir {datafolder}/{val} ' else: return f'--{key} {val} ' def get_param_list(param_dict): param_keys = list(param_dict.keys()) param_modes = len(param_keys) param_nums = [len(param_dict[key]) for key in param_keys] param_ids = np.zeros(param_nums+[param_modes], dtype=int) for i in range(param_modes): broad_tuple = np.ones(param_modes, dtype=int).tolist() broad_tuple[i] = param_nums[i] broad_tuple = tuple(broad_tuple) print(broad_tuple) param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple) param_ids = param_ids.reshape(-1, param_modes) # print(param_ids) print(len(param_ids)) params = [] expnames = [] for i in range(param_ids.shape[0]): one = "" name = "" param_id = param_ids[i] for j in range(param_modes): key = param_keys[j] val = param_dict[key][param_id[j]] if type(key) is tuple: assert len(key) == len(val) for k in range(len(key)): one += get_param_str(key[k], val[k]) name += f'{val[k]},' name=name[:-1]+'-' else: one += get_param_str(key, val) name += f'{val}-' params.append(one) name=name.replace(' ','') print(name) expnames.append(name[:-1]) # print(params) return params, expnames if __name__ == '__main__': # nerf expFolder = "nerf/" # parameters to iterate, use tuple to couple multiple parameters datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/' param_dict = { 'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'], 'data_dim_color': [13, 27, 54] } # n_iters = 30000 # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder' # cmd = f'CUDA_VISIBLE_DEVICES={cuda} python train.py ' \ # f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\ # f'--expname {data_name} --batch_size {batch_size} ' \ # f'--n_iters {n_iters} ' \ # f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\ # f'--N_vis {5} ' \ # f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \ # f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \ # f'--shadingMode MLP_Fea --fea2denseAct softplus --view_pe {2} --fea_pe {2} ' \ # f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \ # f'--render_test 1 ' # print(cmd) # os.system(cmd) # nsvf # expFolder = "nsvf_0227/" # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/' # param_dict = { # 'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder' # 'shadingMode': ['SH'], # ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")], # ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)], # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)], # ('n_iters','N_voxel_final'): [(30000,300**3)], # ('dataset_name','N_vis','render_test') : [("nsvf",5,1)], # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")] # # } # tankstemple # expFolder = "tankstemple_0304/" # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/' # param_dict = { # 'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'], # 'shadingMode': ['MLP_Fea'], # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")], # ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)], # ('TV_weight_density','TV_weight_app'):[(0.1,0.01)], # # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)], # ('n_iters','N_voxel_final'): [(15000,300**3)], # ('dataset_name','N_vis') : [("tankstemple",5)], # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")] # } # llff # expFolder = "real_iconic/" # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/' # List = os.listdir(datafolder) # param_dict = { # 'data_name': List, # ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)], # ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")], # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)], # ('n_iters','N_voxel_final'): [(25000,640**3)], # ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)], # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")], # } # expFolder = "llff/" # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data' # param_dict = { # 'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids' # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")], # ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)], # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)], # ('n_iters','N_voxel_final'): [(25000,640**3)], # ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)], # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")], # } #setting available gpus gpus_que = queue.Queue(3) for i in [1,2,3]: gpus_que.put(i) os.makedirs(f"log/{expFolder}", exist_ok=True) def run_program(gpu, expname, param): cmd = f'CUDA_VISIBLE_DEVICES={gpu} python train.py ' \ f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \ f'{param}' \ f'> "log/{expFolder}{expname}/{expname}.txt"' print(cmd) os.system(cmd) gpus_que.put(gpu) params, expnames = get_param_list(param_dict) logFolder=f"log/{expFolder}" os.makedirs(logFolder, exist_ok=True) ths = [] for i in range(len(params)): if getStopFolder(logFolder): break targetFolder = f"log/{expFolder}{expnames[i]}" gpu = gpus_que.get() getFolderLocker(logFolder) if os.path.isdir(targetFolder): releaseFolderLocker(logFolder) gpus_que.put(gpu) continue else: os.makedirs(targetFolder, exist_ok=True) print("making",targetFolder, "running",expnames[i], params[i]) releaseFolderLocker(logFolder) t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True) t.start() ths.append(t) for th in ths: th.join() ================================================ FILE: extra/compute_metrics.py ================================================ import os, math import numpy as np import scipy.signal from typing import List, Optional from PIL import Image import os import torch import configargparse __LPIPS__ = {} def init_lpips(net_name, device): assert net_name in ['alex', 'vgg'] import lpips print(f'init_lpips: lpips_{net_name}') return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) def rgb_lpips(np_gt, np_im, net_name, device): if net_name not in __LPIPS__: __LPIPS__[net_name] = init_lpips(net_name, device) gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) return __LPIPS__[net_name](gt, im, normalize=True).item() def findItem(items, target): for one in items: if one[:len(target)]==target: return one return None ''' Evaluation metrics (ssim, lpips) ''' def rgb_ssim(img0, img1, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, return_map=False): # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 assert len(img0.shape) == 3 assert img0.shape[-1] == 3 assert img0.shape == img1.shape # Construct a 1D Gaussian blur filter. hw = filter_size // 2 shift = (2 * hw - filter_size + 1) / 2 f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 filt = np.exp(-0.5 * f_i) filt /= np.sum(filt) # Blur in x and y (faster than the 2D convolution). def convolve2d(z, f): return scipy.signal.convolve2d(z, f, mode='valid') filt_fn = lambda z: np.stack([ convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) for i in range(z.shape[-1])], -1) mu0 = filt_fn(img0) mu1 = filt_fn(img1) mu00 = mu0 * mu0 mu11 = mu1 * mu1 mu01 = mu0 * mu1 sigma00 = filt_fn(img0**2) - mu00 sigma11 = filt_fn(img1**2) - mu11 sigma01 = filt_fn(img0 * img1) - mu01 # Clip the variances and covariances to valid values. # Variance must be non-negative: sigma00 = np.maximum(0., sigma00) sigma11 = np.maximum(0., sigma11) sigma01 = np.sign(sigma01) * np.minimum( np.sqrt(sigma00 * sigma11), np.abs(sigma01)) c1 = (k1 * max_val)**2 c2 = (k2 * max_val)**2 numer = (2 * mu01 + c1) * (2 * sigma01 + c2) denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) ssim_map = numer / denom ssim = np.mean(ssim_map) return ssim_map if return_map else ssim if __name__ == '__main__': parser = configargparse.ArgumentParser() parser.add_argument("--exp", type=str, help="folder of exps") parser.add_argument("--paramStr", type=str, help="str of params") args = parser.parse_args() # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']# # gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic" # expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']# # gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/" # expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp paramStr = args.paramStr fileNum = 200 expitems = os.listdir(expFolder) finalFolder = f'{expFolder}/finals/{paramStr}' outFile = f'{finalFolder}/{paramStr}_metrics.txt' os.makedirs(finalFolder, exist_ok=True) expitems.sort(reverse=True) with open(outFile, 'w') as f: all_psnr = [] all_ssim = [] all_alex = [] all_vgg = [] for dataname in datanames: gtstr = gtFolder+"/"+dataname+"/test/r_%d.png" expname = findItem(expitems, f'{paramStr}-{dataname}') print("expname: ", expname) if expname is None: print("no ",dataname, "exists") continue resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png" metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt' video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4' exist_metric=False if os.path.isfile(metric_file): metrics = np.loadtxt(metric_file) print(metrics, metrics.tolist()) if metrics.size == 4: psnr, ssim, l_a, l_v = metrics.tolist() exist_metric = True os.system(f"cp {video_file} {finalFolder}/") if not exist_metric: psnrs = [] ssims = [] l_alex = [] l_vgg = [] for i in range(fileNum): gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0 gtmask = gt[...,[3]] gt = gt[...,:3] gt = gt*gtmask + (1-gtmask) img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3] / 255.0 # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max()) psnr = -10. * np.log10(np.mean(np.square(img - gt))) ssim = rgb_ssim(img, gt, 1) lpips_alex = rgb_lpips(gt, img, 'alex','cuda') lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda') print(i, psnr, ssim, lpips_alex, lpips_vgg) psnrs.append(psnr) ssims.append(ssim) l_alex.append(lpips_alex) l_vgg.append(lpips_vgg) psnr = np.mean(np.array(psnrs)) ssim = np.mean(np.array(ssims)) l_a = np.mean(np.array(l_alex)) l_v = np.mean(np.array(l_vgg)) rS=f'{dataname} : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}' print(rS) f.write(rS+"\n") all_psnr.append(psnr) all_ssim.append(ssim) all_alex.append(l_a) all_vgg.append(l_v) psnr = np.mean(np.array(all_psnr)) ssim = np.mean(np.array(all_ssim)) l_a = np.mean(np.array(all_alex)) l_v = np.mean(np.array(all_vgg)) rS=f'mean : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}' print(rS) f.write(rS+"\n") ================================================ FILE: models/VGG.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from collections import namedtuple import torchvision.models as models # pytorch pretrained vgg class Encoder(nn.Module): def __init__(self): super().__init__() #pretrained vgg19 vgg19 = models.vgg19(weights='DEFAULT').features self.relu1_1 = vgg19[:2] self.relu2_1 = vgg19[2:7] self.relu3_1 = vgg19[7:12] self.relu4_1 = vgg19[12:21] #fix parameters self.requires_grad_(False) def forward(self, x): _output = namedtuple('output', ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1']) relu1_1 = self.relu1_1(x) relu2_1 = self.relu2_1(relu1_1) relu3_1 = self.relu3_1(relu2_1) relu4_1 = self.relu4_1(relu3_1) output = _output(relu1_1, relu2_1, relu3_1, relu4_1) return output class Decoder(nn.Module): """ starting from relu 4_1 """ def __init__(self, ckpt_path=None): super().__init__() self.layers = nn.Sequential( # nn.Conv2d(512, 256, 3, padding=1, padding_mode='reflect'), # nn.ReLU(), # nn.Upsample(scale_factor=2, mode='nearest'), # relu4-1 nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu3-4 nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu3-3 nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu3-2 nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'), nn.ReLU(), nn.Upsample(scale_factor=2, mode='nearest'),# relu3-1 nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu2-2 nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'), nn.ReLU(), nn.Upsample(scale_factor=2, mode='nearest'),# relu2-1 nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu1-2 nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'), ) if ckpt_path is not None: self.load_state_dict(torch.load(ckpt_path)) def forward(self, x): return self.layers(x) ### high-res unet feature map decoder class DownBlock(nn.Module): def __init__(self, in_dim, out_dim, down='conv'): super(DownBlock, self).__init__() if down == 'conv': self.down_conv = nn.Sequential( nn.Conv2d(in_dim, out_dim, 3, 2, 1), nn.LeakyReLU(), nn.Conv2d(out_dim, out_dim, 3, 1, 1), nn.LeakyReLU(), ) elif down == 'mean': self.down_conv = nn.AvgPool2d(2) else: raise NotImplementedError( '[ERROR] invalid downsampling operator: {:s}'.format(down) ) def forward(self, x): x = self.down_conv(x) return x class UpBlock(nn.Module): def __init__(self, in_dim, out_dim, skip_dim=None, up='nearest'): super(UpBlock, self).__init__() if up == 'conv': self.up_conv = nn.Sequential( nn.ConvTranspose2d(in_dim, out_dim, 3, 2, 1, 1), nn.ReLU(), ) else: assert up in ('bilinear', 'nearest'), \ '[ERROR] invalid upsampling mode: {:s}'.format(up) self.up_conv = nn.Sequential( nn.Upsample(scale_factor=2, mode=up), nn.Conv2d(in_dim, out_dim, 3, 1, 1), nn.ReLU(), ) in_dim = out_dim if skip_dim is not None: in_dim += skip_dim self.conv = nn.Sequential( nn.Conv2d(in_dim, out_dim, 3, 1, 1), nn.ReLU(), ) def _pad(self, x, y): dh = y.size(-2) - x.size(-2) dw = y.size(-1) - x.size(-1) if dh == 0 and dw == 0: return x if dh < 0: x = x[..., :dh, :] if dw < 0: x = x[..., :, :dw] if dh > 0 or dw > 0: x = F.pad( x, pad=(dw // 2, dw - dw // 2, dh // 2, dh - dh // 2), mode='reflect' ) return x def forward(self, x, skip=None): x = self.up_conv(x) if skip is not None: x = torch.cat([self._pad(x, skip), skip], 1) x = self.conv(x) return x class UNetDecoder(nn.Module): def __init__(self, in_dim=256): super(UNetDecoder, self).__init__() self.down_layers = nn.ModuleList() self.skip_convs = nn.ModuleList() self.up_layers = nn.ModuleList() in_dim = in_dim self.n_levels = 2 self.up = 1 for i in range(self.n_levels): self.down_layers.append( DownBlock( in_dim, in_dim, ) ) out_dim = in_dim // 2 ** (self.n_levels - i) self.skip_convs.append(nn.Conv2d(in_dim, out_dim, 1)) self.up_layers.append( UpBlock( out_dim * 2, out_dim, out_dim, ) ) out_dim = in_dim // 2 ** self.n_levels self.out_conv = nn.Sequential( nn.Conv2d(out_dim, out_dim, 3, 1, 1), nn.ReLU(), nn.Conv2d(out_dim, 3, 1, 1), ) def forward(self, feats): skips = [] for i in range(self.n_levels): skips.append(self.skip_convs[i](feats)) feats = self.down_layers[i](feats) for i in range(self.n_levels - 1, -1, -1): feats = self.up_layers[i](feats, skips[i]) rgb = self.out_conv(feats) return rgb ### high-res feature map decoder class PlainDecoder(nn.Module): def __init__(self) -> None: super().__init__() self.layers = nn.Sequential( nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu3-4 nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu3-3 nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu3-2 nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'), nn.ReLU(), nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu2-2 nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'), nn.ReLU(), # relu1-2 nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'), ) def forward(self, x): return self.layers(x) ================================================ FILE: models/__init__.py ================================================ ================================================ FILE: models/sh.py ================================================ import torch ################## sh function ################## C0 = 0.28209479177387814 C1 = 0.4886025119029199 C2 = [ 1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396 ] C3 = [ -0.5900435899266435, 2.890611442640554, -0.4570457994644658, 0.3731763325901154, -0.4570457994644658, 1.445305721320277, -0.5900435899266435 ] C4 = [ 2.5033429417967046, -1.7701307697799304, 0.9461746957575601, -0.6690465435572892, 0.10578554691520431, -0.6690465435572892, 0.47308734787878004, -1.7701307697799304, 0.6258357354491761, ] def eval_sh(deg, sh, dirs): """ Evaluate spherical harmonics at unit directions using hardcoded SH polynomials. Works with torch/np/jnp. ... Can be 0 or more batch dimensions. :param deg: int SH max degree. Currently, 0-4 supported :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) :param dirs: torch.Tensor unit directions (..., 3) :return: (..., C) """ assert deg <= 4 and deg >= 0 assert (deg + 1) ** 2 == sh.shape[-1] C = sh.shape[-2] result = C0 * sh[..., 0] if deg > 0: x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] result = (result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]) if deg > 1: xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z result = (result + C2[0] * xy * sh[..., 4] + C2[1] * yz * sh[..., 5] + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + C2[3] * xz * sh[..., 7] + C2[4] * (xx - yy) * sh[..., 8]) if deg > 2: result = (result + C3[0] * y * (3 * xx - yy) * sh[..., 9] + C3[1] * xy * z * sh[..., 10] + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + C3[5] * z * (xx - yy) * sh[..., 14] + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) if deg > 3: result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) return result def eval_sh_bases(deg, dirs): """ Evaluate spherical harmonics bases at unit directions, without taking linear combination. At each point, the final result may the be obtained through simple multiplication. :param deg: int SH max degree. Currently, 0-4 supported :param dirs: torch.Tensor (..., 3) unit directions :return: torch.Tensor (..., (deg+1) ** 2) """ assert deg <= 4 and deg >= 0 result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) result[..., 0] = C0 if deg > 0: x, y, z = dirs.unbind(-1) result[..., 1] = -C1 * y; result[..., 2] = C1 * z; result[..., 3] = -C1 * x; if deg > 1: xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z result[..., 4] = C2[0] * xy; result[..., 5] = C2[1] * yz; result[..., 6] = C2[2] * (2.0 * zz - xx - yy); result[..., 7] = C2[3] * xz; result[..., 8] = C2[4] * (xx - yy); if deg > 2: result[..., 9] = C3[0] * y * (3 * xx - yy); result[..., 10] = C3[1] * xy * z; result[..., 11] = C3[2] * y * (4 * zz - xx - yy); result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); result[..., 13] = C3[4] * x * (4 * zz - xx - yy); result[..., 14] = C3[5] * z * (xx - yy); result[..., 15] = C3[6] * x * (xx - 3 * yy); if deg > 3: result[..., 16] = C4[0] * xy * (xx - yy); result[..., 17] = C4[1] * yz * (3 * xx - yy); result[..., 18] = C4[2] * xy * (7 * zz - 1); result[..., 19] = C4[3] * yz * (7 * zz - 3); result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); result[..., 21] = C4[5] * xz * (7 * zz - 3); result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); result[..., 23] = C4[7] * xz * (xx - 3 * yy); result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); return result ================================================ FILE: models/styleModules.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import random def calc_mean_std(x, eps=1e-8): """ calculating channel-wise instance mean and standard variance x: shape of (N,C,*) """ mean = torch.mean(x.flatten(2), dim=-1, keepdim=True) # size of (N, C, 1) std = torch.std(x.flatten(2), dim=-1, keepdim=True) + eps # size of (N, C, 1) return mean, std def cal_adain_style_loss(x, y): """ style loss in one layer Args: x, y: feature maps of size [N, C, H, W] """ x_mean, x_std = calc_mean_std(x) y_mean, y_std = calc_mean_std(y) return nn.functional.mse_loss(x_mean, y_mean) \ + nn.functional.mse_loss(x_std, y_std) def cal_mse_content_loss(x, y): return nn.functional.mse_loss(x, y) class LearnableIN(nn.Module): ''' Input: (N, C, L) or (C, L) ''' def __init__(self, dim=256): super().__init__() self.IN = torch.nn.InstanceNorm1d(dim, momentum=1e-4, track_running_stats=True) def forward(self, x): if x.size()[-1] <= 1: return x return self.IN(x) class SimpleLinearStylizer(nn.Module): def __init__(self, input_dim=256, embed_dim=32, n_layers=3 ) -> None: super().__init__() self.input_dim = input_dim self.embed_dim = embed_dim self.IN = LearnableIN(input_dim) self.q_embed = nn.Conv1d(input_dim, embed_dim, 1) self.k_embed = nn.Conv1d(input_dim, embed_dim, 1) self.v_embed = nn.Conv1d(input_dim, embed_dim, 1) self.unzipper = nn.Conv1d(embed_dim, input_dim, 1, bias=0) s_net = [] for i in range(n_layers - 1): out_dim = max(embed_dim, input_dim // 2) s_net.append( nn.Sequential( nn.Conv1d(input_dim, out_dim, 1), nn.ReLU(inplace=True), ) ) input_dim = out_dim s_net.append(nn.Conv1d(input_dim, embed_dim, 1)) self.s_net = nn.Sequential(*s_net) self.s_fc = nn.Linear(embed_dim ** 2, embed_dim ** 2) def _vectorized_covariance(self, x): cov = torch.bmm(x, x.transpose(2, 1)) / x.size(-1) cov = cov.flatten(1) return cov def get_content_matrix(self, c): ''' Args: c: content feature [N,input_dim,S] Return: mat: [N,S,embed_dim,embed_dim] ''' normalized_c = self.IN(c) # normalized_c = torch.nn.functional.instance_norm(c) q_embed = self.q_embed(normalized_c) k_embed = self.k_embed(normalized_c) c_cov = q_embed.transpose(1,2).unsqueeze(3) * k_embed.transpose(1,2).unsqueeze(2) # [N,S,embed_dim,embed_dim] attn = torch.softmax(c_cov, -1) # [N,S,embed_dim,embed_dim] return attn, normalized_c def get_style_mean_std_matrix(self, s): ''' Args: s: style feature [N,input_dim,S] Return: mat: [N,embed_dim,embed_dim] ''' s_mean = s.mean(-1, keepdim=True) s_std = s.std(-1, keepdim=True) s = s - s_mean s_embed = self.s_net(s) s_cov = self._vectorized_covariance(s_embed) s_mat = self.s_fc(s_cov) s_mat = s_mat.reshape(-1, self.embed_dim, self.embed_dim) return s_mean, s_std, s_mat def transform_content_3D(self, c): ''' Args: c: content feature [N,input_dim,S] Return: transformed_c: [N,embed_dim,S] ''' attn, normalized_c = self.get_content_matrix(c) # [N,S,embed_dim,embed_dim] c = self.v_embed(normalized_c) # [N,embed_dim,S] c = c.transpose(1,2).unsqueeze(3) # [N,S,embed_dim,1] c = torch.matmul(attn, c).squeeze(3) # [N,S,embed_dim] return c.transpose(1,2) def transfer_style_2D(self, s_mean_std_mat, c, acc_map): ''' Agrs: c: content feature map after volume rendering [N,embed_dim,S] s_mat: style matrix [N,embed_dim,embed_dim] acc_map: [S] s_mean = [N,input_dim,1] s_std = [N,input_dim,1] ''' s_mean, s_std, s_mat = s_mean_std_mat cs = torch.bmm(s_mat, c) # [N,embed_dim,S] cs = self.unzipper(cs) # [N,input_dim,S] cs = cs * s_std + s_mean * acc_map[None,None,...] return cs class AdaAttN(nn.Module): """ Attention-weighted AdaIN (Liu et al., ICCV 21) """ def __init__(self, qk_dim, v_dim): """ Args: qk_dim (int): query and key size. v_dim (int): value size. """ super(AdaAttN, self).__init__() self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1) self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1) self.s_embed = nn.Conv1d(v_dim, v_dim, 1) def forward(self, q, k): """ Args: q (float tensor, (bs, qk, *)): query (content) features. k (float tensor, (bs, qk, *)): key (style) features. c (float tensor, (bs, v, *)): content value features. s (float tensor, (bs, v, *)): style value features. Returns: cs (float tensor, (bs, v, *)): stylized content features. """ c, s = q, k shape = c.shape q, k = q.flatten(2), k.flatten(2) c, s = c.flatten(2), s.flatten(2) # QKV attention with projected content and style features q = self.q_embed(F.instance_norm(q)).transpose(2, 1) # (bs, n, qk) k = self.k_embed(F.instance_norm(k)) # (bs, qk, m) s = self.s_embed(s).transpose(2, 1) # (bs, m, v) attn = F.softmax(torch.bmm(q, k), -1) # (bs, n, m) # attention-weighted channel-wise statistics mean = torch.bmm(attn, s) # (bs, n, v) var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2) # (bs, n, v) mean = mean.transpose(2, 1) # (bs, v, n) std = torch.sqrt(var).transpose(2, 1) # (bs, v, n) cs = F.instance_norm(c) * std + mean # (bs, v, n) cs = cs.reshape(shape) return cs class AdaAttN_new_IN(nn.Module): """ Attention-weighted AdaIN (Liu et al., ICCV 21) """ def __init__(self, qk_dim, v_dim): """ Args: qk_dim (int): query and key size. v_dim (int): value size. """ super(AdaAttN_new_IN, self).__init__() self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1) self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1) self.s_embed = nn.Conv1d(v_dim, v_dim, 1) self.IN = LearnableIN(qk_dim) def forward(self, q, k): """ Args: q (float tensor, (bs, qk, *)): query (content) features. k (float tensor, (bs, qk, *)): key (style) features. c (float tensor, (bs, v, *)): content value features. s (float tensor, (bs, v, *)): style value features. Returns: cs (float tensor, (bs, v, *)): stylized content features. """ c, s = q, k shape = c.shape q, k = q.flatten(2), k.flatten(2) c, s = c.flatten(2), s.flatten(2) # QKV attention with projected content and style features q = self.q_embed(self.IN(q)).transpose(2, 1) # (bs, n, qk) k = self.k_embed(F.instance_norm(k)) # (bs, qk, m) s = self.s_embed(s).transpose(2, 1) # (bs, m, v) attn = F.softmax(torch.bmm(q, k), -1) # (bs, n, m) # attention-weighted channel-wise statistics mean = torch.bmm(attn, s) # (bs, n, v) var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2) # (bs, n, v) mean = mean.transpose(2, 1) # (bs, v, n) std = torch.sqrt(var).transpose(2, 1) # (bs, v, n) cs = self.IN(c) * std + mean # (bs, v, n) cs = cs.reshape(shape) return cs class AdaAttN_woin(nn.Module): """ Attention-weighted AdaIN (Liu et al., ICCV 21) """ def __init__(self, qk_dim, v_dim): """ Args: qk_dim (int): query and key size. v_dim (int): value size. """ super().__init__() self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1) self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1) self.s_embed = nn.Conv1d(v_dim, v_dim, 1) def forward(self, q, k): """ Args: q (float tensor, (bs, qk, *)): query (content) features. k (float tensor, (bs, qk, *)): key (style) features. c (float tensor, (bs, v, *)): content value features. s (float tensor, (bs, v, *)): style value features. Returns: cs (float tensor, (bs, v, *)): stylized content features. """ c, s = q, k shape = c.shape q, k = q.flatten(2), k.flatten(2) c, s = c.flatten(2), s.flatten(2) # QKV attention with projected content and style features q = self.q_embed(q).transpose(2, 1) # (bs, n, qk) k = self.k_embed(k) # (bs, qk, m) s = self.s_embed(s).transpose(2, 1) # (bs, m, v) attn = F.softmax(torch.bmm(q, k), -1) # (bs, n, m) # attention-weighted channel-wise statistics mean = torch.bmm(attn, s) # (bs, n, v) var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2) # (bs, n, v) mean = mean.transpose(2, 1) # (bs, v, n) std = torch.sqrt(var).transpose(2, 1) # (bs, v, n) cs = c * std + mean # (bs, v, n) cs = cs.reshape(shape) return cs ================================================ FILE: models/tensoRF.py ================================================ from .tensorBase import * from .VGG import Encoder, Decoder, UNetDecoder, PlainDecoder from .styleModules import LearnableIN, SimpleLinearStylizer class TensorVMSplit(TensorBase): def __init__(self, aabb, gridSize, device, **kargs): super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs) def change_to_feature_mod(self, feature_n_comp, device): self.density_line.requires_grad_(False) self.density_plane.requires_grad_(False) self.app_line = None self.app_plane = None self.basis_mat = None self.renderModule = None # Both encoder and decoder do not require grad when initialized self.encoder = Encoder().to(device) self.decoder = PlainDecoder().to(device) # We need to finetune decoder when training a feature grid self.decoder.requires_grad_(True) self.feature_n_comp = feature_n_comp self.init_feature_svd(device) def change_to_style_mod(self, device='cuda'): assert self.feature_line is not None, 'Have to be trained in feature mod first!' self.feature_line.requires_grad_(False) self.feature_plane.requires_grad_(False) self.feature_basis_mat.requires_grad_(False) self.decoder.requires_grad_(True) self.IN = LearnableIN().to(device) # self.stylizer = NearestFeatureTransform() # self.stylizer = LearnableIN(1,256,device) # self.stylizer = AdaAttN(256, 256).to(device) # self.stylizer = AdaAttN_woin(256, 256).to(device) self.stylizer = SimpleLinearStylizer(256).to(device) def init_svd_volume(self, res, device): self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device) self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device) self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device) def init_feature_svd(self, device): self.feature_plane, self.feature_line = self.init_one_svd(self.feature_n_comp, self.gridSize, 0.1, device) self.feature_basis_mat = torch.nn.Linear(sum(self.feature_n_comp), 256, bias=False).to(device) def init_one_svd(self, n_component, gridSize, scale, device): plane_coef, line_coef = [], [] for i in range(len(self.vecMode)): vec_id = self.vecMode[i] mat_id_0, mat_id_1 = self.matMode[i] plane_coef.append(torch.nn.Parameter( scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) line_coef.append( torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1)))) return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device) def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, {'params': self.density_plane, 'lr': lr_init_spatialxyz}, {'params': self.app_line, 'lr': lr_init_spatialxyz}, {'params': self.app_plane, 'lr': lr_init_spatialxyz}, {'params': self.basis_mat.parameters(), 'lr':lr_init_network}] if isinstance(self.renderModule, torch.nn.Module): grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}] return grad_vars def get_optparam_groups_feature_mod(self, lr_init_spatialxyz, lr_init_network): grad_vars = [{'params': self.feature_line, 'lr': lr_init_spatialxyz}, {'params': self.feature_plane, 'lr': lr_init_spatialxyz}, {'params': self.feature_basis_mat.parameters(), 'lr':lr_init_network}, {'params': self.decoder.parameters(), 'lr':lr_init_network}] return grad_vars def get_optparam_groups_style_mod(self, lr_init_network, lr_finetune): grad_vars = [ {'params': self.stylizer.parameters(), 'lr': lr_init_network}, {'params': self.decoder.parameters(), 'lr': lr_finetune}, ] return grad_vars def vectorDiffs(self, vector_comps): total = 0 for idx in range(len(vector_comps)): n_comp, n_size = vector_comps[idx].shape[1:-1] dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2)) non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1] total = total + torch.mean(torch.abs(non_diagonal)) return total def vector_comp_diffs(self): return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line) def density_L1(self): total = 0 for idx in range(len(self.density_plane)): total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[idx]))# + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx])) return total def TV_loss_density(self, reg): total = 0 for idx in range(len(self.density_plane)): total = total + reg(self.density_plane[idx]) * 1e-2 #+ reg(self.density_line[idx]) * 1e-3 return total def TV_loss_app(self, reg): total = 0 for idx in range(len(self.app_plane)): total = total + reg(self.app_plane[idx]) * 1e-2 #+ reg(self.app_line[idx]) * 1e-3 return total def TV_loss_feature(self, reg): total = 0 for idx in range(len(self.feature_plane)): total = total + reg(self.feature_plane[idx]) + reg(self.feature_line[idx]) return total def compute_densityfeature(self, xyz_sampled): # plane + line basis coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device) for idx_plane in range(len(self.density_plane)): plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1]) line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1]) sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0) return sigma_feature def compute_appfeature(self, xyz_sampled): # plane + line basis coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) plane_coef_point,line_coef_point = [],[] for idx_plane in range(len(self.app_plane)): plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1])) line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1])) plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point) return self.basis_mat((plane_coef_point * line_coef_point).T) def compute_feature(self, xyz_sampled): # plane + line basis coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) plane_coef_point,line_coef_point = [],[] for idx_plane in range(len(self.feature_plane)): plane_coef_point.append(F.grid_sample(self.feature_plane[idx_plane], coordinate_plane[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1])) line_coef_point.append(F.grid_sample(self.feature_line[idx_plane], coordinate_line[[idx_plane]], align_corners=True).view(-1, *xyz_sampled.shape[:1])) plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point) return self.feature_basis_mat((plane_coef_point * line_coef_point).T) def render_feature_map(self, rays_chunk, s_mean_std_mat=None, is_train=False, ndc_ray=False, N_samples=-1): # sample points viewdirs = rays_chunk[:, 3:6] if ndc_ray: xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True) dists = dists * rays_norm else: xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) if self.alphaMask is not None: alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid]) alpha_mask = alphas > 0 ray_invalid = ~ray_valid ray_invalid[ray_valid] |= (~alpha_mask) ray_valid = ~ray_invalid sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) if s_mean_std_mat is not None: features = torch.zeros((*xyz_sampled.shape[:2], self.stylizer.embed_dim), device=xyz_sampled.device) else: features = torch.zeros((*xyz_sampled.shape[:2], 256), device=xyz_sampled.device) if ray_valid.any(): xyz_sampled = self.normalize_coord(xyz_sampled) sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid]) validsigma = self.feature2density(sigma_feature) sigma[ray_valid] = validsigma alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale) app_mask = weight > self.rayMarch_weight_thres if app_mask.any(): valid_features = self.compute_feature(xyz_sampled[app_mask]) # [n_valid_points~40k if not specify nSamples, C=256] # transform content on 3d if s_mean_std_mat is not None: valid_features = self.stylizer.transform_content_3D(valid_features.transpose(0,1)[None,...]) valid_features = valid_features.squeeze(0).transpose(0,1) features[app_mask] = valid_features feature_map = torch.sum(weight[..., None] * features, -2) acc_map = torch.sum(weight, -1) # style transfer on 2d if s_mean_std_mat is not None: feature_map = self.stylizer.transfer_style_2D(s_mean_std_mat, feature_map.transpose(0,1)[None,...], acc_map) feature_map = feature_map.squeeze().transpose(0,1) return feature_map, acc_map def render_depth_map(self, rays_chunk, is_train=False, ndc_ray=False, N_samples=-1): # sample points viewdirs = rays_chunk[:, 3:6] if ndc_ray: xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True) dists = dists * rays_norm else: xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) if self.alphaMask is not None: alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid]) alpha_mask = alphas > 0 ray_invalid = ~ray_valid ray_invalid[ray_valid] |= (~alpha_mask) ray_valid = ~ray_invalid sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) if ray_valid.any(): xyz_sampled = self.normalize_coord(xyz_sampled) sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid]) validsigma = self.feature2density(sigma_feature) sigma[ray_valid] = validsigma alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale) acc_map = torch.sum(weight, -1) depth_map = torch.sum(weight * z_vals, -1) depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1] return depth_map # [n_rays] @torch.no_grad() def up_sampling_VM(self, plane_coef, line_coef, res_target): for i in range(len(self.vecMode)): vec_id = self.vecMode[i] mat_id_0, mat_id_1 = self.matMode[i] plane_coef[i] = torch.nn.Parameter( F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear', align_corners=True)) line_coef[i] = torch.nn.Parameter( F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) return plane_coef, line_coef @torch.no_grad() def upsample_volume_grid(self, res_target): self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target) self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target) self.update_stepSize(res_target) print(f'upsamping to {res_target}') @torch.no_grad() def shrink(self, new_aabb): print("====> shrinking ...") xyz_min, xyz_max = new_aabb t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units # print(new_aabb, self.aabb) # print(t_l, b_r,self.alphaMask.alpha_volume.shape) t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1 b_r = torch.stack([b_r, self.gridSize]).amin(0) for i in range(len(self.vecMode)): mode0 = self.vecMode[i] self.density_line[i] = torch.nn.Parameter( self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:] ) self.app_line[i] = torch.nn.Parameter( self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:] ) mode0, mode1 = self.matMode[i] self.density_plane[i] = torch.nn.Parameter( self.density_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]] ) self.app_plane[i] = torch.nn.Parameter( self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]] ) if not torch.all(self.alphaMask.gridSize == self.gridSize): t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1) correct_aabb = torch.zeros_like(new_aabb) correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1] correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1] print("aabb", new_aabb, "\ncorrect aabb", correct_aabb) new_aabb = correct_aabb newSize = b_r - t_l self.aabb = new_aabb self.update_stepSize((newSize[0], newSize[1], newSize[2])) ================================================ FILE: models/tensorBase.py ================================================ import torch import torch.nn import torch.nn.functional as F from .sh import eval_sh_bases import numpy as np import time def positional_encoding(positions, freqs): freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,) pts = (positions[..., None] * freq_bands).reshape( positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF) pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) return pts def raw2alpha(sigma, dist): # sigma, dist [N_rays, N_samples] alpha = 1. - torch.exp(-sigma*dist) T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) weights = alpha * T[:, :-1] # [N_rays, N_samples] return alpha, weights, T[:,-1:] def SHRender(xyz_sampled, viewdirs, features): sh_mult = eval_sh_bases(2, viewdirs)[:, None] rgb_sh = features.view(-1, 3, sh_mult.shape[-1]) rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5) return rgb def RGBRender(xyz_sampled, viewdirs, features): rgb = features return rgb class AlphaGridMask(torch.nn.Module): def __init__(self, device, aabb, alpha_volume): super(AlphaGridMask, self).__init__() self.device = device self.aabb=aabb.to(self.device) self.aabbSize = self.aabb[1] - self.aabb[0] self.invgridSize = 1.0/self.aabbSize * 2 self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:]) self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device) def sample_alpha(self, xyz_sampled): xyz_sampled = self.normalize_coord(xyz_sampled) alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1) return alpha_vals def normalize_coord(self, xyz_sampled): return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1 class MLPRender_Fea(torch.nn.Module): def __init__(self,inChanel, viewpe=6, feape=6, featureC=128): super(MLPRender_Fea, self).__init__() self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel self.viewpe = viewpe self.feape = feape layer1 = torch.nn.Linear(self.in_mlpC, featureC) layer2 = torch.nn.Linear(featureC, featureC) layer3 = torch.nn.Linear(featureC,3) self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) torch.nn.init.constant_(self.mlp[-1].bias, 0) def forward(self, pts, viewdirs, features): indata = [features, viewdirs] if self.feape > 0: indata += [positional_encoding(features, self.feape)] if self.viewpe > 0: indata += [positional_encoding(viewdirs, self.viewpe)] mlp_in = torch.cat(indata, dim=-1) rgb = self.mlp(mlp_in) rgb = torch.sigmoid(rgb) return rgb class MLPRender_PE(torch.nn.Module): def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128): super(MLPRender_PE, self).__init__() self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel # self.viewpe = viewpe self.pospe = pospe layer1 = torch.nn.Linear(self.in_mlpC, featureC) layer2 = torch.nn.Linear(featureC, featureC) layer3 = torch.nn.Linear(featureC,3) self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) torch.nn.init.constant_(self.mlp[-1].bias, 0) def forward(self, pts, viewdirs, features): indata = [features, viewdirs] if self.pospe > 0: indata += [positional_encoding(pts, self.pospe)] if self.viewpe > 0: indata += [positional_encoding(viewdirs, self.viewpe)] mlp_in = torch.cat(indata, dim=-1) rgb = self.mlp(mlp_in) rgb = torch.sigmoid(rgb) return rgb class MLPRender(torch.nn.Module): def __init__(self,inChanel, viewpe=6, featureC=128): super(MLPRender, self).__init__() self.in_mlpC = (3+2*viewpe*3) + inChanel self.viewpe = viewpe layer1 = torch.nn.Linear(self.in_mlpC, featureC) layer2 = torch.nn.Linear(featureC, featureC) layer3 = torch.nn.Linear(featureC,3) self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) torch.nn.init.constant_(self.mlp[-1].bias, 0) def forward(self, pts, viewdirs, features): indata = [features, viewdirs] if self.viewpe > 0: indata += [positional_encoding(viewdirs, self.viewpe)] mlp_in = torch.cat(indata, dim=-1) rgb = self.mlp(mlp_in) rgb = torch.sigmoid(rgb) return rgb class TensorBase(torch.nn.Module): def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27, shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0], density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001, pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0, fea2denseAct = 'softplus'): super(TensorBase, self).__init__() self.density_n_comp = density_n_comp self.app_n_comp = appearance_n_comp self.app_dim = app_dim self.aabb = aabb self.alphaMask = alphaMask self.device=device self.density_shift = density_shift self.alphaMask_thres = alphaMask_thres self.distance_scale = distance_scale self.rayMarch_weight_thres = rayMarch_weight_thres self.fea2denseAct = fea2denseAct self.near_far = near_far self.step_ratio = step_ratio self.update_stepSize(gridSize) self.matMode = [[0,1], [0,2], [1,2]] self.vecMode = [2, 1, 0] self.comp_w = [1,1,1] self.init_svd_volume(gridSize[0], device) self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device) def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device): if shadingMode == 'MLP_PE': self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device) elif shadingMode == 'MLP_Fea': self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device) elif shadingMode == 'MLP': self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device) elif shadingMode == 'SH': self.renderModule = SHRender elif shadingMode == 'RGB': assert self.app_dim == 3 self.renderModule = RGBRender else: print("Unrecognized shading module") exit() print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe) def update_stepSize(self, gridSize): print("aabb", self.aabb.view(-1)) print("grid size", gridSize) self.aabbSize = self.aabb[1] - self.aabb[0] self.invaabbSize = 2.0/self.aabbSize self.gridSize= torch.LongTensor(gridSize).to(self.device) self.units=self.aabbSize / (self.gridSize-1) self.stepSize=torch.mean(self.units)*self.step_ratio self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize))) self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1 print("sampling step size: ", self.stepSize) print("sampling number: ", self.nSamples) def init_svd_volume(self, res, device): pass def compute_features(self, xyz_sampled): pass def compute_densityfeature(self, xyz_sampled): pass def compute_appfeature(self, xyz_sampled): pass def normalize_coord(self, xyz_sampled): return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1 def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001): pass def get_kwargs(self): return { 'aabb': self.aabb, 'gridSize':self.gridSize.tolist(), 'density_n_comp': self.density_n_comp, 'appearance_n_comp': self.app_n_comp, 'app_dim': self.app_dim, 'density_shift': self.density_shift, 'alphaMask_thres': self.alphaMask_thres, 'distance_scale': self.distance_scale, 'rayMarch_weight_thres': self.rayMarch_weight_thres, 'fea2denseAct': self.fea2denseAct, 'near_far': self.near_far, 'step_ratio': self.step_ratio, 'shadingMode': self.shadingMode, 'pos_pe': self.pos_pe, 'view_pe': self.view_pe, 'fea_pe': self.fea_pe, 'featureC': self.featureC } def save(self, path): kwargs = self.get_kwargs() ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()} if self.alphaMask is not None: alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy() ckpt.update({'alphaMask.shape':alpha_volume.shape}) ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))}) ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()}) torch.save(ckpt, path) def load(self, ckpt): if 'alphaMask.aabb' in ckpt.keys(): length = np.prod(ckpt['alphaMask.shape']) alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape'])) self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device)) self.load_state_dict(ckpt['state_dict']) def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): N_samples = N_samples if N_samples > 0 else self.nSamples near, far = self.near_far interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o) if is_train: interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples) rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) return rays_pts, interpx, ~mask_outbbox def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): N_samples = N_samples if N_samples>0 else self.nSamples stepsize = self.stepSize near, far = self.near_far vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d) rate_a = (self.aabb[1] - rays_o) / vec rate_b = (self.aabb[0] - rays_o) / vec t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far) rng = torch.arange(N_samples)[None].float() if is_train: rng = rng.repeat(rays_d.shape[-2],1) rng += torch.rand_like(rng[:,[0]]) step = stepsize * rng.to(rays_o.device) interpx = (t_min[...,None] + step) rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None] mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1) return rays_pts, interpx, ~mask_outbbox def shrink(self, new_aabb, voxel_size): pass @torch.no_grad() def getDenseAlpha(self,gridSize=None): gridSize = self.gridSize if gridSize is None else gridSize samples = torch.stack(torch.meshgrid( torch.linspace(0, 1, gridSize[0]), torch.linspace(0, 1, gridSize[1]), torch.linspace(0, 1, gridSize[2]), ), -1).to(self.device) dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples # dense_xyz = dense_xyz # print(self.stepSize, self.distance_scale*self.aabbDiag) alpha = torch.zeros_like(dense_xyz[...,0]) for i in range(gridSize[0]): alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2])) return alpha, dense_xyz @torch.no_grad() def updateAlphaMask(self, gridSize=(200,200,200)): alpha, dense_xyz = self.getDenseAlpha(gridSize) dense_xyz = dense_xyz.transpose(0,2).contiguous() alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None] total_voxels = gridSize[0] * gridSize[1] * gridSize[2] ks = 3 alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1]) alpha[alpha>=self.alphaMask_thres] = 1 alpha[alpha0.5] xyz_min = valid_xyz.amin(0) xyz_max = valid_xyz.amax(0) new_aabb = torch.stack((xyz_min, xyz_max)) total = torch.sum(alpha) print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100)) return new_aabb @torch.no_grad() def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False): print('========> filtering rays ...') tt = time.time() N = torch.tensor(all_rays.shape[:-1]).prod() mask_filtered = [] idx_chunks = torch.split(torch.arange(N), chunk) for idx_chunk in idx_chunks: rays_chunk = all_rays[idx_chunk].to(self.device) rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6] if bbox_only: vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d) rate_a = (self.aabb[1] - rays_o) / vec rate_b = (self.aabb[0] - rays_o) / vec t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far) t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far) mask_inbbox = t_max > t_min else: xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False) mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1) mask_filtered.append(mask_inbbox.cpu()) mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1]) print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}') return all_rays[mask_filtered], all_rgbs[mask_filtered] def feature2density(self, density_features): if self.fea2denseAct == "softplus": return F.softplus(density_features+self.density_shift) elif self.fea2denseAct == "relu": return F.relu(density_features) def compute_alpha(self, xyz_locs, length=1): if self.alphaMask is not None: alphas = self.alphaMask.sample_alpha(xyz_locs) alpha_mask = alphas > 0 else: alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool) sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device) if alpha_mask.any(): xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask]) sigma_feature = self.compute_densityfeature(xyz_sampled) validsigma = self.feature2density(sigma_feature) sigma[alpha_mask] = validsigma alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1]) return alpha def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): # sample points viewdirs = rays_chunk[:, 3:6] if ndc_ray: xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True) dists = dists * rays_norm viewdirs = viewdirs / rays_norm else: xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape) if self.alphaMask is not None: alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid]) alpha_mask = alphas > 0 ray_invalid = ~ray_valid ray_invalid[ray_valid] |= (~alpha_mask) ray_valid = ~ray_invalid sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device) if ray_valid.any(): xyz_sampled = self.normalize_coord(xyz_sampled) sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid]) validsigma = self.feature2density(sigma_feature) sigma[ray_valid] = validsigma alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale) app_mask = weight > self.rayMarch_weight_thres if app_mask.any(): app_features = self.compute_appfeature(xyz_sampled[app_mask]) valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features) rgb[app_mask] = valid_rgbs acc_map = torch.sum(weight, -1) rgb_map = torch.sum(weight[..., None] * rgb, -2) if white_bg or (is_train and torch.rand((1,))<0.5): rgb_map = rgb_map + (1. - acc_map[..., None]) rgb_map = rgb_map.clamp(0,1) with torch.no_grad(): depth_map = torch.sum(weight * z_vals, -1) depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1] return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight ================================================ FILE: opt.py ================================================ import configargparse def config_parser(cmd=None): parser = configargparse.ArgumentParser() parser.add_argument('--config', is_config_file=True, help='config file path') parser.add_argument("--expname", type=str, help='experiment name') parser.add_argument("--basedir", type=str, default='./log', help='where to store ckpts and logs') parser.add_argument("--add_timestamp", type=int, default=0, help='add timestamp to dir') parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory') parser.add_argument("--wikiartdir", type=str, default='./data/WikiArt', help='input data directory') parser.add_argument("--progress_refresh_rate", type=int, default=10, help='how many iterations to show psnrs or iters') parser.add_argument('--with_depth', action='store_true') parser.add_argument('--downsample_train', type=float, default=1.0) parser.add_argument('--downsample_test', type=float, default=1.0) parser.add_argument('--model_name', type=str, default='TensorVMSplit', choices=['TensorVMSplit', 'TensorCP']) # loader options parser.add_argument("--batch_size", type=int, default=4096) parser.add_argument("--n_iters", type=int, default=30000) parser.add_argument('--dataset_name', type=str, default='blender', choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']) # training options parser.add_argument("--patch_size", type=int, default=256, help='patch_size for training') parser.add_argument("--chunk_size", type=int, default=4096, help='chunk_size for training') # learning rate parser.add_argument("--lr_init", type=float, default=0.02, help='learning rate') parser.add_argument("--lr_basis", type=float, default=1e-4, help='learning rate') parser.add_argument("--lr_finetune", type=float, default=1e-5, help='learning rate') parser.add_argument("--lr_decay_iters", type=int, default=-1, help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters') parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1, help='the target decay ratio; after decay_iters inital lr decays to lr*ratio') parser.add_argument("--lr_upsample_reset", type=int, default=1, help='reset lr to inital after upsampling') # loss parser.add_argument("--L1_weight_inital", type=float, default=0.0, help='loss weight') parser.add_argument("--L1_weight_rest", type=float, default=0, help='loss weight') parser.add_argument("--Ortho_weight", type=float, default=0.0, help='loss weight') parser.add_argument("--TV_weight_density", type=float, default=0.0, help='loss weight') parser.add_argument("--TV_weight_app", type=float, default=0.0, help='loss weight') parser.add_argument("--TV_weight_feature", type=float, default=0.0, help='loss weight') parser.add_argument("--style_weight", type=float, default=0, help='loss weight') parser.add_argument("--content_weight", type=float, default=0, help='loss weight') parser.add_argument("--image_tv_weight", type=float, default=0, help='loss weight') parser.add_argument("--featuremap_tv_weight", type=float, default=0, help='loss weight') # model # volume options parser.add_argument("--n_lamb_sigma", type=int, action="append") parser.add_argument("--n_lamb_sh", type=int, action="append") parser.add_argument("--data_dim_color", type=int, default=27) parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001, help='mask points in ray marching') parser.add_argument("--alpha_mask_thre", type=float, default=0.0001, help='threshold for creating alpha mask volume') parser.add_argument("--distance_scale", type=float, default=25, help='scaling sampling distance for computation') parser.add_argument("--density_shift", type=float, default=-10, help='shift density in softplus; making density = 0 when feature == 0') # network decoder parser.add_argument("--shadingMode", type=str, default="MLP_Fea", help='which shading mode to use') parser.add_argument("--pos_pe", type=int, default=6, help='number of pe for pos') parser.add_argument("--view_pe", type=int, default=6, help='number of pe for view') parser.add_argument("--fea_pe", type=int, default=6, help='number of pe for features') parser.add_argument("--featureC", type=int, default=128, help='hidden feature channel in MLP') # test option parser.add_argument("--ckpt", type=str, default=None, help='specific weights npy file to reload for coarse network') parser.add_argument("--render_only", type=int, default=0) parser.add_argument("--render_test", type=int, default=0) parser.add_argument("--render_train", type=int, default=0) parser.add_argument("--render_path", type=int, default=0) parser.add_argument("--export_mesh", type=int, default=0) parser.add_argument("--style_img", type=str, required=False) # rendering options parser.add_argument('--lindisp', default=False, action="store_true", help='use disparity depth sampling') parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter') parser.add_argument("--accumulate_decay", type=float, default=0.998) parser.add_argument("--fea2denseAct", type=str, default='softplus') parser.add_argument('--ndc_ray', type=int, default=0) parser.add_argument('--nSamples', type=int, default=1e6, help='sample point each ray, pass 1e6 if automatic adjust') parser.add_argument('--step_ratio',type=float,default=0.5) ## blender flags parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)') parser.add_argument('--N_voxel_init', type=int, default=100**3) parser.add_argument('--N_voxel_final', type=int, default=300**3) parser.add_argument("--upsamp_list", type=int, action="append") parser.add_argument("--update_AlphaMask_list", type=int, action="append") parser.add_argument('--idx_view', type=int, default=0) # logging/saving options parser.add_argument("--N_vis", type=int, default=5, help='N images to vis') parser.add_argument("--vis_every", type=int, default=10000, help='frequency of visualize the image') if cmd is not None: return parser.parse_args(cmd) else: return parser.parse_args() ================================================ FILE: renderer.py ================================================ import torch,os,imageio,sys from tqdm.auto import tqdm from dataLoader.ray_utils import get_rays from models.tensoRF import raw2alpha, TensorVMSplit, AlphaGridMask from utils import * from dataLoader.ray_utils import ndc_rays_blender, denormalize_vgg, normalize_vgg def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, render_feature=False, style_img=None, device='cuda'): rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], [] features, accs = [], [] s_mean_std_mat = None if style_img is not None: with torch.no_grad(): style_feature = tensorf.encoder(normalize_vgg(style_img)) s_mean_std_mat = tensorf.stylizer.get_style_mean_std_matrix(style_feature.relu3_1.flatten(2)) N_rays_all = rays.shape[0] for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)): rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device) if render_feature: feature_map, acc_map = tensorf.render_feature_map(rays_chunk, s_mean_std_mat=s_mean_std_mat, is_train=is_train, ndc_ray=ndc_ray, N_samples=N_samples) features.append(feature_map) accs.append(acc_map) else: rgb_map, depth_map = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples) rgbs.append(rgb_map) depth_maps.append(depth_map) if render_feature: if style_img is not None: return torch.cat(features), torch.cat(accs), style_feature return torch.cat(features), torch.cat(accs) return torch.cat(rgbs), None, torch.cat(depth_maps), None, None def OctreeRender_trilinear_fast_depth(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, is_train=False, device='cuda'): depth_maps = [] N_rays_all = rays.shape[0] for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)): rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device) depth_map = tensorf.render_depth_map(rays_chunk, is_train=is_train, ndc_ray=ndc_ray, N_samples=N_samples) depth_maps.append(depth_map) return torch.cat(depth_maps) @torch.no_grad() def evaluation_feature(test_dataset, tensorf, args, renderer, chunk_size=2048, savePath=None, N_vis=10, prtx='', N_samples=-1, white_bg=False, ndc_ray=False, compute_extra_metrics=False, style_img=None, device='cuda'): ''' To see if the decoded feature map is similar to gt rgb map ''' PSNRs, rgb_maps, vis_feature_maps = [], [], [] ssims,l_alex,l_vgg=[],[],[] os.makedirs(savePath, exist_ok=True) os.makedirs(savePath + '/feature', exist_ok=True) W, H = test_dataset.img_wh try: tqdm._instances.clear() except Exception: pass near_far = test_dataset.near_far img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays_stack.shape[0] // N_vis,1) idxs = list(range(0, test_dataset.all_rays_stack.shape[0], img_eval_interval)) for idx, samples in tqdm(enumerate(test_dataset.all_rays_stack[0::img_eval_interval]), file=sys.stdout): rays = samples.view(-1,samples.shape[-1]) if style_img is None: feature_map, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, render_feature=True, device=device) else: feature_map, _, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, render_feature=True, style_img=style_img, device=device) feature_map = feature_map.reshape(H, W, 256)[None,...].permute(0,3,1,2) recon_rgb = denormalize_vgg(tensorf.decoder(feature_map)) recon_rgb = recon_rgb.permute(0,2,3,1).clamp(0,1) vis_feature_map = torch.sigmoid(feature_map[:, [1,2,3], :, :].permute(0,2,3,1)) if test_dataset.white_bg: mask = test_dataset.all_masks[idx:idx+1].to(device) recon_rgb = mask*recon_rgb + (1.-mask) vis_feature_map = mask*vis_feature_map + (1.-mask) recon_rgb = recon_rgb.reshape(H, W, 3).cpu() vis_feature_map = vis_feature_map.squeeze().cpu() if len(test_dataset.all_rgbs_stack): gt_rgb = test_dataset.all_rgbs_stack[idxs[idx]].view(H, W, 3) loss = torch.mean((recon_rgb - gt_rgb) ** 2) PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0)) if compute_extra_metrics: ssim = rgb_ssim(recon_rgb, gt_rgb, 1) l_a = rgb_lpips(gt_rgb.numpy(), recon_rgb.numpy(), 'alex', tensorf.device) l_v = rgb_lpips(gt_rgb.numpy(), recon_rgb.numpy(), 'vgg', tensorf.device) ssims.append(ssim) l_alex.append(l_a) l_vgg.append(l_v) recon_rgb = (recon_rgb.numpy() * 255).astype('uint8') vis_feature_map = (vis_feature_map.numpy() * 255).astype('uint8') gt_rgb = (gt_rgb.numpy() * 255).astype('uint8') if savePath is not None: # rgb_map = np.concatenate((recon_rgb, gt_rgb), axis=1) rgb_maps.append(recon_rgb) vis_feature_maps.append(vis_feature_map) imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', recon_rgb) imageio.imwrite(f'{savePath}/feature/feature_{idx:03d}.png', vis_feature_map) imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8) imageio.mimwrite(f'{savePath}/feature/feature_video.mp4', np.stack(vis_feature_maps), fps=30, quality=8) if PSNRs: psnr = np.mean(np.asarray(PSNRs)) if compute_extra_metrics: ssim = np.mean(np.asarray(ssims)) l_a = np.mean(np.asarray(l_alex)) l_v = np.mean(np.asarray(l_vgg)) np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) else: np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) return PSNRs @torch.no_grad() def evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, chunk_size=2048, savePath=None, N_vis=5, prtx='', N_samples=-1, white_bg=False, ndc_ray=False, compute_extra_metrics=False, style_img=None, device='cuda'): PSNRs, rgb_maps, vis_feature_maps, depth_maps = [], [], [], [] ssims,l_alex,l_vgg=[],[],[] os.makedirs(savePath, exist_ok=True) os.makedirs(savePath + '/feature', exist_ok=True) W, H = test_dataset.img_wh try: tqdm._instances.clear() except Exception: pass near_far = test_dataset.near_far for idx, c2w in tqdm(enumerate(c2ws)): c2w = torch.FloatTensor(c2w) rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3) if ndc_ray: rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d) rays = torch.cat([rays_o, rays_d], 1).reshape(H, W, 6).permute(2,0,1) # (6,H,W) rays = rays.permute(1,2,0).reshape(-1,6) # (H * W, 6) if style_img is None: feature_map, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, render_feature=True, device=device) else: feature_map, _, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, render_feature=True, style_img=style_img, device=device) feature_map = feature_map.reshape(H, W, 256)[None,...].permute(0,3,1,2) recon_rgb = denormalize_vgg(tensorf.decoder(feature_map)) recon_rgb = recon_rgb.permute(0,2,3,1).clamp(0,1) recon_rgb = recon_rgb.reshape(H, W, 3).cpu() recon_rgb = (recon_rgb.numpy() * 255).astype('uint8') rgb_maps.append(recon_rgb) vis_feature_map = torch.sigmoid(feature_map[:, [1,2,3], :, :].permute(0,2,3,1)) vis_feature_map = vis_feature_map.squeeze().cpu() vis_feature_map = (vis_feature_map.numpy() * 255).astype('uint8') vis_feature_maps.append(vis_feature_map) if savePath is not None: imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', recon_rgb) imageio.imwrite(f'{savePath}/feature/feature_{idx:03d}.png', vis_feature_map) imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8) imageio.mimwrite(f'{savePath}/feature/feature_video.mp4', np.stack(vis_feature_maps), fps=30, quality=8) if PSNRs: psnr = np.mean(np.asarray(PSNRs)) if compute_extra_metrics: ssim = np.mean(np.asarray(ssims)) l_a = np.mean(np.asarray(l_alex)) l_v = np.mean(np.asarray(l_vgg)) np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) else: np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) return PSNRs @torch.no_grad() def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): PSNRs, rgb_maps, depth_maps = [], [], [] ssims,l_alex,l_vgg=[],[],[] os.makedirs(savePath, exist_ok=True) os.makedirs(savePath+"/rgbd", exist_ok=True) try: tqdm._instances.clear() except Exception: pass near_far = test_dataset.near_far img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays_stack.shape[0] // N_vis,1) idxs = list(range(0, test_dataset.all_rays_stack.shape[0], img_eval_interval)) for idx, samples in tqdm(enumerate(test_dataset.all_rays_stack[0::img_eval_interval]), file=sys.stdout): W, H = test_dataset.img_wh rays = samples.view(-1,samples.shape[-1]) rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, device=device) rgb_map = rgb_map.clamp(0.0, 1.0) rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu() depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) if len(test_dataset.all_rgbs_stack): gt_rgb = test_dataset.all_rgbs_stack[idxs[idx]].view(H, W, 3) loss = torch.mean((rgb_map - gt_rgb) ** 2) PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0)) if compute_extra_metrics: ssim = rgb_ssim(rgb_map, gt_rgb, 1) l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device) l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device) ssims.append(ssim) l_alex.append(l_a) l_vgg.append(l_v) rgb_map = (rgb_map.numpy() * 255).astype('uint8') # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) rgb_maps.append(rgb_map) depth_maps.append(depth_map) if savePath is not None: imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', depth_map) imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10) imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10) if PSNRs: psnr = np.mean(np.asarray(PSNRs)) if compute_extra_metrics: ssim = np.mean(np.asarray(ssims)) l_a = np.mean(np.asarray(l_alex)) l_v = np.mean(np.asarray(l_vgg)) np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) else: np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) return PSNRs @torch.no_grad() def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): PSNRs, rgb_maps, depth_maps = [], [], [] ssims,l_alex,l_vgg=[],[],[] os.makedirs(savePath, exist_ok=True) os.makedirs(savePath+"/rgbd", exist_ok=True) try: tqdm._instances.clear() except Exception: pass near_far = test_dataset.near_far for idx, c2w in tqdm(enumerate(c2ws)): W, H = test_dataset.img_wh c2w = torch.FloatTensor(c2w) rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3) if ndc_ray: rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d) rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6) rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples, ndc_ray=ndc_ray, white_bg = white_bg, device=device) rgb_map = rgb_map.clamp(0.0, 1.0) rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu() depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) rgb_map = (rgb_map.numpy() * 255).astype('uint8') # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) rgb_maps.append(rgb_map) depth_maps.append(depth_map) if savePath is not None: imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) rgb_map = np.concatenate((rgb_map, depth_map), axis=1) imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map) imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8) imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8) if PSNRs: psnr = np.mean(np.asarray(PSNRs)) if compute_extra_metrics: ssim = np.mean(np.asarray(ssims)) l_a = np.mean(np.asarray(l_alex)) l_v = np.mean(np.asarray(l_vgg)) np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) else: np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) return PSNRs ================================================ FILE: scripts/test.sh ================================================ CUDA_VISIBLE_DEVICES=$1 python train.py \ --config configs/llff.txt \ --ckpt log/trex/trex.th \ --render_only 1 --render_test 1 ================================================ FILE: scripts/test_feature.sh ================================================ expname=trex CUDA_VISIBLE_DEVICES=$1 python train_feature.py \ --config configs/llff_feature.txt \ --datadir ./data/nerf_llff_data/trex \ --expname $expname \ --ckpt ./log_feature/$expname/$expname.th \ --render_only 1 \ --render_test 0 \ --render_path 1 \ --chunk_size 1024 ================================================ FILE: scripts/test_style.sh ================================================ expname=trex CUDA_VISIBLE_DEVICES=$1 python train_style.py \ --config configs/llff_style.txt \ --datadir ./data/nerf_llff_data/trex \ --expname $expname \ --ckpt log_style/$expname/$expname.th \ --style_img path/to/reference/style/image \ --render_only 1 \ --render_train 0 \ --render_test 0 \ --render_path 1 \ --chunk_size 1024 \ --rm_weight_mask_thre 0.0001 \ ================================================ FILE: scripts/train.sh ================================================ CUDA_VISIBLE_DEVICES=$1 python train.py --config=configs/llff.txt ================================================ FILE: scripts/train_feature.sh ================================================ CUDA_VISIBLE_DEVICES=$1 python train_feature.py --config=configs/llff_feature.txt ================================================ FILE: scripts/train_style.sh ================================================ CUDA_VISIBLE_DEVICES=$1 python train_style.py --config=configs/llff_style.txt ================================================ FILE: train.py ================================================ import os from tqdm.auto import tqdm from opt import config_parser import json, random from renderer import * from utils import * from torch.utils.tensorboard import SummaryWriter import datetime from dataLoader import dataset_dict import sys device = torch.device("cuda" if torch.cuda.is_available() else "cpu") renderer = OctreeRender_trilinear_fast class SimpleSampler: def __init__(self, total, batch): self.total = total self.batch = batch self.curr = total self.ids = None def nextids(self): self.curr+=self.batch if self.curr + self.batch > self.total: self.ids = torch.LongTensor(np.random.permutation(self.total)) self.curr = 0 return self.ids[self.curr:self.curr+self.batch] @torch.no_grad() def export_mesh(args): ckpt = torch.load(args.ckpt, map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device': device}) tensorf = eval(args.model_name)(**kwargs) tensorf.load(ckpt) alpha,_ = tensorf.getDenseAlpha() convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005) @torch.no_grad() def render_test(args): # init dataset dataset = dataset_dict[args.dataset_name] test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) white_bg = test_dataset.white_bg ndc_ray = args.ndc_ray if not os.path.exists(args.ckpt): print('the ckpt path does not exists!!') return ckpt = torch.load(args.ckpt, map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device': device}) tensorf = eval(args.model_name)(**kwargs) tensorf.load(ckpt) logfolder = os.path.dirname(args.ckpt) if args.render_train: os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================') if args.render_test: os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True) evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) if args.render_path: c2ws = test_dataset.render_path os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True) evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) def reconstruction(args): # init dataset dataset = dataset_dict[args.dataset_name] train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False) test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) white_bg = train_dataset.white_bg near_far = train_dataset.near_far ndc_ray = args.ndc_ray # init resolution upsamp_list = args.upsamp_list update_AlphaMask_list = args.update_AlphaMask_list n_lamb_sigma = args.n_lamb_sigma n_lamb_sh = args.n_lamb_sh if args.add_timestamp: logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' else: logfolder = f'{args.basedir}/{args.expname}' # init log file os.makedirs(logfolder, exist_ok=True) os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True) os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True) os.makedirs(f'{logfolder}/rgba', exist_ok=True) summary_writer = SummaryWriter(logfolder) # init parameters # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0]) aabb = train_dataset.scene_bbox.to(device) reso_cur = N_to_reso(args.N_voxel_init, aabb) nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) if args.ckpt is not None: ckpt = torch.load(args.ckpt, map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device':device}) tensorf = eval(args.model_name)(**kwargs) tensorf.load(ckpt) else: tensorf = eval(args.model_name)(aabb, reso_cur, device, density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far, shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale, pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct) grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis) if args.lr_decay_iters > 0: lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters) else: args.lr_decay_iters = args.n_iters lr_factor = args.lr_decay_target_ratio**(1/args.n_iters) print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters) optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99)) #linear in logrithmic space if upsamp_list is not None: N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:] torch.cuda.empty_cache() PSNRs,PSNRs_test = [],[0] allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs if not args.ndc_ray: allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True) trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size) Ortho_reg_weight = args.Ortho_weight print("initial Ortho_reg_weight", Ortho_reg_weight) L1_reg_weight = args.L1_weight_inital print("initial L1_reg_weight", L1_reg_weight) TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app tvreg = TVLoss() print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}") pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout) for iteration in pbar: ray_idx = trainingSampler.nextids() rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device) #rgb_map, alphas_map, depth_map, weights, uncertainty rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(rays_train, tensorf, chunk=args.batch_size, N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True) loss = torch.mean((rgb_map - rgb_train) ** 2) # loss total_loss = loss if Ortho_reg_weight > 0: loss_reg = tensorf.vector_comp_diffs() total_loss += Ortho_reg_weight*loss_reg summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration) if L1_reg_weight > 0: loss_reg_L1 = tensorf.density_L1() total_loss += L1_reg_weight*loss_reg_L1 summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration) if TV_weight_density>0: TV_weight_density *= lr_factor loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density total_loss = total_loss + loss_tv summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration) if TV_weight_app>0: TV_weight_app *= lr_factor loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app total_loss = total_loss + loss_tv summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration) optimizer.zero_grad() total_loss.backward() optimizer.step() loss = loss.detach().item() PSNRs.append(-10.0 * np.log(loss) / np.log(10.0)) summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration) summary_writer.add_scalar('train/mse', loss, global_step=iteration) for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * lr_factor # Print the current values of the losses. if iteration % args.progress_refresh_rate == 0: pbar.set_description( f'Iteration {iteration:05d}:' + f' train_psnr = {float(np.mean(PSNRs)):.2f}' + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}' + f' mse = {loss:.6f}' ) PSNRs = [] if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0: PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis, prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False) summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration) if update_AlphaMask_list is not None and iteration in update_AlphaMask_list: if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution reso_mask = reso_cur new_aabb = tensorf.updateAlphaMask(tuple(reso_mask)) if iteration == update_AlphaMask_list[0]: tensorf.shrink(new_aabb) # tensorVM.alphaMask = None L1_reg_weight = args.L1_weight_rest print("continuing L1_reg_weight", L1_reg_weight) if not args.ndc_ray and iteration == update_AlphaMask_list[1]: # filter rays outside the bbox allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs) trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size) if upsamp_list is not None and iteration in upsamp_list: n_voxels = N_voxel_list.pop(0) reso_cur = N_to_reso(n_voxels, tensorf.aabb) nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) tensorf.upsample_volume_grid(reso_cur) if args.lr_upsample_reset: print("reset lr to initial") lr_scale = 1 #0.1 ** (iteration / args.n_iters) else: lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters) grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale) optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) tensorf.save(f'{logfolder}/{args.expname}.th') if args.render_train: os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') if args.render_test: os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True) PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration) print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') if args.render_path: c2ws = test_dataset.render_path # c2ws = test_dataset.poses print('========>',c2ws.shape) os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True) evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) if __name__ == '__main__': torch.set_default_dtype(torch.float32) torch.manual_seed(20211202) np.random.seed(20211202) args = config_parser() print(args) if args.export_mesh: export_mesh(args) if args.render_only and (args.render_test or args.render_path): render_test(args) else: reconstruction(args) ================================================ FILE: train_feature.py ================================================ import os from unittest.mock import patch from tqdm.auto import tqdm from opt import config_parser import json, random from renderer import * from utils import * from torch.utils.tensorboard import SummaryWriter from torchvision.utils import make_grid import torchvision.transforms.functional as TF import datetime from dataLoader import dataset_dict import sys device = torch.device("cuda" if torch.cuda.is_available() else "cpu") renderer = OctreeRender_trilinear_fast class SimpleSampler: def __init__(self, total, batch): self.total = total self.batch = batch self.curr = total self.ids = None def nextids(self): self.curr+=self.batch if self.curr + self.batch > self.total: self.ids = torch.LongTensor(np.random.permutation(self.total)) self.curr = 0 return self.ids[self.curr:self.curr+self.batch] def InfiniteSampler(n): # i = 0 i = n - 1 order = np.random.permutation(n) while True: yield order[i] i += 1 if i >= n: np.random.seed() order = np.random.permutation(n) i = 0 class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler): def __init__(self, num_samples): self.num_samples = num_samples def __iter__(self): return iter(InfiniteSampler(self.num_samples)) def __len__(self): return 2 ** 31 @torch.no_grad() def render_test(args): # init dataset dataset = dataset_dict[args.dataset_name] test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) white_bg = test_dataset.white_bg ndc_ray = args.ndc_ray if not os.path.exists(args.ckpt): print('the ckpt path does not exists!!') return ckpt = torch.load(args.ckpt, map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device': device}) tensorf = eval(args.model_name)(**kwargs) tensorf.change_to_feature_mod(args.n_lamb_sh ,device) tensorf.load(ckpt) tensorf.eval() tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre logfolder = os.path.dirname(args.ckpt) if args.render_train: os.makedirs(f'{logfolder}/{args.expname}/imgs_train_all', exist_ok=True) train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) evaluation_feature(train_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_train_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) if args.render_test: os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True) evaluation_feature(test_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_test_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) if args.render_path: c2ws = test_dataset.render_path os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True) evaluation_feature_path(test_dataset,tensorf, c2ws, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_path_all/', N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) def reconstruction(args): # init dataset dataset = dataset_dict[args.dataset_name] train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) white_bg = train_dataset.white_bg near_far = train_dataset.near_far h_rays, w_rays = train_dataset.img_wh[1], train_dataset.img_wh[0] ndc_ray = args.ndc_ray patch_size = args.patch_size if args.add_timestamp: logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' else: logfolder = f'{args.basedir}/{args.expname}' # init log file os.makedirs(logfolder, exist_ok=True) summary_writer = SummaryWriter(logfolder) # init parameters # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0]) aabb = train_dataset.scene_bbox.to(device) # TODO: need to update reso_cur reso_cur = N_to_reso(args.N_voxel_init, aabb) nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) assert args.ckpt is not None, 'Have to be pre-trained to get density fielded!' ckpt = torch.load(args.ckpt, map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device':device}) tensorf = eval(args.model_name)(**kwargs) tensorf.load(ckpt) tensorf.change_to_feature_mod(args.n_lamb_sh ,device) tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre train_dataset.prepare_feature_data(tensorf.encoder) grad_vars = tensorf.get_optparam_groups_feature_mod(args.lr_init, args.lr_basis) if args.lr_decay_iters > 0: lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters) else: args.lr_decay_iters = args.n_iters lr_factor = args.lr_decay_target_ratio**(1/args.n_iters) print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters) optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99)) torch.cuda.empty_cache() PSNRs = [] allrays, allfeatures = train_dataset.all_rays, train_dataset.all_features allrays_stack, allrgbs_stack = train_dataset.all_rays_stack, train_dataset.all_rgbs_stack if not args.ndc_ray: allrays, allfeatures = tensorf.filtering_rays(allrays, allfeatures, bbox_only=True) trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size) frameSampler = iter(InfiniteSamplerWrapper(allrays_stack.size(0))) # every next(sampler) returns a frame index TV_weight_feature = args.TV_weight_feature tvreg = TVLoss() print(f"initial TV_weight_feature: {TV_weight_feature}") pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout) for iteration in pbar: feature_loss, pixel_loss = 0., 0. if iteration%2==0: ray_idx = trainingSampler.nextids() rays_train, features_train = allrays[ray_idx], allfeatures[ray_idx].to(device) feature_map, _ = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, render_feature=True, device=device, is_train=True) feature_loss = torch.mean((feature_map - features_train) ** 2) else: frame_idx = next(frameSampler) start_h = np.random.randint(0, h_rays-patch_size+1) start_w = np.random.randint(0, w_rays-patch_size+1) if white_bg: # move random sampled patches into center mid_h, mid_w = (h_rays-patch_size+1)/2, (w_rays-patch_size+1)/2 if mid_h-start_h>=1: start_h += np.random.randint(0, mid_h-start_h) elif mid_h-start_h<=-1: start_h += np.random.randint(mid_h-start_h, 0) if mid_w-start_w>=1: start_w += np.random.randint(0, mid_w-start_w) elif mid_w-start_w<=-1: start_w += np.random.randint(mid_w-start_w, 0) rays_train = allrays_stack[frame_idx, start_h:start_h+patch_size, start_w:start_w+patch_size, :].reshape(-1, 6).to(device) # [patch*patch, 6] rgbs_train = allrgbs_stack[frame_idx, start_h:(start_h+patch_size), start_w:(start_w+patch_size), :].to(device) # [patch, patch, 3] feature_map, _ = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg=white_bg, ndc_ray=ndc_ray, render_feature=True, device=device, is_train=True) feature_map = feature_map.reshape(patch_size, patch_size, 256)[None,...].permute(0,3,1,2) recon_rgb = tensorf.decoder(feature_map) rgbs_train = rgbs_train[None,...].permute(0,3,1,2) img_enc = tensorf.encoder(normalize_vgg(rgbs_train)) recon_rgb_enc = tensorf.encoder(recon_rgb) feature_loss =(F.mse_loss(recon_rgb_enc.relu4_1, img_enc.relu4_1) + F.mse_loss(recon_rgb_enc.relu3_1, img_enc.relu3_1)) / 10 recon_rgb = denormalize_vgg(recon_rgb) pixel_loss = torch.mean((recon_rgb - rgbs_train) ** 2) total_loss = pixel_loss + feature_loss # loss # NOTE: Calculate feature TV loss rather than appearence TV loss if TV_weight_feature>0: TV_weight_feature *= lr_factor loss_tv = tensorf.TV_loss_feature(tvreg)*TV_weight_feature total_loss = total_loss + loss_tv summary_writer.add_scalar('train/reg_tv_feature', loss_tv.detach().item(), global_step=iteration) optimizer.zero_grad() total_loss.backward() optimizer.step() if pixel_loss == 0: feature_loss = feature_loss.detach().item() PSNRs.append(-10.0 * np.log(feature_loss) / np.log(10.0)) summary_writer.add_scalar('train/PSNR_feature', PSNRs[-1], global_step=iteration) summary_writer.add_scalar('train/mse_feature', feature_loss, global_step=iteration) else: pixel_loss = pixel_loss.detach().item() PSNRs.append(-10.0 * np.log(pixel_loss) / np.log(10.0)) summary_writer.add_scalar('train/PSNR_pixel', PSNRs[-1], global_step=iteration) summary_writer.add_scalar('train/mse_pixel', pixel_loss, global_step=iteration) summary_writer.add_scalar('train/mse_recon_feature', feature_loss.detach().item(), global_step=iteration) for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * lr_factor # Print the current values of the losses. if iteration % args.progress_refresh_rate == 0: pbar.set_description( f'Iteration {iteration:05d}:' + f' psnr = {float(np.mean(PSNRs)):.2f}' ) PSNRs = [] if iteration % (args.progress_refresh_rate*20) == 1: summary_writer.add_image('output', make_grid([rgbs_train.squeeze(), recon_rgb.clamp(0, 1).squeeze()], nrow=2, padding=0, normalize=False), global_step=iteration) tensorf.save(f'{logfolder}/{args.expname}.th') if __name__ == '__main__': torch.set_default_dtype(torch.float32) torch.manual_seed(20211202) np.random.seed(20211202) args = config_parser() print(args) if args.render_only: render_test(args) else: reconstruction(args) ================================================ FILE: train_style.py ================================================ import os from tqdm.auto import tqdm from opt import config_parser from PIL import Image, ImageFile from pathlib import Path from torchvision.utils import make_grid import torchvision.transforms.functional as TF from renderer import * from utils import * from torch.utils.tensorboard import SummaryWriter import datetime from dataLoader import dataset_dict from dataLoader.styleLoader import getDataLoader from models.styleModules import cal_mse_content_loss, cal_adain_style_loss import sys device = torch.device("cuda" if torch.cuda.is_available() else "cpu") renderer = OctreeRender_trilinear_fast depth_renderer = OctreeRender_trilinear_fast_depth class SimpleSampler: def __init__(self, total, batch): self.total = total self.batch = batch self.curr = total self.ids = None def nextids(self): self.curr+=self.batch if self.curr + self.batch > self.total: self.ids = torch.LongTensor(np.random.permutation(self.total)) self.curr = 0 return self.ids[self.curr:self.curr+self.batch] def InfiniteSampler(n): # i = 0 i = n - 1 order = np.random.permutation(n) while True: yield order[i] i += 1 if i >= n: np.random.seed() order = np.random.permutation(n) i = 0 class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler): def __init__(self, num_samples): self.num_samples = num_samples def __iter__(self): return iter(InfiniteSampler(self.num_samples)) def __len__(self): return 2 ** 31 @torch.no_grad() def render_test(args): # init dataset dataset = dataset_dict[args.dataset_name] ndc_ray = args.ndc_ray if not os.path.exists(args.ckpt): print('the ckpt path does not exists!!') return assert args.style_img is not None, 'Must specify a style image!' ckpt = torch.load(args.ckpt, map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device': device}) tensorf = eval(args.model_name)(**kwargs) tensorf.change_to_feature_mod(args.n_lamb_sh, device) tensorf.change_to_style_mod(device) tensorf.load(ckpt) tensorf.eval() tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre logfolder = os.path.dirname(args.ckpt) trans = T.Compose([T.Resize(size=(256,256)), T.ToTensor()]) style_img = trans(Image.open(args.style_img)).cuda()[None, ...] style_name = Path(args.style_img).stem if args.render_train: train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) os.makedirs(f'{logfolder}/{args.expname}/imgs_train_all/{style_name}', exist_ok=True) evaluation_feature(train_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_train_all/{style_name}', N_vis=-1, N_samples=-1, white_bg = train_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device) if args.render_test: test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all/{style_name}', exist_ok=True) evaluation_feature(test_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_test_all/{style_name}', N_vis=-1, N_samples=-1, white_bg = test_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device) if args.render_path: test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) c2ws = test_dataset.render_path os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all/{style_name}', exist_ok=True) evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_path_all/{style_name}', N_vis=-1, N_samples=-1, white_bg = test_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device) def reconstruction(args): # init dataset dataset = dataset_dict[args.dataset_name] train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) white_bg = train_dataset.white_bg near_far = train_dataset.near_far h_rays, w_rays = train_dataset.img_wh[1], train_dataset.img_wh[0] ndc_ray = args.ndc_ray patch_size = args.patch_size # ground truth image patch size when training Image.MAX_IMAGE_PIXELS = None # Disable DecompressionBombError ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated style_loader = getDataLoader(args.wikiartdir, batch_size=1, sampler=InfiniteSamplerWrapper, image_side_length=256, num_workers=2) style_iter = iter(style_loader) if args.add_timestamp: logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' else: logfolder = f'{args.basedir}/{args.expname}' # init log file os.makedirs(logfolder, exist_ok=True) summary_writer = SummaryWriter(logfolder) # init parameters # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0]) aabb = train_dataset.scene_bbox.to(device) # TODO: need to update reso_cur reso_cur = N_to_reso(args.N_voxel_init, aabb) nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) assert args.ckpt is not None, 'Have to be pre-trained to get density fielded!' ckpt = torch.load(args.ckpt, map_location=device) kwargs = ckpt['kwargs'] kwargs.update({'device':device}) tensorf = eval(args.model_name)(**kwargs) tensorf.change_to_feature_mod(args.n_lamb_sh, device) tensorf.load(ckpt) tensorf.change_to_style_mod(device) tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre tvreg = TVLoss() grad_vars = tensorf.get_optparam_groups_style_mod(args.lr_basis, args.lr_finetune) if args.lr_decay_iters > 0: lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters) else: args.lr_decay_iters = args.n_iters lr_factor = args.lr_decay_target_ratio**(1/args.n_iters) print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters) tensorf.train() optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99)) torch.cuda.empty_cache() allrays_stack, allrgbs_stack = train_dataset.all_rays_stack, train_dataset.all_rgbs_stack frameSampler = iter(InfiniteSamplerWrapper(allrays_stack.size(0))) # every next(sampler) returns a frame index pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout) for iteration in pbar: # get style_img, this style_img has NOT been normalized according to the pretrained VGGmodel style_img = next(style_iter)[0].to(device) # randomly sample patch_size*patch_size patch from given frame frame_idx = next(frameSampler) start_h = np.random.randint(0, h_rays-patch_size+1) start_w = np.random.randint(0, w_rays-patch_size+1) if white_bg: # move random sampled patches into center mid_h, mid_w = (h_rays-patch_size+1)/2, (w_rays-patch_size+1)/2 if mid_h-start_h>=1: start_h += np.random.randint(0, mid_h-start_h) elif mid_h-start_h<=-1: start_h += np.random.randint(mid_h-start_h, 0) if mid_w-start_w>=1: start_w += np.random.randint(0, mid_w-start_w) elif mid_w-start_w<=-1: start_w += np.random.randint(mid_w-start_w, 0) rays_train = allrays_stack[frame_idx, start_h:start_h+patch_size, start_w:start_w+patch_size, :]\ .reshape(-1, 6).to(device) # [patch*patch, 6] rgbs_train = allrgbs_stack[frame_idx, start_h:(start_h+patch_size), start_w:(start_w+patch_size), :].to(device) # [patch, patch, 3] feature_map, acc_map, style_feature = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, render_feature=True, style_img=style_img, device=device, is_train=True) feature_map = feature_map.reshape(patch_size, patch_size, 256)[None,...].permute(0,3,1,2) rgb_map = tensorf.decoder(feature_map) # feature_map is trained with normalized rgb maps, so here we don't normalize the rgb map again. rgbs_train = normalize_vgg(rgbs_train[None,...].permute(0,3,1,2)) out_image_feature = tensorf.encoder(rgb_map) content_feature = tensorf.encoder(rgbs_train) if white_bg: mask = acc_map.reshape(patch_size, patch_size, 1)[None,...].permute(0,3,1,2) if not (mask>0.5).any(): continue # content loss _mask = F.interpolate(mask, size=content_feature.relu4_1.size()[-2:], mode='bilinear').ge(1e-5) content_loss = cal_mse_content_loss(torch.masked_select(content_feature.relu4_1, _mask), torch.masked_select(out_image_feature.relu4_1, _mask)) # style loss style_loss = 0. for style_feature, image_feature in zip(style_feature, out_image_feature): _mask = F.interpolate(mask, size=image_feature.size()[-2:], mode='bilinear').ge(1e-5) C = image_feature.size()[1] masked_img_feature = torch.masked_select(image_feature, _mask).reshape(1,C,-1) style_loss += cal_adain_style_loss(style_feature, masked_img_feature) content_loss *= args.content_weight style_loss *= args.style_weight else: # content loss content_loss = cal_mse_content_loss(content_feature.relu4_1, out_image_feature.relu4_1) # style loss style_loss = 0. for style_feature, image_feature in zip(style_feature, out_image_feature): style_loss += cal_adain_style_loss(style_feature, image_feature) content_loss *= args.content_weight style_loss *= args.style_weight feature_tv_loss = tvreg(feature_map) * args.featuremap_tv_weight image_tv_loss = tvreg(denormalize_vgg(rgb_map)) * args.image_tv_weight total_loss = content_loss + style_loss + feature_tv_loss + image_tv_loss optimizer.zero_grad() total_loss.backward() optimizer.step() for param_group in optimizer.param_groups: param_group['lr'] = param_group['lr'] * lr_factor # Print the current values of the losses. if iteration%args.progress_refresh_rate==0: summary_writer.add_scalar('train/content_loss', content_loss, global_step=iteration) summary_writer.add_scalar('train/style_loss', style_loss, global_step=iteration) summary_writer.add_scalar('train/feature_tv_loss', feature_tv_loss, global_step=iteration) summary_writer.add_scalar('train/image_tv_loss', image_tv_loss, global_step=iteration) pbar.set_description( f'Iteration {iteration:05d}:' + f' content_loss = {content_loss.item():.2f}' + f' style_loss = {style_loss.item():.2f}' ) if iteration % (args.progress_refresh_rate*20) == 0: summary_writer.add_image('output', make_grid([denormalize_vgg(rgbs_train).squeeze(), \ denormalize_vgg(rgb_map).clamp(0, 1).squeeze(), \ TF.resize(style_img, (patch_size,patch_size)).squeeze()], nrow=3, padding=0, normalize=False), global_step=iteration) tensorf.save(f'{logfolder}/{args.expname}.th') if __name__ == '__main__': torch.set_default_dtype(torch.float32) torch.manual_seed(20211202) np.random.seed(20211202) args = config_parser() print(args) if args.render_only: render_test(args) else: reconstruction(args) ================================================ FILE: utils.py ================================================ import cv2,torch import numpy as np from PIL import Image import torchvision.transforms as T import torch.nn.functional as F import scipy.signal mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): """ depth: (H, W) """ x = np.nan_to_num(depth) # change nan to 0 if minmax is None: mi = np.min(x[x>0]) # get minimum positive depth (ignore background) ma = np.max(x) else: mi,ma = minmax x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 x = (255*x).astype(np.uint8) x_ = cv2.applyColorMap(x, cmap) return x_, [mi,ma] def init_log(log, keys): for key in keys: log[key] = torch.tensor([0.0], dtype=float) return log def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): """ depth: (H, W) """ if type(depth) is not np.ndarray: depth = depth.cpu().numpy() x = np.nan_to_num(depth) # change nan to 0 if minmax is None: mi = np.min(x[x>0]) # get minimum positive depth (ignore background) ma = np.max(x) else: mi,ma = minmax x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 x = (255*x).astype(np.uint8) x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) x_ = T.ToTensor()(x_) # (3, H, W) return x_, [mi,ma] def N_to_reso(n_voxels, bbox): xyz_min, xyz_max = bbox dim = len(xyz_min) voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim) return ((xyz_max - xyz_min) / voxel_size).long().tolist() def cal_n_samples(reso, step_ratio=0.5): return int(np.linalg.norm(reso)/step_ratio) __LPIPS__ = {} def init_lpips(net_name, device): assert net_name in ['alex', 'vgg'] import lpips print(f'init_lpips: lpips_{net_name}') return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) def rgb_lpips(np_gt, np_im, net_name, device): if net_name not in __LPIPS__: __LPIPS__[net_name] = init_lpips(net_name, device) gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) return __LPIPS__[net_name](gt, im, normalize=True).item() def findItem(items, target): for one in items: if one[:len(target)]==target: return one return None ''' Evaluation metrics (ssim, lpips) ''' def rgb_ssim(img0, img1, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03, return_map=False): # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 assert len(img0.shape) == 3 assert img0.shape[-1] == 3 assert img0.shape == img1.shape # Construct a 1D Gaussian blur filter. hw = filter_size // 2 shift = (2 * hw - filter_size + 1) / 2 f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 filt = np.exp(-0.5 * f_i) filt /= np.sum(filt) # Blur in x and y (faster than the 2D convolution). def convolve2d(z, f): return scipy.signal.convolve2d(z, f, mode='valid') filt_fn = lambda z: np.stack([ convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) for i in range(z.shape[-1])], -1) mu0 = filt_fn(img0) mu1 = filt_fn(img1) mu00 = mu0 * mu0 mu11 = mu1 * mu1 mu01 = mu0 * mu1 sigma00 = filt_fn(img0**2) - mu00 sigma11 = filt_fn(img1**2) - mu11 sigma01 = filt_fn(img0 * img1) - mu01 # Clip the variances and covariances to valid values. # Variance must be non-negative: sigma00 = np.maximum(0., sigma00) sigma11 = np.maximum(0., sigma11) sigma01 = np.sign(sigma01) * np.minimum( np.sqrt(sigma00 * sigma11), np.abs(sigma01)) c1 = (k1 * max_val)**2 c2 = (k2 * max_val)**2 numer = (2 * mu01 + c1) * (2 * sigma01 + c2) denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) ssim_map = numer / denom ssim = np.mean(ssim_map) return ssim_map if return_map else ssim import torch.nn as nn class TVLoss(nn.Module): def __init__(self): super(TVLoss,self).__init__() def forward(self,x): batch_size = x.size()[0] h_x = x.size()[2] w_x = x.size()[3] if w_x==1: count_h = self._tensor_size(x[:,:,1:,:]) h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() return 2*(h_tv/count_h)/batch_size if h_x==1: count_w = self._tensor_size(x[:,:,:,1:]) w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() return 2*(w_tv/count_w)/batch_size count_h = self._tensor_size(x[:,:,1:,:]) count_w = self._tensor_size(x[:,:,:,1:]) h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() return 2*(h_tv/count_h+w_tv/count_w)/batch_size def _tensor_size(self,t): return t.size()[1]*t.size()[2]*t.size()[3] import plyfile import skimage.measure def convert_sdf_samples_to_ply( pytorch_3d_sdf_tensor, ply_filename_out, bbox, level=0.5, offset=None, scale=None, ): """ Convert sdf samples to .ply :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid :voxel_size: float, the size of the voxels :ply_filename_out: string, path of the filename to save to This function adapted from: https://github.com/RobotLocomotion/spartan """ numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy() voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape)) verts, faces, normals, values = skimage.measure.marching_cubes( numpy_3d_sdf_tensor, level=level, spacing=voxel_size ) faces = faces[...,::-1] # inverse face orientation # transform from voxel coordinates to camera coordinates # note x and y are flipped in the output of marching_cubes mesh_points = np.zeros_like(verts) mesh_points[:, 0] = bbox[0,0] + verts[:, 0] mesh_points[:, 1] = bbox[0,1] + verts[:, 1] mesh_points[:, 2] = bbox[0,2] + verts[:, 2] # apply additional offset and scale if scale is not None: mesh_points = mesh_points / scale if offset is not None: mesh_points = mesh_points - offset # try writing to the ply file num_verts = verts.shape[0] num_faces = faces.shape[0] verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) for i in range(0, num_verts): verts_tuple[i] = tuple(mesh_points[i, :]) faces_building = [] for i in range(0, num_faces): faces_building.append(((faces[i, :].tolist(),))) faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") el_faces = plyfile.PlyElement.describe(faces_tuple, "face") ply_data = plyfile.PlyData([el_verts, el_faces]) print("saving mesh to %s" % (ply_filename_out)) ply_data.write(ply_filename_out) # Point cloud operations # import pytorch3d # import pytorch3d.transforms as transforms # from pytorch3d.structures import Pointclouds # from pytorch3d.renderer import ( # look_at_view_transform, # FoVOrthographicCameras, # PointsRasterizationSettings, # PointsRenderer, # PointsRasterizer, # AlphaCompositor, # ) # import imageio # def construct_points_coordinates(rays, depth): # ''' # Construct points' coordinates of a point cloud, every point corresponds to # a point on one ray with specified depth. # Args: # rays: [n_rays, 6] # depth: [n_rays] # Return: # point_cloud: [n_rays, 3] # ''' # rays_o, rays_d = rays[:, :3], rays[:, 3:6] # points_coordinates = rays_o + rays_d * depth[...,None] # return points_coordinates # def plot_point_cloud(points, rgbs, prtx=''): # ''' # Args: # points: coodinates of points [n, 3] # rgbs: color of points [n, 3] # ''' # point_cloud = Pointclouds(points=[points], features=[rgbs]) # images = [] # for degree in range(0,60,2): # R, T = look_at_view_transform(20, 10, degree) # cameras = FoVOrthographicCameras(device='cuda', R=R, T=T, znear=0.01) # raster_settings = PointsRasterizationSettings( # image_size=512, # radius = 0.003, # points_per_pixel = 10 # ) # rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) # renderer = PointsRenderer( # rasterizer=rasterizer, # compositor=AlphaCompositor() # ) # image = renderer(point_cloud).squeeze() # image = (image.detach().cpu().numpy() * 255).astype('uint8') # images.append(image) # # imageio.imwrite(f'./visualization/{prtx}pointcloud_{degree:02d}.png', image) # imageio.mimwrite(f'./visualization/{prtx}pointcloud_video.mp4', np.stack(images), fps=10, quality=8) # def knn_loss(p1, f1, p2, f2, thres=0.0005): # ''' # Args: # p1, p2: points [p1,3] [p2,3] # f1, f2: features [p1,c] [p2,c] # thres: if dist > thres then the two points are not adjacent # ''' # dists, idx, _ = pytorch3d.ops.knn_points(p1[None,...], p2[None,...], K=1) # # [N=1, P1, K=1] # match_feature = f2[idx.squeeze()] # [p1, C] # square_err = torch.sum((f1 - match_feature)**2, dim=-1) # [p1] # return torch.masked_select(square_err, dists.squeeze()