Repository: WU-CVGL/BAD-NeRF
Branch: main
Commit: aed2c4a4b230
Files: 35
Total size: 113.4 KB
Directory structure:
gitextract_1dub1nd4/
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── Spline.py
├── configs/
│ ├── cozy2room-cubic.txt
│ ├── cozy2room.txt
│ ├── factory-cubic.txt
│ ├── factory.txt
│ ├── pool-cubic.txt
│ ├── pool.txt
│ ├── roomdark-cubic.txt
│ ├── roomdark.txt
│ ├── roomhigh-cubic.txt
│ ├── roomhigh.txt
│ ├── roomlow-cubic.txt
│ ├── roomlow.txt
│ ├── tanabata-cubic.txt
│ ├── tanabata.txt
│ ├── wine-cubic.txt
│ └── wine.txt
├── load_llff.py
├── lpips/
│ ├── __init__.py
│ ├── lpips.py
│ └── pretrained_networks.py
├── metrics.py
├── nerf.py
├── novel_view_test.py
├── optimize_pose_cubic.py
├── optimize_pose_linear.py
├── requirements.txt
├── run_nerf.py
├── run_nerf_helpers.py
├── test.py
└── train.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea/
.vscode/
__pycache__/
logs*/
data/
configs_test/
weights/
================================================
FILE: Dockerfile
================================================
FROM nvcr.io/nvidia/pytorch:23.02-py3
ARG DEBIAN_FRONTEND=noninteractive
ENV TZ=Asia/Shanghai LANG=C.UTF-8 LC_ALL=C.UTF-8 PIP_NO_CACHE_DIR=1 PIP_CACHE_DIR=/tmp/ PYTHONUNBUFFERED=1 PYTHONFAULTHANDLER=1 PYTHONHASHSEED=0
RUN \
# uncomment to use apt mirror
# sed -i "s/archive.ubuntu.com/mirrors.ustc.edu.cn/g" /etc/apt/sources.list &&\
# sed -i "s/security.ubuntu.com/mirrors.ustc.edu.cn/g" /etc/apt/sources.list &&\
rm -f /etc/apt/sources.list.d/* &&\
rm -rf /opt/hpcx/ &&\
apt-get update && apt-get upgrade -y &&\
apt-get install -y --no-install-recommends \
autoconf automake autotools-dev build-essential ca-certificates \
make cmake ninja-build pkg-config g++ ccache yasm openmpi-bin \
git curl wget unzip nano net-tools htop iotop \
cloc rsync xz-utils software-properties-common tzdata \
&& apt-get purge -y unattended-upgrades \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
ADD requirements.txt /tmp
RUN \
# uncomment to use pypi mirror
# pip config set global.index-url https://mirrors.bfsu.edu.cn/pypi/web/simple &&\
pip install -U pip &&\
pip install -r /tmp/requirements.txt &&\
pip install "jupyterlab~=3.5.0" "jupyter-archive~=3.2" &&\
rm -rf /tmp/*
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 Peng Wang, Lingzhe Zhao, Ruijie Ma, Peidong Liu
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# 😈BAD-NeRF
This is an official PyTorch implementation of the paper [BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields](https://arxiv.org/abs/2211.12853) (CVPR 2023). Authors: [Peng Wang](https://github.com/wangpeng000), [Lingzhe Zhao](https://github.com/LingzheZhao), Ruijie Ma and [Peidong Liu](https://ethliup.github.io/).
BAD-NeRF jointly learns the 3D representation and optimizes the camera motion trajectories within exposure time from blurry images and inaccurate initial poses.
Here is the [Project page](https://wangpeng000.github.io/BAD-NeRF/).
## ✨News
📺 **[2023.12]** We were invited to give an online talk on Bilibili and Wechat, hosted by **计算机视觉life**. Archive of our live stream (in Chinese): [[link to Bilibili]](https://www.bilibili.com/video/BV1Rb4y1G7Ek/)
⚡ **[2023.11]** We made our **nerfstudio**-framework-based implementation: [[BAD-NeRFstudio]](https://github.com/WU-CVGL/BAD-NeRFstudio) public. Now you can train a scene from blurry images in minutes!
## Novel View Synthesis
## Deblurring Result

## Pose Estimation Result

## Method overview

We follow the real physical image formation process of a motion-blurred image to synthesize blurry images from NeRF. Both NeRF and the motion trajectories are estimated by maximizing the photometric consistency between the synthesized blurry images and the real blurry images.
## Quickstart
### 1. Setup environment
```
git clone https://github.com/WU-CVGL/BAD-NeRF
cd BAD-NeRF
pip install -r requirements.txt
```
### 2. Download datasets
You can download the data and weights [here](https://westlakeu-my.sharepoint.com/:f:/g/personal/cvgl_westlake_edu_cn/EsgdW2cRic5JqerhNbTsxtkBqy9m6cbnb2ugYZtvaib3qA?e=bjK7op).
For the scenes of Deblur-NeRF (*cozy2room*, *factory* etc.), the folder `images` only includes blurry images and the folder `images_1` additionally includes novel view images. But for our scenes (*room-low*, *room-high* and *dark*), there are no novel view images. Note that the images in the *dark* scene are undistorted, causing that there are some useless pixels, and you should uncomment the code of `Graph.forward` in `nerf.py`.
### 3. Configs
Change the data path and other parameters (if needed) in `configs/cozy2room.txt`. We use *cozy2room* scene as an example.
### 4. Demo with our pre-trained model
You can test our code and render sharp images with the provided weight files. To do this, you should first put the weight file under the corresponding logs folder `./logs/cozy2room-linear`, and then change the parameter `load_weights=True` in `cozy2room.txt`, finally run
```
python test.py --config configs/cozy2room.txt
```
### 5. Training
```
python train.py --config configs/cozy2room.txt
```
After training, you can get deblurred images, optimized camera poses and synthesized novel view images.
## Notes
### Camera poses
The poses (`poses_bounds.npy`) are generated from only blurred images (folder `images`) by COLMAP.
### Spline model
We use `linear interpolation` as the default spline model in our experiments, you can simply change the parameter `linear` (all the parameters can be changed in `configs/***.txt` or `run_nerf.py`) to `False` to use the higher-order spline model (i.e. `cubic B-Spline`).
### Virtual images
You can change the important parameter `deblur_images` to a smaller/bigger value for lightly/severely blurred images.
### Learning rate
After rebuttal, we found that sometimes the gradients will be NaN if `cubic B-Spline` model with a `pose_lrate=1e-3` is used. Therefore, we set the initial pose learning rate to 1e-4 and it may achieve a better performance compared to that in our paper. If the gradient appears NaN in your experiments unfortunately, just kill it and try again or decrease the `pose_lrate`.
## Your own data
`images`: This folder is used to estimate initial camera poses from blurry images. Specifically, just put your own data in the folder `images` (only blurry images), and run `imgs2poses.py` script from the [LLFF code](https://github.com/fyusion/llff) to estimate camera poses and generate `poses_bounds.npy`.
`images_1`: This is the default training folder, which includes the same blurry images in `images` folder and (optional) several novel view sharp images. If you want to add novel view images (sharp images), please put them into the folder `images_1` with an interval of `llffhold` (a parameter used for novel view testing). Remember that, set parameter `novel_view` to `True` if `images_1` includes novel view images. Otherwise, if there are no novel view images, you can directly put the blurry images to the folder `images_1` and set parameter `novel_view` to `False`.
`images_test`: To compute deblurring metrics, this folder contains ground truth images theoretically. However, you can copy the blurry images in `images` folder to `images_test` folder if you don't have ground truth images, which is the easiest way to run the code correctly (remember the computed metrics are wrong).
```
#-----------------------------------------------------------------------------------------#
# images folder: img_blur_*.png is the blurry image. #
#-----------------------------------------------------------------------------------------#
# images_1 folder: img_blur_*.png is the same as that in `images` and (optional) #
# img_novel_*.png is the sharp novel view image. #
#-----------------------------------------------------------------------------------------#
# images_test folder: img_test_*.png should be the ground truth image corrseponds to #
# img_blur_*.png to compute PSNR metric. Of course, you can directly put img_blur_*.png #
# to run the code if you don't have gt images (then the metrics are wrong). #
#-----------------------------------------------------------------------------------------#
images folder: (suppose 10 images)
img_blur_0.png
img_blur_1.png
.
.
.
img_blur_9.png
#-----------------------------------------------------------------------------------------#
images_1 folder: (suppose novel view images are placed with an `llffhold=5` interval.)
img_novel_0.png (optional)
img_blur_0.png
img_blur_1.png
.
img_blur_4.png
img_novel_1.png (optional)
img_blur_5.png
.
img_blur_9.png
img_novel_2.png (optional)
#-----------------------------------------------------------------------------------------#
images_test folder: (theoretically gt images, but can be other images)
img_test_0.png
img_test_1.png
.
.
.
img_test_9.png
```
## Citation
If you find this useful, please consider citing our paper:
```bibtex
@InProceedings{wang2023badnerf,
author = {Wang, Peng and Zhao, Lingzhe and Ma, Ruijie and Liu, Peidong},
title = {{BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields}},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2023},
pages = {4170-4179}
}
```
## Acknowledgment
The overall framework, metrics computing and camera transformation are derived from [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch/), [Deblur-NeRF](https://github.com/limacv/Deblur-NeRF) and [BARF](https://github.com/chenhsuanlin/bundle-adjusting-NeRF) respectively. We appreciate the effort of the contributors to these repositories.
================================================
FILE: Spline.py
================================================
import torch
import numpy as np
delt = 0
def skew_symmetric(w):
w0, w1, w2 = w.unbind(dim=-1)
O = torch.zeros_like(w0)
wx = torch.stack(
[
torch.stack([O, -w2, w1], dim=-1),
torch.stack([w2, O, -w0], dim=-1),
torch.stack([-w1, w0, O], dim=-1),
],
dim=-2,
)
return wx
def taylor_A(x, nth=10):
# Taylor expansion of sin(x)/x
ans = torch.zeros_like(x)
denom = 1.0
for i in range(nth + 1):
if i > 0:
denom *= (2 * i) * (2 * i + 1)
ans = ans + (-1) ** i * x ** (2 * i) / denom
return ans
def taylor_B(x, nth=10):
# Taylor expansion of (1-cos(x))/x**2
ans = torch.zeros_like(x)
denom = 1.0
for i in range(nth + 1):
denom *= (2 * i + 1) * (2 * i + 2)
ans = ans + (-1) ** i * x ** (2 * i) / denom
return ans
def taylor_C(x, nth=10):
# Taylor expansion of (x-sin(x))/x**3
ans = torch.zeros_like(x)
denom = 1.0
for i in range(nth + 1):
denom *= (2 * i + 2) * (2 * i + 3)
ans = ans + (-1) ** i * x ** (2 * i) / denom
return ans
def exp_r2q_parallel(r, eps=1e-9):
x, y, z = r[..., 0], r[..., 1], r[..., 2]
theta = 0.5 * torch.sqrt(x**2 + y**2 + z**2)
bool_criterion = (theta < eps).unsqueeze(-1).repeat(1, 1, 4)
return torch.where(
bool_criterion, exp_r2q_taylor(x, y, z, theta), exp_r2q(x, y, z, theta)
)
def exp_r2q(x, y, z, theta):
lambda_ = torch.sin(theta) / (2.0 * theta)
qx = lambda_ * x
qy = lambda_ * y
qz = lambda_ * z
qw = torch.cos(theta)
return torch.stack([qx, qy, qz, qw], -1)
def exp_r2q_taylor(x, y, z, theta):
qx = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * x
qy = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * y
qz = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * z
qw = 1.0 - 1.0 / 2.0 * theta**2 + 1.0 / 24.0 * theta**4
return torch.stack([qx, qy, qz, qw], -1)
def q_to_R_parallel(q):
qb, qc, qd, qa = q.unbind(dim=-1)
R = torch.stack(
[
torch.stack(
[
1 - 2 * (qc**2 + qd**2),
2 * (qb * qc - qa * qd),
2 * (qa * qc + qb * qd),
],
dim=-1,
),
torch.stack(
[
2 * (qb * qc + qa * qd),
1 - 2 * (qb**2 + qd**2),
2 * (qc * qd - qa * qb),
],
dim=-1,
),
torch.stack(
[
2 * (qb * qd - qa * qc),
2 * (qa * qb + qc * qd),
1 - 2 * (qb**2 + qc**2),
],
dim=-1,
),
],
dim=-2,
)
return R
def q_to_Q_parallel(q):
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
Q_0 = torch.stack([w, -z, y, x], -1).unsqueeze(-2)
Q_1 = torch.stack([z, w, -x, y], -1).unsqueeze(-2)
Q_2 = torch.stack([-y, x, w, z], -1).unsqueeze(-2)
Q_3 = torch.stack([-x, -y, -z, w], -1).unsqueeze(-2)
Q_ = torch.cat([Q_0, Q_1, Q_2, Q_3], -2)
return Q_
def q_to_q_conj_parallel(q):
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
q_conj_ = torch.stack([-x, -y, -z, w], -1)
return q_conj_
def log_q2r_parallel(q, eps_theta=1e-20, eps_w=1e-10):
x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
theta = torch.sqrt(x**2 + y**2 + z**2)
bool_theta_0 = theta < eps_theta
bool_w_0 = torch.abs(w) < eps_w
bool_w_0_left = torch.logical_and(bool_w_0, w < 0)
lambda_ = torch.where(
bool_w_0,
torch.where(
bool_w_0_left,
log_q2r_lim_w_0_left(theta),
log_q2r_lim_w_0_right(theta)
),
torch.where(
bool_theta_0,
log_q2r_taylor_theta_0(w, theta),
log_q2r(w, theta)
),
)
r_ = torch.stack([lambda_ * x, lambda_ * y, lambda_ * z], -1)
return r_
def log_q2r(w, theta):
return 2.0 * (torch.arctan(theta / w)) / theta
def log_q2r_taylor_theta_0(w, theta):
return 2.0 / w - 2.0 / 3.0 * (theta**2) / (w * w * w)
def log_q2r_lim_w_0_left(theta):
return -torch.pi / theta
def log_q2r_lim_w_0_right(theta):
return torch.pi / theta
def SE3_to_se3(Rt, eps=1e-8): # [...,3,4]
R, t = Rt.split([3, 1], dim=-1)
w = SO3_to_so3(R)
wx = skew_symmetric(w)
theta = w.norm(dim=-1)[..., None, None]
I = torch.eye(3, device=w.device, dtype=torch.float32)
A = taylor_A(theta)
B = taylor_B(theta)
invV = I - 0.5 * wx + (1 - A / (2 * B)) / (theta**2 + eps) * wx @ wx
u = (invV @ t)[..., 0]
wu = torch.cat([w, u], dim=-1)
return wu
def SO3_to_so3(R, eps=1e-7): # [...,3,3]
trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
theta = ((trace - 1) / 2).clamp(-1 + eps, 1 - eps).acos_()[
..., None, None
] % np.pi # ln(R) will explode if theta==pi
lnR = (
1 / (2 * taylor_A(theta) + 1e-8) * (R - R.transpose(-2, -1))
) # FIXME: wei-chiu finds it weird
w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]
w = torch.stack([w0, w1, w2], dim=-1)
return w
def se3_to_SE3(wu): # [...,3]
w, u = wu.split([3, 3], dim=-1)
wx = skew_symmetric(w) # wx=[0 -w(2) w(1);w(2) 0 -w(0);-w(1) w(0) 0]
theta = w.norm(dim=-1)[..., None, None] # theta=sqrt(w'*w)
I = torch.eye(3, device=w.device, dtype=torch.float32)
A = taylor_A(theta)
B = taylor_B(theta)
C = taylor_C(theta)
R = I + A * wx + B * wx @ wx
V = I + B * wx + C * wx @ wx
Rt = torch.cat([R, (V @ u[..., None])], dim=-1)
return Rt
def SE3_to_se3_N(poses_rt):
poses_se3_list = []
for i in range(poses_rt.shape[0]):
pose_se3 = SE3_to_se3(poses_rt[i])
poses_se3_list.append(pose_se3)
poses = torch.stack(poses_se3_list, 0)
return poses
def se3_to_SE3_N(poses_wu):
poses_se3_list = []
for i in range(poses_wu.shape[0]):
pose_se3 = se3_to_SE3(poses_wu[i])
poses_se3_list.append(pose_se3)
poses = torch.stack(poses_se3_list, 0)
return poses
def se3_2_qt_parallel(wu):
w, u = wu.split([3, 3], dim=-1)
wx = skew_symmetric(w)
theta = w.norm(dim=-1)[..., None, None]
I = torch.eye(3, device=w.device, dtype=torch.float32)
# A = taylor_A(theta)
B = taylor_B(theta)
C = taylor_C(theta)
# R = I + A * wx + B * wx @ wx
V = I + B * wx + C * wx @ wx
t = V @ u[..., None]
q = exp_r2q_parallel(w)
return q, t.squeeze(-1)
def SplineN_linear(start_pose, end_pose, poses_number, NUM, device=None):
pose_time = poses_number / (NUM - 1)
# parallel
pos_0 = torch.where(pose_time == 0)
pose_time[pos_0] = pose_time[pos_0] + 0.000001
pos_1 = torch.where(pose_time == 1)
pose_time[pos_1] = pose_time[pos_1] - 0.000001
q_start, t_start = se3_2_qt_parallel(start_pose)
q_end, t_end = se3_2_qt_parallel(end_pose)
# sample t_vector
t_t = (1 - pose_time)[..., None] * t_start + pose_time[..., None] * t_end
# sample rotation_vector
q_tau_0 = q_to_Q_parallel(q_to_q_conj_parallel(q_start)) @ q_end[..., None]
r = pose_time[..., None] * log_q2r_parallel(q_tau_0.squeeze(-1))
q_t_0 = exp_r2q_parallel(r)
q_t = q_to_Q_parallel(q_start) @ q_t_0[..., None]
# convert q&t to RT
R = q_to_R_parallel(q_t.squeeze(dim=-1))
t = t_t.unsqueeze(dim=-1)
pose_spline = torch.cat([R, t], -1)
poses = pose_spline.reshape([-1, 3, 4])
return poses
def SplineN_cubic(pose0, pose1, pose2, pose3, poses_number, NUM):
sample_time = poses_number / (NUM - 1)
# parallel
pos_0 = torch.where(sample_time == 0)
sample_time[pos_0] = sample_time[pos_0] + 0.000001
pos_1 = torch.where(sample_time == 1)
sample_time[pos_1] = sample_time[pos_1] - 0.000001
sample_time = sample_time.unsqueeze(-1)
q0, t0 = se3_2_qt_parallel(pose0)
q1, t1 = se3_2_qt_parallel(pose1)
q2, t2 = se3_2_qt_parallel(pose2)
q3, t3 = se3_2_qt_parallel(pose3)
u = sample_time
uu = sample_time**2
uuu = sample_time**3
one_over_six = 1.0 / 6.0
half_one = 0.5
# t
coeff0 = one_over_six - half_one * u + half_one * uu - one_over_six * uuu
coeff1 = 4 * one_over_six - uu + half_one * uuu
coeff2 = one_over_six + half_one * u + half_one * uu - half_one * uuu
coeff3 = one_over_six * uuu
# spline t
t_t = coeff0 * t0 + coeff1 * t1 + coeff2 * t2 + coeff3 * t3
# R
coeff1_r = 5 * one_over_six + half_one * u - half_one * uu + one_over_six * uuu
coeff2_r = one_over_six + half_one * u + half_one * uu - 2 * one_over_six * uuu
coeff3_r = one_over_six * uuu
# spline R
q_01 = q_to_Q_parallel(q_to_q_conj_parallel(q0)) @ q1[..., None] # [1]
q_12 = q_to_Q_parallel(q_to_q_conj_parallel(q1)) @ q2[..., None] # [2]
q_23 = q_to_Q_parallel(q_to_q_conj_parallel(q2)) @ q3[..., None] # [3]
r_01 = log_q2r_parallel(q_01.squeeze(-1)) * coeff1_r # [4]
r_12 = log_q2r_parallel(q_12.squeeze(-1)) * coeff2_r # [5]
r_23 = log_q2r_parallel(q_23.squeeze(-1)) * coeff3_r # [6]
q_t_0 = exp_r2q_parallel(r_01) # [7]
q_t_1 = exp_r2q_parallel(r_12) # [8]
q_t_2 = exp_r2q_parallel(r_23) # [9]
q_product1 = q_to_Q_parallel(q_t_1) @ q_t_2[..., None] # [10]
q_product2 = q_to_Q_parallel(q_t_0) @ q_product1 # [10]
q_t = q_to_Q_parallel(q0) @ q_product2 # [10]
R = q_to_R_parallel(q_t.squeeze(-1))
t = t_t.unsqueeze(dim=-1)
pose_spline = torch.cat([R, t], -1)
poses = pose_spline.reshape([-1, 3, 4])
return poses
================================================
FILE: configs/cozy2room-cubic.txt
================================================
expname = cozy2room-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurcozy2room
dataset_type = llff
factor = 1
linear = False
pose_lrate = 1e-4
novel_view = True
factor_pose_novel = 20.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/cozy2room.txt
================================================
expname = cozy2room-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurcozy2room
dataset_type = llff
factor = 1
linear = True
novel_view = True
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/factory-cubic.txt
================================================
expname = factory-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurfactory
dataset_type = llff
factor = 1
linear = False
pose_lrate = 1e-4
novel_view = True
factor_pose_novel = 200.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/factory.txt
================================================
expname = factory-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurfactory
dataset_type = llff
factor = 1
linear = True
novel_view = True
factor_pose_novel = 20.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/pool-cubic.txt
================================================
expname = pool-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurpool
dataset_type = llff
factor = 1
linear = False
pose_lrate = 1e-4
novel_view = True
factor_pose_novel = 20.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/pool.txt
================================================
expname = pool-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurpool
dataset_type = llff
factor = 1
linear = True
novel_view = True
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/roomdark-cubic.txt
================================================
expname = roomdark-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/dark
dataset_type = llff
factor = 1
linear = False
pose_lrate = 1e-4
novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/roomdark.txt
================================================
expname = roomdark-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/dark
dataset_type = llff
factor = 1
linear = True
novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/roomhigh-cubic.txt
================================================
expname = roomhigh-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/roomblur_high
dataset_type = llff
factor = 1
linear = False
pose_lrate = 1e-4
novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/roomhigh.txt
================================================
expname = roomhigh-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/roomblur_high
dataset_type = llff
factor = 1
linear = True
novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/roomlow-cubic.txt
================================================
expname = roomlow-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/roomblur_low
dataset_type = llff
factor = 1
linear = False
pose_lrate = 1e-4
novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/roomlow.txt
================================================
expname = roomlow-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/roomblur_low
dataset_type = llff
factor = 1
linear = True
novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/tanabata-cubic.txt
================================================
expname = tanabata-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurtanabata
dataset_type = llff
factor = 1
linear = False
pose_lrate = 1e-4
novel_view = True
factor_pose_novel = 50.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/tanabata.txt
================================================
expname = tanabata-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurtanabata
dataset_type = llff
factor = 1
linear = True
novel_view = True
factor_pose_novel = 5.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/wine-cubic.txt
================================================
expname = wine-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurwine
dataset_type = llff
factor = 1
linear = False
pose_lrate = 1e-4
novel_view = True
factor_pose_novel = 20.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: configs/wine.txt
================================================
expname = wine-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurwine
dataset_type = llff
factor = 1
linear = True
novel_view = True
factor_pose_novel = 2.0
i_novel_view = 200000
N_rand = 5000
deblur_images = 7
N_samples = 64
N_importance = 64
use_viewdirs = True
raw_noise_std = 1.0
load_weights = False
weight_iter = 200000
i_img = 25000
i_video = 200000
i_weights = 10000
================================================
FILE: load_llff.py
================================================
import numpy as np
import os, imageio
import torch
########## Slightly modified version of LLFF data loading code
########## see https://github.com/Fyusion/LLFF for original
# downsample
def _minify(basedir, factors=[], resolutions=[]): # basedir: ./data/nerf_llff_data/fern
needtoload = False
for r in factors:
imgdir = os.path.join(basedir, 'images_{}'.format(r))
if not os.path.exists(imgdir):
needtoload = True
for r in resolutions:
imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
if not os.path.exists(imgdir):
needtoload = True
if not needtoload:
return
from shutil import copy
from subprocess import check_output
imgdir = os.path.join(basedir, 'images')
imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
imgdir_orig = imgdir
wd = os.getcwd()
for r in factors + resolutions:
if isinstance(r, int):
name = 'images_{}'.format(r)
resizearg = '{}%'.format(100. / r)
else:
name = 'images_{}x{}'.format(r[1], r[0])
resizearg = '{}x{}'.format(r[1], r[0])
imgdir = os.path.join(basedir, name)
if os.path.exists(imgdir):
continue
print('Minifying', r, basedir)
os.makedirs(imgdir)
check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
ext = imgs[0].split('.')[-1]
args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
print(args)
os.chdir(imgdir)
check_output(args, shell=True)
os.chdir(wd)
if ext != 'png':
check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
print('Removed duplicates')
print('Done')
def _load_data(basedir, pose_state, factor=None, width=None, height=None, load_imgs=True):
poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
bds = poses_arr[:, -2:].transpose([1, 0])
img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
sh = imageio.imread(img0).shape
sfx = ''
if factor is not None:
sfx = '_{}'.format(factor)
_minify(basedir, factors=[factor])
factor = factor
elif height is not None:
factor = sh[0] / float(height)
width = int(sh[1] / factor)
_minify(basedir, resolutions=[[height, width]])
sfx = '_{}x{}'.format(width, height)
elif width is not None:
factor = sh[1] / float(width)
height = int(sh[0] / factor)
_minify(basedir, resolutions=[[height, width]])
sfx = '_{}x{}'.format(width, height)
else:
factor = 1
imgdir = os.path.join(basedir, 'images' + sfx)
if not os.path.exists(imgdir):
print(imgdir, 'does not exist, returning')
return
imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if
f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
if poses.shape[-1] != len(imgfiles):
print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
# return
sh = imageio.imread(imgfiles[0]).shape
poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
poses[2, 4, :] = poses[2, 4, :] * 1. / factor
if not load_imgs:
return poses, bds
def imread(f):
if f.endswith('png'):
return imageio.imread(f, ignoregamma=True)
else:
return imageio.imread(f)
imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
imgs = np.stack(imgs, -1)
print('Loaded image data', imgs.shape, poses[:, -1, 0])
return poses, bds, imgs
def normalize(x):
return x / np.linalg.norm(x)
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def ptstocam(pts, c2w):
tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
return tt
def poses_avg(poses):
hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0)
vec2 = normalize(poses[:, :3, 2].sum(0))
up = poses[:, :3, 1].sum(0)
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
return c2w
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
render_poses = []
rads = np.array(list(rads) + [1.])
hwf = c2w[:, 4:5]
for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]:
c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
return render_poses
def recenter_poses(poses):
poses_ = poses + 0
bottom = np.reshape([0, 0, 0, 1.], [1, 4])
c2w = poses_avg(poses)
c2w = np.concatenate([c2w[:3, :4], bottom], -2)
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
poses = np.concatenate([poses[:, :3, :4], bottom], -2)
poses = np.linalg.inv(c2w) @ poses
poses_[:, :3, :4] = poses[:, :3, :4]
poses = poses_
return poses
#####################
def spherify_poses(poses, bds):
p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1)
rays_d = poses[:, :3, 2:3]
rays_o = poses[:, :3, 3:4]
def min_line_dist(rays_o, rays_d):
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
b_i = -A_i @ rays_o
pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0))
return pt_mindist
pt_mindist = min_line_dist(rays_o, rays_d)
center = pt_mindist
up = (poses[:, :3, 3] - center).mean(0)
vec0 = normalize(up)
vec1 = normalize(np.cross([.1, .2, .3], vec0))
vec2 = normalize(np.cross(vec0, vec1))
pos = center
c2w = np.stack([vec1, vec2, vec0, pos], 1)
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
sc = 1. / rad
poses_reset[:, :3, 3] *= sc
bds *= sc
rad *= sc
centroid = np.mean(poses_reset[:, :3, 3], 0)
zh = centroid[2]
radcircle = np.sqrt(rad ** 2 - zh ** 2)
new_poses = []
for th in np.linspace(0., 2. * np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0, 0, -1.])
vec2 = normalize(camorigin)
vec0 = normalize(np.cross(vec2, up))
vec1 = normalize(np.cross(vec2, vec0))
pos = camorigin
p = np.stack([vec0, vec1, vec2, pos], 1)
new_poses.append(p)
new_poses = np.stack(new_poses, 0)
new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1)
poses_reset = np.concatenate(
[poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1)
return poses_reset, new_poses, bds
def load_llff_data(basedir, pose_state=None, factor=1, recenter=True, bd_factor=.75, spherify=False,
path_zflat=False):
poses, bds, imgs = _load_data(basedir, pose_state=pose_state,factor=factor)
print('Loaded', basedir, bds.min(), bds.max(), 'Pose State: ', pose_state)
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
# Rescale if bd_factor is provided
sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor)
poses[:, :3, 3] *= sc # T
bds *= sc
if recenter:
poses = recenter_poses(poses)
if spherify:
poses, render_poses, bds = spherify_poses(poses, bds)
else:
c2w = poses_avg(poses)
up = normalize(poses[:, :3, 1].sum(0))
# Find a reasonable "focus depth" for this dataset
close_depth, inf_depth = bds.min() * .9, bds.max() * 5.
dt = .75
mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))
focal = mean_dz
# Get radii for spiral path
shrink_factor = .8
zdelta = close_depth * .2
tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
rads = np.percentile(np.abs(tt), 90, 0)
c2w_path = c2w
N_views = 120
N_rots = 2
if path_zflat:
# zloc = np.percentile(tt, 10, 0)[2]
zloc = -close_depth * .1
c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2]
rads[2] = 0.
N_rots = 1
N_views /= 2
# Generate poses for spiral path
render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)
render_poses = np.array(render_poses).astype(np.float32)
imgs = torch.Tensor(imgs)
poses = torch.Tensor(poses)
bds = torch.Tensor(bds)
render_poses = torch.Tensor(render_poses)
return imgs, poses, bds, render_poses
================================================
FILE: lpips/__init__.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
# from torch.autograd import Variable
from .lpips import *
================================================
FILE: lpips/lpips.py
================================================
from __future__ import absolute_import
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from . import pretrained_networks as pn
import torch.nn
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)
def l2(p0, p1, range=255.):
return .5 * np.mean((p0 / range - p1 / range) ** 2)
def psnr(p0, p1, peak=255.):
return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))
def dssim(p0, p1, range=255.):
from skimage.measure import compare_ssim
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
def rgb2lab(in_img, mean_cent=False):
from skimage import color
img_lab = color.rgb2lab(in_img)
if mean_cent:
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
return img_lab
def tensor2np(tensor_obj):
# change dimension of a tensor object into a numpy array
return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))
def np2tensor(np_obj):
# change dimenion of np array into tensor array
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
# image tensor to lab tensor
from skimage import color
img = tensor2im(image_tensor)
img_lab = color.rgb2lab(img)
if mc_only:
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
if to_norm and not mc_only:
img_lab[:, :, 0] = img_lab[:, :, 0] - 50
img_lab = img_lab / 100.
return np2tensor(img_lab)
def tensorlab2tensor(lab_tensor, return_inbnd=False):
from skimage import color
import warnings
warnings.filterwarnings("ignore")
lab = tensor2np(lab_tensor) * 100.
lab[:, :, 0] = lab[:, :, 0] + 50
rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
if return_inbnd:
# convert back to lab, see if we match
lab_back = color.rgb2lab(rgb_back.astype('uint8'))
mask = 1. * np.isclose(lab_back, lab, atol=2.)
mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
return im2tensor(rgb_back), mask
else:
return im2tensor(rgb_back)
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
return image_numpy.astype(imtype)
def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
return torch.tensor((image / factor - cent)
[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
def tensor2vec(vector_tensor):
return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
def voc_ap(rec, prec, use_07_metric=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric:
# 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(rec >= t) == 0:
p = 0
else:
p = np.max(prec[rec >= t])
ap = ap + p / 11.
else:
# correct AP calculation
# first append sentinel values at the end
mrec = np.concatenate(([0.], rec, [1.]))
mpre = np.concatenate(([0.], prec, [0.]))
# compute the precision envelope
for i in range(mpre.size - 1, 0, -1):
mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
# to calculate area under PR curve, look for points
# where X axis (recall) changes value
i = np.where(mrec[1:] != mrec[:-1])[0]
# and sum (\Delta recall) * prec
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
return ap
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2, 3], keepdim=keepdim)
def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
# lpips - [True] means with linear calibration on top of base network
# pretrained - [True] means load linear weights
super(LPIPS, self).__init__()
if verbose:
print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' %
('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
self.pnet_type = net
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.lpips = lpips # false means baseline of just averaging all layers
self.version = version
self.scaling_layer = ScalingLayer()
if self.pnet_type in ['vgg', 'vgg16']:
net_type = pn.vgg16
self.chns = [64, 128, 256, 512, 512]
elif self.pnet_type == 'alex':
net_type = pn.alexnet
self.chns = [64, 192, 384, 256, 256]
elif self.pnet_type == 'squeeze':
net_type = pn.squeezenet
self.chns = [64, 128, 256, 384, 384, 512, 512]
self.L = len(self.chns)
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
if lpips:
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
if self.pnet_type == 'squeeze': # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins += [self.lin5, self.lin6]
self.lins = nn.ModuleList(self.lins)
if pretrained:
if model_path is None:
import inspect
import os
model_path = os.path.abspath(
os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))
if verbose:
print('Loading model from: %s' % model_path)
self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
if eval_mode:
self.eval()
def forward(self, in0, in1, retPerLayer=False, normalize=False):
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (
in0, in1)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
if self.lpips:
if self.spatial:
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
else:
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
else:
if self.spatial:
res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
else:
res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]
val = res[0]
for l in range(1, self.L):
val += res[l]
# a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
# b = torch.max(self.lins[kk](feats0[kk]**2))
# for kk in range(self.L):
# a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
# b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
# a = a/self.L
# from IPython import embed
# embed()
# return 10*torch.log10(b/a)
if retPerLayer:
return val, res
else:
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(), ] if use_dropout else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class Dist2LogitLayer(nn.Module):
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
def __init__(self, chn_mid=32, use_sigmoid=True):
super(Dist2LogitLayer, self).__init__()
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ]
layers += [nn.LeakyReLU(0.2, True), ]
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ]
layers += [nn.LeakyReLU(0.2, True), ]
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ]
if use_sigmoid:
layers += [nn.Sigmoid(), ]
self.model = nn.Sequential(*layers)
def forward(self, d0, d1, eps=0.1):
return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))
class BCERankingLoss(nn.Module):
def __init__(self, chn_mid=32):
super(BCERankingLoss, self).__init__()
self.net = Dist2LogitLayer(chn_mid=chn_mid)
# self.parameters = list(self.net.parameters())
self.loss = torch.nn.BCELoss()
def forward(self, d0, d1, judge):
per = (judge + 1.) / 2.
self.logit = self.net.forward(d0, d1)
return self.loss(self.logit, per)
# L2, DSSIM metrics
class FakeNet(nn.Module):
def __init__(self, use_gpu=True, colorspace='Lab'):
super(FakeNet, self).__init__()
self.use_gpu = use_gpu
self.colorspace = colorspace
class L2(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert (in0.size()[0] == 1) # currently only supports batchSize 1
if self.colorspace == 'RGB':
(N, C, X, Y) = in0.size()
value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y),
dim=3).view(N)
return value
elif self.colorspace == 'Lab':
value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype(
'float')
ret_var = Variable(torch.Tensor((value,)))
if self.use_gpu:
ret_var = ret_var.cuda()
return ret_var
class DSSIM(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert (in0.size()[0] == 1) # currently only supports batchSize 1
if self.colorspace == 'RGB':
value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype(
'float')
elif self.colorspace == 'Lab':
value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype(
'float')
ret_var = Variable(torch.Tensor((value,)))
if self.use_gpu:
ret_var = ret_var.cuda()
return ret_var
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print('Network', net)
print('Total number of parameters: %d' % num_params)
================================================
FILE: lpips/pretrained_networks.py
================================================
from collections import namedtuple
import torch
from torchvision import models as tv
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if (num == 18):
self.net = tv.resnet18(pretrained=pretrained)
elif (num == 34):
self.net = tv.resnet34(pretrained=pretrained)
elif (num == 50):
self.net = tv.resnet50(pretrained=pretrained)
elif (num == 101):
self.net = tv.resnet101(pretrained=pretrained)
elif (num == 152):
self.net = tv.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out
================================================
FILE: metrics.py
================================================
from skimage import metrics
import torch
import torch.hub
from lpips.lpips import LPIPS
import os
import numpy as np
photometric = {
"mse": None,
"ssim": None,
"psnr": None,
"lpips": None
}
def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor,
metric="mse", margin=0, mask=None):
"""
im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
"""
if metric not in photometric.keys():
raise RuntimeError(f"img_utils:: metric {metric} not recognized")
if photometric[metric] is None:
if metric == "mse":
photometric[metric] = metrics.mean_squared_error
elif metric == "ssim":
photometric[metric] = metrics.structural_similarity
elif metric == "psnr":
photometric[metric] = metrics.peak_signal_noise_ratio
elif metric == "lpips":
photometric[metric] = LPIPS().cpu()
if mask is not None:
if mask.dim() == 3:
mask = mask.unsqueeze(1)
if mask.shape[1] == 1:
mask = mask.expand(-1, 3, -1, -1)
mask = mask.permute(0, 2, 3, 1).numpy()
batchsz, hei, wid, _ = mask.shape
if margin > 0:
marginh = int(hei * margin) + 1
marginw = int(wid * margin) + 1
mask = mask[:, marginh:hei - marginh, marginw:wid - marginw]
# convert from [0, 1] to [-1, 1]
im1t = (im1t * 2 - 1).clamp(-1, 1)
im2t = (im2t * 2 - 1).clamp(-1, 1)
if im1t.dim() == 3:
im1t = im1t.unsqueeze(0)
im2t = im2t.unsqueeze(0)
im1t = im1t.detach().cpu()
im2t = im2t.detach().cpu()
if im1t.shape[-1] == 3:
im1t = im1t.permute(0, 3, 1, 2)
im2t = im2t.permute(0, 3, 1, 2)
im1 = im1t.permute(0, 2, 3, 1).numpy()
im2 = im2t.permute(0, 2, 3, 1).numpy()
batchsz, hei, wid, _ = im1.shape
if margin > 0:
marginh = int(hei * margin) + 1
marginw = int(wid * margin) + 1
im1 = im1[:, marginh:hei - marginh, marginw:wid - marginw]
im2 = im2[:, marginh:hei - marginh, marginw:wid - marginw]
values = []
for i in range(batchsz):
if metric in ["mse", "psnr"]:
if mask is not None:
im1 = im1 * mask[i]
im2 = im2 * mask[i]
value = photometric[metric](
im1[i], im2[i]
)
if mask is not None:
hei, wid, _ = im1[i].shape
pixelnum = mask[i, ..., 0].sum()
value = value - 10 * np.log10(hei * wid / pixelnum)
elif metric in ["ssim"]:
value, ssimmap = photometric["ssim"](
im1[i], im2[i], multichannel=True, full=True
)
if mask is not None:
value = (ssimmap * mask[i]).sum() / mask[i].sum()
elif metric in ["lpips"]:
value = photometric[metric](
im1t[i:i + 1], im2t[i:i + 1]
)
else:
raise NotImplementedError
values.append(value)
return sum(values) / len(values)
================================================
FILE: nerf.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from run_nerf import *
from Spline import se3_to_SE3
max_iter = 200000
T = max_iter+1
BOUNDARY = 20
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(args, multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input': False if args.barf else True,
'input_dims': 3,
'max_freq_log2': multires - 1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj: eo.embed(x)
return embed, embedder_obj.out_dim
class Model():
def __init__(self):
super().__init__()
def build_network(self, args):
self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)
return self.graph
def setup_optimizer(self, args):
grad_vars = list(self.graph.nerf.parameters())
if args.N_importance>0:
grad_vars += list(self.graph.nerf_fine.parameters())
self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
return self.optim
class NeRF(nn.Module):
def __init__(self, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=False):
super().__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
# network
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in
range(D - 1)])
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)])
if use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W // 2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
def forward(self, barf_i, pts, viewdirs, args):
embed_fn, input_ch = get_embedder(args, args.multires, args.i_embed)
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(args, args.multires_views, args.i_embed)
pts_flat = torch.reshape(pts, [-1, pts.shape[-1]])
embedded = embed_fn(pts_flat)
if viewdirs is not None:
input_dirs = viewdirs[:, None].expand(pts.shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = embeddirs_fn(input_dirs_flat)
embedded = torch.cat([embedded, embedded_dirs], -1)
input_pts, input_views = torch.split(embedded, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
else:
outputs = self.output_linear(h)
outputs = torch.reshape(outputs, list(pts.shape[:-1]) + [outputs.shape[-1]])
return outputs
def raw2output(self, raw, z_vals, rays_d, raw_noise_std=0.0):
raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists)
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1)
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
rgb = torch.sigmoid(raw[..., :3])
noise = 0.
if raw_noise_std > 0.:
noise = torch.randn(raw[..., 3].shape) * raw_noise_std
alpha = raw2alpha(raw[..., 3] + noise, dists)
weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:,:-1]
rgb_map = torch.sum(weights[..., None] * rgb, -2)
depth_map = torch.sum(weights * z_vals, -1)
# disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
disp_map = torch.max(1e-6 * torch.ones_like(depth_map), depth_map / (torch.sum(weights, -1)+1e-6))
acc_map = torch.sum(weights, -1)
sigma = F.relu(raw[..., 3] + noise)
return rgb_map, disp_map, acc_map, weights, depth_map, sigma
class Graph(nn.Module):
def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=False):
super().__init__()
self.nerf = NeRF(D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)
if args.N_importance > 0:
self.nerf_fine = NeRF(D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)
def forward(self, i, img_idx, poses_num, H, W, K, args, novel_view=False):
if novel_view:
poses_sharp = se3_to_SE3(self.se3_sharp.weight)
ray_idx_sharp = torch.randperm(H * W)[:300]
ret = self.render(i, poses_sharp, ray_idx_sharp, H, W, K, args)
return ret, ray_idx_sharp, poses_sharp
spline_poses = self.get_pose(i, img_idx, args)
ray_idx = torch.randperm(H * W)[:args.N_rand // poses_num]
'''
# only used in distorted data
# aims to prevent the ray_idx lying on the edges
for j in range(ray_idx.shape[0]):
h = torch.randperm(H - 1)[0]
w = torch.randperm(W - 1)[0]
while (h < BOUNDARY or h > (H - 1 - BOUNDARY) or w < BOUNDARY or w > (W - 1 - BOUNDARY)):
h = torch.randperm(H - 1)[0]
w = torch.randperm(W - 1)[0]
index = h * W + w
ray_idx[j] = index
'''
ret = self.render(i, spline_poses, ray_idx, H, W, K, args, near=0, far=1.0, ray_idx_tv=None, training=True)
if (i % args.i_img == 0 or i % args.i_novel_view == 0) and i > 0:
if args.deblur_images % 2 == 0:
all_poses = self.get_pose_even(i, torch.arange(self.se3.weight.shape[0]), args.deblur_images)
else:
all_poses = self.get_pose(i, torch.arange(self.se3.weight.shape[0]), args)
return ret, ray_idx, spline_poses, all_poses
else:
return ret, ray_idx, spline_poses
def get_pose(self, i, img_idx, args):
return i
def get_gt_pose(self, poses, args):
return poses
def render(self, barf_i, poses, ray_idx, H, W, K, args, near=0., far=1., ray_idx_tv=None, training=False):
if training:
ray_idx_ = ray_idx.repeat(poses.shape[0])
poses = poses.unsqueeze(1).repeat(1, ray_idx.shape[0], 1, 1).reshape(-1, 3, 4)
j = ray_idx_.reshape(-1, 1).squeeze() // W
i = ray_idx_.reshape(-1, 1).squeeze() % W
rays_o_, rays_d_ = get_specific_rays(i, j, K, poses)
rays_o_d = torch.stack([rays_o_, rays_d_], 0)
batch_rays = torch.permute(rays_o_d, [1, 0, 2])
else:
rays_list = []
for p in poses[:, :3, :4]:
rays_o_, rays_d_ = get_rays(H, W, K, p)
rays_o_d = torch.stack([rays_o_, rays_d_], 0)
rays_list.append(rays_o_d)
rays = torch.stack(rays_list, 0)
rays = rays.reshape(-1, 2, H * W, 3)
rays = torch.permute(rays, [0, 2, 1, 3])
batch_rays = rays[:, ray_idx]
batch_rays = batch_rays.reshape(-1, 2, 3)
batch_rays = torch.transpose(batch_rays, 0, 1)
# get standard rays
rays_o, rays_d = batch_rays
if args.use_viewdirs:
viewdirs = rays_d
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1, 3]).float()
sh = rays_d.shape
if args.ndc:
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
# Create ray batch
rays_o = torch.reshape(rays_o, [-1, 3]).float()
rays_d = torch.reshape(rays_d, [-1, 3]).float()
near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])
rays = torch.cat([rays_o, rays_d, near, far], -1)
if args.use_viewdirs:
rays = torch.cat([rays, viewdirs], -1)
N_rays = rays.shape[0]
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]
viewdirs = rays[:, -3:] if rays.shape[-1] > 8 else None
bounds = torch.reshape(rays[..., 6:8], [-1, 1, 2])
near, far = bounds[..., 0], bounds[..., 1]
t_vals = torch.linspace(0., 1., steps=args.N_samples)
z_vals = near * (1. - t_vals) + far * (t_vals)
z_vals = z_vals.expand([N_rays, args.N_samples])
# perturb
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape)
z_vals = lower + (upper - lower) * t_rand
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
raw_output = self.nerf.forward(barf_i, pts, viewdirs, args)
rgb_map, disp_map, acc_map, weights, depth_map, sigma = self.nerf.raw2output(raw_output, z_vals, rays_d)
if args.N_importance > 0:
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], args.N_importance)
z_samples = z_samples.detach()
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]
raw_output = self.nerf_fine.forward(barf_i, pts, viewdirs, args)
rgb_map, disp_map, acc_map, weights, depth_map, sigma = self.nerf_fine.raw2output(raw_output, z_vals, rays_d)
ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}
if args.N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['sigma'] = sigma
return ret
@torch.no_grad()
def render_video(self, barf_i, poses, H, W, K, args):
all_ret = {}
ray_idx = torch.arange(0, H*W)
for i in range(0, ray_idx.shape[0], args.chunk):
ret = self.render(barf_i, poses, ray_idx[i:i+args.chunk], H, W, K, args)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
for k in all_ret:
k_sh = list([H, W]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)
return all_ret
================================================
FILE: novel_view_test.py
================================================
import nerf
import torch.nn
class Model(nerf.Model):
def __init__(self, se3_start, graph):
super().__init__()
self.start = se3_start
self.graph_fixed = graph
def build_network(self, args):
self.graph_fixed.se3_sharp = torch.nn.Embedding(self.start.shape[0], 6) # 22和25
self.graph_fixed.se3_sharp.weight.data = torch.nn.Parameter(self.start)
return self.graph_fixed
def setup_optimizer(self, args):
grad_vars_se3 = list(self.graph_fixed.se3_sharp.parameters())
self.optim_se3_sharp = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)
return self.optim_se3_sharp
================================================
FILE: optimize_pose_cubic.py
================================================
import torch.nn
import Spline
import nerf
class Model(nerf.Model):
def __init__(self, se3_0, se3_1, se3_2, se3_3):
super().__init__()
self.se3_0 = se3_0
self.se3_1 = se3_1
self.se3_2 = se3_2
self.se3_3 = se3_3
def build_network(self, args):
self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)
self.graph.se3 = torch.nn.Embedding(self.se3_0.shape[0], 6*4)
start_end = torch.cat([self.se3_0, self.se3_1, self.se3_2, self.se3_3], -1)
self.graph.se3.weight.data = torch.nn.Parameter(start_end)
return self.graph
def setup_optimizer(self, args):
grad_vars = list(self.graph.nerf.parameters())
if args.N_importance > 0:
grad_vars += list(self.graph.nerf_fine.parameters())
self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
grad_vars_se3 = list(self.graph.se3.parameters())
self.optim_se3 = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)
return self.optim, self.optim_se3
class Graph(nerf.Graph):
def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True):
super().__init__(args, D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)
self.pose_eye = torch.eye(3, 4)
self.se3_start = None
self.se3_end = None
def get_pose(self, i, img_idx, args):
se3_0 = self.se3.weight[:, :6][img_idx]
se3_1 = self.se3.weight[:, 6:12][img_idx]
se3_2 = self.se3.weight[:, 12:18][img_idx]
se3_3 = self.se3.weight[:, 18:][img_idx]
pose_nums = torch.arange(args.deblur_images).reshape(1, -1).repeat(se3_0.shape[0], 1)
seg_pos_x = torch.arange(se3_0.shape[0]).reshape([se3_0.shape[0], 1]).repeat(1, args.deblur_images)
se3_0 = se3_0[seg_pos_x, :]
se3_1 = se3_1[seg_pos_x, :]
se3_2 = se3_2[seg_pos_x, :]
se3_3 = se3_3[seg_pos_x, :]
spline_poses = Spline.SplineN_cubic(se3_0, se3_1, se3_2, se3_3, pose_nums, args.deblur_images)
return spline_poses
def get_pose_even(self, i, img_idx, num):
deblur_images_num = num+1
se3_0 = self.se3.weight[:, :6][img_idx]
se3_1 = self.se3.weight[:, 6:12][img_idx]
se3_2 = self.se3.weight[:, 12:18][img_idx]
se3_3 = self.se3.weight[:, 18:][img_idx]
pose_nums = torch.arange(deblur_images_num).reshape(1, -1).repeat(se3_0.shape[0], 1)
seg_pos_x = torch.arange(se3_0.shape[0]).reshape([se3_0.shape[0], 1]).repeat(1, deblur_images_num)
se3_0 = se3_0[seg_pos_x, :]
se3_1 = se3_1[seg_pos_x, :]
se3_2 = se3_2[seg_pos_x, :]
se3_3 = se3_3[seg_pos_x, :]
spline_poses = Spline.SplineN_cubic(se3_0, se3_1, se3_2, se3_3, pose_nums, deblur_images_num)
return spline_poses
def get_gt_pose(self, poses, args):
a = self.pose_eye
return poses
================================================
FILE: optimize_pose_linear.py
================================================
import torch.nn
import Spline
import nerf
class Model(nerf.Model):
def __init__(self, se3_start, se3_end):
super().__init__()
self.start = se3_start
self.end = se3_end
def build_network(self, args):
self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)
self.graph.se3 = torch.nn.Embedding(self.start.shape[0], 6*2)
start_end = torch.cat([self.start, self.end], -1)
self.graph.se3.weight.data = torch.nn.Parameter(start_end)
return self.graph
def setup_optimizer(self, args):
grad_vars = list(self.graph.nerf.parameters())
if args.N_importance > 0:
grad_vars += list(self.graph.nerf_fine.parameters())
self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
grad_vars_se3 = list(self.graph.se3.parameters())
self.optim_se3 = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)
return self.optim, self.optim_se3
class Graph(nerf.Graph):
def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True):
super().__init__(args, D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)
self.pose_eye = torch.eye(3, 4)
self.se3_start = None
self.se3_end = None
def get_pose(self, i, img_idx, args):
se3_start = self.se3.weight[:, :6][img_idx]
se3_end = self.se3.weight[:, 6:][img_idx]
pose_nums = torch.arange(args.deblur_images).reshape(1, -1).repeat(se3_start.shape[0], 1)
seg_pos_x = torch.arange(se3_start.shape[0]).reshape([se3_start.shape[0], 1]).repeat(1, args.deblur_images)
se3_start = se3_start[seg_pos_x, :]
se3_end = se3_end[seg_pos_x, :]
spline_poses = Spline.SplineN_linear(se3_start, se3_end, pose_nums, args.deblur_images)
return spline_poses
def get_pose_even(self, i, img_idx, num):
deblur_images_num = num+1
se3_start = self.se3.weight[:, :6][img_idx]
se3_end = self.se3.weight[:, 6:][img_idx]
pose_nums = torch.arange(deblur_images_num).reshape(1, -1).repeat(se3_start.shape[0],1)
seg_pos_x = torch.arange(se3_start.shape[0]).reshape([se3_start.shape[0], 1]).repeat(1, deblur_images_num)
se3_start = se3_start[seg_pos_x, :]
se3_end = se3_end[seg_pos_x, :]
spline_poses = Spline.SplineN_linear(se3_start, se3_end, pose_nums, deblur_images_num)
return spline_poses
def get_gt_pose(self, poses, args):
a = self.pose_eye
return poses
================================================
FILE: requirements.txt
================================================
configargparse
imageio<2.28.0,>=2.26.0
imageio-ffmpeg
matplotlib
numpy
scikit-learn<1
scikit-image<0.20,>=0.19
torch>=1.8
torchvision>=0.9.1
tqdm
================================================
FILE: run_nerf.py
================================================
import os, sys
import numpy as np
import imageio
import json
import random
import time
import torch
from run_nerf_helpers import *
from Spline import *
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange
import matplotlib.pyplot as plt
from load_llff import load_llff_data
import torchvision.transforms.functional as torchvision_F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
DEBUG = False
def config_parser():
import configargparse
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True, default='configs/cozy2room.txt',
help='config file path')
parser.add_argument("--expname", type=str,
help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/',
help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str, default='./data/llff/fern',
help='input data directory')
# training options
parser.add_argument("--N_iters", type=int, default=200000,
help='the number of sharp images one blur image corresponds to')
parser.add_argument("--deblur_images", type=int, default=5,
help='the number of sharp images one blur image corresponds to')
parser.add_argument("--skip", type=int, default=8,
help='original llffhold before concatenate images')
parser.add_argument("--netdepth", type=int, default=8,
help='layers in network')
parser.add_argument("--netwidth", type=int, default=256,
help='channels per layer')
parser.add_argument("--netdepth_fine", type=int, default=8,
help='layers in fine network')
parser.add_argument("--netwidth_fine", type=int, default=256,
help='channels per layer in fine network')
parser.add_argument("--N_rand", type=int, default=32*32*4,
help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float, default=5e-4,
help='learning rate')
parser.add_argument("--pose_lrate", type=float, default=1e-3,
help='learning rate')
parser.add_argument("--lrate_decay", type=int, default=200,
help='exponential learning rate decay (in 1000 steps)')
parser.add_argument("--chunk", type=int, default=1024*2,
help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--netchunk", type=int, default=1024*32,
help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_batching", action='store_true',
help='only take random rays from 1 image at a time')
parser.add_argument("--no_reload", action='store_true',
help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None,
help='specific weights npy file to reload for coarse network')
# rendering options
parser.add_argument("--N_samples", type=int, default=64,
help='number of coarse samples per ray')
parser.add_argument("--N_importance", type=int, default=0,
help='number of additional fine samples per ray')
parser.add_argument("--perturb", type=float, default=1.,
help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--use_viewdirs", action='store_true',
help='use full 5D input instead of 3D')
parser.add_argument("--i_embed", type=int, default=0,
help='set 0 for default positional encoding, -1 for none')
parser.add_argument("--multires", type=int, default=10,
help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4,
help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--raw_noise_std", type=float, default=0.,
help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
parser.add_argument("--render_only", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
parser.add_argument("--render_test", action='store_true',
help='render the test set instead of render_poses path')
parser.add_argument("--render_factor", type=int, default=0,
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
parser.add_argument("--ndc", type=bool, default=True,
help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
# training options
parser.add_argument("--precrop_iters", type=int, default=0,
help='number of steps to train on central crops')
parser.add_argument("--precrop_frac", type=float,
default=.5, help='fraction of img taken for central crops')
# dataset options
parser.add_argument("--dataset_type", type=str, default='llff',
help='options: llff / blender / deepvoxels')
parser.add_argument("--testskip", type=int, default=8,
help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
## deepvoxels flags
parser.add_argument("--shape", type=str, default='greek',
help='options : armchair / cube / greek / vase')
## blender flags
parser.add_argument("--white_bkgd", action='store_true',
help='set to render synthetic data on a white bkgd (always use for dvoxels)')
parser.add_argument("--half_res", action='store_true',
help='load blender synthetic data at 400x400 instead of 800x800')
## llff flags
parser.add_argument("--factor", type=int, default=8,
help='downsample factor for LLFF images')
parser.add_argument("--no_ndc", action='store_true',
help='do not use normalized device coordinates (set for non-forward facing scenes)')
parser.add_argument("--lindisp", action='store_true',
help='sampling linearly in disparity rather than depth')
parser.add_argument("--spherify", action='store_true',
help='set for spherical 360 scenes')
parser.add_argument("--llffhold", type=int, default=8,
help='will take every 1/N images as LLFF test set, paper uses 8')
# logging/saving options
parser.add_argument("--i_print", type=int, default=100,
help='frequency of console printout and metric loggin')
parser.add_argument("--i_img", type=int, default=15000,
help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000,
help='frequency of weight ckpt saving')
parser.add_argument("--i_testset", type=int, default=50000,
help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=30000,
help='frequency of render_poses video saving')
parser.add_argument("--load_weights", action='store_true',
help='frequency of weight ckpt loading')
# barf: up & down
parser.add_argument("--barf", action='store_true',
help='barf')
parser.add_argument("--barf_start", type=float, default=0.1,
help='barf start')
parser.add_argument("--barf_end", type=float, default=0.9,
help='barf start')
# test option
parser.add_argument("--only_optim_one", action='store_true', default=False,
help='frequency of weight ckpt loading')
parser.add_argument("--split_train_data", action='store_true', default=False,
help='frequency of weight ckpt loading')
parser.add_argument("--weight_iter", type=int, default=20000,
help='weight_iter')
# pose noise
parser.add_argument("--pose_noise", type=float, default=0.1,
help='random noise of pose')
# linear
parser.add_argument("--linear", action='store_true', default=False,
help='linear or cubic spline')
# novel view
parser.add_argument("--novel_view", action='store_true', default=False,
help='novel view test')
parser.add_argument("--i_novel_view", type=int, default=200000,
help='novel view iter')
parser.add_argument("--factor_pose_novel", type=float, default=2.0,
help='factor of learning rate')
parser.add_argument("--N_novel_view", type=int, default=20000,
help='novel view iter for optimizing poses')
return parser
================================================
FILE: run_nerf_helpers.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm, trange
import os
import imageio
# Misc
img2mse = lambda x, y : torch.mean((x - y) ** 2)
img2se = lambda x, y : (x - y) ** 2
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) # logab = logcb / logca
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
to8b_tensor = lambda x : (255*torch.clip(x,0,1)).type(torch.int)
def imread(f):
if f.endswith('png'):
return imageio.imread(f, ignoregamma=True)
else:
return imageio.imread(f)
def load_imgs(path):
imgfiles = [os.path.join(path, f) for f in sorted(os.listdir(path)) if
f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
imgs = np.stack(imgs, -1)
imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
imgs = imgs.astype(np.float32)
imgs = torch.tensor(imgs).cuda()
return imgs
# Ray helpers
def get_rays(H, W, K, c2w):
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3,-1].expand(rays_d.shape)
return rays_o, rays_d
def ndc_rays(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = -(near + rays_o[...,2]) / rays_d[...,2]
rays_o = rays_o + t[...,None] * rays_d
# Projection
o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
o2 = 1. + 2. * near / rays_o[...,2]
d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
d2 = -2. * near / rays_o[...,2]
rays_o = torch.stack([o0,o1,o2], -1)
rays_d = torch.stack([d0,d1,d2], -1)
return rays_o, rays_d
# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# Invert CDF
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[...,1]-cdf_g[...,0])
denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
t = (u-cdf_g[...,0])/denom
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
return samples
def render_video_test(i_, graph, render_poses, H, W, K, args):
rgbs = []
disps = []
# t = time.time()
for i, pose in enumerate(tqdm(render_poses)):
# print(i, time.time() - t)
# t = time.time()
pose = pose[None, :3, :4]
ret = graph.render_video(i_, pose[:3, :4], H, W, K, args)
rgbs.append(ret['rgb_map'].cpu().numpy())
disps.append(ret['disp_map'].cpu().numpy())
if i==0:
print(ret['rgb_map'].shape, ret['disp_map'].shape)
rgbs = np.stack(rgbs, 0)
disps = np.stack(disps, 0)
return rgbs, disps
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
def render_image_test(i, graph, render_poses, H, W, K, args, novel_view=False, need_depth=False):
if novel_view:
img_dir = os.path.join(args.basedir, args.expname, 'img_novel_{:06d}'.format(i))
else:
img_dir = os.path.join(args.basedir, args.expname, 'img_test_{:06d}'.format(i))
os.makedirs(img_dir, exist_ok=True)
imgs =[]
for j, pose in enumerate(tqdm(render_poses)):
# print(i, time.time() - t)
# t = time.time()
pose = pose[None, :3, :4]
ret = graph.render_video(i, pose[:3, :4], H, W, K, args)
imgs.append(ret['rgb_map'])
rgbs = ret['rgb_map'].cpu().numpy()
rgb8 = to8b(rgbs)
imageio.imwrite(os.path.join(img_dir, 'rgb_{:03d}.png'.format(j)), rgb8)
if need_depth:
depths = ret['disp_map'].cpu().numpy()
depths_ = depths/np.max(depths)
depth8 = to8b(depths_)
imageio.imwrite(os.path.join(img_dir, 'depth_{:03d}.png'.format(j)), depth8)
imgs = torch.stack(imgs, 0)
return imgs
def init_weights(linear):
# use Xavier init instead of Kaiming init
torch.nn.init.kaiming_normal_(linear.weight)
torch.nn.init.zeros_(linear.bias)
def init_nerf(nerf):
for linear_pt in nerf.pts_linears:
init_weights(linear_pt)
for linear_view in nerf.views_linears:
init_weights(linear_view)
init_weights(nerf.feature_linear)
init_weights(nerf.alpha_linear)
init_weights(nerf.rgb_linear)
# Ray helpers only get specific rays
def get_specific_rays(i, j, K, c2w):
# i, j = torch.meshgrid(torch.linspace(0, W - 1, W),
# torch.linspace(0, H - 1, H)) # pytorch's meshgrid has indexing='ij'
# i = i.t()
# j = j.t()
dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[..., :3, :3], -1)
# dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[..., :3, -1]
return rays_o, rays_d
def save_render_pose(poses, path):
poses_np = poses.cpu().detach().numpy()
N = poses_np.shape[0]
bottom = np.reshape([0., 0., 0., 1.], [1, 4])
bottom_all = np.expand_dims(bottom, 0).repeat(N, axis=0)
poses_Rt = np.concatenate([poses_np, bottom_all], 1)
poses_txt = os.path.join(path, 'poses_render.txt')
for j in range(poses_np.shape[0]):
poses_flat = poses_Rt[j].reshape(16, 1).squeeze()
for k in range(16):
with open(poses_txt, 'a') as outfile:
if k == 0:
outfile.write(f"pose{j} ")
if k != 15:
outfile.write(f"{poses_flat[k]} ")
if k == 15:
outfile.write(f"{poses_flat[k]}\n")
================================================
FILE: test.py
================================================
import torch
from nerf import *
import optimize_pose_linear, optimize_pose_cubic
import torchvision.transforms.functional as torchvision_F
import matplotlib.pyplot as plt
from metrics import compute_img_metric
import novel_view_test
def test():
parser = config_parser()
args = parser.parse_args()
print('spline numbers: ', args.deblur_images)
imgs_sharp_dir = os.path.join(args.datadir, 'images_test')
imgs_sharp = load_imgs(imgs_sharp_dir)
# Load data images and groundtruth
K = None
if args.dataset_type == 'llff':
images_all, poses_start, bds_start, render_poses = load_llff_data(args.datadir, pose_state=None,
factor=args.factor, recenter=True,
bd_factor=.75, spherify=args.spherify)
hwf = poses_start[0, :3, -1]
# split train/val/test
if args.novel_view:
i_test = torch.arange(0, images_all.shape[0], args.llffhold)
else:
i_test = torch.tensor([100]).long()
i_val = i_test
i_train = torch.Tensor([i for i in torch.arange(int(images_all.shape[0])) if
(i not in i_test and i not in i_val)]).long()
# train data
images = images_all[i_train]
# novel view data
if args.novel_view:
images_novel = images_all[i_test]
# gt data
imgs_sharp = imgs_sharp
# get poses
poses_end = poses_start
poses_start_se3 = SE3_to_se3_N(poses_start[:, :3, :4])
poses_end_se3 = poses_start_se3
poses_org = poses_start.repeat(args.deblur_images, 1, 1)
poses = poses_org[:, :, :4]
print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
print('DEFINING BOUNDS')
if args.no_ndc:
near = torch.min(bds_start) * .9
far = torch.max(bds_start) * 1.
else:
near = 0.
far = 1.
print('NEAR FAR', near, far)
else:
print('Unknown dataset type', args.dataset_type, 'exiting')
return
# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
if K is None:
K = torch.Tensor([
[focal, 0, 0.5 * W],
[0, focal, 0.5 * H],
[0, 0, 1]
])
# Create log dir and copy the config file
basedir = args.basedir
expname = args.expname
test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt')
test_metric_file_novel = os.path.join(basedir, expname, 'test_metrics_novel.txt')
# print_file = os.path.join(basedir, expname, 'print.txt')
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(basedir, expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
if args.linear:
print('Linear Spline Model Loading!')
model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)
else:
print('Cubic Spline Model Loading!')
model = optimize_pose_cubic.Model(poses_start_se3, poses_start_se3, poses_start_se3, poses_start_se3)
graph = model.build_network(args)
optimizer, optimizer_se3 = model.setup_optimizer(args)
path = os.path.join(basedir, expname, '{:06d}.tar'.format(args.weight_iter))
graph_ckpt = torch.load(path)
graph.load_state_dict(graph_ckpt['graph'])
optimizer.load_state_dict(graph_ckpt['optimizer'])
optimizer_se3.load_state_dict(graph_ckpt['optimizer_se3'])
global_step = graph_ckpt['global_step']
if args.deblur_images % 2 == 0:
all_poses = graph.get_pose_even(0, torch.arange(graph.se3.weight.shape[0]), args.deblur_images)
else:
all_poses = graph.get_pose(0, torch.arange(graph.se3.weight.shape[0]), args)
# Turn on testing mode
with torch.no_grad():
if args.deblur_images % 2 == 0:
i_render = torch.arange(i_train.shape[0]) * (args.deblur_images + 1) + args.deblur_images // 2
else:
i_render = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2
imgs_render = render_image_test(0, graph, all_poses[i_render], H, W, K, args)
mse_render = compute_img_metric(imgs_sharp, imgs_render, 'mse')
psnr_render = compute_img_metric(imgs_sharp, imgs_render, 'psnr')
ssim_render = compute_img_metric(imgs_sharp, imgs_render, 'ssim')
lpips_render = compute_img_metric(imgs_sharp, imgs_render, 'lpips')
with open(test_metric_file, 'a') as outfile:
outfile.write(f"test: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}"
f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n")
# Turn on novel view testing mode
if args.novel_view:
i_ = torch.arange(0, images.shape[0], args.llffhold - 1)
poses_test_se3_ = graph.se3.weight[i_, :6]
model_test = novel_view_test.Model(poses_test_se3_, graph)
graph_test = model_test.build_network(args)
optimizer_test = model_test.setup_optimizer(args)
for j in range(args.N_novel_view):
ret_sharp, ray_idx_sharp, poses_sharp = graph_test.forward(0, 0, 0, H, W, K, args,
novel_view=True)
target_s_novel = images_novel.reshape(-1, H * W, 3)[:, ray_idx_sharp]
target_s_novel = target_s_novel.reshape(-1, 3)
loss_sharp = img2mse(ret_sharp['rgb_map'], target_s_novel)
psnr_sharp = mse2psnr(loss_sharp)
if 'rgb0' in ret_sharp:
img_loss0 = img2mse(ret_sharp['rgb0'], target_s_novel)
loss_sharp = loss_sharp + img_loss0
if j % 100 == 0:
print(psnr_sharp.item(), loss_sharp.item())
optimizer_test.zero_grad()
loss_sharp.backward()
optimizer_test.step()
decay_rate_sharp = 0.01
decay_steps_sharp = args.lrate_decay * 100
new_lrate_novel = args.pose_lrate * (decay_rate_sharp ** (j / decay_steps_sharp))
for param_group in optimizer_test.param_groups:
if (j / decay_steps_sharp) <= 1.:
param_group['lr'] = new_lrate_novel * args.factor_pose_novel
with torch.no_grad():
imgs_render_novel = render_image_test(0, graph, poses_sharp, H, W, K, args, novel_view=True)
mse_render = compute_img_metric(images_novel, imgs_render_novel, 'mse')
psnr_render = compute_img_metric(images_novel, imgs_render_novel, 'psnr')
ssim_render = compute_img_metric(images_novel, imgs_render_novel, 'ssim')
lpips_render = compute_img_metric(images_novel, imgs_render_novel, 'lpips')
with open(test_metric_file_novel, 'a') as outfile:
outfile.write(f"novel view test: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}"
f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n")
return 0
if __name__=='__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')
test()
================================================
FILE: train.py
================================================
import torch
from nerf import *
import optimize_pose_linear, optimize_pose_cubic
import torchvision.transforms.functional as torchvision_F
import matplotlib.pyplot as plt
from metrics import compute_img_metric
import novel_view_test
def train():
parser = config_parser()
args = parser.parse_args()
print('spline numbers: ', args.deblur_images)
imgs_sharp_dir = os.path.join(args.datadir, 'images_test')
imgs_sharp = load_imgs(imgs_sharp_dir)
# Load data images and groundtruth
K = None
if args.dataset_type == 'llff':
images_all, poses_start, bds_start, render_poses = load_llff_data(args.datadir, pose_state=None,
factor=args.factor, recenter=True,
bd_factor=.75, spherify=args.spherify)
hwf = poses_start[0, :3, -1]
# split train/val/test
if args.novel_view:
i_test = torch.arange(0, images_all.shape[0], args.llffhold)
else:
i_test = torch.tensor([100]).long()
i_val = i_test
i_train = torch.Tensor([i for i in torch.arange(int(images_all.shape[0])) if
(i not in i_test and i not in i_val)]).long()
# train data
images = images_all[i_train]
# novel view data
if args.novel_view:
images_novel = images_all[i_test]
# gt data
imgs_sharp = imgs_sharp
# get poses
poses_end = poses_start
poses_start_se3 = SE3_to_se3_N(poses_start[:, :3, :4])
poses_end_se3 = poses_start_se3
poses_org = poses_start.repeat(args.deblur_images, 1, 1)
poses = poses_org[:, :, :4]
print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
print('DEFINING BOUNDS')
if args.no_ndc:
near = torch.min(bds_start) * .9
far = torch.max(bds_start) * 1.
else:
near = 0.
far = 1.
print('NEAR FAR', near, far)
else:
print('Unknown dataset type', args.dataset_type, 'exiting')
return
# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
if K is None:
K = torch.Tensor([
[focal, 0, 0.5 * W],
[0, focal, 0.5 * H],
[0, 0, 1]
])
# Create log dir and copy the config file
basedir = args.basedir
expname = args.expname
test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt')
test_metric_file_novel = os.path.join(basedir, expname, 'test_metrics_novel.txt')
print_file = os.path.join(basedir, expname, 'print.txt')
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(basedir, expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
if args.load_weights:
if args.linear:
print('Linear Spline Model Loading!')
model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)
else:
print('Cubic Spline Model Loading!')
model = optimize_pose_cubic.Model(poses_start_se3, poses_start_se3, poses_start_se3, poses_start_se3)
graph = model.build_network(args)
optimizer, optimizer_se3 = model.setup_optimizer(args)
path = os.path.join(basedir, expname, '{:06d}.tar'.format(args.weight_iter)) # here
graph_ckpt = torch.load(path)
graph.load_state_dict(graph_ckpt['graph'])
optimizer.load_state_dict(graph_ckpt['optimizer'])
optimizer_se3.load_state_dict(graph_ckpt['optimizer_se3'])
global_step = graph_ckpt['global_step']
else:
if args.linear:
low, high = 0.0001, 0.005
rand = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low
poses_start_se3 = poses_start_se3 + rand
model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)
else:
low, high = 0.0001, 0.01
rand1 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low
rand2 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low
rand3 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low
poses_se3_1 = poses_start_se3 + rand1
poses_se3_2 = poses_start_se3 + rand2
poses_se3_3 = poses_start_se3 + rand3
model = optimize_pose_cubic.Model(poses_start_se3, poses_se3_1, poses_se3_2, poses_se3_3)
graph = model.build_network(args) # nerf, nerf_fine, forward
optimizer, optimizer_se3 = model.setup_optimizer(args)
N_iters = args.N_iters + 1
print('Begin')
print('TRAIN views are', i_train)
print('TEST views are', i_test)
print('VAL views are', i_val)
start = 0
if not args.load_weights:
global_step = start
global_step_ = global_step
threshold = N_iters + 1
poses_num = poses.shape[0]
for i in trange(start, threshold):
### core optimization loop ###
i = i+global_step_
if i == 0:
init_nerf(graph.nerf)
init_nerf(graph.nerf_fine)
img_idx = torch.randperm(images.shape[0])
if (i % args.i_img == 0 or i % args.i_novel_view == 0) and i > 0:
ret, ray_idx, spline_poses, all_poses = graph.forward(i, img_idx, poses_num, H, W, K, args)
else:
ret, ray_idx, spline_poses = graph.forward(i, img_idx, poses_num, H, W, K, args)
# get image ground truth
target_s = images[img_idx].reshape(-1, H * W, 3)
target_s = target_s[:, ray_idx]
target_s = target_s.reshape(-1, 3)
# average
shape0 = img_idx.shape[0]
interval = target_s.shape[0] // shape0
rgb_list = []
extras_list = []
rgb_ = 0
extras_ = 0
for j in range(0, shape0 * args.deblur_images):
rgb_ += ret['rgb_map'][j * interval:(j + 1) * interval]
if 'rgb0' in ret:
extras_ += ret['rgb0'][j * interval:(j + 1) * interval]
if (j + 1) % args.deblur_images == 0:
rgb_ = rgb_ / args.deblur_images
rgb_list.append(rgb_)
rgb_ = 0
if 'rgb0' in ret:
extras_ = extras_ / args.deblur_images
extras_list.append(extras_)
extras_ = 0
rgb_blur = torch.stack(rgb_list, 0)
rgb_blur = rgb_blur.reshape(-1, 3)
if 'rgb0' in ret:
extras_blur = torch.stack(extras_list, 0)
extras_blur = extras_blur.reshape(-1, 3)
# backward
optimizer_se3.zero_grad()
optimizer.zero_grad()
img_loss = img2mse(rgb_blur, target_s)
loss = img_loss
psnr = mse2psnr(img_loss)
if 'rgb0' in ret:
img_loss0 = img2mse(extras_blur, target_s)
loss = loss + img_loss0
psnr0 = mse2psnr(img_loss0)
loss.backward()
optimizer.step()
optimizer_se3.step()
# NOTE: IMPORTANT!
### update learning rate ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
decay_rate_pose = 0.01
new_lrate_pose = args.pose_lrate * (decay_rate_pose ** (global_step / decay_steps))
for param_group in optimizer_se3.param_groups:
param_group['lr'] = new_lrate_pose
###############################
if i%args.i_print==0:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} coarse_loss:, {img_loss0.item()}, PSNR: {psnr.item()}")
with open(print_file, 'a') as outfile:
outfile.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} coarse_loss:, {img_loss0.item()}, PSNR: {psnr.item()}\n")
if i < 10:
print('coarse_loss:', img_loss0.item())
with open(print_file, 'a') as outfile:
outfile.write(f"coarse loss: {img_loss0.item()}\n")
if i % args.i_weights == 0 and i > 0:
path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
torch.save({
'global_step': global_step,
'graph': graph.state_dict(),
'optimizer': optimizer.state_dict(),
'optimizer_se3': optimizer_se3.state_dict(),
}, path)
print('Saved checkpoints at', path)
if i % args.i_img == 0 and i > 0:
# Turn on testing mode
with torch.no_grad():
if args.deblur_images % 2 == 0:
i_render = torch.arange(i_train.shape[0]) * (args.deblur_images+1) + args.deblur_images // 2
else:
i_render = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2
imgs_render = render_image_test(i, graph, all_poses[i_render], H, W, K, args)
mse_render = compute_img_metric(imgs_sharp, imgs_render, 'mse')
psnr_render = compute_img_metric(imgs_sharp, imgs_render, 'psnr')
ssim_render = compute_img_metric(imgs_sharp, imgs_render, 'ssim')
lpips_render = compute_img_metric(imgs_sharp, imgs_render, 'lpips')
with open(test_metric_file, 'a') as outfile:
outfile.write(f"iter{i}: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}"
f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n")
if i % args.i_video == 0 and i > 0:
# Turn on testing mode
with torch.no_grad():
rgbs, disps = render_video_test(i, graph, render_poses, H, W, K, args)
print('Done, saving', rgbs.shape, disps.shape)
moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)
if args.novel_view and i % args.i_novel_view == 0 and i > 0:
# Turn on novel view testing mode
i_ = torch.arange(0, images.shape[0], args.llffhold-1)
poses_test_se3_ = graph.se3.weight[i_,:6]
model_test = novel_view_test.Model(poses_test_se3_, graph)
graph_test = model_test.build_network(args)
optimizer_test = model_test.setup_optimizer(args)
for j in range(args.N_novel_view):
ret_sharp, ray_idx_sharp, poses_sharp = graph_test.forward(i, img_idx, poses_num, H, W, K, args, novel_view=True)
target_s_novel = images_novel.reshape(-1, H*W, 3)[:, ray_idx_sharp]
target_s_novel = target_s_novel.reshape(-1, 3)
loss_sharp = img2mse(ret_sharp['rgb_map'], target_s_novel)
psnr_sharp = mse2psnr(loss_sharp)
if 'rgb0' in ret_sharp:
img_loss0 = img2mse(ret_sharp['rgb0'], target_s_novel)
loss_sharp = loss_sharp + img_loss0
if j%100==0:
print(psnr_sharp.item(), loss_sharp.item())
optimizer_test.zero_grad()
loss_sharp.backward()
optimizer_test.step()
decay_rate_sharp = 0.01
decay_steps_sharp = args.lrate_decay * 100
new_lrate_novel = args.pose_lrate * (decay_rate_sharp ** (j / decay_steps_sharp))
for param_group in optimizer_test.param_groups:
if (j / decay_steps_sharp) <= 1.:
param_group['lr'] = new_lrate_novel * args.factor_pose_novel
with torch.no_grad():
imgs_render_novel = render_image_test(i, graph, poses_sharp, H, W, K, args, novel_view=True)
mse_render = compute_img_metric(images_novel, imgs_render_novel, 'mse')
psnr_render = compute_img_metric(images_novel, imgs_render_novel, 'psnr')
ssim_render = compute_img_metric(images_novel, imgs_render_novel, 'ssim')
lpips_render = compute_img_metric(images_novel, imgs_render_novel, 'lpips')
with open(test_metric_file_novel, 'a') as outfile:
outfile.write(f"iter{i}: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}"
f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n")
if i % args.N_iters == 0 and i > 0:
# Turn on testing mode
with torch.no_grad():
path_pose = os.path.join(basedir, expname)
i_render_pose = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2
render_poses_final = all_poses[i_render_pose]
save_render_pose(render_poses_final, path_pose)
global_step += 1
if __name__=='__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')
train()