Showing preview only (231K chars total). Download the full file or copy to clipboard to get everything.
Repository: rover-xingyu/L2G-NeRF
Branch: main
Commit: 08d06597a233
Files: 40
Total size: 218.7 KB
Directory structure:
gitextract_16lz1fht/
├── LICENSE
├── README.md
├── camera.py
├── data/
│ ├── base.py
│ ├── blender.py
│ ├── iphone.py
│ └── llff.py
├── evaluate.py
├── external/
│ └── pohsun_ssim/
│ ├── LICENSE.txt
│ ├── README.md
│ ├── max_ssim.py
│ ├── pytorch_ssim/
│ │ └── __init__.py
│ ├── setup.cfg
│ └── setup.py
├── extract_mesh.py
├── model/
│ ├── barf.py
│ ├── base.py
│ ├── l2g_nerf.py
│ ├── l2g_planar.py
│ ├── nerf.py
│ └── planar.py
├── options/
│ ├── barf_blender.yaml
│ ├── barf_iphone.yaml
│ ├── barf_llff.yaml
│ ├── base.yaml
│ ├── l2g_nerf_blender.yaml
│ ├── l2g_nerf_iphone.yaml
│ ├── l2g_nerf_llff.yaml
│ ├── l2g_planar.yaml
│ ├── nerf_blender.yaml
│ ├── nerf_blender_repr.yaml
│ ├── nerf_llff.yaml
│ ├── nerf_llff_repr.yaml
│ └── planar.yaml
├── options.py
├── requirements.yaml
├── train.py
├── util.py
├── util_vis.py
└── warp.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) [year] [fullname]
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# L2G-NeRF: Local-to-Global Registration for Bundle-Adjusting Neural Radiance Fields
**[Project Page](https://rover-xingyu.github.io/L2G-NeRF/) |
[Paper](https://arxiv.org/pdf/2211.11505.pdf) |
[Video](https://www.youtube.com/watch?v=y8XP9Umt6Mw)**
[Yue Chen¹](https://scholar.google.com/citations?user=M2hq1_UAAAAJ&hl=en),
[Xingyu Chen¹](https://scholar.google.com/citations?user=gDHPrWEAAAAJ&hl=en),
[Xuan Wang²](https://scholar.google.com/citations?user=h-3xd3EAAAAJ&hl=en),
[Qi Zhang³](https://scholar.google.com/citations?user=2vFjhHMAAAAJ&hl=en),
[Yu Guo¹](https://scholar.google.com/citations?user=OemeiSIAAAAJ&hl=en),
[Ying Shan³](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en),
[Fei Wang¹](https://scholar.google.com/citations?user=uU2JTpUAAAAJ&hl=en).
[¹Xi'an Jiaotong University](http://en.xjtu.edu.cn/),
[²Ant Group](https://www.antgroup.com/en),
[³Tencent AI Lab](https://ai.tencent.com/ailab/en/index/).
This repository is an official implementation of [L2G-NeRF](https://rover-xingyu.github.io/L2G-NeRF/) using [pytorch](https://pytorch.org/).
# :computer: Installation
## Hardware
* We implement all experiments on a single NVIDIA GeForce RTX 2080 Ti GPU.
* L2G-NeRF takes about 4.5 and 8 hours for training in synthetic objects and real-world scenes, respectively, while training BARF takes about 8 and 10.5 hours.
## Software
* Clone this repo by `git clone https://github.com/rover-xingyu/L2G-NeRF`
* This code is developed with Python3. PyTorch 1.9+ is required. It is recommended use [Anaconda](https://www.anaconda.com/products/individual) to set up the environment, use `conda env create --file requirements.yaml python=3` to install the dependencies and activate it by `conda activate L2G-NeRF`
--------------------------------------
# :key: Training and Evaluation
## Data download
Both the Blender synthetic data and LLFF real-world data can be found in the [NeRF Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
For convenience, you can download them with the following script:
```bash
# Blender
gdown --id 18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG # download nerf_synthetic.zip
unzip nerf_synthetic.zip
rm -f nerf_synthetic.zip
mv nerf_synthetic data/blender
# LLFF
gdown --id 16VnMcF1KJYxN9QId6TClMsZRahHNMW5g # download nerf_llff_data.zip
unzip nerf_llff_data.zip
rm -f nerf_llff_data.zip
mv nerf_llff_data data/llff
```
--------------------------------------
## Running L2G-NeRF
To train and evaluate L2G-NeRF:
```bash
# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets
# NeRF (3D): Synthetic Objects
# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})
python3 train.py \
--model=l2g_nerf --yaml=l2g_nerf_blender \
--group=exp_synthetic --name=l2g_lego \
--data.scene=lego --gpu=3 \
--data.root=/the/data/path/of/nerf_synthetic/ \
--camera.noise_r=0.07 --camera.noise_t=0.5
python3 evaluate.py \
--model=l2g_nerf --yaml=l2g_nerf_blender \
--group=exp_synthetic --name=l2g_lego \
--data.scene=lego --gpu=3 \
--data.root=/the/data/path/of/nerf_synthetic/ \
--data.val_sub= --resume
# NeRF (3D): Real-World Scenes
# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})
python3 train.py \
--model=l2g_nerf --yaml=l2g_nerf_llff \
--group=exp_LLFF --name=l2g_fern \
--data.scene=fern --gpu=3 \
--data.root=/the/data/path/of/nerf_llff_data/ \
--loss_weight.global_alignment=2
python3 evaluate.py \
--model=l2g_nerf --yaml=l2g_nerf_llff \
--group=exp_LLFF --name=l2g_fern \
--data.scene=fern --gpu=3 \
--data.root=/the/data/path/of/nerf_llff_data/ \
--loss_weight.global_alignment=2 \
--resume
# Neural image Alignment (2D): Rigid
# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment
python3 train.py \
--model=l2g_planar --yaml=l2g_planar \
--group=exp_planar --name=l2g_girl \
--warp.type=rigid --warp.dof=3 \
--data.image_fname=data/girl.jpg \
--data.image_size=[595,512] \
--data.patch_crop=[260,260] \
--seed=1 --gpu=3
# Neural image Alignment (2D): Homography
# use the image of “cat” from ImageNet for homography image alignment
python3 train.py \
--model=l2g_planar --yaml=l2g_planar \
--group=exp_planar --name=l2g_cat \
--warp.type=homography --warp.dof=8 \
--data.image_fname=data/cat.jpg \
--data.image_size=[360,480] \
--data.patch_crop=[180,180] \
--gpu=0
```
--------------------------------------
## Running BARF
If you want to train and evaluate the BARF extension of the original NeRF model that jointly optimizes poses (coarse-to-fine positional encoding):
```bash
# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets
# NeRF (3D): Synthetic Objects
# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})
python3 train.py \
--model=barf --yaml=barf_blender \
--group=exp_synthetic --name=barf_lego \
--data.scene=lego --gpu=1 \
--data.root=/the/data/path/of/nerf_synthetic/ \
--camera.noise_r=0.07 --camera.noise_t=0.5
python3 evaluate.py \
--model=barf --yaml=barf_blender \
--group=exp_synthetic --name=barf_lego \
--data.scene=lego --gpu=1 \
--data.root=/the/data/path/of/nerf_synthetic/ \
--data.val_sub= --resume
# NeRF (3D): Real-World Scenes
# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})
python3 train.py \
--model=barf --yaml=barf_llff \
--group=exp_LLFF --name=barf_fern \
--data.scene=fern --gpu=1 \
--data.root=/the/data/path/of/nerf_llff_data/
python3 evaluate.py \
--model=barf --yaml=barf_llff \
--group=exp_LLFF --name=barf_fern \
--data.scene=fern --gpu=1 \
--data.root=/the/data/path/of/nerf_llff_data/ \
--resume
# Neural image Alignment (2D): Rigid
# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment
python3 train.py \
--model=planar --yaml=planar \
--group=exp_planar --name=barf_girl \
--warp.type=rigid --warp.dof=3 \
--data.image_fname=data/girl.jpg \
--data.image_size=[595,512] \
--data.patch_crop=[260,260] \
--seed=1 --gpu=1
# Neural image Alignment (2D): Homography
# use the image of “cat” from ImageNet for homography image alignment
python3 train.py \
--model=planar --yaml=planar \
--group=exp_planar --name=barf_cat \
--warp.type=homography --warp.dof=8 \
--data.image_fname=data/cat.jpg \
--data.image_size=[360,480] \
--data.patch_crop=[180,180] \
--gpu=2
```
--------------------------------------
## Running Naive
If you want to train and evaluate the Naive extension of the original NeRF model that jointly optimizes poses (full positional encoding):
```bash
# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets
# NeRF (3D): Synthetic Objects
# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})
python3 train.py \
--model=barf --yaml=barf_blender \
--group=exp_synthetic --name=nerf_lego \
--data.scene=lego --gpu=2 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \
--barf_c2f=null \
--camera.noise_r=0.07 --camera.noise_t=0.5
python3 evaluate.py \
--model=barf --yaml=barf_blender \
--group=exp_synthetic --name=nerf_lego \
--data.scene=lego --gpu=2 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \
--barf_c2f=null \
--data.val_sub= --resume
# NeRF (3D): Real-World Scenes
# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})
python3 train.py \
--model=barf --yaml=barf_llff \
--group=exp_LLFF --name=nerf_fern \
--data.scene=fern --gpu=2 \
--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \
--barf_c2f=null
python3 evaluate.py \
--model=barf --yaml=barf_llff \
--group=exp_LLFF --name=nerf_fern \
--data.scene=fern --gpu=2 \
--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \
--barf_c2f=null --resume
# Neural image Alignment (2D): Rigid
# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment
python3 train.py \
--model=planar --yaml=planar \
--group=exp_planar --name=naive_girl \
--warp.type=rigid --warp.dof=3 \
--data.image_fname=data/girl.jpg \
--data.image_size=[595,512] \
--data.patch_crop=[260,260] \
--seed=1 --gpu=2 --barf_c2f=null
# Neural image Alignment (2D): Homography
# use the image of “cat” from ImageNet for homography image alignment
python3 train.py \
--model=planar --yaml=planar \
--group=exp_planar --name=naive_cat \
--warp.type=homography --warp.dof=8 \
--data.image_fname=data/cat.jpg \
--data.image_size=[360,480] \
--data.patch_crop=[180,180] \
--gpu=3 --barf_c2f=null
```
--------------------------------------
## Running reference NeRF
If you want to train and evaluate the reference NeRF models (assuming known camera poses):
```bash
# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets
# NeRF (3D): Synthetic Objects
# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})
python3 train.py \
--model=nerf --yaml=nerf_blender \
--group=exp_synthetic --name=ref_lego \
--data.scene=lego --gpu=0 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/
python3 evaluate.py \
--model=nerf --yaml=nerf_blender \
--group=exp_synthetic --name=ref_lego \
--data.scene=lego --gpu=0 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \
--data.val_sub= --resume
# NeRF (3D): Real-World Scenes
# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})
python3 train.py \
--model=nerf --yaml=nerf_llff \
--group=exp_LLFF --name=ref_fern \
--data.scene=fern --gpu=0 \
--data.root=/home/cy/PNW/datasets/nerf_llff_data/
python3 evaluate.py \
--model=nerf --yaml=nerf_llff \
--group=exp_LLFF --name=ref_fern \
--data.scene=fern --gpu=0 \
--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \
--resume
```
--------------------------------------
# :laughing: Visualization
## Results and Videos
All the results will be stored in the directory `output/<GROUP>/<NAME>`.
You may want to organize your experiments by grouping different runs in the same group. Many videos will be created to visualize the pose optimization process and novel view synthesis.
## TensorBoard
The TensorBoard events include the following:
- **SCALARS**: the rendering losses and PSNR over the course of optimization. For L2G_NeRF/BARF/Naive, the rotational/translational errors with respect to the given poses are also computed.
- **IMAGES**: visualization of the RGB images and the RGB/depth rendering.
## Visdom
The visualization of 3D camera poses is provided in Visdom:
Run `visdom -port 8600` to start the Visdom server. The Visdom host server is default to `localhost`; this can be overridden with `--visdom.server` (see `options/base.yaml` for details). If you want to disable Visdom visualization, add `--visdom!`.
## Mesh
The `extract_mesh.py` script provides a simple way to extract the underlying 3D geometry using marching cubes (supporte for the Blender dataset). Run as follows:
```bash
python3 extract_mesh.py \
--model=l2g_nerf --yaml=l2g_nerf_blender \
--group=exp_synthetic --name=l2g_lego \
--data.scene=lego --gpu=3 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \
--data.val_sub= --resume
```
--------------------------------------
# :mag_right: Codebase structure
The main engine and network architecture in `model/l2g_nerf.py` inherit those from `model/nerf.py`.
Some tips on using and understanding the codebase:
- The computation graph for forward/backprop is stored in `var` throughout the codebase.
- The losses are stored in `loss`. To add a new loss function, just implement it in `compute_loss()` and add its weight to `opt.loss_weight.<name>`. It will automatically be added to the overall loss and logged to Tensorboard.
- If you are using a multi-GPU machine, you can set `--gpu=<gpu_number>` to specify which GPU to use. Multi-GPU training/evaluation is currently not supported.
- To resume from a previous checkpoint, add `--resume=<ITER_NUMBER>`, or just `--resume` to resume from the latest checkpoint.
- To eliminate the global alignment objective, set `--loss_weight.global_alignment=null`, the ablation is equivalent to a local registration method.
--------------------------------------
# Citation
If you find this project useful for your research, please use the following BibTeX entry.
```bibtex
@inproceedings{chen2023local,
title={Local-to-global registration for bundle-adjusting neural radiance fields},
author={Chen, Yue and Chen, Xingyu and Wang, Xuan and Zhang, Qi and Guo, Yu and Shan, Ying and Wang, Fei},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={8264--8273},
year={2023}
}
```
--------------------------------------
# Acknowledge
Our code is based on the awesome pytorch implementation of Bundle-Adjusting Neural Radiance Fields ([BARF](https://github.com/chenhsuanlin/bundle-adjusting-NeRF)). We appreciate all the contributors.
================================================
FILE: camera.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import collections
from easydict import EasyDict as edict
import util
from util import log,debug
class Pose():
"""
A class of operations on camera poses (PyTorch tensors with shape [...,3,4])
each [3,4] camera pose takes the form of [R|t]
"""
def __call__(self,R=None,t=None):
# construct a camera pose from the given R and/or t
assert(R is not None or t is not None)
if R is None:
if not isinstance(t,torch.Tensor): t = torch.tensor(t)
R = torch.eye(3,device=t.device).repeat(*t.shape[:-1],1,1)
elif t is None:
if not isinstance(R,torch.Tensor): R = torch.tensor(R)
t = torch.zeros(R.shape[:-1],device=R.device)
else:
if not isinstance(R,torch.Tensor): R = torch.tensor(R)
if not isinstance(t,torch.Tensor): t = torch.tensor(t)
assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3,3))
R = R.float()
t = t.float()
pose = torch.cat([R,t[...,None]],dim=-1) # [...,3,4]
assert(pose.shape[-2:]==(3,4))
return pose
def invert(self,pose,use_inverse=False):
# invert a camera pose
R,t = pose[...,:3],pose[...,3:]
R_inv = R.inverse() if use_inverse else R.transpose(-1,-2)
t_inv = (-R_inv@t)[...,0]
pose_inv = self(R=R_inv,t=t_inv)
return pose_inv
def compose(self,pose_list):
# compose a sequence of poses together
# pose_new(x) = poseN o ... o pose2 o pose1(x)
pose_new = pose_list[0]
for pose in pose_list[1:]:
pose_new = self.compose_pair(pose_new,pose)
return pose_new
def compose_pair(self,pose_a,pose_b):
# pose_new(x) = pose_b o pose_a(x)
R_a,t_a = pose_a[...,:3],pose_a[...,3:]
R_b,t_b = pose_b[...,:3],pose_b[...,3:]
R_new = R_b@R_a
t_new = (R_b@t_a+t_b)[...,0]
pose_new = self(R=R_new,t=t_new)
return pose_new
class Lie():
"""
Lie algebra for SO(3) and SE(3) operations in PyTorch
"""
def so3_to_SO3(self,w): # [...,3]
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
R = I+A*wx+B*wx@wx
return R
def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
trace = R[...,0,0]+R[...,1,1]+R[...,2,2]
theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi
lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird
w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]
w = torch.stack([w0,w1,w2],dim=-1)
return w
def se3_to_SE3(self,wu): # [...,3]
w,u = wu.split([3,3],dim=-1)
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
C = self.taylor_C(theta)
R = I+A*wx+B*wx@wx
V = I+B*wx+C*wx@wx
Rt = torch.cat([R,(V@u[...,None])],dim=-1)
return Rt
def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
R,t = Rt.split([3,1],dim=-1)
w = self.SO3_to_so3(R)
wx = self.skew_symmetric(w)
theta = w.norm(dim=-1)[...,None,None]
I = torch.eye(3,device=w.device,dtype=torch.float32)
A = self.taylor_A(theta)
B = self.taylor_B(theta)
invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
u = (invV@t)[...,0]
wu = torch.cat([w,u],dim=-1)
return wu
def skew_symmetric(self,w):
w0,w1,w2 = w.unbind(dim=-1)
O = torch.zeros_like(w0)
wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),
torch.stack([w2,O,-w0],dim=-1),
torch.stack([-w1,w0,O],dim=-1)],dim=-2)
return wx
def taylor_A(self,x,nth=10):
# Taylor expansion of sin(x)/x
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
if i>0: denom *= (2*i)*(2*i+1)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
def taylor_B(self,x,nth=10):
# Taylor expansion of (1-cos(x))/x**2
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+1)*(2*i+2)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
def taylor_C(self,x,nth=10):
# Taylor expansion of (x-sin(x))/x**3
ans = torch.zeros_like(x)
denom = 1.
for i in range(nth+1):
denom *= (2*i+2)*(2*i+3)
ans = ans+(-1)**i*x**(2*i)/denom
return ans
class Quaternion():
def q_to_R(self,q):
# https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
qa,qb,qc,qd = q.unbind(dim=-1)
R = torch.stack([torch.stack([1-2*(qc**2+qd**2),2*(qb*qc-qa*qd),2*(qa*qc+qb*qd)],dim=-1),
torch.stack([2*(qb*qc+qa*qd),1-2*(qb**2+qd**2),2*(qc*qd-qa*qb)],dim=-1),
torch.stack([2*(qb*qd-qa*qc),2*(qa*qb+qc*qd),1-2*(qb**2+qc**2)],dim=-1)],dim=-2)
return R
def R_to_q(self,R,eps=1e-8): # [B,3,3]
# https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
# FIXME: this function seems a bit problematic, need to double-check
row0,row1,row2 = R.unbind(dim=-2)
R00,R01,R02 = row0.unbind(dim=-1)
R10,R11,R12 = row1.unbind(dim=-1)
R20,R21,R22 = row2.unbind(dim=-1)
t = R[...,0,0]+R[...,1,1]+R[...,2,2]
r = (1+t+eps).sqrt()
qa = 0.5*r
qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt()
qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt()
qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt()
q = torch.stack([qa,qb,qc,qd],dim=-1)
for i,qi in enumerate(q):
if torch.isnan(qi).any():
K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1),
torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R02],dim=-1),
torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1),
torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0
K = K[i]
eigval,eigvec = torch.linalg.eigh(K)
V = eigvec[:,eigval.argmax()]
q[i] = torch.stack([V[3],V[0],V[1],V[2]])
return q
def invert(self,q):
qa,qb,qc,qd = q.unbind(dim=-1)
norm = q.norm(dim=-1,keepdim=True)
q_inv = torch.stack([qa,-qb,-qc,-qd],dim=-1)/norm**2
return q_inv
def product(self,q1,q2): # [B,4]
q1a,q1b,q1c,q1d = q1.unbind(dim=-1)
q2a,q2b,q2c,q2d = q2.unbind(dim=-1)
hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d,
q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c,
q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b,
q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a],dim=-1)
return hamil_prod
pose = Pose()
lie = Lie()
quaternion = Quaternion()
def to_hom(X):
# get homogeneous coordinates of the input
X_hom = torch.cat([X,torch.ones_like(X[...,:1])],dim=-1)
return X_hom
# basic operations of transforming 3D points between world/camera/image coordinates
def world2cam(X,pose): # [B,N,3]
X_hom = to_hom(X)
return X_hom@pose.transpose(-1,-2)
def cam2img(X,cam_intr):
return X@cam_intr.transpose(-1,-2)
def img2cam(X,cam_intr):
return X@cam_intr.inverse().transpose(-1,-2)
def cam2world(X,pose):
X_hom = to_hom(X)
pose_inv = Pose().invert(pose)
return X_hom@pose_inv.transpose(-1,-2)
def angle_to_rotation_matrix(a,axis):
# get the rotation matrix from Euler angle around specific axis
roll = dict(X=1,Y=2,Z=0)[axis]
O = torch.zeros_like(a)
I = torch.ones_like(a)
M = torch.stack([torch.stack([a.cos(),-a.sin(),O],dim=-1),
torch.stack([a.sin(),a.cos(),O],dim=-1),
torch.stack([O,O,I],dim=-1)],dim=-2)
M = M.roll((roll,roll),dims=(-2,-1))
return M
def get_center_and_ray(opt,pose,intr=None): # [HW,2]
# given the intrinsic/extrinsic matrices, get the camera center and ray directions]
assert(opt.camera.model=="perspective")
with torch.no_grad():
# compute image coordinate grid
y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)
x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)
Y,X = torch.meshgrid(y_range,x_range) # [H,W]
xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
# compute center and ray
batch_size = len(pose)
xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]
grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]
center_3D = torch.zeros_like(grid_3D) # [B,HW,3]
# transform from camera to world coordinates
grid_3D = cam2world(grid_3D,pose) # [B,HW,3]
center_3D = cam2world(center_3D,pose) # [B,HW,3]
ray = grid_3D-center_3D # [B,HW,3]
return center_3D,ray
def get_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [HW,2]
# given the intrinsic matrices, get the grid_3D]
assert(opt.camera.model=="perspective")
with torch.no_grad():
# compute image coordinate grid
y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)
x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)
Y,X = torch.meshgrid(y_range,x_range) # [H,W]
xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
# compute grid_3D
xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]
grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]
if ray_idx is not None:
# consider only subset of rays
grid_3D = grid_3D[:,ray_idx]
return grid_3D
def gather_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [HW,2]
# given the intrinsic matrices, get the grid_3D]
assert(opt.camera.model=="perspective")
with torch.no_grad():
# compute image coordinate grid
y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)
x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)
Y,X = torch.meshgrid(y_range,x_range) # [H,W]
xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
# compute grid_3D
xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]
grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]
if ray_idx is not None:
# consider only subset of rays
grid_3D = torch.gather(grid_3D, 1, ray_idx[...,None].expand(-1,-1,3))
return grid_3D
def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False):
if multi_samples: center,ray = center[:,:,None],ray[:,:,None]
# x = c+dv
points_3D = center+ray*depth # [B,HW,3]/[B,HW,N,3]/[N,3]
return points_3D
def convert_NDC(opt,center,ray,intr,near=1):
# shift camera center (ray origins) to near plane (z=1)
# (unlike conventional NDC, we assume the cameras are facing towards the +z direction)
center = center+(near-center[...,2:])/ray[...,2:]*ray
# projection
cx,cy,cz = center.unbind(dim=-1) # [B,HW]
rx,ry,rz = ray.unbind(dim=-1) # [B,HW]
scale_x = intr[:,0,0]/intr[:,0,2] # [B]
scale_y = intr[:,1,1]/intr[:,1,2] # [B]
cnx = scale_x[:,None]*(cx/cz)
cny = scale_y[:,None]*(cy/cz)
cnz = 1-2*near/cz
rnx = scale_x[:,None]*(rx/rz-cx/cz)
rny = scale_y[:,None]*(ry/rz-cy/cz)
rnz = 2*near/cz
center_ndc = torch.stack([cnx,cny,cnz],dim=-1) # [B,HW,3]
ray_ndc = torch.stack([rnx,rny,rnz],dim=-1) # [B,HW,3]
return center_ndc,ray_ndc
def rotation_distance(R1,R2,eps=1e-7):
# http://www.boris-belousov.net/2016/12/01/quat-dist/
R_diff = R1@R2.transpose(-2,-1)
trace = R_diff[...,0,0]+R_diff[...,1,1]+R_diff[...,2,2]
angle = ((trace-1)/2).clamp(-1+eps,1-eps).acos_() # numerical stability near -1/+1
return angle
def procrustes_analysis(X0,X1): # [N,3]
# translation
t0 = X0.mean(dim=0,keepdim=True)
t1 = X1.mean(dim=0,keepdim=True)
X0c = X0-t0
X1c = X1-t1
# scale
s0 = (X0c**2).sum(dim=-1).mean().sqrt()
s1 = (X1c**2).sum(dim=-1).mean().sqrt()
X0cs = X0c/s0
X1cs = X1c/s1
# rotation (use double for SVD, float loses precision)
U,S,V = (X0cs.t()@X1cs).double().svd(some=True)
R = (U@V.t()).float()
if R.det()<0: R[2] *= -1
# align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0
sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R)
return sim3
def get_novel_view_poses(opt,pose_anchor,N=60,scale=1):
# create circular viewpoints (small oscillations)
theta = torch.arange(N)/N*2*np.pi
R_x = angle_to_rotation_matrix((theta.sin()*0.05).asin(),"X")
R_y = angle_to_rotation_matrix((theta.cos()*0.05).asin(),"Y")
pose_rot = pose(R=R_y@R_x)
pose_shift = pose(t=[0,0,-4*scale])
pose_shift2 = pose(t=[0,0,3.8*scale])
pose_oscil = pose.compose([pose_shift,pose_rot,pose_shift2])
pose_novel = pose.compose([pose_oscil,pose_anchor.cpu()[None]])
return pose_novel
================================================
FILE: data/base.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import torch.multiprocessing as mp
import PIL
import tqdm
import threading,queue
from easydict import EasyDict as edict
import util
from util import log,debug
class Dataset(torch.utils.data.Dataset):
def __init__(self,opt,split="train"):
super().__init__()
self.opt = opt
self.split = split
self.augment = split=="train" and opt.data.augment
# define image sizes
if opt.data.center_crop is not None:
self.crop_H = int(self.raw_H*opt.data.center_crop)
self.crop_W = int(self.raw_W*opt.data.center_crop)
else: self.crop_H,self.crop_W = self.raw_H,self.raw_W
if not opt.H or not opt.W:
opt.H,opt.W = self.crop_H,self.crop_W
def setup_loader(self,opt,shuffle=False,drop_last=False):
loader = torch.utils.data.DataLoader(self,
batch_size=opt.batch_size or 1,
num_workers=opt.data.num_workers,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=False, # spews warnings in PyTorch 1.9 but should be True in general
)
print("number of samples: {}".format(len(self)))
return loader
def get_list(self,opt):
raise NotImplementedError
def preload_worker(self,data_list,load_func,q,lock,idx_tqdm):
while True:
idx = q.get()
data_list[idx] = load_func(self.opt,idx)
with lock:
idx_tqdm.update()
q.task_done()
def preload_threading(self,opt,load_func,data_str="images"):
data_list = [None]*len(self)
q = queue.Queue(maxsize=len(self))
idx_tqdm = tqdm.tqdm(range(len(self)),desc="preloading {}".format(data_str),leave=False)
for i in range(len(self)): q.put(i)
lock = threading.Lock()
for ti in range(opt.data.num_workers):
t = threading.Thread(target=self.preload_worker,
args=(data_list,load_func,q,lock,idx_tqdm),daemon=True)
t.start()
q.join()
idx_tqdm.close()
assert(all(map(lambda x: x is not None,data_list)))
return data_list
def __getitem__(self,idx):
raise NotImplementedError
def get_image(self,opt,idx):
raise NotImplementedError
def generate_augmentation(self,opt):
brightness = opt.data.augment.brightness or 0.
contrast = opt.data.augment.contrast or 0.
saturation = opt.data.augment.saturation or 0.
hue = opt.data.augment.hue or 0.
color_jitter = torchvision.transforms.ColorJitter.get_params(
brightness=(1-brightness,1+brightness),
contrast=(1-contrast,1+contrast),
saturation=(1-saturation,1+saturation),
hue=(-hue,hue),
)
aug = edict(
color_jitter=color_jitter,
flip=np.random.randn()>0 if opt.data.augment.hflip else False,
rot_angle=(np.random.rand()*2-1)*opt.data.augment.rotate if opt.data.augment.rotate else 0,
)
return aug
def preprocess_image(self,opt,image,aug=None):
if aug is not None:
image = self.apply_color_jitter(opt,image,aug.color_jitter)
image = torchvision_F.hflip(image) if aug.flip else image
image = image.rotate(aug.rot_angle,resample=PIL.Image.BICUBIC)
# center crop
if opt.data.center_crop is not None:
self.crop_H = int(self.raw_H*opt.data.center_crop)
self.crop_W = int(self.raw_W*opt.data.center_crop)
image = torchvision_F.center_crop(image,(self.crop_H,self.crop_W))
else: self.crop_H,self.crop_W = self.raw_H,self.raw_W
# resize
if opt.data.image_size[0] is not None:
image = image.resize((opt.W,opt.H))
image = torchvision_F.to_tensor(image)
return image
def preprocess_camera(self,opt,intr,pose,aug=None):
intr,pose = intr.clone(),pose.clone()
# center crop
intr[0,2] -= (self.raw_W-self.crop_W)/2
intr[1,2] -= (self.raw_H-self.crop_H)/2
# resize
intr[0] *= opt.W/self.crop_W
intr[1] *= opt.H/self.crop_H
return intr,pose
def apply_color_jitter(self,opt,image,color_jitter):
mode = image.mode
if mode!="L":
chan = image.split()
rgb = PIL.Image.merge("RGB",chan[:3])
rgb = color_jitter(rgb)
rgb_chan = rgb.split()
image = PIL.Image.merge(mode,rgb_chan+chan[3:])
return image
def __len__(self):
return len(self.list)
================================================
FILE: data/blender.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import PIL
import imageio
from easydict import EasyDict as edict
import json
import pickle
from . import base
import camera
from util import log,debug
class Dataset(base.Dataset):
def __init__(self,opt,split="train",subset=None):
self.raw_H,self.raw_W = 800,800
super().__init__(opt,split)
self.root = opt.data.root or "data/blender"
self.path = "{}/{}".format(self.root,opt.data.scene)
# load/parse metadata
meta_fname = "{}/transforms_{}.json".format(self.path,split)
with open(meta_fname) as file:
self.meta = json.load(file)
self.list = self.meta["frames"]
self.focal = 0.5*self.raw_W/np.tan(0.5*self.meta["camera_angle_x"])
if subset: self.list = self.list[:subset]
# preload dataset
if opt.data.preload:
self.images = self.preload_threading(opt,self.get_image)
self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")
def prefetch_all_data(self,opt):
assert(not opt.data.augment)
# pre-iterate through all samples and group together
self.all = torch.utils.data._utils.collate.default_collate([s for s in self])
def get_all_camera_poses(self,opt):
pose_raw_all = [torch.tensor(f["transform_matrix"],dtype=torch.float32) for f in self.list]
pose_canon_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0)
return pose_canon_all
def __getitem__(self,idx):
opt = self.opt
sample = dict(idx=idx)
aug = self.generate_augmentation(opt) if self.augment else None
image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
image = self.preprocess_image(opt,image,aug=aug)
intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
sample.update(
image=image,
intr=intr,
pose=pose,
)
return sample
def get_image(self,opt,idx):
image_fname = "{}/{}.png".format(self.path,self.list[idx]["file_path"])
image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
return image
def preprocess_image(self,opt,image,aug=None):
image = super().preprocess_image(opt,image,aug=aug)
rgb,mask = image[:3],image[3:]
if opt.data.bgcolor is not None:
rgb = rgb*mask+opt.data.bgcolor*(1-mask)
return rgb
def get_camera(self,opt,idx):
intr = torch.tensor([[self.focal,0,self.raw_W/2],
[0,self.focal,self.raw_H/2],
[0,0,1]]).float()
pose_raw = torch.tensor(self.list[idx]["transform_matrix"],dtype=torch.float32)
pose = self.parse_raw_camera(opt,pose_raw)
return intr,pose
def parse_raw_camera(self,opt,pose_raw):
pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1])))
pose = camera.pose.compose([pose_flip,pose_raw[:3]])
pose = camera.pose.invert(pose)
return pose
================================================
FILE: data/iphone.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import PIL
import imageio
from easydict import EasyDict as edict
import json
import pickle
from . import base
import camera
from util import log,debug
class Dataset(base.Dataset):
def __init__(self,opt,split="train",subset=None):
self.raw_H,self.raw_W = 1080,1920
super().__init__(opt,split)
self.root = opt.data.root or "data/iphone"
self.path = "{}/{}".format(self.root,opt.data.scene)
self.path_image = "{}/images".format(self.path)
# self.list = sorted(os.listdir(self.path_image),key=lambda f: int(f.split(".")[0]))
self.list = os.listdir(self.path_image)
# manually split train/val subsets
num_val_split = int(len(self)*opt.data.val_ratio)
self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:]
if subset: self.list = self.list[:subset]
# preload dataset
if opt.data.preload:
self.images = self.preload_threading(opt,self.get_image)
self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")
def prefetch_all_data(self,opt):
assert(not opt.data.augment)
# pre-iterate through all samples and group together
self.all = torch.utils.data._utils.collate.default_collate([s for s in self])
def get_all_camera_poses(self,opt):
# poses are unknown, so just return some dummy poses (identity transform)
return camera.pose(t=torch.zeros(len(self),3))
def __getitem__(self,idx):
opt = self.opt
sample = dict(idx=idx)
aug = self.generate_augmentation(opt) if self.augment else None
image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
image = self.preprocess_image(opt,image,aug=aug)
intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
sample.update(
image=image,
intr=intr,
pose=pose,
)
return sample
def get_image(self,opt,idx):
image_fname = "{}/{}".format(self.path_image,self.list[idx])
image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
return image
def get_camera(self,opt,idx):
self.focal = self.raw_W*4.2/(12.8/2.55)
intr = torch.tensor([[self.focal,0,self.raw_W/2],
[0,self.focal,self.raw_H/2],
[0,0,1]]).float()
pose = camera.pose(t=torch.zeros(3)) # dummy pose, won't be used
return intr,pose
================================================
FILE: data/llff.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import PIL
import imageio
from easydict import EasyDict as edict
import json
import pickle
from . import base
import camera
from util import log,debug
class Dataset(base.Dataset):
def __init__(self,opt,split="train",subset=None):
self.raw_H,self.raw_W = 3024,4032
super().__init__(opt,split)
self.root = opt.data.root or "data/llff"
self.path = "{}/{}".format(self.root,opt.data.scene)
self.path_image = "{}/images".format(self.path)
image_fnames = sorted(os.listdir(self.path_image))
poses_raw,bounds = self.parse_cameras_and_bounds(opt)
self.list = list(zip(image_fnames,poses_raw,bounds))
# manually split train/val subsets
num_val_split = int(len(self)*opt.data.val_ratio)
self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:]
if subset: self.list = self.list[:subset]
# preload dataset
if opt.data.preload:
self.images = self.preload_threading(opt,self.get_image)
self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")
def prefetch_all_data(self,opt):
assert(not opt.data.augment)
# pre-iterate through all samples and group together
self.all = torch.utils.data._utils.collate.default_collate([s for s in self])
def parse_cameras_and_bounds(self,opt):
fname = "{}/poses_bounds.npy".format(self.path)
data = torch.tensor(np.load(fname),dtype=torch.float32)
# parse cameras (intrinsics and poses)
cam_data = data[:,:-2].view([-1,3,5]) # [N,3,5]
poses_raw = cam_data[...,:4] # [N,3,4]
poses_raw[...,0],poses_raw[...,1] = poses_raw[...,1],-poses_raw[...,0]
raw_H,raw_W,self.focal = cam_data[0,:,-1]
assert(self.raw_H==raw_H and self.raw_W==raw_W)
# parse depth bounds
bounds = data[:,-2:] # [N,2]
scale = 1./(bounds.min()*0.75) # not sure how this was determined
poses_raw[...,3] *= scale
bounds *= scale
# roughly center camera poses
poses_raw = self.center_camera_poses(opt,poses_raw)
return poses_raw,bounds
def center_camera_poses(self,opt,poses):
# compute average pose
center = poses[...,3].mean(dim=0)
v1 = torch_F.normalize(poses[...,1].mean(dim=0),dim=0)
v2 = torch_F.normalize(poses[...,2].mean(dim=0),dim=0)
v0 = v1.cross(v2)
pose_avg = torch.stack([v0,v1,v2,center],dim=-1)[None] # [1,3,4]
# apply inverse of averaged pose
poses = camera.pose.compose([poses,camera.pose.invert(pose_avg)])
return poses
def get_all_camera_poses(self,opt):
pose_raw_all = [tup[1] for tup in self.list]
pose_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0)
return pose_all
def __getitem__(self,idx):
opt = self.opt
sample = dict(idx=idx)
aug = self.generate_augmentation(opt) if self.augment else None
image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
image = self.preprocess_image(opt,image,aug=aug)
intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
sample.update(
image=image,
intr=intr,
pose=pose,
)
return sample
def get_image(self,opt,idx):
image_fname = "{}/{}".format(self.path_image,self.list[idx][0])
image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
return image
def get_camera(self,opt,idx):
intr = torch.tensor([[self.focal,0,self.raw_W/2],
[0,self.focal,self.raw_H/2],
[0,0,1]]).float()
pose_raw = self.list[idx][1]
pose = self.parse_raw_camera(opt,pose_raw)
return intr,pose
def parse_raw_camera(self,opt,pose_raw):
pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1])))
pose = camera.pose.compose([pose_flip,pose_raw[:3]])
pose = camera.pose.invert(pose)
pose = camera.pose.compose([pose_flip,pose])
return pose
================================================
FILE: evaluate.py
================================================
import numpy as np
import os,sys,time
import torch
import importlib
import options
from util import log
import warnings
warnings.filterwarnings('ignore')
def main():
log.process(os.getpid())
log.title("[{}] (PyTorch code for evaluating NeRF/BARF/L2G_NeRF)".format(sys.argv[0]))
opt_cmd = options.parse_arguments(sys.argv[1:])
opt = options.set(opt_cmd=opt_cmd)
with torch.cuda.device(opt.device):
model = importlib.import_module("model.{}".format(opt.model))
m = model.Model(opt)
m.load_dataset(opt,eval_split="test")
m.build_networks(opt)
if opt.model in ["barf", "l2g_nerf"]:
m.generate_videos_pose(opt)
m.restore_checkpoint(opt)
if opt.data.dataset in ["blender","llff"]:
m.evaluate_full(opt)
m.generate_videos_synthesis(opt)
if __name__=="__main__":
main()
================================================
FILE: external/pohsun_ssim/LICENSE.txt
================================================
MIT
================================================
FILE: external/pohsun_ssim/README.md
================================================
# pytorch-ssim
### Differentiable structural similarity (SSIM) index.
 
