Showing preview only (228K chars total). Download the full file or copy to clipboard to get everything.
Repository: jingsenzhu/i2-sdf
Branch: main
Commit: 58c9a8241feb
Files: 31
Total size: 217.1 KB
Directory structure:
gitextract_xvem6nqy/
├── .gitignore
├── DATA_CONVENTION.md
├── LICENSE
├── README.md
├── config/
│ ├── synthetic.yml
│ └── synthetic_light_mask.yml
├── data/
│ ├── normalize_cameras.py
│ └── npz_to_blender.py
├── dataset/
│ ├── __init__.py
│ ├── eval_dataset.py
│ └── train_dataset.py
├── environment.yml
├── i2-sdf-dataset-links.csv
├── main_recon.py
├── model/
│ ├── __init__.py
│ ├── eval/
│ │ ├── __init__.py
│ │ └── recon.py
│ ├── network/
│ │ ├── __init__.py
│ │ ├── density.py
│ │ ├── embedder.py
│ │ ├── mlp.py
│ │ └── ray_sampler.py
│ ├── rendering/
│ │ ├── __init__.py
│ │ └── brdf.py
│ └── trainer/
│ ├── __init__.py
│ └── recon.py
└── utils/
├── __init__.py
├── cfgnode.py
├── mesh_util.py
├── plots.py
└── rend_util.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
data/synthetic
exps
__pycache__
archive
tmp*
================================================
FILE: DATA_CONVENTION.md
================================================
# Data Convention
The format of our multi-view dataset is derived from [VolSDF](https://github.com/lioryariv/volsdf/blob/main/DATA_CONVENTION.md).
### Directory Structure
```python
scan<scan_id>/
cameras.npz
image/ -> {:04d}.png # tone-mapped LDR images
depth/ -> {:04d}.exr
normal/ -> {:04d}.exr
mask/ -> {:04d}.png
val/ -> {:04d}.png # validation images (LDR)
hdr/ -> {:04d}.exr # raw HDR images
# followings are optional
light_mask/ -> {:04d}.png # emitter mask images
material/ -> {:04d}_kd.exr, {:04d}_ks.exr, {:04d}_rough.exr # diffuse, specular albedo and roughness
```
Zeros areas in the depth maps and normal maps indicate invalid areas such as windows.
Note that not all areas inside the scenes use the GGX material model. We'll provide a mask of the invalid areas.
### Camera Information
The `cameras.npz` contains each image's associated camera projection matrix `'world_mat_{i}'` and a normalization matrix `'scale_mat_{i}'`, the same as VolSDF. Besides, we also provide a validation set of images for novel view synthesis, whose associated camera projection matrices are `'val_mat_{i}'`. Validation set and training set share the same normalization matrix.
The normalization matrices may not be readily available in `cameras.npz`. You can manually run `data/normalize_cameras.py` to generate `cameras_normalize.npz`. Since our method requires the entire scene to be within a radius-3 bounding sphere, we suggest normalizing cameras by radius 2.0 or 2.5.
An example of running `normalize_cameras.py`:
```shell
python normalize_cameras.py --id <scan_id> -n <synthetic/real/...> -r 2.0
```
Note that we follow **OpenCV camera coordinate system** (X right, Y downwards, Z into the image plane).
### Dataset Format Conversion
If you want to convert the dataset format to NeRF blender format, run `npz_to_blender.py`:
```sh
python npz_to_blender.py --root /path/to/dataset
```
The script will automatically scale all pose matrices to fit within a `[-1, 1]` bounding box.
### About Real Dataset
Our real dataset comes from [Inria](https://repo-sam.inria.fr/fungraph/deep-indoor-relight/) and [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/), with estimated depth from MVS tools and manually-labeled light masks (2 living room scenes). All depths has an absolute scale without needs of shifting like MonoSDF. All camera calibrations and depths are provided by the authors of [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/). We thank them for providing the datasets.
Normal is not provided in the real dataset, and we find it sufficient for plausible reconstruction without a normal supervision in these scenes. Of course, you can estimate normal using any methods if you want to enable normal supervision.
### About EXR format
We suggest using OpenCV to load an `.exr` format `float32` image:
```python
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' # Enable OpenCV support for EXR
import cv2
...
im = cv2.imread(im_path, -1) # im will be an numpy.float32 array of shape (H, W, C)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # cv2 reads image in BGR shape, convert into RGB
```
We suggest using [tev](https://github.com/Tom94/tev) to preview HDR `.exr` images conveniently.
### About Mesh
Due to copyright issues, we could not release the original 3D mesh of our synthetic scenes. Instead, we'll provide a point cloud sampled from the GT mesh to enable 3D reconstruction evaluations.
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 Jingsen Zhu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
**News**
- `04/04/2023` dataset preview release: 2 synthetic scenes available
- `15/04/2023` code release: 3D reconstruction and novel view synthesis part
- `21/04/2023` dataset release: real data
**TODO**
- [ ] Full dataset release
- [x] Code release for 3D reconstruction and novel view synthesis
- [ ] Code release for intrinsic decomposition and scene editing
**Dataset released**
- Synthetic: `kitchen_0`, `bedroom_relight_0`, `bedroom_0`, `bedroom_1`, `bedroom_relight_1`, `diningroom_0`, `livingroom_0`, `livingroom_1`, more scenes to be released
- Real: `inria_livingroom`, `nisr_livingroom`, `nisr_coffee_shop_0`, `nisr_coffee_shop_1`, release complete
# I<sup>2</sup>-SDF: Intrinsic Indoor Scene Reconstruction and Editing via Raytracing in Neural SDFs (CVPR 2023)
### [Project Page](https://jingsenzhu.github.io/i2-sdf/) | [Paper](https://arxiv.org/abs/2303.07634) | [Dataset](i2-sdf-dataset-links.csv)
## Setup
### Installation
```
conda env create -f environment.yml
conda activate i2sdf
```
### Data preparation
Download our synthetic dataset and extract them into `data/synthetic`. If you want to run on your customized dataset, we provide a brief introduction to our data convention [here](DATA_CONVENTION.md).
## Dataset
We provide a high-quality synthetic indoor scene multi-view dataset, with ground truth camera pose and geometry annotations. See [HERE](DATA_CONVENTION.md) for data conventions. Click [HERE](https://mega.nz/folder/jdhDnTqL#Ija678SU2Va_JJOiwqmdEg) to download.
## 3D Reconstruction and Novel View Synthesis
### Training
```
python main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version>
```
Note: `config/synthetic.yml` doesn't contain light mask network, while `config/synthetic_light_mask.yml` contains.
If you run out of GPU memory, try to reduce the `split_n_pixels` (i.e. validation batch size), `batch_size` in the config. The default parameters are evaluated under RTX A6000 (48GB). For RTX 3090 (24GB), try to set `split_n_pixels` 5000.
### Evaluation
#### Novel view synthesis
```
python main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test [--is_val] [--full]
```
The optional flag `--is_val` evaluates on the validation set instead of training set, `--full` produces full-resolution rendered images without downsampling.
#### View Interpolation
```
python main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test --test_mode interpolate --inter_id <view_id_0> <view_id_1> [--full]
```
Generates a view interpolation video between 2 views. Requires `ffmpeg` being installed.
The number of frames and frame rate of the video can be specified by options.
#### Mesh Extraction
```
python main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test --test_mode mesh
```
## Intrinsic Decomposition and Scene Editing
**Brewing🍺, code coming soon.**
## Citation
If you find our work is useful, please consider cite:
```
@inproceedings{zhu2023i2sdf,
title = {I$^2$-SDF: Intrinsic Indoor Scene Reconstruction and Editing via Raytracing in Neural SDFs},
author = {Jingsen Zhu and Yuchi Huo and Qi Ye and Fujun Luan and Jifan Li and Dianbing Xi and Lisha Wang and Rui Tang and Wei Hua and Hujun Bao and Rui Wang},
booktitle = {CVPR},
year = {2023}
}
```
## Acknowledgement
- This repository is built upon [Pytorch lightning](https://lightning.ai/).
- Thanks to Lior Yariv for her excellent work [VolSDF](https://lioryariv.github.io/volsdf/).
- Thanks to [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/) team for providing their real-world dataset.
================================================
FILE: config/synthetic.yml
================================================
train:
expname: synthetic
learning_rate: 5.0e-4
steps: 200000
checkpoint_freq: 10000
plot_freq: 500
split_n_pixels: 12000
batch_size: 1600
pdf_criterion: DEPTH
plot:
plot_nimgs: 1
grid_boundary: [-1.5, 1.5]
loss:
eikonal_weight: 0.1
smooth_weight: 0.01
smooth_iter: 150000
depth_weight: 0.1
normal_weight: 0.05
bubble_weight: 0.5
min_bubble_iter: 50000
max_bubble_iter: 150000
dataset:
data_dir: synthetic
img_res: [480, 640]
downsample: 2
pdf_prune: 0.05
pdf_max: 0.2
model:
feature_vector_size: 256
scene_bounding_sphere: 3.0
implicit_network:
d_in: 3
d_out: 1
dims: [ 256, 256, 256, 256, 256, 256, 256, 256 ]
geometric_init: True
bias: 0.6
skip_in: [4]
weight_norm: True
embed_type: 'positional'
multires: 6
rendering_network:
# mode: idr
# d_in: 9
# Don't find actual differences between 'nerf' and 'idr' mode
# Choose 'nerf' mode for a slight faster performance
mode: nerf
d_in: 3
d_out: 3
dims: [ 256, 256, 256, 256 ]
weight_norm: True
embed_type: 'positional'
multires: 4
density:
params_init:
beta: 0.1
beta_min: 0.0001
ray_sampler:
near: 0.0
N_samples: 64
N_samples_eval: 128
N_samples_extra: 32
eps: 0.1
beta_iters: 10
max_total_iters: 5
N_samples_inverse_sphere: 32
add_tiny: 1.0e-6
================================================
FILE: config/synthetic_light_mask.yml
================================================
train:
expname: synthetic_light
learning_rate: 5.0e-4
steps: 200000
checkpoint_freq: 10000
plot_freq: 500
split_n_pixels: 12000
batch_size: 1600
pdf_criterion: DEPTH
plot:
plot_nimgs: 1
grid_boundary: [-1.5, 1.5]
loss:
eikonal_weight: 0.1
smooth_weight: 0.01
smooth_iter: 150000
depth_weight: 0.1
normal_weight: 0.05
bubble_weight: 0.5
light_mask_weight: 0.5
min_bubble_iter: 50000
max_bubble_iter: 150000
dataset:
data_dir: synthetic
img_res: [480, 640]
downsample: 2
pdf_prune: 0.05
pdf_max: 0.2
model:
feature_vector_size: 256
scene_bounding_sphere: 3.0
implicit_network:
d_in: 3
d_out: 1
dims: [ 256, 256, 256, 256, 256, 256 ]
geometric_init: True
bias: 0.6
skip_in: [3]
weight_norm: True
embed_type: 'positional'
multires: 6
rendering_network:
mode: nerf
d_in: 3
d_out: 3
dims: [ 256, 256, 256 ]
weight_norm: True
embed_type: 'positional'
multires: 4
light_network:
dims: [ 128 ]
weight_norm: True
density:
params_init:
beta: 0.1
beta_min: 0.0001
ray_sampler:
near: 0.0
N_samples: 64
N_samples_eval: 128
N_samples_extra: 32
eps: 0.1
beta_iters: 10
max_total_iters: 5
N_samples_inverse_sphere: 32
add_tiny: 1.0e-6
================================================
FILE: data/normalize_cameras.py
================================================
import cv2
import numpy as np
import argparse
from copy import deepcopy
def get_center_point(num_cams,cameras):
A = np.zeros((3 * num_cams, 3 + num_cams))
b = np.zeros((3 * num_cams, 1))
camera_centers=np.zeros((3,num_cams))
for i in range(num_cams):
P0 = cameras['world_mat_%d' % i][:3, :]
K = cv2.decomposeProjectionMatrix(P0)[0]
R = cv2.decomposeProjectionMatrix(P0)[1]
c = cv2.decomposeProjectionMatrix(P0)[2]
c = c / c[3]
camera_centers[:,i]=c[:3].flatten()
# v = np.linalg.inv(K) @ np.array([800, 600, 1])
# v = v / np.linalg.norm(v)
v=R[2,:]
A[3 * i:(3 * i + 3), :3] = np.eye(3)
A[3 * i:(3 * i + 3), 3 + i] = -v
b[3 * i:(3 * i + 3)] = c[:3]
soll= np.linalg.pinv(A) @ b
return soll,camera_centers
def normalize_cameras(original_cameras_filename,output_cameras_filename,num_of_cameras,radius,convert_coord):
cameras = np.load(original_cameras_filename)
if num_of_cameras==-1:
all_files=cameras.files
maximal_ind=0
for field in all_files:
if 'val' not in field:
maximal_ind=np.maximum(maximal_ind,int(field.split('_')[-1]))
num_of_cameras=maximal_ind+1
soll, camera_centers = get_center_point(num_of_cameras, cameras)
center = soll[:3].flatten()
max_radius = np.linalg.norm((center[:, np.newaxis] - camera_centers), axis=0).max() * 1.1
normalization = np.eye(4).astype(np.float32)
normalization[0, 3] = center[0]
normalization[1, 3] = center[1]
normalization[2, 3] = center[2]
normalization[0, 0] = max_radius / radius
normalization[1, 1] = max_radius / radius
normalization[2, 2] = max_radius / radius
cameras_new = {}
cameras_new = deepcopy(dict(cameras))
for i in range(num_of_cameras):
cameras_new['scale_mat_%d' % i] = normalization
# cameras_new['world_mat_%d' % i] = cameras['world_mat_%d' % i].copy()
# if ('val_mat_%d' % i) in cameras:
# cameras_new['val_mat_%d' % i] = cameras['val_mat_%d' % i].copy()
def opengl2opencv(P):
out = cv2.decomposeProjectionMatrix(P[:3,:])
K, R, t = out[0:3]
K = K/K[2,2]
intrinsics = np.eye(4, dtype=np.float32)
intrinsics[:3, :3] = K
t = (t[:3] / t[3]).squeeze()
w2c = np.eye(4, dtype=np.float32)
w2c[:3,:3] = R
w2c[:3,3] = -R @ t
T = np.diag([1, -1, -1, 1])
w2c = T @ w2c
return intrinsics @ w2c
if convert_coord:
cameras_new['world_mat_%d' % i] = opengl2opencv(cameras_new['world_mat_%d' % i])
if ('val_mat_%d' % i) in cameras_new:
cameras_new['val_mat_%d' % i] = opengl2opencv(cameras_new['val_mat_%d' % i])
# cameras_new['world_mat_%d' % i] = T @ cameras_new['world_mat_%d' % i]
np.savez(output_cameras_filename, **cameras_new)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Normalizing cameras')
parser.add_argument('-i', '--input_cameras_file', type=str, default="cameras.npz",
help='the input cameras file')
parser.add_argument('-o', '--output_cameras_file', type=str, default="cameras_normalize.npz",
help='the output cameras file')
parser.add_argument('--id', type=int, nargs='?')
parser.add_argument('-n', '--name', type=str, default='synthetic')
parser.add_argument('--number_of_cams',type=int, default=-1,
help='Number of cameras, if -1 use all')
parser.add_argument('-r', '--radius', type=float, default=2.0)
parser.add_argument('-c', '--convert_coord', action='store_true')
args = parser.parse_args()
if args.id:
args.input_cameras_file = f'{args.name}/scan{args.id}/cameras.npz'
args.output_cameras_file = f'{args.name}/scan{args.id}/cameras_normalize.npz'
normalize_cameras(args.input_cameras_file, args.output_cameras_file, args.number_of_cams, args.radius, args.convert_coord)
================================================
FILE: data/npz_to_blender.py
================================================
"""
Transform npz-formatted scenes to json-formatted scene (NeRF blender format)
Scale all poses to fit in a [-1, 1] box
"""
import copy
import json
import os
import cv2
import numpy as np
from tqdm import tqdm
import argparse
def to16b(img):
img = img.clip(0, 1) * 65535
return img.astype(np.uint16)
def opencv_to_gl(pose):
mat = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
pose[:3, :3] = pose[:3, :3] @ mat
return pose
def get_offset(poses):
eyes = np.stack([pose[:3, 3] for pose in poses])
scale = eyes.max(axis=0) - eyes.min(axis=0)
print(f'scale : {scale}')
offset = -(eyes.max(axis=0) + eyes.min(axis=0)) / 2
print(f'offset : {offset}')
return scale / 2, offset
def scale_pose(pose, scale, offset):
pose[:3, 3] = (pose[:3, 3] + offset) / scale
# print(pose[:3, 3])
return pose.tolist()
def load_K_Rt_from_P(filename, P=None):
if P is None:
lines = open(filename).read().splitlines()
if len(lines) == 4:
lines = lines[1:]
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
P = np.asarray(lines).astype(np.float32).squeeze()
out = cv2.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
K = K / K[2, 2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--root', required=True)
parser.add_argument('--scale', action='store_true')
args = parser.parse_args()
os.chdir(os.path.join(args.root))
image_dir = 'image'
n_images = len(os.listdir(image_dir))
val_dir = 'val'
n_val = len(os.listdir(val_dir))
os.makedirs('depths', exist_ok=True)
cam_file = 'cameras.npz'
camera_dict = np.load(cam_file)
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
val_mats = [camera_dict['val_mat_%d' % idx].astype(np.float32) for idx in range(n_val)]
intrinsics_all = []
pose_all = []
for mat in world_mats + val_mats:
P = mat
P = P[:3, :4]
intrinsics, pose = load_K_Rt_from_P(None, P)
intrinsics_all.append(intrinsics)
pose_all.append(opencv_to_gl(pose))
train_json = dict()
train_json['fl_y'] = intrinsics[1][1]
train_json['h'] = int(intrinsics[1, 2] * 2)
train_json['fl_x'] = intrinsics[0][0]
train_json['w'] = int(intrinsics[0, 2] * 2)
if args.scale:
scale, offset = get_offset(pose_all)
# train_json['enable_depth_loading'] = True
# train_json['integer_depth_scale'] = 1 / 65535
train_json['frames'] = []
test_json = copy.deepcopy(train_json)
test_json['enable_depth_loading'] = False
for i in tqdm(range(n_images)):
frames = train_json['frames']
if args.scale:
depth = cv2.imread(os.path.join('depth', '{:04d}.exr'.format(i)), -1)
cv2.imwrite(os.path.join('depths', '{:04d}.exr'.format(i)), depth / scale.max())
pose = pose_all[i].tolist() if not args.scale else scale_pose(pose_all[i], scale.max(), offset)
frame = {
'file_path': f'./image/{i:04d}',
'depth_path': f'./depths/{i:04d}.exr' if args.scale else f'./depth/{i:04d}.exr',
'transform_matrix': pose
}
frames.append(frame)
for i in tqdm(range(n_val)):
frames = test_json['frames']
pose = pose_all[i + n_images].tolist() if not args.scale else scale_pose(pose_all[i + n_images], scale.max(), offset)
frame = {
'file_path': f'./val/{i:04d}',
'transform_matrix': pose
}
frames.append(frame)
with open('transforms_train.json', 'w') as f:
json.dump(train_json, f, indent=4)
with open('transforms_test.json', 'w') as f:
json.dump(test_json, f, indent=4)
with open('transforms_val.json', 'w') as f:
json.dump(test_json, f, indent=4)
if __name__ == '__main__':
main()
================================================
FILE: dataset/__init__.py
================================================
from .train_dataset import *
from .eval_dataset import *
================================================
FILE: dataset/eval_dataset.py
================================================
from copy import deepcopy
import os
import torch
import numpy as np
from torch.utils.data import Dataset
import utils.plots as plt
import torch.nn.functional as F
import utils
from utils import rend_util
from tqdm.contrib import tzip, tenumerate
from scipy.spatial.transform import Rotation as Rot
from scipy.spatial.transform import Slerp
import cv2
class GridDataset(Dataset):
"""
Used for mesh extraction
"""
def __init__(self, points, xyz) -> None:
super().__init__()
self.grid_points = points
self.xyz = xyz
def __len__(self):
return self.grid_points.size(0)
def __getitem__(self, index):
return self.grid_points[index]
class PlotDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir,
plot_nimgs,
scan_id=0,
is_val=False,
data=None,
is_hdr=False,
indices=None,
use_lmask=False,
**kwargs
):
self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))
val_dir = '{0}/val'.format(self.instance_dir)
is_val = is_val and os.path.exists(val_dir)
lmask_dir = '{0}/light_mask'.format(self.instance_dir)
self.use_lmask = use_lmask and os.path.exists(lmask_dir)
if is_val:
print("[INFO] Validation set detected")
if is_val or data is None:
assert os.path.exists(self.instance_dir), "Data directory is empty"
if is_val:
image_dir = val_dir
elif is_hdr:
image_dir = '{0}/hdr'.format(self.instance_dir)
else:
image_dir = '{0}/image'.format(self.instance_dir)
image_paths = sorted(utils.glob_imgs(image_dir))
if indices is not None:
print(f"[INFO] Selecting indices: {indices}")
image_paths = [image_paths[i] for i in indices]
self.n_images = len(image_paths)
self.indices = indices if indices is not None else list(range(self.n_images))
self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)
camera_dict = np.load(self.cam_file)
scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.indices] if not is_val else [camera_dict['scale_mat_0'].astype(np.float32)] * len(self.indices)
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.indices] if not is_val else [camera_dict['val_mat_%d' % idx].astype(np.float32) for idx in self.indices]
self.intrinsics_all = []
self.pose_all = []
for scale_mat, world_mat in zip(scale_mats, world_mats):
P = world_mat @ scale_mat
P = P[:3, :4]
intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
self.pose_all.append(torch.from_numpy(pose).float())
self.intrinsics_all = torch.stack(self.intrinsics_all, 0)
self.pose_all = torch.stack(self.pose_all, 0)
self.rgb_images = []
for path in image_paths:
rgb = rend_util.load_rgb(path, is_hdr=is_hdr)
self.img_res = [rgb.shape[1], rgb.shape[2]]
rgb = rgb.reshape(3, -1).transpose(1, 0)
self.rgb_images.append(torch.from_numpy(rgb).float())
self.rgb_images = torch.stack(self.rgb_images, 0)
if self.use_lmask:
self.lightmask_images = []
lmask_paths = sorted(utils.glob_imgs(lmask_dir))
for path in lmask_paths:
lmask = rend_util.load_mask(path)
lmask = lmask.reshape(-1, 1)
self.lightmask_images.append(torch.from_numpy(lmask).float())
self.lightmask_images = torch.stack(self.lightmask_images, 0)
self.total_pixels = self.rgb_images.size(1)
else:
self.intrinsics_all = data['intrinsics']
self.pose_all = data['pose']
self.rgb_images = data['rgb']
self.n_images = len(self.rgb_images)
self.img_res = [data['img_res'][0], data['img_res'][1]]
self.total_pixels = self.img_res[0] * self.img_res[1]
if 'light_mask' in data:
self.lightmask_images = data['light_mask']
self.use_lmask = True
if (scale := kwargs.get('downsample', 1)) > 1:
old_img_res = deepcopy(self.img_res)
self.img_res[0] //= scale
self.img_res[1] //= scale
self.total_pixels = self.img_res[0] * self.img_res[1]
self.rgb_images = self.rgb_images.transpose(1, 2).reshape(-1, 3, old_img_res[0], old_img_res[1])
self.rgb_images = F.interpolate(self.rgb_images, self.img_res, mode='area')
self.rgb_images = self.rgb_images.reshape(-1, 3, self.total_pixels).transpose(1, 2)
if self.use_lmask:
self.lightmask_images = self.lightmask_images.transpose(1, 2).reshape(-1, 1, old_img_res[0], old_img_res[1])
self.lightmask_images = F.interpolate(self.lightmask_images, self.img_res, mode='area')
self.lightmask_images = self.lightmask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)
self.intrinsics_all = self.intrinsics_all.clone()
self.intrinsics_all[:,0,0] /= scale
self.intrinsics_all[:,1,1] /= scale
self.intrinsics_all[:,0,2] /= scale
self.intrinsics_all[:,1,2] /= scale
print(f"[INFO] Plot image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total")
if plot_nimgs == -1:
self.plot_nimgs = self.n_images
else:
self.plot_nimgs = min(plot_nimgs, self.n_images)
self.shuffle = kwargs.get('shuffle', True)
if self.shuffle:
self.shuffle_plot_index()
def shuffle_plot_index(self):
if self.shuffle:
self.plot_index = torch.randperm(self.n_images)[:self.plot_nimgs]
def __len__(self):
return self.plot_nimgs
def get_uv(self):
uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
uv = uv.reshape(2, -1).transpose(1, 0)
return uv
def __getitem__(self, idx):
if self.shuffle:
idx = self.plot_index[idx]
uv = self.get_uv()
sample = {
"uv": uv,
"intrinsics": self.intrinsics_all[idx],
"pose": self.pose_all[idx]
}
ground_truth = {
"rgb": self.rgb_images[idx]
}
if self.use_lmask:
ground_truth['light_mask'] = self.lightmask_images[idx]
return idx, sample, ground_truth
def collate_fn(self, batch_list):
# get list of dictionaries and returns input, ground_true as dictionary for all batch instances
batch_list = zip(*batch_list)
all_parsed = []
for entry in batch_list:
if type(entry[0]) is dict:
# make them all into a new dict
ret = {}
for k in entry[0].keys():
ret[k] = torch.stack([obj[k] for obj in entry])
all_parsed.append(ret)
else:
all_parsed.append(torch.LongTensor(entry))
return tuple(all_parsed)
class InterpolateDataset(torch.utils.data.Dataset):
"""
View interpolation: specify 2 view ids from training set and generate a video moving between them
"""
def __init__(self,
data_dir,
# img_res,
id0,
id1,
num_frames=60,
scan_id=0,
**kwargs
):
self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))
assert os.path.exists(self.instance_dir), "Data directory is empty"
image_dir = '{0}/image'.format(self.instance_dir)
im = cv2.imread(f"{image_dir}/{id0:04d}.png")
h, w, _ = im.shape
self.img_res = [h, w]
self.total_pixels = h * w
self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)
camera_dict = np.load(self.cam_file)
P0 = camera_dict['world_mat_%d' % id0].astype(np.float32) @ camera_dict['scale_mat_%d' % id0].astype(np.float32)
P1 = camera_dict['world_mat_%d' % id1].astype(np.float32) @ camera_dict['scale_mat_%d' % id1].astype(np.float32)
P0 = P0[:3,:]
P1 = P1[:3,:]
K, pose0 = rend_util.load_K_Rt_from_P(None, P0)
_, pose1 = rend_util.load_K_Rt_from_P(None, P1)
rots = Rot.from_matrix(np.stack([pose0[:3,:3].T, pose1[:3,:3].T]))
slerp = Slerp([0, 1], rots)
if (scale := kwargs.get('downsample', 1)) > 1:
self.img_res[0] = self.img_res[0] // scale
self.img_res[1] = self.img_res[1] // scale
self.total_pixels = self.img_res[0] * self.img_res[1]
K[0,0] /= scale
K[1,1] /= scale
K[0,2] /= scale
K[1,2] /= scale
self.intrinsics = torch.from_numpy(K).float()
self.pose_all = []
for i in range(num_frames):
ratio = np.sin(((i / num_frames) - 0.5) * np.pi) * 0.5 + 0.5
t = (1 - ratio) * pose0[:3,3] + ratio * pose1[:3,3]
R = slerp(ratio).as_matrix()
pose = np.eye(4, dtype=np.float32)
pose[:3,3] = t
pose[:3,:3] = R.T
self.pose_all.append(torch.from_numpy(pose).float())
self.pose_all = torch.stack(self.pose_all)
self.n_frames = num_frames
def __len__(self):
return self.n_frames
def __getitem__(self, idx):
uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
uv = uv.reshape(2, -1).transpose(1, 0)
sample = {
"uv": uv,
"intrinsics": self.intrinsics,
"pose": self.pose_all[idx]
}
return idx, sample
def collate_fn(self, batch_list):
# get list of dictionaries and returns input, ground_true as dictionary for all batch instances
batch_list = zip(*batch_list)
all_parsed = []
for entry in batch_list:
if type(entry[0]) is dict:
# make them all into a new dict
ret = {}
for k in entry[0].keys():
ret[k] = torch.stack([obj[k] for obj in entry])
all_parsed.append(ret)
else:
all_parsed.append(torch.LongTensor(entry))
return tuple(all_parsed)
class RelightDataset(PlotDataset):
def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwargs):
super().__init__(data_dir, 1, scan_id, is_val, None, False, [edit_cfg['index']], True, **kwargs)
self.edit_mask = 'mask' in edit_cfg
if self.edit_mask:
self.mask = rend_util.load_mask(edit_cfg['mask']).astype(np.float32)
mh, mw = self.mask.shape
if mh != self.img_res[0] or mw != self.img_res[1]:
self.mask = cv2.resize(self.mask, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)
self.mask = (self.mask > 0.5)
self.mask = torch.from_numpy(self.mask).float().flatten()
if 'normal' in edit_cfg:
self.loadattr(edit_cfg, 'normal', 0)
self.normal = self.normal.reshape(-1, 3)
self.normal = F.normalize(self.normal, dim=-1, eps=1e-6)
if 'rough' in edit_cfg:
self.loadattr(edit_cfg, 'rough', 1)
self.rough = self.rough.reshape(-1, 1)
if 'kd' in edit_cfg:
self.loadattr(edit_cfg, 'kd', 2)
self.kd = self.kd.reshape(-1, 3)
if 'ks' in edit_cfg:
self.loadattr(edit_cfg, 'ks', 2)
self.ks = self.ks.reshape(-1, 3)
self.uv = self.get_uv()
def loadattr(self, edit_cfg, attr, mode=0):
if mode == 0:
im = rend_util.load_normal(edit_cfg[attr])
elif mode == 1:
im = cv2.imread(edit_cfg[attr], -1)
if len(im.shape) == 3:
im = im[:,:,-1]
else:
im = rend_util.load_rgb(edit_cfg[attr]).transpose(1, 2, 0)
h, w = im.shape[:2]
if h != self.img_res[0] or w != self.img_res[1]:
im = cv2.resize(im, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)
setattr(self, attr, torch.from_numpy(im).float())
def __len__(self):
return self.total_pixels
def __getitem__(self, idx):
sample = {
"uv": self.uv[idx].unsqueeze(0),
"intrinsics": self.intrinsics_all[0],
"pose": self.pose_all[0]
}
ground_truth = {
"rgb": self.rgb_images[0][idx],
'light_mask': self.lightmask_images[0][idx]
# 'edit_mask': self.edit_mask[idx]
}
if self.edit_mask:
ground_truth['mask'] = self.mask[idx]
if hasattr(self, 'normal'):
ground_truth['normal'] = self.normal[idx]
if hasattr(self, 'rough'):
ground_truth['rough'] = self.rough[idx]
if hasattr(self, 'kd'):
ground_truth['kd'] = self.kd[idx]
if hasattr(self, 'ks'):
ground_truth['ks'] = self.ks[idx]
return idx, sample, ground_truth
class RelightVideoDataset(PlotDataset):
def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwargs):
self.n_frames = edit_cfg['n_frames']
self.img_idx = edit_cfg['index']
super().__init__(data_dir, 1, scan_id, is_val, None, False, [edit_cfg['index']] * self.n_frames, True, **kwargs)
self.edit_mask = 'mask' in edit_cfg
if self.edit_mask:
self.mask = rend_util.load_mask(edit_cfg['mask']).astype(np.float32)
mh, mw = self.mask.shape
if mh != self.img_res[0] or mw != self.img_res[1]:
self.mask = cv2.resize(self.mask, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)
self.mask = (self.mask > 0.5)
self.mask = torch.from_numpy(self.mask).float().flatten()
self.uv = self.get_uv()
def __len__(self):
return self.n_frames
def __getitem__(self, idx):
sample = {
"uv": self.uv,
"intrinsics": self.intrinsics_all[idx],
"pose": self.pose_all[idx]
}
ground_truth = {
"rgb": self.rgb_images[idx],
'light_mask': self.lightmask_images[idx]
# 'edit_mask': self.edit_mask[idx]
}
if self.edit_mask:
ground_truth['mask'] = self.mask
return idx, sample, ground_truth
================================================
FILE: dataset/train_dataset.py
================================================
import json
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
import utils
from utils import rend_util
from tqdm.contrib import tzip, tenumerate
from scipy.spatial.transform import Rotation as Rot
from scipy.spatial.transform import Slerp
class ReconDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir,
scan_id=0,
use_mask=False,
use_depth=False,
use_normal=False,
use_bubble=False,
use_lightmask=False,
is_hdr=False,
**kwargs
):
self.sampling_idx = slice(None)
self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))
assert os.path.exists(self.instance_dir), "Data directory is empty"
print(f"[INFO] Loading data from {self.instance_dir}")
image_dir = '{0}/image'.format(self.instance_dir) if not is_hdr else '{0}/hdr'.format(self.instance_dir)
self.is_hdr = is_hdr
if is_hdr:
print("[INFO] Using HDR image")
image_paths = sorted(utils.glob_imgs(image_dir))
self.n_images = len(image_paths)
self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)
camera_dict = np.load(self.cam_file)
scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
self.intrinsics_all = []
self.pose_all = []
for scale_mat, world_mat in zip(scale_mats, world_mats):
P = world_mat @ scale_mat
P = P[:3, :4]
intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
self.pose_all.append(torch.from_numpy(pose).float())
self.intrinsics_all = torch.stack(self.intrinsics_all, 0)
self.pose_all = torch.stack(self.pose_all, 0)
self.rgb_images = []
for path in image_paths:
rgb = rend_util.load_rgb(path, is_hdr=is_hdr)
self.img_res = [rgb.shape[1], rgb.shape[2]]
rgb = rgb.reshape(3, -1).transpose(1, 0)
self.rgb_images.append(torch.from_numpy(rgb).float())
self.rgb_images = torch.stack(self.rgb_images, 0)
self.total_pixels = self.rgb_images.size(1)
print(f"[INFO] image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total")
uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
uv = np.flip(uv, axis=0).copy()
self.uv = torch.from_numpy(uv).float()
self.uv = self.uv.reshape(2, -1).transpose(1, 0) # (h*w, 2)
mask_dir = '{0}/mask'.format(self.instance_dir)
self.use_mask = use_mask
if self.use_mask:
if os.path.exists(mask_dir):
mask_paths = sorted(utils.glob_imgs(mask_dir))
# assert len(mask_paths) == self.n_images
self.mask_images = []
for path in mask_paths:
mask = rend_util.load_mask(path)
mask = mask.reshape(-1, 1)
self.mask_images.append(torch.from_numpy(mask).float())
self.mask_images = torch.stack(self.mask_images, 0)
else:
print("[INFO] No existing mask image, use one mask as default")
self.mask_images = torch.ones(self.n_images, self.total_pixels, 1, dtype=torch.float)
lmask_dir = '{0}/light_mask'.format(self.instance_dir)
self.use_lightmask = use_lightmask and os.path.exists(lmask_dir)
if self.use_lightmask:
lmask_paths = sorted(utils.glob_imgs(lmask_dir))
self.lightmask_images = []
for path in lmask_paths:
lmask = rend_util.load_mask(path)
lmask = lmask.reshape(-1, 1)
self.lightmask_images.append(torch.from_numpy(lmask).float())
self.lightmask_images = torch.stack(self.lightmask_images, 0)
depth_dir = '{0}/depth'.format(self.instance_dir)
self.use_depth = use_depth and os.path.exists(depth_dir)
self.use_bubble = use_bubble and os.path.exists(depth_dir)
if self.use_depth or self.use_bubble:
self.depth_images = []
self.depth_masks = []
self.pointcloud = [] # pointcloud for bubble loss, unprojected from depth and poses
self.pointlinks = [] # link from pixel index to pointcloud index, value -1 when the pixel is invalid at pointcloud
self.pixlinks = [] # link from pointcloud index to pixel index
depth_paths = sorted(utils.glob_depths(depth_dir))
n_points = 0
if kwargs.get('noise_scale', 0.0) > 0:
print(f"[INFO] Ablation study: using noise scale {kwargs.get('noise_scale')}")
for scale_mat, depth_path, intrinsics, pose, i in tzip(scale_mats, depth_paths, self.intrinsics_all, self.pose_all, range(len(self.pose_all))):
depth = rend_util.load_depth(depth_path)
depth = torch.from_numpy(depth.reshape(-1)).float()
depth = depth / scale_mat[2,2]
valid_indices = torch.where((depth > 1e-3) & (depth < 6))[0]
if i == 0 and scale_mat[2,2] != 1:
print(f"[INFO] Depth scaled by {scale_mat[2,2]:.2f}")
depth_mask = torch.zeros([self.total_pixels], dtype=torch.bool)
depth_mask[valid_indices] = True
# if self.use_depth:
if (noise_scale := kwargs.get('noise_scale', 0.0)) > 0:
depth = rend_util.add_depth_noise(depth, depth_mask.float(), noise_scale)
self.depth_images.append(depth)
self.depth_masks.append(depth_mask)
if self.use_bubble:
pointlink = -torch.ones([self.total_pixels], dtype=torch.long)
pointlink[depth_mask] = torch.arange(0, len(valid_indices), dtype=torch.long) + n_points
pixlink = torch.arange(i * self.total_pixels, (i + 1) * self.total_pixels, dtype=torch.long)[depth_mask]
n_points += len(valid_indices)
self.pointlinks.append(pointlink)
self.pixlinks.append(pixlink)
self.pointcloud.append(rend_util.depth_to_world(self.uv, intrinsics, pose, depth, depth_mask))
self.depth_images = torch.stack(self.depth_images, 0)
self.depth_masks = torch.stack(self.depth_masks, 0)
if self.use_bubble:
self.pointcloud = torch.cat(self.pointcloud, 0)
self.pointlinks = torch.cat(self.pointlinks, 0)
self.pixlinks = torch.cat(self.pixlinks, 0)
self.pointcloud = self.pointcloud[:,:3] / self.pointcloud[:,3:]
self.pdf_prune = kwargs.get('pdf_prune', 0)
self.pdf_max = kwargs.get('pdf_max', None)
print(f"[INFO] PDF clamped to {self.pdf_prune}")
normal_dir = '{0}/normal'.format(self.instance_dir)
self.use_normal = use_normal and os.path.exists(normal_dir)
if self.use_normal:
self.normal_images = []
self.normal_masks = []
normal_paths = sorted(utils.glob_normal(normal_dir))
for pose, normal_path in tzip(self.pose_all, normal_paths):
normal = rend_util.load_normal(normal_path)
normal = torch.from_numpy(normal.reshape(-1, 3)).float()
valid_indices = torch.where(torch.linalg.vector_norm(normal, dim=1) > 1e-3)[0]
R = pose[:3,:3]
normal = (R @ normal.T).T # convert normal from view space to world space
normal = F.normalize(normal, dim=1, eps=1e-6)
self.normal_images.append(normal)
normal_mask = torch.zeros([self.total_pixels], dtype=torch.bool)
normal_mask[valid_indices] = True
self.normal_masks.append(normal_mask)
self.normal_images = torch.stack(self.normal_images, 0)
self.normal_masks = torch.stack(self.normal_masks, 0)
def __len__(self):
return self.n_images * self.total_pixels
def __getitem__(self, idx):
pidx = idx % self.total_pixels
tidx = idx
idx = idx // self.total_pixels
sample = {
"uv": self.uv[pidx].unsqueeze(0),
"intrinsics": self.intrinsics_all[idx],
"pose": self.pose_all[idx]
}
ground_truth = {
"rgb": self.rgb_images[idx][pidx]
}
if self.use_mask:
ground_truth['mask'] = self.mask_images[idx][pidx]
if self.use_lightmask:
ground_truth['light_mask'] = self.lightmask_images[idx][pidx]
if self.use_depth or self.use_bubble:
ground_truth['depth'] = self.depth_images[idx][pidx]
ground_truth['depth_mask'] = self.depth_masks[idx][pidx]
if self.use_normal:
ground_truth['normal'] = self.normal_images[idx][pidx]
ground_truth['normal_mask'] = self.normal_masks[idx][pidx]
return tidx, idx, sample, ground_truth
def collate_fn(self, batch_list):
# get list of dictionaries and returns input, ground_true as dictionary for all batch instances
batch_list = zip(*batch_list)
all_parsed = []
for entry in batch_list:
if type(entry[0]) is dict:
# make them all into a new dict
ret = {}
for k in entry[0].keys():
ret[k] = torch.stack([obj[k] for obj in entry])
all_parsed.append(ret)
else:
all_parsed.append(torch.LongTensor(entry))
return tuple(all_parsed)
class MaterialDataset(torch.utils.data.Dataset):
def __init__(self,
data_dir,
scan_id=0,
downsample_train=1,
is_hdr=False,
**kwargs
):
self.sampling_idx = slice(None)
self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))
assert os.path.exists(self.instance_dir), "Data directory is empty"
image_dir = '{0}/image'.format(self.instance_dir) if not is_hdr else '{0}/hdr'.format(self.instance_dir)
self.is_hdr = is_hdr
if is_hdr:
print("[INFO] Using HDR image")
image_paths = sorted(utils.glob_imgs(image_dir))
self.n_images = len(image_paths)
self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)
camera_dict = np.load(self.cam_file)
scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
self.intrinsics_all = []
self.pose_all = []
for scale_mat, world_mat in zip(scale_mats, world_mats):
P = world_mat @ scale_mat
P = P[:3, :4]
intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
self.pose_all.append(torch.from_numpy(pose).float())
self.intrinsics_all = torch.stack(self.intrinsics_all, 0)
self.pose_all = torch.stack(self.pose_all, 0)
self.rgb_images = []
for path in image_paths:
rgb = rend_util.load_rgb(path, is_hdr=is_hdr)
self.img_res = [rgb.shape[1], rgb.shape[2]]
rgb = rgb.reshape(3, -1).transpose(1, 0)
self.rgb_images.append(torch.from_numpy(rgb).float())
self.rgb_images = torch.stack(self.rgb_images, 0)
self.total_pixels = self.rgb_images.size(1)
mask_dir = '{0}/mask'.format(self.instance_dir)
self.use_mask = os.path.exists(mask_dir)
if self.use_mask:
mask_paths = sorted(utils.glob_imgs(mask_dir))
# assert len(mask_paths) == self.n_images
self.mask_images = []
for path in mask_paths:
mask = rend_util.load_mask(path)
mask = mask.reshape(-1, 1)
self.mask_images.append(torch.from_numpy(mask).float())
self.mask_images = torch.stack(self.mask_images, 0)
lmask_dir = '{0}/light_mask'.format(self.instance_dir)
self.use_lightmask = os.path.exists(lmask_dir)
if self.use_lightmask:
print("[INFO] Light mask detected")
lmask_paths = sorted(utils.glob_imgs(lmask_dir))
self.lightmask_images = []
for path in lmask_paths:
lmask = rend_util.load_mask(path)
lmask = lmask.reshape(-1, 1)
self.lightmask_images.append(torch.from_numpy(lmask).float())
self.lightmask_images = torch.stack(self.lightmask_images, 0)
if downsample_train > 1:
old_res = (self.img_res[0], self.img_res[1])
self.rgb_images = self.rgb_images.transpose(1, 2).reshape(-1, 3, *old_res)
self.img_res[0] //= downsample_train
self.img_res[1] //= downsample_train
self.total_pixels = self.img_res[0] * self.img_res[1]
self.rgb_images = F.interpolate(self.rgb_images, self.img_res, mode='area')
self.rgb_images = self.rgb_images.reshape(-1, 3, self.total_pixels).transpose(1, 2)
self.intrinsics_all = self.intrinsics_all.clone()
self.intrinsics_all[:,0,0] /= downsample_train
self.intrinsics_all[:,1,1] /= downsample_train
self.intrinsics_all[:,0,2] /= downsample_train
self.intrinsics_all[:,1,2] /= downsample_train
if self.use_mask:
self.mask_images = self.mask_images.transpose(1, 2).reshape(-1, 1, *old_res)
self.mask_images = F.interpolate(self.mask_images, self.img_res, mode='area')
self.mask_images = self.mask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)
if self.use_lightmask:
self.lightmask_images = self.lightmask_images.transpose(1, 2).reshape(-1, 1, *old_res)
self.lightmask_images = F.interpolate(self.lightmask_images, self.img_res, mode='area')
self.lightmask_images[self.lightmask_images > 0] = 1
self.lightmask_images = self.lightmask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)
print(f"[INFO] image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total")
uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
uv = np.flip(uv, axis=0).copy()
self.uv = torch.from_numpy(uv).float()
self.uv = self.uv.reshape(2, -1).transpose(1, 0) # (h*w, 2)
def __len__(self):
return self.n_images * self.total_pixels
def __getitem__(self, idx):
pidx = idx % self.total_pixels
tidx = idx
idx = idx // self.total_pixels
sample = {
"uv": self.uv[pidx].unsqueeze(0),
"intrinsics": self.intrinsics_all[idx],
"pose": self.pose_all[idx]
}
ground_truth = {
"rgb": self.rgb_images[idx][pidx]
}
if self.use_mask:
ground_truth['mask'] = self.mask_images[idx][pidx]
if self.use_lightmask:
ground_truth['light_mask'] = self.lightmask_images[idx][pidx]
return tidx, idx, sample, ground_truth
def collate_fn(self, batch_list):
# get list of dictionaries and returns input, ground_true as dictionary for all batch instances
batch_list = zip(*batch_list)
all_parsed = []
for entry in batch_list:
if type(entry[0]) is dict:
# make them all into a new dict
ret = {}
for k in entry[0].keys():
ret[k] = torch.stack([obj[k] for obj in entry])
all_parsed.append(ret)
else:
all_parsed.append(torch.LongTensor(entry))
return tuple(all_parsed)
================================================
FILE: environment.yml
================================================
name: i2sdf
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- cudatoolkit=11.3.1=h9edb442_10
- ffmpeg=4.3=hf484d3e_0
- numpy=1.23.5=py39h14f4228_0
- pip=23.0.1=py39h06a4308_0
- python=3.9.16=h7a1cb2a_2
- pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0
- torchaudio=0.12.1=py39_cu113
- torchvision=0.13.1=py39_cu113
- pip:
- fast-pytorch-kmeans==0.1.9
- ffmpeg-python==0.2.0
- gputil==1.4.0
- lpips==0.1.4
- open3d==0.17.0
- opencv-python==4.7.0.72
- pymcubes==0.1.4
- pytorch-lightning==1.9.0
- pyyaml==6.0
- rich==13.3.3
- scikit-image==0.20.0
- scikit-learn==1.2.2
- scipy==1.9.1
- tensorboard==2.12.0
- tensorboardx==2.6
- torchmetrics==0.11.4
- tqdm==4.65.0
- trimesh==3.21.4
================================================
FILE: i2-sdf-dataset-links.csv
================================================
file,url
interiorverse/i2-sdf/i2-sdf/bedroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_0.zip
interiorverse/i2-sdf/i2-sdf/bedroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_0_preview.png
interiorverse/i2-sdf/i2-sdf/bedroom_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_1.zip
interiorverse/i2-sdf/i2-sdf/bedroom_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_1_preview.png
interiorverse/i2-sdf/i2-sdf/bedroom_relight_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_0.zip
interiorverse/i2-sdf/i2-sdf/bedroom_relight_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_0_preview.png
interiorverse/i2-sdf/i2-sdf/bedroom_relight_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_1.zip
interiorverse/i2-sdf/i2-sdf/bedroom_relight_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_1_preview.png
interiorverse/i2-sdf/i2-sdf/diningroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/diningroom_0.zip
interiorverse/i2-sdf/i2-sdf/diningroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/diningroom_0_preview.png
interiorverse/i2-sdf/i2-sdf/kitchen_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/kitchen_0.zip
interiorverse/i2-sdf/i2-sdf/kitchen_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/kitchen_0_preview.png
interiorverse/i2-sdf/i2-sdf/livingroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_0.zip
interiorverse/i2-sdf/i2-sdf/livingroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_0_preview.png
interiorverse/i2-sdf/i2-sdf/livingroom_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_1.zip
interiorverse/i2-sdf/i2-sdf/livingroom_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_1_preview.png
interiorverse/i2-sdf/i2-sdf/real_data/inria_livingroom.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/inria_livingroom.zip
interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_0.zip
interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_1.zip
interiorverse/i2-sdf/i2-sdf/real_data/nisr_livingroom.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_livingroom.zip
================================================
FILE: main_recon.py
================================================
import torch
import yaml
import pytorch_lightning as pl
import argparse
import os
import utils
import model
from pytorch_lightning import loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from rich.progress import TextColumn
import GPUtil
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--conf", type=str, required=True, help="Path to (.yml) config file.")
parser.add_argument('-d', "--device_ids", type=int, nargs='+', default=None, help="GPU devices to use")
parser.add_argument("--exps_folder", type=str, default="exps")
parser.add_argument('--expname', type=str, default='')
parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.')
parser.add_argument('--test', action='store_true')
parser.add_argument('--test_mode', choices=['render', 'mesh', 'interpolate'], default='render')
parser.add_argument('-v', '--version', type=int, nargs='?')
parser.add_argument('--inter_id', type=int, nargs=2, required=False, help='2 view ids for interpolation video.')
parser.add_argument('-i', '--indices', nargs='*', type=int, help='If set, render only specified indices of the dataset instead of all images.')
parser.add_argument('--n_frames', type=int, default=60, help='Number of frames in the interpolation video.')
parser.add_argument('--frame_rate', type=int, default=24, help='Frame rate of the interpolation video.')
parser.add_argument('-f', '--full_res', action='store_true', help='If set, dataset downscaling will be ignored.')
parser.add_argument('--is_val', action='store_true', help='If set, render the validation set instead of training set.')
parser.add_argument('--val_mesh', action='store_true', help='If set, extract and save mesh every validation epoch.')
parser.add_argument('--score', action='store_true', help='If set, evaluate the meshing score (need to provide GT mesh).')
parser.add_argument('--far_clip', type=float, default=5.0)
parser.add_argument('--ckpt', type=str, default='last')
parser.add_argument('--resolution', type=int, default=512, help='Resolution for marching cube algorithm')
parser.add_argument('--spp', type=int, default=128)
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
with open(args.conf) as f:
cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
cfg = utils.CfgNode(cfg_dict)
expname = args.expname if args.expname else cfg.train.expname
scan_id = cfg.dataset.scan_id if args.scan_id == -1 else args.scan_id
cfg.dataset.scan_id = scan_id
expname = expname + '_' + str(scan_id)
if args.version is None and (v := args.conf.find("version_")) != -1:
args.version = int(args.conf[v + 8:args.conf.find("/config")])
print(f"[INFO] Loaded version {args.version} from config file")
if args.version is not None:
logger = loggers.TensorBoardLogger(save_dir=args.exps_folder, name=expname, version=args.version)
else:
logger = loggers.TensorBoardLogger(save_dir=args.exps_folder, name=expname)
if args.device_ids is None:
args.device_ids = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False,
excludeID=[], excludeUUID=[])
print("Selected GPU {} automatically".format(args.device_ids[0]))
torch.cuda.set_device(args.device_ids[0])
torch.set_float32_matmul_precision('medium')
progbar_callback = utils.RichProgressBarWithScanId(scan_id, leave=False)
pl.seed_everything(args.seed)
if args.test:
version = args.version if args.version is not None else logger.version - 1
exp_dir = os.path.join(logger.root_dir, f"version_{version}")
del logger
if args.test_mode == 'render':
system = model.VolumeRenderSystem(cfg, exp_dir, indices=args.indices, is_val=args.is_val, full_res=args.full_res)
if not args.ckpt.endswith('.ckpt'):
args.ckpt += '.ckpt'
ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')
system.load_state_dict(ckpt['state_dict'])
model.lpips.cuda()
elif args.test_mode == 'mesh':
system = model.SDFMeshSystem(cfg, exp_dir, args.resolution, args.score)
if not args.ckpt.endswith('.ckpt'):
args.ckpt += '.ckpt'
ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')
system.load_state_dict(ckpt['state_dict'])
system.cuda()
system.eval()
system.initialize()
# elif args.test_mode == 'interpolate':
else:
system = model.ViewInterpolateSystem(cfg, exp_dir, *args.inter_id, n_frames=args.n_frames, frame_rate=args.frame_rate)
if not args.ckpt.endswith('.ckpt'):
args.ckpt += '.ckpt'
ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')
system.load_state_dict(ckpt['state_dict'])
trainer = pl.Trainer(
logger=False,
accelerator='gpu',
devices=args.device_ids,
callbacks=[progbar_callback]
)
trainer.test(system)
else:
max_steps = cfg.train.get('steps', 200000)
print(f"Training for {max_steps} steps")
exp_dir = logger.log_dir
checkpoint_callback = ModelCheckpoint(os.path.join(exp_dir, 'checkpoints'), save_last=True, every_n_train_steps=cfg.train.checkpoint_freq)
if hasattr(cfg.train, 'plot_freq'):
kwargs = {'val_check_interval': cfg.train.plot_freq}
else:
kwargs = {'check_val_every_n_epoch': cfg.train.plot_epochs}
trainer = pl.Trainer(
logger=logger,
accelerator='gpu',
devices=args.device_ids,
strategy=None,
callbacks=[checkpoint_callback, progbar_callback],
max_steps=max_steps,
**kwargs
)
system = model.ReconstructionTrainer(
cfg, progbar_callback,
exp_dir=exp_dir,
is_val=args.is_val,
val_mesh=args.val_mesh
)
trainer.fit(system)
torch.cuda.empty_cache()
================================================
FILE: model/__init__.py
================================================
from .network import *
from .trainer import *
# from .material import *
from .eval import *
from .rendering import RenderingLayer
================================================
FILE: model/eval/__init__.py
================================================
from .recon import *
================================================
FILE: model/eval/recon.py
================================================
import torch
import pytorch_lightning as pl
import numpy as np
import os
from glob import glob
from torch.utils.data import DataLoader
import utils
from utils import rend_util
import utils.plots as plt
import dataset
import model
from skimage import measure
import cv2
import trimesh
from rich.progress import track
from torchmetrics.functional import structural_similarity_index_measure as ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
lpips = LPIPS()
class SDFMeshSystem(pl.LightningModule):
def __init__(self, conf, exp_dir, resolution, score=False, far_clip=5.0) -> None:
super().__init__()
self.expdir = exp_dir
conf_model = conf.model
conf_model.use_normal = False
self.model = model.I2SDFNetwork(conf_model)
self.resolution = resolution
self.grid_boundary = conf.plot.grid_boundary
self.initialized = False
self.instance_dir = os.path.join('data', conf.dataset.data_dir, 'scan{0}'.format(conf.dataset.scan_id))
camera_dict = np.load(os.path.join(self.instance_dir, 'cameras_normalize.npz'))
self.scale_mat = camera_dict['scale_mat_0']
self.scan_id = conf.dataset.scan_id
self.score = score
if score:
self.n_imgs = len(os.listdir(os.path.join(self.instance_dir, 'image')))
self.poses = []
self.far_clip = far_clip
for i in range(self.n_imgs):
K, pose = rend_util.load_K_Rt_from_P(None, camera_dict[f'world_mat_{i}'][:3,:])
self.poses.append(pose)
self.K = K
self.H, self.W, _ = cv2.imread(os.path.join(self.instance_dir, 'image', '0000.png')).shape
def initialize(self):
grid = plt.get_grid_uniform(100, self.grid_boundary)
z = []
points = grid['grid_points']
for pnts in track(torch.split(points, 1000000, dim=0)):
z.append(self.model.implicit_network(pnts)[:,0].detach().cpu().numpy())
z = np.concatenate(z, axis=0).astype(np.float32)
verts, faces, normals, values = measure.marching_cubes(
volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
level=0,
spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
grid['xyz'][0][2] - grid['xyz'][0][1],
grid['xyz'][0][2] - grid['xyz'][0][1]))
verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
mesh_low_res = trimesh.Trimesh(verts, faces, normals)
recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
recon_pc = torch.from_numpy(recon_pc).float().cuda()
s_mean = recon_pc.mean(dim=0)
s_cov = recon_pc - s_mean
s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
self.vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]
if torch.det(self.vecs) < 0:
self.vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), self.vecs)
helper = torch.bmm(self.vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
(recon_pc - s_mean).unsqueeze(-1)).squeeze().cpu()
grid_aligned = plt.get_grid(helper, self.resolution)
grid_points = grid_aligned['grid_points']
g = []
for pnts in track(torch.split(grid_points, 1000000, dim=0)):
g.append(torch.bmm(self.vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
pnts.unsqueeze(-1)).squeeze() + s_mean)
grid_points = torch.cat(g, dim=0)
points = grid_points.cpu()
self.test_dataset = dataset.GridDataset(points, grid_aligned['xyz'])
self.grid_points = grid_points
self.initialized = True
def test_dataloader(self):
assert self.initialized
print(len(self.test_dataset))
return DataLoader(self.test_dataset, batch_size=2000000, shuffle=False, num_workers=32)
def test_step(self, batch, batch_idx):
return self.model.implicit_network(batch)[:,0].detach().cpu().numpy()
def test_epoch_end(self, outputs) -> None:
# z = torch.cat(outputs, dim=0).cpu().numpy()
z = np.concatenate(outputs, axis=0).astype(np.float32)
if (not (np.min(z) > 0 or np.max(z) < 0)):
verts, faces, normals, values = measure.marching_cubes(
volume=z.reshape(self.test_dataset.xyz[1].shape[0], self.test_dataset.xyz[0].shape[0], self.test_dataset.xyz[2].shape[0]).transpose([1, 0, 2]),
level=0,
spacing=(self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1],
self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1],
self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1]))
verts = torch.from_numpy(verts).float().cuda()
verts = torch.bmm(self.vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
verts.unsqueeze(-1)).squeeze()
verts = (verts + self.grid_points[0]).cpu().numpy()
mesh = trimesh.Trimesh(verts, faces, normals)
mesh.apply_transform(self.scale_mat)
mesh_folder = os.path.join(self.expdir, 'eval/mesh')
os.makedirs(mesh_folder, exist_ok=True)
mesh.export(os.path.join(mesh_folder, 'scan{0}.ply'.format(self.scan_id)), 'ply')
if self.score:
from utils import mesh_util
import open3d as o3d
mesh = mesh_util.refuse(mesh, self.poses, self.K, self.H, self.W)
out_mesh_path = os.path.join(mesh_folder, 'scan{0}_refined.ply'.format(self.scan_id))
o3d.io.write_triangle_mesh(out_mesh_path, mesh)
mesh = trimesh.load(out_mesh_path)
print("[INFO] Pred mesh refined")
gt_mesh = trimesh.load(os.path.join(self.instance_dir, 'mesh.ply'))
gt_mesh = mesh_util.refuse(gt_mesh, self.poses, self.K, self.H, self.W, self.far_clip)
out_mesh_path = os.path.join(mesh_folder, 'scan{0}_gt.ply'.format(self.scan_id))
o3d.io.write_triangle_mesh(out_mesh_path, gt_mesh)
gt_mesh = trimesh.load(out_mesh_path)
print("[INFO] GT mesh refined")
metrics = mesh_util.evaluate(mesh, gt_mesh)
with open(f"{mesh_folder}/metrics.txt", 'w') as f:
for k in metrics:
f.write(f"{k.upper()}: {metrics[k]}\n")
print(f"[INFO] Metrics saved to {mesh_folder}/metrics.txt\n")
def forward(self):
raise NotImplementedError("forward not supported by trainer")
class VolumeRenderSystem(pl.LightningModule):
def __init__(self, conf, exp_dir, indices=None, is_val=False, score_mesh=False, full_res=False) -> None:
super().__init__()
self.expdir = exp_dir
conf_model = conf.model
conf_model.use_normal = False
self.model = model.I2SDFNetwork(conf_model)
self.scan_id = conf.dataset.scan_id
dataset_conf = conf.dataset
if full_res:
dataset_conf.downsample = 1
self.test_dataset = dataset.PlotDataset(**dataset_conf, plot_nimgs=-1, shuffle=False, indices=indices, is_val=is_val)
self.total_pixels = self.test_dataset.total_pixels
self.img_res = self.test_dataset.img_res
self.split_n_pixels = conf.train.split_n_pixels
self.expdir = os.path.join(self.expdir, 'eval')
if is_val:
self.expdir = os.path.join(self.expdir, 'test')
os.makedirs(os.path.join(self.expdir, 'rendering'), exist_ok=True)
os.makedirs(os.path.join(self.expdir, 'depth'), exist_ok=True)
os.makedirs(os.path.join(self.expdir, 'normal'), exist_ok=True)
def test_dataloader(self):
print(len(self.test_dataset))
return DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.test_dataset.collate_fn)
@torch.inference_mode(False)
@torch.no_grad()
def test_step(self, batch, batch_idx):
indices, model_input, ground_truth = batch
# idx = batch_idx
idx = self.test_dataset.indices[batch_idx]
split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)
res = []
for s in split:
out = utils.detach_dict(self.model(s))
d = {
'rgb_values': out['rgb_values'].detach(),
'depth_values': out['depth_values'].detach()
}
d['normal_map'] = out['normal_map'].detach()
# d['surface_point'] = out['surface_point'].detach()
del out
res.append(d)
model_outputs = utils.merge_output(res, self.total_pixels, 1)
_, num_samples, _ = ground_truth['rgb'].shape
model_outputs['rgb_values'] = model_outputs['rgb_values'].reshape(1, num_samples, 3)
model_outputs['depth_values'] = model_outputs['depth_values'].reshape(1, num_samples, 1)
plt.plot_imgs_wo_gt(model_outputs['normal_map'].reshape(1, num_samples, 3), self.expdir, "{:04d}w".format(idx), 1, self.img_res, is_hdr=True)
normal_map = model_outputs['normal_map'].reshape(num_samples, 3).T # (3, h*w)
R = model_input['pose'].squeeze()[:3,:3].T
normal_map = R @ normal_map
model_outputs['normal_map'] = normal_map.T.reshape(1, num_samples, 3)
plt.plot_imgs_wo_gt(model_outputs['normal_map'], self.expdir, "{:04d}".format(idx), 1, self.img_res, is_hdr=True)
model_outputs['normal_map'] = (model_outputs['normal_map'] + 1.) / 2.
plt.plot_imgs_wo_gt(model_outputs['normal_map'], self.expdir, "{:04d}".format(idx), 1, self.img_res)
# plt.plot_imgs_wo_gt(model_outputs['surface_point'].reshape(1, num_samples, 3), self.expdir, "{:04d}p".format(idx), 1, self.img_res, is_hdr=True)
plt.plot_images(model_outputs['rgb_values'], ground_truth['rgb'], self.expdir, "{:04d}".format(idx), 1, self.img_res)
plt.plot_imgs_wo_gt(model_outputs['rgb_values'], self.expdir, "{:04d}_pred".format(idx), 1, self.img_res, 'rendering')
plt.plot_depths(model_outputs['depth_values'], self.expdir, "{:04d}".format(idx), 1, self.img_res)
plt.plot_depths(model_outputs['depth_values'], self.expdir, "{:04d}".format(idx), 1, self.img_res, None)
pred_img = model_outputs['rgb_values'].T.reshape(3, *self.img_res).unsqueeze(0)
gt_img = ground_truth['rgb'].T.reshape(3, *self.img_res).unsqueeze(0)
return {
'psnr': utils.get_psnr(model_outputs['rgb_values'], ground_truth['rgb']).item(),
'ssim': ssim(pred_img, gt_img).item(),
'lpips': lpips(pred_img.clamp(0, 1) * 2 - 1, gt_img.clamp(0, 1) * 2 - 1).item()
}
def test_epoch_end(self, outputs):
with open(os.path.join(self.expdir, 'metrics.txt'), 'w') as f:
f.write(f"# IMAGE RESOLUTION {self.img_res}\n")
psnr_sum = ssim_sum = lpips_sum = 0
psnrs = []
ssims = []
lpipss = []
for i, metrics in enumerate(outputs):
f.write(f"[{i:04d}] [PSNR]{metrics['psnr']:.2f} [SSIM]{metrics['ssim']:.2f} [LPIPS]{metrics['lpips']:.2f}\n")
psnrs.append(metrics['psnr'])
ssims.append(metrics['ssim'])
lpipss.append(metrics['lpips'])
psnr_sum += metrics['psnr']
ssim_sum += metrics['ssim']
lpips_sum += metrics['lpips']
f.write(f"[MEAN] [PSNR]{psnr_sum/len(outputs):.2f} [SSIM]{ssim_sum/len(outputs):.2f} [LPIPS]{lpips_sum/len(outputs):.2f}\n")
np.savez_compressed(os.path.join(self.expdir, 'metrics.npz'), psnr=np.array(psnrs), ssim=np.array(ssims), lpips=np.array(lpipss))
def forward(self):
raise NotImplementedError("forward not supported by trainer")
class ViewInterpolateSystem(pl.LightningModule):
def __init__(self, conf, exp_dir, id0, id1, n_frames=60, frame_rate=24, use_normal=True) -> None:
super().__init__()
self.expdir = exp_dir
conf_model = conf.model
conf_model.use_normal = False
self.model = model.I2SDFNetwork(conf_model)
self.scan_id = conf.dataset.scan_id
dataset_conf = conf.dataset
self.test_dataset = dataset.InterpolateDataset(**dataset_conf, id0=id0, id1=id1, num_frames=n_frames)
self.total_pixels = self.test_dataset.total_pixels
self.img_res = self.test_dataset.img_res
self.split_n_pixels = conf.train.split_n_pixels
self.n_frames = n_frames
self.frame_rate = frame_rate
self.video_dir = os.path.join(self.expdir, 'eval/interpolate')
self.id0 = id0
self.id1 = id1
self.use_normal = use_normal
os.makedirs(self.video_dir, exist_ok=True)
self.frame_dir = os.path.join(self.video_dir, f"{self.id0:04d}_{self.id1:04d}")
os.makedirs(self.frame_dir, exist_ok=True)
if self.use_normal:
self.normal_fdir = os.path.join(self.video_dir, f"{self.id0:04d}_{self.id1:04d}_normal")
os.makedirs(self.normal_fdir, exist_ok=True)
def test_dataloader(self):
print(len(self.test_dataset))
return DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.test_dataset.collate_fn)
@torch.inference_mode(False)
@torch.no_grad()
def test_step(self, batch, batch_idx):
indices, model_input = batch
idx = batch_idx
split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)
res = []
res_normal = []
for s in split:
out = utils.detach_dict(self.model(s, predict_only=not self.use_normal))
rgb = out['rgb_values'].detach()
res.append(rgb)
if self.use_normal:
res_normal.append(out['normal_map'])
del out
rendered = torch.cat(res, dim=0).reshape(self.img_res[0], self.img_res[1], 3).cpu().numpy()
rendered = (rendered * 255).clip(0, 255).astype(np.uint8)
cv2.imwrite(f"{self.frame_dir}/{idx:04d}.png", rendered[:,:,::-1])
if self.use_normal:
normal_map = torch.cat(res_normal, dim=0).reshape(-1, 3).T
R = model_input['pose'].squeeze()[:3,:3].T
normal_map = R @ normal_map
normal_map = normal_map.T.reshape(self.img_res[0], self.img_res[1], 3).cpu().numpy()
normal_map = (((normal_map + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8)
cv2.imwrite(f"{self.normal_fdir}/{idx:04d}.png", normal_map[:,:,::-1])
def test_epoch_end(self, outputs):
import ffmpeg
(
ffmpeg
.input(os.path.join(self.frame_dir, '*.png'), pattern_type='glob', framerate=self.frame_rate)
.output(os.path.join(self.video_dir, f"scan{self.scan_id}_{self.id0:04d}_{self.id1:04d}.mp4"), vcodec='h264')
.overwrite_output()
.run()
)
if self.use_normal:
(
ffmpeg
.input(os.path.join(self.normal_fdir, '*.png'), pattern_type='glob', framerate=self.frame_rate)
.output(os.path.join(self.video_dir, f"scan{self.scan_id}_{self.id0:04d}_{self.id1:04d}_normal.mp4"), vcodec='h264')
.overwrite_output()
.run()
)
def forward(self):
raise NotImplementedError("forward not supported by trainer")
================================================
FILE: model/network/__init__.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import utils
from model.network.mlp import ImplicitNetwork, RenderingNetwork
from model.network.density import LaplaceDensity, AbsDensity
from model.network.ray_sampler import ErrorBoundSampler
from fast_pytorch_kmeans import KMeans
from sklearn.cluster import DBSCAN
"""
For modeling more complex backgrounds, we follow the inverted sphere parametrization from NeRF++
https://github.com/Kai-46/nerfplusplus
"""
class I2SDFNetwork(nn.Module):
def __init__(self, conf):
super().__init__()
self.feature_vector_size = conf.feature_vector_size
self.scene_bounding_sphere = getattr(conf, 'scene_bounding_sphere', 1.0)
# Foreground object's networks
self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0, **conf.implicit_network)
self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.rendering_network)
self.use_light = hasattr(conf, 'light_network')
if self.use_light:
# self.light_network = RenderingNetwork(self.feature_vector_size, mode='nerf', output_activation='sigmoid', use_dir=False, **conf.light_network)
self.light_network = ImplicitNetwork(0, 0, d_in=self.feature_vector_size, d_out=1, geometric_init=False, embed_type=None, output_activation='sigmoid', **conf.light_network)
self.density = LaplaceDensity(**conf.density)
# Background's networks
self.use_bg = hasattr(conf, 'bg_network')
if self.use_bg:
bg_feature_vector_size = conf.bg_network.feature_vector_size
self.bg_implicit_network = ImplicitNetwork(bg_feature_vector_size, 0.0, **conf.bg_network.implicit_network)
self.bg_rendering_network = RenderingNetwork(bg_feature_vector_size, **conf.bg_network.rendering_network)
self.bg_density = AbsDensity(**getattr(conf.bg_network, 'density', {}))
else:
print("[INFO] BG Network Disabled")
self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, inverse_sphere_bg=self.use_bg, **conf.ray_sampler)
self.use_normal = conf.get('use_normal', False)
self.detach_light_feature = conf.get('detach_light_feature', True)
def init_emission_groups(self, n_emitters, pointcloud, init_emission=1.0, use_dbscan=False):
if use_dbscan:
"""
Use DBSCAN algorithm to initialize emitter cluster centroids for K-Means from a small random batch
Note that DBSCAN can automatically determine the number of clusters
"""
pt_samples = pointcloud[torch.randperm(len(pointcloud))[:10000]].cpu().numpy()
lab_samples = torch.from_numpy(DBSCAN(n_jobs=16).fit_predict(pt_samples))
if n_emitters != len(torch.unique(lab_samples)):
print(f"[ERROR] Inconsistent emitter count: {n_emitters} / {len(torch.unique(lab_samples))}")
# n_emitters = len(torch.unique(lab_samples))
exit()
init_centroids = torch.zeros(n_emitters, 3)
for i in range(n_emitters):
idx = (lab_samples == i).int().argmax()
init_centroids[i,:] = torch.from_numpy(pt_samples[idx, :])
init_centroids = init_centroids.to(pointcloud.device)
else:
"""
Use K-Means plus plus to initialize emitter cluster centroids for K-Means
"""
init_centroids = utils.kmeans_pp_centroid(pointcloud, n_emitters)
self.emitter_clusters = KMeans(n_emitters)
labels = self.emitter_clusters.fit_predict(pointcloud, init_centroids)
print("[INFO] emitters clustered")
self.emissions = nn.Parameter(torch.empty(n_emitters, 3).fill_(init_emission), True)
return labels, self.emitter_clusters.centroids
def get_param_groups(self, lr):
return [{'params': self.parameters(), 'lr': lr}]
def forward(self, input, predict_only=False):
intrinsics = input["intrinsics"]
uv = input["uv"]
pose = input["pose"]
ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics)
batch_size, num_pixels, _ = ray_dirs.shape
cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)
ray_dirs = ray_dirs.reshape(-1, 3)
ray_dirs_norm = torch.linalg.vector_norm(ray_dirs, dim=1)
ray_dirs = F.normalize(ray_dirs, dim=1)
z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self)
if self.use_bg:
z_vals, z_vals_bg = z_vals
z_max = z_vals[:,-1]
z_vals = z_vals[:,:-1]
N_samples = z_vals.shape[1]
points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1)
points_flat = points.reshape(-1, 3)
dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)
dirs_flat = dirs.reshape(-1, 3)
returns_grad = self.use_normal or (not self.training) or (self.rendering_network.mode == 'idr')
# with torch.enable_grad():
with torch.set_grad_enabled(returns_grad):
# with torch.inference_mode(not returns_grad):
sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat, returns_grad)
rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors)
rgb = rgb_flat.reshape(-1, N_samples, 3)
weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf)
fg_rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1)
weight_sum = torch.sum(weights, -1, keepdim=True)
# dist = torch.sum(weights / weight_sum.clamp(min=1e-6) * z_vals, 1)
dist = torch.sum(weights * z_vals, 1)
depth_values = dist / torch.clamp(ray_dirs_norm, min=1e-6)
# depth_values = torch.sum(weights * z_vals, 1) / torch.clamp(torch.sum(weights, 1), min=1e-6) # (bn,)
# Background rendering
if self.use_bg:
N_bg_samples = z_vals_bg.shape[1]
z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ]) # 1--->0
bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1)
bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1)
bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg) # [..., N_samples, 4]
bg_points_flat = bg_points.reshape(-1, 4)
bg_dirs_flat = bg_dirs.reshape(-1, 3)
output = self.bg_implicit_network(bg_points_flat)
bg_sdf = output[:,:1]
bg_feature_vectors = output[:, 1:]
bg_rgb_flat = self.bg_rendering_network(None, None, bg_dirs_flat, bg_feature_vectors)
bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)
bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf)
bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1)
# Composite foreground and background
bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values
rgb_values = fg_rgb_values + bg_rgb_values
else:
rgb_values = fg_rgb_values
output = {
'rgb_values': rgb_values,
'depth_values': depth_values,
'weight_sum': weight_sum
}
if self.use_light:
light_features = F.relu(feature_vectors)
if self.detach_light_feature:
light_features = light_features.detach_()
# lmask_flat = self.light_network(None, None, None, light_features)
lmask_flat = self.light_network(light_features)
lmask = lmask_flat.reshape(-1, N_samples, 1)
lmask_values = torch.sum(weights.unsqueeze(-1).detach() * lmask, 1)
output['light_mask'] = lmask_values
if predict_only:
return output
if self.training:
# Sample points for the eikonal loss
n_eik_points = batch_size * num_pixels
eikonal_points = torch.empty(n_eik_points, 3, device=cam_loc.device).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere)
# Add some of the near surface points
eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3)
n_eik_near = eik_near_points.size(0)
eikonal_points = torch.cat([eikonal_points, eik_near_points], 0)
# Add neighbor points near surface for smoothness loss
eik_near_neighbors = eik_near_points + torch.empty_like(eik_near_points).uniform_(-0.005, 0.005)
eikonal_points = torch.cat([eikonal_points, eik_near_neighbors], 0)
grad_theta = self.implicit_network.gradient(eikonal_points)
output['grad_theta'] = grad_theta[:n_eik_points+n_eik_near,]
normals = grad_theta[n_eik_points:,]
normals = F.normalize(normals, dim=1, eps=1e-6)
diff_norm = torch.norm(normals[:n_eik_near,:] - normals[n_eik_near:,:], dim=1)
output['diff_norm'] = diff_norm
# Sample pointclouds for bubble loss
if 'pointcloud' in input:
surface_points = input['pointcloud']
cam_loc_selected = cam_loc[np.random.randint(0, len(cam_loc)),:]
surface_points = torch.cat([surface_points, cam_loc_selected.unsqueeze(0)], dim=0)
surface_sdf = self.implicit_network.get_sdf_vals(surface_points)
output['surface_sdf'] = surface_sdf[:-1,:]
# Accumulate gradients for normal loss
if self.use_normal:
normals = F.normalize(gradients, dim=-1)
normals = normals.reshape(-1, N_samples, 3)
normal_map = torch.sum(weights.unsqueeze(-1).detach() * normals, 1)
normal_map = F.normalize(normal_map, dim=-1)
output['normal_values'] = normal_map
# elif not self.training:
else:
# Accumulate gradients for normal visualization
gradients = gradients.detach()
normals = F.normalize(gradients, dim=-1)
normals = normals.reshape(-1, N_samples, 3)
normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1)
normal_map = F.normalize(normal_map, dim=-1)
output['normal_map'] = normal_map
return output
def volume_rendering(self, z_vals, z_max, sdf):
density_flat = self.density(sdf)
density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples
# included also the dist from the sphere intersection
dists = z_vals[:, 1:] - z_vals[:, :-1]
dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1)
# LOG SPACE
free_energy = dists * density
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=free_energy.device), free_energy], dim=-1) # add 0 for transperancy 1 at t_0
alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here
transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now
fg_transmittance = transmittance[:, :-1]
weights = alpha * fg_transmittance # probability of the ray hits something here
bg_transmittance = transmittance[:, -1] # factor to be multiplied with the bg volume rendering
return weights, bg_transmittance
def bg_volume_rendering(self, z_vals_bg, bg_sdf):
bg_density_flat = self.bg_density(bg_sdf)
bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples
bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:]
bg_dists = torch.cat([bg_dists, torch.tensor([1e10], device=bg_dists.device).unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1)
# LOG SPACE
bg_free_energy = bg_dists * bg_density
bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1, device=bg_free_energy.device), bg_free_energy[:, :-1]], dim=-1) # shift one step
bg_alpha = 1 - torch.exp(-bg_free_energy) # probability of it is not empty here
bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1)) # probability of everything is empty up to now
bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here
return bg_weights
def depth2pts_outside(self, ray_o, ray_d, depth):
'''
ray_o, ray_d: [..., 3]
depth: [...]; inverse of distance to sphere origin
'''
o_dot_d = torch.sum(ray_d * ray_o, dim=-1)
under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.scene_bounding_sphere ** 2)
d_sphere = torch.sqrt(under_sqrt) - o_dot_d
p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d
p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d
p_mid_norm = torch.norm(p_mid, dim=-1)
rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
phi = torch.asin(p_mid_norm / self.scene_bounding_sphere)
theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
# now rotate p_sphere
# Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
p_sphere_new = p_sphere * torch.cos(rot_angle) + \
torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle))
p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
return pts
class I2SDFLoss(nn.Module):
def __init__(self, eikonal_weight=0.1, smooth_weight=0.0, mask_weight=0.0, depth_weight=0.1, normal_weight=0.05, angular_weight=0.05, bubble_weight=0.0, min_bubble_iter=0, max_bubble_iter=None, smooth_iter=None, light_mask_weight=0.0, eikonal_weight_bubble=0.0):
super().__init__()
self.eikonal_weight = eikonal_weight
self.rgb_loss = F.l1_loss
self.smooth_weight = smooth_weight
self.mask_weight = mask_weight
self.depth_weight = depth_weight
self.normal_weight = normal_weight
self.angular_weight = angular_weight
self.bubble_weight = bubble_weight
# self.eikonal_weight_bubble = eikonal_weight_bubble if eikonal_weight_bubble else self.eikonal_weight
self.min_bubble_iter = min_bubble_iter
self.max_bubble_iter = max_bubble_iter
self.smooth_iter = smooth_iter
if self.bubble_weight > 0 and self.max_bubble_iter is not None and self.smooth_iter < self.max_bubble_iter:
self.smooth_iter = self.max_bubble_iter # Disable smoothness loss during bubble steps
self.light_mask_weight = light_mask_weight
def get_rgb_loss(self, rgb_values, rgb_gt):
rgb_gt = rgb_gt.reshape(-1, 3)
rgb_loss = self.rgb_loss(rgb_values, rgb_gt)
return rgb_loss
def get_eikonal_loss(self, grad_theta):
eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
return eikonal_loss
def get_mask_loss(self, mask_pred, mask_gt):
return F.binary_cross_entropy(mask_pred.clip(1e-3, 1.0 - 1e-3), mask_gt)
def get_depth_loss(self, depth, depth_gt, depth_mask):
depth_gt = depth_gt.flatten()
depth_mask = depth_mask.flatten()
# TODO: Add support for scale invariant depth loss (like MonoSDF)
return F.mse_loss(depth[depth_mask], depth_gt[depth_mask])
def get_normal_l1_loss(self, normal, normal_gt, normal_mask):
normal_gt = normal_gt.reshape(-1, 3)
normal_mask = normal_mask.flatten()
return torch.abs(1 - torch.sum(normal[normal_mask] * normal_gt[normal_mask], dim=-1)).mean()
def get_normal_angular_loss(self, normal, normal_gt, normal_mask):
normal_gt = normal_gt.reshape(-1, 3)
normal_mask = normal_mask.flatten()
dot = torch.sum(normal[normal_mask] * normal_gt[normal_mask], dim=-1)
angle = torch.acos(torch.clamp(dot, -1.0+1e-6, 1.0-1e-6)) / math.tau
return angle.clamp_max(0.5).abs().mean()
def forward(self, model_outputs, ground_truth, current_step):
rgb_gt = ground_truth['rgb']
rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt)
if 'grad_theta' in model_outputs:
eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
else:
eikonal_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
smooth_activated = self.smooth_iter is None or current_step > self.smooth_iter
if smooth_activated and self.smooth_weight > 0 and 'diff_norm' in model_outputs:
smooth_loss = model_outputs['diff_norm'].mean()
else:
smooth_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
if 'mask' in ground_truth and self.mask_weight > 0:
mask_loss = self.get_mask_loss(model_outputs['weight_sum'], ground_truth['mask'])
else:
mask_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
if 'depth' in ground_truth and self.depth_weight > 0:
depth_loss = self.get_depth_loss(model_outputs['depth_values'], ground_truth['depth'], ground_truth['depth_mask'])
else:
depth_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
if 'normal' in ground_truth and self.normal_weight > 0:
normal_loss = self.get_normal_l1_loss(model_outputs['normal_values'], ground_truth['normal'], ground_truth['normal_mask'])
else:
normal_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
if 'normal' in ground_truth and self.angular_weight > 0:
angular_loss = self.get_normal_l1_loss(model_outputs['normal_values'], ground_truth['normal'], ground_truth['normal_mask'])
else:
angular_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
if 'surface_sdf' in model_outputs and self.bubble_weight > 0:
bubble_loss = model_outputs['surface_sdf'].abs().mean()
else:
bubble_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
if 'light_mask' in model_outputs and self.light_mask_weight > 0:
light_mask_loss = self.get_mask_loss(model_outputs['light_mask'].reshape(-1, 1), ground_truth['light_mask'].reshape(-1, 1))
else:
light_mask_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
loss = rgb_loss + \
self.eikonal_weight * eikonal_loss + \
self.smooth_weight * smooth_loss + \
self.mask_weight * mask_loss + \
self.depth_weight * depth_loss + \
self.normal_weight * normal_loss + \
self.angular_weight * angular_loss + \
self.bubble_weight * bubble_loss + \
self.light_mask_weight * light_mask_loss
output = {
'loss': loss,
'rgb_loss': rgb_loss,
'eikonal_loss': eikonal_loss,
'smooth_loss': smooth_loss,
'mask_loss': mask_loss,
'depth_loss': depth_loss,
'normal_loss': normal_loss,
'angular_loss': angular_loss,
'bubble_loss': bubble_loss,
'light_mask_loss': light_mask_loss
}
return output
================================================
FILE: model/network/density.py
================================================
import torch.nn as nn
import torch
class Density(nn.Module):
def __init__(self, params_init={}):
super().__init__()
for p in params_init:
param = nn.Parameter(torch.tensor(params_init[p]))
setattr(self, p, param)
def forward(self, sdf, beta=None):
return self.density_func(sdf, beta=beta)
class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
def __init__(self, params_init={}, beta_min=0.0001):
super().__init__(params_init=params_init)
self.beta_min = torch.tensor(beta_min)
def density_func(self, sdf, beta=None):
if beta is None:
beta = self.get_beta()
alpha = 1 / beta
return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))
def get_beta(self):
beta = self.beta.abs() + self.beta_min
return beta
class AbsDensity(Density): # like NeRF++
def density_func(self, sdf, beta=None):
return torch.abs(sdf)
class SimpleDensity(Density): # like NeRF
def __init__(self, params_init={}, noise_std=1.0):
super().__init__(params_init=params_init)
self.noise_std = noise_std
def density_func(self, sdf, beta=None):
if self.training and self.noise_std > 0.0:
noise = torch.randn(sdf.shape).to(sdf.device) * self.noise_std
sdf = sdf + noise
return torch.relu(sdf)
================================================
FILE: model/network/embedder.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class Embedder:
""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
class SHEncoder(nn.Module):
def __init__(self, input_dims=3, degree=4):
super().__init__()
self.input_dims = input_dims
self.degree = degree
assert self.input_dims == 3
assert self.degree >= 1 and self.degree <= 5
self.out_dim = degree ** 2
self.C0 = 0.28209479177387814
self.C1 = 0.4886025119029199
self.C2 = [
1.0925484305920792,
-1.0925484305920792,
0.31539156525252005,
-1.0925484305920792,
0.5462742152960396
]
self.C3 = [
-0.5900435899266435,
2.890611442640554,
-0.4570457994644658,
0.3731763325901154,
-0.4570457994644658,
1.445305721320277,
-0.5900435899266435
]
self.C4 = [
2.5033429417967046,
-1.7701307697799304,
0.9461746957575601,
-0.6690465435572892,
0.10578554691520431,
-0.6690465435572892,
0.47308734787878004,
-1.7701307697799304,
0.6258357354491761
]
def forward(self, input, **kwargs):
result = torch.empty((*input.shape[:-1], self.out_dim), dtype=input.dtype, device=input.device)
x, y, z = input.unbind(-1)
result[..., 0] = self.C0
if self.degree > 1:
result[..., 1] = -self.C1 * y
result[..., 2] = self.C1 * z
result[..., 3] = -self.C1 * x
if self.degree > 2:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result[..., 4] = self.C2[0] * xy
result[..., 5] = self.C2[1] * yz
result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy)
#result[..., 6] = self.C2[2] * (3.0 * zz - 1) # xx + yy + zz == 1, but this will lead to different backward gradients, interesting...
result[..., 7] = self.C2[3] * xz
result[..., 8] = self.C2[4] * (xx - yy)
if self.degree > 3:
result[..., 9] = self.C3[0] * y * (3 * xx - yy)
result[..., 10] = self.C3[1] * xy * z
result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy)
result[..., 12] = self.C3[3] * z * (2 * zz - 3 * xx - 3 * yy)
result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy)
result[..., 14] = self.C3[5] * z * (xx - yy)
result[..., 15] = self.C3[6] * x * (xx - 3 * yy)
if self.degree > 4:
result[..., 16] = self.C4[0] * xy * (xx - yy)
result[..., 17] = self.C4[1] * yz * (3 * xx - yy)
result[..., 18] = self.C4[2] * xy * (7 * zz - 1)
result[..., 19] = self.C4[3] * yz * (7 * zz - 3)
result[..., 20] = self.C4[4] * (zz * (35 * zz - 30) + 3)
result[..., 21] = self.C4[5] * xz * (7 * zz - 3)
result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1)
result[..., 23] = self.C4[7] * xz * (xx - 3 * yy)
result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
return result
class FourierFeature(nn.Module):
def __init__(self, channels, sigma=1.0, input_dims=3, include_input=True) -> None:
super().__init__()
self.register_buffer('B', torch.randn(input_dims, channels) * sigma, True)
self.channels = channels
self.out_dim = 2 * self.channels + 3 if include_input else 2 * self.channels
self.include_input = include_input
def forward(self, x):
xp = torch.matmul(2 * np.pi * x, self.B)
return torch.cat([x, torch.sin(xp), torch.cos(xp)], dim=-1) if self.include_input else torch.cat([torch.sin(xp), torch.cos(xp)], dim=-1)
def get_embedder(embed_type='positional', **kwargs):
if embed_type == 'positional':
input_dims = kwargs['input_dims']
multires = kwargs['multires']
embed_kwargs = {
'include_input': True,
'input_dims': input_dims,
'max_freq_log2': multires-1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
return embed, embedder_obj.out_dim
elif embed_type == 'spherical_harmonics':
embedder = SHEncoder(**kwargs)
return embedder, embedder.out_dim
elif embed_type == 'fourier':
embedder = FourierFeature(**kwargs)
return embedder, embedder.out_dim
else:
raise ValueError('Unknown embedding type: {}'.format(embed_type))
================================================
FILE: model/network/mlp.py
================================================
import torch.nn as nn
import numpy as np
import utils
from .embedder import *
from .density import LaplaceDensity
from .ray_sampler import ErrorBoundSampler
class ImplicitNetwork(nn.Module):
def __init__(
self,
feature_vector_size,
sdf_bounding_sphere,
d_in,
d_out,
dims,
geometric_init=True,
bias=1.0,
skip_in=(),
weight_norm=True,
embed_type=None,
sphere_scale=1.0,
output_activation=None,
**kwargs
):
super().__init__()
self.sdf_bounding_sphere = sdf_bounding_sphere
self.sphere_scale = sphere_scale
dims = [d_in] + dims + [d_out + feature_vector_size]
self.embed_fn = None
if embed_type:
embed_fn, input_ch = get_embedder(embed_type, input_dims=d_in, **kwargs)
self.embed_fn = embed_fn
dims[0] = input_ch
print(f"[INFO] Implicit network dims: {dims}")
self.num_layers = len(dims)
self.skip_in = skip_in
self.weight_norm = weight_norm
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
if out_dim < 0:
print(dims)
else:
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if geometric_init:
if l == self.num_layers - 2:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
elif (embed_type or self.use_grid) and l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
elif (embed_type or self.use_grid) and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.activation = nn.Softplus(beta=100)
self.output_activation = None
if output_activation is not None:
self.output_activation = activations[output_activation]
def get_param_groups(self, lr):
return [{'params': self.parameters(), 'lr': lr}]
def forward(self, input):
if self.embed_fn is not None:
input = self.embed_fn(input)
x = input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if l in self.skip_in:
x = torch.cat([x, input], 1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.activation(x)
if self.output_activation is not None:
x = self.output_activation(x)
return x
def gradient(self, x):
x.requires_grad_(True)
y = self.forward(x)[:,:1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients
def feature(self, x):
return self.forward(x)[:,1:]
def get_outputs(self, x, returns_grad=True):
x.requires_grad_(returns_grad)
output = self.forward(x)
sdf = output[:,:1]
''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
if self.sdf_bounding_sphere > 0.0:
sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
sdf = torch.minimum(sdf, sphere_sdf)
feature_vectors = output[:, 1:]
if returns_grad:
d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
gradients = torch.autograd.grad(
outputs=sdf,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return sdf, feature_vectors, gradients
else:
return sdf, feature_vectors, None
def get_sdf_vals(self, x):
sdf = self.forward(x)[:,:1]
''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
if self.sdf_bounding_sphere > 0.0:
sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
sdf = torch.minimum(sdf, sphere_sdf)
return sdf
activations = {
'sigmoid': nn.Sigmoid(),
'relu': nn.ReLU(),
'softplus': nn.Softplus()
}
class RenderingNetwork(nn.Module):
def __init__(
self,
feature_vector_size,
mode,
d_in,
d_out,
dims,
weight_norm=True,
embed_type=None,
embed_point=None,
output_activation='sigmoid',
**kwargs
):
super().__init__()
self.mode = mode
dims = [d_in + feature_vector_size] + dims + [d_out]
self.d_out = d_out
self.embedview_fn = None
if embed_type:
embedview_fn, input_ch = get_embedder(embed_type, input_dims=3, **kwargs)
self.embedview_fn = embedview_fn
dims[0] += (input_ch - 3)
if mode == 'idr':
self.embedpoint_fn = None
if embed_point is not None:
embedpoint_fn, input_ch = get_embedder(input_dims=3, **embed_point)
self.embedpoint_fn = embedpoint_fn
dims[0] += (input_ch - 3)
print(f"[INFO] Rendering network dims: {dims}")
self.num_layers = len(dims)
self.weight_norm = weight_norm
for l in range(0, self.num_layers - 1):
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.activation = nn.ReLU()
self.output_activation = activations[output_activation]
def forward(self, points, normals, view_dirs, feature_vectors):
if self.embedview_fn is not None:
view_dirs = self.embedview_fn(view_dirs)
if self.mode == 'idr':
rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
# elif self.mode == 'nerf':
else:
rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1)
x = rendering_input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
x = lin(x)
if l < self.num_layers - 2:
x = self.activation(x)
x = self.output_activation(x)
return x
================================================
FILE: model/network/ray_sampler.py
================================================
import abc
import torch
from utils import rend_util
import utils
class RaySampler(metaclass=abc.ABCMeta):
def __init__(self, near, far):
self.near = near
self.far = far
@abc.abstractmethod
def get_z_vals(self, ray_dirs, cam_loc, model):
pass
class UniformSampler(RaySampler):
def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R
self.N_samples = N_samples
self.scene_bounding_sphere = scene_bounding_sphere
self.take_sphere_intersection = take_sphere_intersection
def get_z_vals(self, ray_dirs, cam_loc, model):
if not self.take_sphere_intersection:
near, far = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device), self.far * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)
else:
sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
near = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)
far = sphere_intersections[:,1:]
t_vals = torch.linspace(0., 1., steps=self.N_samples, device=ray_dirs.device)
z_vals = near * (1. - t_vals) + far * (t_vals)
if model.training:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape, device=z_vals.device)
z_vals = lower + (upper - lower) * t_rand
return z_vals
class ErrorBoundSampler(RaySampler):
def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
eps, beta_iters, max_total_iters,
inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0):
super().__init__(near, 2.0 * scene_bounding_sphere)
self.N_samples = N_samples
self.N_samples_eval = N_samples_eval
self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg)
self.N_samples_extra = N_samples_extra
self.eps = eps
self.beta_iters = beta_iters
self.max_total_iters = max_total_iters
self.scene_bounding_sphere = scene_bounding_sphere
self.add_tiny = add_tiny
self.inverse_sphere_bg = inverse_sphere_bg
if inverse_sphere_bg:
self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)
def get_z_vals(self, ray_dirs, cam_loc, model):
beta0 = model.density.get_beta().detach()
# Start with uniform sampling
z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
samples, samples_idx = z_vals, None
# Get maximum beta from the upper bound (Lemma 2)
dists = z_vals[:, 1:] - z_vals[:, :-1]
bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
beta = torch.sqrt(bound)
# beta = torch.sqrt(bound).clone()
total_iters, not_converge = 0, True
# Algorithm 1
while not_converge and total_iters < self.max_total_iters:
points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
points_flat = points.reshape(-1, 3)
# Calculating the SDF only for the new sampled points
with torch.no_grad():
samples_sdf = model.implicit_network.get_sdf_vals(points_flat)
if samples_idx is not None:
sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
samples_sdf.reshape(-1, samples.shape[1])], -1)
sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
else:
sdf = samples_sdf
# Calculating the bound d* (Theorem 1)
d = sdf.reshape(z_vals.shape)
dists = z_vals[:, 1:] - z_vals[:, :-1]
a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
# d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1, device=z_vals.device)
# d_star[first_cond] = b[first_cond]
# d_star[second_cond] = c[second_cond]
s = (a + b + c) / 2.0
area_before_sqrt = s * (s - a) * (s - b) * (s - c)
mask = ~first_cond & ~second_cond & (b + c - a > 0)
# d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
# Optimization: multiplication is 5-20 times faster than indexing
first_cond = first_cond & ~second_cond
d_star = first_cond * b + second_cond * c + torch.nan_to_num((2.0 * torch.sqrt(area_before_sqrt)) / a) * mask
d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign
# Updating beta using line search
curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
# beta[curr_error <= self.eps] = beta0
# Optimization: multiplication is 5-20 times faster than indexing
beta0_mask = curr_error <= self.eps
beta = beta * ~beta0_mask + beta0 * beta0_mask
beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
for j in range(self.beta_iters):
beta_mid = (beta_min + beta_max) / 2.
curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
# beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
# beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
beta_mid_mask = curr_error <= self.eps
beta_max = beta_max * ~beta_mid_mask + beta_mid * beta_mid_mask
beta_min = beta_min * beta_mid_mask + beta_mid * ~beta_mid_mask
beta = beta_max
# Upsample more points
# tmp0 = beta.unsqueeze(-1).clone()
# tmp0 = beta.unsqueeze(-1)
# density = model.density(sdf.reshape(z_vals.shape), beta=tmp0)
density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))
# dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).unsqueeze(0).repeat(dists.shape[0], 1)], -1)
dists = torch.cat([dists, torch.full([dists.shape[0], 1], 1e10, device=dists.device)], -1)
free_energy = dists * density
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=free_energy.device), free_energy[:, :-1]], dim=-1)
alpha = 1 - torch.exp(-free_energy)
transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
weights = alpha * transmittance # probability of the ray hits something here
# Check if we are done and this is the last sampling
total_iters += 1
not_converge = beta.max() > beta0
if not_converge and total_iters < self.max_total_iters:
''' Sample more points proportional to the current error bound'''
N = self.N_samples_eval
bins = z_vals
error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
# tmp0 = beta.unsqueeze(-1).clone()
# tmp1 = -d_star / tmp0
# tmp2 = dists[:,:-1] ** 2.
# tmp3 = 4 * tmp0 ** 2
# error_per_section = tmp1 * tmp2 / tmp3
error_integral = torch.cumsum(error_per_section, dim=-1)
bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]
pdf = bound_opacity + self.add_tiny
pdf = pdf / torch.sum(pdf, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
else:
''' Sample the final sample set to be used in the volume rendering integral '''
N = self.N_samples
bins = z_vals
pdf = weights[..., :-1]
pdf = pdf + 1e-5 # prevent nans
pdf = pdf / torch.sum(pdf, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
# Invert CDF
if (not_converge and total_iters < self.max_total_iters) or (not model.training):
u = torch.linspace(0., 1., steps=N, device=cdf.device).unsqueeze(0).repeat(cdf.shape[0], 1)
else:
u = torch.rand(list(cdf.shape[:-1]) + [N], device=cdf.device)
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom_mask = denom < 1e-5
denom = denom_mask + ~denom_mask * denom
# denom = torch.where(denom < 1e-5, 1.0, denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
# Adding samples if we not converged
if not_converge and total_iters < self.max_total_iters:
z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)
z_samples = samples
near, far = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device), self.far * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)
if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]
if self.N_samples_extra > 0:
if model.training:
sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
else:
sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
else:
z_vals_extra = torch.cat([near, far], -1)
z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)
# add some of the near surface points
idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],), device=z_vals.device)
z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))
if self.inverse_sphere_bg:
z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
z_vals = (z_vals, z_vals_inverse_sphere)
return z_vals, z_samples_eik
def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
density = model.density(sdf.reshape(z_vals.shape), beta=beta)
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=dists.device), dists * density[:, :-1]], dim=-1)
integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
error_integral = torch.cumsum(error_per_section, dim=-1)
bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])
return bound_opacity.max(-1)[0]
================================================
FILE: model/rendering/__init__.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import utils
import cv2
from .brdf import *
class RenderingLayer(nn.Module):
def __init__(self, spp, split_n_pixels, preserve_light=True) -> None:
super().__init__()
self.spp = spp
self.split_n_pixels = split_n_pixels
self.preserve_light = preserve_light
def forward(
self,
model,
surface_points,
view_direction,
Kd,
Ks,
normal,
rough,
radiance_scale=None,
intersect_func=None
):
"""
Render according to material, normal and lighting conditions
Params:
model: NeRF model to predict radiance
surface_points, view_direction, albedo, normal, rough, metal: (bn, c)
"""
bn = normal.size(0)
cx, cy, cz = create_frame(normal)
wi_x = torch.sum(cx*view_direction, dim=1)
wi_y = torch.sum(cy*view_direction, dim=1)
wi_z = torch.sum(cz*view_direction, dim=1)
wi = torch.stack([wi_x, wi_y, wi_z], dim=1)
wi_mask = (wi[:,2] >= 0.00001)
wi[:,2,...] = torch.where(wi[:,2,...] < 0.00001, torch.ones_like(wi[:,2,...]) * 0.00001, wi[:,2,...])
wi = F.normalize(wi, dim=1, eps=1e-6)
# wi_mask = torch.where(wi[:,2:3,...] < 0, torch.zeros_like(wi[:,2:3,...]), torch.ones_like(wi[:,2:3,...]))
wi = wi.unsqueeze(1) # (bn, 1, 3)
# with torch.no_grad():
if True:
samples = torch.rand(bn, self.spp, 3, device=normal.device)
pS = probabilityToSampleSpecular(Kd, Ks)
clamp_value = 0.0
pS.clamp_min_(clamp_value)
sample_diffuse = samples[:,:,0] >= pS
ls_diffuse = square_to_cosine_hemisphere(samples[:,:,1:])
ls_specular = sample_ggx_specular(samples[:,:,1:], rough, wi)
wo = torch.where(sample_diffuse.unsqueeze(2).expand(bn, self.spp, 3), ls_diffuse, ls_specular) # (bn, spp, 3)
pdfs = pdf_ggx(Kd, Ks, rough, wi, wo, clamp_value).unsqueeze(2)
eval_diff, eval_spec, wo_mask = eval_ggx(Kd, Ks, rough, wi, wo)
# wo_mask = torch.all(wo_mask, dim=1)
direction = to_global(wo, cx.unsqueeze(1), cy.unsqueeze(1), cz.unsqueeze(1))
# surface_points = surface_points + 0.01 * view_direction # prevent self-intersection
surface_points = surface_points.unsqueeze(1).expand_as(direction).reshape(-1, 3)
direction = direction.reshape(-1, 3)
surface_points = surface_points + direction * 0.01 # prevent self-intersection
pts_splits = torch.split(surface_points, self.split_n_pixels, dim=0)
dirs_splits = torch.split(direction, self.split_n_pixels, dim=0)
radiance = []
# with torch.no_grad():
for pts, dirs in zip(pts_splits, dirs_splits):
radiance.append(model.get_incident_radiance(pts, dirs, intersect_func))
radiance = torch.cat(radiance, dim=0)
radiance = radiance.view(bn, self.spp, 3)
if radiance_scale is not None:
radiance = radiance * radiance_scale[None,None,:]
pdfs = torch.clamp(pdfs, min=0.00001)
ndl = torch.clamp(wo[:,:,2:], min=0)
brdfDiffuse = eval_diff.expand(bn, self.spp, 3) * ndl / pdfs
colorDiffuse = torch.mean(brdfDiffuse * radiance, dim=1)
brdfSpec = eval_spec.expand(bn, self.spp, 3) * ndl / pdfs
colorSpec = torch.mean(brdfSpec * radiance, dim=1)
return colorDiffuse, colorSpec, wi_mask
================================================
FILE: model/rendering/brdf.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
def create_frame(n: torch.Tensor, eps:float = 1e-6):
"""
Generate orthonormal coordinate system based on surface normal
[Duff et al. 17] Building An Orthonormal Basis, Revisited. JCGT. 2017.
:param: n (bn, 3, ...)
"""
z = F.normalize(n, dim=1, eps=eps)
sgn = torch.where(z[:,2,...] >= 0, 1.0, -1.0)
a = -1.0 / (sgn + z[:,2,...])
b = z[:,0,...] * z[:,1,...] * a
x = torch.stack([1.0 + sgn * z[:,0,...] * z[:,0,...] * a, sgn * b, -sgn * z[:,0,...]], dim=1)
y = torch.stack([b, sgn + z[:,1,...] * z[:,1,...] * a, -z[:,1,...]], dim=1)
return x, y, z
def get_rendering_parameters(albedo_raw, rough_raw, use_metallic):
if use_metallic:
assert albedo_raw.size(-1) == 3 and rough_raw.size(-1) == 2
metal = rough_raw[:,1:]
rough = rough_raw[:,:1].clamp_min(0.01)
Ks = baseColorToSpecularF0(albedo_raw, metal)
Kd = albedo_raw * (1 - metal)
else:
assert albedo_raw.size(-1) == 6 and rough_raw.size(-1) == 1
Kd = albedo_raw[:,:3]
Ks = albedo_raw[:,3:].clamp_min(0.04)
rough = rough_raw.clamp_min(0.01)
return Kd, Ks, rough
def to_global(d, x, y, z):
"""
d, x, y, z: (*, 3)
"""
return d[...,0:1] * x + d[...,1:2] * y + d[...,2:3] * z
def sqrt_(x: torch.Tensor, eps=1e-8) -> torch.Tensor:
"""
clamping 0 values of sqrt input to avoid NAN gradients
"""
return torch.sqrt(torch.clamp(x, min=eps))
def reflect(v: torch.Tensor, h: torch.Tensor):
dot = torch.sum(v*h, dim=2, keepdim=True)
return 2 * dot * h - v
def square_to_cosine_hemisphere(sample: torch.Tensor):
u, v = sample[:,:,0,...], sample[:,:,1,...]
phi = u * 2 * np.pi
r = sqrt_(v)
cos_theta = sqrt_(torch.clamp(1 - v, 0))
return torch.stack([torch.cos(phi) * r, torch.sin(phi) * r, cos_theta], dim=2)
def get_cos_theta(v: torch.Tensor):
return v[:,:,2,...]
def get_phi(v: torch.Tensor):
cos_theta = torch.clamp(v[:,:,2,...], min=0, max=1)
sin_theta = torch.clamp(sqrt_(1 - cos_theta*cos_theta), min=1e-8)
cos_phi = torch.clamp(v[:,:,0,...] / sin_theta, -1, 1)
sin_phi = v[:,:,1,...] / sin_theta
phi = torch.acos(cos_phi) # (0, pi)
return torch.where(sin_phi > 0, phi, 2*np.pi - phi)
def sample_disney_specular(sample: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):
"""
:param: sample (bn, spp, 3, h, w)
:param: roughness (bn, 1, 1, h, w)
:param: wi (*, *, 3, h, w), supposed to be normalized
:return: wo (bn, spp, 3, h, w), phi (bn, spp, h, w), cos theta (bn, spp, h, w)
"""
# a = torch.clamp(roughness, 0.001)
a = roughness
u, v = sample[:,:,0,...], sample[:,:,1,...]
phi = u * 2 * np.pi
cos_theta = sqrt_((1 - v) / (1 + (a*a - 1) * v))
sin_theta = sqrt_(1 - cos_theta*cos_theta)
cos_phi = torch.cos(phi)
sin_phi = torch.sin(phi)
half = torch.stack([sin_theta*cos_phi, sin_theta*sin_phi, cos_theta], dim=2)
wo = F.normalize(reflect(wi.expand_as(half), half), dim=2, eps=1e-8)
return wo
#, phi.squeeze(2), cos_theta.squeeze(2)
def GTR2(ndh, a):
a2 = a*a
t = 1.0 + (a2 - 1.0) * ndh * ndh
return a2 / (np.pi * t * t)
def SchlickFresnel(u):
m = torch.clamp(1.0 - u, 0, 1)
return m**5
def smithG_GGX(ndv, a):
a = a*a
b = ndv*ndv
return 1.0 / (ndv + sqrt_(a + b - a * b))
def pdf_disney(roughness: torch.Tensor, metallic: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):
"""
:param: roughness/metallic (bn, 1, h, w)
:param: wi (*, *, 3, h, w), supposed to be normalized
:param: wo (*, *, 3, h, w), supposed to be normalized
"""
# specularAlpha = torch.clamp(roughness, 0.001)
specularAlpha = roughness
diffuseRatio = 0.5 * (1 - metallic)
specularRatio = 1 - diffuseRatio
half = F.normalize(wi + wo, dim=2, eps=1e-8)
cosTheta = torch.abs(half[:,:,2,...])
pdfGTR2 = GTR2(cosTheta, specularAlpha) * cosTheta
pdfSpec = pdfGTR2 / torch.clamp(4.0 * torch.abs(torch.sum(wo*half, dim=2)), min=1e-8)
pdfDiff = torch.abs(wo[:,:,2,...]) / np.pi
pdf = diffuseRatio * pdfDiff + specularRatio * pdfSpec
pdf = torch.where(wi[:,:,2,...] < 0.0001, torch.ones_like(pdf) * 0.0001, pdf)
pdf = torch.where(wo[:,:,2,...] < 0.0001, torch.ones_like(pdf) * 0.0001, pdf)
return pdf
def eval_disney(albedo: torch.Tensor, roughness: torch.Tensor, metallic: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):
"""
:param: albedo/roughness/metallic (bn, c, h, w)
:param: wi (*, *, 3, h, w), supposed to be normalized
:param: wo (*, *, 3, h, w), supposed to be normalized
"""
h = wi + wo;
h = F.normalize(h, dim=2, eps=1e-8)
CSpec0 = torch.lerp(torch.ones_like(albedo)*0.04, albedo, metallic).unsqueeze(1)
ldh = torch.clamp(torch.sum( (wo * h), dim = 2), 0, 1).unsqueeze(2)
ndv = wi[:,:,2:3,...]
ndl = wo[:,:,2:3,...]
ndh = h[:,:,2:3,...]
FL, FV = SchlickFresnel(ndl), SchlickFresnel(ndv)
roughness = roughness.unsqueeze(1)
Fd90 = 0.5 + 2.0 * ldh * ldh * roughness
Fd = torch.lerp(torch.ones_like(Fd90), Fd90, FL) * torch.lerp(torch.ones_like(Fd90), Fd90, FV)
Ds = GTR2(ndh, roughness)
FH = SchlickFresnel(ldh)
Fs = torch.lerp(CSpec0, torch.ones_like(CSpec0), FH)
roughg = (roughness * 0.5 + 0.5) ** 2
Gs1, Gs2 = smithG_GGX(ndl, roughg), smithG_GGX(ndv, roughg)
Gs = Gs1 * Gs2
eval_diff = Fd * albedo.unsqueeze(1) * (1.0 - metallic.unsqueeze(1)) / np.pi
eval_spec = Gs * Fs * Ds
mask = torch.where(ndl < 0, torch.zeros_like(ndl), torch.ones(ndl))
return eval_diff, eval_spec, mask
def F_Schlick(SpecularColor, VoH):
Fc = (1 - VoH)**5
return torch.clamp(50.0 * SpecularColor[:,:,1:2,...], min=0, max=1) * Fc + (1 - Fc) * SpecularColor
def GetSpecularEventProbability(SpecularColor, NoV) -> torch.Tensor:
f = F_Schlick(SpecularColor, NoV);
return (f[:,:,0,...] + f[:,:,1,...] + f[:,:,2,...]) / 3
def baseColorToSpecularF0(baseColor, metalness):
return torch.lerp(torch.empty_like(baseColor).fill_(0.04), baseColor, metalness)
def luminance(color):
if color.size(1) == 1:
return color
# return color.mean(dim=1, keepdim=True)
return color[:,0:1,...] * 0.212671 + color[:,1:2,...] * 0.715160 + color[:,2:3,...] * 0.072169
def probabilityToSampleSpecular(difColor, specColor) -> torch.Tensor:
lumDiffuse = torch.clamp(luminance(difColor), min=0.01)
lumSpecular = torch.clamp(luminance(specColor), min=0.01)
return lumSpecular / (lumDiffuse + lumSpecular)
def shadowedF90(F0):
t = 1 / 0.04
return torch.clamp(t * luminance(F0), max=1)
def evalFresnel(f0, f90, NdotS):
# print(f0.shape, f90.shape, NdotS.shape)
return f0 + (f90 - f0) * (1 - NdotS)**5
def Smith_G1_GGX(alphaSquared, NdotSSquared):
return 2 / (sqrt_(((alphaSquared * (1 - NdotSSquared)) + NdotSSquared) / NdotSSquared) + 1)
def Smith_G2_GGX(alphaSquared, NdotL, NdotV):
a = NdotV * sqrt_(alphaSquared + NdotL * (NdotL - alphaSquared * NdotL))
b = NdotL * sqrt_(alphaSquared + NdotV * (NdotV - alphaSquared * NdotV))
return 0.5 / (a + b)
def GGX_D(alphaSquared, NdotH):
b = ((alphaSquared - 1) * NdotH * NdotH + 1)
return alphaSquared / (np.pi * b * b)
def pdf_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor, ps_min=0.0):
"""
:param: color (bn, 3, h, w)
:param: roughness/metallic (bn, 1, h, w)
:param: wi (*, *, 3, h, w), supposed to be normalized
:param: wo (*, *, 3, h, w), supposed to be normalized
:return: pdf (*, *, h, w)
"""
alpha = roughness * roughness
alphaSquared = alpha * alpha
NdotV = wi[:,:,2,...]
h = F.normalize(wi + wo, dim=2, eps=1e-8)
NdotH = h[:,:,2,...]
# print(alphaSquared.min(), NdotH.min(), NdotV.min())
ggxd = GGX_D(torch.clamp(alphaSquared, min=0.00001), NdotH)
smith = Smith_G1_GGX(alphaSquared, NdotV * NdotV)
# pdf_spec = GGX_D(torch.clamp(alphaSquared, min=0.00001), NdotH) * Smith_G1_GGX(alphaSquared, NdotV * NdotV) / (4 * NdotV)
pdf_spec = ggxd * smith / (4 * NdotV)
# print(torch.any(torch.isnan(ggxd)), torch.any(torch.isnan(smith)), torch.any(torch.isnan(NdotV)))
# print(NdotV.min(), ggxd.min(), smith.min())
with torch.no_grad():
pS = probabilityToSampleSpecular(Kd, Ks).clamp_min(ps_min)
pdf_diff = wo[:,:,2,...] / np.pi
# print("#########################################")
# print("#########################################")
# print("#########################################")
# print(torch.any(torch.isnan(kS)), torch.any(torch.isnan(pdf_spec)), torch.any(torch.isnan(pdf_diff)))
# print("#########################################")
# print("#########################################")
# print("#########################################")
pdf = pS * pdf_spec + (1 - pS) * pdf_diff
pdf = torch.where(wi[:,:,2,...] <= 0.0001, torch.ones_like(pdf) * 0.0001, pdf)
pdf = torch.where(wo[:,:,2,...] <= 0.0001, torch.ones_like(pdf) * 0.0001, pdf)
return pdf
def eval_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):
"""
:param: color (bn, c, h, w)
:param: roughness/metallic (bn, 1, h, w)
:param: wi (*, *, c, h, w), supposed to be normalized
:param: wo (*, *, c, h, w), supposed to be normalized
:return: fr(wi, wo) (*, *, c, h, w)
"""
NDotL = wo[:,:,2:3,...]
NDotV = wi[:,:,2:3,...]
H = F.normalize(wi + wo, dim=2, eps=1e-8)
NDotH = H[:,:,2:3,...]
LDotH = torch.sum(wo*H, dim=2, keepdim=True)
roughness = roughness.unsqueeze(1)
alpha = roughness * roughness
alpha2 = alpha * alpha
D = GGX_D(torch.clamp(alpha2, min=0.00001), NDotH)
G2 = Smith_G2_GGX(alpha2, NDotL, NDotV)
f = evalFresnel(Ks.unsqueeze(1), shadowedF90(Ks).unsqueeze(1), LDotH)
# spec = torch.where(NDotL <= 0, torch.zeros_like(NDotL), f * G2 * D)
# mask = torch.where(NDotL <= 0, torch.zeros_like(NDotL), torch.ones_like(NDotL))
spec = torch.where(NDotL < 0.0001, torch.ones_like(NDotL) * 0.0001, f * G2 * D)
# mask = torch.where(NDotL <= 0, torch.zeros_like(NDotL), torch.ones_like(NDotL))
mask = (NDotL >= 0.0001).squeeze(-1)
return Kd.unsqueeze(1) / np.pi, spec, mask
def sample_weight_ggx(alphaSquared, NdotL, NdotV):
G1V = Smith_G1_GGX(alphaSquared, NdotV*NdotV)
G1L = Smith_G1_GGX(alphaSquared, NdotL*NdotL)
return G1L / (G1V + G1L - G1V * G1L)
def sample_ggx(sample: torch.Tensor, Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):
"""
:param: sample (bn, spp, 3, h, w)
:param: roughness (bn, 1, h, w)
:param: wi (*, *, 3, h, w), supposed to be normalized
:return: wo (bn, spp, 3, h, w), weight (bn, spp, 3, h, w)
"""
with torch.no_grad():
pS = probabilityToSampleSpecular(Kd, Ks)
sample_diffuse = sample[:,:,2,...] >= pS
wo_diff = square_to_cosine_hemisphere(sample[:,:,1:,...])
weight_diff = Kd / (1 - pS)
weight_diff = weight_diff.unsqueeze(1)
roughness = roughness.unsqueeze(1)
alpha = roughness * roughness
# alpha = roughness
Vh = F.normalize(torch.cat([alpha * wi[:,:,0:1,...], alpha * wi[:,:,1:2,...], wi[:,:,2:3,...]], dim=2), dim=2, eps=1e-8)
lensq = Vh[:,:,0:1,...]**2 + Vh[:,:,1:2,...]**2
zero_ = torch.zeros_like(Vh[:,:,0,...])
one_ = torch.ones_like(Vh[:,:,0,...])
T1 = torch.where(
lensq > 0,
torch.stack([-Vh[:,:,1,...], Vh[:,:,0,...], zero_], dim=2) / sqrt_(lensq),
torch.stack([one_, zero_, zero_], dim=2)
)
T2 = torch.cross(Vh, T1, dim=2)
r = sqrt_(sample[:,:,0:1,...])
phi = 2 * np.pi * sample[:,:,1:2,...]
t1 = r * torch.cos(phi)
t2 = r * torch.sin(phi)
s = 0.5 * (1 + Vh[:,:,2:3,...])
t2 = torch.lerp(sqrt_(1 - t1**2), t2, s)
Nh = t1 * T1 + t2 * T2 + sqrt_(torch.clamp(1 - t1*t1 - t2*t2, min=0)) * Vh
h = F.normalize(torch.cat([alpha * Nh[:,:,0:1,...], alpha * Nh[:,:,1:2,...], torch.clamp(Nh[:,:,2:3,...], min=0)], dim=2), dim=2, eps=1e-8)
wo = reflect(wi, h)
HdotL = torch.clamp(torch.sum(h*wo, dim=2, keepdim=True), min=0.0001, max=1.0)
NdotL = torch.clamp(wo[:,:,2:3,...], min=0.0001, max=1.0)
NdotV = torch.clamp(wi[:,:,2:3,...], min=0.0001, max=1.0)
# NdotH = torch.clamp(h[:,:,2:3,...], min=0.00001, max=1.0)
# F = evalFresnel(specularF0, shadowedF90(specularF0), HdotL)
weight = evalFresnel(Ks, shadowedF90(Ks), HdotL) * sample_weight_ggx(alpha*alpha, NdotL, NdotV) / pS.unsqueeze(1)
wo = torch.where(sample_diffuse.unsqueeze(2), wo_diff, wo)
weight = torch.where(sample_diffuse.unsqueeze(2), weight_diff, weight)
return wo, weight
def sample_ggx_specular(sample: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):
"""
:param: sample (bn, spp, 2, h, w)
:param: roughness (bn, 1, h, w)
:param: wi (*, *, 3, h, w), supposed to be normalized
:return: wo (bn, spp, 3, h, w), phi (bn, spp, h, w), cos theta (bn, spp, h, w)
"""
roughness = roughness.unsqueeze(1)
alpha = roughness * roughness
# alpha = roughness
Vh = F.normalize(torch.cat([alpha * wi[:,:,0:1,...], alpha * wi[:,:,1:2,...], wi[:,:,2:3,...]], dim=2), dim=2, eps=1e-8)
# bn, spp, _, row, col = Vh.shape
# Vh = Vh.view(-1, 3, row, col)
# T1, T2, Vh = utils.hughes_moeller(Vh)
# T1 = T1.view(bn, spp, 3, row, col)
# T2 = T2.view(bn, spp, 3, row, col)
# Vh = Vh.view(bn, spp, 3, row, col)
lensq = Vh[:,:,0:1,...]**2 + Vh[:,:,1:2,...]**2
zero_ = torch.zeros_like(Vh[:,:,0,...])
one_ = torch.ones_like(Vh[:,:,0,...])
T1 = torch.where(
lensq > 0,
torch.stack([-Vh[:,:,1,...], Vh[:,:,0,...], zero_], dim=2) / sqrt_(lensq),
torch.stack([one_, zero_, zero_], dim=2)
)
T2 = torch.cross(Vh, T1, dim=2)
r = sqrt_(sample[:,:,0:1,...])
phi = 2 * np.pi * sample[:,:,1:2,...]
t1 = r * torch.cos(phi)
t2 = r * torch.sin(phi)
s = 0.5 * (1 + Vh[:,:,2:3,...])
t2 = torch.lerp(sqrt_(1 - t1**2), t2, s)
Nh = t1 * T1 + t2 * T2 + sqrt_(torch.clamp(1 - t1*t1 - t2*t2, min=0)) * Vh
h = F.normalize(torch.cat([alpha * Nh[:,:,0:1,...], alpha * Nh[:,:,1:2,...], torch.clamp(Nh[:,:,2:3,...], min=0)], dim=2), dim=2, eps=1e-8)
wo = reflect(wi, h)
return wo
================================================
FILE: model/trainer/__init__.py
================================================
from .recon import ReconstructionTrainer
================================================
FILE: model/trainer/recon.py
================================================
import math
import torch
import pytorch_lightning as pl
import numpy as np
import torch.optim as optim
import os
from torch.utils.data import DataLoader
import utils
from utils import rend_util
import utils.plots as plt
import dataset
import model
from tqdm import trange
from pytorch_lightning.callbacks import RichProgressBar
from torchmetrics.functional import structural_similarity_index_measure as ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
import cv2
lpips = LPIPS()
class ReconstructionTrainer(pl.LightningModule):
def __init__(self, conf, prog_bar: RichProgressBar, exp_dir, model_only=False, val_mesh=False, is_val=False, **kwargs) -> None:
super().__init__()
self.conf = conf
self.prog_bar = prog_bar
self.batch_size = conf.train.batch_size
self.bubble_batch_size = getattr(conf.train, 'bubble_batch_size', self.batch_size)
self.expdir = exp_dir
self.val_mesh = val_mesh
conf_model = conf.model
use_normal = (getattr(conf.loss, 'normal_weight', 0) > 0) or (getattr(conf.loss, 'angular_weight', 0) > 0)
conf_model.use_normal = use_normal
self.model = model.I2SDFNetwork(conf_model)
if model_only:
return
print('[INFO] Loading data ...')
dataset_conf = conf.dataset
self.scan_id = dataset_conf.scan_id
self.train_dataset = dataset.ReconDataset(
**dataset_conf,
use_mask=getattr(conf.loss, 'mask_weight', 0) > 0,
use_depth=getattr(conf.loss, 'depth_weight', 0) > 0,
use_normal=use_normal,
use_bubble=getattr(conf.loss, 'bubble_weight', 0) > 0,
use_lightmask=getattr(conf.loss, 'light_mask_weight', 0) > 0
)
if self.train_dataset.use_bubble:
os.makedirs(os.path.join(self.expdir, 'hotmap'), exist_ok=True)
os.makedirs(os.path.join(self.expdir, 'countmap'), exist_ok=True)
self.pdf_criterion = getattr(conf.train, 'pdf_criterion', 'DEPTH')
assert self.pdf_criterion in ['RGB', 'DEPTH']
self.is_hdr = self.train_dataset.is_hdr
self.plots_dir = os.path.join(self.expdir, 'plots')
self.trace_bub_idx = self.conf.train.get('trace_bub_idx', -1)
if self.trace_bub_idx != -1:
os.makedirs(f"{self.plots_dir}/bubble", exist_ok=True)
print(f"[INFO] Activate hotmap visualization for #{self.trace_bub_idx}")
self.plot_dataset = dataset.PlotDataset(**dataset_conf, indices=[self.trace_bub_idx], plot_nimgs=1, is_val=is_val)
else:
data = {
'intrinsics': self.train_dataset.intrinsics_all,
'pose': self.train_dataset.pose_all,
'rgb': self.train_dataset.rgb_images,
'img_res': self.train_dataset.img_res
}
if self.train_dataset.use_lightmask:
data['light_mask'] = self.train_dataset.lightmask_images
self.plot_dataset = dataset.PlotDataset(**dataset_conf, data=data, plot_nimgs=conf.plot.plot_nimgs, is_val=is_val)
os.makedirs(self.plots_dir, exist_ok=True)
with open(f"{self.expdir}/config.yml", 'w') as f:
f.write(self.conf.dump())
if self.train_dataset.use_bubble:
points = self.train_dataset.pointcloud
index = torch.randperm(points.size(0))[:200000]
points = points[index,:]
plt.visualize_pointcloud(points, f"{self.expdir}/pointcloud.html")
print(f"[INFO] Pointcloud visualization success: saved to {self.expdir}/pointcloud.html")
self.pdf_prune = self.train_dataset.pdf_prune
self.pdf_max = self.train_dataset.pdf_max
# self.ds_len = len(self.train_dataset)
self.ds_len = self.train_dataset.n_images
print('[INFO] Finish loading data. Data-set size: {0}'.format(self.ds_len))
epoch_steps = len(self.train_dataset) / self.batch_size
self.nepochs = int(math.ceil(200000 / epoch_steps))
self.loss = model.I2SDFLoss(**conf.loss)
self.total_pixels = self.plot_dataset.total_pixels
self.img_res = self.plot_dataset.img_res
self.bubble_activated = False
self.uniform_bubble = getattr(self.conf.train, 'uniform_bubble', False)
if self.uniform_bubble:
print("[INFO] Ablation study: uniform sampling for bubble loss")
self.checkpoint_freq = self.conf.train.checkpoint_freq
self.split_n_pixels = self.conf.train.split_n_pixels
self.plot_conf = self.conf.plot
self.progbar_task = None
if self.train_dataset.use_lightmask and getattr(self.conf.train, 'flip_light', False):
self.train_dataset.lightmask_images = 1.0 - self.train_dataset.lightmask_images
self.plot_dataset.lightmask_images = 1.0 - self.plot_dataset.lightmask_images
def forward(self):
raise NotImplementedError("forward not supported by trainer")
def plot_hotmap(self, path):
assert self.bubble_activated
ds = self.train_dataset
hotmaps = torch.zeros(self.ds_len * ds.total_pixels)
hotmaps[ds.pixlinks] = self.pdf.cpu()
hotmaps = hotmaps.reshape(self.ds_len, *ds.img_res)
for i, hotmap in enumerate(hotmaps):
hotmap = hotmap.numpy()
# hotmap /= max(1e-4, hotmap.max())
hotmap = (hotmap * 255).astype(np.uint8)
hotmap = cv2.applyColorMap(hotmap, cv2.COLORMAP_MAGMA)
cv2.imwrite(os.path.join(path, "{:04d}.png".format(i)), hotmap)
if self.trace_bub_idx == i:
cv2.imwrite(os.path.join(f"{self.plots_dir}/bubble", f"{self.global_step}_hot.png"), hotmap)
def plot_countmap(self, path):
assert self.bubble_activated
ds = self.train_dataset
countmaps = torch.zeros(self.ds_len * ds.total_pixels)
countmaps[ds.pixlinks] = self.sample_count.cpu().float()
countmaps = countmaps.reshape(self.ds_len, *ds.img_res)
countmaps = countmaps / max(1, countmaps.max())
for i, countmap in enumerate(countmaps):
countmap = countmap.numpy()
countmap = (countmap * 255).astype(np.uint8)
countmap = cv2.applyColorMap(countmap, cv2.COLORMAP_MAGMA)
cv2.imwrite(os.path.join(path, "{:04d}.png".format(i)), countmap)
if self.trace_bub_idx == i:
cv2.imwrite(os.path.join(f"{self.plots_dir}/bubble", f"{self.global_step}_cnt.png"), countmap)
def update_pdf(self, value, idx):
assert self.bubble_activated
ds = self.train_dataset
value = value.to(self.pdf.device)
if self.pdf_max is not None:
value = value.clamp(max=self.pdf_max)
value[value < self.pdf_prune] = 0 # PDF pruning
link = ds.pointlinks[idx]
mask = (link != -1)
link = link[mask]
value = value[mask]
self.pdf[link] = value
def sample_bubble(self, batch_size):
assert self.bubble_activated
ds = self.train_dataset
if self.uniform_bubble:
sample_idx = torch.randperm(ds.pointcloud.size(0), device=ds.pointcloud.device)[:batch_size]
return ds.pointcloud[sample_idx,:]
sample_idx = torch.where(self.pdf > 0)[0]
pdf_samples = self.pdf[sample_idx]
pointcloud_samples = ds.pointcloud[sample_idx,:]
if sample_idx.size(0) >= (1 << 24):
# print(sample_idx.size(0), self.pdf.size(0), (1 << 24))
print("[ERROR] PDF capacity exceeds maximum limit of PyTorch")
exit(1)
idx = torch.multinomial(pdf_samples, batch_size, replacement=False) # importance sampling
self.sample_count[sample_idx[idx]] += 1
return pointcloud_samples[idx,:]
def initialize_bubble_pdf(self, split_size):
ds = self.train_dataset
# ds.pdf = ds.pdf.cuda()
self.register_buffer('pdf', torch.zeros(len(ds.pointcloud)), False)
self.register_buffer('sample_count', torch.zeros(len(ds.pointcloud)), False)
self.pdf = self.pdf.cuda()
# self.sample_count = self.sample_count.cuda()
for i in trange(ds.n_images):
intrinsics = ds.intrinsics_all[i].cuda().unsqueeze(0)
pose = ds.pose_all[i].cuda().unsqueeze(0)
img = ds.rgb_images[i].cuda() if self.pdf_criterion != 'DEPTH' else ds.depth_images[i].cuda()
uv = ds.uv.cuda().unsqueeze(1) # (h*w, 1, 2)
img_splits = torch.split(img, split_size)
uv_splits = torch.split(uv, split_size)
indices = torch.arange(i * ds.total_pixels, (i + 1) * ds.total_pixels, dtype=torch.long, device='cuda')
index_splits = torch.split(indices, split_size)
for img_split, uv_split, index_split in zip(img_splits, uv_splits, index_splits):
data = {
'uv': uv_split,
'intrinsics': intrinsics.repeat(len(uv_split), 1, 1),
'pose': pose.repeat(len(uv_split), 1, 1)
}
model_output = self.model.forward(data, True)
if self.pdf_criterion == 'RGB':
self.update_pdf((model_output['rgb_values'].detach().clamp(0, 1) - img_split.clamp(0, 1)).abs().mean(dim=-1), index_split)
# elif self.pdf_criterion == 'DEPTH':
else:
self.update_pdf((model_output['depth_values'].detach() - img_split).abs(), index_split)
def configure_optimizers(self):
lr = self.conf.train.learning_rate
optimizer = optim.Adam(self.model.get_param_groups(lr), eps=1e-15)
decay_rate = getattr(self.conf.train, 'sched_decay_rate', 0.1)
decay_steps = self.nepochs * self.ds_len
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, decay_rate ** (1./decay_steps))
return [optimizer], [scheduler]
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.train_dataset.collate_fn, num_workers=4)
def val_dataloader(self):
return DataLoader(self.plot_dataset, batch_size=self.conf.plot.plot_nimgs, shuffle=False, collate_fn=self.train_dataset.collate_fn)
def log_if_nonzero(self, name, value, *args, **kwargs):
if value > 0:
self.log(name, value, *args, **kwargs)
def training_step(self, batch, batch_idx):
indices, img_indices, model_input, ground_truth = batch
if not self.bubble_activated and self.train_dataset.use_bubble and self.global_step >= self.loss.min_bubble_iter and self.global_step < self.loss.max_bubble_iter:
# Start bubble step
with torch.no_grad():
self.bubble_activated = True
self.train_dataset.pointcloud = self.train_dataset.pointcloud.cuda()
# self.loss.eikonal_weight = self.loss.eikonal_weight_pointcloud
# Disable normal loss, since it will discourage the growth of bubbles
self.loss.normal_weight_bak = self.loss.normal_weight
self.loss.normal_weight = 0.0
self.loss.angular_weight_bak = self.loss.angular_weight
self.loss.angular_weight = 0.0
if not self.uniform_bubble:
print(f"[INFO] Start to initializing pointcloud PDF, criterion: {self.pdf_criterion}")
self.initialize_bubble_pdf(self.split_n_pixels) # initialize PDF maps for each image by computing losses
torch.save(self.pdf, os.path.join(self.expdir, 'checkpoints', "pdf.pt"))
torch.cuda.empty_cache()
self.plot_hotmap(os.path.join(self.expdir, 'hotmap'))
print("[INFO] Finish to initializing pointcloud PDF")
print(f"[INFO] {torch.count_nonzero(self.pdf).item()}/{self.pdf.size(0)} points to be sampled")
if self.bubble_activated:
model_input['pointcloud'] = self.sample_bubble(self.bubble_batch_size)
model_outputs = self.model(model_input)
if self.bubble_activated and not self.uniform_bubble:
with torch.no_grad():
if self.pdf_criterion == 'RGB':
self.update_pdf((model_outputs['rgb_values'].detach().clamp(0, 1) - ground_truth['rgb'].clamp(0, 1)).abs().mean(dim=-1), indices)
# elif self.pdf_criterion == 'DEPTH':
else:
self.update_pdf((model_outputs['depth_values'].detach() - ground_truth['depth']).abs(), indices)
loss_output = self.loss(model_outputs, ground_truth, self.global_step)
if self.bubble_activated and self.loss.max_bubble_iter is not None and self.global_step >= self.loss.max_bubble_iter:
# End bubble step
self.train_dataset.use_bubble = False
self.bubble_activated = False
del self.train_dataset.pointcloud
del self.train_dataset.pointlinks
del self.train_dataset.pixlinks
if not self.uniform_bubble:
delattr(self, 'pdf')
delattr(self, 'sample_count')
torch.cuda.empty_cache()
# self.loss.eikonal_weight = self.conf.loss.eikonal_weight
# Restore normal loss
self.loss.normal_weight = self.loss.normal_weight_bak
self.loss.angular_weight = self.loss.angular_weight_bak
loss = loss_output['loss']
with torch.no_grad():
psnr = rend_util.get_psnr(model_outputs['rgb_values'].detach(), ground_truth['rgb'].view(-1, 3))
self.log('train/loss', loss.item())
self.log('train/psnr', psnr.item(), True)
self.log('train/rgb_loss', loss_output['rgb_loss'].item())
self.log_if_nonzero('train/eikonal_loss', loss_output['eikonal_loss'].item())
self.log_if_nonzero('train/smooth_loss', loss_output['smooth_loss'].item())
self.log_if_nonzero('train/mask_loss', loss_output['mask_loss'].item())
self.log_if_nonzero('train/depth_loss', loss_output['depth_loss'].item())
self.log_if_nonzero('train/normal_loss', loss_output['normal_loss'].item())
self.log_if_nonzero('train/angular_loss', loss_output['angular_loss'].item())
self.log_if_nonzero('train/bubble_loss', loss_output['bubble_loss'].item())
self.log_if_nonzero('train/light_mask_loss', loss_output['light_mask_loss'].item())
self.log('train/beta', self.model.density.beta.item())
return loss
def validation_step(self, batch, batch_idx):
indices, model_input, ground_truth = batch
split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)
res = []
if self.progbar_task is None and self.prog_bar.progress:
self.progbar_task = self.prog_bar.progress.add_task("[cyan]Validation split", total=len(split))
elif self.progbar_task:
self.prog_bar.progress.reset(self.progbar_task, total=len(split), visible=True)
for s in split:
out = utils.detach_dict(self.model(s))
d = {
'rgb_values': out['rgb_values'].detach(),
'depth_values': out['depth_values'].detach()
}
if 'normal_map' in out:
d['normal_map'] = out['normal_map'].detach()
if 'light_mask' in out:
d['light_mask'] = out['light_mask'].detach()
del out
res.append(d)
if self.progbar_task:
self.prog_bar.progress.update(self.progbar_task, advance=1, refresh=True)
if self.progbar_task:
self.prog_bar.progress.update(self.progbar_task, visible=False)
batch_size = ground_truth['rgb'].shape[0]
model_outputs = utils.merge_output(res, self.total_pixels, batch_size)
def get_plot_data(model_outputs, pose, ground_truth):
rgb_gt = ground_truth['rgb']
batch_size, num_samples, _ = rgb_gt.shape
rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3)
if self.is_hdr:
eval_hdr = rgb_eval
gt_hdr = rgb_gt
rgb_eval = rend_util.linear_to_srgb(rgb_eval.clamp(0, 1))
rgb_gt = rend_util.linear_to_srgb(rgb_gt.clamp(0, 1))
depth_eval = model_outputs['depth_values'].reshape(batch_size, num_samples, 1)
plot_data = {
'rgb_gt': rgb_gt,
'pose': pose,
'rgb_eval': rgb_eval,
'depth_eval': depth_eval
}
if self.is_hdr:
plot_data['hdr_gt'] = gt_hdr
plot_data['hdr_eval'] = eval_hdr
if 'normal_map' in model_outputs:
normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3)
normal_map = normal_map.transpose(1, 2) # (bn, 3, h*w)
R = pose[:,:3,:3].transpose(1, 2)
normal_map = torch.bmm(R, normal_map) # world to camera
normal_map = normal_map.transpose(1, 2)
normal_map = (normal_map + 1.) / 2.
plot_data['normal_map'] = normal_map
if 'light_mask' in model_outputs:
plot_data['lmask_eval'] = model_outputs['light_mask'].reshape(batch_size, num_samples, 1)
plot_data['lmask_gt'] = ground_truth['light_mask'].reshape(batch_size, num_samples, 1)
return plot_data
plot_data = get_plot_data(model_outputs, model_input['pose'], ground_truth)
return {
'indices': indices,
'plot_data': plot_data
}
def validation_epoch_end(self, outputs) -> None:
self.plot_dataset.shuffle_plot_index()
indices = torch.cat([x['indices'] for x in outputs], dim=0)
plot_data = utils.merge_dict([x['plot_data'] for x in outputs])
rgb_eval = plot_data['rgb_eval']
rgb_gt = plot_data['rgb_gt']
psnr = rend_util.get_psnr(rgb_eval, rgb_gt)
self.log('val/psnr', psnr.item())
rgb_gt = rgb_gt.transpose(1, 2).view(-1, 3, *self.img_res) # (bn, h*w, 3) => (bn, 3, h, w)
rgb_eval = rgb_eval.transpose(1, 2).view(-1, 3, *self.img_res)
self.log('val/ssim', ssim(rgb_eval, rgb_gt).item())
lpips.to(rgb_eval.device)
self.log('val/lpips', lpips(rgb_eval.clamp(0, 1) * 2 - 1, rgb_gt.clamp(0, 1) * 2 - 1).item())
os.makedirs(self.plots_dir, exist_ok=True)
os.makedirs('{0}/rendering'.format(self.plots_dir), exist_ok=True)
if self.is_hdr:
os.makedirs('{0}/hdr'.format(self.plots_dir), exist_ok=True)
os.makedirs('{0}/depth'.format(self.plots_dir), exist_ok=True)
if 'normal_map' in plot_data:
os.makedirs('{0}/normal'.format(self.plots_dir), exist_ok=True)
if 'lmask_eval' in plot_data:
os.makedirs('{0}/light_mask'.format(self.plots_dir), exist_ok=True)
if self.val_mesh:
os.makedirs('{0}/mesh'.format(self.plots_dir), exist_ok=True)
if self.bubble_activated and not self.uniform_bubble:
self.plot_hotmap(os.path.join(self.expdir, 'hotmap'))
self.plot_countmap(os.path.join(self.expdir, 'countmap'))
plt.plot(self.model.implicit_network,
indices,
plot_data,
self.plots_dir,
self.global_step,
self.img_res,
meshing=self.val_mesh,
**self.plot_conf
)
================================================
FILE: utils/__init__.py
================================================
from .cfgnode import CfgNode
from .rend_util import *
import torch
import torch.nn.functional as F
import torch.nn as nn
from glob import glob
import os
import numpy as np
from pytorch_lightning.callbacks import RichProgressBar
from rich.progress import TextColumn
class RichProgressBarWithScanId(RichProgressBar):
def __init__(self, scan_id, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.custom_column = TextColumn(f"[progress.description]scan_id: {scan_id}")
def configure_columns(self, trainer):
return super().configure_columns(trainer) + [self.custom_column]
def glob_imgs(path):
imgs = []
for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG', '*.exr']:
imgs.extend(glob(os.path.join(path, ext)))
return imgs
def glob_depths(path):
imgs = []
for ext in ['*.exr']:
imgs.extend(glob(os.path.join(path, ext)))
return imgs
glob_normal = glob_depths
def split_input(model_input, total_pixels, n_pixels=10000):
'''
Split the input to fit Cuda memory for large resolution.
Can decrease the value of n_pixels in case of cuda out of memory error.
'''
split = []
for i, indx in enumerate(torch.split(torch.arange(total_pixels, device=model_input['uv'].device), n_pixels, dim=0)):
data = model_input.copy()
data['uv'] = torch.index_select(model_input['uv'], 1, indx)
if 'object_mask' in data:
data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx)
split.append(data)
return split
def split_dict(d, batch_size=10000):
keys = d.keys()
splits = {}
for k in d:
splits[k] = torch.split(d[k], batch_size)
n_splits = len(splits[k])
split_inputs = []
for i in range(n_splits):
split = {}
for k in d:
split[k] = splits[k][i]
split_inputs.append(split)
return split_inputs
def detach_dict(d):
return {k: v.detach() for k, v in d.items() if torch.is_tensor(v)}
def merge_output(res, total_pixels, batch_size):
''' Merge the split output. '''
model_outputs = {}
for entry in res[0]:
if res[0][entry] is None:
continue
if len(res[0][entry].shape) == 1:
model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],
1).reshape(batch_size * total_pixels)
else:
model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],
1).reshape(batch_size * total_pixels, -1)
return model_outputs
def merge_dict(dicts):
output = {}
for entry in dicts[0]:
output[entry] = torch.cat([r[entry] for r in dicts], dim=0)
return output
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
class _trunc_exp(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # cast to float32
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(-15, 15))
trunc_exp = _trunc_exp.apply
def kmeans_pp_centroid(points: torch.Tensor, k):
n, c = points.shape
centroids = torch.zeros(k, c, device=points.device)
centroids[0, :] = points[np.random.randint(0, n), :].clone()
d = [0.0] * n
for i in range(1, k):
sum_all = 0
d = (points.unsqueeze(1) - centroids[:i,:].unsqueeze(0)).norm(p=2, dim=-1).min(dim=1).values
sum_all = d.sum() * np.random.random()
cumsum = torch.cumsum(d, dim=0)
j = ((cumsum - sum_all) > 0).int().argmax()
centroids[i,:] = points[j,:].clone()
return centroids
================================================
FILE: utils/cfgnode.py
================================================
"""
Define a class to hold configurations.
Borrows and merges stuff from YACS, fvcore, and detectron2
https://github.com/rbgirshick/yacs
https://github.com/facebookresearch/fvcore/
https://github.com/facebookresearch/detectron2/
"""
import copy
import importlib.util
import io
import logging
import os
from ast import literal_eval
from typing import Optional
import yaml
# File exts for yaml
_YAML_EXTS = {"", ".yml", ".yaml"}
# File exts for python
_PY_EXTS = {".py"}
# CfgNodes can only contain a limited set of valid types
_VALID_TYPES = {tuple, list, str, int, float, bool}
# Valid file object types
_FILE_TYPES = (io.IOBase,)
# Logger
logger = logging.getLogger(__name__)
class CfgNode(dict):
r"""CfgNode is a `node` in the configuration `tree`. It's a simple wrapper around a `dict` and supports access to
`attributes` via `keys`.
"""
IMMUTABLE = "__immutable__"
DEPRECATED_KEYS = "__deprecated_keys__"
RENAMED_KEYS = "__renamed_keys__"
NEW_ALLOWED = "__new_allowed__"
def __init__(
self,
init_dict: Optional[dict] = None,
key_list: Optional[list] = None,
new_allowed: Optional[bool] = False,
):
r"""
Args:
init_dict (dict): A dictionary to initialize the `CfgNode`.
key_list (list[str]): A list of names that index this `CfgNode` from the root. Currently, only used for
logging.
new_allowed (bool): Whether adding a new key is allowed when merging with other `CfgNode` objects.
"""
# Recursively convert nested dictionaries in `init_dict` to config tree.
init_dict = {} if init_dict is None else init_dict
key_list = [] if key_list is None else key_list
init_dict = self._create_config_tree_from_dict(init_dict, key_list)
super(CfgNode, self).__init__(init_dict)
# Control the immutability of the `CfgNode`.
self.__dict__[CfgNode.IMMUTABLE] = False
# Support for deprecated options.
# If you choose to remove support for an option in code, but don't want to change all of the config files
# (to allow for deprecated config files to run), you can add the full config key as a string to this set.
self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
# Support for renamed options.
# If you rename an option, record the mapping from the old name to the new name in this dictionary. Optionally,
# if the type also changed, you can make this value a tuple that specifies two things: the renamed key, and the
# instructions to edit the config file.
self.__dict__[CfgNode.RENAMED_KEYS] = {
# 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY', # Dummy example
# 'EXAMPLE.OLD.KEY': ( # A more complex example
# 'EXAMPLE.NEW.KEY',
# "Also convert to a tuple, eg. 'foo' -> ('foo', ) or "
# + "'foo.bar' -> ('foo', 'bar')"
# ),
}
# Allow new attributes after initialization.
self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed
@classmethod
def _create_config_tree_from_dict(cls, init_dict: dict, key_list: list):
r"""Create a configuration tree using the input dict. Any dict-like objects inside `init_dict` will be treated
as new `CfgNode` objects.
Args:
init_dict (dict): Input dictionary, to create config tree from.
key_list (list): A list of names that index this `CfgNode` from the root. Currently only used for logging.
"""
d = copy.deepcopy(init_dict)
for k, v in d.items():
if isinstance(v, dict):
# Convert dictionary to CfgNode
d[k] = cls(v, key_list=key_list + [k])
else:
# Check for valid leaf type or nested CfgNode
_assert_with_logging(
_valid_type(v, allow_cfg_node=False),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list + [k]), type(v), _VALID_TYPES
),
)
return d
def __getattr__(self, name: str):
if name in self:
return self[name]
else:
raise AttributeError(name)
def __setattr__(self, name: str, value):
if self.is_frozen():
raise AttributeError(
"Attempted to set {} to {}, but CfgNode is immutable".format(
name, value
)
)
_assert_with_logging(
name not in self.__dict__,
"Invalid attempt to modify internal CfgNode state: {}".format(name),
)
_assert_with_logging(
_valid_type(value, allow_cfg_node=True),
"Invalid type {} for key {}; valid types = {}".format(
type(value), name, _VALID_TYPES
),
)
self[name] = value
def __str__(self):
def _indent(s_, num_spaces):
s = s_.split("\n")
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
r = ""
s = []
for k, v in sorted(self.items()):
separator = "\n" if isinstance(v, CfgNode) else " "
attr_str = "{}:{}{}".format(str(k), separator, str(v))
attr_str = _indent(attr_str, 2)
s.append(attr_str)
r += "\n".join(s)
return r
def __repr__(self):
return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())
def dump(self, **kwargs):
r"""Dump CfgNode to a string.
"""
def _convert_to_dict(cfg_node, key_list):
if not isinstance(cfg_node, CfgNode):
_assert_with_logging(
_valid_type(cfg_node),
"Key {} with value {} is not a valid type; valid types: {}".format(
".".join(key_list), type(cfg_node), _VALID_TYPES
),
)
return cfg_node
else:
cfg_dict = dict(cfg_node)
for k, v in cfg_dict.items():
cfg_dict[k] = _convert_to_dict(v, key_list + [k])
return cfg_dict
self_as_dict = _convert_to_dict(self, [])
return yaml.safe_dump(self_as_dict, **kwargs)
def merge_from_file(self, cfg_filename: str):
r"""Load a yaml config file and merge it with this CfgNode.
Args:
cfg_filename (str): Config file path.
"""
with open(cfg_filename, "r") as f:
cfg = self.load_cfg(f)
self.merge_from_other_cfg(cfg)
def merge_from_other_cfg(self, cfg_other):
r"""Merge `cfg_other` into the current `CfgNode`.
Args:
cfg_other
"""
_merge_a_into_b(cfg_other, self, self, [])
def merge_from_list(self, cfg_list: list):
r"""Merge config (keys, values) in a list (eg. from commandline) into this `CfgNode`.
Eg. `cfg_list = ['FOO.BAR', 0.5]`.
"""
_assert_with_logging(
len(cfg_list) % 2 == 0,
"Override list has odd lengths: {}; it must be a list of pairs".format(
cfg_list
),
)
root = self
for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
if root.key_is_deprecated(full_key):
continue
if root.key_is_renamed(full_key):
root.raise_key_rename_error(full_key)
key_list = full_key.split(".")
d = self
for subkey in key_list[:-1]:
_assert_with_logging(
subkey in d, "Non-existent key: {}".format(full_key)
)
d = d[subkey]
subkey = key_list[-1]
_assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
value = self._decode_cfg_value(v)
value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
d[subkey] = value
def freeze(self):
r"""Make this `CfgNode` and all of its children immutable. """
self._immutable(True)
def defrost(self):
r"""Make this `CfgNode` and all of its children mutable. """
self._immutable(False)
def is_frozen(self):
r"""Return mutability. """
return self.__dict__[CfgNode.IMMUTABLE]
def _immutable(self, is_immutable: bool):
r"""Set mutability and recursively apply to all nested `CfgNode` objects.
Args:
is_immutable (bool): Whether or not the `CfgNode` and its children are immutable.
"""
self.__dict__[CfgNode.IMMUTABLE] = is_immutable
# Recursively propagate state to all children.
for v in self.__dict__.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
for v in self.values():
if isinstance(v, CfgNode):
v._immutable(is_immutable)
def clone(self):
r"""Recursively copy this `CfgNode`. """
return copy.deepcopy(self)
def register_deprecated_key(self, key: str):
r"""Register key (eg. `FOO.BAR`) a deprecated option. When merging deprecated keys, a warning is generated and
the key is ignored.
"""
_assert_with_logging(
key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
"key {} is already registered as a deprecated key".format(key),
)
self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)
def register_renamed_key(
self, old_name: str, new_name: str, message: Optional[str] = None
):
r"""Register a key as having been renamed from `old_name` to `new_name`. When merging a renamed key, an
exception is thrown alerting the user to the fact that the key has been renamed.
"""
_assert_with_logging(
old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
"key {} is already registered as a renamed cfg key".format(old_name),
)
value = new_name
if message:
value = (new_name, message)
self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value
def key_is_deprecated(self, full_key: str):
r"""Test if a key is deprecated. """
if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
logger.warning("deprecated config key (ignoring): {}".format(full_key))
return True
return False
def key_is_renamed(self, full_key: str):
r"""Test if a key is renamed. """
return full_key in self.__dict__[CfgNode.RENAMED_KEYS]
def raise_key_rename_error(self, full_key: str):
new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
if isinstance(new_key, tuple):
msg = " Note: " + new_key[1]
new_key = new_key[0]
else:
msg = ""
raise KeyError(
"Key {} was renamed to {}; please update your config.{}".format(
full_key, new_key, msg
)
)
def is_new_allowed(self):
return self.__dict__[CfgNode.NEW_ALLOWED]
@classmethod
def load_cfg(cls, cfg_file_obj_or_str):
r"""Load a configuration into the `CfgNode`.
Args:
cfg_file_obj_or_str (str or cfg compatible object): Supports loading from:
- A file object backed by a YAML file.
- A file object backed by a Python source file that exports an sttribute "cfg" (dict or `CfgNode`).
- A string that can be parsed as valid YAML.
"""
_assert_with_logging(
isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
"Expected first argument to be of type {} or {}, but got {}".format(
_FILE_TYPES, str, type(cfg_file_obj_or_str)
),
)
if isinstance(cfg_file_obj_or_str, str):
return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
return cls._load_cfg_from_file(cfg_file_obj_or_str)
else:
raise NotImplementedError("Impossible to reach here (unless there's a bug)")
@classmethod
def _load_cfg_from_file(cls, file_obj):
r"""Load a config from a YAML file or a Python source file. """
_, file_ext = os.path.splitext(file_obj.name)
if file_ext in _YAML_EXTS:
return cls._load_cfg_from_yaml_str(file_obj.read())
elif file_ext in _PY_EXTS:
return cls._load_cfg_py_source(file_obj.name)
else:
raise Exception(
"Attempt to load from an unsupported filetype {}; only {} supported".format(
_YAML_EXTS.union(_PY_EXTS)
)
)
@classmethod
def _load_cfg_from_yaml_str(cls, str_obj):
r"""Load a config from a YAML string encoding. """
cfg_as_dict = yaml.safe_load(str_obj)
return cls(cfg_as_dict)
@classmethod
def _load_cfg_py_source(cls, filename):
r"""Load a config from a Python source file. """
module = _load_module_from_file("yacs.config.override", filename)
_assert_with_logging(
hasattr(module, "cfg"),
"Python module from file {} must export a 'cfg' attribute".format(filename),
)
VALID_ATTR_TYPES = {dict, CfgNode}
_assert_with_logging(
type(module.cfg) in VALID_ATTR_TYPES,
"Import module 'cfg' attribute must be in {} but is {}".format(
VALID_ATTR_TYPES, type(module.cfg)
),
)
return cls(module.cfg)
@classmethod
def _decode_cfg_value(cls, value):
r"""Decodes a raw config value (eg. from a yaml config file or commandline argument) into a Python object.
If `value` is a dict, it will be interpreted as a new `CfgNode`.
If `value` is a str, it will be evaluated as a literal.
Otherwise, it is returned as is.
"""
# Configs parsed from raw yaml will contain dictionary keys that need to be converted to `CfgNode` objects.
if isinstance(value, dict):
return cls(value)
# All remaining processing is only applied to strings.
if not isinstance(value, str):
return value
# Try to interpret `value` as a: string, number, tuple, list, dict, bool, or None
try:
value = literal_eval(value)
# The following two excepts allow `value` to pass through it when it represents a string.
# The type of `value` is always a string (before calling `literal_eval`), but sometimes it *represents* a
# string and other times a data structure, like a list. In the case that `value` represents a str, what we
# got back from the yaml parser is `foo` *without quotes* (so, not `"foo"`). `literal_eval` is ok with `"foo"`,
# but will raise a `ValueError` if given `foo`. In other cases, like paths (`val = 'foo/bar'`) `literal_eval`
# will raise a `SyntaxError`.
except ValueError:
pass
except SyntaxError:
pass
return value
# Keep this function in global scope, for backward compataibility.
load_cfg = CfgNode.load_cfg
def _valid_type(value, allow_cfg_node: Optional[bool] = False):
return (type(value) in _VALID_TYPES) or (
allow_cfg_node and isinstance(value, CfgNode)
)
def _merge_a_into_b(a: CfgNode, b: CfgNode, root: CfgNode, key_list: list):
r"""Merge `CfgNode` `a` into `CfgNode` `b`, clobbering the options in `b` wherever they are also specified in `a`.
"""
_assert_with_logging(
isinstance(a, CfgNode),
"`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
)
_assert_with_logging(
isinstance(b, CfgNode),
"`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
)
for k, v_ in a.items():
full_key = ".".join(key_list + [k])
v = copy.deepcopy(v_)
v = b._decode_cfg_value(v)
if k in b:
v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
# Recursively merge dicts.
if isinstance(v, CfgNode):
try:
_merge_a_into_b(v, b[k], root, key_list + [k])
except BaseException:
raise
else:
b[k] = v
elif b.is_new_allowed():
b[k] = v
else:
if root.key_is_deprecated(full_key):
continue
elif root.key_is_renamed(full_key):
root.raise_key_rename_error(full_key)
else:
raise KeyError("Non-existent config key: {}".format(full_key))
def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
r"""Checks that `replacement`, which is intended to replace `original` is of the right type. The type is correct if
it matches exactly or is one of a few cases in which the type can easily be coerced.
"""
original_type = type(original)
replacement_type = type(replacement)
if replacement_type == original_type:
return replacement
# If replacement and original types match, cast replacement from `from_type` to `to_type`.
def _conditional_cast(from_type, to_type):
if replacement_type == from_type and original_type == to_type:
return True, to_type(replacement)
else:
return False, None
# Conditional casts.
# list <-> tuple
casts = [(tuple, list), (list, tuple)]
for (from_type, to_type) in casts:
converted, converted_value = _conditional_cast(from_type, to_type)
if converted:
return converted_value
raise ValueError(
"Type mismatch ({} vs. {} with values ({} vs. {}) for config key: {}".format(
original_type, replacement_type, original, replacement, full_key
)
)
def _assert_with_logging(cond, msg):
if not cond:
logger.debug(msg)
assert cond, msg
def _load_module_from_file(name, filename):
spec = importlib.util.spec_from_file_location(name, filename)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
================================================
FILE: utils/mesh_util.py
================================================
# adapted from https://github.com/zju3dv/manhattan_sdf
import numpy as np
import open3d as o3d
from sklearn.neighbors import KDTree
import trimesh
import os
os.environ['PYOPENGL_PLATFORM'] = 'egl'
import pyrender
from tqdm.contrib import tenumerate, tzip
def nn_correspondance(verts1, verts2):
indices = []
distances = []
if len(verts1) == 0 or len(verts2) == 0:
return indices, distances
kdtree = KDTree(verts1)
distances, indices = kdtree.query(verts2)
distances = distances.reshape(-1)
return distances
def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02):
pcd_trgt = o3d.geometry.PointCloud()
pcd_pred = o3d.geometry.PointCloud()
pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3])
pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3])
if down_sample:
pcd_pred = pcd_pred.voxel_down_sample(down_sample)
pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)
verts_pred = np.asarray(pcd_pred.points)
verts_trgt = np.asarray(pcd_trgt.points)
dist1 = nn_correspondance(verts_pred, verts_trgt)
dist2 = nn_correspondance(verts_trgt, verts_pred)
precision = np.mean((dist2 < threshold).astype('float'))
recal = np.mean((dist1 < threshold).astype('float'))
fscore = 2 * precision * recal / (precision + recal)
metrics = {
'Acc': np.mean(dist2),
'Comp': np.mean(dist1),
'Prec': precision,
'Recal': recal,
'F-score': fscore,
}
return metrics
class Renderer():
def __init__(self, height=480, width=640):
self.renderer = pyrender.OffscreenRenderer(width, height)
self.scene = pyrender.Scene()
# self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES
def __call__(self, height, width, intrinsics, pose, mesh):
self.renderer.viewport_height = height
self.renderer.viewport_width = width
self.scene.clear()
self.scene.add(mesh)
cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2],
fx=intrinsics[0, 0], fy=intrinsics[1, 1])
self.scene.add(cam, pose=self.fix_pose(pose))
return self.renderer.render(self.scene) # , self.render_flags)
def fix_pose(self, pose):
# 3D Rotation about the x-axis.
t = np.pi
c = np.cos(t)
s = np.sin(t)
R = np.array([[1, 0, 0],
[0, c, -s],
[0, s, c]])
axis_transform = np.eye(4)
axis_transform[:3, :3] = R
return pose @ axis_transform
def mesh_opengl(self, mesh):
return pyrender.Mesh.from_trimesh(mesh)
def delete(self):
self.renderer.delete()
def refuse(mesh, poses, K, H, W, far_clip=5.0):
renderer = Renderer()
mesh_opengl = renderer.mesh_opengl(mesh)
volume = o3d.pipelines.integration.ScalableTSDFVolume(
voxel_length=0.01,
sdf_trunc=3 * 0.01,
color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
)
for i, pose in tenumerate(poses):
intrinsic = K
rgb = np.ones((H, W, 3))
rgb = (rgb * 255).astype(np.uint8)
rgb = o3d.geometry.Image(rgb)
_, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl)
depth_pred = o3d.geometry.Image(depth_pred)
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
rgb, depth_pred, depth_scale=1.0, depth_trunc=far_clip, convert_rgb_to_intensity=False
)
fx, fy, cx, cy = intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]
intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx, fy=fy, cx=cx, cy=cy)
extrinsic = np.linalg.inv(pose)
volume.integrate(rgbd, intrinsic, extrinsic)
return volume.extract_triangle_mesh()
def depth2mesh(depths, poses, K, H, W):
volume = o3d.pipelines.integration.ScalableTSDFVolume(
voxel_length=0.01,
sdf_trunc=3 * 0.01,
color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
)
for depth, pose in tzip(depths, poses):
rgb = np.ones((H, W, 3))
rgb = (rgb * 255).astype(np.uint8)
rgb = o3d.geometry.Image(rgb)
depth = o3d.geometry.Image(depth)
rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
rgb, depth, depth_scale=1.0, depth_trunc=5.0, convert_rgb_to_intensity=False
)
fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx, fy=fy, cx=cx, cy=cy)
extrinsic = np.linalg.inv(pose)
volume.integrate(rgbd, intrinsic, extrinsic)
return volume.extract_triangle_mesh()
================================================
FILE: utils/plots.py
================================================
import plotly.graph_objs as go
import plotly.offline as offline
from plotly.subplots import make_subplots
import numpy as np
import torch
from skimage import measure
import torchvision.utils as vutils
import trimesh
from PIL import Image
import cv2
import mcubes
from utils import rend_util
def plot(implicit_network, indices, plot_data, path, epoch, img_res, plot_nimgs, resolution=None, grid_boundary=None, meshing=True, level=0):
if plot_data is not None:
if 'rgb_eval' in plot_data:
plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res)
if 'hdr_eval' in plot_data:
plot_images(plot_data['hdr_eval'], plot_data['hdr_gt'], path, epoch, plot_nimgs, img_res, 'hdr', True)
if 'rgb_surface' in plot_data:
plot_images(plot_data['rgb_surface'], plot_data['rgb_gt'], path, '{0}s'.format(epoch), plot_nimgs, img_res)
if 'rendered' in plot_data:
plot_images(plot_data['rendered'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res, 'rendered')
if 'normal_map' in plot_data:
plot_imgs_wo_gt(plot_data['normal_map'], path, epoch, plot_nimgs, img_res)
if 'depth_eval' in plot_data:
plot_depths(plot_data['depth_eval'], path, epoch, plot_nimgs, img_res)
if 'lmask_eval' in plot_data:
plot_images(plot_data['lmask_eval'], plot_data['lmask_gt'], path, epoch, plot_nimgs, img_res, 'light_mask')
if 'Kd' in plot_data:
# plot_imgs_wo_gt(plot_data['albedo'], path, epoch, plot_nimgs, img_res, 'albedo')
plot_images(plot_data['Ks'], plot_data['Kd'], path, epoch, plot_nimgs, img_res, 'albedo')
if 'roughness' in plot_data:
plot_colormap(plot_data['roughness'], path, epoch, plot_nimgs, img_res)
if 'metallic' in plot_data:
plot_colormap(plot_data['metallic'], path, epoch, plot_nimgs, img_res, 'metallic')
if 'emission' in plot_data:
plot_imgs_wo_gt(plot_data['emission'], path, '{0}'.format(epoch), plot_nimgs, img_res, 'emission', True)
if meshing:
path = f"{path}/mesh"
cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose'])
data = []
# plot surface
surface_traces = get_surface_trace(path=path,
epoch=epoch,
sdf=lambda x: implicit_network(x)[:, 0],
resolution=resolution,
grid_boundary=grid_boundary,
level=level
)
if surface_traces is not None:
data.append(surface_traces[0])
# plot cameras locations
if plot_data is not None:
for i, loc, dir in zip(indices, cam_loc, cam_dir):
data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i)))
fig = go.Figure(data=data)
scene_dict = dict(xaxis=dict(range=[-6, 6], autorange=False),
yaxis=dict(range=[-6, 6], autorange=False),
zaxis=dict(range=[-6, 6], autorange=False),
aspectratio=dict(x=1, y=1, z=1))
fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
filename = '{0}/surface_{1}.html'.format(path, epoch)
offline.plot(fig, filename=filename, auto_open=False)
def visualize_pointcloud(points, filename):
fig = go.Figure(
data = [
get_3D_scatter_trace(points, 'Pointcloud', 1)
]
)
scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
yaxis=dict(range=[-3, 3], autorange=False),
zaxis=dict(range=[-3, 3], autorange=False),
aspectratio=dict(x=1, y=1, z=1))
fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
offline.plot(fig, filename=filename, auto_open=False)
def visualize_clustered_pointcloud(points, labels, centroids, filename):
fig = go.Figure()
if centroids is not None:
fig.add_trace(get_3D_scatter_trace(centroids, "Centroids", 10))
for c in torch.unique(labels):
cluster = points[labels == c, :]
fig.add_trace(get_3D_scatter_trace(cluster, f"Emitter #{int(c)}"))
scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
yaxis=dict(range=[-3, 3], autorange=False),
zaxis=dict(range=[-3, 3], autorange=False),
aspectratio=dict(x=1, y=1, z=1))
fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
offline.plot(fig, filename=filename, auto_open=False)
def visualize_marked_pointcloud(points, counts, path, epoch):
fig = go.Figure(
data = [
get_3D_marked_scatter_trace(points, counts, 'Pointcloud samples')
]
)
scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
yaxis=dict(range=[-3, 3], autorange=False),
zaxis=dict(range=[-3, 3], autorange=False),
aspectratio=dict(x=1, y=1, z=1))
fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
filename = '{0}/pointcloud/{1}.html'.format(path, epoch)
offline.plot(fig, filename=filename, auto_open=False)
def get_3D_scatter_trace(points, name='', size=3, caption=None):
assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped "
assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped "
trace = go.Scatter3d(
x=points[:, 0].cpu(),
y=points[:, 1].cpu(),
z=points[:, 2].cpu(),
mode='markers',
name=name,
marker=dict(
size=size,
line=dict(
width=2,
),
opacity=1.0,
), text=caption)
return trace
def get_3D_marked_scatter_trace(points, marks, name='', size=1, caption=None):
assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped "
assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped "
trace = go.Scatter3d(
x=points[:, 0].cpu(),
y=points[:, 1].cpu(),
z=points[:, 2].cpu(),
mode='markers',
name=name,
marker=dict(
size=size,
li
gitextract_xvem6nqy/
├── .gitignore
├── DATA_CONVENTION.md
├── LICENSE
├── README.md
├── config/
│ ├── synthetic.yml
│ └── synthetic_light_mask.yml
├── data/
│ ├── normalize_cameras.py
│ └── npz_to_blender.py
├── dataset/
│ ├── __init__.py
│ ├── eval_dataset.py
│ └── train_dataset.py
├── environment.yml
├── i2-sdf-dataset-links.csv
├── main_recon.py
├── model/
│ ├── __init__.py
│ ├── eval/
│ │ ├── __init__.py
│ │ └── recon.py
│ ├── network/
│ │ ├── __init__.py
│ │ ├── density.py
│ │ ├── embedder.py
│ │ ├── mlp.py
│ │ └── ray_sampler.py
│ ├── rendering/
│ │ ├── __init__.py
│ │ └── brdf.py
│ └── trainer/
│ ├── __init__.py
│ └── recon.py
└── utils/
├── __init__.py
├── cfgnode.py
├── mesh_util.py
├── plots.py
└── rend_util.py
SYMBOL INDEX (262 symbols across 18 files)
FILE: data/normalize_cameras.py
function get_center_point (line 6) | def get_center_point(num_cams,cameras):
function normalize_cameras (line 31) | def normalize_cameras(original_cameras_filename,output_cameras_filename,...
FILE: data/npz_to_blender.py
function to16b (line 14) | def to16b(img):
function opencv_to_gl (line 19) | def opencv_to_gl(pose):
function get_offset (line 25) | def get_offset(poses):
function scale_pose (line 37) | def scale_pose(pose, scale, offset):
function load_K_Rt_from_P (line 43) | def load_K_Rt_from_P(filename, P=None):
function main (line 67) | def main():
FILE: dataset/eval_dataset.py
class GridDataset (line 15) | class GridDataset(Dataset):
method __init__ (line 19) | def __init__(self, points, xyz) -> None:
method __len__ (line 24) | def __len__(self):
method __getitem__ (line 27) | def __getitem__(self, index):
class PlotDataset (line 31) | class PlotDataset(torch.utils.data.Dataset):
method __init__ (line 32) | def __init__(self,
method shuffle_plot_index (line 137) | def shuffle_plot_index(self):
method __len__ (line 141) | def __len__(self):
method get_uv (line 144) | def get_uv(self):
method __getitem__ (line 150) | def __getitem__(self, idx):
method collate_fn (line 170) | def collate_fn(self, batch_list):
class InterpolateDataset (line 188) | class InterpolateDataset(torch.utils.data.Dataset):
method __init__ (line 192) | def __init__(self,
method __len__ (line 244) | def __len__(self):
method __getitem__ (line 247) | def __getitem__(self, idx):
method collate_fn (line 258) | def collate_fn(self, batch_list):
class RelightDataset (line 276) | class RelightDataset(PlotDataset):
method __init__ (line 277) | def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwar...
method loadattr (line 302) | def loadattr(self, edit_cfg, attr, mode=0):
method __len__ (line 316) | def __len__(self):
method __getitem__ (line 319) | def __getitem__(self, idx):
class RelightVideoDataset (line 343) | class RelightVideoDataset(PlotDataset):
method __init__ (line 344) | def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwar...
method __len__ (line 358) | def __len__(self):
method __getitem__ (line 361) | def __getitem__(self, idx):
FILE: dataset/train_dataset.py
class ReconDataset (line 15) | class ReconDataset(torch.utils.data.Dataset):
method __init__ (line 17) | def __init__(self,
method __len__ (line 166) | def __len__(self):
method __getitem__ (line 169) | def __getitem__(self, idx):
method collate_fn (line 194) | def collate_fn(self, batch_list):
class MaterialDataset (line 212) | class MaterialDataset(torch.utils.data.Dataset):
method __init__ (line 214) | def __init__(self,
method __len__ (line 315) | def __len__(self):
method __getitem__ (line 318) | def __getitem__(self, idx):
method collate_fn (line 336) | def collate_fn(self, batch_list):
FILE: model/eval/recon.py
class SDFMeshSystem (line 21) | class SDFMeshSystem(pl.LightningModule):
method __init__ (line 22) | def __init__(self, conf, exp_dir, resolution, score=False, far_clip=5....
method initialize (line 46) | def initialize(self):
method test_dataloader (line 84) | def test_dataloader(self):
method test_step (line 89) | def test_step(self, batch, batch_idx):
method test_epoch_end (line 92) | def test_epoch_end(self, outputs) -> None:
method forward (line 131) | def forward(self):
class VolumeRenderSystem (line 135) | class VolumeRenderSystem(pl.LightningModule):
method __init__ (line 136) | def __init__(self, conf, exp_dir, indices=None, is_val=False, score_me...
method test_dataloader (line 157) | def test_dataloader(self):
method test_step (line 163) | def test_step(self, batch, batch_idx):
method test_epoch_end (line 205) | def test_epoch_end(self, outputs):
method forward (line 223) | def forward(self):
class ViewInterpolateSystem (line 227) | class ViewInterpolateSystem(pl.LightningModule):
method __init__ (line 228) | def __init__(self, conf, exp_dir, id0, id1, n_frames=60, frame_rate=24...
method test_dataloader (line 253) | def test_dataloader(self):
method test_step (line 259) | def test_step(self, batch, batch_idx):
method test_epoch_end (line 284) | def test_epoch_end(self, outputs):
method forward (line 302) | def forward(self):
FILE: model/network/__init__.py
class I2SDFNetwork (line 19) | class I2SDFNetwork(nn.Module):
method __init__ (line 20) | def __init__(self, conf):
method init_emission_groups (line 49) | def init_emission_groups(self, n_emitters, pointcloud, init_emission=1...
method get_param_groups (line 77) | def get_param_groups(self, lr):
method forward (line 80) | def forward(self, input, predict_only=False):
method volume_rendering (line 223) | def volume_rendering(self, z_vals, z_max, sdf):
method bg_volume_rendering (line 242) | def bg_volume_rendering(self, z_vals_bg, bg_sdf):
method depth2pts_outside (line 258) | def depth2pts_outside(self, ray_o, ray_d, depth):
class I2SDFLoss (line 289) | class I2SDFLoss(nn.Module):
method __init__ (line 290) | def __init__(self, eikonal_weight=0.1, smooth_weight=0.0, mask_weight=...
method get_rgb_loss (line 308) | def get_rgb_loss(self, rgb_values, rgb_gt):
method get_eikonal_loss (line 313) | def get_eikonal_loss(self, grad_theta):
method get_mask_loss (line 317) | def get_mask_loss(self, mask_pred, mask_gt):
method get_depth_loss (line 320) | def get_depth_loss(self, depth, depth_gt, depth_mask):
method get_normal_l1_loss (line 326) | def get_normal_l1_loss(self, normal, normal_gt, normal_mask):
method get_normal_angular_loss (line 331) | def get_normal_angular_loss(self, normal, normal_gt, normal_mask):
method forward (line 338) | def forward(self, model_outputs, ground_truth, current_step):
FILE: model/network/density.py
class Density (line 5) | class Density(nn.Module):
method __init__ (line 6) | def __init__(self, params_init={}):
method forward (line 12) | def forward(self, sdf, beta=None):
class LaplaceDensity (line 16) | class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf...
method __init__ (line 17) | def __init__(self, params_init={}, beta_min=0.0001):
method density_func (line 21) | def density_func(self, sdf, beta=None):
method get_beta (line 28) | def get_beta(self):
class AbsDensity (line 33) | class AbsDensity(Density): # like NeRF++
method density_func (line 34) | def density_func(self, sdf, beta=None):
class SimpleDensity (line 38) | class SimpleDensity(Density): # like NeRF
method __init__ (line 39) | def __init__(self, params_init={}, noise_std=1.0):
method density_func (line 43) | def density_func(self, sdf, beta=None):
FILE: model/network/embedder.py
class Embedder (line 6) | class Embedder:
method __init__ (line 8) | def __init__(self, **kwargs):
method create_embedding_fn (line 12) | def create_embedding_fn(self):
method embed (line 37) | def embed(self, inputs):
class SHEncoder (line 41) | class SHEncoder(nn.Module):
method __init__ (line 42) | def __init__(self, input_dims=3, degree=4):
method forward (line 84) | def forward(self, input, **kwargs):
class FourierFeature (line 125) | class FourierFeature(nn.Module):
method __init__ (line 126) | def __init__(self, channels, sigma=1.0, input_dims=3, include_input=Tr...
method forward (line 133) | def forward(self, x):
function get_embedder (line 138) | def get_embedder(embed_type='positional', **kwargs):
FILE: model/network/mlp.py
class ImplicitNetwork (line 10) | class ImplicitNetwork(nn.Module):
method __init__ (line 11) | def __init__(
method get_param_groups (line 81) | def get_param_groups(self, lr):
method forward (line 84) | def forward(self, input):
method gradient (line 107) | def gradient(self, x):
method feature (line 120) | def feature(self, x):
method get_outputs (line 123) | def get_outputs(self, x, returns_grad=True):
method get_sdf_vals (line 145) | def get_sdf_vals(self, x):
class RenderingNetwork (line 159) | class RenderingNetwork(nn.Module):
method __init__ (line 160) | def __init__(
method forward (line 208) | def forward(self, points, normals, view_dirs, feature_vectors):
FILE: model/network/ray_sampler.py
class RaySampler (line 6) | class RaySampler(metaclass=abc.ABCMeta):
method __init__ (line 7) | def __init__(self, near, far):
method get_z_vals (line 12) | def get_z_vals(self, ray_dirs, cam_loc, model):
class UniformSampler (line 15) | class UniformSampler(RaySampler):
method __init__ (line 16) | def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere...
method get_z_vals (line 22) | def get_z_vals(self, ray_dirs, cam_loc, model):
class ErrorBoundSampler (line 46) | class ErrorBoundSampler(RaySampler):
method __init__ (line 47) | def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_e...
method get_z_vals (line 67) | def get_z_vals(self, ray_dirs, cam_loc, model):
method get_error_bound (line 243) | def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
FILE: model/rendering/__init__.py
class RenderingLayer (line 10) | class RenderingLayer(nn.Module):
method __init__ (line 11) | def __init__(self, spp, split_n_pixels, preserve_light=True) -> None:
method forward (line 17) | def forward(
FILE: model/rendering/brdf.py
function create_frame (line 5) | def create_frame(n: torch.Tensor, eps:float = 1e-6):
function get_rendering_parameters (line 20) | def get_rendering_parameters(albedo_raw, rough_raw, use_metallic):
function to_global (line 35) | def to_global(d, x, y, z):
function sqrt_ (line 41) | def sqrt_(x: torch.Tensor, eps=1e-8) -> torch.Tensor:
function reflect (line 47) | def reflect(v: torch.Tensor, h: torch.Tensor):
function square_to_cosine_hemisphere (line 51) | def square_to_cosine_hemisphere(sample: torch.Tensor):
function get_cos_theta (line 59) | def get_cos_theta(v: torch.Tensor):
function get_phi (line 63) | def get_phi(v: torch.Tensor):
function sample_disney_specular (line 72) | def sample_disney_specular(sample: torch.Tensor, roughness: torch.Tensor...
function GTR2 (line 94) | def GTR2(ndh, a):
function SchlickFresnel (line 99) | def SchlickFresnel(u):
function smithG_GGX (line 103) | def smithG_GGX(ndv, a):
function pdf_disney (line 109) | def pdf_disney(roughness: torch.Tensor, metallic: torch.Tensor, wi: torc...
function eval_disney (line 130) | def eval_disney(albedo: torch.Tensor, roughness: torch.Tensor, metallic:...
function F_Schlick (line 164) | def F_Schlick(SpecularColor, VoH):
function GetSpecularEventProbability (line 168) | def GetSpecularEventProbability(SpecularColor, NoV) -> torch.Tensor:
function baseColorToSpecularF0 (line 172) | def baseColorToSpecularF0(baseColor, metalness):
function luminance (line 175) | def luminance(color):
function probabilityToSampleSpecular (line 181) | def probabilityToSampleSpecular(difColor, specColor) -> torch.Tensor:
function shadowedF90 (line 186) | def shadowedF90(F0):
function evalFresnel (line 190) | def evalFresnel(f0, f90, NdotS):
function Smith_G1_GGX (line 194) | def Smith_G1_GGX(alphaSquared, NdotSSquared):
function Smith_G2_GGX (line 197) | def Smith_G2_GGX(alphaSquared, NdotL, NdotV):
function GGX_D (line 202) | def GGX_D(alphaSquared, NdotH):
function pdf_ggx (line 206) | def pdf_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor,...
function eval_ggx (line 241) | def eval_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor...
function sample_weight_ggx (line 268) | def sample_weight_ggx(alphaSquared, NdotL, NdotV):
function sample_ggx (line 273) | def sample_ggx(sample: torch.Tensor, Kd: torch.Tensor, Ks: torch.Tensor,...
function sample_ggx_specular (line 325) | def sample_ggx_specular(sample: torch.Tensor, roughness: torch.Tensor, w...
FILE: model/trainer/recon.py
class ReconstructionTrainer (line 23) | class ReconstructionTrainer(pl.LightningModule):
method __init__ (line 24) | def __init__(self, conf, prog_bar: RichProgressBar, exp_dir, model_onl...
method forward (line 109) | def forward(self):
method plot_hotmap (line 112) | def plot_hotmap(self, path):
method plot_countmap (line 127) | def plot_countmap(self, path):
method update_pdf (line 142) | def update_pdf(self, value, idx):
method sample_bubble (line 155) | def sample_bubble(self, batch_size):
method initialize_bubble_pdf (line 172) | def initialize_bubble_pdf(self, split_size):
method configure_optimizers (line 201) | def configure_optimizers(self):
method train_dataloader (line 209) | def train_dataloader(self):
method val_dataloader (line 212) | def val_dataloader(self):
method log_if_nonzero (line 215) | def log_if_nonzero(self, name, value, *args, **kwargs):
method training_step (line 219) | def training_step(self, batch, batch_idx):
method validation_step (line 290) | def validation_step(self, batch, batch_idx):
method validation_epoch_end (line 358) | def validation_epoch_end(self, outputs) -> None:
FILE: utils/__init__.py
class RichProgressBarWithScanId (line 12) | class RichProgressBarWithScanId(RichProgressBar):
method __init__ (line 13) | def __init__(self, scan_id, *args, **kwargs) -> None:
method configure_columns (line 17) | def configure_columns(self, trainer):
function glob_imgs (line 21) | def glob_imgs(path):
function glob_depths (line 27) | def glob_depths(path):
function split_input (line 35) | def split_input(model_input, total_pixels, n_pixels=10000):
function split_dict (line 50) | def split_dict(d, batch_size=10000):
function detach_dict (line 66) | def detach_dict(d):
function merge_output (line 70) | def merge_output(res, total_pixels, batch_size):
function merge_dict (line 87) | def merge_dict(dicts):
class _trunc_exp (line 96) | class _trunc_exp(Function):
method forward (line 99) | def forward(ctx, x):
method backward (line 105) | def backward(ctx, g):
function kmeans_pp_centroid (line 111) | def kmeans_pp_centroid(points: torch.Tensor, k):
FILE: utils/cfgnode.py
class CfgNode (line 34) | class CfgNode(dict):
method __init__ (line 44) | def __init__(
method _create_config_tree_from_dict (line 87) | def _create_config_tree_from_dict(cls, init_dict: dict, key_list: list):
method __getattr__ (line 110) | def __getattr__(self, name: str):
method __setattr__ (line 116) | def __setattr__(self, name: str, value):
method __str__ (line 138) | def __str__(self):
method __repr__ (line 159) | def __repr__(self):
method dump (line 162) | def dump(self, **kwargs):
method merge_from_file (line 184) | def merge_from_file(self, cfg_filename: str):
method merge_from_other_cfg (line 193) | def merge_from_other_cfg(self, cfg_other):
method merge_from_list (line 200) | def merge_from_list(self, cfg_list: list):
method freeze (line 229) | def freeze(self):
method defrost (line 233) | def defrost(self):
method is_frozen (line 237) | def is_frozen(self):
method _immutable (line 241) | def _immutable(self, is_immutable: bool):
method clone (line 255) | def clone(self):
method register_deprecated_key (line 259) | def register_deprecated_key(self, key: str):
method register_renamed_key (line 270) | def register_renamed_key(
method key_is_deprecated (line 286) | def key_is_deprecated(self, full_key: str):
method key_is_renamed (line 293) | def key_is_renamed(self, full_key: str):
method raise_key_rename_error (line 297) | def raise_key_rename_error(self, full_key: str):
method is_new_allowed (line 310) | def is_new_allowed(self):
method load_cfg (line 314) | def load_cfg(cls, cfg_file_obj_or_str):
method _load_cfg_from_file (line 336) | def _load_cfg_from_file(cls, file_obj):
method _load_cfg_from_yaml_str (line 351) | def _load_cfg_from_yaml_str(cls, str_obj):
method _load_cfg_py_source (line 357) | def _load_cfg_py_source(cls, filename):
method _decode_cfg_value (line 374) | def _decode_cfg_value(cls, value):
function _valid_type (line 406) | def _valid_type(value, allow_cfg_node: Optional[bool] = False):
function _merge_a_into_b (line 412) | def _merge_a_into_b(a: CfgNode, b: CfgNode, root: CfgNode, key_list: list):
function _check_and_coerce_cfg_value_type (line 450) | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
function _assert_with_logging (line 482) | def _assert_with_logging(cond, msg):
function _load_module_from_file (line 488) | def _load_module_from_file(name, filename):
FILE: utils/mesh_util.py
function nn_correspondance (line 12) | def nn_correspondance(verts1, verts2):
function evaluate (line 25) | def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02):
class Renderer (line 55) | class Renderer():
method __init__ (line 56) | def __init__(self, height=480, width=640):
method __call__ (line 61) | def __call__(self, height, width, intrinsics, pose, mesh):
method fix_pose (line 71) | def fix_pose(self, pose):
method mesh_opengl (line 83) | def mesh_opengl(self, mesh):
method delete (line 86) | def delete(self):
function refuse (line 90) | def refuse(mesh, poses, K, H, W, far_clip=5.0):
function depth2mesh (line 117) | def depth2mesh(depths, poses, K, H, W):
FILE: utils/plots.py
function plot (line 15) | def plot(implicit_network, indices, plot_data, path, epoch, img_res, plo...
function visualize_pointcloud (line 76) | def visualize_pointcloud(points, filename):
function visualize_clustered_pointcloud (line 90) | def visualize_clustered_pointcloud(points, labels, centroids, filename):
function visualize_marked_pointcloud (line 105) | def visualize_marked_pointcloud(points, counts, path, epoch):
function get_3D_scatter_trace (line 120) | def get_3D_scatter_trace(points, name='', size=3, caption=None):
function get_3D_marked_scatter_trace (line 141) | def get_3D_marked_scatter_trace(points, marks, name='', size=1, caption=...
function get_3D_quiver_trace (line 164) | def get_3D_quiver_trace(points, directions, color='#bd1540', name=''):
function get_surface_trace (line 188) | def get_surface_trace(path, epoch, sdf, resolution=100, grid_boundary=[-...
function extract_fields (line 228) | def extract_fields(bound_min, bound_max, resolution, query_func):
function extract_geometry (line 246) | def extract_geometry(bound_min, bound_max, resolution, threshold, query_...
function get_surface_high_res_mesh (line 258) | def get_surface_high_res_mesh(sdf, resolution=100, grid_boundary=[-2.0, ...
function get_surface_by_grid (line 339) | def get_surface_by_grid(grid_params, sdf, resolution=100, level=0, highe...
function get_grid_uniform (line 440) | def get_grid_uniform(resolution, grid_boundary=[-2.0, 2.0]):
function get_grid (line 453) | def get_grid(points, resolution, input_min=None, input_max=None, eps=0.1):
function plot_imgs_wo_gt (line 492) | def plot_imgs_wo_gt(normal_maps, path, epoch, plot_nrow, img_res, path_n...
function plot_imgs_filter (line 508) | def plot_imgs_filter(rgb_points, ground_true, path, epoch, img_res, path...
function plot_colormap (line 522) | def plot_colormap(mat_info, path, epoch, plot_nrow, img_res, colormap=cv...
function plot_depths (line 538) | def plot_depths(depth_maps, path, epoch, plot_nrow, img_res, colormap=cv...
function plot_images (line 560) | def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res...
function lin2img (line 581) | def lin2img(tensor, img_res):
FILE: utils/rend_util.py
function linear_to_srgb (line 9) | def linear_to_srgb(data):
function get_psnr (line 13) | def get_psnr(img1, img2, normalize_rgb=False):
function load_rgb (line 25) | def load_rgb(path, normalize_rgb = False, is_hdr = False):
function load_mask (line 38) | def load_mask(path):
function load_depth (line 46) | def load_depth(path):
function load_normal (line 52) | def load_normal(path):
function load_K_Rt_from_P (line 57) | def load_K_Rt_from_P(filename, P=None):
function depth_to_world (line 81) | def depth_to_world(uv, intrinsics, pose, depth, depth_mask=None):
function get_camera_params (line 92) | def get_camera_params(uv, pose, intrinsics):
function get_camera_for_plot (line 123) | def get_camera_for_plot(pose):
function lift (line 134) | def lift(x, y, z, intrinsics):
function quat_to_rot (line 150) | def quat_to_rot(q):
function rot_to_quat (line 170) | def rot_to_quat(R):
function get_general_sphere_intersections (line 191) | def get_general_sphere_intersections(cam_loc, ray_directions, center, r):
function get_sphere_intersections (line 211) | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0):
function add_depth_noise (line 229) | def add_depth_noise(depth, depth_mask, scale=1):
Condensed preview — 31 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (230K chars).
[
{
"path": ".gitignore",
"chars": 44,
"preview": "data/synthetic\nexps\n__pycache__\narchive\ntmp*"
},
{
"path": "DATA_CONVENTION.md",
"chars": 3461,
"preview": "# Data Convention\n\nThe format of our multi-view dataset is derived from [VolSDF](https://github.com/lioryariv/volsdf/blo"
},
{
"path": "LICENSE",
"chars": 1068,
"preview": "MIT License\n\nCopyright (c) 2023 Jingsen Zhu\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "README.md",
"chars": 3711,
"preview": "**News**\n\n- `04/04/2023` dataset preview release: 2 synthetic scenes available\n- `15/04/2023` code release: 3D reconstru"
},
{
"path": "config/synthetic.yml",
"chars": 1599,
"preview": "train:\n expname: synthetic\n learning_rate: 5.0e-4\n steps: 200000\n checkpoint_freq: 10000\n plot_freq: 500\n"
},
{
"path": "config/synthetic_light_mask.yml",
"chars": 1520,
"preview": "train:\n expname: synthetic_light\n learning_rate: 5.0e-4\n steps: 200000\n checkpoint_freq: 10000\n plot_freq"
},
{
"path": "data/normalize_cameras.py",
"chars": 4116,
"preview": "import cv2\nimport numpy as np\nimport argparse\nfrom copy import deepcopy\n\ndef get_center_point(num_cams,cameras):\n A ="
},
{
"path": "data/npz_to_blender.py",
"chars": 4142,
"preview": "\"\"\"\n Transform npz-formatted scenes to json-formatted scene (NeRF blender format)\n Scale all poses to fit in a [-1"
},
{
"path": "dataset/__init__.py",
"chars": 56,
"preview": "from .train_dataset import *\nfrom .eval_dataset import *"
},
{
"path": "dataset/eval_dataset.py",
"chars": 15229,
"preview": "from copy import deepcopy\nimport os\nimport torch\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport utils.pl"
},
{
"path": "dataset/train_dataset.py",
"chars": 16464,
"preview": "import json\nimport os\nimport cv2\nimport torch\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport torch.nn.fu"
},
{
"path": "environment.yml",
"chars": 782,
"preview": "name: i2sdf\nchannels:\n - pytorch\n - conda-forge\n - defaults\ndependencies:\n - cudatoolkit=11.3.1=h9edb442_10\n - ffmp"
},
{
"path": "i2-sdf-dataset-links.csv",
"chars": 2799,
"preview": "file,url\ninteriorverse/i2-sdf/i2-sdf/bedroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedr"
},
{
"path": "main_recon.py",
"chars": 6355,
"preview": "import torch\nimport yaml\nimport pytorch_lightning as pl\nimport argparse\nimport os\nimport utils\nimport model\nfrom pytorch"
},
{
"path": "model/__init__.py",
"chars": 129,
"preview": "from .network import *\nfrom .trainer import *\n# from .material import *\nfrom .eval import *\nfrom .rendering import Rende"
},
{
"path": "model/eval/__init__.py",
"chars": 20,
"preview": "from .recon import *"
},
{
"path": "model/eval/recon.py",
"chars": 15760,
"preview": "import torch\nimport pytorch_lightning as pl\nimport numpy as np\nimport os\nfrom glob import glob\nfrom torch.utils.data imp"
},
{
"path": "model/network/__init__.py",
"chars": 19904,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\nimport utils\nfrom mod"
},
{
"path": "model/network/density.py",
"chars": 1422,
"preview": "import torch.nn as nn\nimport torch\n\n\nclass Density(nn.Module):\n def __init__(self, params_init={}):\n super()._"
},
{
"path": "model/network/embedder.py",
"chars": 6089,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nclass Embedder:\n \"\"\" Positiona"
},
{
"path": "model/network/mlp.py",
"chars": 7508,
"preview": "import torch.nn as nn\nimport numpy as np\n\nimport utils\nfrom .embedder import *\nfrom .density import LaplaceDensity\nfrom "
},
{
"path": "model/network/ray_sampler.py",
"chars": 12331,
"preview": "import abc\nimport torch\nfrom utils import rend_util\nimport utils\n\nclass RaySampler(metaclass=abc.ABCMeta):\n def __ini"
},
{
"path": "model/rendering/__init__.py",
"chars": 3616,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport utils\nimport cv2\nfrom .brdf"
},
{
"path": "model/rendering/brdf.py",
"chars": 14422,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\n\ndef create_frame(n: torch.Tensor, eps:float = 1e-6):\n "
},
{
"path": "model/trainer/__init__.py",
"chars": 40,
"preview": "from .recon import ReconstructionTrainer"
},
{
"path": "model/trainer/recon.py",
"chars": 19845,
"preview": "import math\nimport torch\nimport pytorch_lightning as pl\nimport numpy as np\nimport torch.optim as optim\nimport os\nfrom to"
},
{
"path": "utils/__init__.py",
"chars": 3849,
"preview": "from .cfgnode import CfgNode\nfrom .rend_util import *\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn"
},
{
"path": "utils/cfgnode.py",
"chars": 18578,
"preview": "\"\"\"\nDefine a class to hold configurations.\nBorrows and merges stuff from YACS, fvcore, and detectron2\nhttps://github.com"
},
{
"path": "utils/mesh_util.py",
"chars": 4800,
"preview": "# adapted from https://github.com/zju3dv/manhattan_sdf\nimport numpy as np\nimport open3d as o3d\nfrom sklearn.neighbors im"
},
{
"path": "utils/plots.py",
"chars": 25132,
"preview": "import plotly.graph_objs as go\nimport plotly.offline as offline\nfrom plotly.subplots import make_subplots\nimport numpy a"
},
{
"path": "utils/rend_util.py",
"chars": 7476,
"preview": "import numpy as np\nimport imageio\nimport skimage\nimport cv2\nimport torch\nfrom torch.nn import functional as F\n\n\ndef line"
}
]
About this extraction
This page contains the full source code of the jingsenzhu/i2-sdf GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 31 files (217.1 KB), approximately 59.9k tokens, and a symbol index with 262 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.