Repository: jingsenzhu/i2-sdf Branch: main Commit: 58c9a8241feb Files: 31 Total size: 217.1 KB Directory structure: gitextract_xvem6nqy/ ├── .gitignore ├── DATA_CONVENTION.md ├── LICENSE ├── README.md ├── config/ │ ├── synthetic.yml │ └── synthetic_light_mask.yml ├── data/ │ ├── normalize_cameras.py │ └── npz_to_blender.py ├── dataset/ │ ├── __init__.py │ ├── eval_dataset.py │ └── train_dataset.py ├── environment.yml ├── i2-sdf-dataset-links.csv ├── main_recon.py ├── model/ │ ├── __init__.py │ ├── eval/ │ │ ├── __init__.py │ │ └── recon.py │ ├── network/ │ │ ├── __init__.py │ │ ├── density.py │ │ ├── embedder.py │ │ ├── mlp.py │ │ └── ray_sampler.py │ ├── rendering/ │ │ ├── __init__.py │ │ └── brdf.py │ └── trainer/ │ ├── __init__.py │ └── recon.py └── utils/ ├── __init__.py ├── cfgnode.py ├── mesh_util.py ├── plots.py └── rend_util.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ data/synthetic exps __pycache__ archive tmp* ================================================ FILE: DATA_CONVENTION.md ================================================ # Data Convention The format of our multi-view dataset is derived from [VolSDF](https://github.com/lioryariv/volsdf/blob/main/DATA_CONVENTION.md). ### Directory Structure ```python scan/ cameras.npz image/ -> {:04d}.png # tone-mapped LDR images depth/ -> {:04d}.exr normal/ -> {:04d}.exr mask/ -> {:04d}.png val/ -> {:04d}.png # validation images (LDR) hdr/ -> {:04d}.exr # raw HDR images # followings are optional light_mask/ -> {:04d}.png # emitter mask images material/ -> {:04d}_kd.exr, {:04d}_ks.exr, {:04d}_rough.exr # diffuse, specular albedo and roughness ``` Zeros areas in the depth maps and normal maps indicate invalid areas such as windows. Note that not all areas inside the scenes use the GGX material model. We'll provide a mask of the invalid areas. ### Camera Information The `cameras.npz` contains each image's associated camera projection matrix `'world_mat_{i}'` and a normalization matrix `'scale_mat_{i}'`, the same as VolSDF. Besides, we also provide a validation set of images for novel view synthesis, whose associated camera projection matrices are `'val_mat_{i}'`. Validation set and training set share the same normalization matrix. The normalization matrices may not be readily available in `cameras.npz`. You can manually run `data/normalize_cameras.py` to generate `cameras_normalize.npz`. Since our method requires the entire scene to be within a radius-3 bounding sphere, we suggest normalizing cameras by radius 2.0 or 2.5. An example of running `normalize_cameras.py`: ```shell python normalize_cameras.py --id -n -r 2.0 ``` Note that we follow **OpenCV camera coordinate system** (X right, Y downwards, Z into the image plane). ### Dataset Format Conversion If you want to convert the dataset format to NeRF blender format, run `npz_to_blender.py`: ```sh python npz_to_blender.py --root /path/to/dataset ``` The script will automatically scale all pose matrices to fit within a `[-1, 1]` bounding box. ### About Real Dataset Our real dataset comes from [Inria](https://repo-sam.inria.fr/fungraph/deep-indoor-relight/) and [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/), with estimated depth from MVS tools and manually-labeled light masks (2 living room scenes). All depths has an absolute scale without needs of shifting like MonoSDF. All camera calibrations and depths are provided by the authors of [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/). We thank them for providing the datasets. Normal is not provided in the real dataset, and we find it sufficient for plausible reconstruction without a normal supervision in these scenes. Of course, you can estimate normal using any methods if you want to enable normal supervision. ### About EXR format We suggest using OpenCV to load an `.exr` format `float32` image: ```python import os os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' # Enable OpenCV support for EXR import cv2 ... im = cv2.imread(im_path, -1) # im will be an numpy.float32 array of shape (H, W, C) im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # cv2 reads image in BGR shape, convert into RGB ``` We suggest using [tev](https://github.com/Tom94/tev) to preview HDR `.exr` images conveniently. ### About Mesh Due to copyright issues, we could not release the original 3D mesh of our synthetic scenes. Instead, we'll provide a point cloud sampled from the GT mesh to enable 3D reconstruction evaluations. ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 Jingsen Zhu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ **News** - `04/04/2023` dataset preview release: 2 synthetic scenes available - `15/04/2023` code release: 3D reconstruction and novel view synthesis part - `21/04/2023` dataset release: real data **TODO** - [ ] Full dataset release - [x] Code release for 3D reconstruction and novel view synthesis - [ ] Code release for intrinsic decomposition and scene editing **Dataset released** - Synthetic: `kitchen_0`, `bedroom_relight_0`, `bedroom_0`, `bedroom_1`, `bedroom_relight_1`, `diningroom_0`, `livingroom_0`, `livingroom_1`, more scenes to be released - Real: `inria_livingroom`, `nisr_livingroom`, `nisr_coffee_shop_0`, `nisr_coffee_shop_1`, release complete # I2-SDF: Intrinsic Indoor Scene Reconstruction and Editing via Raytracing in Neural SDFs (CVPR 2023) ### [Project Page](https://jingsenzhu.github.io/i2-sdf/) | [Paper](https://arxiv.org/abs/2303.07634) | [Dataset](i2-sdf-dataset-links.csv) ## Setup ### Installation ``` conda env create -f environment.yml conda activate i2sdf ``` ### Data preparation Download our synthetic dataset and extract them into `data/synthetic`. If you want to run on your customized dataset, we provide a brief introduction to our data convention [here](DATA_CONVENTION.md). ## Dataset We provide a high-quality synthetic indoor scene multi-view dataset, with ground truth camera pose and geometry annotations. See [HERE](DATA_CONVENTION.md) for data conventions. Click [HERE](https://mega.nz/folder/jdhDnTqL#Ija678SU2Va_JJOiwqmdEg) to download. ## 3D Reconstruction and Novel View Synthesis ### Training ``` python main_recon.py --conf config/.yml --scan_id -d -v ``` Note: `config/synthetic.yml` doesn't contain light mask network, while `config/synthetic_light_mask.yml` contains. If you run out of GPU memory, try to reduce the `split_n_pixels` (i.e. validation batch size), `batch_size` in the config. The default parameters are evaluated under RTX A6000 (48GB). For RTX 3090 (24GB), try to set `split_n_pixels` 5000. ### Evaluation #### Novel view synthesis ``` python main_recon.py --conf config/.yml --scan_id -d -v --test [--is_val] [--full] ``` The optional flag `--is_val` evaluates on the validation set instead of training set, `--full` produces full-resolution rendered images without downsampling. #### View Interpolation ``` python main_recon.py --conf config/.yml --scan_id -d -v --test --test_mode interpolate --inter_id [--full] ``` Generates a view interpolation video between 2 views. Requires `ffmpeg` being installed. The number of frames and frame rate of the video can be specified by options. #### Mesh Extraction ``` python main_recon.py --conf config/.yml --scan_id -d -v --test --test_mode mesh ``` ## Intrinsic Decomposition and Scene Editing **Brewing🍺, code coming soon.** ## Citation If you find our work is useful, please consider cite: ``` @inproceedings{zhu2023i2sdf, title = {I$^2$-SDF: Intrinsic Indoor Scene Reconstruction and Editing via Raytracing in Neural SDFs}, author = {Jingsen Zhu and Yuchi Huo and Qi Ye and Fujun Luan and Jifan Li and Dianbing Xi and Lisha Wang and Rui Tang and Wei Hua and Hujun Bao and Rui Wang}, booktitle = {CVPR}, year = {2023} } ``` ## Acknowledgement - This repository is built upon [Pytorch lightning](https://lightning.ai/). - Thanks to Lior Yariv for her excellent work [VolSDF](https://lioryariv.github.io/volsdf/). - Thanks to [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/) team for providing their real-world dataset. ================================================ FILE: config/synthetic.yml ================================================ train: expname: synthetic learning_rate: 5.0e-4 steps: 200000 checkpoint_freq: 10000 plot_freq: 500 split_n_pixels: 12000 batch_size: 1600 pdf_criterion: DEPTH plot: plot_nimgs: 1 grid_boundary: [-1.5, 1.5] loss: eikonal_weight: 0.1 smooth_weight: 0.01 smooth_iter: 150000 depth_weight: 0.1 normal_weight: 0.05 bubble_weight: 0.5 min_bubble_iter: 50000 max_bubble_iter: 150000 dataset: data_dir: synthetic img_res: [480, 640] downsample: 2 pdf_prune: 0.05 pdf_max: 0.2 model: feature_vector_size: 256 scene_bounding_sphere: 3.0 implicit_network: d_in: 3 d_out: 1 dims: [ 256, 256, 256, 256, 256, 256, 256, 256 ] geometric_init: True bias: 0.6 skip_in: [4] weight_norm: True embed_type: 'positional' multires: 6 rendering_network: # mode: idr # d_in: 9 # Don't find actual differences between 'nerf' and 'idr' mode # Choose 'nerf' mode for a slight faster performance mode: nerf d_in: 3 d_out: 3 dims: [ 256, 256, 256, 256 ] weight_norm: True embed_type: 'positional' multires: 4 density: params_init: beta: 0.1 beta_min: 0.0001 ray_sampler: near: 0.0 N_samples: 64 N_samples_eval: 128 N_samples_extra: 32 eps: 0.1 beta_iters: 10 max_total_iters: 5 N_samples_inverse_sphere: 32 add_tiny: 1.0e-6 ================================================ FILE: config/synthetic_light_mask.yml ================================================ train: expname: synthetic_light learning_rate: 5.0e-4 steps: 200000 checkpoint_freq: 10000 plot_freq: 500 split_n_pixels: 12000 batch_size: 1600 pdf_criterion: DEPTH plot: plot_nimgs: 1 grid_boundary: [-1.5, 1.5] loss: eikonal_weight: 0.1 smooth_weight: 0.01 smooth_iter: 150000 depth_weight: 0.1 normal_weight: 0.05 bubble_weight: 0.5 light_mask_weight: 0.5 min_bubble_iter: 50000 max_bubble_iter: 150000 dataset: data_dir: synthetic img_res: [480, 640] downsample: 2 pdf_prune: 0.05 pdf_max: 0.2 model: feature_vector_size: 256 scene_bounding_sphere: 3.0 implicit_network: d_in: 3 d_out: 1 dims: [ 256, 256, 256, 256, 256, 256 ] geometric_init: True bias: 0.6 skip_in: [3] weight_norm: True embed_type: 'positional' multires: 6 rendering_network: mode: nerf d_in: 3 d_out: 3 dims: [ 256, 256, 256 ] weight_norm: True embed_type: 'positional' multires: 4 light_network: dims: [ 128 ] weight_norm: True density: params_init: beta: 0.1 beta_min: 0.0001 ray_sampler: near: 0.0 N_samples: 64 N_samples_eval: 128 N_samples_extra: 32 eps: 0.1 beta_iters: 10 max_total_iters: 5 N_samples_inverse_sphere: 32 add_tiny: 1.0e-6 ================================================ FILE: data/normalize_cameras.py ================================================ import cv2 import numpy as np import argparse from copy import deepcopy def get_center_point(num_cams,cameras): A = np.zeros((3 * num_cams, 3 + num_cams)) b = np.zeros((3 * num_cams, 1)) camera_centers=np.zeros((3,num_cams)) for i in range(num_cams): P0 = cameras['world_mat_%d' % i][:3, :] K = cv2.decomposeProjectionMatrix(P0)[0] R = cv2.decomposeProjectionMatrix(P0)[1] c = cv2.decomposeProjectionMatrix(P0)[2] c = c / c[3] camera_centers[:,i]=c[:3].flatten() # v = np.linalg.inv(K) @ np.array([800, 600, 1]) # v = v / np.linalg.norm(v) v=R[2,:] A[3 * i:(3 * i + 3), :3] = np.eye(3) A[3 * i:(3 * i + 3), 3 + i] = -v b[3 * i:(3 * i + 3)] = c[:3] soll= np.linalg.pinv(A) @ b return soll,camera_centers def normalize_cameras(original_cameras_filename,output_cameras_filename,num_of_cameras,radius,convert_coord): cameras = np.load(original_cameras_filename) if num_of_cameras==-1: all_files=cameras.files maximal_ind=0 for field in all_files: if 'val' not in field: maximal_ind=np.maximum(maximal_ind,int(field.split('_')[-1])) num_of_cameras=maximal_ind+1 soll, camera_centers = get_center_point(num_of_cameras, cameras) center = soll[:3].flatten() max_radius = np.linalg.norm((center[:, np.newaxis] - camera_centers), axis=0).max() * 1.1 normalization = np.eye(4).astype(np.float32) normalization[0, 3] = center[0] normalization[1, 3] = center[1] normalization[2, 3] = center[2] normalization[0, 0] = max_radius / radius normalization[1, 1] = max_radius / radius normalization[2, 2] = max_radius / radius cameras_new = {} cameras_new = deepcopy(dict(cameras)) for i in range(num_of_cameras): cameras_new['scale_mat_%d' % i] = normalization # cameras_new['world_mat_%d' % i] = cameras['world_mat_%d' % i].copy() # if ('val_mat_%d' % i) in cameras: # cameras_new['val_mat_%d' % i] = cameras['val_mat_%d' % i].copy() def opengl2opencv(P): out = cv2.decomposeProjectionMatrix(P[:3,:]) K, R, t = out[0:3] K = K/K[2,2] intrinsics = np.eye(4, dtype=np.float32) intrinsics[:3, :3] = K t = (t[:3] / t[3]).squeeze() w2c = np.eye(4, dtype=np.float32) w2c[:3,:3] = R w2c[:3,3] = -R @ t T = np.diag([1, -1, -1, 1]) w2c = T @ w2c return intrinsics @ w2c if convert_coord: cameras_new['world_mat_%d' % i] = opengl2opencv(cameras_new['world_mat_%d' % i]) if ('val_mat_%d' % i) in cameras_new: cameras_new['val_mat_%d' % i] = opengl2opencv(cameras_new['val_mat_%d' % i]) # cameras_new['world_mat_%d' % i] = T @ cameras_new['world_mat_%d' % i] np.savez(output_cameras_filename, **cameras_new) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Normalizing cameras') parser.add_argument('-i', '--input_cameras_file', type=str, default="cameras.npz", help='the input cameras file') parser.add_argument('-o', '--output_cameras_file', type=str, default="cameras_normalize.npz", help='the output cameras file') parser.add_argument('--id', type=int, nargs='?') parser.add_argument('-n', '--name', type=str, default='synthetic') parser.add_argument('--number_of_cams',type=int, default=-1, help='Number of cameras, if -1 use all') parser.add_argument('-r', '--radius', type=float, default=2.0) parser.add_argument('-c', '--convert_coord', action='store_true') args = parser.parse_args() if args.id: args.input_cameras_file = f'{args.name}/scan{args.id}/cameras.npz' args.output_cameras_file = f'{args.name}/scan{args.id}/cameras_normalize.npz' normalize_cameras(args.input_cameras_file, args.output_cameras_file, args.number_of_cams, args.radius, args.convert_coord) ================================================ FILE: data/npz_to_blender.py ================================================ """ Transform npz-formatted scenes to json-formatted scene (NeRF blender format) Scale all poses to fit in a [-1, 1] box """ import copy import json import os import cv2 import numpy as np from tqdm import tqdm import argparse def to16b(img): img = img.clip(0, 1) * 65535 return img.astype(np.uint16) def opencv_to_gl(pose): mat = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) pose[:3, :3] = pose[:3, :3] @ mat return pose def get_offset(poses): eyes = np.stack([pose[:3, 3] for pose in poses]) scale = eyes.max(axis=0) - eyes.min(axis=0) print(f'scale : {scale}') offset = -(eyes.max(axis=0) + eyes.min(axis=0)) / 2 print(f'offset : {offset}') return scale / 2, offset def scale_pose(pose, scale, offset): pose[:3, 3] = (pose[:3, 3] + offset) / scale # print(pose[:3, 3]) return pose.tolist() def load_K_Rt_from_P(filename, P=None): if P is None: lines = open(filename).read().splitlines() if len(lines) == 4: lines = lines[1:] lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] P = np.asarray(lines).astype(np.float32).squeeze() out = cv2.decomposeProjectionMatrix(P) K = out[0] R = out[1] t = out[2] K = K / K[2, 2] intrinsics = np.eye(4) intrinsics[:3, :3] = K pose = np.eye(4, dtype=np.float32) pose[:3, :3] = R.transpose() pose[:3, 3] = (t[:3] / t[3])[:, 0] return intrinsics, pose def main(): parser = argparse.ArgumentParser() parser.add_argument('--root', required=True) parser.add_argument('--scale', action='store_true') args = parser.parse_args() os.chdir(os.path.join(args.root)) image_dir = 'image' n_images = len(os.listdir(image_dir)) val_dir = 'val' n_val = len(os.listdir(val_dir)) os.makedirs('depths', exist_ok=True) cam_file = 'cameras.npz' camera_dict = np.load(cam_file) world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)] val_mats = [camera_dict['val_mat_%d' % idx].astype(np.float32) for idx in range(n_val)] intrinsics_all = [] pose_all = [] for mat in world_mats + val_mats: P = mat P = P[:3, :4] intrinsics, pose = load_K_Rt_from_P(None, P) intrinsics_all.append(intrinsics) pose_all.append(opencv_to_gl(pose)) train_json = dict() train_json['fl_y'] = intrinsics[1][1] train_json['h'] = int(intrinsics[1, 2] * 2) train_json['fl_x'] = intrinsics[0][0] train_json['w'] = int(intrinsics[0, 2] * 2) if args.scale: scale, offset = get_offset(pose_all) # train_json['enable_depth_loading'] = True # train_json['integer_depth_scale'] = 1 / 65535 train_json['frames'] = [] test_json = copy.deepcopy(train_json) test_json['enable_depth_loading'] = False for i in tqdm(range(n_images)): frames = train_json['frames'] if args.scale: depth = cv2.imread(os.path.join('depth', '{:04d}.exr'.format(i)), -1) cv2.imwrite(os.path.join('depths', '{:04d}.exr'.format(i)), depth / scale.max()) pose = pose_all[i].tolist() if not args.scale else scale_pose(pose_all[i], scale.max(), offset) frame = { 'file_path': f'./image/{i:04d}', 'depth_path': f'./depths/{i:04d}.exr' if args.scale else f'./depth/{i:04d}.exr', 'transform_matrix': pose } frames.append(frame) for i in tqdm(range(n_val)): frames = test_json['frames'] pose = pose_all[i + n_images].tolist() if not args.scale else scale_pose(pose_all[i + n_images], scale.max(), offset) frame = { 'file_path': f'./val/{i:04d}', 'transform_matrix': pose } frames.append(frame) with open('transforms_train.json', 'w') as f: json.dump(train_json, f, indent=4) with open('transforms_test.json', 'w') as f: json.dump(test_json, f, indent=4) with open('transforms_val.json', 'w') as f: json.dump(test_json, f, indent=4) if __name__ == '__main__': main() ================================================ FILE: dataset/__init__.py ================================================ from .train_dataset import * from .eval_dataset import * ================================================ FILE: dataset/eval_dataset.py ================================================ from copy import deepcopy import os import torch import numpy as np from torch.utils.data import Dataset import utils.plots as plt import torch.nn.functional as F import utils from utils import rend_util from tqdm.contrib import tzip, tenumerate from scipy.spatial.transform import Rotation as Rot from scipy.spatial.transform import Slerp import cv2 class GridDataset(Dataset): """ Used for mesh extraction """ def __init__(self, points, xyz) -> None: super().__init__() self.grid_points = points self.xyz = xyz def __len__(self): return self.grid_points.size(0) def __getitem__(self, index): return self.grid_points[index] class PlotDataset(torch.utils.data.Dataset): def __init__(self, data_dir, plot_nimgs, scan_id=0, is_val=False, data=None, is_hdr=False, indices=None, use_lmask=False, **kwargs ): self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id)) val_dir = '{0}/val'.format(self.instance_dir) is_val = is_val and os.path.exists(val_dir) lmask_dir = '{0}/light_mask'.format(self.instance_dir) self.use_lmask = use_lmask and os.path.exists(lmask_dir) if is_val: print("[INFO] Validation set detected") if is_val or data is None: assert os.path.exists(self.instance_dir), "Data directory is empty" if is_val: image_dir = val_dir elif is_hdr: image_dir = '{0}/hdr'.format(self.instance_dir) else: image_dir = '{0}/image'.format(self.instance_dir) image_paths = sorted(utils.glob_imgs(image_dir)) if indices is not None: print(f"[INFO] Selecting indices: {indices}") image_paths = [image_paths[i] for i in indices] self.n_images = len(image_paths) self.indices = indices if indices is not None else list(range(self.n_images)) self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir) camera_dict = np.load(self.cam_file) scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.indices] if not is_val else [camera_dict['scale_mat_0'].astype(np.float32)] * len(self.indices) world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.indices] if not is_val else [camera_dict['val_mat_%d' % idx].astype(np.float32) for idx in self.indices] self.intrinsics_all = [] self.pose_all = [] for scale_mat, world_mat in zip(scale_mats, world_mats): P = world_mat @ scale_mat P = P[:3, :4] intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) self.pose_all.append(torch.from_numpy(pose).float()) self.intrinsics_all = torch.stack(self.intrinsics_all, 0) self.pose_all = torch.stack(self.pose_all, 0) self.rgb_images = [] for path in image_paths: rgb = rend_util.load_rgb(path, is_hdr=is_hdr) self.img_res = [rgb.shape[1], rgb.shape[2]] rgb = rgb.reshape(3, -1).transpose(1, 0) self.rgb_images.append(torch.from_numpy(rgb).float()) self.rgb_images = torch.stack(self.rgb_images, 0) if self.use_lmask: self.lightmask_images = [] lmask_paths = sorted(utils.glob_imgs(lmask_dir)) for path in lmask_paths: lmask = rend_util.load_mask(path) lmask = lmask.reshape(-1, 1) self.lightmask_images.append(torch.from_numpy(lmask).float()) self.lightmask_images = torch.stack(self.lightmask_images, 0) self.total_pixels = self.rgb_images.size(1) else: self.intrinsics_all = data['intrinsics'] self.pose_all = data['pose'] self.rgb_images = data['rgb'] self.n_images = len(self.rgb_images) self.img_res = [data['img_res'][0], data['img_res'][1]] self.total_pixels = self.img_res[0] * self.img_res[1] if 'light_mask' in data: self.lightmask_images = data['light_mask'] self.use_lmask = True if (scale := kwargs.get('downsample', 1)) > 1: old_img_res = deepcopy(self.img_res) self.img_res[0] //= scale self.img_res[1] //= scale self.total_pixels = self.img_res[0] * self.img_res[1] self.rgb_images = self.rgb_images.transpose(1, 2).reshape(-1, 3, old_img_res[0], old_img_res[1]) self.rgb_images = F.interpolate(self.rgb_images, self.img_res, mode='area') self.rgb_images = self.rgb_images.reshape(-1, 3, self.total_pixels).transpose(1, 2) if self.use_lmask: self.lightmask_images = self.lightmask_images.transpose(1, 2).reshape(-1, 1, old_img_res[0], old_img_res[1]) self.lightmask_images = F.interpolate(self.lightmask_images, self.img_res, mode='area') self.lightmask_images = self.lightmask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2) self.intrinsics_all = self.intrinsics_all.clone() self.intrinsics_all[:,0,0] /= scale self.intrinsics_all[:,1,1] /= scale self.intrinsics_all[:,0,2] /= scale self.intrinsics_all[:,1,2] /= scale print(f"[INFO] Plot image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total") if plot_nimgs == -1: self.plot_nimgs = self.n_images else: self.plot_nimgs = min(plot_nimgs, self.n_images) self.shuffle = kwargs.get('shuffle', True) if self.shuffle: self.shuffle_plot_index() def shuffle_plot_index(self): if self.shuffle: self.plot_index = torch.randperm(self.n_images)[:self.plot_nimgs] def __len__(self): return self.plot_nimgs def get_uv(self): uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() uv = uv.reshape(2, -1).transpose(1, 0) return uv def __getitem__(self, idx): if self.shuffle: idx = self.plot_index[idx] uv = self.get_uv() sample = { "uv": uv, "intrinsics": self.intrinsics_all[idx], "pose": self.pose_all[idx] } ground_truth = { "rgb": self.rgb_images[idx] } if self.use_lmask: ground_truth['light_mask'] = self.lightmask_images[idx] return idx, sample, ground_truth def collate_fn(self, batch_list): # get list of dictionaries and returns input, ground_true as dictionary for all batch instances batch_list = zip(*batch_list) all_parsed = [] for entry in batch_list: if type(entry[0]) is dict: # make them all into a new dict ret = {} for k in entry[0].keys(): ret[k] = torch.stack([obj[k] for obj in entry]) all_parsed.append(ret) else: all_parsed.append(torch.LongTensor(entry)) return tuple(all_parsed) class InterpolateDataset(torch.utils.data.Dataset): """ View interpolation: specify 2 view ids from training set and generate a video moving between them """ def __init__(self, data_dir, # img_res, id0, id1, num_frames=60, scan_id=0, **kwargs ): self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id)) assert os.path.exists(self.instance_dir), "Data directory is empty" image_dir = '{0}/image'.format(self.instance_dir) im = cv2.imread(f"{image_dir}/{id0:04d}.png") h, w, _ = im.shape self.img_res = [h, w] self.total_pixels = h * w self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir) camera_dict = np.load(self.cam_file) P0 = camera_dict['world_mat_%d' % id0].astype(np.float32) @ camera_dict['scale_mat_%d' % id0].astype(np.float32) P1 = camera_dict['world_mat_%d' % id1].astype(np.float32) @ camera_dict['scale_mat_%d' % id1].astype(np.float32) P0 = P0[:3,:] P1 = P1[:3,:] K, pose0 = rend_util.load_K_Rt_from_P(None, P0) _, pose1 = rend_util.load_K_Rt_from_P(None, P1) rots = Rot.from_matrix(np.stack([pose0[:3,:3].T, pose1[:3,:3].T])) slerp = Slerp([0, 1], rots) if (scale := kwargs.get('downsample', 1)) > 1: self.img_res[0] = self.img_res[0] // scale self.img_res[1] = self.img_res[1] // scale self.total_pixels = self.img_res[0] * self.img_res[1] K[0,0] /= scale K[1,1] /= scale K[0,2] /= scale K[1,2] /= scale self.intrinsics = torch.from_numpy(K).float() self.pose_all = [] for i in range(num_frames): ratio = np.sin(((i / num_frames) - 0.5) * np.pi) * 0.5 + 0.5 t = (1 - ratio) * pose0[:3,3] + ratio * pose1[:3,3] R = slerp(ratio).as_matrix() pose = np.eye(4, dtype=np.float32) pose[:3,3] = t pose[:3,:3] = R.T self.pose_all.append(torch.from_numpy(pose).float()) self.pose_all = torch.stack(self.pose_all) self.n_frames = num_frames def __len__(self): return self.n_frames def __getitem__(self, idx): uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() uv = uv.reshape(2, -1).transpose(1, 0) sample = { "uv": uv, "intrinsics": self.intrinsics, "pose": self.pose_all[idx] } return idx, sample def collate_fn(self, batch_list): # get list of dictionaries and returns input, ground_true as dictionary for all batch instances batch_list = zip(*batch_list) all_parsed = [] for entry in batch_list: if type(entry[0]) is dict: # make them all into a new dict ret = {} for k in entry[0].keys(): ret[k] = torch.stack([obj[k] for obj in entry]) all_parsed.append(ret) else: all_parsed.append(torch.LongTensor(entry)) return tuple(all_parsed) class RelightDataset(PlotDataset): def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwargs): super().__init__(data_dir, 1, scan_id, is_val, None, False, [edit_cfg['index']], True, **kwargs) self.edit_mask = 'mask' in edit_cfg if self.edit_mask: self.mask = rend_util.load_mask(edit_cfg['mask']).astype(np.float32) mh, mw = self.mask.shape if mh != self.img_res[0] or mw != self.img_res[1]: self.mask = cv2.resize(self.mask, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA) self.mask = (self.mask > 0.5) self.mask = torch.from_numpy(self.mask).float().flatten() if 'normal' in edit_cfg: self.loadattr(edit_cfg, 'normal', 0) self.normal = self.normal.reshape(-1, 3) self.normal = F.normalize(self.normal, dim=-1, eps=1e-6) if 'rough' in edit_cfg: self.loadattr(edit_cfg, 'rough', 1) self.rough = self.rough.reshape(-1, 1) if 'kd' in edit_cfg: self.loadattr(edit_cfg, 'kd', 2) self.kd = self.kd.reshape(-1, 3) if 'ks' in edit_cfg: self.loadattr(edit_cfg, 'ks', 2) self.ks = self.ks.reshape(-1, 3) self.uv = self.get_uv() def loadattr(self, edit_cfg, attr, mode=0): if mode == 0: im = rend_util.load_normal(edit_cfg[attr]) elif mode == 1: im = cv2.imread(edit_cfg[attr], -1) if len(im.shape) == 3: im = im[:,:,-1] else: im = rend_util.load_rgb(edit_cfg[attr]).transpose(1, 2, 0) h, w = im.shape[:2] if h != self.img_res[0] or w != self.img_res[1]: im = cv2.resize(im, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA) setattr(self, attr, torch.from_numpy(im).float()) def __len__(self): return self.total_pixels def __getitem__(self, idx): sample = { "uv": self.uv[idx].unsqueeze(0), "intrinsics": self.intrinsics_all[0], "pose": self.pose_all[0] } ground_truth = { "rgb": self.rgb_images[0][idx], 'light_mask': self.lightmask_images[0][idx] # 'edit_mask': self.edit_mask[idx] } if self.edit_mask: ground_truth['mask'] = self.mask[idx] if hasattr(self, 'normal'): ground_truth['normal'] = self.normal[idx] if hasattr(self, 'rough'): ground_truth['rough'] = self.rough[idx] if hasattr(self, 'kd'): ground_truth['kd'] = self.kd[idx] if hasattr(self, 'ks'): ground_truth['ks'] = self.ks[idx] return idx, sample, ground_truth class RelightVideoDataset(PlotDataset): def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwargs): self.n_frames = edit_cfg['n_frames'] self.img_idx = edit_cfg['index'] super().__init__(data_dir, 1, scan_id, is_val, None, False, [edit_cfg['index']] * self.n_frames, True, **kwargs) self.edit_mask = 'mask' in edit_cfg if self.edit_mask: self.mask = rend_util.load_mask(edit_cfg['mask']).astype(np.float32) mh, mw = self.mask.shape if mh != self.img_res[0] or mw != self.img_res[1]: self.mask = cv2.resize(self.mask, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA) self.mask = (self.mask > 0.5) self.mask = torch.from_numpy(self.mask).float().flatten() self.uv = self.get_uv() def __len__(self): return self.n_frames def __getitem__(self, idx): sample = { "uv": self.uv, "intrinsics": self.intrinsics_all[idx], "pose": self.pose_all[idx] } ground_truth = { "rgb": self.rgb_images[idx], 'light_mask': self.lightmask_images[idx] # 'edit_mask': self.edit_mask[idx] } if self.edit_mask: ground_truth['mask'] = self.mask return idx, sample, ground_truth ================================================ FILE: dataset/train_dataset.py ================================================ import json import os import cv2 import torch import numpy as np from torch.utils.data import Dataset import torch.nn.functional as F import utils from utils import rend_util from tqdm.contrib import tzip, tenumerate from scipy.spatial.transform import Rotation as Rot from scipy.spatial.transform import Slerp class ReconDataset(torch.utils.data.Dataset): def __init__(self, data_dir, scan_id=0, use_mask=False, use_depth=False, use_normal=False, use_bubble=False, use_lightmask=False, is_hdr=False, **kwargs ): self.sampling_idx = slice(None) self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id)) assert os.path.exists(self.instance_dir), "Data directory is empty" print(f"[INFO] Loading data from {self.instance_dir}") image_dir = '{0}/image'.format(self.instance_dir) if not is_hdr else '{0}/hdr'.format(self.instance_dir) self.is_hdr = is_hdr if is_hdr: print("[INFO] Using HDR image") image_paths = sorted(utils.glob_imgs(image_dir)) self.n_images = len(image_paths) self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir) camera_dict = np.load(self.cam_file) scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] self.intrinsics_all = [] self.pose_all = [] for scale_mat, world_mat in zip(scale_mats, world_mats): P = world_mat @ scale_mat P = P[:3, :4] intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) self.pose_all.append(torch.from_numpy(pose).float()) self.intrinsics_all = torch.stack(self.intrinsics_all, 0) self.pose_all = torch.stack(self.pose_all, 0) self.rgb_images = [] for path in image_paths: rgb = rend_util.load_rgb(path, is_hdr=is_hdr) self.img_res = [rgb.shape[1], rgb.shape[2]] rgb = rgb.reshape(3, -1).transpose(1, 0) self.rgb_images.append(torch.from_numpy(rgb).float()) self.rgb_images = torch.stack(self.rgb_images, 0) self.total_pixels = self.rgb_images.size(1) print(f"[INFO] image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total") uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) uv = np.flip(uv, axis=0).copy() self.uv = torch.from_numpy(uv).float() self.uv = self.uv.reshape(2, -1).transpose(1, 0) # (h*w, 2) mask_dir = '{0}/mask'.format(self.instance_dir) self.use_mask = use_mask if self.use_mask: if os.path.exists(mask_dir): mask_paths = sorted(utils.glob_imgs(mask_dir)) # assert len(mask_paths) == self.n_images self.mask_images = [] for path in mask_paths: mask = rend_util.load_mask(path) mask = mask.reshape(-1, 1) self.mask_images.append(torch.from_numpy(mask).float()) self.mask_images = torch.stack(self.mask_images, 0) else: print("[INFO] No existing mask image, use one mask as default") self.mask_images = torch.ones(self.n_images, self.total_pixels, 1, dtype=torch.float) lmask_dir = '{0}/light_mask'.format(self.instance_dir) self.use_lightmask = use_lightmask and os.path.exists(lmask_dir) if self.use_lightmask: lmask_paths = sorted(utils.glob_imgs(lmask_dir)) self.lightmask_images = [] for path in lmask_paths: lmask = rend_util.load_mask(path) lmask = lmask.reshape(-1, 1) self.lightmask_images.append(torch.from_numpy(lmask).float()) self.lightmask_images = torch.stack(self.lightmask_images, 0) depth_dir = '{0}/depth'.format(self.instance_dir) self.use_depth = use_depth and os.path.exists(depth_dir) self.use_bubble = use_bubble and os.path.exists(depth_dir) if self.use_depth or self.use_bubble: self.depth_images = [] self.depth_masks = [] self.pointcloud = [] # pointcloud for bubble loss, unprojected from depth and poses self.pointlinks = [] # link from pixel index to pointcloud index, value -1 when the pixel is invalid at pointcloud self.pixlinks = [] # link from pointcloud index to pixel index depth_paths = sorted(utils.glob_depths(depth_dir)) n_points = 0 if kwargs.get('noise_scale', 0.0) > 0: print(f"[INFO] Ablation study: using noise scale {kwargs.get('noise_scale')}") for scale_mat, depth_path, intrinsics, pose, i in tzip(scale_mats, depth_paths, self.intrinsics_all, self.pose_all, range(len(self.pose_all))): depth = rend_util.load_depth(depth_path) depth = torch.from_numpy(depth.reshape(-1)).float() depth = depth / scale_mat[2,2] valid_indices = torch.where((depth > 1e-3) & (depth < 6))[0] if i == 0 and scale_mat[2,2] != 1: print(f"[INFO] Depth scaled by {scale_mat[2,2]:.2f}") depth_mask = torch.zeros([self.total_pixels], dtype=torch.bool) depth_mask[valid_indices] = True # if self.use_depth: if (noise_scale := kwargs.get('noise_scale', 0.0)) > 0: depth = rend_util.add_depth_noise(depth, depth_mask.float(), noise_scale) self.depth_images.append(depth) self.depth_masks.append(depth_mask) if self.use_bubble: pointlink = -torch.ones([self.total_pixels], dtype=torch.long) pointlink[depth_mask] = torch.arange(0, len(valid_indices), dtype=torch.long) + n_points pixlink = torch.arange(i * self.total_pixels, (i + 1) * self.total_pixels, dtype=torch.long)[depth_mask] n_points += len(valid_indices) self.pointlinks.append(pointlink) self.pixlinks.append(pixlink) self.pointcloud.append(rend_util.depth_to_world(self.uv, intrinsics, pose, depth, depth_mask)) self.depth_images = torch.stack(self.depth_images, 0) self.depth_masks = torch.stack(self.depth_masks, 0) if self.use_bubble: self.pointcloud = torch.cat(self.pointcloud, 0) self.pointlinks = torch.cat(self.pointlinks, 0) self.pixlinks = torch.cat(self.pixlinks, 0) self.pointcloud = self.pointcloud[:,:3] / self.pointcloud[:,3:] self.pdf_prune = kwargs.get('pdf_prune', 0) self.pdf_max = kwargs.get('pdf_max', None) print(f"[INFO] PDF clamped to {self.pdf_prune}") normal_dir = '{0}/normal'.format(self.instance_dir) self.use_normal = use_normal and os.path.exists(normal_dir) if self.use_normal: self.normal_images = [] self.normal_masks = [] normal_paths = sorted(utils.glob_normal(normal_dir)) for pose, normal_path in tzip(self.pose_all, normal_paths): normal = rend_util.load_normal(normal_path) normal = torch.from_numpy(normal.reshape(-1, 3)).float() valid_indices = torch.where(torch.linalg.vector_norm(normal, dim=1) > 1e-3)[0] R = pose[:3,:3] normal = (R @ normal.T).T # convert normal from view space to world space normal = F.normalize(normal, dim=1, eps=1e-6) self.normal_images.append(normal) normal_mask = torch.zeros([self.total_pixels], dtype=torch.bool) normal_mask[valid_indices] = True self.normal_masks.append(normal_mask) self.normal_images = torch.stack(self.normal_images, 0) self.normal_masks = torch.stack(self.normal_masks, 0) def __len__(self): return self.n_images * self.total_pixels def __getitem__(self, idx): pidx = idx % self.total_pixels tidx = idx idx = idx // self.total_pixels sample = { "uv": self.uv[pidx].unsqueeze(0), "intrinsics": self.intrinsics_all[idx], "pose": self.pose_all[idx] } ground_truth = { "rgb": self.rgb_images[idx][pidx] } if self.use_mask: ground_truth['mask'] = self.mask_images[idx][pidx] if self.use_lightmask: ground_truth['light_mask'] = self.lightmask_images[idx][pidx] if self.use_depth or self.use_bubble: ground_truth['depth'] = self.depth_images[idx][pidx] ground_truth['depth_mask'] = self.depth_masks[idx][pidx] if self.use_normal: ground_truth['normal'] = self.normal_images[idx][pidx] ground_truth['normal_mask'] = self.normal_masks[idx][pidx] return tidx, idx, sample, ground_truth def collate_fn(self, batch_list): # get list of dictionaries and returns input, ground_true as dictionary for all batch instances batch_list = zip(*batch_list) all_parsed = [] for entry in batch_list: if type(entry[0]) is dict: # make them all into a new dict ret = {} for k in entry[0].keys(): ret[k] = torch.stack([obj[k] for obj in entry]) all_parsed.append(ret) else: all_parsed.append(torch.LongTensor(entry)) return tuple(all_parsed) class MaterialDataset(torch.utils.data.Dataset): def __init__(self, data_dir, scan_id=0, downsample_train=1, is_hdr=False, **kwargs ): self.sampling_idx = slice(None) self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id)) assert os.path.exists(self.instance_dir), "Data directory is empty" image_dir = '{0}/image'.format(self.instance_dir) if not is_hdr else '{0}/hdr'.format(self.instance_dir) self.is_hdr = is_hdr if is_hdr: print("[INFO] Using HDR image") image_paths = sorted(utils.glob_imgs(image_dir)) self.n_images = len(image_paths) self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir) camera_dict = np.load(self.cam_file) scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] self.intrinsics_all = [] self.pose_all = [] for scale_mat, world_mat in zip(scale_mats, world_mats): P = world_mat @ scale_mat P = P[:3, :4] intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) self.pose_all.append(torch.from_numpy(pose).float()) self.intrinsics_all = torch.stack(self.intrinsics_all, 0) self.pose_all = torch.stack(self.pose_all, 0) self.rgb_images = [] for path in image_paths: rgb = rend_util.load_rgb(path, is_hdr=is_hdr) self.img_res = [rgb.shape[1], rgb.shape[2]] rgb = rgb.reshape(3, -1).transpose(1, 0) self.rgb_images.append(torch.from_numpy(rgb).float()) self.rgb_images = torch.stack(self.rgb_images, 0) self.total_pixels = self.rgb_images.size(1) mask_dir = '{0}/mask'.format(self.instance_dir) self.use_mask = os.path.exists(mask_dir) if self.use_mask: mask_paths = sorted(utils.glob_imgs(mask_dir)) # assert len(mask_paths) == self.n_images self.mask_images = [] for path in mask_paths: mask = rend_util.load_mask(path) mask = mask.reshape(-1, 1) self.mask_images.append(torch.from_numpy(mask).float()) self.mask_images = torch.stack(self.mask_images, 0) lmask_dir = '{0}/light_mask'.format(self.instance_dir) self.use_lightmask = os.path.exists(lmask_dir) if self.use_lightmask: print("[INFO] Light mask detected") lmask_paths = sorted(utils.glob_imgs(lmask_dir)) self.lightmask_images = [] for path in lmask_paths: lmask = rend_util.load_mask(path) lmask = lmask.reshape(-1, 1) self.lightmask_images.append(torch.from_numpy(lmask).float()) self.lightmask_images = torch.stack(self.lightmask_images, 0) if downsample_train > 1: old_res = (self.img_res[0], self.img_res[1]) self.rgb_images = self.rgb_images.transpose(1, 2).reshape(-1, 3, *old_res) self.img_res[0] //= downsample_train self.img_res[1] //= downsample_train self.total_pixels = self.img_res[0] * self.img_res[1] self.rgb_images = F.interpolate(self.rgb_images, self.img_res, mode='area') self.rgb_images = self.rgb_images.reshape(-1, 3, self.total_pixels).transpose(1, 2) self.intrinsics_all = self.intrinsics_all.clone() self.intrinsics_all[:,0,0] /= downsample_train self.intrinsics_all[:,1,1] /= downsample_train self.intrinsics_all[:,0,2] /= downsample_train self.intrinsics_all[:,1,2] /= downsample_train if self.use_mask: self.mask_images = self.mask_images.transpose(1, 2).reshape(-1, 1, *old_res) self.mask_images = F.interpolate(self.mask_images, self.img_res, mode='area') self.mask_images = self.mask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2) if self.use_lightmask: self.lightmask_images = self.lightmask_images.transpose(1, 2).reshape(-1, 1, *old_res) self.lightmask_images = F.interpolate(self.lightmask_images, self.img_res, mode='area') self.lightmask_images[self.lightmask_images > 0] = 1 self.lightmask_images = self.lightmask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2) print(f"[INFO] image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total") uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) uv = np.flip(uv, axis=0).copy() self.uv = torch.from_numpy(uv).float() self.uv = self.uv.reshape(2, -1).transpose(1, 0) # (h*w, 2) def __len__(self): return self.n_images * self.total_pixels def __getitem__(self, idx): pidx = idx % self.total_pixels tidx = idx idx = idx // self.total_pixels sample = { "uv": self.uv[pidx].unsqueeze(0), "intrinsics": self.intrinsics_all[idx], "pose": self.pose_all[idx] } ground_truth = { "rgb": self.rgb_images[idx][pidx] } if self.use_mask: ground_truth['mask'] = self.mask_images[idx][pidx] if self.use_lightmask: ground_truth['light_mask'] = self.lightmask_images[idx][pidx] return tidx, idx, sample, ground_truth def collate_fn(self, batch_list): # get list of dictionaries and returns input, ground_true as dictionary for all batch instances batch_list = zip(*batch_list) all_parsed = [] for entry in batch_list: if type(entry[0]) is dict: # make them all into a new dict ret = {} for k in entry[0].keys(): ret[k] = torch.stack([obj[k] for obj in entry]) all_parsed.append(ret) else: all_parsed.append(torch.LongTensor(entry)) return tuple(all_parsed) ================================================ FILE: environment.yml ================================================ name: i2sdf channels: - pytorch - conda-forge - defaults dependencies: - cudatoolkit=11.3.1=h9edb442_10 - ffmpeg=4.3=hf484d3e_0 - numpy=1.23.5=py39h14f4228_0 - pip=23.0.1=py39h06a4308_0 - python=3.9.16=h7a1cb2a_2 - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0 - torchaudio=0.12.1=py39_cu113 - torchvision=0.13.1=py39_cu113 - pip: - fast-pytorch-kmeans==0.1.9 - ffmpeg-python==0.2.0 - gputil==1.4.0 - lpips==0.1.4 - open3d==0.17.0 - opencv-python==4.7.0.72 - pymcubes==0.1.4 - pytorch-lightning==1.9.0 - pyyaml==6.0 - rich==13.3.3 - scikit-image==0.20.0 - scikit-learn==1.2.2 - scipy==1.9.1 - tensorboard==2.12.0 - tensorboardx==2.6 - torchmetrics==0.11.4 - tqdm==4.65.0 - trimesh==3.21.4 ================================================ FILE: i2-sdf-dataset-links.csv ================================================ file,url interiorverse/i2-sdf/i2-sdf/bedroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_0.zip interiorverse/i2-sdf/i2-sdf/bedroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_0_preview.png interiorverse/i2-sdf/i2-sdf/bedroom_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_1.zip interiorverse/i2-sdf/i2-sdf/bedroom_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_1_preview.png interiorverse/i2-sdf/i2-sdf/bedroom_relight_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_0.zip interiorverse/i2-sdf/i2-sdf/bedroom_relight_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_0_preview.png interiorverse/i2-sdf/i2-sdf/bedroom_relight_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_1.zip interiorverse/i2-sdf/i2-sdf/bedroom_relight_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_1_preview.png interiorverse/i2-sdf/i2-sdf/diningroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/diningroom_0.zip interiorverse/i2-sdf/i2-sdf/diningroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/diningroom_0_preview.png interiorverse/i2-sdf/i2-sdf/kitchen_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/kitchen_0.zip interiorverse/i2-sdf/i2-sdf/kitchen_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/kitchen_0_preview.png interiorverse/i2-sdf/i2-sdf/livingroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_0.zip interiorverse/i2-sdf/i2-sdf/livingroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_0_preview.png interiorverse/i2-sdf/i2-sdf/livingroom_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_1.zip interiorverse/i2-sdf/i2-sdf/livingroom_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_1_preview.png interiorverse/i2-sdf/i2-sdf/real_data/inria_livingroom.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/inria_livingroom.zip interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_0.zip interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_1.zip interiorverse/i2-sdf/i2-sdf/real_data/nisr_livingroom.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_livingroom.zip ================================================ FILE: main_recon.py ================================================ import torch import yaml import pytorch_lightning as pl import argparse import os import utils import model from pytorch_lightning import loggers from pytorch_lightning.callbacks import ModelCheckpoint from rich.progress import TextColumn import GPUtil if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--conf", type=str, required=True, help="Path to (.yml) config file.") parser.add_argument('-d', "--device_ids", type=int, nargs='+', default=None, help="GPU devices to use") parser.add_argument("--exps_folder", type=str, default="exps") parser.add_argument('--expname', type=str, default='') parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.') parser.add_argument('--test', action='store_true') parser.add_argument('--test_mode', choices=['render', 'mesh', 'interpolate'], default='render') parser.add_argument('-v', '--version', type=int, nargs='?') parser.add_argument('--inter_id', type=int, nargs=2, required=False, help='2 view ids for interpolation video.') parser.add_argument('-i', '--indices', nargs='*', type=int, help='If set, render only specified indices of the dataset instead of all images.') parser.add_argument('--n_frames', type=int, default=60, help='Number of frames in the interpolation video.') parser.add_argument('--frame_rate', type=int, default=24, help='Frame rate of the interpolation video.') parser.add_argument('-f', '--full_res', action='store_true', help='If set, dataset downscaling will be ignored.') parser.add_argument('--is_val', action='store_true', help='If set, render the validation set instead of training set.') parser.add_argument('--val_mesh', action='store_true', help='If set, extract and save mesh every validation epoch.') parser.add_argument('--score', action='store_true', help='If set, evaluate the meshing score (need to provide GT mesh).') parser.add_argument('--far_clip', type=float, default=5.0) parser.add_argument('--ckpt', type=str, default='last') parser.add_argument('--resolution', type=int, default=512, help='Resolution for marching cube algorithm') parser.add_argument('--spp', type=int, default=128) parser.add_argument('--seed', type=int, default=42) args = parser.parse_args() with open(args.conf) as f: cfg_dict = yaml.load(f, Loader=yaml.FullLoader) cfg = utils.CfgNode(cfg_dict) expname = args.expname if args.expname else cfg.train.expname scan_id = cfg.dataset.scan_id if args.scan_id == -1 else args.scan_id cfg.dataset.scan_id = scan_id expname = expname + '_' + str(scan_id) if args.version is None and (v := args.conf.find("version_")) != -1: args.version = int(args.conf[v + 8:args.conf.find("/config")]) print(f"[INFO] Loaded version {args.version} from config file") if args.version is not None: logger = loggers.TensorBoardLogger(save_dir=args.exps_folder, name=expname, version=args.version) else: logger = loggers.TensorBoardLogger(save_dir=args.exps_folder, name=expname) if args.device_ids is None: args.device_ids = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, excludeID=[], excludeUUID=[]) print("Selected GPU {} automatically".format(args.device_ids[0])) torch.cuda.set_device(args.device_ids[0]) torch.set_float32_matmul_precision('medium') progbar_callback = utils.RichProgressBarWithScanId(scan_id, leave=False) pl.seed_everything(args.seed) if args.test: version = args.version if args.version is not None else logger.version - 1 exp_dir = os.path.join(logger.root_dir, f"version_{version}") del logger if args.test_mode == 'render': system = model.VolumeRenderSystem(cfg, exp_dir, indices=args.indices, is_val=args.is_val, full_res=args.full_res) if not args.ckpt.endswith('.ckpt'): args.ckpt += '.ckpt' ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda') system.load_state_dict(ckpt['state_dict']) model.lpips.cuda() elif args.test_mode == 'mesh': system = model.SDFMeshSystem(cfg, exp_dir, args.resolution, args.score) if not args.ckpt.endswith('.ckpt'): args.ckpt += '.ckpt' ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda') system.load_state_dict(ckpt['state_dict']) system.cuda() system.eval() system.initialize() # elif args.test_mode == 'interpolate': else: system = model.ViewInterpolateSystem(cfg, exp_dir, *args.inter_id, n_frames=args.n_frames, frame_rate=args.frame_rate) if not args.ckpt.endswith('.ckpt'): args.ckpt += '.ckpt' ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda') system.load_state_dict(ckpt['state_dict']) trainer = pl.Trainer( logger=False, accelerator='gpu', devices=args.device_ids, callbacks=[progbar_callback] ) trainer.test(system) else: max_steps = cfg.train.get('steps', 200000) print(f"Training for {max_steps} steps") exp_dir = logger.log_dir checkpoint_callback = ModelCheckpoint(os.path.join(exp_dir, 'checkpoints'), save_last=True, every_n_train_steps=cfg.train.checkpoint_freq) if hasattr(cfg.train, 'plot_freq'): kwargs = {'val_check_interval': cfg.train.plot_freq} else: kwargs = {'check_val_every_n_epoch': cfg.train.plot_epochs} trainer = pl.Trainer( logger=logger, accelerator='gpu', devices=args.device_ids, strategy=None, callbacks=[checkpoint_callback, progbar_callback], max_steps=max_steps, **kwargs ) system = model.ReconstructionTrainer( cfg, progbar_callback, exp_dir=exp_dir, is_val=args.is_val, val_mesh=args.val_mesh ) trainer.fit(system) torch.cuda.empty_cache() ================================================ FILE: model/__init__.py ================================================ from .network import * from .trainer import * # from .material import * from .eval import * from .rendering import RenderingLayer ================================================ FILE: model/eval/__init__.py ================================================ from .recon import * ================================================ FILE: model/eval/recon.py ================================================ import torch import pytorch_lightning as pl import numpy as np import os from glob import glob from torch.utils.data import DataLoader import utils from utils import rend_util import utils.plots as plt import dataset import model from skimage import measure import cv2 import trimesh from rich.progress import track from torchmetrics.functional import structural_similarity_index_measure as ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS lpips = LPIPS() class SDFMeshSystem(pl.LightningModule): def __init__(self, conf, exp_dir, resolution, score=False, far_clip=5.0) -> None: super().__init__() self.expdir = exp_dir conf_model = conf.model conf_model.use_normal = False self.model = model.I2SDFNetwork(conf_model) self.resolution = resolution self.grid_boundary = conf.plot.grid_boundary self.initialized = False self.instance_dir = os.path.join('data', conf.dataset.data_dir, 'scan{0}'.format(conf.dataset.scan_id)) camera_dict = np.load(os.path.join(self.instance_dir, 'cameras_normalize.npz')) self.scale_mat = camera_dict['scale_mat_0'] self.scan_id = conf.dataset.scan_id self.score = score if score: self.n_imgs = len(os.listdir(os.path.join(self.instance_dir, 'image'))) self.poses = [] self.far_clip = far_clip for i in range(self.n_imgs): K, pose = rend_util.load_K_Rt_from_P(None, camera_dict[f'world_mat_{i}'][:3,:]) self.poses.append(pose) self.K = K self.H, self.W, _ = cv2.imread(os.path.join(self.instance_dir, 'image', '0000.png')).shape def initialize(self): grid = plt.get_grid_uniform(100, self.grid_boundary) z = [] points = grid['grid_points'] for pnts in track(torch.split(points, 1000000, dim=0)): z.append(self.model.implicit_network(pnts)[:,0].detach().cpu().numpy()) z = np.concatenate(z, axis=0).astype(np.float32) verts, faces, normals, values = measure.marching_cubes( volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], grid['xyz'][2].shape[0]).transpose([1, 0, 2]), level=0, spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1])) verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) mesh_low_res = trimesh.Trimesh(verts, faces, normals) recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] recon_pc = torch.from_numpy(recon_pc).float().cuda() s_mean = recon_pc.mean(dim=0) s_cov = recon_pc - s_mean s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) self.vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0] if torch.det(self.vecs) < 0: self.vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), self.vecs) helper = torch.bmm(self.vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), (recon_pc - s_mean).unsqueeze(-1)).squeeze().cpu() grid_aligned = plt.get_grid(helper, self.resolution) grid_points = grid_aligned['grid_points'] g = [] for pnts in track(torch.split(grid_points, 1000000, dim=0)): g.append(torch.bmm(self.vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), pnts.unsqueeze(-1)).squeeze() + s_mean) grid_points = torch.cat(g, dim=0) points = grid_points.cpu() self.test_dataset = dataset.GridDataset(points, grid_aligned['xyz']) self.grid_points = grid_points self.initialized = True def test_dataloader(self): assert self.initialized print(len(self.test_dataset)) return DataLoader(self.test_dataset, batch_size=2000000, shuffle=False, num_workers=32) def test_step(self, batch, batch_idx): return self.model.implicit_network(batch)[:,0].detach().cpu().numpy() def test_epoch_end(self, outputs) -> None: # z = torch.cat(outputs, dim=0).cpu().numpy() z = np.concatenate(outputs, axis=0).astype(np.float32) if (not (np.min(z) > 0 or np.max(z) < 0)): verts, faces, normals, values = measure.marching_cubes( volume=z.reshape(self.test_dataset.xyz[1].shape[0], self.test_dataset.xyz[0].shape[0], self.test_dataset.xyz[2].shape[0]).transpose([1, 0, 2]), level=0, spacing=(self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1], self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1], self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1])) verts = torch.from_numpy(verts).float().cuda() verts = torch.bmm(self.vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), verts.unsqueeze(-1)).squeeze() verts = (verts + self.grid_points[0]).cpu().numpy() mesh = trimesh.Trimesh(verts, faces, normals) mesh.apply_transform(self.scale_mat) mesh_folder = os.path.join(self.expdir, 'eval/mesh') os.makedirs(mesh_folder, exist_ok=True) mesh.export(os.path.join(mesh_folder, 'scan{0}.ply'.format(self.scan_id)), 'ply') if self.score: from utils import mesh_util import open3d as o3d mesh = mesh_util.refuse(mesh, self.poses, self.K, self.H, self.W) out_mesh_path = os.path.join(mesh_folder, 'scan{0}_refined.ply'.format(self.scan_id)) o3d.io.write_triangle_mesh(out_mesh_path, mesh) mesh = trimesh.load(out_mesh_path) print("[INFO] Pred mesh refined") gt_mesh = trimesh.load(os.path.join(self.instance_dir, 'mesh.ply')) gt_mesh = mesh_util.refuse(gt_mesh, self.poses, self.K, self.H, self.W, self.far_clip) out_mesh_path = os.path.join(mesh_folder, 'scan{0}_gt.ply'.format(self.scan_id)) o3d.io.write_triangle_mesh(out_mesh_path, gt_mesh) gt_mesh = trimesh.load(out_mesh_path) print("[INFO] GT mesh refined") metrics = mesh_util.evaluate(mesh, gt_mesh) with open(f"{mesh_folder}/metrics.txt", 'w') as f: for k in metrics: f.write(f"{k.upper()}: {metrics[k]}\n") print(f"[INFO] Metrics saved to {mesh_folder}/metrics.txt\n") def forward(self): raise NotImplementedError("forward not supported by trainer") class VolumeRenderSystem(pl.LightningModule): def __init__(self, conf, exp_dir, indices=None, is_val=False, score_mesh=False, full_res=False) -> None: super().__init__() self.expdir = exp_dir conf_model = conf.model conf_model.use_normal = False self.model = model.I2SDFNetwork(conf_model) self.scan_id = conf.dataset.scan_id dataset_conf = conf.dataset if full_res: dataset_conf.downsample = 1 self.test_dataset = dataset.PlotDataset(**dataset_conf, plot_nimgs=-1, shuffle=False, indices=indices, is_val=is_val) self.total_pixels = self.test_dataset.total_pixels self.img_res = self.test_dataset.img_res self.split_n_pixels = conf.train.split_n_pixels self.expdir = os.path.join(self.expdir, 'eval') if is_val: self.expdir = os.path.join(self.expdir, 'test') os.makedirs(os.path.join(self.expdir, 'rendering'), exist_ok=True) os.makedirs(os.path.join(self.expdir, 'depth'), exist_ok=True) os.makedirs(os.path.join(self.expdir, 'normal'), exist_ok=True) def test_dataloader(self): print(len(self.test_dataset)) return DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.test_dataset.collate_fn) @torch.inference_mode(False) @torch.no_grad() def test_step(self, batch, batch_idx): indices, model_input, ground_truth = batch # idx = batch_idx idx = self.test_dataset.indices[batch_idx] split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels) res = [] for s in split: out = utils.detach_dict(self.model(s)) d = { 'rgb_values': out['rgb_values'].detach(), 'depth_values': out['depth_values'].detach() } d['normal_map'] = out['normal_map'].detach() # d['surface_point'] = out['surface_point'].detach() del out res.append(d) model_outputs = utils.merge_output(res, self.total_pixels, 1) _, num_samples, _ = ground_truth['rgb'].shape model_outputs['rgb_values'] = model_outputs['rgb_values'].reshape(1, num_samples, 3) model_outputs['depth_values'] = model_outputs['depth_values'].reshape(1, num_samples, 1) plt.plot_imgs_wo_gt(model_outputs['normal_map'].reshape(1, num_samples, 3), self.expdir, "{:04d}w".format(idx), 1, self.img_res, is_hdr=True) normal_map = model_outputs['normal_map'].reshape(num_samples, 3).T # (3, h*w) R = model_input['pose'].squeeze()[:3,:3].T normal_map = R @ normal_map model_outputs['normal_map'] = normal_map.T.reshape(1, num_samples, 3) plt.plot_imgs_wo_gt(model_outputs['normal_map'], self.expdir, "{:04d}".format(idx), 1, self.img_res, is_hdr=True) model_outputs['normal_map'] = (model_outputs['normal_map'] + 1.) / 2. plt.plot_imgs_wo_gt(model_outputs['normal_map'], self.expdir, "{:04d}".format(idx), 1, self.img_res) # plt.plot_imgs_wo_gt(model_outputs['surface_point'].reshape(1, num_samples, 3), self.expdir, "{:04d}p".format(idx), 1, self.img_res, is_hdr=True) plt.plot_images(model_outputs['rgb_values'], ground_truth['rgb'], self.expdir, "{:04d}".format(idx), 1, self.img_res) plt.plot_imgs_wo_gt(model_outputs['rgb_values'], self.expdir, "{:04d}_pred".format(idx), 1, self.img_res, 'rendering') plt.plot_depths(model_outputs['depth_values'], self.expdir, "{:04d}".format(idx), 1, self.img_res) plt.plot_depths(model_outputs['depth_values'], self.expdir, "{:04d}".format(idx), 1, self.img_res, None) pred_img = model_outputs['rgb_values'].T.reshape(3, *self.img_res).unsqueeze(0) gt_img = ground_truth['rgb'].T.reshape(3, *self.img_res).unsqueeze(0) return { 'psnr': utils.get_psnr(model_outputs['rgb_values'], ground_truth['rgb']).item(), 'ssim': ssim(pred_img, gt_img).item(), 'lpips': lpips(pred_img.clamp(0, 1) * 2 - 1, gt_img.clamp(0, 1) * 2 - 1).item() } def test_epoch_end(self, outputs): with open(os.path.join(self.expdir, 'metrics.txt'), 'w') as f: f.write(f"# IMAGE RESOLUTION {self.img_res}\n") psnr_sum = ssim_sum = lpips_sum = 0 psnrs = [] ssims = [] lpipss = [] for i, metrics in enumerate(outputs): f.write(f"[{i:04d}] [PSNR]{metrics['psnr']:.2f} [SSIM]{metrics['ssim']:.2f} [LPIPS]{metrics['lpips']:.2f}\n") psnrs.append(metrics['psnr']) ssims.append(metrics['ssim']) lpipss.append(metrics['lpips']) psnr_sum += metrics['psnr'] ssim_sum += metrics['ssim'] lpips_sum += metrics['lpips'] f.write(f"[MEAN] [PSNR]{psnr_sum/len(outputs):.2f} [SSIM]{ssim_sum/len(outputs):.2f} [LPIPS]{lpips_sum/len(outputs):.2f}\n") np.savez_compressed(os.path.join(self.expdir, 'metrics.npz'), psnr=np.array(psnrs), ssim=np.array(ssims), lpips=np.array(lpipss)) def forward(self): raise NotImplementedError("forward not supported by trainer") class ViewInterpolateSystem(pl.LightningModule): def __init__(self, conf, exp_dir, id0, id1, n_frames=60, frame_rate=24, use_normal=True) -> None: super().__init__() self.expdir = exp_dir conf_model = conf.model conf_model.use_normal = False self.model = model.I2SDFNetwork(conf_model) self.scan_id = conf.dataset.scan_id dataset_conf = conf.dataset self.test_dataset = dataset.InterpolateDataset(**dataset_conf, id0=id0, id1=id1, num_frames=n_frames) self.total_pixels = self.test_dataset.total_pixels self.img_res = self.test_dataset.img_res self.split_n_pixels = conf.train.split_n_pixels self.n_frames = n_frames self.frame_rate = frame_rate self.video_dir = os.path.join(self.expdir, 'eval/interpolate') self.id0 = id0 self.id1 = id1 self.use_normal = use_normal os.makedirs(self.video_dir, exist_ok=True) self.frame_dir = os.path.join(self.video_dir, f"{self.id0:04d}_{self.id1:04d}") os.makedirs(self.frame_dir, exist_ok=True) if self.use_normal: self.normal_fdir = os.path.join(self.video_dir, f"{self.id0:04d}_{self.id1:04d}_normal") os.makedirs(self.normal_fdir, exist_ok=True) def test_dataloader(self): print(len(self.test_dataset)) return DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.test_dataset.collate_fn) @torch.inference_mode(False) @torch.no_grad() def test_step(self, batch, batch_idx): indices, model_input = batch idx = batch_idx split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels) res = [] res_normal = [] for s in split: out = utils.detach_dict(self.model(s, predict_only=not self.use_normal)) rgb = out['rgb_values'].detach() res.append(rgb) if self.use_normal: res_normal.append(out['normal_map']) del out rendered = torch.cat(res, dim=0).reshape(self.img_res[0], self.img_res[1], 3).cpu().numpy() rendered = (rendered * 255).clip(0, 255).astype(np.uint8) cv2.imwrite(f"{self.frame_dir}/{idx:04d}.png", rendered[:,:,::-1]) if self.use_normal: normal_map = torch.cat(res_normal, dim=0).reshape(-1, 3).T R = model_input['pose'].squeeze()[:3,:3].T normal_map = R @ normal_map normal_map = normal_map.T.reshape(self.img_res[0], self.img_res[1], 3).cpu().numpy() normal_map = (((normal_map + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8) cv2.imwrite(f"{self.normal_fdir}/{idx:04d}.png", normal_map[:,:,::-1]) def test_epoch_end(self, outputs): import ffmpeg ( ffmpeg .input(os.path.join(self.frame_dir, '*.png'), pattern_type='glob', framerate=self.frame_rate) .output(os.path.join(self.video_dir, f"scan{self.scan_id}_{self.id0:04d}_{self.id1:04d}.mp4"), vcodec='h264') .overwrite_output() .run() ) if self.use_normal: ( ffmpeg .input(os.path.join(self.normal_fdir, '*.png'), pattern_type='glob', framerate=self.frame_rate) .output(os.path.join(self.video_dir, f"scan{self.scan_id}_{self.id0:04d}_{self.id1:04d}_normal.mp4"), vcodec='h264') .overwrite_output() .run() ) def forward(self): raise NotImplementedError("forward not supported by trainer") ================================================ FILE: model/network/__init__.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import math import numpy as np import utils from model.network.mlp import ImplicitNetwork, RenderingNetwork from model.network.density import LaplaceDensity, AbsDensity from model.network.ray_sampler import ErrorBoundSampler from fast_pytorch_kmeans import KMeans from sklearn.cluster import DBSCAN """ For modeling more complex backgrounds, we follow the inverted sphere parametrization from NeRF++ https://github.com/Kai-46/nerfplusplus """ class I2SDFNetwork(nn.Module): def __init__(self, conf): super().__init__() self.feature_vector_size = conf.feature_vector_size self.scene_bounding_sphere = getattr(conf, 'scene_bounding_sphere', 1.0) # Foreground object's networks self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0, **conf.implicit_network) self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.rendering_network) self.use_light = hasattr(conf, 'light_network') if self.use_light: # self.light_network = RenderingNetwork(self.feature_vector_size, mode='nerf', output_activation='sigmoid', use_dir=False, **conf.light_network) self.light_network = ImplicitNetwork(0, 0, d_in=self.feature_vector_size, d_out=1, geometric_init=False, embed_type=None, output_activation='sigmoid', **conf.light_network) self.density = LaplaceDensity(**conf.density) # Background's networks self.use_bg = hasattr(conf, 'bg_network') if self.use_bg: bg_feature_vector_size = conf.bg_network.feature_vector_size self.bg_implicit_network = ImplicitNetwork(bg_feature_vector_size, 0.0, **conf.bg_network.implicit_network) self.bg_rendering_network = RenderingNetwork(bg_feature_vector_size, **conf.bg_network.rendering_network) self.bg_density = AbsDensity(**getattr(conf.bg_network, 'density', {})) else: print("[INFO] BG Network Disabled") self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, inverse_sphere_bg=self.use_bg, **conf.ray_sampler) self.use_normal = conf.get('use_normal', False) self.detach_light_feature = conf.get('detach_light_feature', True) def init_emission_groups(self, n_emitters, pointcloud, init_emission=1.0, use_dbscan=False): if use_dbscan: """ Use DBSCAN algorithm to initialize emitter cluster centroids for K-Means from a small random batch Note that DBSCAN can automatically determine the number of clusters """ pt_samples = pointcloud[torch.randperm(len(pointcloud))[:10000]].cpu().numpy() lab_samples = torch.from_numpy(DBSCAN(n_jobs=16).fit_predict(pt_samples)) if n_emitters != len(torch.unique(lab_samples)): print(f"[ERROR] Inconsistent emitter count: {n_emitters} / {len(torch.unique(lab_samples))}") # n_emitters = len(torch.unique(lab_samples)) exit() init_centroids = torch.zeros(n_emitters, 3) for i in range(n_emitters): idx = (lab_samples == i).int().argmax() init_centroids[i,:] = torch.from_numpy(pt_samples[idx, :]) init_centroids = init_centroids.to(pointcloud.device) else: """ Use K-Means plus plus to initialize emitter cluster centroids for K-Means """ init_centroids = utils.kmeans_pp_centroid(pointcloud, n_emitters) self.emitter_clusters = KMeans(n_emitters) labels = self.emitter_clusters.fit_predict(pointcloud, init_centroids) print("[INFO] emitters clustered") self.emissions = nn.Parameter(torch.empty(n_emitters, 3).fill_(init_emission), True) return labels, self.emitter_clusters.centroids def get_param_groups(self, lr): return [{'params': self.parameters(), 'lr': lr}] def forward(self, input, predict_only=False): intrinsics = input["intrinsics"] uv = input["uv"] pose = input["pose"] ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics) batch_size, num_pixels, _ = ray_dirs.shape cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) ray_dirs = ray_dirs.reshape(-1, 3) ray_dirs_norm = torch.linalg.vector_norm(ray_dirs, dim=1) ray_dirs = F.normalize(ray_dirs, dim=1) z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self) if self.use_bg: z_vals, z_vals_bg = z_vals z_max = z_vals[:,-1] z_vals = z_vals[:,:-1] N_samples = z_vals.shape[1] points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) points_flat = points.reshape(-1, 3) dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) dirs_flat = dirs.reshape(-1, 3) returns_grad = self.use_normal or (not self.training) or (self.rendering_network.mode == 'idr') # with torch.enable_grad(): with torch.set_grad_enabled(returns_grad): # with torch.inference_mode(not returns_grad): sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat, returns_grad) rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors) rgb = rgb_flat.reshape(-1, N_samples, 3) weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf) fg_rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1) weight_sum = torch.sum(weights, -1, keepdim=True) # dist = torch.sum(weights / weight_sum.clamp(min=1e-6) * z_vals, 1) dist = torch.sum(weights * z_vals, 1) depth_values = dist / torch.clamp(ray_dirs_norm, min=1e-6) # depth_values = torch.sum(weights * z_vals, 1) / torch.clamp(torch.sum(weights, 1), min=1e-6) # (bn,) # Background rendering if self.use_bg: N_bg_samples = z_vals_bg.shape[1] z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ]) # 1--->0 bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1) bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1) bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg) # [..., N_samples, 4] bg_points_flat = bg_points.reshape(-1, 4) bg_dirs_flat = bg_dirs.reshape(-1, 3) output = self.bg_implicit_network(bg_points_flat) bg_sdf = output[:,:1] bg_feature_vectors = output[:, 1:] bg_rgb_flat = self.bg_rendering_network(None, None, bg_dirs_flat, bg_feature_vectors) bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3) bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf) bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1) # Composite foreground and background bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values rgb_values = fg_rgb_values + bg_rgb_values else: rgb_values = fg_rgb_values output = { 'rgb_values': rgb_values, 'depth_values': depth_values, 'weight_sum': weight_sum } if self.use_light: light_features = F.relu(feature_vectors) if self.detach_light_feature: light_features = light_features.detach_() # lmask_flat = self.light_network(None, None, None, light_features) lmask_flat = self.light_network(light_features) lmask = lmask_flat.reshape(-1, N_samples, 1) lmask_values = torch.sum(weights.unsqueeze(-1).detach() * lmask, 1) output['light_mask'] = lmask_values if predict_only: return output if self.training: # Sample points for the eikonal loss n_eik_points = batch_size * num_pixels eikonal_points = torch.empty(n_eik_points, 3, device=cam_loc.device).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere) # Add some of the near surface points eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3) n_eik_near = eik_near_points.size(0) eikonal_points = torch.cat([eikonal_points, eik_near_points], 0) # Add neighbor points near surface for smoothness loss eik_near_neighbors = eik_near_points + torch.empty_like(eik_near_points).uniform_(-0.005, 0.005) eikonal_points = torch.cat([eikonal_points, eik_near_neighbors], 0) grad_theta = self.implicit_network.gradient(eikonal_points) output['grad_theta'] = grad_theta[:n_eik_points+n_eik_near,] normals = grad_theta[n_eik_points:,] normals = F.normalize(normals, dim=1, eps=1e-6) diff_norm = torch.norm(normals[:n_eik_near,:] - normals[n_eik_near:,:], dim=1) output['diff_norm'] = diff_norm # Sample pointclouds for bubble loss if 'pointcloud' in input: surface_points = input['pointcloud'] cam_loc_selected = cam_loc[np.random.randint(0, len(cam_loc)),:] surface_points = torch.cat([surface_points, cam_loc_selected.unsqueeze(0)], dim=0) surface_sdf = self.implicit_network.get_sdf_vals(surface_points) output['surface_sdf'] = surface_sdf[:-1,:] # Accumulate gradients for normal loss if self.use_normal: normals = F.normalize(gradients, dim=-1) normals = normals.reshape(-1, N_samples, 3) normal_map = torch.sum(weights.unsqueeze(-1).detach() * normals, 1) normal_map = F.normalize(normal_map, dim=-1) output['normal_values'] = normal_map # elif not self.training: else: # Accumulate gradients for normal visualization gradients = gradients.detach() normals = F.normalize(gradients, dim=-1) normals = normals.reshape(-1, N_samples, 3) normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1) normal_map = F.normalize(normal_map, dim=-1) output['normal_map'] = normal_map return output def volume_rendering(self, z_vals, z_max, sdf): density_flat = self.density(sdf) density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples # included also the dist from the sphere intersection dists = z_vals[:, 1:] - z_vals[:, :-1] dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1) # LOG SPACE free_energy = dists * density shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=free_energy.device), free_energy], dim=-1) # add 0 for transperancy 1 at t_0 alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now fg_transmittance = transmittance[:, :-1] weights = alpha * fg_transmittance # probability of the ray hits something here bg_transmittance = transmittance[:, -1] # factor to be multiplied with the bg volume rendering return weights, bg_transmittance def bg_volume_rendering(self, z_vals_bg, bg_sdf): bg_density_flat = self.bg_density(bg_sdf) bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:] bg_dists = torch.cat([bg_dists, torch.tensor([1e10], device=bg_dists.device).unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1) # LOG SPACE bg_free_energy = bg_dists * bg_density bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1, device=bg_free_energy.device), bg_free_energy[:, :-1]], dim=-1) # shift one step bg_alpha = 1 - torch.exp(-bg_free_energy) # probability of it is not empty here bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1)) # probability of everything is empty up to now bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here return bg_weights def depth2pts_outside(self, ray_o, ray_d, depth): ''' ray_o, ray_d: [..., 3] depth: [...]; inverse of distance to sphere origin ''' o_dot_d = torch.sum(ray_d * ray_o, dim=-1) under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.scene_bounding_sphere ** 2) d_sphere = torch.sqrt(under_sqrt) - o_dot_d p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d p_mid_norm = torch.norm(p_mid, dim=-1) rot_axis = torch.cross(ray_o, p_sphere, dim=-1) rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) phi = torch.asin(p_mid_norm / self.scene_bounding_sphere) theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1] rot_angle = (phi - theta).unsqueeze(-1) # [..., 1] # now rotate p_sphere # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula p_sphere_new = p_sphere * torch.cos(rot_angle) + \ torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \ rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle)) p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True) pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1) return pts class I2SDFLoss(nn.Module): def __init__(self, eikonal_weight=0.1, smooth_weight=0.0, mask_weight=0.0, depth_weight=0.1, normal_weight=0.05, angular_weight=0.05, bubble_weight=0.0, min_bubble_iter=0, max_bubble_iter=None, smooth_iter=None, light_mask_weight=0.0, eikonal_weight_bubble=0.0): super().__init__() self.eikonal_weight = eikonal_weight self.rgb_loss = F.l1_loss self.smooth_weight = smooth_weight self.mask_weight = mask_weight self.depth_weight = depth_weight self.normal_weight = normal_weight self.angular_weight = angular_weight self.bubble_weight = bubble_weight # self.eikonal_weight_bubble = eikonal_weight_bubble if eikonal_weight_bubble else self.eikonal_weight self.min_bubble_iter = min_bubble_iter self.max_bubble_iter = max_bubble_iter self.smooth_iter = smooth_iter if self.bubble_weight > 0 and self.max_bubble_iter is not None and self.smooth_iter < self.max_bubble_iter: self.smooth_iter = self.max_bubble_iter # Disable smoothness loss during bubble steps self.light_mask_weight = light_mask_weight def get_rgb_loss(self, rgb_values, rgb_gt): rgb_gt = rgb_gt.reshape(-1, 3) rgb_loss = self.rgb_loss(rgb_values, rgb_gt) return rgb_loss def get_eikonal_loss(self, grad_theta): eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() return eikonal_loss def get_mask_loss(self, mask_pred, mask_gt): return F.binary_cross_entropy(mask_pred.clip(1e-3, 1.0 - 1e-3), mask_gt) def get_depth_loss(self, depth, depth_gt, depth_mask): depth_gt = depth_gt.flatten() depth_mask = depth_mask.flatten() # TODO: Add support for scale invariant depth loss (like MonoSDF) return F.mse_loss(depth[depth_mask], depth_gt[depth_mask]) def get_normal_l1_loss(self, normal, normal_gt, normal_mask): normal_gt = normal_gt.reshape(-1, 3) normal_mask = normal_mask.flatten() return torch.abs(1 - torch.sum(normal[normal_mask] * normal_gt[normal_mask], dim=-1)).mean() def get_normal_angular_loss(self, normal, normal_gt, normal_mask): normal_gt = normal_gt.reshape(-1, 3) normal_mask = normal_mask.flatten() dot = torch.sum(normal[normal_mask] * normal_gt[normal_mask], dim=-1) angle = torch.acos(torch.clamp(dot, -1.0+1e-6, 1.0-1e-6)) / math.tau return angle.clamp_max(0.5).abs().mean() def forward(self, model_outputs, ground_truth, current_step): rgb_gt = ground_truth['rgb'] rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt) if 'grad_theta' in model_outputs: eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta']) else: eikonal_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float() smooth_activated = self.smooth_iter is None or current_step > self.smooth_iter if smooth_activated and self.smooth_weight > 0 and 'diff_norm' in model_outputs: smooth_loss = model_outputs['diff_norm'].mean() else: smooth_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float() if 'mask' in ground_truth and self.mask_weight > 0: mask_loss = self.get_mask_loss(model_outputs['weight_sum'], ground_truth['mask']) else: mask_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float() if 'depth' in ground_truth and self.depth_weight > 0: depth_loss = self.get_depth_loss(model_outputs['depth_values'], ground_truth['depth'], ground_truth['depth_mask']) else: depth_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float() if 'normal' in ground_truth and self.normal_weight > 0: normal_loss = self.get_normal_l1_loss(model_outputs['normal_values'], ground_truth['normal'], ground_truth['normal_mask']) else: normal_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float() if 'normal' in ground_truth and self.angular_weight > 0: angular_loss = self.get_normal_l1_loss(model_outputs['normal_values'], ground_truth['normal'], ground_truth['normal_mask']) else: angular_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float() if 'surface_sdf' in model_outputs and self.bubble_weight > 0: bubble_loss = model_outputs['surface_sdf'].abs().mean() else: bubble_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float() if 'light_mask' in model_outputs and self.light_mask_weight > 0: light_mask_loss = self.get_mask_loss(model_outputs['light_mask'].reshape(-1, 1), ground_truth['light_mask'].reshape(-1, 1)) else: light_mask_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float() loss = rgb_loss + \ self.eikonal_weight * eikonal_loss + \ self.smooth_weight * smooth_loss + \ self.mask_weight * mask_loss + \ self.depth_weight * depth_loss + \ self.normal_weight * normal_loss + \ self.angular_weight * angular_loss + \ self.bubble_weight * bubble_loss + \ self.light_mask_weight * light_mask_loss output = { 'loss': loss, 'rgb_loss': rgb_loss, 'eikonal_loss': eikonal_loss, 'smooth_loss': smooth_loss, 'mask_loss': mask_loss, 'depth_loss': depth_loss, 'normal_loss': normal_loss, 'angular_loss': angular_loss, 'bubble_loss': bubble_loss, 'light_mask_loss': light_mask_loss } return output ================================================ FILE: model/network/density.py ================================================ import torch.nn as nn import torch class Density(nn.Module): def __init__(self, params_init={}): super().__init__() for p in params_init: param = nn.Parameter(torch.tensor(params_init[p])) setattr(self, p, param) def forward(self, sdf, beta=None): return self.density_func(sdf, beta=beta) class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf) def __init__(self, params_init={}, beta_min=0.0001): super().__init__(params_init=params_init) self.beta_min = torch.tensor(beta_min) def density_func(self, sdf, beta=None): if beta is None: beta = self.get_beta() alpha = 1 / beta return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) def get_beta(self): beta = self.beta.abs() + self.beta_min return beta class AbsDensity(Density): # like NeRF++ def density_func(self, sdf, beta=None): return torch.abs(sdf) class SimpleDensity(Density): # like NeRF def __init__(self, params_init={}, noise_std=1.0): super().__init__(params_init=params_init) self.noise_std = noise_std def density_func(self, sdf, beta=None): if self.training and self.noise_std > 0.0: noise = torch.randn(sdf.shape).to(sdf.device) * self.noise_std sdf = sdf + noise return torch.relu(sdf) ================================================ FILE: model/network/embedder.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class Embedder: """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x: x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) else: freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) class SHEncoder(nn.Module): def __init__(self, input_dims=3, degree=4): super().__init__() self.input_dims = input_dims self.degree = degree assert self.input_dims == 3 assert self.degree >= 1 and self.degree <= 5 self.out_dim = degree ** 2 self.C0 = 0.28209479177387814 self.C1 = 0.4886025119029199 self.C2 = [ 1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396 ] self.C3 = [ -0.5900435899266435, 2.890611442640554, -0.4570457994644658, 0.3731763325901154, -0.4570457994644658, 1.445305721320277, -0.5900435899266435 ] self.C4 = [ 2.5033429417967046, -1.7701307697799304, 0.9461746957575601, -0.6690465435572892, 0.10578554691520431, -0.6690465435572892, 0.47308734787878004, -1.7701307697799304, 0.6258357354491761 ] def forward(self, input, **kwargs): result = torch.empty((*input.shape[:-1], self.out_dim), dtype=input.dtype, device=input.device) x, y, z = input.unbind(-1) result[..., 0] = self.C0 if self.degree > 1: result[..., 1] = -self.C1 * y result[..., 2] = self.C1 * z result[..., 3] = -self.C1 * x if self.degree > 2: xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z result[..., 4] = self.C2[0] * xy result[..., 5] = self.C2[1] * yz result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy) #result[..., 6] = self.C2[2] * (3.0 * zz - 1) # xx + yy + zz == 1, but this will lead to different backward gradients, interesting... result[..., 7] = self.C2[3] * xz result[..., 8] = self.C2[4] * (xx - yy) if self.degree > 3: result[..., 9] = self.C3[0] * y * (3 * xx - yy) result[..., 10] = self.C3[1] * xy * z result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy) result[..., 12] = self.C3[3] * z * (2 * zz - 3 * xx - 3 * yy) result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy) result[..., 14] = self.C3[5] * z * (xx - yy) result[..., 15] = self.C3[6] * x * (xx - 3 * yy) if self.degree > 4: result[..., 16] = self.C4[0] * xy * (xx - yy) result[..., 17] = self.C4[1] * yz * (3 * xx - yy) result[..., 18] = self.C4[2] * xy * (7 * zz - 1) result[..., 19] = self.C4[3] * yz * (7 * zz - 3) result[..., 20] = self.C4[4] * (zz * (35 * zz - 30) + 3) result[..., 21] = self.C4[5] * xz * (7 * zz - 3) result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1) result[..., 23] = self.C4[7] * xz * (xx - 3 * yy) result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) return result class FourierFeature(nn.Module): def __init__(self, channels, sigma=1.0, input_dims=3, include_input=True) -> None: super().__init__() self.register_buffer('B', torch.randn(input_dims, channels) * sigma, True) self.channels = channels self.out_dim = 2 * self.channels + 3 if include_input else 2 * self.channels self.include_input = include_input def forward(self, x): xp = torch.matmul(2 * np.pi * x, self.B) return torch.cat([x, torch.sin(xp), torch.cos(xp)], dim=-1) if self.include_input else torch.cat([torch.sin(xp), torch.cos(xp)], dim=-1) def get_embedder(embed_type='positional', **kwargs): if embed_type == 'positional': input_dims = kwargs['input_dims'] multires = kwargs['multires'] embed_kwargs = { 'include_input': True, 'input_dims': input_dims, 'max_freq_log2': multires-1, 'num_freqs': multires, 'log_sampling': True, 'periodic_fns': [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) def embed(x, eo=embedder_obj): return eo.embed(x) return embed, embedder_obj.out_dim elif embed_type == 'spherical_harmonics': embedder = SHEncoder(**kwargs) return embedder, embedder.out_dim elif embed_type == 'fourier': embedder = FourierFeature(**kwargs) return embedder, embedder.out_dim else: raise ValueError('Unknown embedding type: {}'.format(embed_type)) ================================================ FILE: model/network/mlp.py ================================================ import torch.nn as nn import numpy as np import utils from .embedder import * from .density import LaplaceDensity from .ray_sampler import ErrorBoundSampler class ImplicitNetwork(nn.Module): def __init__( self, feature_vector_size, sdf_bounding_sphere, d_in, d_out, dims, geometric_init=True, bias=1.0, skip_in=(), weight_norm=True, embed_type=None, sphere_scale=1.0, output_activation=None, **kwargs ): super().__init__() self.sdf_bounding_sphere = sdf_bounding_sphere self.sphere_scale = sphere_scale dims = [d_in] + dims + [d_out + feature_vector_size] self.embed_fn = None if embed_type: embed_fn, input_ch = get_embedder(embed_type, input_dims=d_in, **kwargs) self.embed_fn = embed_fn dims[0] = input_ch print(f"[INFO] Implicit network dims: {dims}") self.num_layers = len(dims) self.skip_in = skip_in self.weight_norm = weight_norm for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] if out_dim < 0: print(dims) else: out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if geometric_init: if l == self.num_layers - 2: torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, -bias) elif (embed_type or self.use_grid) and l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif (embed_type or self.use_grid) and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.activation = nn.Softplus(beta=100) self.output_activation = None if output_activation is not None: self.output_activation = activations[output_activation] def get_param_groups(self, lr): return [{'params': self.parameters(), 'lr': lr}] def forward(self, input): if self.embed_fn is not None: input = self.embed_fn(input) x = input for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if l in self.skip_in: x = torch.cat([x, input], 1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.activation(x) if self.output_activation is not None: x = self.output_activation(x) return x def gradient(self, x): x.requires_grad_(True) y = self.forward(x)[:,:1] d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad( outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients def feature(self, x): return self.forward(x)[:,1:] def get_outputs(self, x, returns_grad=True): x.requires_grad_(returns_grad) output = self.forward(x) sdf = output[:,:1] ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded ''' if self.sdf_bounding_sphere > 0.0: sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) sdf = torch.minimum(sdf, sphere_sdf) feature_vectors = output[:, 1:] if returns_grad: d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) gradients = torch.autograd.grad( outputs=sdf, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return sdf, feature_vectors, gradients else: return sdf, feature_vectors, None def get_sdf_vals(self, x): sdf = self.forward(x)[:,:1] ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded ''' if self.sdf_bounding_sphere > 0.0: sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) sdf = torch.minimum(sdf, sphere_sdf) return sdf activations = { 'sigmoid': nn.Sigmoid(), 'relu': nn.ReLU(), 'softplus': nn.Softplus() } class RenderingNetwork(nn.Module): def __init__( self, feature_vector_size, mode, d_in, d_out, dims, weight_norm=True, embed_type=None, embed_point=None, output_activation='sigmoid', **kwargs ): super().__init__() self.mode = mode dims = [d_in + feature_vector_size] + dims + [d_out] self.d_out = d_out self.embedview_fn = None if embed_type: embedview_fn, input_ch = get_embedder(embed_type, input_dims=3, **kwargs) self.embedview_fn = embedview_fn dims[0] += (input_ch - 3) if mode == 'idr': self.embedpoint_fn = None if embed_point is not None: embedpoint_fn, input_ch = get_embedder(input_dims=3, **embed_point) self.embedpoint_fn = embedpoint_fn dims[0] += (input_ch - 3) print(f"[INFO] Rendering network dims: {dims}") self.num_layers = len(dims) self.weight_norm = weight_norm for l in range(0, self.num_layers - 1): out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.activation = nn.ReLU() self.output_activation = activations[output_activation] def forward(self, points, normals, view_dirs, feature_vectors): if self.embedview_fn is not None: view_dirs = self.embedview_fn(view_dirs) if self.mode == 'idr': rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) # elif self.mode == 'nerf': else: rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1) x = rendering_input for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) x = lin(x) if l < self.num_layers - 2: x = self.activation(x) x = self.output_activation(x) return x ================================================ FILE: model/network/ray_sampler.py ================================================ import abc import torch from utils import rend_util import utils class RaySampler(metaclass=abc.ABCMeta): def __init__(self, near, far): self.near = near self.far = far @abc.abstractmethod def get_z_vals(self, ray_dirs, cam_loc, model): pass class UniformSampler(RaySampler): def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1): super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R self.N_samples = N_samples self.scene_bounding_sphere = scene_bounding_sphere self.take_sphere_intersection = take_sphere_intersection def get_z_vals(self, ray_dirs, cam_loc, model): if not self.take_sphere_intersection: near, far = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device), self.far * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device) else: sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere) near = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device) far = sphere_intersections[:,1:] t_vals = torch.linspace(0., 1., steps=self.N_samples, device=ray_dirs.device) z_vals = near * (1. - t_vals) + far * (t_vals) if model.training: # get intervals between samples mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = torch.cat([mids, z_vals[..., -1:]], -1) lower = torch.cat([z_vals[..., :1], mids], -1) # stratified samples in those intervals t_rand = torch.rand(z_vals.shape, device=z_vals.device) z_vals = lower + (upper - lower) * t_rand return z_vals class ErrorBoundSampler(RaySampler): def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra, eps, beta_iters, max_total_iters, inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0): super().__init__(near, 2.0 * scene_bounding_sphere) self.N_samples = N_samples self.N_samples_eval = N_samples_eval self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg) self.N_samples_extra = N_samples_extra self.eps = eps self.beta_iters = beta_iters self.max_total_iters = max_total_iters self.scene_bounding_sphere = scene_bounding_sphere self.add_tiny = add_tiny self.inverse_sphere_bg = inverse_sphere_bg if inverse_sphere_bg: self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0) def get_z_vals(self, ray_dirs, cam_loc, model): beta0 = model.density.get_beta().detach() # Start with uniform sampling z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model) samples, samples_idx = z_vals, None # Get maximum beta from the upper bound (Lemma 2) dists = z_vals[:, 1:] - z_vals[:, :-1] bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1) beta = torch.sqrt(bound) # beta = torch.sqrt(bound).clone() total_iters, not_converge = 0, True # Algorithm 1 while not_converge and total_iters < self.max_total_iters: points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1) points_flat = points.reshape(-1, 3) # Calculating the SDF only for the new sampled points with torch.no_grad(): samples_sdf = model.implicit_network.get_sdf_vals(points_flat) if samples_idx is not None: sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]), samples_sdf.reshape(-1, samples.shape[1])], -1) sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1) else: sdf = samples_sdf # Calculating the bound d* (Theorem 1) d = sdf.reshape(z_vals.shape) dists = z_vals[:, 1:] - z_vals[:, :-1] a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs() first_cond = a.pow(2) + b.pow(2) <= c.pow(2) second_cond = a.pow(2) + c.pow(2) <= b.pow(2) # d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1, device=z_vals.device) # d_star[first_cond] = b[first_cond] # d_star[second_cond] = c[second_cond] s = (a + b + c) / 2.0 area_before_sqrt = s * (s - a) * (s - b) * (s - c) mask = ~first_cond & ~second_cond & (b + c - a > 0) # d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask]) # Optimization: multiplication is 5-20 times faster than indexing first_cond = first_cond & ~second_cond d_star = first_cond * b + second_cond * c + torch.nan_to_num((2.0 * torch.sqrt(area_before_sqrt)) / a) * mask d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign # Updating beta using line search curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star) # beta[curr_error <= self.eps] = beta0 # Optimization: multiplication is 5-20 times faster than indexing beta0_mask = curr_error <= self.eps beta = beta * ~beta0_mask + beta0 * beta0_mask beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta for j in range(self.beta_iters): beta_mid = (beta_min + beta_max) / 2. curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star) # beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps] # beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps] beta_mid_mask = curr_error <= self.eps beta_max = beta_max * ~beta_mid_mask + beta_mid * beta_mid_mask beta_min = beta_min * beta_mid_mask + beta_mid * ~beta_mid_mask beta = beta_max # Upsample more points # tmp0 = beta.unsqueeze(-1).clone() # tmp0 = beta.unsqueeze(-1) # density = model.density(sdf.reshape(z_vals.shape), beta=tmp0) density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1)) # dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).unsqueeze(0).repeat(dists.shape[0], 1)], -1) dists = torch.cat([dists, torch.full([dists.shape[0], 1], 1e10, device=dists.device)], -1) free_energy = dists * density shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=free_energy.device), free_energy[:, :-1]], dim=-1) alpha = 1 - torch.exp(-free_energy) transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) weights = alpha * transmittance # probability of the ray hits something here # Check if we are done and this is the last sampling total_iters += 1 not_converge = beta.max() > beta0 if not_converge and total_iters < self.max_total_iters: ''' Sample more points proportional to the current error bound''' N = self.N_samples_eval bins = z_vals error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2) # tmp0 = beta.unsqueeze(-1).clone() # tmp1 = -d_star / tmp0 # tmp2 = dists[:,:-1] ** 2. # tmp3 = 4 * tmp0 ** 2 # error_per_section = tmp1 * tmp2 / tmp3 error_integral = torch.cumsum(error_per_section, dim=-1) bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1] pdf = bound_opacity + self.add_tiny pdf = pdf / torch.sum(pdf, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) else: ''' Sample the final sample set to be used in the volume rendering integral ''' N = self.N_samples bins = z_vals pdf = weights[..., :-1] pdf = pdf + 1e-5 # prevent nans pdf = pdf / torch.sum(pdf, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Invert CDF if (not_converge and total_iters < self.max_total_iters) or (not model.training): u = torch.linspace(0., 1., steps=N, device=cdf.device).unsqueeze(0).repeat(cdf.shape[0], 1) else: u = torch.rand(list(cdf.shape[:-1]) + [N], device=cdf.device) u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom_mask = denom < 1e-5 denom = denom_mask + ~denom_mask * denom # denom = torch.where(denom < 1e-5, 1.0, denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) # Adding samples if we not converged if not_converge and total_iters < self.max_total_iters: z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1) z_samples = samples near, far = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device), self.far * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device) if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:] if self.N_samples_extra > 0: if model.training: sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra] else: sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long() z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1) else: z_vals_extra = torch.cat([near, far], -1) z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1) # add some of the near surface points idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],), device=z_vals.device) z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1)) if self.inverse_sphere_bg: z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model) z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere) z_vals = (z_vals, z_vals_inverse_sphere) return z_vals, z_samples_eik def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star): density = model.density(sdf.reshape(z_vals.shape), beta=beta) shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=dists.device), dists * density[:, :-1]], dim=-1) integral_estimation = torch.cumsum(shifted_free_energy, dim=-1) error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2) error_integral = torch.cumsum(error_per_section, dim=-1) bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1]) return bound_opacity.max(-1)[0] ================================================ FILE: model/rendering/__init__.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import utils import cv2 from .brdf import * class RenderingLayer(nn.Module): def __init__(self, spp, split_n_pixels, preserve_light=True) -> None: super().__init__() self.spp = spp self.split_n_pixels = split_n_pixels self.preserve_light = preserve_light def forward( self, model, surface_points, view_direction, Kd, Ks, normal, rough, radiance_scale=None, intersect_func=None ): """ Render according to material, normal and lighting conditions Params: model: NeRF model to predict radiance surface_points, view_direction, albedo, normal, rough, metal: (bn, c) """ bn = normal.size(0) cx, cy, cz = create_frame(normal) wi_x = torch.sum(cx*view_direction, dim=1) wi_y = torch.sum(cy*view_direction, dim=1) wi_z = torch.sum(cz*view_direction, dim=1) wi = torch.stack([wi_x, wi_y, wi_z], dim=1) wi_mask = (wi[:,2] >= 0.00001) wi[:,2,...] = torch.where(wi[:,2,...] < 0.00001, torch.ones_like(wi[:,2,...]) * 0.00001, wi[:,2,...]) wi = F.normalize(wi, dim=1, eps=1e-6) # wi_mask = torch.where(wi[:,2:3,...] < 0, torch.zeros_like(wi[:,2:3,...]), torch.ones_like(wi[:,2:3,...])) wi = wi.unsqueeze(1) # (bn, 1, 3) # with torch.no_grad(): if True: samples = torch.rand(bn, self.spp, 3, device=normal.device) pS = probabilityToSampleSpecular(Kd, Ks) clamp_value = 0.0 pS.clamp_min_(clamp_value) sample_diffuse = samples[:,:,0] >= pS ls_diffuse = square_to_cosine_hemisphere(samples[:,:,1:]) ls_specular = sample_ggx_specular(samples[:,:,1:], rough, wi) wo = torch.where(sample_diffuse.unsqueeze(2).expand(bn, self.spp, 3), ls_diffuse, ls_specular) # (bn, spp, 3) pdfs = pdf_ggx(Kd, Ks, rough, wi, wo, clamp_value).unsqueeze(2) eval_diff, eval_spec, wo_mask = eval_ggx(Kd, Ks, rough, wi, wo) # wo_mask = torch.all(wo_mask, dim=1) direction = to_global(wo, cx.unsqueeze(1), cy.unsqueeze(1), cz.unsqueeze(1)) # surface_points = surface_points + 0.01 * view_direction # prevent self-intersection surface_points = surface_points.unsqueeze(1).expand_as(direction).reshape(-1, 3) direction = direction.reshape(-1, 3) surface_points = surface_points + direction * 0.01 # prevent self-intersection pts_splits = torch.split(surface_points, self.split_n_pixels, dim=0) dirs_splits = torch.split(direction, self.split_n_pixels, dim=0) radiance = [] # with torch.no_grad(): for pts, dirs in zip(pts_splits, dirs_splits): radiance.append(model.get_incident_radiance(pts, dirs, intersect_func)) radiance = torch.cat(radiance, dim=0) radiance = radiance.view(bn, self.spp, 3) if radiance_scale is not None: radiance = radiance * radiance_scale[None,None,:] pdfs = torch.clamp(pdfs, min=0.00001) ndl = torch.clamp(wo[:,:,2:], min=0) brdfDiffuse = eval_diff.expand(bn, self.spp, 3) * ndl / pdfs colorDiffuse = torch.mean(brdfDiffuse * radiance, dim=1) brdfSpec = eval_spec.expand(bn, self.spp, 3) * ndl / pdfs colorSpec = torch.mean(brdfSpec * radiance, dim=1) return colorDiffuse, colorSpec, wi_mask ================================================ FILE: model/rendering/brdf.py ================================================ import torch import torch.nn.functional as F import numpy as np def create_frame(n: torch.Tensor, eps:float = 1e-6): """ Generate orthonormal coordinate system based on surface normal [Duff et al. 17] Building An Orthonormal Basis, Revisited. JCGT. 2017. :param: n (bn, 3, ...) """ z = F.normalize(n, dim=1, eps=eps) sgn = torch.where(z[:,2,...] >= 0, 1.0, -1.0) a = -1.0 / (sgn + z[:,2,...]) b = z[:,0,...] * z[:,1,...] * a x = torch.stack([1.0 + sgn * z[:,0,...] * z[:,0,...] * a, sgn * b, -sgn * z[:,0,...]], dim=1) y = torch.stack([b, sgn + z[:,1,...] * z[:,1,...] * a, -z[:,1,...]], dim=1) return x, y, z def get_rendering_parameters(albedo_raw, rough_raw, use_metallic): if use_metallic: assert albedo_raw.size(-1) == 3 and rough_raw.size(-1) == 2 metal = rough_raw[:,1:] rough = rough_raw[:,:1].clamp_min(0.01) Ks = baseColorToSpecularF0(albedo_raw, metal) Kd = albedo_raw * (1 - metal) else: assert albedo_raw.size(-1) == 6 and rough_raw.size(-1) == 1 Kd = albedo_raw[:,:3] Ks = albedo_raw[:,3:].clamp_min(0.04) rough = rough_raw.clamp_min(0.01) return Kd, Ks, rough def to_global(d, x, y, z): """ d, x, y, z: (*, 3) """ return d[...,0:1] * x + d[...,1:2] * y + d[...,2:3] * z def sqrt_(x: torch.Tensor, eps=1e-8) -> torch.Tensor: """ clamping 0 values of sqrt input to avoid NAN gradients """ return torch.sqrt(torch.clamp(x, min=eps)) def reflect(v: torch.Tensor, h: torch.Tensor): dot = torch.sum(v*h, dim=2, keepdim=True) return 2 * dot * h - v def square_to_cosine_hemisphere(sample: torch.Tensor): u, v = sample[:,:,0,...], sample[:,:,1,...] phi = u * 2 * np.pi r = sqrt_(v) cos_theta = sqrt_(torch.clamp(1 - v, 0)) return torch.stack([torch.cos(phi) * r, torch.sin(phi) * r, cos_theta], dim=2) def get_cos_theta(v: torch.Tensor): return v[:,:,2,...] def get_phi(v: torch.Tensor): cos_theta = torch.clamp(v[:,:,2,...], min=0, max=1) sin_theta = torch.clamp(sqrt_(1 - cos_theta*cos_theta), min=1e-8) cos_phi = torch.clamp(v[:,:,0,...] / sin_theta, -1, 1) sin_phi = v[:,:,1,...] / sin_theta phi = torch.acos(cos_phi) # (0, pi) return torch.where(sin_phi > 0, phi, 2*np.pi - phi) def sample_disney_specular(sample: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor): """ :param: sample (bn, spp, 3, h, w) :param: roughness (bn, 1, 1, h, w) :param: wi (*, *, 3, h, w), supposed to be normalized :return: wo (bn, spp, 3, h, w), phi (bn, spp, h, w), cos theta (bn, spp, h, w) """ # a = torch.clamp(roughness, 0.001) a = roughness u, v = sample[:,:,0,...], sample[:,:,1,...] phi = u * 2 * np.pi cos_theta = sqrt_((1 - v) / (1 + (a*a - 1) * v)) sin_theta = sqrt_(1 - cos_theta*cos_theta) cos_phi = torch.cos(phi) sin_phi = torch.sin(phi) half = torch.stack([sin_theta*cos_phi, sin_theta*sin_phi, cos_theta], dim=2) wo = F.normalize(reflect(wi.expand_as(half), half), dim=2, eps=1e-8) return wo #, phi.squeeze(2), cos_theta.squeeze(2) def GTR2(ndh, a): a2 = a*a t = 1.0 + (a2 - 1.0) * ndh * ndh return a2 / (np.pi * t * t) def SchlickFresnel(u): m = torch.clamp(1.0 - u, 0, 1) return m**5 def smithG_GGX(ndv, a): a = a*a b = ndv*ndv return 1.0 / (ndv + sqrt_(a + b - a * b)) def pdf_disney(roughness: torch.Tensor, metallic: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor): """ :param: roughness/metallic (bn, 1, h, w) :param: wi (*, *, 3, h, w), supposed to be normalized :param: wo (*, *, 3, h, w), supposed to be normalized """ # specularAlpha = torch.clamp(roughness, 0.001) specularAlpha = roughness diffuseRatio = 0.5 * (1 - metallic) specularRatio = 1 - diffuseRatio half = F.normalize(wi + wo, dim=2, eps=1e-8) cosTheta = torch.abs(half[:,:,2,...]) pdfGTR2 = GTR2(cosTheta, specularAlpha) * cosTheta pdfSpec = pdfGTR2 / torch.clamp(4.0 * torch.abs(torch.sum(wo*half, dim=2)), min=1e-8) pdfDiff = torch.abs(wo[:,:,2,...]) / np.pi pdf = diffuseRatio * pdfDiff + specularRatio * pdfSpec pdf = torch.where(wi[:,:,2,...] < 0.0001, torch.ones_like(pdf) * 0.0001, pdf) pdf = torch.where(wo[:,:,2,...] < 0.0001, torch.ones_like(pdf) * 0.0001, pdf) return pdf def eval_disney(albedo: torch.Tensor, roughness: torch.Tensor, metallic: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor): """ :param: albedo/roughness/metallic (bn, c, h, w) :param: wi (*, *, 3, h, w), supposed to be normalized :param: wo (*, *, 3, h, w), supposed to be normalized """ h = wi + wo; h = F.normalize(h, dim=2, eps=1e-8) CSpec0 = torch.lerp(torch.ones_like(albedo)*0.04, albedo, metallic).unsqueeze(1) ldh = torch.clamp(torch.sum( (wo * h), dim = 2), 0, 1).unsqueeze(2) ndv = wi[:,:,2:3,...] ndl = wo[:,:,2:3,...] ndh = h[:,:,2:3,...] FL, FV = SchlickFresnel(ndl), SchlickFresnel(ndv) roughness = roughness.unsqueeze(1) Fd90 = 0.5 + 2.0 * ldh * ldh * roughness Fd = torch.lerp(torch.ones_like(Fd90), Fd90, FL) * torch.lerp(torch.ones_like(Fd90), Fd90, FV) Ds = GTR2(ndh, roughness) FH = SchlickFresnel(ldh) Fs = torch.lerp(CSpec0, torch.ones_like(CSpec0), FH) roughg = (roughness * 0.5 + 0.5) ** 2 Gs1, Gs2 = smithG_GGX(ndl, roughg), smithG_GGX(ndv, roughg) Gs = Gs1 * Gs2 eval_diff = Fd * albedo.unsqueeze(1) * (1.0 - metallic.unsqueeze(1)) / np.pi eval_spec = Gs * Fs * Ds mask = torch.where(ndl < 0, torch.zeros_like(ndl), torch.ones(ndl)) return eval_diff, eval_spec, mask def F_Schlick(SpecularColor, VoH): Fc = (1 - VoH)**5 return torch.clamp(50.0 * SpecularColor[:,:,1:2,...], min=0, max=1) * Fc + (1 - Fc) * SpecularColor def GetSpecularEventProbability(SpecularColor, NoV) -> torch.Tensor: f = F_Schlick(SpecularColor, NoV); return (f[:,:,0,...] + f[:,:,1,...] + f[:,:,2,...]) / 3 def baseColorToSpecularF0(baseColor, metalness): return torch.lerp(torch.empty_like(baseColor).fill_(0.04), baseColor, metalness) def luminance(color): if color.size(1) == 1: return color # return color.mean(dim=1, keepdim=True) return color[:,0:1,...] * 0.212671 + color[:,1:2,...] * 0.715160 + color[:,2:3,...] * 0.072169 def probabilityToSampleSpecular(difColor, specColor) -> torch.Tensor: lumDiffuse = torch.clamp(luminance(difColor), min=0.01) lumSpecular = torch.clamp(luminance(specColor), min=0.01) return lumSpecular / (lumDiffuse + lumSpecular) def shadowedF90(F0): t = 1 / 0.04 return torch.clamp(t * luminance(F0), max=1) def evalFresnel(f0, f90, NdotS): # print(f0.shape, f90.shape, NdotS.shape) return f0 + (f90 - f0) * (1 - NdotS)**5 def Smith_G1_GGX(alphaSquared, NdotSSquared): return 2 / (sqrt_(((alphaSquared * (1 - NdotSSquared)) + NdotSSquared) / NdotSSquared) + 1) def Smith_G2_GGX(alphaSquared, NdotL, NdotV): a = NdotV * sqrt_(alphaSquared + NdotL * (NdotL - alphaSquared * NdotL)) b = NdotL * sqrt_(alphaSquared + NdotV * (NdotV - alphaSquared * NdotV)) return 0.5 / (a + b) def GGX_D(alphaSquared, NdotH): b = ((alphaSquared - 1) * NdotH * NdotH + 1) return alphaSquared / (np.pi * b * b) def pdf_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor, ps_min=0.0): """ :param: color (bn, 3, h, w) :param: roughness/metallic (bn, 1, h, w) :param: wi (*, *, 3, h, w), supposed to be normalized :param: wo (*, *, 3, h, w), supposed to be normalized :return: pdf (*, *, h, w) """ alpha = roughness * roughness alphaSquared = alpha * alpha NdotV = wi[:,:,2,...] h = F.normalize(wi + wo, dim=2, eps=1e-8) NdotH = h[:,:,2,...] # print(alphaSquared.min(), NdotH.min(), NdotV.min()) ggxd = GGX_D(torch.clamp(alphaSquared, min=0.00001), NdotH) smith = Smith_G1_GGX(alphaSquared, NdotV * NdotV) # pdf_spec = GGX_D(torch.clamp(alphaSquared, min=0.00001), NdotH) * Smith_G1_GGX(alphaSquared, NdotV * NdotV) / (4 * NdotV) pdf_spec = ggxd * smith / (4 * NdotV) # print(torch.any(torch.isnan(ggxd)), torch.any(torch.isnan(smith)), torch.any(torch.isnan(NdotV))) # print(NdotV.min(), ggxd.min(), smith.min()) with torch.no_grad(): pS = probabilityToSampleSpecular(Kd, Ks).clamp_min(ps_min) pdf_diff = wo[:,:,2,...] / np.pi # print("#########################################") # print("#########################################") # print("#########################################") # print(torch.any(torch.isnan(kS)), torch.any(torch.isnan(pdf_spec)), torch.any(torch.isnan(pdf_diff))) # print("#########################################") # print("#########################################") # print("#########################################") pdf = pS * pdf_spec + (1 - pS) * pdf_diff pdf = torch.where(wi[:,:,2,...] <= 0.0001, torch.ones_like(pdf) * 0.0001, pdf) pdf = torch.where(wo[:,:,2,...] <= 0.0001, torch.ones_like(pdf) * 0.0001, pdf) return pdf def eval_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor): """ :param: color (bn, c, h, w) :param: roughness/metallic (bn, 1, h, w) :param: wi (*, *, c, h, w), supposed to be normalized :param: wo (*, *, c, h, w), supposed to be normalized :return: fr(wi, wo) (*, *, c, h, w) """ NDotL = wo[:,:,2:3,...] NDotV = wi[:,:,2:3,...] H = F.normalize(wi + wo, dim=2, eps=1e-8) NDotH = H[:,:,2:3,...] LDotH = torch.sum(wo*H, dim=2, keepdim=True) roughness = roughness.unsqueeze(1) alpha = roughness * roughness alpha2 = alpha * alpha D = GGX_D(torch.clamp(alpha2, min=0.00001), NDotH) G2 = Smith_G2_GGX(alpha2, NDotL, NDotV) f = evalFresnel(Ks.unsqueeze(1), shadowedF90(Ks).unsqueeze(1), LDotH) # spec = torch.where(NDotL <= 0, torch.zeros_like(NDotL), f * G2 * D) # mask = torch.where(NDotL <= 0, torch.zeros_like(NDotL), torch.ones_like(NDotL)) spec = torch.where(NDotL < 0.0001, torch.ones_like(NDotL) * 0.0001, f * G2 * D) # mask = torch.where(NDotL <= 0, torch.zeros_like(NDotL), torch.ones_like(NDotL)) mask = (NDotL >= 0.0001).squeeze(-1) return Kd.unsqueeze(1) / np.pi, spec, mask def sample_weight_ggx(alphaSquared, NdotL, NdotV): G1V = Smith_G1_GGX(alphaSquared, NdotV*NdotV) G1L = Smith_G1_GGX(alphaSquared, NdotL*NdotL) return G1L / (G1V + G1L - G1V * G1L) def sample_ggx(sample: torch.Tensor, Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor): """ :param: sample (bn, spp, 3, h, w) :param: roughness (bn, 1, h, w) :param: wi (*, *, 3, h, w), supposed to be normalized :return: wo (bn, spp, 3, h, w), weight (bn, spp, 3, h, w) """ with torch.no_grad(): pS = probabilityToSampleSpecular(Kd, Ks) sample_diffuse = sample[:,:,2,...] >= pS wo_diff = square_to_cosine_hemisphere(sample[:,:,1:,...]) weight_diff = Kd / (1 - pS) weight_diff = weight_diff.unsqueeze(1) roughness = roughness.unsqueeze(1) alpha = roughness * roughness # alpha = roughness Vh = F.normalize(torch.cat([alpha * wi[:,:,0:1,...], alpha * wi[:,:,1:2,...], wi[:,:,2:3,...]], dim=2), dim=2, eps=1e-8) lensq = Vh[:,:,0:1,...]**2 + Vh[:,:,1:2,...]**2 zero_ = torch.zeros_like(Vh[:,:,0,...]) one_ = torch.ones_like(Vh[:,:,0,...]) T1 = torch.where( lensq > 0, torch.stack([-Vh[:,:,1,...], Vh[:,:,0,...], zero_], dim=2) / sqrt_(lensq), torch.stack([one_, zero_, zero_], dim=2) ) T2 = torch.cross(Vh, T1, dim=2) r = sqrt_(sample[:,:,0:1,...]) phi = 2 * np.pi * sample[:,:,1:2,...] t1 = r * torch.cos(phi) t2 = r * torch.sin(phi) s = 0.5 * (1 + Vh[:,:,2:3,...]) t2 = torch.lerp(sqrt_(1 - t1**2), t2, s) Nh = t1 * T1 + t2 * T2 + sqrt_(torch.clamp(1 - t1*t1 - t2*t2, min=0)) * Vh h = F.normalize(torch.cat([alpha * Nh[:,:,0:1,...], alpha * Nh[:,:,1:2,...], torch.clamp(Nh[:,:,2:3,...], min=0)], dim=2), dim=2, eps=1e-8) wo = reflect(wi, h) HdotL = torch.clamp(torch.sum(h*wo, dim=2, keepdim=True), min=0.0001, max=1.0) NdotL = torch.clamp(wo[:,:,2:3,...], min=0.0001, max=1.0) NdotV = torch.clamp(wi[:,:,2:3,...], min=0.0001, max=1.0) # NdotH = torch.clamp(h[:,:,2:3,...], min=0.00001, max=1.0) # F = evalFresnel(specularF0, shadowedF90(specularF0), HdotL) weight = evalFresnel(Ks, shadowedF90(Ks), HdotL) * sample_weight_ggx(alpha*alpha, NdotL, NdotV) / pS.unsqueeze(1) wo = torch.where(sample_diffuse.unsqueeze(2), wo_diff, wo) weight = torch.where(sample_diffuse.unsqueeze(2), weight_diff, weight) return wo, weight def sample_ggx_specular(sample: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor): """ :param: sample (bn, spp, 2, h, w) :param: roughness (bn, 1, h, w) :param: wi (*, *, 3, h, w), supposed to be normalized :return: wo (bn, spp, 3, h, w), phi (bn, spp, h, w), cos theta (bn, spp, h, w) """ roughness = roughness.unsqueeze(1) alpha = roughness * roughness # alpha = roughness Vh = F.normalize(torch.cat([alpha * wi[:,:,0:1,...], alpha * wi[:,:,1:2,...], wi[:,:,2:3,...]], dim=2), dim=2, eps=1e-8) # bn, spp, _, row, col = Vh.shape # Vh = Vh.view(-1, 3, row, col) # T1, T2, Vh = utils.hughes_moeller(Vh) # T1 = T1.view(bn, spp, 3, row, col) # T2 = T2.view(bn, spp, 3, row, col) # Vh = Vh.view(bn, spp, 3, row, col) lensq = Vh[:,:,0:1,...]**2 + Vh[:,:,1:2,...]**2 zero_ = torch.zeros_like(Vh[:,:,0,...]) one_ = torch.ones_like(Vh[:,:,0,...]) T1 = torch.where( lensq > 0, torch.stack([-Vh[:,:,1,...], Vh[:,:,0,...], zero_], dim=2) / sqrt_(lensq), torch.stack([one_, zero_, zero_], dim=2) ) T2 = torch.cross(Vh, T1, dim=2) r = sqrt_(sample[:,:,0:1,...]) phi = 2 * np.pi * sample[:,:,1:2,...] t1 = r * torch.cos(phi) t2 = r * torch.sin(phi) s = 0.5 * (1 + Vh[:,:,2:3,...]) t2 = torch.lerp(sqrt_(1 - t1**2), t2, s) Nh = t1 * T1 + t2 * T2 + sqrt_(torch.clamp(1 - t1*t1 - t2*t2, min=0)) * Vh h = F.normalize(torch.cat([alpha * Nh[:,:,0:1,...], alpha * Nh[:,:,1:2,...], torch.clamp(Nh[:,:,2:3,...], min=0)], dim=2), dim=2, eps=1e-8) wo = reflect(wi, h) return wo ================================================ FILE: model/trainer/__init__.py ================================================ from .recon import ReconstructionTrainer ================================================ FILE: model/trainer/recon.py ================================================ import math import torch import pytorch_lightning as pl import numpy as np import torch.optim as optim import os from torch.utils.data import DataLoader import utils from utils import rend_util import utils.plots as plt import dataset import model from tqdm import trange from pytorch_lightning.callbacks import RichProgressBar from torchmetrics.functional import structural_similarity_index_measure as ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" import cv2 lpips = LPIPS() class ReconstructionTrainer(pl.LightningModule): def __init__(self, conf, prog_bar: RichProgressBar, exp_dir, model_only=False, val_mesh=False, is_val=False, **kwargs) -> None: super().__init__() self.conf = conf self.prog_bar = prog_bar self.batch_size = conf.train.batch_size self.bubble_batch_size = getattr(conf.train, 'bubble_batch_size', self.batch_size) self.expdir = exp_dir self.val_mesh = val_mesh conf_model = conf.model use_normal = (getattr(conf.loss, 'normal_weight', 0) > 0) or (getattr(conf.loss, 'angular_weight', 0) > 0) conf_model.use_normal = use_normal self.model = model.I2SDFNetwork(conf_model) if model_only: return print('[INFO] Loading data ...') dataset_conf = conf.dataset self.scan_id = dataset_conf.scan_id self.train_dataset = dataset.ReconDataset( **dataset_conf, use_mask=getattr(conf.loss, 'mask_weight', 0) > 0, use_depth=getattr(conf.loss, 'depth_weight', 0) > 0, use_normal=use_normal, use_bubble=getattr(conf.loss, 'bubble_weight', 0) > 0, use_lightmask=getattr(conf.loss, 'light_mask_weight', 0) > 0 ) if self.train_dataset.use_bubble: os.makedirs(os.path.join(self.expdir, 'hotmap'), exist_ok=True) os.makedirs(os.path.join(self.expdir, 'countmap'), exist_ok=True) self.pdf_criterion = getattr(conf.train, 'pdf_criterion', 'DEPTH') assert self.pdf_criterion in ['RGB', 'DEPTH'] self.is_hdr = self.train_dataset.is_hdr self.plots_dir = os.path.join(self.expdir, 'plots') self.trace_bub_idx = self.conf.train.get('trace_bub_idx', -1) if self.trace_bub_idx != -1: os.makedirs(f"{self.plots_dir}/bubble", exist_ok=True) print(f"[INFO] Activate hotmap visualization for #{self.trace_bub_idx}") self.plot_dataset = dataset.PlotDataset(**dataset_conf, indices=[self.trace_bub_idx], plot_nimgs=1, is_val=is_val) else: data = { 'intrinsics': self.train_dataset.intrinsics_all, 'pose': self.train_dataset.pose_all, 'rgb': self.train_dataset.rgb_images, 'img_res': self.train_dataset.img_res } if self.train_dataset.use_lightmask: data['light_mask'] = self.train_dataset.lightmask_images self.plot_dataset = dataset.PlotDataset(**dataset_conf, data=data, plot_nimgs=conf.plot.plot_nimgs, is_val=is_val) os.makedirs(self.plots_dir, exist_ok=True) with open(f"{self.expdir}/config.yml", 'w') as f: f.write(self.conf.dump()) if self.train_dataset.use_bubble: points = self.train_dataset.pointcloud index = torch.randperm(points.size(0))[:200000] points = points[index,:] plt.visualize_pointcloud(points, f"{self.expdir}/pointcloud.html") print(f"[INFO] Pointcloud visualization success: saved to {self.expdir}/pointcloud.html") self.pdf_prune = self.train_dataset.pdf_prune self.pdf_max = self.train_dataset.pdf_max # self.ds_len = len(self.train_dataset) self.ds_len = self.train_dataset.n_images print('[INFO] Finish loading data. Data-set size: {0}'.format(self.ds_len)) epoch_steps = len(self.train_dataset) / self.batch_size self.nepochs = int(math.ceil(200000 / epoch_steps)) self.loss = model.I2SDFLoss(**conf.loss) self.total_pixels = self.plot_dataset.total_pixels self.img_res = self.plot_dataset.img_res self.bubble_activated = False self.uniform_bubble = getattr(self.conf.train, 'uniform_bubble', False) if self.uniform_bubble: print("[INFO] Ablation study: uniform sampling for bubble loss") self.checkpoint_freq = self.conf.train.checkpoint_freq self.split_n_pixels = self.conf.train.split_n_pixels self.plot_conf = self.conf.plot self.progbar_task = None if self.train_dataset.use_lightmask and getattr(self.conf.train, 'flip_light', False): self.train_dataset.lightmask_images = 1.0 - self.train_dataset.lightmask_images self.plot_dataset.lightmask_images = 1.0 - self.plot_dataset.lightmask_images def forward(self): raise NotImplementedError("forward not supported by trainer") def plot_hotmap(self, path): assert self.bubble_activated ds = self.train_dataset hotmaps = torch.zeros(self.ds_len * ds.total_pixels) hotmaps[ds.pixlinks] = self.pdf.cpu() hotmaps = hotmaps.reshape(self.ds_len, *ds.img_res) for i, hotmap in enumerate(hotmaps): hotmap = hotmap.numpy() # hotmap /= max(1e-4, hotmap.max()) hotmap = (hotmap * 255).astype(np.uint8) hotmap = cv2.applyColorMap(hotmap, cv2.COLORMAP_MAGMA) cv2.imwrite(os.path.join(path, "{:04d}.png".format(i)), hotmap) if self.trace_bub_idx == i: cv2.imwrite(os.path.join(f"{self.plots_dir}/bubble", f"{self.global_step}_hot.png"), hotmap) def plot_countmap(self, path): assert self.bubble_activated ds = self.train_dataset countmaps = torch.zeros(self.ds_len * ds.total_pixels) countmaps[ds.pixlinks] = self.sample_count.cpu().float() countmaps = countmaps.reshape(self.ds_len, *ds.img_res) countmaps = countmaps / max(1, countmaps.max()) for i, countmap in enumerate(countmaps): countmap = countmap.numpy() countmap = (countmap * 255).astype(np.uint8) countmap = cv2.applyColorMap(countmap, cv2.COLORMAP_MAGMA) cv2.imwrite(os.path.join(path, "{:04d}.png".format(i)), countmap) if self.trace_bub_idx == i: cv2.imwrite(os.path.join(f"{self.plots_dir}/bubble", f"{self.global_step}_cnt.png"), countmap) def update_pdf(self, value, idx): assert self.bubble_activated ds = self.train_dataset value = value.to(self.pdf.device) if self.pdf_max is not None: value = value.clamp(max=self.pdf_max) value[value < self.pdf_prune] = 0 # PDF pruning link = ds.pointlinks[idx] mask = (link != -1) link = link[mask] value = value[mask] self.pdf[link] = value def sample_bubble(self, batch_size): assert self.bubble_activated ds = self.train_dataset if self.uniform_bubble: sample_idx = torch.randperm(ds.pointcloud.size(0), device=ds.pointcloud.device)[:batch_size] return ds.pointcloud[sample_idx,:] sample_idx = torch.where(self.pdf > 0)[0] pdf_samples = self.pdf[sample_idx] pointcloud_samples = ds.pointcloud[sample_idx,:] if sample_idx.size(0) >= (1 << 24): # print(sample_idx.size(0), self.pdf.size(0), (1 << 24)) print("[ERROR] PDF capacity exceeds maximum limit of PyTorch") exit(1) idx = torch.multinomial(pdf_samples, batch_size, replacement=False) # importance sampling self.sample_count[sample_idx[idx]] += 1 return pointcloud_samples[idx,:] def initialize_bubble_pdf(self, split_size): ds = self.train_dataset # ds.pdf = ds.pdf.cuda() self.register_buffer('pdf', torch.zeros(len(ds.pointcloud)), False) self.register_buffer('sample_count', torch.zeros(len(ds.pointcloud)), False) self.pdf = self.pdf.cuda() # self.sample_count = self.sample_count.cuda() for i in trange(ds.n_images): intrinsics = ds.intrinsics_all[i].cuda().unsqueeze(0) pose = ds.pose_all[i].cuda().unsqueeze(0) img = ds.rgb_images[i].cuda() if self.pdf_criterion != 'DEPTH' else ds.depth_images[i].cuda() uv = ds.uv.cuda().unsqueeze(1) # (h*w, 1, 2) img_splits = torch.split(img, split_size) uv_splits = torch.split(uv, split_size) indices = torch.arange(i * ds.total_pixels, (i + 1) * ds.total_pixels, dtype=torch.long, device='cuda') index_splits = torch.split(indices, split_size) for img_split, uv_split, index_split in zip(img_splits, uv_splits, index_splits): data = { 'uv': uv_split, 'intrinsics': intrinsics.repeat(len(uv_split), 1, 1), 'pose': pose.repeat(len(uv_split), 1, 1) } model_output = self.model.forward(data, True) if self.pdf_criterion == 'RGB': self.update_pdf((model_output['rgb_values'].detach().clamp(0, 1) - img_split.clamp(0, 1)).abs().mean(dim=-1), index_split) # elif self.pdf_criterion == 'DEPTH': else: self.update_pdf((model_output['depth_values'].detach() - img_split).abs(), index_split) def configure_optimizers(self): lr = self.conf.train.learning_rate optimizer = optim.Adam(self.model.get_param_groups(lr), eps=1e-15) decay_rate = getattr(self.conf.train, 'sched_decay_rate', 0.1) decay_steps = self.nepochs * self.ds_len scheduler = optim.lr_scheduler.ExponentialLR(optimizer, decay_rate ** (1./decay_steps)) return [optimizer], [scheduler] def train_dataloader(self): return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.train_dataset.collate_fn, num_workers=4) def val_dataloader(self): return DataLoader(self.plot_dataset, batch_size=self.conf.plot.plot_nimgs, shuffle=False, collate_fn=self.train_dataset.collate_fn) def log_if_nonzero(self, name, value, *args, **kwargs): if value > 0: self.log(name, value, *args, **kwargs) def training_step(self, batch, batch_idx): indices, img_indices, model_input, ground_truth = batch if not self.bubble_activated and self.train_dataset.use_bubble and self.global_step >= self.loss.min_bubble_iter and self.global_step < self.loss.max_bubble_iter: # Start bubble step with torch.no_grad(): self.bubble_activated = True self.train_dataset.pointcloud = self.train_dataset.pointcloud.cuda() # self.loss.eikonal_weight = self.loss.eikonal_weight_pointcloud # Disable normal loss, since it will discourage the growth of bubbles self.loss.normal_weight_bak = self.loss.normal_weight self.loss.normal_weight = 0.0 self.loss.angular_weight_bak = self.loss.angular_weight self.loss.angular_weight = 0.0 if not self.uniform_bubble: print(f"[INFO] Start to initializing pointcloud PDF, criterion: {self.pdf_criterion}") self.initialize_bubble_pdf(self.split_n_pixels) # initialize PDF maps for each image by computing losses torch.save(self.pdf, os.path.join(self.expdir, 'checkpoints', "pdf.pt")) torch.cuda.empty_cache() self.plot_hotmap(os.path.join(self.expdir, 'hotmap')) print("[INFO] Finish to initializing pointcloud PDF") print(f"[INFO] {torch.count_nonzero(self.pdf).item()}/{self.pdf.size(0)} points to be sampled") if self.bubble_activated: model_input['pointcloud'] = self.sample_bubble(self.bubble_batch_size) model_outputs = self.model(model_input) if self.bubble_activated and not self.uniform_bubble: with torch.no_grad(): if self.pdf_criterion == 'RGB': self.update_pdf((model_outputs['rgb_values'].detach().clamp(0, 1) - ground_truth['rgb'].clamp(0, 1)).abs().mean(dim=-1), indices) # elif self.pdf_criterion == 'DEPTH': else: self.update_pdf((model_outputs['depth_values'].detach() - ground_truth['depth']).abs(), indices) loss_output = self.loss(model_outputs, ground_truth, self.global_step) if self.bubble_activated and self.loss.max_bubble_iter is not None and self.global_step >= self.loss.max_bubble_iter: # End bubble step self.train_dataset.use_bubble = False self.bubble_activated = False del self.train_dataset.pointcloud del self.train_dataset.pointlinks del self.train_dataset.pixlinks if not self.uniform_bubble: delattr(self, 'pdf') delattr(self, 'sample_count') torch.cuda.empty_cache() # self.loss.eikonal_weight = self.conf.loss.eikonal_weight # Restore normal loss self.loss.normal_weight = self.loss.normal_weight_bak self.loss.angular_weight = self.loss.angular_weight_bak loss = loss_output['loss'] with torch.no_grad(): psnr = rend_util.get_psnr(model_outputs['rgb_values'].detach(), ground_truth['rgb'].view(-1, 3)) self.log('train/loss', loss.item()) self.log('train/psnr', psnr.item(), True) self.log('train/rgb_loss', loss_output['rgb_loss'].item()) self.log_if_nonzero('train/eikonal_loss', loss_output['eikonal_loss'].item()) self.log_if_nonzero('train/smooth_loss', loss_output['smooth_loss'].item()) self.log_if_nonzero('train/mask_loss', loss_output['mask_loss'].item()) self.log_if_nonzero('train/depth_loss', loss_output['depth_loss'].item()) self.log_if_nonzero('train/normal_loss', loss_output['normal_loss'].item()) self.log_if_nonzero('train/angular_loss', loss_output['angular_loss'].item()) self.log_if_nonzero('train/bubble_loss', loss_output['bubble_loss'].item()) self.log_if_nonzero('train/light_mask_loss', loss_output['light_mask_loss'].item()) self.log('train/beta', self.model.density.beta.item()) return loss def validation_step(self, batch, batch_idx): indices, model_input, ground_truth = batch split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels) res = [] if self.progbar_task is None and self.prog_bar.progress: self.progbar_task = self.prog_bar.progress.add_task("[cyan]Validation split", total=len(split)) elif self.progbar_task: self.prog_bar.progress.reset(self.progbar_task, total=len(split), visible=True) for s in split: out = utils.detach_dict(self.model(s)) d = { 'rgb_values': out['rgb_values'].detach(), 'depth_values': out['depth_values'].detach() } if 'normal_map' in out: d['normal_map'] = out['normal_map'].detach() if 'light_mask' in out: d['light_mask'] = out['light_mask'].detach() del out res.append(d) if self.progbar_task: self.prog_bar.progress.update(self.progbar_task, advance=1, refresh=True) if self.progbar_task: self.prog_bar.progress.update(self.progbar_task, visible=False) batch_size = ground_truth['rgb'].shape[0] model_outputs = utils.merge_output(res, self.total_pixels, batch_size) def get_plot_data(model_outputs, pose, ground_truth): rgb_gt = ground_truth['rgb'] batch_size, num_samples, _ = rgb_gt.shape rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3) if self.is_hdr: eval_hdr = rgb_eval gt_hdr = rgb_gt rgb_eval = rend_util.linear_to_srgb(rgb_eval.clamp(0, 1)) rgb_gt = rend_util.linear_to_srgb(rgb_gt.clamp(0, 1)) depth_eval = model_outputs['depth_values'].reshape(batch_size, num_samples, 1) plot_data = { 'rgb_gt': rgb_gt, 'pose': pose, 'rgb_eval': rgb_eval, 'depth_eval': depth_eval } if self.is_hdr: plot_data['hdr_gt'] = gt_hdr plot_data['hdr_eval'] = eval_hdr if 'normal_map' in model_outputs: normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3) normal_map = normal_map.transpose(1, 2) # (bn, 3, h*w) R = pose[:,:3,:3].transpose(1, 2) normal_map = torch.bmm(R, normal_map) # world to camera normal_map = normal_map.transpose(1, 2) normal_map = (normal_map + 1.) / 2. plot_data['normal_map'] = normal_map if 'light_mask' in model_outputs: plot_data['lmask_eval'] = model_outputs['light_mask'].reshape(batch_size, num_samples, 1) plot_data['lmask_gt'] = ground_truth['light_mask'].reshape(batch_size, num_samples, 1) return plot_data plot_data = get_plot_data(model_outputs, model_input['pose'], ground_truth) return { 'indices': indices, 'plot_data': plot_data } def validation_epoch_end(self, outputs) -> None: self.plot_dataset.shuffle_plot_index() indices = torch.cat([x['indices'] for x in outputs], dim=0) plot_data = utils.merge_dict([x['plot_data'] for x in outputs]) rgb_eval = plot_data['rgb_eval'] rgb_gt = plot_data['rgb_gt'] psnr = rend_util.get_psnr(rgb_eval, rgb_gt) self.log('val/psnr', psnr.item()) rgb_gt = rgb_gt.transpose(1, 2).view(-1, 3, *self.img_res) # (bn, h*w, 3) => (bn, 3, h, w) rgb_eval = rgb_eval.transpose(1, 2).view(-1, 3, *self.img_res) self.log('val/ssim', ssim(rgb_eval, rgb_gt).item()) lpips.to(rgb_eval.device) self.log('val/lpips', lpips(rgb_eval.clamp(0, 1) * 2 - 1, rgb_gt.clamp(0, 1) * 2 - 1).item()) os.makedirs(self.plots_dir, exist_ok=True) os.makedirs('{0}/rendering'.format(self.plots_dir), exist_ok=True) if self.is_hdr: os.makedirs('{0}/hdr'.format(self.plots_dir), exist_ok=True) os.makedirs('{0}/depth'.format(self.plots_dir), exist_ok=True) if 'normal_map' in plot_data: os.makedirs('{0}/normal'.format(self.plots_dir), exist_ok=True) if 'lmask_eval' in plot_data: os.makedirs('{0}/light_mask'.format(self.plots_dir), exist_ok=True) if self.val_mesh: os.makedirs('{0}/mesh'.format(self.plots_dir), exist_ok=True) if self.bubble_activated and not self.uniform_bubble: self.plot_hotmap(os.path.join(self.expdir, 'hotmap')) self.plot_countmap(os.path.join(self.expdir, 'countmap')) plt.plot(self.model.implicit_network, indices, plot_data, self.plots_dir, self.global_step, self.img_res, meshing=self.val_mesh, **self.plot_conf ) ================================================ FILE: utils/__init__.py ================================================ from .cfgnode import CfgNode from .rend_util import * import torch import torch.nn.functional as F import torch.nn as nn from glob import glob import os import numpy as np from pytorch_lightning.callbacks import RichProgressBar from rich.progress import TextColumn class RichProgressBarWithScanId(RichProgressBar): def __init__(self, scan_id, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.custom_column = TextColumn(f"[progress.description]scan_id: {scan_id}") def configure_columns(self, trainer): return super().configure_columns(trainer) + [self.custom_column] def glob_imgs(path): imgs = [] for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG', '*.exr']: imgs.extend(glob(os.path.join(path, ext))) return imgs def glob_depths(path): imgs = [] for ext in ['*.exr']: imgs.extend(glob(os.path.join(path, ext))) return imgs glob_normal = glob_depths def split_input(model_input, total_pixels, n_pixels=10000): ''' Split the input to fit Cuda memory for large resolution. Can decrease the value of n_pixels in case of cuda out of memory error. ''' split = [] for i, indx in enumerate(torch.split(torch.arange(total_pixels, device=model_input['uv'].device), n_pixels, dim=0)): data = model_input.copy() data['uv'] = torch.index_select(model_input['uv'], 1, indx) if 'object_mask' in data: data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx) split.append(data) return split def split_dict(d, batch_size=10000): keys = d.keys() splits = {} for k in d: splits[k] = torch.split(d[k], batch_size) n_splits = len(splits[k]) split_inputs = [] for i in range(n_splits): split = {} for k in d: split[k] = splits[k][i] split_inputs.append(split) return split_inputs def detach_dict(d): return {k: v.detach() for k, v in d.items() if torch.is_tensor(v)} def merge_output(res, total_pixels, batch_size): ''' Merge the split output. ''' model_outputs = {} for entry in res[0]: if res[0][entry] is None: continue if len(res[0][entry].shape) == 1: model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res], 1).reshape(batch_size * total_pixels) else: model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res], 1).reshape(batch_size * total_pixels, -1) return model_outputs def merge_dict(dicts): output = {} for entry in dicts[0]: output[entry] = torch.cat([r[entry] for r in dicts], dim=0) return output from torch.autograd import Function from torch.cuda.amp import custom_bwd, custom_fwd class _trunc_exp(Function): @staticmethod @custom_fwd(cast_inputs=torch.float32) # cast to float32 def forward(ctx, x): ctx.save_for_backward(x) return torch.exp(x) @staticmethod @custom_bwd def backward(ctx, g): x = ctx.saved_tensors[0] return g * torch.exp(x.clamp(-15, 15)) trunc_exp = _trunc_exp.apply def kmeans_pp_centroid(points: torch.Tensor, k): n, c = points.shape centroids = torch.zeros(k, c, device=points.device) centroids[0, :] = points[np.random.randint(0, n), :].clone() d = [0.0] * n for i in range(1, k): sum_all = 0 d = (points.unsqueeze(1) - centroids[:i,:].unsqueeze(0)).norm(p=2, dim=-1).min(dim=1).values sum_all = d.sum() * np.random.random() cumsum = torch.cumsum(d, dim=0) j = ((cumsum - sum_all) > 0).int().argmax() centroids[i,:] = points[j,:].clone() return centroids ================================================ FILE: utils/cfgnode.py ================================================ """ Define a class to hold configurations. Borrows and merges stuff from YACS, fvcore, and detectron2 https://github.com/rbgirshick/yacs https://github.com/facebookresearch/fvcore/ https://github.com/facebookresearch/detectron2/ """ import copy import importlib.util import io import logging import os from ast import literal_eval from typing import Optional import yaml # File exts for yaml _YAML_EXTS = {"", ".yml", ".yaml"} # File exts for python _PY_EXTS = {".py"} # CfgNodes can only contain a limited set of valid types _VALID_TYPES = {tuple, list, str, int, float, bool} # Valid file object types _FILE_TYPES = (io.IOBase,) # Logger logger = logging.getLogger(__name__) class CfgNode(dict): r"""CfgNode is a `node` in the configuration `tree`. It's a simple wrapper around a `dict` and supports access to `attributes` via `keys`. """ IMMUTABLE = "__immutable__" DEPRECATED_KEYS = "__deprecated_keys__" RENAMED_KEYS = "__renamed_keys__" NEW_ALLOWED = "__new_allowed__" def __init__( self, init_dict: Optional[dict] = None, key_list: Optional[list] = None, new_allowed: Optional[bool] = False, ): r""" Args: init_dict (dict): A dictionary to initialize the `CfgNode`. key_list (list[str]): A list of names that index this `CfgNode` from the root. Currently, only used for logging. new_allowed (bool): Whether adding a new key is allowed when merging with other `CfgNode` objects. """ # Recursively convert nested dictionaries in `init_dict` to config tree. init_dict = {} if init_dict is None else init_dict key_list = [] if key_list is None else key_list init_dict = self._create_config_tree_from_dict(init_dict, key_list) super(CfgNode, self).__init__(init_dict) # Control the immutability of the `CfgNode`. self.__dict__[CfgNode.IMMUTABLE] = False # Support for deprecated options. # If you choose to remove support for an option in code, but don't want to change all of the config files # (to allow for deprecated config files to run), you can add the full config key as a string to this set. self.__dict__[CfgNode.DEPRECATED_KEYS] = set() # Support for renamed options. # If you rename an option, record the mapping from the old name to the new name in this dictionary. Optionally, # if the type also changed, you can make this value a tuple that specifies two things: the renamed key, and the # instructions to edit the config file. self.__dict__[CfgNode.RENAMED_KEYS] = { # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example # 'EXAMPLE.OLD.KEY': ( # A more complex example # 'EXAMPLE.NEW.KEY', # "Also convert to a tuple, eg. 'foo' -> ('foo', ) or " # + "'foo.bar' -> ('foo', 'bar')" # ), } # Allow new attributes after initialization. self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed @classmethod def _create_config_tree_from_dict(cls, init_dict: dict, key_list: list): r"""Create a configuration tree using the input dict. Any dict-like objects inside `init_dict` will be treated as new `CfgNode` objects. Args: init_dict (dict): Input dictionary, to create config tree from. key_list (list): A list of names that index this `CfgNode` from the root. Currently only used for logging. """ d = copy.deepcopy(init_dict) for k, v in d.items(): if isinstance(v, dict): # Convert dictionary to CfgNode d[k] = cls(v, key_list=key_list + [k]) else: # Check for valid leaf type or nested CfgNode _assert_with_logging( _valid_type(v, allow_cfg_node=False), "Key {} with value {} is not a valid type; valid types: {}".format( ".".join(key_list + [k]), type(v), _VALID_TYPES ), ) return d def __getattr__(self, name: str): if name in self: return self[name] else: raise AttributeError(name) def __setattr__(self, name: str, value): if self.is_frozen(): raise AttributeError( "Attempted to set {} to {}, but CfgNode is immutable".format( name, value ) ) _assert_with_logging( name not in self.__dict__, "Invalid attempt to modify internal CfgNode state: {}".format(name), ) _assert_with_logging( _valid_type(value, allow_cfg_node=True), "Invalid type {} for key {}; valid types = {}".format( type(value), name, _VALID_TYPES ), ) self[name] = value def __str__(self): def _indent(s_, num_spaces): s = s_.split("\n") if len(s) == 1: return s_ first = s.pop(0) s = [(num_spaces * " ") + line for line in s] s = "\n".join(s) s = first + "\n" + s return s r = "" s = [] for k, v in sorted(self.items()): separator = "\n" if isinstance(v, CfgNode) else " " attr_str = "{}:{}{}".format(str(k), separator, str(v)) attr_str = _indent(attr_str, 2) s.append(attr_str) r += "\n".join(s) return r def __repr__(self): return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) def dump(self, **kwargs): r"""Dump CfgNode to a string. """ def _convert_to_dict(cfg_node, key_list): if not isinstance(cfg_node, CfgNode): _assert_with_logging( _valid_type(cfg_node), "Key {} with value {} is not a valid type; valid types: {}".format( ".".join(key_list), type(cfg_node), _VALID_TYPES ), ) return cfg_node else: cfg_dict = dict(cfg_node) for k, v in cfg_dict.items(): cfg_dict[k] = _convert_to_dict(v, key_list + [k]) return cfg_dict self_as_dict = _convert_to_dict(self, []) return yaml.safe_dump(self_as_dict, **kwargs) def merge_from_file(self, cfg_filename: str): r"""Load a yaml config file and merge it with this CfgNode. Args: cfg_filename (str): Config file path. """ with open(cfg_filename, "r") as f: cfg = self.load_cfg(f) self.merge_from_other_cfg(cfg) def merge_from_other_cfg(self, cfg_other): r"""Merge `cfg_other` into the current `CfgNode`. Args: cfg_other """ _merge_a_into_b(cfg_other, self, self, []) def merge_from_list(self, cfg_list: list): r"""Merge config (keys, values) in a list (eg. from commandline) into this `CfgNode`. Eg. `cfg_list = ['FOO.BAR', 0.5]`. """ _assert_with_logging( len(cfg_list) % 2 == 0, "Override list has odd lengths: {}; it must be a list of pairs".format( cfg_list ), ) root = self for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): if root.key_is_deprecated(full_key): continue if root.key_is_renamed(full_key): root.raise_key_rename_error(full_key) key_list = full_key.split(".") d = self for subkey in key_list[:-1]: _assert_with_logging( subkey in d, "Non-existent key: {}".format(full_key) ) d = d[subkey] subkey = key_list[-1] _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key)) value = self._decode_cfg_value(v) value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key) d[subkey] = value def freeze(self): r"""Make this `CfgNode` and all of its children immutable. """ self._immutable(True) def defrost(self): r"""Make this `CfgNode` and all of its children mutable. """ self._immutable(False) def is_frozen(self): r"""Return mutability. """ return self.__dict__[CfgNode.IMMUTABLE] def _immutable(self, is_immutable: bool): r"""Set mutability and recursively apply to all nested `CfgNode` objects. Args: is_immutable (bool): Whether or not the `CfgNode` and its children are immutable. """ self.__dict__[CfgNode.IMMUTABLE] = is_immutable # Recursively propagate state to all children. for v in self.__dict__.values(): if isinstance(v, CfgNode): v._immutable(is_immutable) for v in self.values(): if isinstance(v, CfgNode): v._immutable(is_immutable) def clone(self): r"""Recursively copy this `CfgNode`. """ return copy.deepcopy(self) def register_deprecated_key(self, key: str): r"""Register key (eg. `FOO.BAR`) a deprecated option. When merging deprecated keys, a warning is generated and the key is ignored. """ _assert_with_logging( key not in self.__dict__[CfgNode.DEPRECATED_KEYS], "key {} is already registered as a deprecated key".format(key), ) self.__dict__[CfgNode.DEPRECATED_KEYS].add(key) def register_renamed_key( self, old_name: str, new_name: str, message: Optional[str] = None ): r"""Register a key as having been renamed from `old_name` to `new_name`. When merging a renamed key, an exception is thrown alerting the user to the fact that the key has been renamed. """ _assert_with_logging( old_name not in self.__dict__[CfgNode.RENAMED_KEYS], "key {} is already registered as a renamed cfg key".format(old_name), ) value = new_name if message: value = (new_name, message) self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value def key_is_deprecated(self, full_key: str): r"""Test if a key is deprecated. """ if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]: logger.warning("deprecated config key (ignoring): {}".format(full_key)) return True return False def key_is_renamed(self, full_key: str): r"""Test if a key is renamed. """ return full_key in self.__dict__[CfgNode.RENAMED_KEYS] def raise_key_rename_error(self, full_key: str): new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key] if isinstance(new_key, tuple): msg = " Note: " + new_key[1] new_key = new_key[0] else: msg = "" raise KeyError( "Key {} was renamed to {}; please update your config.{}".format( full_key, new_key, msg ) ) def is_new_allowed(self): return self.__dict__[CfgNode.NEW_ALLOWED] @classmethod def load_cfg(cls, cfg_file_obj_or_str): r"""Load a configuration into the `CfgNode`. Args: cfg_file_obj_or_str (str or cfg compatible object): Supports loading from: - A file object backed by a YAML file. - A file object backed by a Python source file that exports an sttribute "cfg" (dict or `CfgNode`). - A string that can be parsed as valid YAML. """ _assert_with_logging( isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)), "Expected first argument to be of type {} or {}, but got {}".format( _FILE_TYPES, str, type(cfg_file_obj_or_str) ), ) if isinstance(cfg_file_obj_or_str, str): return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str) elif isinstance(cfg_file_obj_or_str, _FILE_TYPES): return cls._load_cfg_from_file(cfg_file_obj_or_str) else: raise NotImplementedError("Impossible to reach here (unless there's a bug)") @classmethod def _load_cfg_from_file(cls, file_obj): r"""Load a config from a YAML file or a Python source file. """ _, file_ext = os.path.splitext(file_obj.name) if file_ext in _YAML_EXTS: return cls._load_cfg_from_yaml_str(file_obj.read()) elif file_ext in _PY_EXTS: return cls._load_cfg_py_source(file_obj.name) else: raise Exception( "Attempt to load from an unsupported filetype {}; only {} supported".format( _YAML_EXTS.union(_PY_EXTS) ) ) @classmethod def _load_cfg_from_yaml_str(cls, str_obj): r"""Load a config from a YAML string encoding. """ cfg_as_dict = yaml.safe_load(str_obj) return cls(cfg_as_dict) @classmethod def _load_cfg_py_source(cls, filename): r"""Load a config from a Python source file. """ module = _load_module_from_file("yacs.config.override", filename) _assert_with_logging( hasattr(module, "cfg"), "Python module from file {} must export a 'cfg' attribute".format(filename), ) VALID_ATTR_TYPES = {dict, CfgNode} _assert_with_logging( type(module.cfg) in VALID_ATTR_TYPES, "Import module 'cfg' attribute must be in {} but is {}".format( VALID_ATTR_TYPES, type(module.cfg) ), ) return cls(module.cfg) @classmethod def _decode_cfg_value(cls, value): r"""Decodes a raw config value (eg. from a yaml config file or commandline argument) into a Python object. If `value` is a dict, it will be interpreted as a new `CfgNode`. If `value` is a str, it will be evaluated as a literal. Otherwise, it is returned as is. """ # Configs parsed from raw yaml will contain dictionary keys that need to be converted to `CfgNode` objects. if isinstance(value, dict): return cls(value) # All remaining processing is only applied to strings. if not isinstance(value, str): return value # Try to interpret `value` as a: string, number, tuple, list, dict, bool, or None try: value = literal_eval(value) # The following two excepts allow `value` to pass through it when it represents a string. # The type of `value` is always a string (before calling `literal_eval`), but sometimes it *represents* a # string and other times a data structure, like a list. In the case that `value` represents a str, what we # got back from the yaml parser is `foo` *without quotes* (so, not `"foo"`). `literal_eval` is ok with `"foo"`, # but will raise a `ValueError` if given `foo`. In other cases, like paths (`val = 'foo/bar'`) `literal_eval` # will raise a `SyntaxError`. except ValueError: pass except SyntaxError: pass return value # Keep this function in global scope, for backward compataibility. load_cfg = CfgNode.load_cfg def _valid_type(value, allow_cfg_node: Optional[bool] = False): return (type(value) in _VALID_TYPES) or ( allow_cfg_node and isinstance(value, CfgNode) ) def _merge_a_into_b(a: CfgNode, b: CfgNode, root: CfgNode, key_list: list): r"""Merge `CfgNode` `a` into `CfgNode` `b`, clobbering the options in `b` wherever they are also specified in `a`. """ _assert_with_logging( isinstance(a, CfgNode), "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode), ) _assert_with_logging( isinstance(b, CfgNode), "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode), ) for k, v_ in a.items(): full_key = ".".join(key_list + [k]) v = copy.deepcopy(v_) v = b._decode_cfg_value(v) if k in b: v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key) # Recursively merge dicts. if isinstance(v, CfgNode): try: _merge_a_into_b(v, b[k], root, key_list + [k]) except BaseException: raise else: b[k] = v elif b.is_new_allowed(): b[k] = v else: if root.key_is_deprecated(full_key): continue elif root.key_is_renamed(full_key): root.raise_key_rename_error(full_key) else: raise KeyError("Non-existent config key: {}".format(full_key)) def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): r"""Checks that `replacement`, which is intended to replace `original` is of the right type. The type is correct if it matches exactly or is one of a few cases in which the type can easily be coerced. """ original_type = type(original) replacement_type = type(replacement) if replacement_type == original_type: return replacement # If replacement and original types match, cast replacement from `from_type` to `to_type`. def _conditional_cast(from_type, to_type): if replacement_type == from_type and original_type == to_type: return True, to_type(replacement) else: return False, None # Conditional casts. # list <-> tuple casts = [(tuple, list), (list, tuple)] for (from_type, to_type) in casts: converted, converted_value = _conditional_cast(from_type, to_type) if converted: return converted_value raise ValueError( "Type mismatch ({} vs. {} with values ({} vs. {}) for config key: {}".format( original_type, replacement_type, original, replacement, full_key ) ) def _assert_with_logging(cond, msg): if not cond: logger.debug(msg) assert cond, msg def _load_module_from_file(name, filename): spec = importlib.util.spec_from_file_location(name, filename) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module ================================================ FILE: utils/mesh_util.py ================================================ # adapted from https://github.com/zju3dv/manhattan_sdf import numpy as np import open3d as o3d from sklearn.neighbors import KDTree import trimesh import os os.environ['PYOPENGL_PLATFORM'] = 'egl' import pyrender from tqdm.contrib import tenumerate, tzip def nn_correspondance(verts1, verts2): indices = [] distances = [] if len(verts1) == 0 or len(verts2) == 0: return indices, distances kdtree = KDTree(verts1) distances, indices = kdtree.query(verts2) distances = distances.reshape(-1) return distances def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02): pcd_trgt = o3d.geometry.PointCloud() pcd_pred = o3d.geometry.PointCloud() pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3]) pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3]) if down_sample: pcd_pred = pcd_pred.voxel_down_sample(down_sample) pcd_trgt = pcd_trgt.voxel_down_sample(down_sample) verts_pred = np.asarray(pcd_pred.points) verts_trgt = np.asarray(pcd_trgt.points) dist1 = nn_correspondance(verts_pred, verts_trgt) dist2 = nn_correspondance(verts_trgt, verts_pred) precision = np.mean((dist2 < threshold).astype('float')) recal = np.mean((dist1 < threshold).astype('float')) fscore = 2 * precision * recal / (precision + recal) metrics = { 'Acc': np.mean(dist2), 'Comp': np.mean(dist1), 'Prec': precision, 'Recal': recal, 'F-score': fscore, } return metrics class Renderer(): def __init__(self, height=480, width=640): self.renderer = pyrender.OffscreenRenderer(width, height) self.scene = pyrender.Scene() # self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES def __call__(self, height, width, intrinsics, pose, mesh): self.renderer.viewport_height = height self.renderer.viewport_width = width self.scene.clear() self.scene.add(mesh) cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2], fx=intrinsics[0, 0], fy=intrinsics[1, 1]) self.scene.add(cam, pose=self.fix_pose(pose)) return self.renderer.render(self.scene) # , self.render_flags) def fix_pose(self, pose): # 3D Rotation about the x-axis. t = np.pi c = np.cos(t) s = np.sin(t) R = np.array([[1, 0, 0], [0, c, -s], [0, s, c]]) axis_transform = np.eye(4) axis_transform[:3, :3] = R return pose @ axis_transform def mesh_opengl(self, mesh): return pyrender.Mesh.from_trimesh(mesh) def delete(self): self.renderer.delete() def refuse(mesh, poses, K, H, W, far_clip=5.0): renderer = Renderer() mesh_opengl = renderer.mesh_opengl(mesh) volume = o3d.pipelines.integration.ScalableTSDFVolume( voxel_length=0.01, sdf_trunc=3 * 0.01, color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 ) for i, pose in tenumerate(poses): intrinsic = K rgb = np.ones((H, W, 3)) rgb = (rgb * 255).astype(np.uint8) rgb = o3d.geometry.Image(rgb) _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl) depth_pred = o3d.geometry.Image(depth_pred) rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( rgb, depth_pred, depth_scale=1.0, depth_trunc=far_clip, convert_rgb_to_intensity=False ) fx, fy, cx, cy = intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2] intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx, fy=fy, cx=cx, cy=cy) extrinsic = np.linalg.inv(pose) volume.integrate(rgbd, intrinsic, extrinsic) return volume.extract_triangle_mesh() def depth2mesh(depths, poses, K, H, W): volume = o3d.pipelines.integration.ScalableTSDFVolume( voxel_length=0.01, sdf_trunc=3 * 0.01, color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 ) for depth, pose in tzip(depths, poses): rgb = np.ones((H, W, 3)) rgb = (rgb * 255).astype(np.uint8) rgb = o3d.geometry.Image(rgb) depth = o3d.geometry.Image(depth) rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( rgb, depth, depth_scale=1.0, depth_trunc=5.0, convert_rgb_to_intensity=False ) fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2] intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx, fy=fy, cx=cx, cy=cy) extrinsic = np.linalg.inv(pose) volume.integrate(rgbd, intrinsic, extrinsic) return volume.extract_triangle_mesh() ================================================ FILE: utils/plots.py ================================================ import plotly.graph_objs as go import plotly.offline as offline from plotly.subplots import make_subplots import numpy as np import torch from skimage import measure import torchvision.utils as vutils import trimesh from PIL import Image import cv2 import mcubes from utils import rend_util def plot(implicit_network, indices, plot_data, path, epoch, img_res, plot_nimgs, resolution=None, grid_boundary=None, meshing=True, level=0): if plot_data is not None: if 'rgb_eval' in plot_data: plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res) if 'hdr_eval' in plot_data: plot_images(plot_data['hdr_eval'], plot_data['hdr_gt'], path, epoch, plot_nimgs, img_res, 'hdr', True) if 'rgb_surface' in plot_data: plot_images(plot_data['rgb_surface'], plot_data['rgb_gt'], path, '{0}s'.format(epoch), plot_nimgs, img_res) if 'rendered' in plot_data: plot_images(plot_data['rendered'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res, 'rendered') if 'normal_map' in plot_data: plot_imgs_wo_gt(plot_data['normal_map'], path, epoch, plot_nimgs, img_res) if 'depth_eval' in plot_data: plot_depths(plot_data['depth_eval'], path, epoch, plot_nimgs, img_res) if 'lmask_eval' in plot_data: plot_images(plot_data['lmask_eval'], plot_data['lmask_gt'], path, epoch, plot_nimgs, img_res, 'light_mask') if 'Kd' in plot_data: # plot_imgs_wo_gt(plot_data['albedo'], path, epoch, plot_nimgs, img_res, 'albedo') plot_images(plot_data['Ks'], plot_data['Kd'], path, epoch, plot_nimgs, img_res, 'albedo') if 'roughness' in plot_data: plot_colormap(plot_data['roughness'], path, epoch, plot_nimgs, img_res) if 'metallic' in plot_data: plot_colormap(plot_data['metallic'], path, epoch, plot_nimgs, img_res, 'metallic') if 'emission' in plot_data: plot_imgs_wo_gt(plot_data['emission'], path, '{0}'.format(epoch), plot_nimgs, img_res, 'emission', True) if meshing: path = f"{path}/mesh" cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose']) data = [] # plot surface surface_traces = get_surface_trace(path=path, epoch=epoch, sdf=lambda x: implicit_network(x)[:, 0], resolution=resolution, grid_boundary=grid_boundary, level=level ) if surface_traces is not None: data.append(surface_traces[0]) # plot cameras locations if plot_data is not None: for i, loc, dir in zip(indices, cam_loc, cam_dir): data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i))) fig = go.Figure(data=data) scene_dict = dict(xaxis=dict(range=[-6, 6], autorange=False), yaxis=dict(range=[-6, 6], autorange=False), zaxis=dict(range=[-6, 6], autorange=False), aspectratio=dict(x=1, y=1, z=1)) fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True) filename = '{0}/surface_{1}.html'.format(path, epoch) offline.plot(fig, filename=filename, auto_open=False) def visualize_pointcloud(points, filename): fig = go.Figure( data = [ get_3D_scatter_trace(points, 'Pointcloud', 1) ] ) scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False), yaxis=dict(range=[-3, 3], autorange=False), zaxis=dict(range=[-3, 3], autorange=False), aspectratio=dict(x=1, y=1, z=1)) fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True) offline.plot(fig, filename=filename, auto_open=False) def visualize_clustered_pointcloud(points, labels, centroids, filename): fig = go.Figure() if centroids is not None: fig.add_trace(get_3D_scatter_trace(centroids, "Centroids", 10)) for c in torch.unique(labels): cluster = points[labels == c, :] fig.add_trace(get_3D_scatter_trace(cluster, f"Emitter #{int(c)}")) scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False), yaxis=dict(range=[-3, 3], autorange=False), zaxis=dict(range=[-3, 3], autorange=False), aspectratio=dict(x=1, y=1, z=1)) fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True) offline.plot(fig, filename=filename, auto_open=False) def visualize_marked_pointcloud(points, counts, path, epoch): fig = go.Figure( data = [ get_3D_marked_scatter_trace(points, counts, 'Pointcloud samples') ] ) scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False), yaxis=dict(range=[-3, 3], autorange=False), zaxis=dict(range=[-3, 3], autorange=False), aspectratio=dict(x=1, y=1, z=1)) fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True) filename = '{0}/pointcloud/{1}.html'.format(path, epoch) offline.plot(fig, filename=filename, auto_open=False) def get_3D_scatter_trace(points, name='', size=3, caption=None): assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped " assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped " trace = go.Scatter3d( x=points[:, 0].cpu(), y=points[:, 1].cpu(), z=points[:, 2].cpu(), mode='markers', name=name, marker=dict( size=size, line=dict( width=2, ), opacity=1.0, ), text=caption) return trace def get_3D_marked_scatter_trace(points, marks, name='', size=1, caption=None): assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped " assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped " trace = go.Scatter3d( x=points[:, 0].cpu(), y=points[:, 1].cpu(), z=points[:, 2].cpu(), mode='markers', name=name, marker=dict( size=size, line=dict( width=2, ), color=marks.squeeze().cpu(), colorscale='Viridis', opacity=1.0, ), text=caption) return trace def get_3D_quiver_trace(points, directions, color='#bd1540', name=''): assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped " assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped " assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped " assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped " trace = go.Cone( name=name, x=points[:, 0].cpu(), y=points[:, 1].cpu(), z=points[:, 2].cpu(), u=directions[:, 0].cpu(), v=directions[:, 1].cpu(), w=directions[:, 2].cpu(), sizemode='absolute', sizeref=0.125, showscale=False, colorscale=[[0, color], [1, color]], anchor="tail" ) return trace def get_surface_trace(path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0): grid = get_grid_uniform(resolution, grid_boundary) points = grid['grid_points'] z = [] for i, pnts in enumerate(torch.split(points, 100000, dim=0)): z.append(sdf(pnts).detach().cpu().numpy()) z = np.concatenate(z, axis=0) if (not (np.min(z) > level or np.max(z) < level)): z = z.astype(np.float32) verts, faces, normals, values = measure.marching_cubes( volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], grid['xyz'][2].shape[0]).transpose([1, 0, 2]), level=level, spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1])) verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) I, J, K = faces.transpose() traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], i=I, j=J, k=K, name='implicit_surface', color='#ffffff', opacity=1.0, flatshading=False, lighting=dict(diffuse=1, ambient=0, specular=0), lightposition=dict(x=0, y=0, z=-1), showlegend=True)] meshexport = trimesh.Trimesh(verts, faces, normals) meshexport.export('{0}/surface_{1}.ply'.format(path, epoch), 'ply') if return_mesh: return meshexport return traces return None def extract_fields(bound_min, bound_max, resolution, query_func): N = 64 X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) u = np.zeros([resolution, resolution, resolution], dtype=np.float32) with torch.no_grad(): for xi, xs in enumerate(X): for yi, ys in enumerate(Y): for zi, zs in enumerate(Z): xx, yy, zz = torch.meshgrid(xs, ys, zs) pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val return u def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): print('threshold: {}'.format(threshold)) u = extract_fields(bound_min, bound_max, resolution, query_func) vertices, triangles = mcubes.marching_cubes(u, threshold) b_max_np = bound_max.detach().cpu().numpy() b_min_np = bound_min.detach().cpu().numpy() vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] mesh = trimesh.Trimesh(vertices, triangles) return mesh def get_surface_high_res_mesh(sdf, resolution=100, grid_boundary=[-2.0, 2.0], level=0, take_components=True): # get low res mesh to sample point cloud grid = get_grid_uniform(100, grid_boundary) z = [] points = grid['grid_points'] for i, pnts in enumerate(torch.split(points, 100000, dim=0)): z.append(sdf(pnts).detach().cpu().numpy()) z = np.concatenate(z, axis=0) z = z.astype(np.float32) verts, faces, normals, values = measure.marching_cubes( volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], grid['xyz'][2].shape[0]).transpose([1, 0, 2]), level=level, spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1])) verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) mesh_low_res = trimesh.Trimesh(verts, faces, normals) if take_components: components = mesh_low_res.split(only_watertight=False) areas = np.array([c.area for c in components], dtype=np.float) mesh_low_res = components[areas.argmax()] recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] recon_pc = torch.from_numpy(recon_pc).float().cuda() # Center and align the recon pc s_mean = recon_pc.mean(dim=0) s_cov = recon_pc - s_mean s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0] if torch.det(vecs) < 0: vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs) helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), (recon_pc - s_mean).unsqueeze(-1)).squeeze() grid_aligned = get_grid(helper.cpu(), resolution) grid_points = grid_aligned['grid_points'] g = [] for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)): g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), pnts.unsqueeze(-1)).squeeze() + s_mean) grid_points = torch.cat(g, dim=0) # MC to new grid points = grid_points z = [] for i, pnts in enumerate(torch.split(points, 100000, dim=0)): z.append(sdf(pnts).detach().cpu().numpy()) z = np.concatenate(z, axis=0) meshexport = None if (not (np.min(z) > level or np.max(z) < level)): z = z.astype(np.float32) verts, faces, normals, values = measure.marching_cubes( volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0], grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]), level=level, spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1])) verts = torch.from_numpy(verts).cuda().float() verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), verts.unsqueeze(-1)).squeeze() verts = (verts + grid_points[0]).cpu().numpy() meshexport = trimesh.Trimesh(verts, faces, normals) return meshexport def get_surface_by_grid(grid_params, sdf, resolution=100, level=0, higher_res=False): grid_params = grid_params * [[1.5], [1.0]] # params = PLOT_DICT[scan_id] input_min = torch.tensor(grid_params[0]).float() input_max = torch.tensor(grid_params[1]).float() if higher_res: # get low res mesh to sample point cloud grid = get_grid(None, 100, input_min=input_min, input_max=input_max, eps=0.0) z = [] points = grid['grid_points'] for i, pnts in enumerate(torch.split(points, 100000, dim=0)): z.append(sdf(pnts).detach().cpu().numpy()) z = np.concatenate(z, axis=0) z = z.astype(np.float32) verts, faces, normals, values = measure.marching_cubes( volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], grid['xyz'][2].shape[0]).transpose([1, 0, 2]), level=level, spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1], grid['xyz'][0][2] - grid['xyz'][0][1])) verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) mesh_low_res = trimesh.Trimesh(verts, faces, normals) components = mesh_low_res.split(only_watertight=False) areas = np.array([c.area for c in components], dtype=np.float) mesh_low_res = components[areas.argmax()] recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] recon_pc = torch.from_numpy(recon_pc).float().cuda() # Center and align the recon pc s_mean = recon_pc.mean(dim=0) s_cov = recon_pc - s_mean s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0] if torch.det(vecs) < 0: vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs) helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), (recon_pc - s_mean).unsqueeze(-1)).squeeze() grid_aligned = get_grid(helper.cpu(), resolution, eps=0.01) else: grid_aligned = get_grid(None, resolution, input_min=input_min, input_max=input_max, eps=0.0) grid_points = grid_aligned['grid_points'] if higher_res: g = [] for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)): g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), pnts.unsqueeze(-1)).squeeze() + s_mean) grid_points = torch.cat(g, dim=0) # MC to new grid points = grid_points z = [] for i, pnts in enumerate(torch.split(points, 100000, dim=0)): z.append(sdf(pnts).detach().cpu().numpy()) z = np.concatenate(z, axis=0) meshexport = None if (not (np.min(z) > level or np.max(z) < level)): z = z.astype(np.float32) verts, faces, normals, values = measure.marching_cubes( volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0], grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]), level=level, spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1])) if higher_res: verts = torch.from_numpy(verts).cuda().float() verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), verts.unsqueeze(-1)).squeeze() verts = (verts + grid_points[0]).cpu().numpy() else: verts = verts + np.array([grid_aligned['xyz'][0][0], grid_aligned['xyz'][1][0], grid_aligned['xyz'][2][0]]) meshexport = trimesh.Trimesh(verts, faces, normals) # CUTTING MESH ACCORDING TO THE BOUNDING BOX if higher_res: bb = grid_params transformation = np.eye(4) transformation[:3, 3] = (bb[1,:] + bb[0,:])/2. bounding_box = trimesh.creation.box(extents=bb[1,:] - bb[0,:], transform=transformation) meshexport = meshexport.slice_plane(bounding_box.facets_origin, -bounding_box.facets_normal) return meshexport def get_grid_uniform(resolution, grid_boundary=[-2.0, 2.0]): x = np.linspace(grid_boundary[0], grid_boundary[1], resolution) y = x z = x xx, yy, zz = np.meshgrid(x, y, z) grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float) return {"grid_points": grid_points.cuda(), "shortest_axis_length": 2.0, "xyz": [x, y, z], "shortest_axis_index": 0} def get_grid(points, resolution, input_min=None, input_max=None, eps=0.1): if input_min is None or input_max is None: input_min = torch.min(points, dim=0)[0].squeeze().numpy() input_max = torch.max(points, dim=0)[0].squeeze().numpy() bounding_box = input_max - input_min shortest_axis = np.argmin(bounding_box) if (shortest_axis == 0): x = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(x) - np.min(x) y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) elif (shortest_axis == 1): y = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(y) - np.min(y) x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) elif (shortest_axis == 2): z = np.linspace(input_min[shortest_axis] - eps, input_max[shortest_axis] + eps, resolution) length = np.max(z) - np.min(z) x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) print(x.shape, y.shape, z.shape) xx, yy, zz = np.meshgrid(x, y, z) # print(xx.shape, yy.shape, zz.shape) # xx = torch.from_numpy(xx.flatten()).cuda().float() # yy = torch.from_numpy(yy.flatten()).cuda().float() # zz = torch.from_numpy(zz.flatten()).cuda().float() # grid_points = torch.cat([xx, yy, zz], dim=1).T grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() return {"grid_points": grid_points, "shortest_axis_length": length, "xyz": [x, y, z], "shortest_axis_index": shortest_axis} def plot_imgs_wo_gt(normal_maps, path, epoch, plot_nrow, img_res, path_name='normal', is_hdr=False): normal_maps_plot = lin2img(normal_maps, img_res) tensor = vutils.make_grid(normal_maps_plot, scale_each=False, normalize=False, nrow=plot_nrow).cpu().detach().numpy() tensor = tensor.transpose(1, 2, 0) if not is_hdr: scale_factor = 255 tensor = (tensor * scale_factor).astype(np.uint8) img = Image.fromarray(tensor) img.save('{0}/{1}/{2}.png'.format(path, path_name, epoch)) else: cv2.imwrite('{0}/{1}/{2}.exr'.format(path, path_name, epoch), tensor[:,:,::-1]) def plot_imgs_filter(rgb_points, ground_true, path, epoch, img_res, path_name='rendering'): output = lin2img(rgb_points, img_res).squeeze(0) # (1, 3, h, w) ground_true = lin2img(ground_true, img_res).squeeze(0) output = output.permute(1, 2, 0).cpu().numpy() scale_factor = 255 output = (output * scale_factor).astype(np.uint8) ground_true = ground_true.permute(1, 2, 0).cpu().numpy() ground_true = (ground_true * scale_factor).astype(np.uint8) output = output[:,:,::-1] ground_true = ground_true[:,:,::-1] filtered = cv2.ximgproc.guidedFilter(ground_true, output, 10, 2, -1) cv2.imwrite('{0}/{1}/{2}.png'.format(path, path_name, epoch), filtered) def plot_colormap(mat_info, path, epoch, plot_nrow, img_res, colormap=cv2.COLORMAP_VIRIDIS, path_name='roughness'): mat_info_plot = lin2img(mat_info, img_res) tensor = vutils.make_grid(mat_info_plot, scale_each=False, normalize=False, nrow=plot_nrow).cpu().detach().numpy() tensor = tensor.transpose(1, 2, 0) if colormap is None: cv2.imwrite('{0}/{1}/{2}.exr'.format(path, path_name, epoch), tensor) else: tensor = (tensor * 255).astype(np.uint8) img = cv2.applyColorMap(tensor, colormap) cv2.imwrite('{0}/{1}/{2}.png'.format(path, path_name, epoch), img) def plot_depths(depth_maps, path, epoch, plot_nrow, img_res, colormap=cv2.COLORMAP_VIRIDIS): depth_maps_plot = lin2img(depth_maps, img_res) tensor = vutils.make_grid(depth_maps_plot, scale_each=False, normalize=False, nrow=plot_nrow).cpu().detach().numpy() tensor = tensor.transpose(1, 2, 0) # scale_factor = 255 # tensor = (tensor * scale_factor).astype(np.uint8) if colormap is None: cv2.imwrite('{0}/depth/{1}.exr'.format(path, epoch), tensor) else: tensor = tensor / (tensor.max() + 1e-6) tensor = (tensor * 255).astype(np.uint8) img = cv2.applyColorMap(tensor, colormap) cv2.imwrite('{0}/depth/{1}.png'.format(path, epoch), img) # img = Image.fromarray(tensor) # img.save('{0}/normal_{1}.png'.format(path, epoch)) def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res, path_name='rendering', is_hdr=False): ground_true = ground_true.cuda() output_vs_gt = torch.cat((rgb_points, ground_true), dim=0) output_vs_gt_plot = lin2img(output_vs_gt, img_res) tensor = vutils.make_grid(output_vs_gt_plot, scale_each=False, normalize=False, nrow=plot_nrow).cpu().detach().numpy() tensor = tensor.transpose(1, 2, 0) if not is_hdr: scale_factor = 255 tensor = (tensor * scale_factor).astype(np.uint8) img = Image.fromarray(tensor) img.save('{0}/{1}/{2}.png'.format(path, path_name, epoch)) else: cv2.imwrite('{0}/{1}/{2}.exr'.format(path, path_name, epoch), tensor[:,:,::-1]) def lin2img(tensor, img_res): batch_size, num_samples, channels = tensor.shape return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1]) ================================================ FILE: utils/rend_util.py ================================================ import numpy as np import imageio import skimage import cv2 import torch from torch.nn import functional as F def linear_to_srgb(data): return torch.where(data <= 0.0031308, data * 12.92, 1.055 * (data ** (1 / 2.4)) - 0.055) def get_psnr(img1, img2, normalize_rgb=False): if normalize_rgb: # [-1,1] --> [0,1] img1 = (img1 + 1.) / 2. img2 = (img2 + 1. ) / 2. mse = torch.mean((img1 - img2) ** 2) # psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda()) psnr = -10. * torch.log(mse) / np.log(10) return psnr def load_rgb(path, normalize_rgb = False, is_hdr = False): if not is_hdr: img = imageio.imread(path) img = skimage.img_as_float32(img) else: img = cv2.imread(path, -1)[:,:,::-1].copy() if normalize_rgb: # [-1,1] --> [0,1] img -= 0.5 img *= 2. img = img.transpose(2, 0, 1) return img def load_mask(path): img = imageio.imread(path) img = skimage.img_as_float32(img) if len(img.shape) == 3: img = img[:, :, 0] return img # (h, w) def load_depth(path): img = cv2.imread(path, -1) if len(img.shape) == 3: img = img[:,:,-1] return img def load_normal(path): img = cv2.imread(path, -1)[:,:,::-1] return img.copy() def load_K_Rt_from_P(filename, P=None): if P is None: lines = open(filename).read().splitlines() if len(lines) == 4: lines = lines[1:] lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] P = np.asarray(lines).astype(np.float32).squeeze() out = cv2.decomposeProjectionMatrix(P) K = out[0] R = out[1] t = out[2] K = K/K[2,2] intrinsics = np.eye(4, dtype=np.float32) intrinsics[:3, :3] = K pose = np.eye(4, dtype=np.float32) pose[:3, :3] = R.transpose() pose[:3,3] = (t[:3] / t[3])[:,0] return intrinsics, pose def depth_to_world(uv, intrinsics, pose, depth, depth_mask=None): x_cam, y_cam = torch.unbind(uv, dim=1) z_cam = torch.ones_like(x_cam) xyz_view = lift(x_cam, y_cam, z_cam, intrinsics) xyz_view[:,:-1] = xyz_view[:,:-1] * depth.unsqueeze(1) if depth_mask is not None: xyz_view = xyz_view[depth_mask,:] xyz_world = pose @ xyz_view.T return xyz_world.T def get_camera_params(uv, pose, intrinsics): if pose.shape[1] == 7: #In case of quaternion vector representation cam_loc = pose[:, 4:] R = quat_to_rot(pose[:,:4]) p = torch.eye(4, device=pose.device).repeat(pose.shape[0],1,1).float() p[:, :3, :3] = R p[:, :3, 3] = cam_loc else: # In case of pose matrix representation cam_loc = pose[:, :3, 3] p = pose batch_size, num_samples, _ = uv.shape depth = torch.ones((batch_size, num_samples), device=pose.device) x_cam = uv[:, :, 0].view(batch_size, -1) y_cam = uv[:, :, 1].view(batch_size, -1) # z_cam = -depth.view(batch_size, -1) z_cam = depth.view(batch_size, -1) pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) # pixel_points_cam[:,:,0] = -pixel_points_cam[:,:,0] # permute for batch matrix product pixel_points_cam = pixel_points_cam.permute(0, 2, 1) world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] ray_dirs = world_coords - cam_loc[:, None, :] # ray_dirs = F.normalize(ray_dirs, dim=2) return ray_dirs, cam_loc def get_camera_for_plot(pose): if pose.shape[1] == 7: #In case of quaternion vector representation cam_loc = pose[:, 4:].detach() R = quat_to_rot(pose[:,:4].detach()) else: # In case of pose matrix representation cam_loc = pose[:, :3, 3] R = pose[:, :3, :3] cam_dir = R[:, :3, 2] return cam_loc, cam_dir def lift(x, y, z, intrinsics): # parse intrinsics intrinsics = intrinsics fx = intrinsics[..., 0, 0] fy = intrinsics[..., 1, 1] cx = intrinsics[..., 0, 2] cy = intrinsics[..., 1, 2] sk = intrinsics[..., 0, 1] x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z # homogeneous return torch.stack((x_lift, y_lift, z, torch.ones_like(z)), dim=-1) def quat_to_rot(q): batch_size, _ = q.shape q = F.normalize(q, dim=1) R = torch.ones((batch_size, 3,3), device=q.device) qr=q[:,0] qi = q[:, 1] qj = q[:, 2] qk = q[:, 3] R[:, 0, 0]=1-2 * (qj**2 + qk**2) R[:, 0, 1] = 2 * (qj *qi -qk*qr) R[:, 0, 2] = 2 * (qi * qk + qr * qj) R[:, 1, 0] = 2 * (qj * qi + qk * qr) R[:, 1, 1] = 1-2 * (qi**2 + qk**2) R[:, 1, 2] = 2*(qj*qk - qi*qr) R[:, 2, 0] = 2 * (qk * qi-qj * qr) R[:, 2, 1] = 2 * (qj*qk + qi*qr) R[:, 2, 2] = 1-2 * (qi**2 + qj**2) return R def rot_to_quat(R): batch_size, _,_ = R.shape q = torch.ones((batch_size, 4), device=R.device) R00 = R[:, 0,0] R01 = R[:, 0, 1] R02 = R[:, 0, 2] R10 = R[:, 1, 0] R11 = R[:, 1, 1] R12 = R[:, 1, 2] R20 = R[:, 2, 0] R21 = R[:, 2, 1] R22 = R[:, 2, 2] q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2 q[:, 1]=(R21-R12)/(4*q[:,0]) q[:, 2] = (R02 - R20) / (4 * q[:, 0]) q[:, 3] = (R10 - R01) / (4 * q[:, 0]) return q def get_general_sphere_intersections(cam_loc, ray_directions, center, r): n_rays = cam_loc.size(0) # print(cam_loc.shape, ray_directions.shape) cam_loc = cam_loc - center.unsqueeze(0) ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3), cam_loc.view(-1, 3, 1)).squeeze(-1) under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2) intersect_mask = (under_sqrt >= 0).squeeze(-1) # (n_rays,) under_sqrt = under_sqrt[intersect_mask,:] ray_cam_dot = ray_cam_dot[intersect_mask,:] sphere_intersections = torch.sqrt(under_sqrt) * torch.tensor([-1, 1], device=cam_loc.device).float() - ray_cam_dot front_mask = (sphere_intersections > 0).all(dim=-1) intersect_mask[intersect_mask.clone()] &= front_mask sphere_intersections = sphere_intersections[front_mask,:] intersection_normals = cam_loc[intersect_mask,:] + ray_directions[intersect_mask,:] * sphere_intersections[:,:1] intersection_points = intersection_normals + center.unsqueeze(0) intersection_normals = F.normalize(intersection_normals, dim=1, eps=1e-8) return intersection_points, intersection_normals, intersect_mask def get_sphere_intersections(cam_loc, ray_directions, r = 1.0): # Input: n_rays x 3 ; n_rays x 3 # Output: n_rays x 1, n_rays x 1 (close and far) ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3), cam_loc.view(-1, 3, 1)).squeeze(-1) under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2) # sanity check if (under_sqrt <= 0).sum() > 0: print('BOUNDING SPHERE PROBLEM!') exit() sphere_intersections = torch.sqrt(under_sqrt) * torch.tensor([-1, 1], device=cam_loc.device).float() - ray_cam_dot sphere_intersections = sphere_intersections.clamp_min(0.0) return sphere_intersections def add_depth_noise(depth, depth_mask, scale=1): mu = 0.0001125 * depth**2 + 0.0048875 sigma = 0.002925 * depth**2 + 0.003325 noise = torch.randn_like(depth) * sigma + mu return (depth + noise * scale) * depth_mask