## Installation
1. Clone this repo.
2. Copy "pytorch_ssim" folder in your project.
## Example
### basic usage
```python
import pytorch_ssim
import torch
from torch.autograd import Variable
img1 = Variable(torch.rand(1, 1, 256, 256))
img2 = Variable(torch.rand(1, 1, 256, 256))
if torch.cuda.is_available():
img1 = img1.cuda()
img2 = img2.cuda()
print(pytorch_ssim.ssim(img1, img2))
ssim_loss = pytorch_ssim.SSIM(window_size = 11)
print(ssim_loss(img1, img2))
```
### maximize ssim
```python
import pytorch_ssim
import torch
from torch.autograd import Variable
from torch import optim
import cv2
import numpy as np
npImg1 = cv2.imread("einstein.png")
img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
img2 = torch.rand(img1.size())
if torch.cuda.is_available():
img1 = img1.cuda()
img2 = img2.cuda()
img1 = Variable( img1, requires_grad=False)
img2 = Variable( img2, requires_grad = True)
# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
ssim_value = pytorch_ssim.ssim(img1, img2).data[0]
print("Initial ssim:", ssim_value)
# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
ssim_loss = pytorch_ssim.SSIM()
optimizer = optim.Adam([img2], lr=0.01)
while ssim_value < 0.95:
optimizer.zero_grad()
ssim_out = -ssim_loss(img1, img2)
ssim_value = - ssim_out.data[0]
print(ssim_value)
ssim_out.backward()
optimizer.step()
```
## Reference
https://ece.uwaterloo.ca/~z70wang/research/ssim/
================================================
FILE: external/pohsun_ssim/max_ssim.py
================================================
import pytorch_ssim
import torch
from torch.autograd import Variable
from torch import optim
import cv2
import numpy as np
npImg1 = cv2.imread("einstein.png")
img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
img2 = torch.rand(img1.size())
if torch.cuda.is_available():
img1 = img1.cuda()
img2 = img2.cuda()
img1 = Variable( img1, requires_grad=False)
img2 = Variable( img2, requires_grad = True)
# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
ssim_value = pytorch_ssim.ssim(img1, img2).data[0]
print("Initial ssim:", ssim_value)
# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
ssim_loss = pytorch_ssim.SSIM()
optimizer = optim.Adam([img2], lr=0.01)
while ssim_value < 0.95:
optimizer.zero_grad()
ssim_out = -ssim_loss(img1, img2)
ssim_value = - ssim_out.data[0]
print(ssim_value)
ssim_out.backward()
optimizer.step()
================================================
FILE: external/pohsun_ssim/pytorch_ssim/__init__.py
================================================
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
================================================
FILE: external/pohsun_ssim/setup.cfg
================================================
[metadata]
description-file = README.md
================================================
FILE: external/pohsun_ssim/setup.py
================================================
from distutils.core import setup
setup(
name = 'pytorch_ssim',
packages = ['pytorch_ssim'], # this must be the same as the name above
version = '0.1',
description = 'Differentiable structural similarity (SSIM) index',
author = 'Po-Hsun (Evan) Su',
author_email = 'evan.pohsun.su@gmail.com',
url = 'https://github.com/Po-Hsun-Su/pytorch-ssim', # use the URL to the github repo
download_url = 'https://github.com/Po-Hsun-Su/pytorch-ssim/archive/0.1.tar.gz', # I'll explain this in a second
keywords = ['pytorch', 'image-processing', 'deep-learning'], # arbitrary keywords
classifiers = [],
)
================================================
FILE: extract_mesh.py
================================================
"""Extracts a 3D mesh from a pretrained model using marching cubes."""
import importlib
import sys
import numpy as np
import options
import torch
import tqdm
import trimesh
import mcubes
from util import log,debug
opt_cmd = options.parse_arguments(sys.argv[1:])
opt = options.set(opt_cmd=opt_cmd)
with torch.cuda.device(opt.device),torch.no_grad():
model = importlib.import_module("model.{}".format(opt.model))
m = model.Model(opt)
m.load_dataset(opt)
m.build_networks(opt)
m.restore_checkpoint(opt)
t = torch.linspace(*opt.trimesh.range,opt.trimesh.res+1) # the best range might vary from model to model
query = torch.stack(torch.meshgrid(t,t,t),dim=-1)
query_flat = query.view(-1,3)
density_all = []
for i in tqdm.trange(0,len(query_flat),opt.trimesh.chunk_size,leave=False):
points = query_flat[None,i:i+opt.trimesh.chunk_size].to(opt.device)
ray_unit = torch.zeros_like(points) # dummy ray to comply with interface, not used
_,density_samples = m.graph.nerf.forward(opt,points,ray_unit=ray_unit,mode=None)
density_all.append(density_samples.cpu())
density_all = torch.cat(density_all,dim=1)[0]
density_all = density_all.view(*query.shape[:-1]).numpy()
log.info("running marching cubes...")
vertices,triangles = mcubes.marching_cubes(density_all,opt.trimesh.thres)
vertices_centered = vertices/opt.trimesh.res-0.5
mesh = trimesh.Trimesh(vertices_centered,triangles)
obj_fname = "{}/mesh.obj".format(opt.output_path)
log.info("saving 3D mesh to {}...".format(obj_fname))
mesh.export(obj_fname)
================================================
FILE: model/barf.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import visdom
import matplotlib.pyplot as plt
import util,util_vis
from util import log,debug
from . import nerf
import camera
# ============================ main engine for training and evaluation ============================
class Model(nerf.Model):
def __init__(self,opt):
super().__init__(opt)
def build_networks(self,opt):
super().build_networks(opt)
if opt.camera.noise:
# pre-generate synthetic pose perturbation
so3_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_r
t_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_t
self.graph.pose_noise = torch.cat([camera.lie.so3_to_SO3(so3_noise),t_noise[...,None]],dim=-1) # [...,3,4]
self.graph.se3_refine = torch.nn.Embedding(len(self.train_data),6).to(opt.device)
torch.nn.init.zeros_(self.graph.se3_refine.weight)
pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
# add synthetic pose perturbation to all training data
if opt.data.dataset=="blender":
pose = pose_GT
if opt.camera.noise:
pose = camera.pose.compose([pose, self.graph.pose_noise])
else: pose = self.graph.pose_eye[None].repeat(len(self.train_data),1,1)
# use Embedding so it could be checkpointed
self.graph.optimised_training_poses = torch.nn.Embedding(len(self.train_data),12,_weight=pose.view(-1,12)).to(opt.device)
idx_range = torch.arange(len(self.train_data),dtype=torch.long,device=opt.device)
idx_X,idx_Y = torch.meshgrid(idx_range,idx_range)
self.graph.idx_grid = torch.stack([idx_X,idx_Y],dim=-1).view(-1,2)
def setup_optimizer(self,opt):
super().setup_optimizer(opt)
optimizer = getattr(torch.optim,opt.optim.algo)
self.optim_pose = optimizer([dict(params=self.graph.se3_refine.parameters(),lr=opt.optim.lr_pose)])
# set up scheduler
if opt.optim.sched_pose:
scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type)
if opt.optim.lr_pose_end:
assert(opt.optim.sched_pose.type=="ExponentialLR")
opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter)
kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!="type" }
self.sched_pose = scheduler(self.optim_pose,**kwargs)
def train_iteration(self,opt,var,loader):
self.optim_pose.zero_grad()
if opt.optim.warmup_pose:
# simple linear warmup of pose learning rate
self.optim_pose.param_groups[0]["lr_orig"] = self.optim_pose.param_groups[0]["lr"] # cache the original learning rate
self.optim_pose.param_groups[0]["lr"] *= min(1,self.it/opt.optim.warmup_pose)
loss = super().train_iteration(opt,var,loader)
self.optim_pose.step()
if opt.optim.warmup_pose:
self.optim_pose.param_groups[0]["lr"] = self.optim_pose.param_groups[0]["lr_orig"] # reset learning rate
if opt.optim.sched_pose: self.sched_pose.step()
self.graph.nerf.progress.data.fill_(self.it/opt.max_iter)
if opt.nerf.fine_sampling:
self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter)
return loss
@torch.no_grad()
def validate(self,opt,ep=None):
pose,pose_GT = self.get_all_training_poses(opt)
_,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
super().validate(opt,ep=ep)
@torch.no_grad()
def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
if split=="train":
# log learning rate
lr = self.optim_pose.param_groups[0]["lr"]
self.tb.add_scalar("{0}/{1}".format(split,"lr_pose"),lr,step)
# compute pose error
if split=="train" and opt.data.dataset in ["blender","llff"]:
pose,pose_GT = self.get_all_training_poses(opt)
pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
self.tb.add_scalar("{0}/error_R".format(split),error.R.mean(),step)
self.tb.add_scalar("{0}/error_t".format(split),error.t.mean(),step)
@torch.no_grad()
def visualize(self,opt,var,step=0,split="train"):
super().visualize(opt,var,step=step,split=split)
if opt.visdom:
if split=="val":
pose,pose_GT = self.get_all_training_poses(opt)
pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose_aligned,pose_GT])
@torch.no_grad()
def get_all_training_poses(self,opt):
# get ground-truth (canonical) camera poses
pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
pose = self.graph.optimised_training_poses.weight.data.detach().clone().view(-1,3,4)
return pose,pose_GT
@torch.no_grad()
def prealign_cameras(self,opt,pose,pose_GT):
# compute 3D similarity transform via Procrustes analysis
center = torch.zeros(1,1,3,device=opt.device)
center_pred = camera.cam2world(center,pose)[:,0] # [N,3]
center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3]
try:
sim3 = camera.procrustes_analysis(center_GT,center_pred)
except:
print("warning: SVD did not converge...")
sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device))
# align the camera poses
center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0
R_aligned = pose[...,:3]@sim3.R.t()
t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
pose_aligned = camera.pose(R=R_aligned,t=t_aligned)
return pose_aligned,sim3
@torch.no_grad()
def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
# measure errors in rotation and translation
R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1)
R_GT,t_GT = pose_GT.split([3,1],dim=-1)
R_error = camera.rotation_distance(R_aligned,R_GT)
t_error = (t_aligned-t_GT)[...,0].norm(dim=-1)
error = edict(R=R_error,t=t_error)
return error
@torch.no_grad()
def evaluate_full(self,opt):
self.graph.eval()
# evaluate rotation/translation
pose,pose_GT = self.get_all_training_poses(opt)
pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
print("--------------------------")
print("rot: {:8.3f}".format(np.rad2deg(error.R.mean().cpu())))
print("trans: {:10.5f}".format(error.t.mean()))
print("--------------------------")
# dump numbers
quant_fname = "{}/quant_pose.txt".format(opt.output_path)
with open(quant_fname,"w") as file:
for i,(err_R,err_t) in enumerate(zip(error.R,error.t)):
file.write("{} {} {}\n".format(i,err_R.item(),err_t.item()))
# evaluate novel view synthesis
super().evaluate_full(opt)
@torch.enable_grad()
def evaluate_test_time_photometric_optim(self,opt,var):
# use another se3 Parameter to absorb the remaining pose errors
var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device))
optimizer = getattr(torch.optim,opt.optim.algo)
optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)])
iterator = tqdm.trange(opt.optim.test_iter,desc="test-time optim.",leave=False,position=1)
for it in iterator:
optim_pose.zero_grad()
var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test)
var = self.graph.forward(opt,var,mode="test-optim")
loss = self.graph.compute_loss(opt,var,mode="test-optim")
loss = self.summarize_loss(opt,var,loss)
loss.all.backward()
optim_pose.step()
iterator.set_postfix(loss="{:.3f}".format(loss.all))
return var
@torch.no_grad()
def generate_videos_pose(self,opt):
self.graph.eval()
fig = plt.figure(figsize=(10,10) if opt.data.dataset=="blender" else (16,8))
cam_path = "{}/poses".format(opt.output_path)
os.makedirs(cam_path,exist_ok=True)
ep_list = []
for ep in range(0,opt.max_iter+1,opt.freq.ckpt):
# load checkpoint (0 is random init)
if ep!=0:
try: util.restore_checkpoint(opt,self,resume=ep)
except: continue
# get the camera poses
pose,pose_ref = self.get_all_training_poses(opt)
if opt.data.dataset in ["blender","llff"]:
pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref)
pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu()
dict(
blender=util_vis.plot_save_poses_blender,
llff=util_vis.plot_save_poses,
)[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep)
else:
pose = pose.detach().cpu()
util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep)
ep_list.append(ep)
plt.close()
# write videos
print("writing videos...")
list_fname = "{}/temp.list".format(cam_path)
with open(list_fname,"w") as file:
for ep in ep_list: file.write("file {}.png\n".format(ep))
cam_vid_fname = "{}/poses.mp4".format(opt.output_path)
os.system("ffmpeg -y -r 4 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1".format(list_fname,cam_vid_fname))
os.remove(list_fname)
# ============================ computation graph for forward/backprop ============================
class Graph(nerf.Graph):
def __init__(self,opt):
super().__init__(opt)
self.nerf = NeRF(opt)
if opt.nerf.fine_sampling:
self.nerf_fine = NeRF(opt)
self.pose_eye = torch.eye(3,4).to(opt.device)
def forward(self,opt,var,mode=None):
# rescale the size of the scene condition on the pose
if opt.data.dataset=="blender":
depth_min,depth_max = opt.nerf.depth.range
position = camera.Pose().invert(self.optimised_training_poses.weight.data.detach().clone().view(-1,3,4))[...,-1]
diameter = ((position[self.idx_grid[...,0]]-position[self.idx_grid[...,1]]).norm(dim=-1)).max()
depth_min_new = (depth_min/(depth_max+depth_min))*diameter
depth_max_new = (depth_max/(depth_max+depth_min))*diameter
opt.nerf.depth.range = [depth_min_new, depth_max_new]
# render images
batch_size = len(var.idx)
pose = self.get_pose(opt,var,mode=mode)
if opt.nerf.rand_rays and mode in ["train","test-optim"]:
# sample random rays for optimization
var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]
ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
else:
# render full image (process in slices)
ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \
self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]
var.update(ret)
return var
def get_pose(self,opt,var,mode=None):
if mode=="train":
# add the pre-generated pose perturbations
if opt.data.dataset=="blender":
if opt.camera.noise:
var.pose_noise = self.pose_noise[var.idx]
pose = camera.pose.compose([var.pose, var.pose_noise])
else: pose = var.pose
else: pose = self.pose_eye
# add learnable pose correction
var.se3_refine = self.se3_refine.weight[var.idx]
pose_refine = camera.lie.se3_to_SE3(var.se3_refine)
pose = camera.pose.compose([pose_refine, pose])
self.optimised_training_poses.weight.data = pose.detach().clone().view(-1,12)
elif mode in ["val","eval","test-optim"]:
# align test pose to refined coordinate system (up to sim3)
sim3 = self.sim3
center = torch.zeros(1,1,3,device=opt.device)
center = camera.cam2world(center,var.pose)[:,0] # [N,3]
center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1
R_aligned = var.pose[...,:3]@self.sim3.R
t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
pose = camera.pose(R=R_aligned,t=t_aligned)
# additionally factorize the remaining pose imperfection
if opt.optim.test_photo and mode!="val":
pose = camera.pose.compose([var.pose_refine_test, pose])
else: pose = var.pose
return pose
class NeRF(nerf.NeRF):
def __init__(self,opt):
super().__init__(opt)
self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed
def positional_encoding(self,opt,input,L): # [B,...,N]
input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL]
# coarse-to-fine: smoothly mask positional encoding for BARF
if opt.barf_c2f is not None:
# set weights for different frequency bands
start,end = opt.barf_c2f
alpha = (self.progress.data-start)/(end-start)*L
k = torch.arange(L,dtype=torch.float32,device=opt.device)
weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
# apply weights
shape = input_enc.shape
input_enc = (input_enc.view(-1,L)*weight).view(*shape)
return input_enc
================================================
FILE: model/base.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import torch.utils.tensorboard
import visdom
import importlib
import tqdm
from easydict import EasyDict as edict
import util,util_vis
from util import log,debug
# ============================ main engine for training and evaluation ============================
class Model():
def __init__(self,opt):
super().__init__()
os.makedirs(opt.output_path,exist_ok=True)
def load_dataset(self,opt,eval_split="val"):
data = importlib.import_module("data.{}".format(opt.data.dataset))
log.info("loading training data...")
self.train_data = data.Dataset(opt,split="train",subset=opt.data.train_sub)
self.train_loader = self.train_data.setup_loader(opt,shuffle=True)
log.info("loading test data...")
if opt.data.val_on_test: eval_split = "test"
self.test_data = data.Dataset(opt,split=eval_split,subset=opt.data.val_sub)
self.test_loader = self.test_data.setup_loader(opt,shuffle=False)
def build_networks(self,opt):
graph = importlib.import_module("model.{}".format(opt.model))
log.info("building networks...")
self.graph = graph.Graph(opt).to(opt.device)
def setup_optimizer(self,opt):
log.info("setting up optimizers...")
optimizer = getattr(torch.optim,opt.optim.algo)
self.optim = optimizer([dict(params=self.graph.parameters(),lr=opt.optim.lr)])
# set up scheduler
if opt.optim.sched:
scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
self.sched = scheduler(self.optim,**kwargs)
def restore_checkpoint(self,opt):
epoch_start,iter_start = None,None
if opt.resume:
log.info("resuming from previous checkpoint...")
epoch_start,iter_start = util.restore_checkpoint(opt,self,resume=opt.resume)
elif opt.load is not None:
log.info("loading weights from checkpoint {}...".format(opt.load))
epoch_start,iter_start = util.restore_checkpoint(opt,self,load_name=opt.load)
else:
log.info("initializing weights from scratch...")
self.epoch_start = epoch_start or 0
self.iter_start = iter_start or 0
def setup_visualizer(self,opt):
log.info("setting up visualizers...")
if opt.tb:
self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path,flush_secs=10)
if opt.visdom:
# check if visdom server is runninng
is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port)
retry = None
while not is_open:
retry = input("visdom port ({}) not open, retry? (y/n) ".format(opt.visdom.port))
if retry not in ["y","n"]: continue
if retry=="y":
is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port)
else: break
self.vis = visdom.Visdom(server=opt.visdom.server,port=opt.visdom.port,env=opt.group)
def train(self,opt):
# before training
log.title("TRAINING START")
self.timer = edict(start=time.time(),it_mean=None)
self.it = self.iter_start
# training
if self.iter_start==0: self.validate(opt,ep=0)
for self.ep in range(self.epoch_start,opt.max_epoch):
self.train_epoch(opt)
# after training
if opt.tb:
self.tb.flush()
self.tb.close()
if opt.visdom: self.vis.close()
log.title("TRAINING DONE")
def train_epoch(self,opt):
# before train epoch
self.graph.train()
# train epoch
loader = tqdm.tqdm(self.train_loader,desc="training epoch {}".format(self.ep+1),leave=False)
for batch in loader:
# train iteration
var = edict(batch)
var = util.move_to_device(var,opt.device)
loss = self.train_iteration(opt,var,loader)
# after train epoch
lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
log.loss_train(opt,self.ep+1,lr,loss.all,self.timer)
if opt.optim.sched: self.sched.step()
if (self.ep+1)%opt.freq.val==0: self.validate(opt,ep=self.ep+1)
if (self.ep+1)%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=self.ep+1,it=self.it)
def train_iteration(self,opt,var,loader):
# before train iteration
self.timer.it_start = time.time()
# train iteration
self.optim.zero_grad()
var = self.graph.forward(opt,var,mode="train")
loss = self.graph.compute_loss(opt,var,mode="train")
loss = self.summarize_loss(opt,var,loss)
loss.all.backward()
self.optim.step()
# after train iteration
if (self.it+1)%opt.freq.scalar==0: self.log_scalars(opt,var,loss,step=self.it+1,split="train")
if (self.it+1)%opt.freq.vis==0: self.visualize(opt,var,step=self.it+1,split="train")
self.it += 1
loader.set_postfix(it=self.it,loss="{:.3f}".format(loss.all))
self.timer.it_end = time.time()
util.update_timer(opt,self.timer,self.ep,len(loader))
return loss
def summarize_loss(self,opt,var,loss):
loss_all = 0.
assert("all" not in loss)
# weigh losses
for key in loss:
assert(key in opt.loss_weight)
assert(loss[key].shape==())
if opt.loss_weight[key] is not None:
assert not torch.isinf(loss[key]),"loss {} is Inf".format(key)
assert not torch.isnan(loss[key]),"loss {} is NaN".format(key)
loss_all += 10**float(opt.loss_weight[key])*loss[key]
loss.update(all=loss_all)
return loss
@torch.no_grad()
def validate(self,opt,ep=None):
self.graph.eval()
loss_val = edict()
loader = tqdm.tqdm(self.test_loader,desc="validating",leave=False)
for it,batch in enumerate(loader):
var = edict(batch)
var = util.move_to_device(var,opt.device)
var = self.graph.forward(opt,var,mode="val")
loss = self.graph.compute_loss(opt,var,mode="val")
loss = self.summarize_loss(opt,var,loss)
for key in loss:
loss_val.setdefault(key,0.)
loss_val[key] += loss[key]*len(var.idx)
loader.set_postfix(loss="{:.3f}".format(loss.all))
if it==0: self.visualize(opt,var,step=ep,split="val")
for key in loss_val: loss_val[key] /= len(self.test_data)
self.log_scalars(opt,var,loss_val,step=ep,split="val")
log.loss_val(opt,loss_val.all)
@torch.no_grad()
def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
for key,value in loss.items():
if key=="all": continue
if opt.loss_weight[key] is not None:
self.tb.add_scalar("{0}/loss_{1}".format(split,key),value,step)
if metric is not None:
for key,value in metric.items():
self.tb.add_scalar("{0}/{1}".format(split,key),value,step)
@torch.no_grad()
def visualize(self,opt,var,step=0,split="train"):
raise NotImplementedError
def save_checkpoint(self,opt,ep=0,it=0,latest=False):
util.save_checkpoint(opt,self,ep=ep,it=it,latest=latest)
if not latest:
log.info("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group,opt.name,ep,it))
# ============================ computation graph for forward/backprop ============================
class Graph(torch.nn.Module):
def __init__(self,opt):
super().__init__()
def forward(self,opt,var,mode=None):
raise NotImplementedError
return var
def compute_loss(self,opt,var,mode=None):
loss = edict()
raise NotImplementedError
return loss
def L1_loss(self,pred,label=0):
loss = (pred.contiguous()-label).abs()
return loss.mean()
def MSE_loss(self,pred,label=0):
loss = (pred.contiguous()-label)**2
return loss.mean()
================================================
FILE: model/l2g_nerf.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import visdom
import matplotlib.pyplot as plt
import util,util_vis
from util import log,debug
from . import nerf
import camera
import roma
# ============================ main engine for training and evaluation ============================
class Model(nerf.Model):
def __init__(self,opt):
super().__init__(opt)
def build_networks(self,opt):
super().build_networks(opt)
if opt.camera.noise:
# pre-generate synthetic pose perturbation
so3_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_r
t_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_t
self.graph.pose_noise = torch.cat([camera.lie.so3_to_SO3(so3_noise),t_noise[...,None]],dim=-1) # [...,3,4]
self.graph.warp_embedding = torch.nn.Embedding(len(self.train_data),opt.arch.embedding_dim).to(opt.device)
self.graph.warp_mlp = localWarp(opt).to(opt.device)
pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
# add synthetic pose perturbation to all training data
if opt.data.dataset=="blender":
pose = pose_GT
if opt.camera.noise:
pose = camera.pose.compose([pose, self.graph.pose_noise])
else: pose = self.graph.pose_eye[None].repeat(len(self.train_data),1,1)
# use Embedding so it could be checkpointed
self.graph.optimised_training_poses = torch.nn.Embedding(len(self.train_data),12,_weight=pose.view(-1,12)).to(opt.device)
# auto near/far for blender dataset
if opt.data.dataset=="blender":
idx_range = torch.arange(len(self.train_data),dtype=torch.long,device=opt.device)
idx_X,idx_Y = torch.meshgrid(idx_range,idx_range)
self.graph.idx_grid = torch.stack([idx_X,idx_Y],dim=-1).view(-1,2)
if opt.error_map_size:
self.graph.error_map = torch.ones([len(self.train_data), opt.error_map_size*opt.error_map_size], dtype=torch.float).to(opt.device)
def setup_optimizer(self,opt):
super().setup_optimizer(opt)
optimizer = getattr(torch.optim,opt.optim.algo)
self.optim_pose = optimizer([dict(params=self.graph.warp_embedding.parameters(),lr=opt.optim.lr_pose)])
self.optim_pose.add_param_group(dict(params=self.graph.warp_mlp.parameters(),lr=opt.optim.lr_pose))
# set up scheduler
if opt.optim.sched_pose:
scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type)
if opt.optim.lr_pose_end:
assert(opt.optim.sched_pose.type=="ExponentialLR")
opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter)
kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!="type" }
self.sched_pose = scheduler(self.optim_pose,**kwargs)
def train_iteration(self,opt,var,loader):
self.optim_pose.zero_grad()
if opt.optim.warmup_pose:
# simple linear warmup of pose learning rate
self.optim_pose.param_groups[0]["lr_orig"] = self.optim_pose.param_groups[0]["lr"] # cache the original learning rate
self.optim_pose.param_groups[0]["lr"] *= min(1,self.it/opt.optim.warmup_pose)
loss = super().train_iteration(opt,var,loader)
self.optim_pose.step()
if opt.optim.warmup_pose:
self.optim_pose.param_groups[0]["lr"] = self.optim_pose.param_groups[0]["lr_orig"] # reset learning rate
if opt.optim.sched_pose: self.sched_pose.step()
self.graph.nerf.progress.data.fill_(self.it/opt.max_iter)
if opt.nerf.fine_sampling:
self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter)
return loss
@torch.no_grad()
def validate(self,opt,ep=None):
pose,pose_GT = self.get_all_training_poses(opt)
_,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
super().validate(opt,ep=ep)
@torch.no_grad()
def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
if split=="train":
# log learning rate
lr = self.optim_pose.param_groups[0]["lr"]
self.tb.add_scalar("{0}/{1}".format(split,"lr_pose"),lr,step)
# compute pose error
if split=="train" and opt.data.dataset in ["blender","llff"]:
pose,pose_GT = self.get_all_training_poses(opt)
pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
self.tb.add_scalar("{0}/error_R".format(split),error.R.mean(),step)
self.tb.add_scalar("{0}/error_t".format(split),error.t.mean(),step)
@torch.no_grad()
def visualize(self,opt,var,step=0,split="train"):
super().visualize(opt,var,step=step,split=split)
if opt.visdom:
if split=="val":
pose,pose_GT = self.get_all_training_poses(opt)
pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose_aligned,pose_GT])
@torch.no_grad()
def get_all_training_poses(self,opt):
# get ground-truth (canonical) camera poses
pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
pose = self.graph.optimised_training_poses.weight.data.detach().clone().view(-1,3,4)
return pose,pose_GT
@torch.no_grad()
def prealign_cameras(self,opt,pose,pose_GT):
# compute 3D similarity transform via Procrustes analysis
center = torch.zeros(1,1,3,device=opt.device)
center_pred = camera.cam2world(center,pose)[:,0] # [N,3]
center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3]
try:
sim3 = camera.procrustes_analysis(center_GT,center_pred)
except:
print("warning: SVD did not converge...")
sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device))
# align the camera poses
center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0
R_aligned = pose[...,:3]@sim3.R.t()
t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
pose_aligned = camera.pose(R=R_aligned,t=t_aligned)
return pose_aligned,sim3
@torch.no_grad()
def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
# measure errors in rotation and translation
R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1)
R_GT,t_GT = pose_GT.split([3,1],dim=-1)
R_error = camera.rotation_distance(R_aligned,R_GT)
t_error = (t_aligned-t_GT)[...,0].norm(dim=-1)
error = edict(R=R_error,t=t_error)
return error
@torch.no_grad()
def evaluate_full(self,opt):
self.graph.eval()
# evaluate rotation/translation
pose,pose_GT = self.get_all_training_poses(opt)
pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
print("--------------------------")
print("rot: {:8.3f}".format(np.rad2deg(error.R.mean().cpu())))
print("trans: {:10.5f}".format(error.t.mean()))
print("--------------------------")
# dump numbers
quant_fname = "{}/quant_pose.txt".format(opt.output_path)
with open(quant_fname,"w") as file:
for i,(err_R,err_t) in enumerate(zip(error.R,error.t)):
file.write("{} {} {}\n".format(i,err_R.item(),err_t.item()))
# evaluate novel view synthesis
super().evaluate_full(opt)
@torch.enable_grad()
def evaluate_test_time_photometric_optim(self,opt,var):
# use another se3 Parameter to absorb the remaining pose errors
var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device))
optimizer = getattr(torch.optim,opt.optim.algo)
optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)])
iterator = tqdm.trange(opt.optim.test_iter,desc="test-time optim.",leave=False,position=1)
for it in iterator:
optim_pose.zero_grad()
var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test)
var = self.graph.forward(opt,var,mode="test-optim")
loss = self.graph.compute_loss(opt,var,mode="test-optim")
loss = self.summarize_loss(opt,var,loss)
loss.all.backward()
optim_pose.step()
iterator.set_postfix(loss="{:.3f}".format(loss.all))
return var
@torch.no_grad()
def generate_videos_pose(self,opt):
self.graph.eval()
fig = plt.figure(figsize=(10,10) if opt.data.dataset=="blender" else (16,8))
cam_path = "{}/poses".format(opt.output_path)
os.makedirs(cam_path,exist_ok=True)
ep_list = []
for ep in range(0,opt.max_iter+1,opt.freq.ckpt):
# load checkpoint (0 is random init)
if ep!=0:
try: util.restore_checkpoint(opt,self,resume=ep)
except: continue
# get the camera poses
pose,pose_ref = self.get_all_training_poses(opt)
if opt.data.dataset in ["blender","llff"]:
pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref)
pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu()
dict(
blender=util_vis.plot_save_poses_blender,
llff=util_vis.plot_save_poses,
)[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep)
else:
pose = pose.detach().cpu()
util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep)
ep_list.append(ep)
plt.close()
# write videos
print("writing videos...")
list_fname = "{}/temp.list".format(cam_path)
with open(list_fname,"w") as file:
for ep in ep_list: file.write("file {}.png\n".format(ep))
cam_vid_fname = "{}/poses.mp4".format(opt.output_path)
os.system("ffmpeg -y -r 4 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1".format(list_fname,cam_vid_fname))
os.remove(list_fname)
# ============================ computation graph for forward/backprop ============================
class Graph(nerf.Graph):
def __init__(self,opt):
super().__init__(opt)
self.nerf = NeRF(opt)
if opt.nerf.fine_sampling:
self.nerf_fine = NeRF(opt)
self.pose_eye = torch.eye(3,4).to(opt.device)
def get_pose(self,opt,var,mode=None):
if mode=="train":
# add the pre-generated pose perturbations
if opt.data.dataset=="blender":
if opt.camera.noise:
var.pose_noise = self.pose_noise[var.idx]
pose = camera.pose.compose([var.pose, var.pose_noise])
else: pose = var.pose
else: pose = self.pose_eye[None]
# add learnable pose correction
batch_size = len(var.idx)
if opt.error_map_size:
num_points = var.ray_idx.shape[1]
camera_cords_grid_3D = camera.gather_camera_cords_grid_3D(opt,batch_size,intr=var.intr,ray_idx=var.ray_idx).detach()
else:
num_points = len(var.ray_idx)
camera_cords_grid_3D = camera.get_camera_cords_grid_3D(opt,batch_size,intr=var.intr,ray_idx=var.ray_idx).detach()
camera_cords_grid_2D = camera_cords_grid_3D[...,:2]
embedding = self.warp_embedding.weight[var.idx,None,:].expand(-1,num_points,-1)
local_se3_refine = self.warp_mlp(opt,torch.cat((camera_cords_grid_2D,embedding),dim=-1))
local_pose_refine = camera.lie.se3_to_SE3(local_se3_refine)
local_pose = camera.pose.compose([local_pose_refine, pose[:,None,...]])
return local_pose
elif mode in ["val","eval","test-optim"]:
# align test pose to refined coordinate system (up to sim3)
sim3 = self.sim3
center = torch.zeros(1,1,3,device=opt.device)
center = camera.cam2world(center,var.pose)[:,0] # [N,3]
center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1
R_aligned = var.pose[...,:3]@self.sim3.R
t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
pose = camera.pose(R=R_aligned,t=t_aligned)
if opt.optim.test_photo and mode!="val":
pose = camera.pose.compose([var.pose_refine_test, pose])
else: pose = var.pose
return pose
def forward(self,opt,var,mode=None):
# rescale the size of the scene condition on the pose
if opt.data.dataset=="blender":
depth_min,depth_max = opt.nerf.depth.range
position = camera.Pose().invert(self.optimised_training_poses.weight.data.detach().clone().view(-1,3,4))[...,-1]
diameter = ((position[self.idx_grid[...,0]]-position[self.idx_grid[...,1]]).norm(dim=-1)).max()
depth_min_new = (depth_min/(depth_max+depth_min))*diameter
depth_max_new = (depth_max/(depth_max+depth_min))*diameter
opt.nerf.depth.range = [depth_min_new, depth_max_new]
# render images
batch_size = len(var.idx)
if opt.nerf.rand_rays and mode in ["train"]:
# sample rays for optimization
if opt.error_map_size:
sample_weight = self.error_map + 2*self.error_map.mean(-1,keepdim=True) # 1/3 importance + 2/3 random
var.ray_idx_coarse = torch.multinomial(sample_weight, opt.nerf.rand_rays//batch_size, replacement=False) # [B, N], but in [0, opt.error_map_size*opt.error_map_size)
inds_x, inds_y = var.ray_idx_coarse // opt.error_map_size, var.ray_idx_coarse % opt.error_map_size # `//` will throw a warning in torch 1.10... anyway.
sx, sy = opt.H / opt.error_map_size, opt.W / opt.error_map_size
inds_x = (inds_x * sx + torch.rand(batch_size, opt.nerf.rand_rays//batch_size, device=opt.device) * sx).long().clamp(max=opt.H - 1)
inds_y = (inds_y * sy + torch.rand(batch_size, opt.nerf.rand_rays//batch_size, device=opt.device) * sy).long().clamp(max=opt.W - 1)
var.ray_idx = inds_x * opt.W + inds_y
else:
var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]# 3/3 random
local_pose = self.get_pose(opt,var,mode=mode)
ret = self.local_render(opt,local_pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
elif opt.nerf.rand_rays and mode in ["test-optim"]:
# sample random rays for optimization
var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]
pose = self.get_pose(opt,var,mode=mode)
ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
else:
# render full image (process in slices)
pose = self.get_pose(opt,var,mode=mode)
ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \
self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]
var.update(ret)
return var
def compute_loss(self,opt,var,mode=None):
loss = edict()
batch_size = len(var.idx)
image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1)
if opt.nerf.rand_rays and mode in ["train","test-optim"]:
if opt.error_map_size:
image = torch.gather(image, 1, var.ray_idx[...,None].expand(-1,-1,3))
else:
image = image[:,var.ray_idx]
# compute image losses
if opt.loss_weight.render is not None:
render_error = ((var.rgb-image)**2).mean(-1)
loss.render = render_error.mean()
if mode in ["train"] and opt.error_map_size:
ema_error = 0.1 * torch.gather(self.error_map, 1, var.ray_idx_coarse) + 0.9 * render_error.detach()
self.error_map.scatter_(1, var.ray_idx_coarse, ema_error)
if opt.loss_weight.render_fine is not None:
assert(opt.nerf.fine_sampling)
loss.render_fine = self.MSE_loss(var.rgb_fine,image)
# global alignment
if mode in ["train"]:
source = torch.cat((var.camera_grid_3D,var.camera_center),dim=1)
target = torch.cat((var.grid_3D,var.center),dim=1)
R_global, t_global = roma.rigid_points_registration(target, source)
svd_poses = torch.cat((R_global,t_global[...,None]),-1)
self.optimised_training_poses.weight.data = svd_poses.detach().clone().view(-1,12)
if opt.loss_weight.global_alignment is not None:
loss.global_alignment = self.MSE_loss(target,camera.cam2world(source,svd_poses))
return loss
def local_render(self,opt,local_pose,intr=None,ray_idx=None,mode=None):
batch_size = len(local_pose)
if opt.error_map_size:
camera_grid_3D = camera.gather_camera_cords_grid_3D(opt,batch_size,intr=intr,ray_idx=ray_idx).detach()
else:
camera_grid_3D = camera.get_camera_cords_grid_3D(opt,batch_size,intr=intr,ray_idx=ray_idx).detach()
camera_center = torch.zeros_like(camera_grid_3D) # [B,HW,3]
grid_3D = camera.cam2world(camera_grid_3D[...,None,:],local_pose)[...,0,:] # [B,HW,3]
center = camera.cam2world(camera_center[...,None,:],local_pose)[...,0,:] # [B,HW,3]
ray = grid_3D-center # [B,HW,3]
ret = edict(camera_grid_3D=camera_grid_3D, camera_center=camera_center, grid_3D=grid_3D, center=center) # [B,HW,K] use for global alignment
if opt.camera.ndc:
# convert center/ray representations to NDC
center,ray = camera.convert_NDC(opt,center,ray,intr=intr)
# render with main MLP
depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1]
rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode)
rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples)
ret.update(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K]
# render with fine MLP from coarse MLP
if opt.nerf.fine_sampling:
with torch.no_grad():
# resample depth acoording to coarse empirical distribution
depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1]
depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1]
depth_samples = depth_samples.sort(dim=2).values
rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode)
rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples)
ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K]
return ret
class NeRF(nerf.NeRF):
def __init__(self,opt):
super().__init__(opt)
self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed
def positional_encoding(self,opt,input,L): # [B,...,N]
input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL]
# coarse-to-fine: smoothly mask positional encoding for BARF
if opt.barf_c2f is not None:
# set weights for different frequency bands
start,end = opt.barf_c2f
alpha = (self.progress.data-start)/(end-start)*L
k = torch.arange(L,dtype=torch.float32,device=opt.device)
weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
# apply weights
shape = input_enc.shape
input_enc = (input_enc.view(-1,L)*weight).view(*shape)
return input_enc
class localWarp(torch.nn.Module):
def __init__(self, opt):
super().__init__()
# point-wise se3 prediction
input_2D_dim = 2
self.mlp_warp = torch.nn.ModuleList()
L = util.get_layer_dims(opt.arch.layers_warp)
for li,(k_in,k_out) in enumerate(L):
if li==0: k_in = input_2D_dim+opt.arch.embedding_dim
if li in opt.arch.skip_warp: k_in += input_2D_dim+opt.arch.embedding_dim
linear = torch.nn.Linear(k_in,k_out)
self.mlp_warp.append(linear)
def forward(self,opt,uvf):
feat = uvf
for li,layer in enumerate(self.mlp_warp):
if li in opt.arch.skip_warp: feat = torch.cat([feat,uvf],dim=-1)
feat = layer(feat)
if li!=len(self.mlp_warp)-1:
feat = torch_F.relu(feat)
warp = feat
return warp
================================================
FILE: model/l2g_planar.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import PIL
import PIL.Image,PIL.ImageDraw
import imageio
import util,util_vis
from util import log,debug
from . import base
import warp
import roma
from kornia.geometry.homography import find_homography_dlt
# ============================ main engine for training and evaluation ============================
class Model(base.Model):
def __init__(self,opt):
super().__init__(opt)
opt.H_crop,opt.W_crop = opt.data.patch_crop
def load_dataset(self,opt,eval_split=None):
image_raw = PIL.Image.open(opt.data.image_fname).convert('RGB')
self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device)
def build_networks(self,opt):
super().build_networks(opt)
self.graph.warp_embedding = torch.nn.Embedding(opt.batch_size,opt.arch.embedding_dim).to(opt.device)
self.graph.warp_mlp = localWarp(opt).to(opt.device)
def setup_optimizer(self,opt):
log.info("setting up optimizers...")
optim_list = [
dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr),
]
optimizer = getattr(torch.optim,opt.optim.algo)
self.optim = optimizer(optim_list)
# set up scheduler
if opt.optim.sched:
scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
if opt.optim.lr_end:
assert(opt.optim.sched.type=="ExponentialLR")
opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter)
kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
self.sched = scheduler(self.optim,**kwargs)
# warp
self.optim_warp = optimizer([dict(params=self.graph.warp_embedding.parameters(),lr=opt.optim.lr_warp)])
self.optim_warp.add_param_group(dict(params=self.graph.warp_mlp.parameters(),lr=opt.optim.lr_warp))
# set up scheduler
if opt.optim.sched_warp:
scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_warp.type)
if opt.optim.lr_warp_end:
assert(opt.optim.sched_warp.type=="ExponentialLR")
opt.optim.sched_warp.gamma = (opt.optim.lr_warp_end/opt.optim.lr_warp)**(1./opt.max_iter)
kwargs = { k:v for k,v in opt.optim.sched_warp.items() if k!="type" }
self.sched_warp = scheduler(self.optim_warp,**kwargs)
def setup_visualizer(self,opt):
super().setup_visualizer(opt)
# set colors for visualization
box_colors = ["#ff0000","#40afff","#9314ff","#ffd700","#00ff00"]
box_colors = list(map(util.colorcode_to_number,box_colors))
self.box_colors = np.array(box_colors).astype(int)
assert(len(self.box_colors)==opt.batch_size)
# create visualization directory
self.vis_path = "{}/vis".format(opt.output_path)
os.makedirs(self.vis_path,exist_ok=True)
self.video_fname = "{}/vis.mp4".format(opt.output_path)
def train(self,opt):
# before training
log.title("TRAINING START")
self.timer = edict(start=time.time(),it_mean=None)
self.ep = self.it = self.vis_it = 0
self.graph.train()
var = edict(idx=torch.arange(opt.batch_size))
# pre-generate perturbations
self.homo_pert, self.rot_pert, self.trans_pert, var.image_pert = self.generate_warp_perturbation(opt)
# train
var = util.move_to_device(var,opt.device)
loader = tqdm.trange(opt.max_iter,desc="training",leave=False)
# visualize initial state
var = self.graph.forward(opt,var)
self.visualize(opt,var,step=0)
for it in loader:
# train iteration
loss = self.train_iteration(opt,var,loader)
# after training
os.system("ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}".format(self.vis_path,self.video_fname))
self.save_checkpoint(opt,ep=None,it=self.it)
if opt.tb:
self.tb.flush()
self.tb.close()
if opt.visdom: self.vis.close()
log.title("TRAINING DONE")
def train_iteration(self,opt,var,loader):
self.optim_warp.zero_grad()
loss = super().train_iteration(opt,var,loader)
self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter)
self.optim_warp.step()
if opt.optim.sched_warp: self.sched_warp.step()
if opt.optim.sched: self.sched.step()
return loss
def generate_warp_perturbation(self,opt):
# pre-generate perturbations (translational noise + homography noise)
def create_random_perturbation(batch_size):
if opt.warp.dof==1:
homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
trans_pert = torch.zeros(batch_size,2,device=opt.device)*opt.warp.noise_t
elif opt.warp.dof==2:
homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
rot_pert = torch.zeros(batch_size,1,device=opt.device)*opt.warp.noise_r
trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
elif opt.warp.dof==3:
homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
elif opt.warp.dof==8:
homo_pert = torch.randn(batch_size,8,device=opt.device)*opt.warp.noise_h
homo_pert[:,:2]=0
rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
else: assert(False)
return homo_pert, rot_pert, trans_pert
homo_pert = torch.zeros(opt.batch_size,8,device=opt.device)
rot_pert = torch.zeros(opt.batch_size,1,device=opt.device)
trans_pert = torch.zeros(opt.batch_size,2,device=opt.device)
for i in range(opt.batch_size):
homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)
while not warp.check_corners_in_range_compose(opt, homo_pert_i, rot_pert_i, trans_pert_i):
homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)
homo_pert[i], rot_pert[i], trans_pert[i] = homo_pert_i, rot_pert_i, trans_pert_i
if opt.warp.fix_first:
homo_pert[0],rot_pert[0],trans_pert[0] = 0,0,0
# create warped image patches
xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2]
xy_grid_hom = warp.camera.to_hom(xy_grid)
warp_matrix = warp.lie.sl3_to_SL3(homo_pert)
warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)
xy_grid_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]
warp_matrix = warp.lie.so2_to_SO2(rot_pert)
xy_grid_warped = xy_grid_warped@warp_matrix.transpose(-2,-1) # [B,HW,2]
xy_grid_warped = xy_grid_warped+trans_pert[...,None,:]
xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2])
xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W,
xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1)
image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1)
image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False)
return homo_pert, rot_pert, trans_pert, image_pert_all
def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):
image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
draw = PIL.ImageDraw.Draw(draw_pil)
corners_all = warp.warp_corners_compose(opt, homo_pert, rot_pert, trans_pert)
corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
for i,corners in enumerate(corners_all):
P = [tuple(float(n) for n in corners[j]) for j in range(4)]
draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
image_pil.alpha_composite(draw_pil)
image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
return image_tensor
def visualize_patches_use_matrix(self,opt,warp_matrix):
image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
draw = PIL.ImageDraw.Draw(draw_pil)
corners_all = warp.warp_corners_use_matrix(opt,warp_matrix)
corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
for i,corners in enumerate(corners_all):
P = [tuple(float(n) for n in corners[j]) for j in range(4)]
draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
image_pil.alpha_composite(draw_pil)
image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
return image_tensor
@torch.no_grad()
def predict_entire_image(self,opt):
xy_grid = warp.get_normalized_pixel_grid(opt)[:1]
rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3]
image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1)
return image
@torch.no_grad()
def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
# compute PSNR
psnr = -10*loss.render.log10()
self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step)
# warp error
pred_corners = warp.warp_corners_use_matrix(opt,self.graph.warp_matrix)
pred_corners[...,0] = (pred_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
pred_corners[...,1] = (pred_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
gt_corners = warp.warp_corners_compose(opt, self.homo_pert, self.rot_pert, self.trans_pert)
gt_corners[...,0] = (gt_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
gt_corners[...,1] = (gt_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
warp_error = (pred_corners-gt_corners).norm(dim=-1).mean()
self.tb.add_scalar("{0}/{1}".format(split,"warp error"),warp_error,step)
@torch.no_grad()
def visualize(self,opt,var,step=0,split="train"):
# dump frames for writing to video
frame_GT = self.visualize_patches_compose(opt,self.homo_pert, self.rot_pert, self.trans_pert)
frame = self.visualize_patches_use_matrix(opt,self.graph.warp_matrix)
frame2 = self.predict_entire_image(opt)
frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy()
imageio.imsave("{}/{}.png".format(self.vis_path,self.vis_it),frame_cat)
self.vis_it += 1
# visualize in Tensorboard
if opt.tb:
colors = self.box_colors
util_vis.tb_image(opt,self.tb,step,split,"image_pert",util_vis.color_border(var.image_pert,colors))
util_vis.tb_image(opt,self.tb,step,split,"rgb_warped",util_vis.color_border(var.rgb_warped_map,colors))
util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes",frame[None])
util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes_GT",frame_GT[None])
util_vis.tb_image(opt,self.tb,self.it+1,"train","image_entire",frame2[None])
local_warp_field = torch.cat([(var.local_warp_field+1)/2,torch.zeros_like(var.local_warp_field[:,:1,...])],dim=1) # [B,3,H,W]
global_warp_field = torch.cat([(var.global_warp_field+1)/2,torch.zeros_like(var.global_warp_field[:,:1,...])],dim=1) # [B,3,H,W]
util_vis.tb_image(opt,self.tb,step,split,"warp_field_local",util_vis.color_border(local_warp_field,colors))
util_vis.tb_image(opt,self.tb,step,split,"warp_field_global",util_vis.color_border(global_warp_field,colors))
# ============================ computation graph for forward/backprop ============================
class Graph(base.Graph):
def __init__(self,opt):
super().__init__(opt)
self.neural_image = NeuralImageFunction(opt)
def forward(self,opt,var,mode=None):
# warp
xy_grid = warp.get_normalized_pixel_grid_crop(opt)
warp_embedding = self.warp_embedding.weight[:,None,:].expand(-1,xy_grid.shape[1],-1)
local_warp_param = self.warp_mlp(opt,torch.cat((xy_grid,warp_embedding),dim=-1))
if opt.warp.fix_first:
local_warp_param[0] = 0
if opt.warp.dof==1:
local_warp_matrix = warp.lie.so2_to_SO2(local_warp_param)
var.local_warped_grid = (xy_grid[...,None,:]@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]
self.warp_matrix = roma.rigid_vectors_registration(xy_grid, var.local_warped_grid)
var.global_warped_grid = xy_grid@self.warp_matrix.transpose(-2,-1) # [B,HW,2]
elif opt.warp.dof==2:
var.local_warped_grid = xy_grid + local_warp_param
self.warp_matrix = var.local_warped_grid.mean(-2)-xy_grid.mean(-2)
var.global_warped_grid = xy_grid + self.warp_matrix[...,None,:]
elif opt.warp.dof==3:
xy_grid_hom = warp.camera.to_hom(xy_grid[...,None,:])
local_warp_matrix = warp.lie.se2_to_SE2(local_warp_param)
var.local_warped_grid = (xy_grid_hom@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]
R_global, t_global = roma.rigid_points_registration(xy_grid, var.local_warped_grid)
self.warp_matrix = torch.cat((R_global,t_global[...,None]),-1)
xy_grid_hom = warp.camera.to_hom(xy_grid)
var.global_warped_grid = xy_grid_hom@self.warp_matrix.transpose(-2,-1) # [B,HW,2]
elif opt.warp.dof==8:
xy_grid_hom = warp.camera.to_hom(xy_grid[...,None,:])
local_warp_matrix = warp.lie.se2_to_SE2(local_warp_param)
var.local_warped_grid = (xy_grid_hom@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]
self.warp_matrix = find_homography_dlt(xy_grid, var.local_warped_grid)
xy_grid_hom = warp.camera.to_hom(xy_grid)
global_warped_grid = xy_grid_hom@self.warp_matrix.transpose(-2,-1)
var.global_warped_grid = global_warped_grid[...,:2]/(global_warped_grid[...,2:]+1e-8) # [B,HW,2]
# render images
var.rgb_warped = self.neural_image.forward(opt,var.local_warped_grid) # [B,HW,3]
var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W]
var.local_warp_field = var.local_warped_grid.view(opt.batch_size,opt.H_crop,opt.W_crop,2).permute(0,3,1,2) # [B,2,H,W]
var.global_warp_field = var.global_warped_grid.view(opt.batch_size,opt.H_crop,opt.W_crop,2).permute(0,3,1,2) # [B,2,H,W]
return var
def compute_loss(self,opt,var,mode=None):
loss = edict()
if opt.loss_weight.render is not None:
image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1)
loss.render = self.MSE_loss(var.rgb_warped,image_pert)
if opt.loss_weight.global_alignment is not None:
loss.global_alignment = self.MSE_loss(var.local_warped_grid,var.global_warped_grid)
return loss
class NeuralImageFunction(torch.nn.Module):
def __init__(self,opt):
super().__init__()
self.define_network(opt)
self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed
def define_network(self,opt):
input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2
# point-wise RGB prediction
self.mlp = torch.nn.ModuleList()
L = util.get_layer_dims(opt.arch.layers)
for li,(k_in,k_out) in enumerate(L):
if li==0: k_in = input_2D_dim
if li in opt.arch.skip: k_in += input_2D_dim
linear = torch.nn.Linear(k_in,k_out)
if opt.barf_c2f and li==0:
# rescale first layer init (distribution was for pos.enc. but only xy is first used)
scale = np.sqrt(input_2D_dim/2.)
linear.weight.data *= scale
linear.bias.data *= scale
self.mlp.append(linear)
def forward(self,opt,coord_2D): # [B,...,3]
if opt.arch.posenc:
points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D)
points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3]
else: points_enc = coord_2D
feat = points_enc
# extract implicit features
for li,layer in enumerate(self.mlp):
if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)
feat = layer(feat)
if li!=len(self.mlp)-1:
feat = torch_F.relu(feat)
rgb = feat.sigmoid_() # [B,...,3]
return rgb
def positional_encoding(self,opt,input,L): # [B,...,N]
shape = input.shape
freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]
spectrum = input[...,None]*freq # [B,...,N,L]
sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
# coarse-to-fine: smoothly mask positional encoding for BARF
if opt.barf_c2f is not None:
# set weights for different frequency bands
start,end = opt.barf_c2f
alpha = (self.progress.data-start)/(end-start)*L
k = torch.arange(L,dtype=torch.float32,device=opt.device)
weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
# apply weights
shape = input_enc.shape
input_enc = (input_enc.view(-1,L)*weight).view(*shape)
return input_enc
class localWarp(torch.nn.Module):
def __init__(self, opt):
super().__init__()
# point-wise se3 prediction
input_2D_dim = 2
self.mlp_warp = torch.nn.ModuleList()
L = util.get_layer_dims(opt.arch.layers_warp)
for li,(k_in,k_out) in enumerate(L):
if li==0: k_in = input_2D_dim+opt.arch.embedding_dim
if li in opt.arch.skip_warp: k_in += input_2D_dim+opt.arch.embedding_dim
linear = torch.nn.Linear(k_in,k_out)
self.mlp_warp.append(linear)
def forward(self,opt,uvf):
feat = uvf
for li,layer in enumerate(self.mlp_warp):
if li in opt.arch.skip_warp: feat = torch.cat([feat,uvf],dim=-1)
feat = layer(feat)
if li!=len(self.mlp_warp)-1:
feat = torch_F.relu(feat)
warp = feat
return warp
================================================
FILE: model/nerf.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import lpips
from external.pohsun_ssim import pytorch_ssim
import util,util_vis
from util import log,debug
from . import base
import camera
# ============================ main engine for training and evaluation ============================
class Model(base.Model):
def __init__(self,opt):
super().__init__(opt)
self.lpips_loss = lpips.LPIPS(net="alex").to(opt.device)
def load_dataset(self,opt,eval_split="val"):
super().load_dataset(opt,eval_split=eval_split)
# prefetch all training data
self.train_data.prefetch_all_data(opt)
self.train_data.all = edict(util.move_to_device(self.train_data.all,opt.device))
def setup_optimizer(self,opt):
log.info("setting up optimizers...")
optimizer = getattr(torch.optim,opt.optim.algo)
self.optim = optimizer([dict(params=self.graph.nerf.parameters(),lr=opt.optim.lr)])
if opt.nerf.fine_sampling:
self.optim.add_param_group(dict(params=self.graph.nerf_fine.parameters(),lr=opt.optim.lr))
# set up scheduler
if opt.optim.sched:
scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
if opt.optim.lr_end:
assert(opt.optim.sched.type=="ExponentialLR")
opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter)
kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
self.sched = scheduler(self.optim,**kwargs)
def train(self,opt):
# before training
log.title("TRAINING START")
self.timer = edict(start=time.time(),it_mean=None)
self.graph.train()
self.ep = 0 # dummy for timer
# training
if self.iter_start==0: self.validate(opt,0)
loader = tqdm.trange(opt.max_iter,desc="training",leave=False)
for self.it in loader:
if self.it<self.iter_start: continue
# set var to all available images
var = self.train_data.all
self.train_iteration(opt,var,loader)
if opt.optim.sched: self.sched.step()
if self.it%opt.freq.val==0: self.validate(opt,self.it)
if self.it%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=None,it=self.it)
# after training
if opt.tb:
self.tb.flush()
self.tb.close()
if opt.visdom: self.vis.close()
log.title("TRAINING DONE")
@torch.no_grad()
def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
# log learning rate
if split=="train":
lr = self.optim.param_groups[0]["lr"]
self.tb.add_scalar("{0}/{1}".format(split,"lr"),lr,step)
if opt.nerf.fine_sampling:
lr = self.optim.param_groups[1]["lr"]
self.tb.add_scalar("{0}/{1}".format(split,"lr_fine"),lr,step)
# compute PSNR
psnr = -10*loss.render.log10()
self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step)
if opt.nerf.fine_sampling:
psnr = -10*loss.render_fine.log10()
self.tb.add_scalar("{0}/{1}".format(split,"PSNR_fine"),psnr,step)
@torch.no_grad()
def visualize(self,opt,var,step=0,split="train",eps=1e-10):
if opt.tb:
util_vis.tb_image(opt,self.tb,step,split,"image",var.image)
if not opt.nerf.rand_rays or split!="train":
if opt.model in ["barf","l2g_nerf"]:
scale = self.graph.sim3.s1/self.graph.sim3.s0
else: scale = 1
depth=var.depth/scale
if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1)
invdepth = (1-depth)/var.opacity if opt.camera.ndc else 1/(depth/var.opacity+eps)
invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno")
rgb_map = var.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]
util_vis.tb_image(opt,self.tb,step,split,"rgb",rgb_map)
util_vis.tb_image(opt,self.tb,step,split,"invdepth",invdepth_map)
if opt.nerf.fine_sampling:
depth=var.depth_fine/scale
if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1)
invdepth = (1-depth)/var.opacity_fine if opt.camera.ndc else 1/(depth/var.opacity_fine+eps)
invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno")
rgb_map = var.rgb_fine.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]
util_vis.tb_image(opt,self.tb,step,split,"rgb_fine",rgb_map)
util_vis.tb_image(opt,self.tb,step,split,"invdepth_fine",invdepth_map)
@torch.no_grad()
def get_all_training_poses(self,opt):
# get ground-truth (canonical) camera poses
pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
return None,pose_GT
@torch.no_grad()
def evaluate_full(self,opt,eps=1e-10):
self.graph.eval()
loader = tqdm.tqdm(self.test_loader,desc="evaluating",leave=False)
res = []
test_path = "{}/test_view".format(opt.output_path)
os.makedirs(test_path,exist_ok=True)
for i,batch in enumerate(loader):
var = edict(batch)
var = util.move_to_device(var,opt.device)
if opt.model in ["barf","l2g_nerf"] and opt.optim.test_photo:
# run test-time optimization to factorize imperfection in optimized poses from view synthesis evaluation
var = self.evaluate_test_time_photometric_optim(opt,var)
var = self.graph.forward(opt,var,mode="eval")
# evaluate view synthesis
if opt.model in ["barf","l2g_nerf"]:
scale = self.graph.sim3.s1/self.graph.sim3.s0
else: scale = 1
depth = var.depth/scale
if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1)
invdepth = (1-depth)/var.opacity if opt.camera.ndc else 1/(depth/var.opacity+eps)
invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno")
rgb_map = var.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]
psnr = -10*self.graph.MSE_loss(rgb_map,var.image).log10().item()
ssim = pytorch_ssim.ssim(rgb_map,var.image).item()
lpips = self.lpips_loss(rgb_map*2-1,var.image*2-1).item()
res.append(edict(psnr=psnr,ssim=ssim,lpips=lpips))
# dump novel views
torchvision_F.to_pil_image(rgb_map.cpu()[0]).save("{}/rgb_{}.png".format(test_path,i))
torchvision_F.to_pil_image(var.image.cpu()[0]).save("{}/rgb_GT_{}.png".format(test_path,i))
torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save("{}/depth_{}.png".format(test_path,i))
# show results in terminal
print("--------------------------")
print("PSNR: {:8.2f}".format(np.mean([r.psnr for r in res])))
print("SSIM: {:8.2f}".format(np.mean([r.ssim for r in res])))
print("LPIPS: {:8.2f}".format(np.mean([r.lpips for r in res])))
print("--------------------------")
# dump numbers to file
quant_fname = "{}/quant.txt".format(opt.output_path)
with open(quant_fname,"w") as file:
for i,r in enumerate(res):
file.write("{} {} {} {}\n".format(i,r.psnr,r.ssim,r.lpips))
@torch.no_grad()
def generate_videos_synthesis(self,opt,eps=1e-10):
self.graph.eval()
if opt.data.dataset=="blender":
test_path = "{}/test_view".format(opt.output_path)
# assume the test view synthesis are already generated
print("writing videos...")
rgb_vid_fname = "{}/test_view_rgb.mp4".format(opt.output_path)
depth_vid_fname = "{}/test_view_depth.mp4".format(opt.output_path)
os.system("ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(test_path,rgb_vid_fname))
os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(test_path,depth_vid_fname))
else:
pose_pred,pose_GT = self.get_all_training_poses(opt)
poses = pose_pred if opt.model in ["barf","l2g_nerf"] else pose_GT
if opt.model in ["barf","l2g_nerf"] and opt.data.dataset=="llff":
_,sim3 = self.prealign_cameras(opt,pose_pred,pose_GT)
scale = sim3.s1/sim3.s0
else: scale = 1
# rotate novel views around the "center" camera of all poses
idx_center = (poses-poses.mean(dim=0,keepdim=True))[...,3].norm(dim=-1).argmin()
pose_novel = camera.get_novel_view_poses(opt,poses[idx_center],N=60,scale=scale).to(opt.device)
# render the novel views
novel_path = "{}/novel_view".format(opt.output_path)
os.makedirs(novel_path,exist_ok=True)
pose_novel_tqdm = tqdm.tqdm(pose_novel,desc="rendering novel views",leave=False)
intr = edict(next(iter(self.test_loader))).intr[:1].to(opt.device) # grab intrinsics
for i,pose in enumerate(pose_novel_tqdm):
ret = self.graph.render_by_slices(opt,pose[None],intr=intr) if opt.nerf.rand_rays else \
self.graph.render(opt,pose[None],intr=intr)
depth=ret.depth/scale
if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1)
invdepth = (1-depth)/ret.opacity if opt.camera.ndc else 1/(depth/ret.opacity+eps)
invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno")
rgb_map = ret.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]
torchvision_F.to_pil_image(rgb_map.cpu()[0]).save("{}/rgb_{}.png".format(novel_path,i))
torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save("{}/depth_{}.png".format(novel_path,i))
# write videos
print("writing videos...")
rgb_vid_fname = "{}/novel_view_rgb.mp4".format(opt.output_path)
depth_vid_fname = "{}/novel_view_depth.mp4".format(opt.output_path)
os.system("ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,rgb_vid_fname))
os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,depth_vid_fname))
# ============================ computation graph for forward/backprop ============================
class Graph(base.Graph):
def __init__(self,opt):
super().__init__(opt)
self.nerf = NeRF(opt)
if opt.nerf.fine_sampling:
self.nerf_fine = NeRF(opt)
def forward(self,opt,var,mode=None):
batch_size = len(var.idx)
pose = self.get_pose(opt,var,mode=mode)
# render images
if opt.nerf.rand_rays and mode in ["train","test-optim"]:
# sample random rays for optimization
var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]
ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
else:
# render full image (process in slices)
ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \
self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]
var.update(ret)
return var
def compute_loss(self,opt,var,mode=None):
loss = edict()
batch_size = len(var.idx)
image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1)
if opt.nerf.rand_rays and mode in ["train","test-optim"]:
image = image[:,var.ray_idx]
# compute image losses
if opt.loss_weight.render is not None:
loss.render = self.MSE_loss(var.rgb,image)
if opt.loss_weight.render_fine is not None:
assert(opt.nerf.fine_sampling)
loss.render_fine = self.MSE_loss(var.rgb_fine,image)
return loss
def get_pose(self,opt,var,mode=None):
return var.pose
def render(self,opt,pose,intr=None,ray_idx=None,mode=None):
batch_size = len(pose)
center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3]
while ray.isnan().any(): # TODO: weird bug, ray becomes NaN arbitrarily if batch_size>1, not deterministic reproducible
center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3]
if ray_idx is not None:
# consider only subset of rays
center,ray = center[:,ray_idx],ray[:,ray_idx]
if opt.camera.ndc:
# convert center/ray representations to NDC
center,ray = camera.convert_NDC(opt,center,ray,intr=intr)
# render with main MLP
depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1]
rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode)
rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples)
ret = edict(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K]
# render with fine MLP from coarse MLP
if opt.nerf.fine_sampling:
with torch.no_grad():
# resample depth acoording to coarse empirical distribution
depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1]
depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1]
depth_samples = depth_samples.sort(dim=2).values
rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode)
rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples)
ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K]
return ret
def render_by_slices(self,opt,pose,intr=None,mode=None):
ret_all = edict(rgb=[],depth=[],opacity=[])
if opt.nerf.fine_sampling:
ret_all.update(rgb_fine=[],depth_fine=[],opacity_fine=[])
# render the image by slices for memory considerations
for c in range(0,opt.H*opt.W,opt.nerf.rand_rays):
ray_idx = torch.arange(c,min(c+opt.nerf.rand_rays,opt.H*opt.W),device=opt.device)
ret = self.render(opt,pose,intr=intr,ray_idx=ray_idx,mode=mode) # [B,R,3],[B,R,1]
for k in ret: ret_all[k].append(ret[k])
# group all slices of images
for k in ret_all: ret_all[k] = torch.cat(ret_all[k],dim=1)
return ret_all
def sample_depth(self,opt,batch_size,num_rays=None):
depth_min,depth_max = opt.nerf.depth.range
num_rays = num_rays or opt.H*opt.W
rand_samples = torch.rand(batch_size,num_rays,opt.nerf.sample_intvs,1,device=opt.device) if opt.nerf.sample_stratified else 0.5
rand_samples += torch.arange(opt.nerf.sample_intvs,device=opt.device)[None,None,:,None].float() # [B,HW,N,1]
depth_samples = rand_samples/opt.nerf.sample_intvs*(depth_max-depth_min)+depth_min # [B,HW,N,1]
depth_samples = dict(
metric=depth_samples,
inverse=1/(depth_samples+1e-8),
)[opt.nerf.depth.param]
return depth_samples
def sample_depth_from_pdf(self,opt,pdf):
depth_min,depth_max = opt.nerf.depth.range
# get CDF from PDF (along last dimension)
cdf = pdf.cumsum(dim=-1) # [B,HW,N]
cdf = torch.cat([torch.zeros_like(cdf[...,:1]),cdf],dim=-1) # [B,HW,N+1]
# take uniform samples
grid = torch.linspace(0,1,opt.nerf.sample_intvs_fine+1,device=opt.device) # [Nf+1]
unif = 0.5*(grid[:-1]+grid[1:]).repeat(*cdf.shape[:-1],1) # [B,HW,Nf]
idx = torch.searchsorted(cdf,unif,right=True) # [B,HW,Nf] \in {1...N}
# inverse transform sampling from CDF
depth_bin = torch.linspace(depth_min,depth_max,opt.nerf.sample_intvs+1,device=opt.device) # [N+1]
depth_bin = depth_bin.repeat(*cdf.shape[:-1],1) # [B,HW,N+1]
depth_low = depth_bin.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf]
depth_high = depth_bin.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf]
cdf_low = cdf.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf]
cdf_high = cdf.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf]
# linear interpolation
t = (unif-cdf_low)/(cdf_high-cdf_low+1e-8) # [B,HW,Nf]
depth_samples = depth_low+t*(depth_high-depth_low) # [B,HW,Nf]
return depth_samples[...,None] # [B,HW,Nf,1]
class NeRF(torch.nn.Module):
def __init__(self,opt):
super().__init__()
self.define_network(opt)
def define_network(self,opt):
input_3D_dim = 3+6*opt.arch.posenc.L_3D if opt.arch.posenc else 3
if opt.nerf.view_dep:
input_view_dim = 3+6*opt.arch.posenc.L_view if opt.arch.posenc else 3
# point-wise feature
self.mlp_feat = torch.nn.ModuleList()
L = util.get_layer_dims(opt.arch.layers_feat)
for li,(k_in,k_out) in enumerate(L):
if li==0: k_in = input_3D_dim
if li in opt.arch.skip: k_in += input_3D_dim
if li==len(L)-1: k_out += 1
linear = torch.nn.Linear(k_in,k_out)
if opt.arch.tf_init:
self.tensorflow_init_weights(opt,linear,out="first" if li==len(L)-1 else None)
self.mlp_feat.append(linear)
# RGB prediction
self.mlp_rgb = torch.nn.ModuleList()
L = util.get_layer_dims(opt.arch.layers_rgb)
feat_dim = opt.arch.layers_feat[-1]
for li,(k_in,k_out) in enumerate(L):
if li==0: k_in = feat_dim+(input_view_dim if opt.nerf.view_dep else 0)
linear = torch.nn.Linear(k_in,k_out)
if opt.arch.tf_init:
self.tensorflow_init_weights(opt,linear,out="all" if li==len(L)-1 else None)
self.mlp_rgb.append(linear)
def tensorflow_init_weights(self,opt,linear,out=None):
# use Xavier init instead of Kaiming init
relu_gain = torch.nn.init.calculate_gain("relu") # sqrt(2)
if out=="all":
torch.nn.init.xavier_uniform_(linear.weight)
elif out=="first":
torch.nn.init.xavier_uniform_(linear.weight[:1])
torch.nn.init.xavier_uniform_(linear.weight[1:],gain=relu_gain)
else:
torch.nn.init.xavier_uniform_(linear.weight,gain=relu_gain)
torch.nn.init.zeros_(linear.bias)
def forward(self,opt,points_3D,ray_unit=None,mode=None): # [B,...,3]
if opt.arch.posenc:
points_enc = self.positional_encoding(opt,points_3D,L=opt.arch.posenc.L_3D)
points_enc = torch.cat([points_3D,points_enc],dim=-1) # [B,...,6L+3]
else: points_enc = points_3D
feat = points_enc
# extract coordinate-based features
for li,layer in enumerate(self.mlp_feat):
if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)
feat = layer(feat)
if li==len(self.mlp_feat)-1:
density = feat[...,0]
if opt.nerf.density_noise_reg and mode=="train":
density += torch.randn_like(density)*opt.nerf.density_noise_reg
density_activ = getattr(torch_F,opt.arch.density_activ) # relu_,abs_,sigmoid_,exp_....
density = density_activ(density)
feat = feat[...,1:]
feat = torch_F.relu(feat)
# predict RGB values
if opt.nerf.view_dep:
assert(ray_unit is not None)
if opt.arch.posenc:
ray_enc = self.positional_encoding(opt,ray_unit,L=opt.arch.posenc.L_view)
ray_enc = torch.cat([ray_unit,ray_enc],dim=-1) # [B,...,6L+3]
else: ray_enc = ray_unit
feat = torch.cat([feat,ray_enc],dim=-1)
for li,layer in enumerate(self.mlp_rgb):
feat = layer(feat)
if li!=len(self.mlp_rgb)-1:
feat = torch_F.relu(feat)
rgb = feat.sigmoid_() # [B,...,3]
return rgb,density
def forward_samples(self,opt,center,ray,depth_samples,mode=None):
points_3D_samples = camera.get_3D_points_from_depth(opt,center,ray,depth_samples,multi_samples=True) # [B,HW,N,3]
if opt.nerf.view_dep:
ray_unit = torch_F.normalize(ray,dim=-1) # [B,HW,3]
ray_unit_samples = ray_unit[...,None,:].expand_as(points_3D_samples) # [B,HW,N,3]
else: ray_unit_samples = None
rgb_samples,density_samples = self.forward(opt,points_3D_samples,ray_unit=ray_unit_samples,mode=mode) # [B,HW,N],[B,HW,N,3]
return rgb_samples,density_samples
def composite(self,opt,ray,rgb_samples,density_samples,depth_samples):
ray_length = ray.norm(dim=-1,keepdim=True) # [B,HW,1]
# volume rendering: compute probability (using quadrature)
depth_intv_samples = depth_samples[...,1:,0]-depth_samples[...,:-1,0] # [B,HW,N-1]
depth_intv_samples = torch.cat([depth_intv_samples,torch.empty_like(depth_intv_samples[...,:1]).fill_(1e10)],dim=2) # [B,HW,N]
dist_samples = depth_intv_samples*ray_length # [B,HW,N]
sigma_delta = density_samples*dist_samples # [B,HW,N]
alpha = 1-(-sigma_delta).exp_() # [B,HW,N]
T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)).exp_() # [B,HW,N]
prob = (T*alpha)[...,None] # [B,HW,N,1]
# integrate RGB and depth weighted by probability
depth = (depth_samples*prob).sum(dim=2) # [B,HW,1]
rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3]
opacity = prob.sum(dim=2) # [B,HW,1]
if opt.nerf.setbg_opaque:
rgb = rgb+opt.data.bgcolor*(1-opacity)
return rgb,depth,opacity,prob # [B,HW,K]
def positional_encoding(self,opt,input,L): # [B,...,N]
shape = input.shape
freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]
spectrum = input[...,None]*freq # [B,...,N,L]
sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
return input_enc
================================================
FILE: model/planar.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import PIL
import PIL.Image,PIL.ImageDraw
import imageio
import util,util_vis
from util import log,debug
from . import base
import warp
# ============================ main engine for training and evaluation ============================
class Model(base.Model):
def __init__(self,opt):
super().__init__(opt)
opt.H_crop,opt.W_crop = opt.data.patch_crop
def load_dataset(self,opt,eval_split=None):
image_raw = PIL.Image.open(opt.data.image_fname)
self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device)
def build_networks(self,opt):
super().build_networks(opt)
self.graph.warp_param = torch.nn.Embedding(opt.batch_size,opt.warp.dof).to(opt.device)
torch.nn.init.zeros_(self.graph.warp_param.weight)
def setup_optimizer(self,opt):
log.info("setting up optimizers...")
optim_list = [
dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr),
dict(params=self.graph.warp_param.parameters(),lr=opt.optim.lr_warp),
]
optimizer = getattr(torch.optim,opt.optim.algo)
self.optim = optimizer(optim_list)
# set up scheduler
if opt.optim.sched:
scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
self.sched = scheduler(self.optim,**kwargs)
def setup_visualizer(self,opt):
super().setup_visualizer(opt)
# set colors for visualization
box_colors = ["#ff0000","#40afff","#9314ff","#ffd700","#00ff00"]
box_colors = list(map(util.colorcode_to_number,box_colors))
self.box_colors = np.array(box_colors).astype(int)
assert(len(self.box_colors)==opt.batch_size)
# create visualization directory
self.vis_path = "{}/vis".format(opt.output_path)
os.makedirs(self.vis_path,exist_ok=True)
self.video_fname = "{}/vis.mp4".format(opt.output_path)
def train(self,opt):
# before training
log.title("TRAINING START")
self.timer = edict(start=time.time(),it_mean=None)
self.ep = self.it = self.vis_it = 0
self.graph.train()
var = edict(idx=torch.arange(opt.batch_size))
# pre-generate perturbations
self.homo_pert, self.rot_pert, self.trans_pert, var.image_pert = self.generate_warp_perturbation(opt)
# train
var = util.move_to_device(var,opt.device)
loader = tqdm.trange(opt.max_iter,desc="training",leave=False)
# visualize initial state
var = self.graph.forward(opt,var)
self.visualize(opt,var,step=0)
for it in loader:
# train iteration
loss = self.train_iteration(opt,var,loader)
if opt.warp.fix_first:
self.graph.warp_param.weight.data[0] = 0
# after training
os.system("ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}".format(self.vis_path,self.video_fname))
self.save_checkpoint(opt,ep=None,it=self.it)
if opt.tb:
self.tb.flush()
self.tb.close()
if opt.visdom: self.vis.close()
log.title("TRAINING DONE")
def train_iteration(self,opt,var,loader):
loss = super().train_iteration(opt,var,loader)
self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter)
return loss
def generate_warp_perturbation(self,opt):
# pre-generate perturbations (translational noise + homography noise)
def create_random_perturbation(batch_size):
if opt.warp.dof==1:
homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
trans_pert = torch.zeros(batch_size,2,device=opt.device)*opt.warp.noise_t
elif opt.warp.dof==2:
homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
rot_pert = torch.zeros(batch_size,1,device=opt.device)*opt.warp.noise_r
trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
elif opt.warp.dof==3:
homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
elif opt.warp.dof==8:
homo_pert = torch.randn(batch_size,8,device=opt.device)*opt.warp.noise_h
homo_pert[:,:2]=0
rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
else: assert(False)
return homo_pert, rot_pert, trans_pert
homo_pert = torch.zeros(opt.batch_size,8,device=opt.device)
rot_pert = torch.zeros(opt.batch_size,1,device=opt.device)
trans_pert = torch.zeros(opt.batch_size,2,device=opt.device)
for i in range(opt.batch_size):
homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)
while not warp.check_corners_in_range_compose(opt, homo_pert_i, rot_pert_i, trans_pert_i):
homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)
homo_pert[i], rot_pert[i], trans_pert[i] = homo_pert_i, rot_pert_i, trans_pert_i
if opt.warp.fix_first:
homo_pert[0],rot_pert[0],trans_pert[0] = 0,0,0
# create warped image patches
xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2]
xy_grid_hom = warp.camera.to_hom(xy_grid)
warp_matrix = warp.lie.sl3_to_SL3(homo_pert)
warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)
xy_grid_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]
warp_matrix = warp.lie.so2_to_SO2(rot_pert)
xy_grid_warped = xy_grid_warped@warp_matrix.transpose(-2,-1) # [B,HW,2]
xy_grid_warped = xy_grid_warped+trans_pert[...,None,:]
xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2])
xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W,
xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1)
image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1)
image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False)
return homo_pert, rot_pert, trans_pert, image_pert_all
def visualize_patches(self,opt,warp_param):
image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
draw = PIL.ImageDraw.Draw(draw_pil)
corners_all = warp.warp_corners(opt,warp_param)
corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
for i,corners in enumerate(corners_all):
P = [tuple(float(n) for n in corners[j]) for j in range(4)]
draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
image_pil.alpha_composite(draw_pil)
image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
return image_tensor
def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):
image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
draw = PIL.ImageDraw.Draw(draw_pil)
corners_all = warp.warp_corners_compose(opt, homo_pert, rot_pert, trans_pert)
corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
for i,corners in enumerate(corners_all):
P = [tuple(float(n) for n in corners[j]) for j in range(4)]
draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
image_pil.alpha_composite(draw_pil)
image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
return image_tensor
@torch.no_grad()
def predict_entire_image(self,opt):
xy_grid = warp.get_normalized_pixel_grid(opt)[:1]
rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3]
image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1)
return image
@torch.no_grad()
def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
# compute PSNR
psnr = -10*loss.render.log10()
self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step)
# warp error
# pred_corners = warp.warp_corners_use_matrix(opt,self.graph.warp_matrix)
pred_corners = warp.warp_corners(opt,self.graph.warp_param.weight)
pred_corners[...,0] = (pred_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
pred_corners[...,1] = (pred_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
gt_corners = warp.warp_corners_compose(opt, self.homo_pert, self.rot_pert, self.trans_pert)
gt_corners[...,0] = (gt_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
gt_corners[...,1] = (gt_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
warp_error = (pred_corners-gt_corners).norm(dim=-1).mean()
self.tb.add_scalar("{0}/{1}".format(split,"warp error"),warp_error,step)
@torch.no_grad()
def visualize(self,opt,var,step=0,split="train"):
# dump frames for writing to video
frame_GT = self.visualize_patches_compose(opt,self.homo_pert, self.rot_pert, self.trans_pert)
frame = self.visualize_patches(opt,self.graph.warp_param.weight)
frame2 = self.predict_entire_image(opt)
frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy()
imageio.imsave("{}/{}.png".format(self.vis_path,self.vis_it),frame_cat)
self.vis_it += 1
# visualize in Tensorboard
if opt.tb:
colors = self.box_colors
util_vis.tb_image(opt,self.tb,step,split,"image_pert",util_vis.color_border(var.image_pert,colors))
util_vis.tb_image(opt,self.tb,step,split,"rgb_warped",util_vis.color_border(var.rgb_warped_map,colors))
util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes",frame[None])
util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes_GT",frame_GT[None])
util_vis.tb_image(opt,self.tb,self.it+1,"train","image_entire",frame2[None])
# ============================ computation graph for forward/backprop ============================
class Graph(base.Graph):
def __init__(self,opt):
super().__init__(opt)
self.neural_image = NeuralImageFunction(opt)
def forward(self,opt,var,mode=None):
xy_grid = warp.get_normalized_pixel_grid_crop(opt)
xy_grid_warped = warp.warp_grid(opt,xy_grid,self.warp_param.weight)
# render images
var.rgb_warped = self.neural_image.forward(opt,xy_grid_warped) # [B,HW,3]
var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W]
return var
def compute_loss(self,opt,var,mode=None):
loss = edict()
if opt.loss_weight.render is not None:
image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1)
loss.render = self.MSE_loss(var.rgb_warped,image_pert)
return loss
class NeuralImageFunction(torch.nn.Module):
def __init__(self,opt):
super().__init__()
self.define_network(opt)
self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed
def define_network(self,opt):
input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2
# point-wise RGB prediction
self.mlp = torch.nn.ModuleList()
L = util.get_layer_dims(opt.arch.layers)
for li,(k_in,k_out) in enumerate(L):
if li==0: k_in = input_2D_dim
if li in opt.arch.skip: k_in += input_2D_dim
linear = torch.nn.Linear(k_in,k_out)
if opt.barf_c2f and li==0:
# rescale first layer init (distribution was for pos.enc. but only xy is first used)
scale = np.sqrt(input_2D_dim/2.)
linear.weight.data *= scale
linear.bias.data *= scale
self.mlp.append(linear)
def forward(self,opt,coord_2D): # [B,...,3]
if opt.arch.posenc:
points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D)
points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3]
else: points_enc = coord_2D
feat = points_enc
# extract implicit features
for li,layer in enumerate(self.mlp):
if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)
feat = layer(feat)
if li!=len(self.mlp)-1:
feat = torch_F.relu(feat)
rgb = feat.sigmoid_() # [B,...,3]
return rgb
def positional_encoding(self,opt,input,L): # [B,...,N]
shape = input.shape
freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]
spectrum = input[...,None]*freq # [B,...,N,L]
sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
# coarse-to-fine: smoothly mask positional encoding for BARF
if opt.barf_c2f is not None:
# set weights for different frequency bands
start,end = opt.barf_c2f
alpha = (self.progress.data-start)/(end-start)*L
k = torch.arange(L,dtype=torch.float32,device=opt.device)
weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
# apply weights
shape = input_enc.shape
input_enc = (input_enc.view(-1,L)*weight).view(*shape)
return input_enc
================================================
FILE: options/barf_blender.yaml
================================================
_parent_: options/nerf_blender.yaml
barf_c2f: [0.1,0.5] # coarse-to-fine scheduling on positional encoding
camera: # camera options
noise: true # synthetic perturbations on the camera poses (Blender only)
noise_r: 0.03 # synthetic perturbations on the camera poses (Blender only)
noise_t: 0.3 # synthetic perturbations on the camera poses (Blender only)
optim: # optimization options
lr_pose: 1.e-3 # learning rate of camera poses
lr_pose_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
sched_pose: # learning rate scheduling options
type: ExponentialLR # scheduler (see PyTorch doc)
gamma: # decay rate (can be empty if lr_pose_end were specified)
warmup_pose: # linear warmup of the pose learning rate (N iterations)
test_photo: true # test-time photometric optimization for evaluation
test_iter: 100 # number of iterations for test-time optimization
visdom: # Visdom options
cam_depth: 0.5 # size of visualized cameras
================================================
FILE: options/barf_iphone.yaml
================================================
_parent_: options/barf_llff.yaml
data: # data options
dataset: iphone # dataset name
scene: IMG_0239 # scene name
image_size: [480,640] # input image sizes [height,width]
================================================
FILE: options/barf_llff.yaml
================================================
_parent_: options/nerf_llff.yaml
barf_c2f: [0.1,0.5] # coarse-to-fine scheduling on positional encoding
camera: # camera options
noise: # synthetic perturbations on the camera poses (Blender only)
optim: # optimization options
lr_pose: 3.e-3 # learning rate of camera poses
lr_pose_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
sched_pose: # learning rate scheduling options
type: ExponentialLR # scheduler (see PyTorch doc)
gamma: # decay rate (can be empty if lr_pose_end were specified)
warmup_pose: # linear warmup of the pose learning rate (N iterations)
test_photo: true # test-time photometric optimization for evaluation
test_iter: 100 # number of iterations for test-time optimization
visdom: # Visdom options
cam_depth: 0.2 # size of visualized cameras
================================================
FILE: options/base.yaml
================================================
# default
group: 0_test # name of experiment group
name: debug # name of experiment run
model: # type of model (must be specified from command line)
yaml: # config file (must be specified from command line)
seed: 0 # seed number (for both numpy and pytorch)
gpu: 0 # GPU index number
cpu: false # run only on CPU (not supported now)
load: # load checkpoint from filename
arch: {} # architectural options
data: # data options
root: # root path to dataset
dataset: # dataset name
image_size: [null,null] # input image sizes [height,width]
num_workers: 8 # number of parallel workers for data loading
preload: false # preload the entire dataset into the memory
augment: {} # data augmentation (training only)
# rotate: # random rotation
# brightness: # 0.2 # random brightness jitter
# contrast: # 0.2 # random contrast jitter
# saturation: # 0.2 # random saturation jitter
# hue: # 0.1 # random hue jitter
# hflip: # True # random horizontal flip
center_crop: # center crop the image by ratio
val_on_test: false # validate on test set during training
train_sub: # consider a subset of N training samples
val_sub: # consider a subset of N validation samples
loss_weight: {} # loss weights (in log scale)
optim: # optimization options
lr: 1.e-3 # learning rate (main)
lr_end: # terminal learning rate (only used with sched.type=ExponentialLR)
algo: Adam # optimizer (see PyTorch doc)
sched: {} # learning rate scheduling options
# type: StepLR # scheduler (see PyTorch doc)
# steps: # decay every N epochs
# gamma: 0.1 # decay rate (can be empty if lr_end were specified)
batch_size: 16 # batch size
max_epoch: 1000 # train to maximum number of epochs
resume: false # resume training (true for latest checkpoint, or number for specific epoch number)
output_root: output # root path for output files (checkpoints and results)
tb: # TensorBoard options
num_images: [4,8] # number of (tiled) images to visualize in TensorBoard
visdom: # Visdom options
server: localhost # server to host Visdom
port: 8600 # port number for Visdom
freq: # periodic actions during training
scalar: 200 # log losses and scalar states (every N iterations)
vis: 1000 # visualize results (every N iterations)
val: 20 # validate on val set (every N epochs)
ckpt: 50 # save checkpoint (every N epochs)
================================================
FILE: options/l2g_nerf_blender.yaml
================================================
_parent_: options/barf_blender.yaml
arch: # architectural options
layers_warp: [null,256,256,256,256,256,256,6] # hidden layers for MLP
skip_warp: [4] # skip connections
embedding_dim: 128 # embedding dim
optim: # optimization options
lr_pose: 1.e-3 # learning rate of camera poses
lr_pose_end: 1.e-8 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
loss_weight: # loss weights (in log scale)
global_alignment: 2 # global alignment loss
error_map_size:
================================================
FILE: options/l2g_nerf_iphone.yaml
================================================
_parent_: options/l2g_nerf_llff.yaml
data: # data options
dataset: iphone # dataset name
scene: IMG_0239 # scene name
image_size: [480,640] # input image sizes [height,width]
optim: # optimization options
test_iter: 1000 # number of iterations for test-time optimization
================================================
FILE: options/l2g_nerf_llff.yaml
================================================
_parent_: options/barf_llff.yaml
arch: # architectural options
layers_warp: [null,256,256,256,256,256,256,6] # hidden layers for MLP
skip_warp: [4] # skip connections
embedding_dim: 128 # embedding dim
optim: # optimization options
lr_pose: 3.e-3 # learning rate of camera poses
lr_pose_end: 1.e-8 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
loss_weight: # loss weights (in log scale)
global_alignment: 2 # global alignment loss
error_map_size:
================================================
FILE: options/l2g_planar.yaml
================================================
_parent_: options/planar.yaml
arch: # architectural options
layers_warp: [null,256,256,256,256,256,256,3] # hidden layers for MLP
skip_warp: [4] # skip connections
embedding_dim: 128 # embedding dim
loss_weight: # loss weights (in log scale)
global_alignment: 2 # global alignment loss
optim:
lr: 1.e-3 # learning rate (main)
lr_end: 1.e-4 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
sched: # learning rate scheduling options
type: ExponentialLR # scheduler (see PyTorch doc)
gamma: # decay rate (can be empty if lr_pose_end were specified)
lr_warp: 1.e-3 # learning rate of camera poses
lr_warp_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
sched_warp: # learning rate scheduling options
type: ExponentialLR # scheduler (see PyTorch doc)
gamma: # decay rate (can be empty if lr_pose_end were specified)
================================================
FILE: options/nerf_blender.yaml
================================================
_parent_: options/base.yaml
arch: # architectural options
layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP
layers_rgb: [null,128,3] # hidden layers for color MLP
skip: [4] # skip connections
posenc: # positional encoding
L_3D: 10 # number of bases (3D point)
L_view: 4 # number of bases (viewpoint)
density_activ: softplus # activation function for output volume density
tf_init: true # initialize network weights in TensorFlow style
nerf: # NeRF-specific options
view_dep: true # condition MLP on viewpoint
depth: # depth-related options
param: metric # depth parametrization (for sampling along the ray)
range: [2,6] # near/far bounds for depth sampling
sample_intvs: 128 # number of samples
sample_stratified: true # stratified sampling
fine_sampling: false # hierarchical sampling with another NeRF
sample_intvs_fine: # number of samples for the fine NeRF
rand_rays: 1024 # number of random rays for each step
density_noise_reg: # Gaussian noise on density output as regularization
setbg_opaque: false # fill transparent rendering with known background color (Blender only)
data: # data options
dataset: blender # dataset name
scene: lego # scene name
image_size: [400,400] # input image sizes [height,width]
num_workers: 4 # number of parallel workers for data loading
preload: true # preload the entire dataset into the memory
bgcolor: 1 # background color (Blender only)
val_sub: 4 # consider a subset of N validation samples
camera: # camera options
model: perspective # type of camera model
ndc: false # reparametrize as normalized device coordinates (NDC)
loss_weight: # loss weights (in log scale)
render: 0 # RGB rendering loss
render_fine: # RGB rendering loss (for fine NeRF)
optim: # optimization options
lr: 5.e-4 # learning rate (main)
lr_end: 1.e-4 # terminal learning rate (only used with sched.type=ExponentialLR)
sched: # learning rate scheduling options
type: ExponentialLR # scheduler (see PyTorch doc)
gamma: # decay rate (can be empty if lr_end were specified)
batch_size: # batch size (not used for NeRF/BARF)
max_epoch: # train to maximum number of epochs (not used for NeRF/BARF)
max_iter: 200000 # train to maximum number of iterations
trimesh: # options for marching cubes to extract 3D mesh
res: 128 # 3D sampling resolution
range: [-1.2,1.2] # 3D range of interest (assuming same for x,y,z)
thres: 25. # volume density threshold for marching cubes
chunk_size: 16384 # chunk size of dense samples to be evaluated at a time
freq: # periodic actions during training
scalar: 200 # log losses and scalar states (every N iterations)
vis: 1000 # visualize results (every N iterations)
val: 2000 # validate on val set (every N iterations)
ckpt: 5000 # save checkpoint (every N iterations)
================================================
FILE: options/nerf_blender_repr.yaml
================================================
_parent_: options/base.yaml
arch: # architectural options
layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP
layers_rgb: [null,128,3] # hidden layers for color MLP
skip: [4] # skip connections
posenc: # positional encoding
L_3D: 10 # number of bases (3D point)
L_view: 4 # number of bases (viewpoint)
density_activ: relu # activation function for output volume density
tf_init: true # initialize network weights in TensorFlow style
nerf: # NeRF-specific options
view_dep: true # condition MLP on viewpoint
depth: # depth-related options
param: metric # depth parametrization (for sampling along the ray)
range: [2,6] # near/far bounds for depth sampling
sample_intvs: 64 # number of samples
sample_stratified: true # stratified sampling
fine_sampling: true # hierarchical sampling with another NeRF
sample_intvs_fine: 128 # number of samples for the fine NeRF
rand_rays: 1024 # number of random rays for each step
density_noise_reg: 0 # Gaussian noise on density output as regularization
setbg_opaque: true # fill transparent rendering with known background color (Blender only)
data: # data options
dataset: blender # dataset name
scene: lego # scene name
image_size: [400,400] # input image sizes [height,width]
num_workers: 4 # number of parallel workers for data loading
preload: true # preload the entire dataset into the memory
bgcolor: 1 # background color (Blender only)
val_sub: 4 # consider a subset of N validation samples
camera: # camera options
model: perspective # type of camera model
ndc: false # reparametrize as normalized device coordinates (NDC)
loss_weight: # loss weights (in log scale)
render: 0 # RGB rendering loss
render_fine: 0 # RGB rendering loss (for fine NeRF)
optim: # optimization options
lr: 5.e-4 # learning rate (main)
lr_end: 5.e-5 # terminal learning rate (only used with sched.type=ExponentialLR)
sched: # learning rate scheduling options
type: ExponentialLR # scheduler (see PyTorch doc)
gamma: # decay rate (can be empty if lr_end were specified)
batch_size: # batch size (not used for NeRF/BARF)
max_epoch: # train to maximum number of epochs (not used for NeRF/BARF)
max_iter: 500000 # train to maximum number of iterations
trimesh: # options for marching cubes to extract 3D mesh
res: 128 # 3D sampling resolution
range: [-1.2,1.2] # 3D range of interest (assuming same for x,y,z)
thres: 25. # volume density threshold for marching cubes
chunk_size: 16384 # chunk size of dense samples to be evaluated at a time
freq: # periodic actions during training
scalar: 200 # log losses and scalar states (every N iterations)
vis: 1000 # visualize results (every N iterations)
val: 2000 # validate on val set (every N iterations)
ckpt: 5000 # save checkpoint (every N iterations)
================================================
FILE: options/nerf_llff.yaml
================================================
_parent_: options/base.yaml
arch: # architectural optionss
layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP]
layers_rgb: [null,128,3] # hidden layers for color MLP]
skip: [4] # skip connections
posenc: # positional encoding:
L_3D: 10 # number of bases (3D point)
L_view: 4 # number of bases (viewpoint)
density_activ: softplus # activation function for output volume density
tf_init: true # initialize network weights in TensorFlow style
nerf: # NeRF-specific options
view_dep: true # condition MLP on viewpoint
depth: # depth-related options
param: inverse # depth parametrization (for sampling along the ray)
range: [1,0] # near/far bounds for depth sampling
sample_intvs: 128 # number of samples
sample_stratified: true # stratified sampling
fine_sampling: false # hierarchical sampling with another NeRF
sample_intvs_fine: # number of samples for the fine NeRF
rand_rays: 2048 # number of random rays for each step
density_noise_reg: # Gaussian noise on density output as regularization
setbg_opaque: # fill transparent rendering with known background color (Blender only)
data: # data options
dataset: llff # dataset name
scene: fern # scene name
image_size: [480,640] # input image sizes [height,width]
num_workers: 4 # number of parallel workers for data loading
preload: true # preload the entire dataset into the memory
val_ratio: 0.1 # ratio of sequence split for validation
camera: # camera options
model: perspective # type of camera model
ndc: false # reparametrize as normalized device coordinates (NDC)
loss_weight: # loss weights (in log scale)
render: 0 # RGB rendering loss
render_fine: # RGB rendering loss (for fine NeRF)
optim: # optimization options
lr: 1.e-3 # learning rate (main)
lr_end: 1.e-4 # terminal learning rate (only used with sched.type=ExponentialLR)
sched: # learning rate scheduling options
type: ExponentialLR # scheduler (see PyTorch doc)
gamma: # decay rate (can be empty if lr_end were specified)
batch_size: # batch size (not used for NeRF/BARF)
max_epoch: # train to maximum number of epochs (not used for NeRF/BARF)
max_iter: 200000 # train to maximum number of iterations
freq: # periodic actions during training
scalar: 200 # log losses and scalar states (every N iterations)
vis: 1000 # visualize results (every N iterations)
val: 2000 # validate on val set (every N iterations)
ckpt: 5000 # save checkpoint (every N iterations)
================================================
FILE: options/nerf_llff_repr.yaml
================================================
_parent_: options/base.yaml
arch: # architectural options
layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP
layers_rgb: [null,128,3] # hidden layers for color MLP
skip: [4] # skip connections
posenc: # positional encoding
L_3D: 10 # number of bases (3D point)
L_view: 4 # number of bases (viewpoint)
density_activ: relu # activation function for output volume density
tf_init: true # initialize network weights in TensorFlow style
nerf: # NeRF-specific options
view_dep: true # condition MLP on viewpoint
depth: # depth-related options
param: metric # depth parametrization (for sampling along the ray)
range: [0,1] # near/far bounds for depth sampling
sample_intvs: 64 # number of samples
sample_stratified: true # stratified sampling
fine_sampling: true # hierarchical sampling with another NeRF
sample_intvs_fine: 128 # number of samples for the fine NeRF
rand_rays: 1024 # number of random rays for each step
density_noise_reg: 1 # Gaussian noise on density output as regularization
setbg_opaque: # fill transparent rendering with known background color (Blender only)
data: # data options
dataset: llff # dataset name
scene: fern # scene name
image_size: [480,640] # input image sizes [height,width]
num_workers: 4 # number of parallel workers for data loading
preload: true # preload the entire dataset into the memory
val_ratio: 0.1 # ratio of sequence split for validation
camera: # camera options
model: perspective # type of camera model
ndc: false # reparametrize as normalized device coordinates (NDC)
loss_weight: # loss weights (in log scale)
render: 0 # RGB rendering loss
render_fine: 0 # RGB rendering loss (for fine NeRF)
optim: # optimization options
lr: 5.e-4 # learning rate (main)
lr_end: 5.e-5 # terminal learning rate (only used with sched.type=ExponentialLR)
sched: # learning rate scheduling options
type: ExponentialLR # scheduler (see PyTorch doc)
gamma: # decay rate (can be empty if lr_end were specified)
batch_size: # batch size (not used for NeRF/BARF)
max_epoch: # train to maximum number of epochs (not used for NeRF/BARF)
max_iter: 500000 # train to maximum number of iterations
freq: # periodic actions during training
scalar: 200 # log losses and scalar states (every N iterations)
vis: 1000 # visualize results (every N iterations)
val: 2000 # validate on val set (every N iterations)
ckpt: 5000 # save checkpoint (every N iterations)
================================================
FILE: options/planar.yaml
================================================
_parent_: options/base.yaml
arch: # architectural options
layers: [null,256,256,256,256,3] # hidden layers for MLP
skip: [] # skip connections
posenc: # positional encoding
L_2D: 8 # number of bases (3D point)
barf_c2f: [0,0.4] # coarse-to-fine scheduling on positional encoding
data: # data options
image_fname: data/cat.jpg # path to image file
image_size: [360,480] # original image size
patch_crop: [180,180] # crop size of image patches to align
warp: # image warping options
type: homography # type of warp function
dof: 8 # degrees of freedom of the warp function
noise_h: 0.3 # scale of pre-generated warp perturbation (homography)
noise_t: 0.1 # scale of pre-generated warp perturbation (translation)
noise_r: 1.0 # scale of pre-generated warp perturbation (rotation)
fix_first: true # fix the first patch for uniqueness of solution
loss_weight: # loss weights (in log scale)
render: 0 # RGB rendering loss
optim: # optimization options
lr: 1.e-3 # learning rate (main)
lr_warp: 1.e-3 # learning rate of warp parameters
batch_size: 5 # batch size (set to number of patches to consider)
max_iter: 5000 # train to maximum number of iterations
visdom: # Visdom options (turned off)
freq: # periodic actions during training
scalar: 20 # log losses and scalar states (every N iterations)
vis: 100 # visualize results (every N iterations)
================================================
FILE: options.py
================================================
import numpy as np
import os,sys,time
import torch
import random
import string
import yaml
from easydict import EasyDict as edict
import util
from util import log
# torch.backends.cudnn.enabled = False
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
def parse_arguments(args):
"""
Parse arguments from command line.
Syntax: --key1.key2.key3=value --> value
--key1.key2.key3= --> None
--key1.key2.key3 --> True
--key1.key2.key3! --> False
"""
opt_cmd = {}
for arg in args:
assert(arg.startswith("--"))
if "=" not in arg[2:]:
key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true")
else:
key_str,value = arg[2:].split("=")
keys_sub = key_str.split(".")
opt_sub = opt_cmd
for k in keys_sub[:-1]:
if k not in opt_sub: opt_sub[k] = {}
opt_sub = opt_sub[k]
assert keys_sub[-1] not in opt_sub,keys_sub[-1]
opt_sub[keys_sub[-1]] = yaml.safe_load(value)
opt_cmd = edict(opt_cmd)
return opt_cmd
def set(opt_cmd={}):
log.info("setting configurations...")
assert("model" in opt_cmd)
# load config from yaml file
assert("yaml" in opt_cmd)
fname = "options/{}.yaml".format(opt_cmd.yaml)
opt_base = load_options(fname)
# override with command line arguments
opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True)
process_options(opt)
log.options(opt)
return opt
def load_options(fname):
with open(fname) as file:
opt = edict(yaml.safe_load(file))
if "_parent_" in opt:
# load parent yaml file(s) as base options
parent_fnames = opt.pop("_parent_")
if type(parent_fnames) is str:
parent_fnames = [parent_fnames]
for parent_fname in parent_fnames:
opt_parent = load_options(parent_fname)
opt_parent = override_options(opt_parent,opt,key_stack=[])
opt = opt_parent
print("loading {}...".format(fname))
return opt
def override_options(opt,opt_over,key_stack=None,safe_check=False):
for key,value in opt_over.items():
if isinstance(value,dict):
# parse child options (until leaf nodes are reached)
opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check)
else:
# ensure command line argument to override is also in yaml file
if safe_check and key not in opt:
add_new = None
while add_new not in ["y","n"]:
key_str = ".".join(key_stack+[key])
add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str))
if add_new=="n":
print("safe exiting...")
exit()
opt[key] = value
return opt
def process_options(opt):
# set seed
if opt.seed is not None:
random.seed(opt.seed)
np.random.seed(opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed_all(opt.seed)
if opt.seed!=0:
opt.name = str(opt.name)+"_seed{}".format(opt.seed)
else:
# create random string as run ID
randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4))
opt.name = str(opt.name)+"_{}".format(randkey)
# other default options
opt.output_path = "{0}/{1}/{2}".format(opt.output_root,opt.group,opt.name)
os.makedirs(opt.output_path,exist_ok=True)
assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough
opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu)
opt.H,opt.W = opt.data.image_size
def save_options_file(opt):
opt_fname = "{}/options.yaml".format(opt.output_path)
if os.path.isfile(opt_fname):
with open(opt_fname) as file:
opt_old = yaml.safe_load(file)
if opt!=opt_old:
# prompt if options are not identical
opt_new_fname = "{}/options_temp.yaml".format(opt.output_path)
with open(opt_new_fname,"w") as file:
yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4)
print("existing options file found (different from current one)...")
os.system("diff {} {}".format(opt_fname,opt_new_fname))
os.system("rm {}".format(opt_new_fname))
override = None
while override not in ["y","n"]:
override = input("override? (y/n) ")
if override=="n":
print("safe exiting...")
exit()
else: print("existing options file found (identical)")
else: print("(creating new options file...)")
with open(opt_fname,"w") as file:
yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4)
================================================
FILE: requirements.yaml
================================================
name: L2G-NeRF
channels:
- conda-forge
- pytorch
dependencies:
- numpy
- scipy
- tqdm
- termcolor
- easydict
- imageio
- ipdb
- pytorch>=1.9.0
- torchvision
- tensorboard
- visdom
- matplotlib
- scikit-video
- trimesh
- pyyaml
- pip
- gdown
- pip:
- lpips
- pymcubes
- roma
- kornia
================================================
FILE: train.py
================================================
import numpy as np
import os,sys,time
import torch
import importlib
import options
from util import log
import warnings
warnings.filterwarnings('ignore')
def main():
log.process(os.getpid())
log.title("[{}] (PyTorch code for training NeRF/BARF/L2G_NeRF)".format(sys.argv[0]))
opt_cmd = options.parse_arguments(sys.argv[1:])
opt = options.set(opt_cmd=opt_cmd)
options.save_options_file(opt)
with torch.cuda.device(opt.device):
model = importlib.import_module("model.{}".format(opt.model))
m = model.Model(opt)
m.load_dataset(opt)
m.build_networks(opt)
m.setup_optimizer(opt)
m.restore_checkpoint(opt)
m.setup_visualizer(opt)
m.train(opt)
if __name__=="__main__":
main()
================================================
FILE: util.py
================================================
import numpy as np
import os,sys,time
import shutil
import datetime
import torch
import torch.nn.functional as torch_F
import ipdb
import types
import termcolor
import socket
import contextlib
from easydict import EasyDict as edict
# convert to colored strings
def red(message,**kwargs): return termcolor.colored(str(message),color="red",attrs=[k for k,v in kwargs.items() if v is True])
def green(message,**kwargs): return termcolor.colored(str(message),color="green",attrs=[k for k,v in kwargs.items() if v is True])
def blue(message,**kwargs): return termcolor.colored(str(message),color="blue",attrs=[k for k,v in kwargs.items() if v is True])
def cyan(message,**kwargs): return termcolor.colored(str(message),color="cyan",attrs=[k for k,v in kwargs.items() if v is True])
def yellow(message,**kwargs): return termcolor.colored(str(message),color="yellow",attrs=[k for k,v in kwargs.items() if v is True])
def magenta(message,**kwargs): return termcolor.colored(str(message),color="magenta",attrs=[k for k,v in kwargs.items() if v is True])
def grey(message,**kwargs): return termcolor.colored(str(message),color="grey",attrs=[k for k,v in kwargs.items() if v is True])
def get_time(sec):
d = int(sec//(24*60*60))
h = int(sec//(60*60)%24)
m = int((sec//60)%60)
s = int(sec%60)
return d,h,m,s
def add_datetime(func):
def wrapper(*args,**kwargs):
datetime_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(grey("[{}] ".format(datetime_str),bold=True),end="")
return func(*args,**kwargs)
return wrapper
def add_functionname(func):
def wrapper(*args,**kwargs):
print(grey("[{}] ".format(func.__name__),bold=True))
return func(*args,**kwargs)
return wrapper
def pre_post_actions(pre=None,post=None):
def func_decorator(func):
def wrapper(*args,**kwargs):
if pre: pre()
retval = func(*args,**kwargs)
if
gitextract_16lz1fht/ ├── LICENSE ├── README.md ├── camera.py ├── data/ │ ├── base.py │ ├── blender.py │ ├── iphone.py │ └── llff.py ├── evaluate.py ├── external/ │ └── pohsun_ssim/ │ ├── LICENSE.txt │ ├── README.md │ ├── max_ssim.py │ ├── pytorch_ssim/ │ │ └── __init__.py │ ├── setup.cfg │ └── setup.py ├── extract_mesh.py ├── model/ │ ├── barf.py │ ├── base.py │ ├── l2g_nerf.py │ ├── l2g_planar.py │ ├── nerf.py │ └── planar.py ├── options/ │ ├── barf_blender.yaml │ ├── barf_iphone.yaml │ ├── barf_llff.yaml │ ├── base.yaml │ ├── l2g_nerf_blender.yaml │ ├── l2g_nerf_iphone.yaml │ ├── l2g_nerf_llff.yaml │ ├── l2g_planar.yaml │ ├── nerf_blender.yaml │ ├── nerf_blender_repr.yaml │ ├── nerf_llff.yaml │ ├── nerf_llff_repr.yaml │ └── planar.yaml ├── options.py ├── requirements.yaml ├── train.py ├── util.py ├── util_vis.py └── warp.py
SYMBOL INDEX (296 symbols across 18 files)
FILE: camera.py
class Pose (line 11) | class Pose():
method __call__ (line 17) | def __call__(self,R=None,t=None):
method invert (line 36) | def invert(self,pose,use_inverse=False):
method compose (line 44) | def compose(self,pose_list):
method compose_pair (line 52) | def compose_pair(self,pose_a,pose_b):
class Lie (line 61) | class Lie():
method so3_to_SO3 (line 66) | def so3_to_SO3(self,w): # [...,3]
method SO3_to_so3 (line 75) | def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
method se3_to_SE3 (line 83) | def se3_to_SE3(self,wu): # [...,3]
method SE3_to_se3 (line 96) | def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
method skew_symmetric (line 109) | def skew_symmetric(self,w):
method taylor_A (line 117) | def taylor_A(self,x,nth=10):
method taylor_B (line 125) | def taylor_B(self,x,nth=10):
method taylor_C (line 133) | def taylor_C(self,x,nth=10):
class Quaternion (line 142) | class Quaternion():
method q_to_R (line 144) | def q_to_R(self,q):
method R_to_q (line 152) | def R_to_q(self,R,eps=1e-8): # [B,3,3]
method invert (line 178) | def invert(self,q):
method product (line 184) | def product(self,q1,q2): # [B,4]
function to_hom (line 197) | def to_hom(X):
function world2cam (line 203) | def world2cam(X,pose): # [B,N,3]
function cam2img (line 206) | def cam2img(X,cam_intr):
function img2cam (line 208) | def img2cam(X,cam_intr):
function cam2world (line 210) | def cam2world(X,pose):
function angle_to_rotation_matrix (line 215) | def angle_to_rotation_matrix(a,axis):
function get_center_and_ray (line 226) | def get_center_and_ray(opt,pose,intr=None): # [HW,2]
function get_camera_cords_grid_3D (line 246) | def get_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [...
function gather_camera_cords_grid_3D (line 263) | def gather_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): ...
function get_3D_points_from_depth (line 280) | def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False):
function convert_NDC (line 286) | def convert_NDC(opt,center,ray,intr,near=1):
function rotation_distance (line 305) | def rotation_distance(R1,R2,eps=1e-7):
function procrustes_analysis (line 312) | def procrustes_analysis(X0,X1): # [N,3]
function get_novel_view_poses (line 331) | def get_novel_view_poses(opt,pose_anchor,N=60,scale=1):
FILE: data/base.py
class Dataset (line 16) | class Dataset(torch.utils.data.Dataset):
method __init__ (line 18) | def __init__(self,opt,split="train"):
method setup_loader (line 31) | def setup_loader(self,opt,shuffle=False,drop_last=False):
method get_list (line 42) | def get_list(self,opt):
method preload_worker (line 45) | def preload_worker(self,data_list,load_func,q,lock,idx_tqdm):
method preload_threading (line 53) | def preload_threading(self,opt,load_func,data_str="images"):
method __getitem__ (line 68) | def __getitem__(self,idx):
method get_image (line 71) | def get_image(self,opt,idx):
method generate_augmentation (line 74) | def generate_augmentation(self,opt):
method preprocess_image (line 92) | def preprocess_image(self,opt,image,aug=None):
method preprocess_camera (line 109) | def preprocess_camera(self,opt,intr,pose,aug=None):
method apply_color_jitter (line 119) | def apply_color_jitter(self,opt,image,color_jitter):
method __len__ (line 129) | def __len__(self):
FILE: data/blender.py
class Dataset (line 17) | class Dataset(base.Dataset):
method __init__ (line 19) | def __init__(self,opt,split="train",subset=None):
method prefetch_all_data (line 36) | def prefetch_all_data(self,opt):
method get_all_camera_poses (line 41) | def get_all_camera_poses(self,opt):
method __getitem__ (line 46) | def __getitem__(self,idx):
method get_image (line 61) | def get_image(self,opt,idx):
method preprocess_image (line 66) | def preprocess_image(self,opt,image,aug=None):
method get_camera (line 73) | def get_camera(self,opt,idx):
method parse_raw_camera (line 81) | def parse_raw_camera(self,opt,pose_raw):
FILE: data/iphone.py
class Dataset (line 17) | class Dataset(base.Dataset):
method __init__ (line 19) | def __init__(self,opt,split="train",subset=None):
method prefetch_all_data (line 36) | def prefetch_all_data(self,opt):
method get_all_camera_poses (line 41) | def get_all_camera_poses(self,opt):
method __getitem__ (line 45) | def __getitem__(self,idx):
method get_image (line 60) | def get_image(self,opt,idx):
method get_camera (line 65) | def get_camera(self,opt,idx):
FILE: data/llff.py
class Dataset (line 17) | class Dataset(base.Dataset):
method __init__ (line 19) | def __init__(self,opt,split="train",subset=None):
method prefetch_all_data (line 37) | def prefetch_all_data(self,opt):
method parse_cameras_and_bounds (line 42) | def parse_cameras_and_bounds(self,opt):
method center_camera_poses (line 60) | def center_camera_poses(self,opt,poses):
method get_all_camera_poses (line 71) | def get_all_camera_poses(self,opt):
method __getitem__ (line 76) | def __getitem__(self,idx):
method get_image (line 91) | def get_image(self,opt,idx):
method get_camera (line 96) | def get_camera(self,opt,idx):
method parse_raw_camera (line 104) | def parse_raw_camera(self,opt,pose_raw):
FILE: evaluate.py
function main (line 12) | def main():
FILE: external/pohsun_ssim/pytorch_ssim/__init__.py
function gaussian (line 7) | def gaussian(window_size, sigma):
function create_window (line 11) | def create_window(window_size, channel):
function _ssim (line 17) | def _ssim(img1, img2, window, window_size, channel, size_average = True):
class SSIM (line 39) | class SSIM(torch.nn.Module):
method __init__ (line 40) | def __init__(self, window_size = 11, size_average = True):
method forward (line 47) | def forward(self, img1, img2):
function ssim (line 65) | def ssim(img1, img2, window_size = 11, size_average = True):
FILE: model/barf.py
class Model (line 19) | class Model(nerf.Model):
method __init__ (line 21) | def __init__(self,opt):
method build_networks (line 24) | def build_networks(self,opt):
method setup_optimizer (line 49) | def setup_optimizer(self,opt):
method train_iteration (line 62) | def train_iteration(self,opt,var,loader):
method validate (line 79) | def validate(self,opt,ep=None):
method log_scalars (line 85) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
method visualize (line 100) | def visualize(self,opt,var,step=0,split="train"):
method get_all_training_poses (line 109) | def get_all_training_poses(self,opt):
method prealign_cameras (line 116) | def prealign_cameras(self,opt,pose,pose_GT):
method evaluate_camera_alignment (line 134) | def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
method evaluate_full (line 144) | def evaluate_full(self,opt):
method evaluate_test_time_photometric_optim (line 163) | def evaluate_test_time_photometric_optim(self,opt,var):
method generate_videos_pose (line 181) | def generate_videos_pose(self,opt):
class Graph (line 217) | class Graph(nerf.Graph):
method __init__ (line 219) | def __init__(self,opt):
method forward (line 226) | def forward(self,opt,var,mode=None):
method get_pose (line 249) | def get_pose(self,opt,var,mode=None):
class NeRF (line 278) | class NeRF(nerf.NeRF):
method __init__ (line 280) | def __init__(self,opt):
method positional_encoding (line 284) | def positional_encoding(self,opt,input,L): # [B,...,N]
FILE: model/base.py
class Model (line 18) | class Model():
method __init__ (line 20) | def __init__(self,opt):
method load_dataset (line 24) | def load_dataset(self,opt,eval_split="val"):
method build_networks (line 34) | def build_networks(self,opt):
method setup_optimizer (line 39) | def setup_optimizer(self,opt):
method restore_checkpoint (line 49) | def restore_checkpoint(self,opt):
method setup_visualizer (line 62) | def setup_visualizer(self,opt):
method train (line 78) | def train(self,opt):
method train_epoch (line 94) | def train_epoch(self,opt):
method train_iteration (line 111) | def train_iteration(self,opt,var,loader):
method summarize_loss (line 130) | def summarize_loss(self,opt,var,loss):
method validate (line 145) | def validate(self,opt,ep=None):
method log_scalars (line 165) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
method visualize (line 175) | def visualize(self,opt,var,step=0,split="train"):
method save_checkpoint (line 178) | def save_checkpoint(self,opt,ep=0,it=0,latest=False):
class Graph (line 185) | class Graph(torch.nn.Module):
method __init__ (line 187) | def __init__(self,opt):
method forward (line 190) | def forward(self,opt,var,mode=None):
method compute_loss (line 194) | def compute_loss(self,opt,var,mode=None):
method L1_loss (line 199) | def L1_loss(self,pred,label=0):
method MSE_loss (line 202) | def MSE_loss(self,pred,label=0):
FILE: model/l2g_nerf.py
class Model (line 20) | class Model(nerf.Model):
method __init__ (line 22) | def __init__(self,opt):
method build_networks (line 25) | def build_networks(self,opt):
method setup_optimizer (line 55) | def setup_optimizer(self,opt):
method train_iteration (line 69) | def train_iteration(self,opt,var,loader):
method validate (line 86) | def validate(self,opt,ep=None):
method log_scalars (line 92) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
method visualize (line 107) | def visualize(self,opt,var,step=0,split="train"):
method get_all_training_poses (line 116) | def get_all_training_poses(self,opt):
method prealign_cameras (line 123) | def prealign_cameras(self,opt,pose,pose_GT):
method evaluate_camera_alignment (line 141) | def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
method evaluate_full (line 151) | def evaluate_full(self,opt):
method evaluate_test_time_photometric_optim (line 170) | def evaluate_test_time_photometric_optim(self,opt,var):
method generate_videos_pose (line 188) | def generate_videos_pose(self,opt):
class Graph (line 224) | class Graph(nerf.Graph):
method __init__ (line 226) | def __init__(self,opt):
method get_pose (line 233) | def get_pose(self,opt,var,mode=None):
method forward (line 274) | def forward(self,opt,var,mode=None):
method compute_loss (line 314) | def compute_loss(self,opt,var,mode=None):
method local_render (line 348) | def local_render(self,opt,local_pose,intr=None,ray_idx=None,mode=None):
class NeRF (line 380) | class NeRF(nerf.NeRF):
method __init__ (line 382) | def __init__(self,opt):
method positional_encoding (line 386) | def positional_encoding(self,opt,input,L): # [B,...,N]
class localWarp (line 400) | class localWarp(torch.nn.Module):
method __init__ (line 401) | def __init__(self, opt):
method forward (line 413) | def forward(self,opt,uvf):
FILE: model/l2g_planar.py
class Model (line 21) | class Model(base.Model):
method __init__ (line 23) | def __init__(self,opt):
method load_dataset (line 27) | def load_dataset(self,opt,eval_split=None):
method build_networks (line 31) | def build_networks(self,opt):
method setup_optimizer (line 36) | def setup_optimizer(self,opt):
method setup_visualizer (line 65) | def setup_visualizer(self,opt):
method train (line 77) | def train(self,opt):
method train_iteration (line 104) | def train_iteration(self,opt,var,loader):
method generate_warp_perturbation (line 113) | def generate_warp_perturbation(self,opt):
method visualize_patches_compose (line 167) | def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):
method visualize_patches_use_matrix (line 181) | def visualize_patches_use_matrix(self,opt,warp_matrix):
method predict_entire_image (line 196) | def predict_entire_image(self,opt):
method log_scalars (line 203) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
method visualize (line 222) | def visualize(self,opt,var,step=0,split="train"):
class Graph (line 245) | class Graph(base.Graph):
method __init__ (line 247) | def __init__(self,opt):
method forward (line 251) | def forward(self,opt,var,mode=None):
method compute_loss (line 292) | def compute_loss(self,opt,var,mode=None):
class NeuralImageFunction (line 301) | class NeuralImageFunction(torch.nn.Module):
method __init__ (line 303) | def __init__(self,opt):
method define_network (line 308) | def define_network(self,opt):
method forward (line 324) | def forward(self,opt,coord_2D): # [B,...,3]
method positional_encoding (line 339) | def positional_encoding(self,opt,input,L): # [B,...,N]
class localWarp (line 358) | class localWarp(torch.nn.Module):
method __init__ (line 359) | def __init__(self, opt):
method forward (line 371) | def forward(self,opt,uvf):
FILE: model/nerf.py
class Model (line 19) | class Model(base.Model):
method __init__ (line 21) | def __init__(self,opt):
method load_dataset (line 25) | def load_dataset(self,opt,eval_split="val"):
method setup_optimizer (line 31) | def setup_optimizer(self,opt):
method train (line 46) | def train(self,opt):
method log_scalars (line 71) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
method visualize (line 88) | def visualize(self,opt,var,step=0,split="train",eps=1e-10):
method get_all_training_poses (line 114) | def get_all_training_poses(self,opt):
method evaluate_full (line 120) | def evaluate_full(self,opt,eps=1e-10):
method generate_videos_synthesis (line 164) | def generate_videos_synthesis(self,opt,eps=1e-10):
class Graph (line 210) | class Graph(base.Graph):
method __init__ (line 212) | def __init__(self,opt):
method forward (line 218) | def forward(self,opt,var,mode=None):
method compute_loss (line 233) | def compute_loss(self,opt,var,mode=None):
method get_pose (line 247) | def get_pose(self,opt,var,mode=None):
method render (line 250) | def render(self,opt,pose,intr=None,ray_idx=None,mode=None):
method render_by_slices (line 278) | def render_by_slices(self,opt,pose,intr=None,mode=None):
method sample_depth (line 291) | def sample_depth(self,opt,batch_size,num_rays=None):
method sample_depth_from_pdf (line 303) | def sample_depth_from_pdf(self,opt,pdf):
class NeRF (line 324) | class NeRF(torch.nn.Module):
method __init__ (line 326) | def __init__(self,opt):
method define_network (line 330) | def define_network(self,opt):
method tensorflow_init_weights (line 356) | def tensorflow_init_weights(self,opt,linear,out=None):
method forward (line 368) | def forward(self,opt,points_3D,ray_unit=None,mode=None): # [B,...,3]
method forward_samples (line 401) | def forward_samples(self,opt,center,ray,depth_samples,mode=None):
method composite (line 410) | def composite(self,opt,ray,rgb_samples,density_samples,depth_samples):
method positional_encoding (line 428) | def positional_encoding(self,opt,input,L): # [B,...,N]
FILE: model/planar.py
class Model (line 20) | class Model(base.Model):
method __init__ (line 22) | def __init__(self,opt):
method load_dataset (line 26) | def load_dataset(self,opt,eval_split=None):
method build_networks (line 30) | def build_networks(self,opt):
method setup_optimizer (line 35) | def setup_optimizer(self,opt):
method setup_visualizer (line 49) | def setup_visualizer(self,opt):
method train (line 61) | def train(self,opt):
method train_iteration (line 90) | def train_iteration(self,opt,var,loader):
method generate_warp_perturbation (line 95) | def generate_warp_perturbation(self,opt):
method visualize_patches (line 149) | def visualize_patches(self,opt,warp_param):
method visualize_patches_compose (line 163) | def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):
method predict_entire_image (line 178) | def predict_entire_image(self,opt):
method log_scalars (line 185) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
method visualize (line 205) | def visualize(self,opt,var,step=0,split="train"):
class Graph (line 224) | class Graph(base.Graph):
method __init__ (line 226) | def __init__(self,opt):
method forward (line 230) | def forward(self,opt,var,mode=None):
method compute_loss (line 238) | def compute_loss(self,opt,var,mode=None):
class NeuralImageFunction (line 245) | class NeuralImageFunction(torch.nn.Module):
method __init__ (line 247) | def __init__(self,opt):
method define_network (line 252) | def define_network(self,opt):
method forward (line 268) | def forward(self,opt,coord_2D): # [B,...,3]
method positional_encoding (line 283) | def positional_encoding(self,opt,input,L): # [B,...,N]
FILE: options.py
function parse_arguments (line 16) | def parse_arguments(args):
function set (line 41) | def set(opt_cmd={}):
function load_options (line 54) | def load_options(fname):
function override_options (line 69) | def override_options(opt,opt_over,key_stack=None,safe_check=False):
function process_options (line 87) | def process_options(opt):
function save_options_file (line 107) | def save_options_file(opt):
FILE: train.py
function main (line 12) | def main():
FILE: util.py
function red (line 15) | def red(message,**kwargs): return termcolor.colored(str(message),color="...
function green (line 16) | def green(message,**kwargs): return termcolor.colored(str(message),color...
function blue (line 17) | def blue(message,**kwargs): return termcolor.colored(str(message),color=...
function cyan (line 18) | def cyan(message,**kwargs): return termcolor.colored(str(message),color=...
function yellow (line 19) | def yellow(message,**kwargs): return termcolor.colored(str(message),colo...
function magenta (line 20) | def magenta(message,**kwargs): return termcolor.colored(str(message),col...
function grey (line 21) | def grey(message,**kwargs): return termcolor.colored(str(message),color=...
function get_time (line 23) | def get_time(sec):
function add_datetime (line 30) | def add_datetime(func):
function add_functionname (line 37) | def add_functionname(func):
function pre_post_actions (line 43) | def pre_post_actions(pre=None,post=None):
class Log (line 55) | class Log:
method __init__ (line 56) | def __init__(self): pass
method process (line 57) | def process(self,pid):
method title (line 59) | def title(self,message):
method info (line 61) | def info(self,message):
method options (line 63) | def options(self,opt,level=0):
method loss_train (line 70) | def loss_train(self,opt,ep,lr,loss,timer):
method loss_val (line 79) | def loss_val(self,opt,loss):
function update_timer (line 85) | def update_timer(opt,timer,ep,it_per_ep):
function move_to_device (line 95) | def move_to_device(X,device):
function to_dict (line 110) | def to_dict(D,dict_type=dict):
function get_child_state_dict (line 117) | def get_child_state_dict(state_dict,key):
function restore_checkpoint (line 120) | def restore_checkpoint(opt,model,load_name=None,resume=False):
function save_checkpoint (line 143) | def save_checkpoint(opt,model,ep,it,latest=False,children=None):
function check_socket_open (line 161) | def check_socket_open(hostname,port):
function get_layer_dims (line 172) | def get_layer_dims(layers):
function suppress (line 177) | def suppress(stdout=False,stderr=False):
function colorcode_to_number (line 186) | def colorcode_to_number(code):
FILE: util_vis.py
function tb_image (line 16) | def tb_image(opt,tb,step,group,name,images,num_vis=None,from_range=(0,1)...
function preprocess_vis_image (line 27) | def preprocess_vis_image(opt,images,from_range=(0,1),cmap="gray"):
function dump_images (line 35) | def dump_images(opt,idx,name,images,masks=None,from_range=(0,1),cmap="gr...
function get_heatmap (line 43) | def get_heatmap(opt,gray,cmap): # [N,H,W]
function color_border (line 48) | def color_border(images,colors,width=3):
function vis_cameras (line 58) | def vis_cameras(opt,vis,step,poses=[],colors=["blue","magenta"],plot_dis...
function get_camera_mesh (line 141) | def get_camera_mesh(pose,depth=1):
function merge_wireframes (line 157) | def merge_wireframes(wireframe):
function merge_meshes (line 164) | def merge_meshes(vertices,faces):
function merge_centers (line 169) | def merge_centers(centers):
function plot_save_poses (line 177) | def plot_save_poses(opt,fig,pose,pose_ref=None,path=None,ep=None):
function plot_save_poses_blender (line 220) | def plot_save_poses_blender(opt,fig,pose,pose_ref=None,path=None,ep=None):
function setup_3D_plot (line 257) | def setup_3D_plot(ax,elev,azim,lim=None):
function apply_colormap (line 294) | def apply_colormap(image, cmap="viridis"):
function apply_depth_colormap (line 317) | def apply_depth_colormap(
FILE: warp.py
function get_normalized_pixel_grid (line 11) | def get_normalized_pixel_grid(opt):
function get_normalized_pixel_grid_crop (line 19) | def get_normalized_pixel_grid_crop(opt):
function warp_grid (line 29) | def warp_grid(opt,xy_grid,warp):
function warp_grid_use_matrix (line 51) | def warp_grid_use_matrix(opt,xy_grid,warp_matrix):
function warp_corners (line 70) | def warp_corners(opt,warp_param):
function warp_corners_compose (line 80) | def warp_corners_compose(opt, homo_pert, rot_pert, trans_pert):
function warp_corners_use_matrix (line 97) | def warp_corners_use_matrix(opt,warp_matrix):
function check_corners_in_range (line 107) | def check_corners_in_range(opt,warp_param):
function check_corners_in_range_compose (line 113) | def check_corners_in_range_compose(opt, homo_pert, rot_pert, trans_pert):
class Lie (line 120) | class Lie():
method so2_to_SO2 (line 122) | def so2_to_SO2(self,theta): # [...,1]
method SO2_to_so2 (line 128) | def SO2_to_so2(self,R): # [...,2,2]
method so2_jacobian (line 132) | def so2_jacobian(self,X,theta): # [...,N,2],[...,1]
method se2_to_SE2 (line 138) | def se2_to_SE2(self,delta): # [...,3]
method SE2_to_se2 (line 148) | def SE2_to_se2(self,Rt,eps=1e-7): # [...,2,3]
method se2_jacobian (line 160) | def se2_jacobian(self,X,delta): # [...,N,2],[...,3]
method sl3_to_SL3 (line 178) | def sl3_to_SL3(self,h):
method taylor_A (line 188) | def taylor_A(self,x,nth=10):
method taylor_B (line 196) | def taylor_B(self,x,nth=10):
method taylor_C (line 205) | def taylor_C(self,x,nth=10):
method taylor_D (line 214) | def taylor_D(self,x,nth=10):
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (232K chars).
[
{
"path": "LICENSE",
"chars": 1069,
"preview": "MIT License\n\nCopyright (c) [year] [fullname]\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
},
{
"path": "README.md",
"chars": 12824,
"preview": "# L2G-NeRF: Local-to-Global Registration for Bundle-Adjusting Neural Radiance Fields\n**[Project Page](https://rover-xing"
},
{
"path": "camera.py",
"chars": 13464,
"preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport collections\nfrom easydic"
},
{
"path": "data/base.py",
"chars": 4765,
"preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
},
{
"path": "data/blender.py",
"chars": 3345,
"preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
},
{
"path": "data/iphone.py",
"chars": 2826,
"preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
},
{
"path": "data/llff.py",
"chars": 4465,
"preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
},
{
"path": "evaluate.py",
"chars": 884,
"preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport importlib\n\nimport options\nfrom util import log\n\nimport warning"
},
{
"path": "external/pohsun_ssim/LICENSE.txt",
"chars": 4,
"preview": "MIT\n"
},
{
"path": "external/pohsun_ssim/README.md",
"chars": 1769,
"preview": "# pytorch-ssim\n\n### Differentiable structural similarity (SSIM) index.\n. The extraction includes 40 files (218.7 KB), approximately 58.3k tokens, and a symbol index with 296 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.