Repository: WU-CVGL/BAD-NeRF Branch: main Commit: aed2c4a4b230 Files: 35 Total size: 113.4 KB Directory structure: gitextract_1dub1nd4/ ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── Spline.py ├── configs/ │ ├── cozy2room-cubic.txt │ ├── cozy2room.txt │ ├── factory-cubic.txt │ ├── factory.txt │ ├── pool-cubic.txt │ ├── pool.txt │ ├── roomdark-cubic.txt │ ├── roomdark.txt │ ├── roomhigh-cubic.txt │ ├── roomhigh.txt │ ├── roomlow-cubic.txt │ ├── roomlow.txt │ ├── tanabata-cubic.txt │ ├── tanabata.txt │ ├── wine-cubic.txt │ └── wine.txt ├── load_llff.py ├── lpips/ │ ├── __init__.py │ ├── lpips.py │ └── pretrained_networks.py ├── metrics.py ├── nerf.py ├── novel_view_test.py ├── optimize_pose_cubic.py ├── optimize_pose_linear.py ├── requirements.txt ├── run_nerf.py ├── run_nerf_helpers.py ├── test.py └── train.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea/ .vscode/ __pycache__/ logs*/ data/ configs_test/ weights/ ================================================ FILE: Dockerfile ================================================ FROM nvcr.io/nvidia/pytorch:23.02-py3 ARG DEBIAN_FRONTEND=noninteractive ENV TZ=Asia/Shanghai LANG=C.UTF-8 LC_ALL=C.UTF-8 PIP_NO_CACHE_DIR=1 PIP_CACHE_DIR=/tmp/ PYTHONUNBUFFERED=1 PYTHONFAULTHANDLER=1 PYTHONHASHSEED=0 RUN \ # uncomment to use apt mirror # sed -i "s/archive.ubuntu.com/mirrors.ustc.edu.cn/g" /etc/apt/sources.list &&\ # sed -i "s/security.ubuntu.com/mirrors.ustc.edu.cn/g" /etc/apt/sources.list &&\ rm -f /etc/apt/sources.list.d/* &&\ rm -rf /opt/hpcx/ &&\ apt-get update && apt-get upgrade -y &&\ apt-get install -y --no-install-recommends \ autoconf automake autotools-dev build-essential ca-certificates \ make cmake ninja-build pkg-config g++ ccache yasm openmpi-bin \ git curl wget unzip nano net-tools htop iotop \ cloc rsync xz-utils software-properties-common tzdata \ && apt-get purge -y unattended-upgrades \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* ADD requirements.txt /tmp RUN \ # uncomment to use pypi mirror # pip config set global.index-url https://mirrors.bfsu.edu.cn/pypi/web/simple &&\ pip install -U pip &&\ pip install -r /tmp/requirements.txt &&\ pip install "jupyterlab~=3.5.0" "jupyter-archive~=3.2" &&\ rm -rf /tmp/* ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 Peng Wang, Lingzhe Zhao, Ruijie Ma, Peidong Liu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # 😈BAD-NeRF This is an official PyTorch implementation of the paper [BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields](https://arxiv.org/abs/2211.12853) (CVPR 2023). Authors: [Peng Wang](https://github.com/wangpeng000), [Lingzhe Zhao](https://github.com/LingzheZhao), Ruijie Ma and [Peidong Liu](https://ethliup.github.io/). BAD-NeRF jointly learns the 3D representation and optimizes the camera motion trajectories within exposure time from blurry images and inaccurate initial poses. Here is the [Project page](https://wangpeng000.github.io/BAD-NeRF/). ## ✨News 📺 **[2023.12]** We were invited to give an online talk on Bilibili and Wechat, hosted by **计算机视觉life**. Archive of our live stream (in Chinese): [[link to Bilibili]](https://www.bilibili.com/video/BV1Rb4y1G7Ek/) ⚡ **[2023.11]** We made our **nerfstudio**-framework-based implementation: [[BAD-NeRFstudio]](https://github.com/WU-CVGL/BAD-NeRFstudio) public. Now you can train a scene from blurry images in minutes! ## Novel View Synthesis
## Deblurring Result ![teaser](./doc/bad-nerf.jpg) ## Pose Estimation Result ![pose estimation](./doc/pose-estimation.jpg) ## Method overview ![method](./doc/overview.jpg) We follow the real physical image formation process of a motion-blurred image to synthesize blurry images from NeRF. Both NeRF and the motion trajectories are estimated by maximizing the photometric consistency between the synthesized blurry images and the real blurry images. ## Quickstart ### 1. Setup environment ``` git clone https://github.com/WU-CVGL/BAD-NeRF cd BAD-NeRF pip install -r requirements.txt ``` ### 2. Download datasets You can download the data and weights [here](https://westlakeu-my.sharepoint.com/:f:/g/personal/cvgl_westlake_edu_cn/EsgdW2cRic5JqerhNbTsxtkBqy9m6cbnb2ugYZtvaib3qA?e=bjK7op). For the scenes of Deblur-NeRF (*cozy2room*, *factory* etc.), the folder `images` only includes blurry images and the folder `images_1` additionally includes novel view images. But for our scenes (*room-low*, *room-high* and *dark*), there are no novel view images. Note that the images in the *dark* scene are undistorted, causing that there are some useless pixels, and you should uncomment the code of `Graph.forward` in `nerf.py`. ### 3. Configs Change the data path and other parameters (if needed) in `configs/cozy2room.txt`. We use *cozy2room* scene as an example. ### 4. Demo with our pre-trained model You can test our code and render sharp images with the provided weight files. To do this, you should first put the weight file under the corresponding logs folder `./logs/cozy2room-linear`, and then change the parameter `load_weights=True` in `cozy2room.txt`, finally run ``` python test.py --config configs/cozy2room.txt ``` ### 5. Training ``` python train.py --config configs/cozy2room.txt ``` After training, you can get deblurred images, optimized camera poses and synthesized novel view images. ## Notes ### Camera poses The poses (`poses_bounds.npy`) are generated from only blurred images (folder `images`) by COLMAP. ### Spline model We use `linear interpolation` as the default spline model in our experiments, you can simply change the parameter `linear` (all the parameters can be changed in `configs/***.txt` or `run_nerf.py`) to `False` to use the higher-order spline model (i.e. `cubic B-Spline`). ### Virtual images You can change the important parameter `deblur_images` to a smaller/bigger value for lightly/severely blurred images. ### Learning rate After rebuttal, we found that sometimes the gradients will be NaN if `cubic B-Spline` model with a `pose_lrate=1e-3` is used. Therefore, we set the initial pose learning rate to 1e-4 and it may achieve a better performance compared to that in our paper. If the gradient appears NaN in your experiments unfortunately, just kill it and try again or decrease the `pose_lrate`. ## Your own data `images`: This folder is used to estimate initial camera poses from blurry images. Specifically, just put your own data in the folder `images` (only blurry images), and run `imgs2poses.py` script from the [LLFF code](https://github.com/fyusion/llff) to estimate camera poses and generate `poses_bounds.npy`. `images_1`: This is the default training folder, which includes the same blurry images in `images` folder and (optional) several novel view sharp images. If you want to add novel view images (sharp images), please put them into the folder `images_1` with an interval of `llffhold` (a parameter used for novel view testing). Remember that, set parameter `novel_view` to `True` if `images_1` includes novel view images. Otherwise, if there are no novel view images, you can directly put the blurry images to the folder `images_1` and set parameter `novel_view` to `False`. `images_test`: To compute deblurring metrics, this folder contains ground truth images theoretically. However, you can copy the blurry images in `images` folder to `images_test` folder if you don't have ground truth images, which is the easiest way to run the code correctly (remember the computed metrics are wrong). ``` #-----------------------------------------------------------------------------------------# # images folder: img_blur_*.png is the blurry image. # #-----------------------------------------------------------------------------------------# # images_1 folder: img_blur_*.png is the same as that in `images` and (optional) # # img_novel_*.png is the sharp novel view image. # #-----------------------------------------------------------------------------------------# # images_test folder: img_test_*.png should be the ground truth image corrseponds to # # img_blur_*.png to compute PSNR metric. Of course, you can directly put img_blur_*.png # # to run the code if you don't have gt images (then the metrics are wrong). # #-----------------------------------------------------------------------------------------# images folder: (suppose 10 images) img_blur_0.png img_blur_1.png . . . img_blur_9.png #-----------------------------------------------------------------------------------------# images_1 folder: (suppose novel view images are placed with an `llffhold=5` interval.) img_novel_0.png (optional) img_blur_0.png img_blur_1.png . img_blur_4.png img_novel_1.png (optional) img_blur_5.png . img_blur_9.png img_novel_2.png (optional) #-----------------------------------------------------------------------------------------# images_test folder: (theoretically gt images, but can be other images) img_test_0.png img_test_1.png . . . img_test_9.png ``` ## Citation If you find this useful, please consider citing our paper: ```bibtex @InProceedings{wang2023badnerf, author = {Wang, Peng and Zhao, Lingzhe and Ma, Ruijie and Liu, Peidong}, title = {{BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields}}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2023}, pages = {4170-4179} } ``` ## Acknowledgment The overall framework, metrics computing and camera transformation are derived from [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch/), [Deblur-NeRF](https://github.com/limacv/Deblur-NeRF) and [BARF](https://github.com/chenhsuanlin/bundle-adjusting-NeRF) respectively. We appreciate the effort of the contributors to these repositories. ================================================ FILE: Spline.py ================================================ import torch import numpy as np delt = 0 def skew_symmetric(w): w0, w1, w2 = w.unbind(dim=-1) O = torch.zeros_like(w0) wx = torch.stack( [ torch.stack([O, -w2, w1], dim=-1), torch.stack([w2, O, -w0], dim=-1), torch.stack([-w1, w0, O], dim=-1), ], dim=-2, ) return wx def taylor_A(x, nth=10): # Taylor expansion of sin(x)/x ans = torch.zeros_like(x) denom = 1.0 for i in range(nth + 1): if i > 0: denom *= (2 * i) * (2 * i + 1) ans = ans + (-1) ** i * x ** (2 * i) / denom return ans def taylor_B(x, nth=10): # Taylor expansion of (1-cos(x))/x**2 ans = torch.zeros_like(x) denom = 1.0 for i in range(nth + 1): denom *= (2 * i + 1) * (2 * i + 2) ans = ans + (-1) ** i * x ** (2 * i) / denom return ans def taylor_C(x, nth=10): # Taylor expansion of (x-sin(x))/x**3 ans = torch.zeros_like(x) denom = 1.0 for i in range(nth + 1): denom *= (2 * i + 2) * (2 * i + 3) ans = ans + (-1) ** i * x ** (2 * i) / denom return ans def exp_r2q_parallel(r, eps=1e-9): x, y, z = r[..., 0], r[..., 1], r[..., 2] theta = 0.5 * torch.sqrt(x**2 + y**2 + z**2) bool_criterion = (theta < eps).unsqueeze(-1).repeat(1, 1, 4) return torch.where( bool_criterion, exp_r2q_taylor(x, y, z, theta), exp_r2q(x, y, z, theta) ) def exp_r2q(x, y, z, theta): lambda_ = torch.sin(theta) / (2.0 * theta) qx = lambda_ * x qy = lambda_ * y qz = lambda_ * z qw = torch.cos(theta) return torch.stack([qx, qy, qz, qw], -1) def exp_r2q_taylor(x, y, z, theta): qx = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * x qy = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * y qz = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * z qw = 1.0 - 1.0 / 2.0 * theta**2 + 1.0 / 24.0 * theta**4 return torch.stack([qx, qy, qz, qw], -1) def q_to_R_parallel(q): qb, qc, qd, qa = q.unbind(dim=-1) R = torch.stack( [ torch.stack( [ 1 - 2 * (qc**2 + qd**2), 2 * (qb * qc - qa * qd), 2 * (qa * qc + qb * qd), ], dim=-1, ), torch.stack( [ 2 * (qb * qc + qa * qd), 1 - 2 * (qb**2 + qd**2), 2 * (qc * qd - qa * qb), ], dim=-1, ), torch.stack( [ 2 * (qb * qd - qa * qc), 2 * (qa * qb + qc * qd), 1 - 2 * (qb**2 + qc**2), ], dim=-1, ), ], dim=-2, ) return R def q_to_Q_parallel(q): x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] Q_0 = torch.stack([w, -z, y, x], -1).unsqueeze(-2) Q_1 = torch.stack([z, w, -x, y], -1).unsqueeze(-2) Q_2 = torch.stack([-y, x, w, z], -1).unsqueeze(-2) Q_3 = torch.stack([-x, -y, -z, w], -1).unsqueeze(-2) Q_ = torch.cat([Q_0, Q_1, Q_2, Q_3], -2) return Q_ def q_to_q_conj_parallel(q): x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] q_conj_ = torch.stack([-x, -y, -z, w], -1) return q_conj_ def log_q2r_parallel(q, eps_theta=1e-20, eps_w=1e-10): x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3] theta = torch.sqrt(x**2 + y**2 + z**2) bool_theta_0 = theta < eps_theta bool_w_0 = torch.abs(w) < eps_w bool_w_0_left = torch.logical_and(bool_w_0, w < 0) lambda_ = torch.where( bool_w_0, torch.where( bool_w_0_left, log_q2r_lim_w_0_left(theta), log_q2r_lim_w_0_right(theta) ), torch.where( bool_theta_0, log_q2r_taylor_theta_0(w, theta), log_q2r(w, theta) ), ) r_ = torch.stack([lambda_ * x, lambda_ * y, lambda_ * z], -1) return r_ def log_q2r(w, theta): return 2.0 * (torch.arctan(theta / w)) / theta def log_q2r_taylor_theta_0(w, theta): return 2.0 / w - 2.0 / 3.0 * (theta**2) / (w * w * w) def log_q2r_lim_w_0_left(theta): return -torch.pi / theta def log_q2r_lim_w_0_right(theta): return torch.pi / theta def SE3_to_se3(Rt, eps=1e-8): # [...,3,4] R, t = Rt.split([3, 1], dim=-1) w = SO3_to_so3(R) wx = skew_symmetric(w) theta = w.norm(dim=-1)[..., None, None] I = torch.eye(3, device=w.device, dtype=torch.float32) A = taylor_A(theta) B = taylor_B(theta) invV = I - 0.5 * wx + (1 - A / (2 * B)) / (theta**2 + eps) * wx @ wx u = (invV @ t)[..., 0] wu = torch.cat([w, u], dim=-1) return wu def SO3_to_so3(R, eps=1e-7): # [...,3,3] trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2] theta = ((trace - 1) / 2).clamp(-1 + eps, 1 - eps).acos_()[ ..., None, None ] % np.pi # ln(R) will explode if theta==pi lnR = ( 1 / (2 * taylor_A(theta) + 1e-8) * (R - R.transpose(-2, -1)) ) # FIXME: wei-chiu finds it weird w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0] w = torch.stack([w0, w1, w2], dim=-1) return w def se3_to_SE3(wu): # [...,3] w, u = wu.split([3, 3], dim=-1) wx = skew_symmetric(w) # wx=[0 -w(2) w(1);w(2) 0 -w(0);-w(1) w(0) 0] theta = w.norm(dim=-1)[..., None, None] # theta=sqrt(w'*w) I = torch.eye(3, device=w.device, dtype=torch.float32) A = taylor_A(theta) B = taylor_B(theta) C = taylor_C(theta) R = I + A * wx + B * wx @ wx V = I + B * wx + C * wx @ wx Rt = torch.cat([R, (V @ u[..., None])], dim=-1) return Rt def SE3_to_se3_N(poses_rt): poses_se3_list = [] for i in range(poses_rt.shape[0]): pose_se3 = SE3_to_se3(poses_rt[i]) poses_se3_list.append(pose_se3) poses = torch.stack(poses_se3_list, 0) return poses def se3_to_SE3_N(poses_wu): poses_se3_list = [] for i in range(poses_wu.shape[0]): pose_se3 = se3_to_SE3(poses_wu[i]) poses_se3_list.append(pose_se3) poses = torch.stack(poses_se3_list, 0) return poses def se3_2_qt_parallel(wu): w, u = wu.split([3, 3], dim=-1) wx = skew_symmetric(w) theta = w.norm(dim=-1)[..., None, None] I = torch.eye(3, device=w.device, dtype=torch.float32) # A = taylor_A(theta) B = taylor_B(theta) C = taylor_C(theta) # R = I + A * wx + B * wx @ wx V = I + B * wx + C * wx @ wx t = V @ u[..., None] q = exp_r2q_parallel(w) return q, t.squeeze(-1) def SplineN_linear(start_pose, end_pose, poses_number, NUM, device=None): pose_time = poses_number / (NUM - 1) # parallel pos_0 = torch.where(pose_time == 0) pose_time[pos_0] = pose_time[pos_0] + 0.000001 pos_1 = torch.where(pose_time == 1) pose_time[pos_1] = pose_time[pos_1] - 0.000001 q_start, t_start = se3_2_qt_parallel(start_pose) q_end, t_end = se3_2_qt_parallel(end_pose) # sample t_vector t_t = (1 - pose_time)[..., None] * t_start + pose_time[..., None] * t_end # sample rotation_vector q_tau_0 = q_to_Q_parallel(q_to_q_conj_parallel(q_start)) @ q_end[..., None] r = pose_time[..., None] * log_q2r_parallel(q_tau_0.squeeze(-1)) q_t_0 = exp_r2q_parallel(r) q_t = q_to_Q_parallel(q_start) @ q_t_0[..., None] # convert q&t to RT R = q_to_R_parallel(q_t.squeeze(dim=-1)) t = t_t.unsqueeze(dim=-1) pose_spline = torch.cat([R, t], -1) poses = pose_spline.reshape([-1, 3, 4]) return poses def SplineN_cubic(pose0, pose1, pose2, pose3, poses_number, NUM): sample_time = poses_number / (NUM - 1) # parallel pos_0 = torch.where(sample_time == 0) sample_time[pos_0] = sample_time[pos_0] + 0.000001 pos_1 = torch.where(sample_time == 1) sample_time[pos_1] = sample_time[pos_1] - 0.000001 sample_time = sample_time.unsqueeze(-1) q0, t0 = se3_2_qt_parallel(pose0) q1, t1 = se3_2_qt_parallel(pose1) q2, t2 = se3_2_qt_parallel(pose2) q3, t3 = se3_2_qt_parallel(pose3) u = sample_time uu = sample_time**2 uuu = sample_time**3 one_over_six = 1.0 / 6.0 half_one = 0.5 # t coeff0 = one_over_six - half_one * u + half_one * uu - one_over_six * uuu coeff1 = 4 * one_over_six - uu + half_one * uuu coeff2 = one_over_six + half_one * u + half_one * uu - half_one * uuu coeff3 = one_over_six * uuu # spline t t_t = coeff0 * t0 + coeff1 * t1 + coeff2 * t2 + coeff3 * t3 # R coeff1_r = 5 * one_over_six + half_one * u - half_one * uu + one_over_six * uuu coeff2_r = one_over_six + half_one * u + half_one * uu - 2 * one_over_six * uuu coeff3_r = one_over_six * uuu # spline R q_01 = q_to_Q_parallel(q_to_q_conj_parallel(q0)) @ q1[..., None] # [1] q_12 = q_to_Q_parallel(q_to_q_conj_parallel(q1)) @ q2[..., None] # [2] q_23 = q_to_Q_parallel(q_to_q_conj_parallel(q2)) @ q3[..., None] # [3] r_01 = log_q2r_parallel(q_01.squeeze(-1)) * coeff1_r # [4] r_12 = log_q2r_parallel(q_12.squeeze(-1)) * coeff2_r # [5] r_23 = log_q2r_parallel(q_23.squeeze(-1)) * coeff3_r # [6] q_t_0 = exp_r2q_parallel(r_01) # [7] q_t_1 = exp_r2q_parallel(r_12) # [8] q_t_2 = exp_r2q_parallel(r_23) # [9] q_product1 = q_to_Q_parallel(q_t_1) @ q_t_2[..., None] # [10] q_product2 = q_to_Q_parallel(q_t_0) @ q_product1 # [10] q_t = q_to_Q_parallel(q0) @ q_product2 # [10] R = q_to_R_parallel(q_t.squeeze(-1)) t = t_t.unsqueeze(dim=-1) pose_spline = torch.cat([R, t], -1) poses = pose_spline.reshape([-1, 3, 4]) return poses ================================================ FILE: configs/cozy2room-cubic.txt ================================================ expname = cozy2room-cubic basedir = ./logs datadir = ./data/nerf_llff_data/blurcozy2room dataset_type = llff factor = 1 linear = False pose_lrate = 1e-4 novel_view = True factor_pose_novel = 20.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/cozy2room.txt ================================================ expname = cozy2room-linear basedir = ./logs datadir = ./data/nerf_llff_data/blurcozy2room dataset_type = llff factor = 1 linear = True novel_view = True factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/factory-cubic.txt ================================================ expname = factory-cubic basedir = ./logs datadir = ./data/nerf_llff_data/blurfactory dataset_type = llff factor = 1 linear = False pose_lrate = 1e-4 novel_view = True factor_pose_novel = 200.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/factory.txt ================================================ expname = factory-linear basedir = ./logs datadir = ./data/nerf_llff_data/blurfactory dataset_type = llff factor = 1 linear = True novel_view = True factor_pose_novel = 20.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/pool-cubic.txt ================================================ expname = pool-cubic basedir = ./logs datadir = ./data/nerf_llff_data/blurpool dataset_type = llff factor = 1 linear = False pose_lrate = 1e-4 novel_view = True factor_pose_novel = 20.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/pool.txt ================================================ expname = pool-linear basedir = ./logs datadir = ./data/nerf_llff_data/blurpool dataset_type = llff factor = 1 linear = True novel_view = True factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/roomdark-cubic.txt ================================================ expname = roomdark-cubic basedir = ./logs datadir = ./data/nerf_llff_data/dark dataset_type = llff factor = 1 linear = False pose_lrate = 1e-4 novel_view = False factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/roomdark.txt ================================================ expname = roomdark-linear basedir = ./logs datadir = ./data/nerf_llff_data/dark dataset_type = llff factor = 1 linear = True novel_view = False factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/roomhigh-cubic.txt ================================================ expname = roomhigh-cubic basedir = ./logs datadir = ./data/nerf_llff_data/roomblur_high dataset_type = llff factor = 1 linear = False pose_lrate = 1e-4 novel_view = False factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/roomhigh.txt ================================================ expname = roomhigh-linear basedir = ./logs datadir = ./data/nerf_llff_data/roomblur_high dataset_type = llff factor = 1 linear = True novel_view = False factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/roomlow-cubic.txt ================================================ expname = roomlow-cubic basedir = ./logs datadir = ./data/nerf_llff_data/roomblur_low dataset_type = llff factor = 1 linear = False pose_lrate = 1e-4 novel_view = False factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/roomlow.txt ================================================ expname = roomlow-linear basedir = ./logs datadir = ./data/nerf_llff_data/roomblur_low dataset_type = llff factor = 1 linear = True novel_view = False factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/tanabata-cubic.txt ================================================ expname = tanabata-cubic basedir = ./logs datadir = ./data/nerf_llff_data/blurtanabata dataset_type = llff factor = 1 linear = False pose_lrate = 1e-4 novel_view = True factor_pose_novel = 50.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/tanabata.txt ================================================ expname = tanabata-linear basedir = ./logs datadir = ./data/nerf_llff_data/blurtanabata dataset_type = llff factor = 1 linear = True novel_view = True factor_pose_novel = 5.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/wine-cubic.txt ================================================ expname = wine-cubic basedir = ./logs datadir = ./data/nerf_llff_data/blurwine dataset_type = llff factor = 1 linear = False pose_lrate = 1e-4 novel_view = True factor_pose_novel = 20.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: configs/wine.txt ================================================ expname = wine-linear basedir = ./logs datadir = ./data/nerf_llff_data/blurwine dataset_type = llff factor = 1 linear = True novel_view = True factor_pose_novel = 2.0 i_novel_view = 200000 N_rand = 5000 deblur_images = 7 N_samples = 64 N_importance = 64 use_viewdirs = True raw_noise_std = 1.0 load_weights = False weight_iter = 200000 i_img = 25000 i_video = 200000 i_weights = 10000 ================================================ FILE: load_llff.py ================================================ import numpy as np import os, imageio import torch ########## Slightly modified version of LLFF data loading code ########## see https://github.com/Fyusion/LLFF for original # downsample def _minify(basedir, factors=[], resolutions=[]): # basedir: ./data/nerf_llff_data/fern needtoload = False for r in factors: imgdir = os.path.join(basedir, 'images_{}'.format(r)) if not os.path.exists(imgdir): needtoload = True for r in resolutions: imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) if not os.path.exists(imgdir): needtoload = True if not needtoload: return from shutil import copy from subprocess import check_output imgdir = os.path.join(basedir, 'images') imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] imgdir_orig = imgdir wd = os.getcwd() for r in factors + resolutions: if isinstance(r, int): name = 'images_{}'.format(r) resizearg = '{}%'.format(100. / r) else: name = 'images_{}x{}'.format(r[1], r[0]) resizearg = '{}x{}'.format(r[1], r[0]) imgdir = os.path.join(basedir, name) if os.path.exists(imgdir): continue print('Minifying', r, basedir) os.makedirs(imgdir) check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) ext = imgs[0].split('.')[-1] args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) print(args) os.chdir(imgdir) check_output(args, shell=True) os.chdir(wd) if ext != 'png': check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) print('Removed duplicates') print('Done') def _load_data(basedir, pose_state, factor=None, width=None, height=None, load_imgs=True): poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) bds = poses_arr[:, -2:].transpose([1, 0]) img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] sh = imageio.imread(img0).shape sfx = '' if factor is not None: sfx = '_{}'.format(factor) _minify(basedir, factors=[factor]) factor = factor elif height is not None: factor = sh[0] / float(height) width = int(sh[1] / factor) _minify(basedir, resolutions=[[height, width]]) sfx = '_{}x{}'.format(width, height) elif width is not None: factor = sh[1] / float(width) height = int(sh[0] / factor) _minify(basedir, resolutions=[[height, width]]) sfx = '_{}x{}'.format(width, height) else: factor = 1 imgdir = os.path.join(basedir, 'images' + sfx) if not os.path.exists(imgdir): print(imgdir, 'does not exist, returning') return imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] if poses.shape[-1] != len(imgfiles): print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) # return sh = imageio.imread(imgfiles[0]).shape poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) poses[2, 4, :] = poses[2, 4, :] * 1. / factor if not load_imgs: return poses, bds def imread(f): if f.endswith('png'): return imageio.imread(f, ignoregamma=True) else: return imageio.imread(f) imgs = [imread(f)[..., :3] / 255. for f in imgfiles] imgs = np.stack(imgs, -1) print('Loaded image data', imgs.shape, poses[:, -1, 0]) return poses, bds, imgs def normalize(x): return x / np.linalg.norm(x) def viewmatrix(z, up, pos): vec2 = normalize(z) vec1_avg = up vec0 = normalize(np.cross(vec1_avg, vec2)) vec1 = normalize(np.cross(vec2, vec0)) m = np.stack([vec0, vec1, vec2, pos], 1) return m def ptstocam(pts, c2w): tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0] return tt def poses_avg(poses): hwf = poses[0, :3, -1:] center = poses[:, :3, 3].mean(0) vec2 = normalize(poses[:, :3, 2].sum(0)) up = poses[:, :3, 1].sum(0) c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) return c2w def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): render_poses = [] rads = np.array(list(rads) + [1.]) hwf = c2w[:, 4:5] for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]: c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) return render_poses def recenter_poses(poses): poses_ = poses + 0 bottom = np.reshape([0, 0, 0, 1.], [1, 4]) c2w = poses_avg(poses) c2w = np.concatenate([c2w[:3, :4], bottom], -2) bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) poses = np.concatenate([poses[:, :3, :4], bottom], -2) poses = np.linalg.inv(c2w) @ poses poses_[:, :3, :4] = poses[:, :3, :4] poses = poses_ return poses ##################### def spherify_poses(poses, bds): p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1) rays_d = poses[:, :3, 2:3] rays_o = poses[:, :3, 3:4] def min_line_dist(rays_o, rays_d): A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) b_i = -A_i @ rays_o pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)) return pt_mindist pt_mindist = min_line_dist(rays_o, rays_d) center = pt_mindist up = (poses[:, :3, 3] - center).mean(0) vec0 = normalize(up) vec1 = normalize(np.cross([.1, .2, .3], vec0)) vec2 = normalize(np.cross(vec0, vec1)) pos = center c2w = np.stack([vec1, vec2, vec0, pos], 1) poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) sc = 1. / rad poses_reset[:, :3, 3] *= sc bds *= sc rad *= sc centroid = np.mean(poses_reset[:, :3, 3], 0) zh = centroid[2] radcircle = np.sqrt(rad ** 2 - zh ** 2) new_poses = [] for th in np.linspace(0., 2. * np.pi, 120): camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) up = np.array([0, 0, -1.]) vec2 = normalize(camorigin) vec0 = normalize(np.cross(vec2, up)) vec1 = normalize(np.cross(vec2, vec0)) pos = camorigin p = np.stack([vec0, vec1, vec2, pos], 1) new_poses.append(p) new_poses = np.stack(new_poses, 0) new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1) poses_reset = np.concatenate( [poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1) return poses_reset, new_poses, bds def load_llff_data(basedir, pose_state=None, factor=1, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): poses, bds, imgs = _load_data(basedir, pose_state=pose_state,factor=factor) print('Loaded', basedir, bds.min(), bds.max(), 'Pose State: ', pose_state) poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) poses = np.moveaxis(poses, -1, 0).astype(np.float32) imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) bds = np.moveaxis(bds, -1, 0).astype(np.float32) # Rescale if bd_factor is provided sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor) poses[:, :3, 3] *= sc # T bds *= sc if recenter: poses = recenter_poses(poses) if spherify: poses, render_poses, bds = spherify_poses(poses, bds) else: c2w = poses_avg(poses) up = normalize(poses[:, :3, 1].sum(0)) # Find a reasonable "focus depth" for this dataset close_depth, inf_depth = bds.min() * .9, bds.max() * 5. dt = .75 mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth)) focal = mean_dz # Get radii for spiral path shrink_factor = .8 zdelta = close_depth * .2 tt = poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T rads = np.percentile(np.abs(tt), 90, 0) c2w_path = c2w N_views = 120 N_rots = 2 if path_zflat: # zloc = np.percentile(tt, 10, 0)[2] zloc = -close_depth * .1 c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2] rads[2] = 0. N_rots = 1 N_views /= 2 # Generate poses for spiral path render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) render_poses = np.array(render_poses).astype(np.float32) imgs = torch.Tensor(imgs) poses = torch.Tensor(poses) bds = torch.Tensor(bds) render_poses = torch.Tensor(render_poses) return imgs, poses, bds, render_poses ================================================ FILE: lpips/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import torch # from torch.autograd import Variable from .lpips import * ================================================ FILE: lpips/lpips.py ================================================ from __future__ import absolute_import import torch import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable import numpy as np from . import pretrained_networks as pn import torch.nn def normalize_tensor(in_feat, eps=1e-10): norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) return in_feat / (norm_factor + eps) def l2(p0, p1, range=255.): return .5 * np.mean((p0 / range - p1 / range) ** 2) def psnr(p0, p1, peak=255.): return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) def dssim(p0, p1, range=255.): from skimage.measure import compare_ssim return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. def rgb2lab(in_img, mean_cent=False): from skimage import color img_lab = color.rgb2lab(in_img) if mean_cent: img_lab[:, :, 0] = img_lab[:, :, 0] - 50 return img_lab def tensor2np(tensor_obj): # change dimension of a tensor object into a numpy array return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) def np2tensor(np_obj): # change dimenion of np array into tensor array return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): # image tensor to lab tensor from skimage import color img = tensor2im(image_tensor) img_lab = color.rgb2lab(img) if mc_only: img_lab[:, :, 0] = img_lab[:, :, 0] - 50 if to_norm and not mc_only: img_lab[:, :, 0] = img_lab[:, :, 0] - 50 img_lab = img_lab / 100. return np2tensor(img_lab) def tensorlab2tensor(lab_tensor, return_inbnd=False): from skimage import color import warnings warnings.filterwarnings("ignore") lab = tensor2np(lab_tensor) * 100. lab[:, :, 0] = lab[:, :, 0] + 50 rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) if return_inbnd: # convert back to lab, see if we match lab_back = color.rgb2lab(rgb_back.astype('uint8')) mask = 1. * np.isclose(lab_back, lab, atol=2.) mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) return im2tensor(rgb_back), mask else: return im2tensor(rgb_back) def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor return image_numpy.astype(imtype) def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): return torch.tensor((image / factor - cent) [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) def tensor2vec(vector_tensor): return vector_tensor.data.cpu().numpy()[:, :, 0, 0] def voc_ap(rec, prec, use_07_metric=False): """ ap = voc_ap(rec, prec, [use_07_metric]) Compute VOC AP given precision and recall. If use_07_metric is true, uses the VOC 07 11 point method (default:False). """ if use_07_metric: # 11 point metric ap = 0. for t in np.arange(0., 1.1, 0.1): if np.sum(rec >= t) == 0: p = 0 else: p = np.max(prec[rec >= t]) ap = ap + p / 11. else: # correct AP calculation # first append sentinel values at the end mrec = np.concatenate(([0.], rec, [1.])) mpre = np.concatenate(([0.], prec, [0.])) # compute the precision envelope for i in range(mpre.size - 1, 0, -1): mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) # to calculate area under PR curve, look for points # where X axis (recall) changes value i = np.where(mrec[1:] != mrec[:-1])[0] # and sum (\Delta recall) * prec ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) return ap def spatial_average(in_tens, keepdim=True): return in_tens.mean([2, 3], keepdim=keepdim) def upsample(in_tens, out_HW=(64, 64)): # assumes scale factor is same for H and W in_H, in_W = in_tens.shape[2], in_tens.shape[3] return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) # Learned perceptual metric class LPIPS(nn.Module): def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): # lpips - [True] means with linear calibration on top of base network # pretrained - [True] means load linear weights super(LPIPS, self).__init__() if verbose: print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' % ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.lpips = lpips # false means baseline of just averaging all layers self.version = version self.scaling_layer = ScalingLayer() if self.pnet_type in ['vgg', 'vgg16']: net_type = pn.vgg16 self.chns = [64, 128, 256, 512, 512] elif self.pnet_type == 'alex': net_type = pn.alexnet self.chns = [64, 192, 384, 256, 256] elif self.pnet_type == 'squeeze': net_type = pn.squeezenet self.chns = [64, 128, 256, 384, 384, 512, 512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if lpips: self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] if self.pnet_type == 'squeeze': # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins += [self.lin5, self.lin6] self.lins = nn.ModuleList(self.lins) if pretrained: if model_path is None: import inspect import os model_path = os.path.abspath( os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net))) if verbose: print('Loading model from: %s' % model_path) self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) if eval_mode: self.eval() def forward(self, in0, in1, retPerLayer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 if self.lpips: if self.spatial: res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] else: if self.spatial: res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] val = res[0] for l in range(1, self.L): val += res[l] # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) # b = torch.max(self.lins[kk](feats0[kk]**2)) # for kk in range(self.L): # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) # a = a/self.L # from IPython import embed # embed() # return 10*torch.log10(b/a) if retPerLayer: return val, res else: return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): ''' A single linear layer which does a 1x1 conv ''' def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(), ] if use_dropout else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class Dist2LogitLayer(nn.Module): ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' def __init__(self, chn_mid=32, use_sigmoid=True): super(Dist2LogitLayer, self).__init__() layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] layers += [nn.LeakyReLU(0.2, True), ] layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] layers += [nn.LeakyReLU(0.2, True), ] layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] if use_sigmoid: layers += [nn.Sigmoid(), ] self.model = nn.Sequential(*layers) def forward(self, d0, d1, eps=0.1): return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) class BCERankingLoss(nn.Module): def __init__(self, chn_mid=32): super(BCERankingLoss, self).__init__() self.net = Dist2LogitLayer(chn_mid=chn_mid) # self.parameters = list(self.net.parameters()) self.loss = torch.nn.BCELoss() def forward(self, d0, d1, judge): per = (judge + 1.) / 2. self.logit = self.net.forward(d0, d1) return self.loss(self.logit, per) # L2, DSSIM metrics class FakeNet(nn.Module): def __init__(self, use_gpu=True, colorspace='Lab'): super(FakeNet, self).__init__() self.use_gpu = use_gpu self.colorspace = colorspace class L2(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert (in0.size()[0] == 1) # currently only supports batchSize 1 if self.colorspace == 'RGB': (N, C, X, Y) = in0.size() value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), dim=3).view(N) return value elif self.colorspace == 'Lab': value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype( 'float') ret_var = Variable(torch.Tensor((value,))) if self.use_gpu: ret_var = ret_var.cuda() return ret_var class DSSIM(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert (in0.size()[0] == 1) # currently only supports batchSize 1 if self.colorspace == 'RGB': value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype( 'float') elif self.colorspace == 'Lab': value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)), tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype( 'float') ret_var = Variable(torch.Tensor((value,))) if self.use_gpu: ret_var = ret_var.cuda() return ret_var def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print('Network', net) print('Total number of parameters: %d' % num_params) ================================================ FILE: lpips/pretrained_networks.py ================================================ from collections import namedtuple import torch from torchvision import models as tv class squeezenet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(squeezenet, self).__init__() pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.slice6 = torch.nn.Sequential() self.slice7 = torch.nn.Sequential() self.N_slices = 7 for x in range(2): self.slice1.add_module(str(x), pretrained_features[x]) for x in range(2, 5): self.slice2.add_module(str(x), pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), pretrained_features[x]) for x in range(10, 11): self.slice5.add_module(str(x), pretrained_features[x]) for x in range(11, 12): self.slice6.add_module(str(x), pretrained_features[x]) for x in range(12, 13): self.slice7.add_module(str(x), pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h h = self.slice6(h) h_relu6 = h h = self.slice7(h) h_relu7 = h vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) return out class alexnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(alexnet, self).__init__() alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(2): self.slice1.add_module(str(x), alexnet_pretrained_features[x]) for x in range(2, 5): self.slice2.add_module(str(x), alexnet_pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), alexnet_pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), alexnet_pretrained_features[x]) for x in range(10, 12): self.slice5.add_module(str(x), alexnet_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) return out class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out class resnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True, num=18): super(resnet, self).__init__() if (num == 18): self.net = tv.resnet18(pretrained=pretrained) elif (num == 34): self.net = tv.resnet34(pretrained=pretrained) elif (num == 50): self.net = tv.resnet50(pretrained=pretrained) elif (num == 101): self.net = tv.resnet101(pretrained=pretrained) elif (num == 152): self.net = tv.resnet152(pretrained=pretrained) self.N_slices = 5 self.conv1 = self.net.conv1 self.bn1 = self.net.bn1 self.relu = self.net.relu self.maxpool = self.net.maxpool self.layer1 = self.net.layer1 self.layer2 = self.net.layer2 self.layer3 = self.net.layer3 self.layer4 = self.net.layer4 def forward(self, X): h = self.conv1(X) h = self.bn1(h) h = self.relu(h) h_relu1 = h h = self.maxpool(h) h = self.layer1(h) h_conv2 = h h = self.layer2(h) h_conv3 = h h = self.layer3(h) h_conv4 = h h = self.layer4(h) h_conv5 = h outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) return out ================================================ FILE: metrics.py ================================================ from skimage import metrics import torch import torch.hub from lpips.lpips import LPIPS import os import numpy as np photometric = { "mse": None, "ssim": None, "psnr": None, "lpips": None } def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor, metric="mse", margin=0, mask=None): """ im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1) """ if metric not in photometric.keys(): raise RuntimeError(f"img_utils:: metric {metric} not recognized") if photometric[metric] is None: if metric == "mse": photometric[metric] = metrics.mean_squared_error elif metric == "ssim": photometric[metric] = metrics.structural_similarity elif metric == "psnr": photometric[metric] = metrics.peak_signal_noise_ratio elif metric == "lpips": photometric[metric] = LPIPS().cpu() if mask is not None: if mask.dim() == 3: mask = mask.unsqueeze(1) if mask.shape[1] == 1: mask = mask.expand(-1, 3, -1, -1) mask = mask.permute(0, 2, 3, 1).numpy() batchsz, hei, wid, _ = mask.shape if margin > 0: marginh = int(hei * margin) + 1 marginw = int(wid * margin) + 1 mask = mask[:, marginh:hei - marginh, marginw:wid - marginw] # convert from [0, 1] to [-1, 1] im1t = (im1t * 2 - 1).clamp(-1, 1) im2t = (im2t * 2 - 1).clamp(-1, 1) if im1t.dim() == 3: im1t = im1t.unsqueeze(0) im2t = im2t.unsqueeze(0) im1t = im1t.detach().cpu() im2t = im2t.detach().cpu() if im1t.shape[-1] == 3: im1t = im1t.permute(0, 3, 1, 2) im2t = im2t.permute(0, 3, 1, 2) im1 = im1t.permute(0, 2, 3, 1).numpy() im2 = im2t.permute(0, 2, 3, 1).numpy() batchsz, hei, wid, _ = im1.shape if margin > 0: marginh = int(hei * margin) + 1 marginw = int(wid * margin) + 1 im1 = im1[:, marginh:hei - marginh, marginw:wid - marginw] im2 = im2[:, marginh:hei - marginh, marginw:wid - marginw] values = [] for i in range(batchsz): if metric in ["mse", "psnr"]: if mask is not None: im1 = im1 * mask[i] im2 = im2 * mask[i] value = photometric[metric]( im1[i], im2[i] ) if mask is not None: hei, wid, _ = im1[i].shape pixelnum = mask[i, ..., 0].sum() value = value - 10 * np.log10(hei * wid / pixelnum) elif metric in ["ssim"]: value, ssimmap = photometric["ssim"]( im1[i], im2[i], multichannel=True, full=True ) if mask is not None: value = (ssimmap * mask[i]).sum() / mask[i].sum() elif metric in ["lpips"]: value = photometric[metric]( im1t[i:i + 1], im2t[i:i + 1] ) else: raise NotImplementedError values.append(value) return sum(values) / len(values) ================================================ FILE: nerf.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from run_nerf import * from Spline import se3_to_SE3 max_iter = 200000 T = max_iter+1 BOUNDARY = 20 class Embedder: def __init__(self, **kwargs): self.kwargs = kwargs self.create_embedding_fn() def create_embedding_fn(self): embed_fns = [] d = self.kwargs['input_dims'] out_dim = 0 if self.kwargs['include_input']: embed_fns.append(lambda x: x) out_dim += d max_freq = self.kwargs['max_freq_log2'] N_freqs = self.kwargs['num_freqs'] if self.kwargs['log_sampling']: freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) else: freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) for freq in freq_bands: for p_fn in self.kwargs['periodic_fns']: embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) out_dim += d self.embed_fns = embed_fns self.out_dim = out_dim def embed(self, inputs): return torch.cat([fn(inputs) for fn in self.embed_fns], -1) def get_embedder(args, multires, i=0): if i == -1: return nn.Identity(), 3 embed_kwargs = { 'include_input': False if args.barf else True, 'input_dims': 3, 'max_freq_log2': multires - 1, 'num_freqs': multires, 'log_sampling': True, 'periodic_fns': [torch.sin, torch.cos], } embedder_obj = Embedder(**embed_kwargs) embed = lambda x, eo=embedder_obj: eo.embed(x) return embed, embedder_obj.out_dim class Model(): def __init__(self): super().__init__() def build_network(self, args): self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True) return self.graph def setup_optimizer(self, args): grad_vars = list(self.graph.nerf.parameters()) if args.N_importance>0: grad_vars += list(self.graph.nerf_fine.parameters()) self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) return self.optim class NeRF(nn.Module): def __init__(self, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=False): super().__init__() self.D = D self.W = W self.input_ch = input_ch self.input_ch_views = input_ch_views self.skips = skips self.use_viewdirs = use_viewdirs # network self.pts_linears = nn.ModuleList( [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D - 1)]) self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)]) if use_viewdirs: self.feature_linear = nn.Linear(W, W) self.alpha_linear = nn.Linear(W, 1) self.rgb_linear = nn.Linear(W // 2, 3) else: self.output_linear = nn.Linear(W, output_ch) def forward(self, barf_i, pts, viewdirs, args): embed_fn, input_ch = get_embedder(args, args.multires, args.i_embed) embeddirs_fn = None if args.use_viewdirs: embeddirs_fn, input_ch_views = get_embedder(args, args.multires_views, args.i_embed) pts_flat = torch.reshape(pts, [-1, pts.shape[-1]]) embedded = embed_fn(pts_flat) if viewdirs is not None: input_dirs = viewdirs[:, None].expand(pts.shape) input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) embedded_dirs = embeddirs_fn(input_dirs_flat) embedded = torch.cat([embedded, embedded_dirs], -1) input_pts, input_views = torch.split(embedded, [self.input_ch, self.input_ch_views], dim=-1) h = input_pts for i, l in enumerate(self.pts_linears): h = self.pts_linears[i](h) h = F.relu(h) if i in self.skips: h = torch.cat([input_pts, h], -1) if self.use_viewdirs: alpha = self.alpha_linear(h) feature = self.feature_linear(h) h = torch.cat([feature, input_views], -1) for i, l in enumerate(self.views_linears): h = self.views_linears[i](h) h = F.relu(h) rgb = self.rgb_linear(h) outputs = torch.cat([rgb, alpha], -1) else: outputs = self.output_linear(h) outputs = torch.reshape(outputs, list(pts.shape[:-1]) + [outputs.shape[-1]]) return outputs def raw2output(self, raw, z_vals, rays_d, raw_noise_std=0.0): raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists) dists = z_vals[..., 1:] - z_vals[..., :-1] dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1) dists = dists * torch.norm(rays_d[..., None, :], dim=-1) rgb = torch.sigmoid(raw[..., :3]) noise = 0. if raw_noise_std > 0.: noise = torch.randn(raw[..., 3].shape) * raw_noise_std alpha = raw2alpha(raw[..., 3] + noise, dists) weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:,:-1] rgb_map = torch.sum(weights[..., None] * rgb, -2) depth_map = torch.sum(weights * z_vals, -1) # disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) disp_map = torch.max(1e-6 * torch.ones_like(depth_map), depth_map / (torch.sum(weights, -1)+1e-6)) acc_map = torch.sum(weights, -1) sigma = F.relu(raw[..., 3] + noise) return rgb_map, disp_map, acc_map, weights, depth_map, sigma class Graph(nn.Module): def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=False): super().__init__() self.nerf = NeRF(D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs) if args.N_importance > 0: self.nerf_fine = NeRF(D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs) def forward(self, i, img_idx, poses_num, H, W, K, args, novel_view=False): if novel_view: poses_sharp = se3_to_SE3(self.se3_sharp.weight) ray_idx_sharp = torch.randperm(H * W)[:300] ret = self.render(i, poses_sharp, ray_idx_sharp, H, W, K, args) return ret, ray_idx_sharp, poses_sharp spline_poses = self.get_pose(i, img_idx, args) ray_idx = torch.randperm(H * W)[:args.N_rand // poses_num] ''' # only used in distorted data # aims to prevent the ray_idx lying on the edges for j in range(ray_idx.shape[0]): h = torch.randperm(H - 1)[0] w = torch.randperm(W - 1)[0] while (h < BOUNDARY or h > (H - 1 - BOUNDARY) or w < BOUNDARY or w > (W - 1 - BOUNDARY)): h = torch.randperm(H - 1)[0] w = torch.randperm(W - 1)[0] index = h * W + w ray_idx[j] = index ''' ret = self.render(i, spline_poses, ray_idx, H, W, K, args, near=0, far=1.0, ray_idx_tv=None, training=True) if (i % args.i_img == 0 or i % args.i_novel_view == 0) and i > 0: if args.deblur_images % 2 == 0: all_poses = self.get_pose_even(i, torch.arange(self.se3.weight.shape[0]), args.deblur_images) else: all_poses = self.get_pose(i, torch.arange(self.se3.weight.shape[0]), args) return ret, ray_idx, spline_poses, all_poses else: return ret, ray_idx, spline_poses def get_pose(self, i, img_idx, args): return i def get_gt_pose(self, poses, args): return poses def render(self, barf_i, poses, ray_idx, H, W, K, args, near=0., far=1., ray_idx_tv=None, training=False): if training: ray_idx_ = ray_idx.repeat(poses.shape[0]) poses = poses.unsqueeze(1).repeat(1, ray_idx.shape[0], 1, 1).reshape(-1, 3, 4) j = ray_idx_.reshape(-1, 1).squeeze() // W i = ray_idx_.reshape(-1, 1).squeeze() % W rays_o_, rays_d_ = get_specific_rays(i, j, K, poses) rays_o_d = torch.stack([rays_o_, rays_d_], 0) batch_rays = torch.permute(rays_o_d, [1, 0, 2]) else: rays_list = [] for p in poses[:, :3, :4]: rays_o_, rays_d_ = get_rays(H, W, K, p) rays_o_d = torch.stack([rays_o_, rays_d_], 0) rays_list.append(rays_o_d) rays = torch.stack(rays_list, 0) rays = rays.reshape(-1, 2, H * W, 3) rays = torch.permute(rays, [0, 2, 1, 3]) batch_rays = rays[:, ray_idx] batch_rays = batch_rays.reshape(-1, 2, 3) batch_rays = torch.transpose(batch_rays, 0, 1) # get standard rays rays_o, rays_d = batch_rays if args.use_viewdirs: viewdirs = rays_d viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) viewdirs = torch.reshape(viewdirs, [-1, 3]).float() sh = rays_d.shape if args.ndc: rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) # Create ray batch rays_o = torch.reshape(rays_o, [-1, 3]).float() rays_d = torch.reshape(rays_d, [-1, 3]).float() near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1]) rays = torch.cat([rays_o, rays_d, near, far], -1) if args.use_viewdirs: rays = torch.cat([rays, viewdirs], -1) N_rays = rays.shape[0] rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] viewdirs = rays[:, -3:] if rays.shape[-1] > 8 else None bounds = torch.reshape(rays[..., 6:8], [-1, 1, 2]) near, far = bounds[..., 0], bounds[..., 1] t_vals = torch.linspace(0., 1., steps=args.N_samples) z_vals = near * (1. - t_vals) + far * (t_vals) z_vals = z_vals.expand([N_rays, args.N_samples]) # perturb mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) upper = torch.cat([mids, z_vals[..., -1:]], -1) lower = torch.cat([z_vals[..., :1], mids], -1) # stratified samples in those intervals t_rand = torch.rand(z_vals.shape) z_vals = lower + (upper - lower) * t_rand pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] raw_output = self.nerf.forward(barf_i, pts, viewdirs, args) rgb_map, disp_map, acc_map, weights, depth_map, sigma = self.nerf.raw2output(raw_output, z_vals, rays_d) if args.N_importance > 0: rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], args.N_importance) z_samples = z_samples.detach() z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None] raw_output = self.nerf_fine.forward(barf_i, pts, viewdirs, args) rgb_map, disp_map, acc_map, weights, depth_map, sigma = self.nerf_fine.raw2output(raw_output, z_vals, rays_d) ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map} if args.N_importance > 0: ret['rgb0'] = rgb_map_0 ret['disp0'] = disp_map_0 ret['acc0'] = acc_map_0 ret['sigma'] = sigma return ret @torch.no_grad() def render_video(self, barf_i, poses, H, W, K, args): all_ret = {} ray_idx = torch.arange(0, H*W) for i in range(0, ray_idx.shape[0], args.chunk): ret = self.render(barf_i, poses, ray_idx[i:i+args.chunk], H, W, K, args) for k in ret: if k not in all_ret: all_ret[k] = [] all_ret[k].append(ret[k]) all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} for k in all_ret: k_sh = list([H, W]) + list(all_ret[k].shape[1:]) all_ret[k] = torch.reshape(all_ret[k], k_sh) return all_ret ================================================ FILE: novel_view_test.py ================================================ import nerf import torch.nn class Model(nerf.Model): def __init__(self, se3_start, graph): super().__init__() self.start = se3_start self.graph_fixed = graph def build_network(self, args): self.graph_fixed.se3_sharp = torch.nn.Embedding(self.start.shape[0], 6) # 22和25 self.graph_fixed.se3_sharp.weight.data = torch.nn.Parameter(self.start) return self.graph_fixed def setup_optimizer(self, args): grad_vars_se3 = list(self.graph_fixed.se3_sharp.parameters()) self.optim_se3_sharp = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate) return self.optim_se3_sharp ================================================ FILE: optimize_pose_cubic.py ================================================ import torch.nn import Spline import nerf class Model(nerf.Model): def __init__(self, se3_0, se3_1, se3_2, se3_3): super().__init__() self.se3_0 = se3_0 self.se3_1 = se3_1 self.se3_2 = se3_2 self.se3_3 = se3_3 def build_network(self, args): self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True) self.graph.se3 = torch.nn.Embedding(self.se3_0.shape[0], 6*4) start_end = torch.cat([self.se3_0, self.se3_1, self.se3_2, self.se3_3], -1) self.graph.se3.weight.data = torch.nn.Parameter(start_end) return self.graph def setup_optimizer(self, args): grad_vars = list(self.graph.nerf.parameters()) if args.N_importance > 0: grad_vars += list(self.graph.nerf_fine.parameters()) self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) grad_vars_se3 = list(self.graph.se3.parameters()) self.optim_se3 = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate) return self.optim, self.optim_se3 class Graph(nerf.Graph): def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True): super().__init__(args, D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs) self.pose_eye = torch.eye(3, 4) self.se3_start = None self.se3_end = None def get_pose(self, i, img_idx, args): se3_0 = self.se3.weight[:, :6][img_idx] se3_1 = self.se3.weight[:, 6:12][img_idx] se3_2 = self.se3.weight[:, 12:18][img_idx] se3_3 = self.se3.weight[:, 18:][img_idx] pose_nums = torch.arange(args.deblur_images).reshape(1, -1).repeat(se3_0.shape[0], 1) seg_pos_x = torch.arange(se3_0.shape[0]).reshape([se3_0.shape[0], 1]).repeat(1, args.deblur_images) se3_0 = se3_0[seg_pos_x, :] se3_1 = se3_1[seg_pos_x, :] se3_2 = se3_2[seg_pos_x, :] se3_3 = se3_3[seg_pos_x, :] spline_poses = Spline.SplineN_cubic(se3_0, se3_1, se3_2, se3_3, pose_nums, args.deblur_images) return spline_poses def get_pose_even(self, i, img_idx, num): deblur_images_num = num+1 se3_0 = self.se3.weight[:, :6][img_idx] se3_1 = self.se3.weight[:, 6:12][img_idx] se3_2 = self.se3.weight[:, 12:18][img_idx] se3_3 = self.se3.weight[:, 18:][img_idx] pose_nums = torch.arange(deblur_images_num).reshape(1, -1).repeat(se3_0.shape[0], 1) seg_pos_x = torch.arange(se3_0.shape[0]).reshape([se3_0.shape[0], 1]).repeat(1, deblur_images_num) se3_0 = se3_0[seg_pos_x, :] se3_1 = se3_1[seg_pos_x, :] se3_2 = se3_2[seg_pos_x, :] se3_3 = se3_3[seg_pos_x, :] spline_poses = Spline.SplineN_cubic(se3_0, se3_1, se3_2, se3_3, pose_nums, deblur_images_num) return spline_poses def get_gt_pose(self, poses, args): a = self.pose_eye return poses ================================================ FILE: optimize_pose_linear.py ================================================ import torch.nn import Spline import nerf class Model(nerf.Model): def __init__(self, se3_start, se3_end): super().__init__() self.start = se3_start self.end = se3_end def build_network(self, args): self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True) self.graph.se3 = torch.nn.Embedding(self.start.shape[0], 6*2) start_end = torch.cat([self.start, self.end], -1) self.graph.se3.weight.data = torch.nn.Parameter(start_end) return self.graph def setup_optimizer(self, args): grad_vars = list(self.graph.nerf.parameters()) if args.N_importance > 0: grad_vars += list(self.graph.nerf_fine.parameters()) self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) grad_vars_se3 = list(self.graph.se3.parameters()) self.optim_se3 = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate) return self.optim, self.optim_se3 class Graph(nerf.Graph): def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True): super().__init__(args, D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs) self.pose_eye = torch.eye(3, 4) self.se3_start = None self.se3_end = None def get_pose(self, i, img_idx, args): se3_start = self.se3.weight[:, :6][img_idx] se3_end = self.se3.weight[:, 6:][img_idx] pose_nums = torch.arange(args.deblur_images).reshape(1, -1).repeat(se3_start.shape[0], 1) seg_pos_x = torch.arange(se3_start.shape[0]).reshape([se3_start.shape[0], 1]).repeat(1, args.deblur_images) se3_start = se3_start[seg_pos_x, :] se3_end = se3_end[seg_pos_x, :] spline_poses = Spline.SplineN_linear(se3_start, se3_end, pose_nums, args.deblur_images) return spline_poses def get_pose_even(self, i, img_idx, num): deblur_images_num = num+1 se3_start = self.se3.weight[:, :6][img_idx] se3_end = self.se3.weight[:, 6:][img_idx] pose_nums = torch.arange(deblur_images_num).reshape(1, -1).repeat(se3_start.shape[0],1) seg_pos_x = torch.arange(se3_start.shape[0]).reshape([se3_start.shape[0], 1]).repeat(1, deblur_images_num) se3_start = se3_start[seg_pos_x, :] se3_end = se3_end[seg_pos_x, :] spline_poses = Spline.SplineN_linear(se3_start, se3_end, pose_nums, deblur_images_num) return spline_poses def get_gt_pose(self, poses, args): a = self.pose_eye return poses ================================================ FILE: requirements.txt ================================================ configargparse imageio<2.28.0,>=2.26.0 imageio-ffmpeg matplotlib numpy scikit-learn<1 scikit-image<0.20,>=0.19 torch>=1.8 torchvision>=0.9.1 tqdm ================================================ FILE: run_nerf.py ================================================ import os, sys import numpy as np import imageio import json import random import time import torch from run_nerf_helpers import * from Spline import * import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm, trange import matplotlib.pyplot as plt from load_llff import load_llff_data import torchvision.transforms.functional as torchvision_F device = torch.device("cuda" if torch.cuda.is_available() else "cpu") np.random.seed(0) DEBUG = False def config_parser(): import configargparse parser = configargparse.ArgumentParser() parser.add_argument('--config', is_config_file=True, default='configs/cozy2room.txt', help='config file path') parser.add_argument("--expname", type=str, help='experiment name') parser.add_argument("--basedir", type=str, default='./logs/', help='where to store ckpts and logs') parser.add_argument("--datadir", type=str, default='./data/llff/fern', help='input data directory') # training options parser.add_argument("--N_iters", type=int, default=200000, help='the number of sharp images one blur image corresponds to') parser.add_argument("--deblur_images", type=int, default=5, help='the number of sharp images one blur image corresponds to') parser.add_argument("--skip", type=int, default=8, help='original llffhold before concatenate images') parser.add_argument("--netdepth", type=int, default=8, help='layers in network') parser.add_argument("--netwidth", type=int, default=256, help='channels per layer') parser.add_argument("--netdepth_fine", type=int, default=8, help='layers in fine network') parser.add_argument("--netwidth_fine", type=int, default=256, help='channels per layer in fine network') parser.add_argument("--N_rand", type=int, default=32*32*4, help='batch size (number of random rays per gradient step)') parser.add_argument("--lrate", type=float, default=5e-4, help='learning rate') parser.add_argument("--pose_lrate", type=float, default=1e-3, help='learning rate') parser.add_argument("--lrate_decay", type=int, default=200, help='exponential learning rate decay (in 1000 steps)') parser.add_argument("--chunk", type=int, default=1024*2, help='number of rays processed in parallel, decrease if running out of memory') parser.add_argument("--netchunk", type=int, default=1024*32, help='number of pts sent through network in parallel, decrease if running out of memory') parser.add_argument("--no_batching", action='store_true', help='only take random rays from 1 image at a time') parser.add_argument("--no_reload", action='store_true', help='do not reload weights from saved ckpt') parser.add_argument("--ft_path", type=str, default=None, help='specific weights npy file to reload for coarse network') # rendering options parser.add_argument("--N_samples", type=int, default=64, help='number of coarse samples per ray') parser.add_argument("--N_importance", type=int, default=0, help='number of additional fine samples per ray') parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter') parser.add_argument("--use_viewdirs", action='store_true', help='use full 5D input instead of 3D') parser.add_argument("--i_embed", type=int, default=0, help='set 0 for default positional encoding, -1 for none') parser.add_argument("--multires", type=int, default=10, help='log2 of max freq for positional encoding (3D location)') parser.add_argument("--multires_views", type=int, default=4, help='log2 of max freq for positional encoding (2D direction)') parser.add_argument("--raw_noise_std", type=float, default=0., help='std dev of noise added to regularize sigma_a output, 1e0 recommended') parser.add_argument("--render_only", action='store_true', help='do not optimize, reload weights and render out render_poses path') parser.add_argument("--render_test", action='store_true', help='render the test set instead of render_poses path') parser.add_argument("--render_factor", type=int, default=0, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') parser.add_argument("--ndc", type=bool, default=True, help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') # training options parser.add_argument("--precrop_iters", type=int, default=0, help='number of steps to train on central crops') parser.add_argument("--precrop_frac", type=float, default=.5, help='fraction of img taken for central crops') # dataset options parser.add_argument("--dataset_type", type=str, default='llff', help='options: llff / blender / deepvoxels') parser.add_argument("--testskip", type=int, default=8, help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') ## deepvoxels flags parser.add_argument("--shape", type=str, default='greek', help='options : armchair / cube / greek / vase') ## blender flags parser.add_argument("--white_bkgd", action='store_true', help='set to render synthetic data on a white bkgd (always use for dvoxels)') parser.add_argument("--half_res", action='store_true', help='load blender synthetic data at 400x400 instead of 800x800') ## llff flags parser.add_argument("--factor", type=int, default=8, help='downsample factor for LLFF images') parser.add_argument("--no_ndc", action='store_true', help='do not use normalized device coordinates (set for non-forward facing scenes)') parser.add_argument("--lindisp", action='store_true', help='sampling linearly in disparity rather than depth') parser.add_argument("--spherify", action='store_true', help='set for spherical 360 scenes') parser.add_argument("--llffhold", type=int, default=8, help='will take every 1/N images as LLFF test set, paper uses 8') # logging/saving options parser.add_argument("--i_print", type=int, default=100, help='frequency of console printout and metric loggin') parser.add_argument("--i_img", type=int, default=15000, help='frequency of tensorboard image logging') parser.add_argument("--i_weights", type=int, default=10000, help='frequency of weight ckpt saving') parser.add_argument("--i_testset", type=int, default=50000, help='frequency of testset saving') parser.add_argument("--i_video", type=int, default=30000, help='frequency of render_poses video saving') parser.add_argument("--load_weights", action='store_true', help='frequency of weight ckpt loading') # barf: up & down parser.add_argument("--barf", action='store_true', help='barf') parser.add_argument("--barf_start", type=float, default=0.1, help='barf start') parser.add_argument("--barf_end", type=float, default=0.9, help='barf start') # test option parser.add_argument("--only_optim_one", action='store_true', default=False, help='frequency of weight ckpt loading') parser.add_argument("--split_train_data", action='store_true', default=False, help='frequency of weight ckpt loading') parser.add_argument("--weight_iter", type=int, default=20000, help='weight_iter') # pose noise parser.add_argument("--pose_noise", type=float, default=0.1, help='random noise of pose') # linear parser.add_argument("--linear", action='store_true', default=False, help='linear or cubic spline') # novel view parser.add_argument("--novel_view", action='store_true', default=False, help='novel view test') parser.add_argument("--i_novel_view", type=int, default=200000, help='novel view iter') parser.add_argument("--factor_pose_novel", type=float, default=2.0, help='factor of learning rate') parser.add_argument("--N_novel_view", type=int, default=20000, help='novel view iter for optimizing poses') return parser ================================================ FILE: run_nerf_helpers.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tqdm import tqdm, trange import os import imageio # Misc img2mse = lambda x, y : torch.mean((x - y) ** 2) img2se = lambda x, y : (x - y) ** 2 mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) # logab = logcb / logca to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) to8b_tensor = lambda x : (255*torch.clip(x,0,1)).type(torch.int) def imread(f): if f.endswith('png'): return imageio.imread(f, ignoregamma=True) else: return imageio.imread(f) def load_imgs(path): imgfiles = [os.path.join(path, f) for f in sorted(os.listdir(path)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] imgs = [imread(f)[..., :3] / 255. for f in imgfiles] imgs = np.stack(imgs, -1) imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) imgs = imgs.astype(np.float32) imgs = torch.tensor(imgs).cuda() return imgs # Ray helpers def get_rays(H, W, K, c2w): i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' i = i.t() j = j.t() dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) # Rotate ray directions from camera frame to the world frame rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] # Translate camera frame's origin to the world frame. It is the origin of all rays. rays_o = c2w[:3,-1].expand(rays_d.shape) return rays_o, rays_d def ndc_rays(H, W, focal, near, rays_o, rays_d): # Shift ray origins to near plane t = -(near + rays_o[...,2]) / rays_d[...,2] rays_o = rays_o + t[...,None] * rays_d # Projection o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] o2 = 1. + 2. * near / rays_o[...,2] d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) d2 = -2. * near / rays_o[...,2] rays_o = torch.stack([o0,o1,o2], -1) rays_d = torch.stack([d0,d1,d2], -1) return rays_o, rays_d # Hierarchical sampling (section 5.2) def sample_pdf(bins, weights, N_samples, det=False, pytest=False): # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) # Take uniform samples if det: u = torch.linspace(0., 1., steps=N_samples) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) # Pytest, overwrite u with numpy's fixed random numbers if pytest: np.random.seed(0) new_shape = list(cdf.shape[:-1]) + [N_samples] if det: u = np.linspace(0., 1., N_samples) u = np.broadcast_to(u, new_shape) else: u = np.random.rand(*new_shape) u = torch.Tensor(u) # Invert CDF u = u.contiguous() inds = torch.searchsorted(cdf, u, right=True) below = torch.max(torch.zeros_like(inds-1), inds-1) above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[...,1]-cdf_g[...,0]) denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) t = (u-cdf_g[...,0])/denom samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) return samples def render_video_test(i_, graph, render_poses, H, W, K, args): rgbs = [] disps = [] # t = time.time() for i, pose in enumerate(tqdm(render_poses)): # print(i, time.time() - t) # t = time.time() pose = pose[None, :3, :4] ret = graph.render_video(i_, pose[:3, :4], H, W, K, args) rgbs.append(ret['rgb_map'].cpu().numpy()) disps.append(ret['disp_map'].cpu().numpy()) if i==0: print(ret['rgb_map'].shape, ret['disp_map'].shape) rgbs = np.stack(rgbs, 0) disps = np.stack(disps, 0) return rgbs, disps to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) def render_image_test(i, graph, render_poses, H, W, K, args, novel_view=False, need_depth=False): if novel_view: img_dir = os.path.join(args.basedir, args.expname, 'img_novel_{:06d}'.format(i)) else: img_dir = os.path.join(args.basedir, args.expname, 'img_test_{:06d}'.format(i)) os.makedirs(img_dir, exist_ok=True) imgs =[] for j, pose in enumerate(tqdm(render_poses)): # print(i, time.time() - t) # t = time.time() pose = pose[None, :3, :4] ret = graph.render_video(i, pose[:3, :4], H, W, K, args) imgs.append(ret['rgb_map']) rgbs = ret['rgb_map'].cpu().numpy() rgb8 = to8b(rgbs) imageio.imwrite(os.path.join(img_dir, 'rgb_{:03d}.png'.format(j)), rgb8) if need_depth: depths = ret['disp_map'].cpu().numpy() depths_ = depths/np.max(depths) depth8 = to8b(depths_) imageio.imwrite(os.path.join(img_dir, 'depth_{:03d}.png'.format(j)), depth8) imgs = torch.stack(imgs, 0) return imgs def init_weights(linear): # use Xavier init instead of Kaiming init torch.nn.init.kaiming_normal_(linear.weight) torch.nn.init.zeros_(linear.bias) def init_nerf(nerf): for linear_pt in nerf.pts_linears: init_weights(linear_pt) for linear_view in nerf.views_linears: init_weights(linear_view) init_weights(nerf.feature_linear) init_weights(nerf.alpha_linear) init_weights(nerf.rgb_linear) # Ray helpers only get specific rays def get_specific_rays(i, j, K, c2w): # i, j = torch.meshgrid(torch.linspace(0, W - 1, W), # torch.linspace(0, H - 1, H)) # pytorch's meshgrid has indexing='ij' # i = i.t() # j = j.t() dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1) # Rotate ray directions from camera frame to the world frame rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[..., :3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] # Translate camera frame's origin to the world frame. It is the origin of all rays. rays_o = c2w[..., :3, -1] return rays_o, rays_d def save_render_pose(poses, path): poses_np = poses.cpu().detach().numpy() N = poses_np.shape[0] bottom = np.reshape([0., 0., 0., 1.], [1, 4]) bottom_all = np.expand_dims(bottom, 0).repeat(N, axis=0) poses_Rt = np.concatenate([poses_np, bottom_all], 1) poses_txt = os.path.join(path, 'poses_render.txt') for j in range(poses_np.shape[0]): poses_flat = poses_Rt[j].reshape(16, 1).squeeze() for k in range(16): with open(poses_txt, 'a') as outfile: if k == 0: outfile.write(f"pose{j} ") if k != 15: outfile.write(f"{poses_flat[k]} ") if k == 15: outfile.write(f"{poses_flat[k]}\n") ================================================ FILE: test.py ================================================ import torch from nerf import * import optimize_pose_linear, optimize_pose_cubic import torchvision.transforms.functional as torchvision_F import matplotlib.pyplot as plt from metrics import compute_img_metric import novel_view_test def test(): parser = config_parser() args = parser.parse_args() print('spline numbers: ', args.deblur_images) imgs_sharp_dir = os.path.join(args.datadir, 'images_test') imgs_sharp = load_imgs(imgs_sharp_dir) # Load data images and groundtruth K = None if args.dataset_type == 'llff': images_all, poses_start, bds_start, render_poses = load_llff_data(args.datadir, pose_state=None, factor=args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses_start[0, :3, -1] # split train/val/test if args.novel_view: i_test = torch.arange(0, images_all.shape[0], args.llffhold) else: i_test = torch.tensor([100]).long() i_val = i_test i_train = torch.Tensor([i for i in torch.arange(int(images_all.shape[0])) if (i not in i_test and i not in i_val)]).long() # train data images = images_all[i_train] # novel view data if args.novel_view: images_novel = images_all[i_test] # gt data imgs_sharp = imgs_sharp # get poses poses_end = poses_start poses_start_se3 = SE3_to_se3_N(poses_start[:, :3, :4]) poses_end_se3 = poses_start_se3 poses_org = poses_start.repeat(args.deblur_images, 1, 1) poses = poses_org[:, :, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) print('DEFINING BOUNDS') if args.no_ndc: near = torch.min(bds_start) * .9 far = torch.max(bds_start) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if K is None: K = torch.Tensor([ [focal, 0, 0.5 * W], [0, focal, 0.5 * H], [0, 0, 1] ]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt') test_metric_file_novel = os.path.join(basedir, expname, 'test_metrics_novel.txt') # print_file = os.path.join(basedir, expname, 'print.txt') os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) if args.linear: print('Linear Spline Model Loading!') model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3) else: print('Cubic Spline Model Loading!') model = optimize_pose_cubic.Model(poses_start_se3, poses_start_se3, poses_start_se3, poses_start_se3) graph = model.build_network(args) optimizer, optimizer_se3 = model.setup_optimizer(args) path = os.path.join(basedir, expname, '{:06d}.tar'.format(args.weight_iter)) graph_ckpt = torch.load(path) graph.load_state_dict(graph_ckpt['graph']) optimizer.load_state_dict(graph_ckpt['optimizer']) optimizer_se3.load_state_dict(graph_ckpt['optimizer_se3']) global_step = graph_ckpt['global_step'] if args.deblur_images % 2 == 0: all_poses = graph.get_pose_even(0, torch.arange(graph.se3.weight.shape[0]), args.deblur_images) else: all_poses = graph.get_pose(0, torch.arange(graph.se3.weight.shape[0]), args) # Turn on testing mode with torch.no_grad(): if args.deblur_images % 2 == 0: i_render = torch.arange(i_train.shape[0]) * (args.deblur_images + 1) + args.deblur_images // 2 else: i_render = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2 imgs_render = render_image_test(0, graph, all_poses[i_render], H, W, K, args) mse_render = compute_img_metric(imgs_sharp, imgs_render, 'mse') psnr_render = compute_img_metric(imgs_sharp, imgs_render, 'psnr') ssim_render = compute_img_metric(imgs_sharp, imgs_render, 'ssim') lpips_render = compute_img_metric(imgs_sharp, imgs_render, 'lpips') with open(test_metric_file, 'a') as outfile: outfile.write(f"test: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}" f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n") # Turn on novel view testing mode if args.novel_view: i_ = torch.arange(0, images.shape[0], args.llffhold - 1) poses_test_se3_ = graph.se3.weight[i_, :6] model_test = novel_view_test.Model(poses_test_se3_, graph) graph_test = model_test.build_network(args) optimizer_test = model_test.setup_optimizer(args) for j in range(args.N_novel_view): ret_sharp, ray_idx_sharp, poses_sharp = graph_test.forward(0, 0, 0, H, W, K, args, novel_view=True) target_s_novel = images_novel.reshape(-1, H * W, 3)[:, ray_idx_sharp] target_s_novel = target_s_novel.reshape(-1, 3) loss_sharp = img2mse(ret_sharp['rgb_map'], target_s_novel) psnr_sharp = mse2psnr(loss_sharp) if 'rgb0' in ret_sharp: img_loss0 = img2mse(ret_sharp['rgb0'], target_s_novel) loss_sharp = loss_sharp + img_loss0 if j % 100 == 0: print(psnr_sharp.item(), loss_sharp.item()) optimizer_test.zero_grad() loss_sharp.backward() optimizer_test.step() decay_rate_sharp = 0.01 decay_steps_sharp = args.lrate_decay * 100 new_lrate_novel = args.pose_lrate * (decay_rate_sharp ** (j / decay_steps_sharp)) for param_group in optimizer_test.param_groups: if (j / decay_steps_sharp) <= 1.: param_group['lr'] = new_lrate_novel * args.factor_pose_novel with torch.no_grad(): imgs_render_novel = render_image_test(0, graph, poses_sharp, H, W, K, args, novel_view=True) mse_render = compute_img_metric(images_novel, imgs_render_novel, 'mse') psnr_render = compute_img_metric(images_novel, imgs_render_novel, 'psnr') ssim_render = compute_img_metric(images_novel, imgs_render_novel, 'ssim') lpips_render = compute_img_metric(images_novel, imgs_render_novel, 'lpips') with open(test_metric_file_novel, 'a') as outfile: outfile.write(f"novel view test: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}" f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n") return 0 if __name__=='__main__': torch.set_default_tensor_type('torch.cuda.FloatTensor') test() ================================================ FILE: train.py ================================================ import torch from nerf import * import optimize_pose_linear, optimize_pose_cubic import torchvision.transforms.functional as torchvision_F import matplotlib.pyplot as plt from metrics import compute_img_metric import novel_view_test def train(): parser = config_parser() args = parser.parse_args() print('spline numbers: ', args.deblur_images) imgs_sharp_dir = os.path.join(args.datadir, 'images_test') imgs_sharp = load_imgs(imgs_sharp_dir) # Load data images and groundtruth K = None if args.dataset_type == 'llff': images_all, poses_start, bds_start, render_poses = load_llff_data(args.datadir, pose_state=None, factor=args.factor, recenter=True, bd_factor=.75, spherify=args.spherify) hwf = poses_start[0, :3, -1] # split train/val/test if args.novel_view: i_test = torch.arange(0, images_all.shape[0], args.llffhold) else: i_test = torch.tensor([100]).long() i_val = i_test i_train = torch.Tensor([i for i in torch.arange(int(images_all.shape[0])) if (i not in i_test and i not in i_val)]).long() # train data images = images_all[i_train] # novel view data if args.novel_view: images_novel = images_all[i_test] # gt data imgs_sharp = imgs_sharp # get poses poses_end = poses_start poses_start_se3 = SE3_to_se3_N(poses_start[:, :3, :4]) poses_end_se3 = poses_start_se3 poses_org = poses_start.repeat(args.deblur_images, 1, 1) poses = poses_org[:, :, :4] print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) print('DEFINING BOUNDS') if args.no_ndc: near = torch.min(bds_start) * .9 far = torch.max(bds_start) * 1. else: near = 0. far = 1. print('NEAR FAR', near, far) else: print('Unknown dataset type', args.dataset_type, 'exiting') return # Cast intrinsics to right types H, W, focal = hwf H, W = int(H), int(W) hwf = [H, W, focal] if K is None: K = torch.Tensor([ [focal, 0, 0.5 * W], [0, focal, 0.5 * H], [0, 0, 1] ]) # Create log dir and copy the config file basedir = args.basedir expname = args.expname test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt') test_metric_file_novel = os.path.join(basedir, expname, 'test_metrics_novel.txt') print_file = os.path.join(basedir, expname, 'print.txt') os.makedirs(os.path.join(basedir, expname), exist_ok=True) f = os.path.join(basedir, expname, 'args.txt') with open(f, 'w') as file: for arg in sorted(vars(args)): attr = getattr(args, arg) file.write('{} = {}\n'.format(arg, attr)) if args.config is not None: f = os.path.join(basedir, expname, 'config.txt') with open(f, 'w') as file: file.write(open(args.config, 'r').read()) if args.load_weights: if args.linear: print('Linear Spline Model Loading!') model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3) else: print('Cubic Spline Model Loading!') model = optimize_pose_cubic.Model(poses_start_se3, poses_start_se3, poses_start_se3, poses_start_se3) graph = model.build_network(args) optimizer, optimizer_se3 = model.setup_optimizer(args) path = os.path.join(basedir, expname, '{:06d}.tar'.format(args.weight_iter)) # here graph_ckpt = torch.load(path) graph.load_state_dict(graph_ckpt['graph']) optimizer.load_state_dict(graph_ckpt['optimizer']) optimizer_se3.load_state_dict(graph_ckpt['optimizer_se3']) global_step = graph_ckpt['global_step'] else: if args.linear: low, high = 0.0001, 0.005 rand = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low poses_start_se3 = poses_start_se3 + rand model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3) else: low, high = 0.0001, 0.01 rand1 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low rand2 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low rand3 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low poses_se3_1 = poses_start_se3 + rand1 poses_se3_2 = poses_start_se3 + rand2 poses_se3_3 = poses_start_se3 + rand3 model = optimize_pose_cubic.Model(poses_start_se3, poses_se3_1, poses_se3_2, poses_se3_3) graph = model.build_network(args) # nerf, nerf_fine, forward optimizer, optimizer_se3 = model.setup_optimizer(args) N_iters = args.N_iters + 1 print('Begin') print('TRAIN views are', i_train) print('TEST views are', i_test) print('VAL views are', i_val) start = 0 if not args.load_weights: global_step = start global_step_ = global_step threshold = N_iters + 1 poses_num = poses.shape[0] for i in trange(start, threshold): ### core optimization loop ### i = i+global_step_ if i == 0: init_nerf(graph.nerf) init_nerf(graph.nerf_fine) img_idx = torch.randperm(images.shape[0]) if (i % args.i_img == 0 or i % args.i_novel_view == 0) and i > 0: ret, ray_idx, spline_poses, all_poses = graph.forward(i, img_idx, poses_num, H, W, K, args) else: ret, ray_idx, spline_poses = graph.forward(i, img_idx, poses_num, H, W, K, args) # get image ground truth target_s = images[img_idx].reshape(-1, H * W, 3) target_s = target_s[:, ray_idx] target_s = target_s.reshape(-1, 3) # average shape0 = img_idx.shape[0] interval = target_s.shape[0] // shape0 rgb_list = [] extras_list = [] rgb_ = 0 extras_ = 0 for j in range(0, shape0 * args.deblur_images): rgb_ += ret['rgb_map'][j * interval:(j + 1) * interval] if 'rgb0' in ret: extras_ += ret['rgb0'][j * interval:(j + 1) * interval] if (j + 1) % args.deblur_images == 0: rgb_ = rgb_ / args.deblur_images rgb_list.append(rgb_) rgb_ = 0 if 'rgb0' in ret: extras_ = extras_ / args.deblur_images extras_list.append(extras_) extras_ = 0 rgb_blur = torch.stack(rgb_list, 0) rgb_blur = rgb_blur.reshape(-1, 3) if 'rgb0' in ret: extras_blur = torch.stack(extras_list, 0) extras_blur = extras_blur.reshape(-1, 3) # backward optimizer_se3.zero_grad() optimizer.zero_grad() img_loss = img2mse(rgb_blur, target_s) loss = img_loss psnr = mse2psnr(img_loss) if 'rgb0' in ret: img_loss0 = img2mse(extras_blur, target_s) loss = loss + img_loss0 psnr0 = mse2psnr(img_loss0) loss.backward() optimizer.step() optimizer_se3.step() # NOTE: IMPORTANT! ### update learning rate ### decay_rate = 0.1 decay_steps = args.lrate_decay * 1000 new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) for param_group in optimizer.param_groups: param_group['lr'] = new_lrate decay_rate_pose = 0.01 new_lrate_pose = args.pose_lrate * (decay_rate_pose ** (global_step / decay_steps)) for param_group in optimizer_se3.param_groups: param_group['lr'] = new_lrate_pose ############################### if i%args.i_print==0: tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} coarse_loss:, {img_loss0.item()}, PSNR: {psnr.item()}") with open(print_file, 'a') as outfile: outfile.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} coarse_loss:, {img_loss0.item()}, PSNR: {psnr.item()}\n") if i < 10: print('coarse_loss:', img_loss0.item()) with open(print_file, 'a') as outfile: outfile.write(f"coarse loss: {img_loss0.item()}\n") if i % args.i_weights == 0 and i > 0: path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) torch.save({ 'global_step': global_step, 'graph': graph.state_dict(), 'optimizer': optimizer.state_dict(), 'optimizer_se3': optimizer_se3.state_dict(), }, path) print('Saved checkpoints at', path) if i % args.i_img == 0 and i > 0: # Turn on testing mode with torch.no_grad(): if args.deblur_images % 2 == 0: i_render = torch.arange(i_train.shape[0]) * (args.deblur_images+1) + args.deblur_images // 2 else: i_render = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2 imgs_render = render_image_test(i, graph, all_poses[i_render], H, W, K, args) mse_render = compute_img_metric(imgs_sharp, imgs_render, 'mse') psnr_render = compute_img_metric(imgs_sharp, imgs_render, 'psnr') ssim_render = compute_img_metric(imgs_sharp, imgs_render, 'ssim') lpips_render = compute_img_metric(imgs_sharp, imgs_render, 'lpips') with open(test_metric_file, 'a') as outfile: outfile.write(f"iter{i}: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}" f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n") if i % args.i_video == 0 and i > 0: # Turn on testing mode with torch.no_grad(): rgbs, disps = render_video_test(i, graph, render_poses, H, W, K, args) print('Done, saving', rgbs.shape, disps.shape) moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i)) imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8) imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8) if args.novel_view and i % args.i_novel_view == 0 and i > 0: # Turn on novel view testing mode i_ = torch.arange(0, images.shape[0], args.llffhold-1) poses_test_se3_ = graph.se3.weight[i_,:6] model_test = novel_view_test.Model(poses_test_se3_, graph) graph_test = model_test.build_network(args) optimizer_test = model_test.setup_optimizer(args) for j in range(args.N_novel_view): ret_sharp, ray_idx_sharp, poses_sharp = graph_test.forward(i, img_idx, poses_num, H, W, K, args, novel_view=True) target_s_novel = images_novel.reshape(-1, H*W, 3)[:, ray_idx_sharp] target_s_novel = target_s_novel.reshape(-1, 3) loss_sharp = img2mse(ret_sharp['rgb_map'], target_s_novel) psnr_sharp = mse2psnr(loss_sharp) if 'rgb0' in ret_sharp: img_loss0 = img2mse(ret_sharp['rgb0'], target_s_novel) loss_sharp = loss_sharp + img_loss0 if j%100==0: print(psnr_sharp.item(), loss_sharp.item()) optimizer_test.zero_grad() loss_sharp.backward() optimizer_test.step() decay_rate_sharp = 0.01 decay_steps_sharp = args.lrate_decay * 100 new_lrate_novel = args.pose_lrate * (decay_rate_sharp ** (j / decay_steps_sharp)) for param_group in optimizer_test.param_groups: if (j / decay_steps_sharp) <= 1.: param_group['lr'] = new_lrate_novel * args.factor_pose_novel with torch.no_grad(): imgs_render_novel = render_image_test(i, graph, poses_sharp, H, W, K, args, novel_view=True) mse_render = compute_img_metric(images_novel, imgs_render_novel, 'mse') psnr_render = compute_img_metric(images_novel, imgs_render_novel, 'psnr') ssim_render = compute_img_metric(images_novel, imgs_render_novel, 'ssim') lpips_render = compute_img_metric(images_novel, imgs_render_novel, 'lpips') with open(test_metric_file_novel, 'a') as outfile: outfile.write(f"iter{i}: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}" f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n") if i % args.N_iters == 0 and i > 0: # Turn on testing mode with torch.no_grad(): path_pose = os.path.join(basedir, expname) i_render_pose = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2 render_poses_final = all_poses[i_render_pose] save_render_pose(render_poses_final, path_pose) global_step += 1 if __name__=='__main__': torch.set_default_tensor_type('torch.cuda.FloatTensor') train()