Repository: rover-xingyu/L2G-NeRF Branch: main Commit: 08d06597a233 Files: 40 Total size: 218.7 KB Directory structure: gitextract_16lz1fht/ ├── LICENSE ├── README.md ├── camera.py ├── data/ │ ├── base.py │ ├── blender.py │ ├── iphone.py │ └── llff.py ├── evaluate.py ├── external/ │ └── pohsun_ssim/ │ ├── LICENSE.txt │ ├── README.md │ ├── max_ssim.py │ ├── pytorch_ssim/ │ │ └── __init__.py │ ├── setup.cfg │ └── setup.py ├── extract_mesh.py ├── model/ │ ├── barf.py │ ├── base.py │ ├── l2g_nerf.py │ ├── l2g_planar.py │ ├── nerf.py │ └── planar.py ├── options/ │ ├── barf_blender.yaml │ ├── barf_iphone.yaml │ ├── barf_llff.yaml │ ├── base.yaml │ ├── l2g_nerf_blender.yaml │ ├── l2g_nerf_iphone.yaml │ ├── l2g_nerf_llff.yaml │ ├── l2g_planar.yaml │ ├── nerf_blender.yaml │ ├── nerf_blender_repr.yaml │ ├── nerf_llff.yaml │ ├── nerf_llff_repr.yaml │ └── planar.yaml ├── options.py ├── requirements.yaml ├── train.py ├── util.py ├── util_vis.py └── warp.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) [year] [fullname] Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # L2G-NeRF: Local-to-Global Registration for Bundle-Adjusting Neural Radiance Fields **[Project Page](https://rover-xingyu.github.io/L2G-NeRF/) | [Paper](https://arxiv.org/pdf/2211.11505.pdf) | [Video](https://www.youtube.com/watch?v=y8XP9Umt6Mw)** [Yue Chen¹](https://scholar.google.com/citations?user=M2hq1_UAAAAJ&hl=en), [Xingyu Chen¹](https://scholar.google.com/citations?user=gDHPrWEAAAAJ&hl=en), [Xuan Wang²](https://scholar.google.com/citations?user=h-3xd3EAAAAJ&hl=en), [Qi Zhang³](https://scholar.google.com/citations?user=2vFjhHMAAAAJ&hl=en), [Yu Guo¹](https://scholar.google.com/citations?user=OemeiSIAAAAJ&hl=en), [Ying Shan³](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en), [Fei Wang¹](https://scholar.google.com/citations?user=uU2JTpUAAAAJ&hl=en). [¹Xi'an Jiaotong University](http://en.xjtu.edu.cn/), [²Ant Group](https://www.antgroup.com/en), [³Tencent AI Lab](https://ai.tencent.com/ailab/en/index/). This repository is an official implementation of [L2G-NeRF](https://rover-xingyu.github.io/L2G-NeRF/) using [pytorch](https://pytorch.org/). # :computer: Installation ## Hardware * We implement all experiments on a single NVIDIA GeForce RTX 2080 Ti GPU. * L2G-NeRF takes about 4.5 and 8 hours for training in synthetic objects and real-world scenes, respectively, while training BARF takes about 8 and 10.5 hours. ## Software * Clone this repo by `git clone https://github.com/rover-xingyu/L2G-NeRF` * This code is developed with Python3. PyTorch 1.9+ is required. It is recommended use [Anaconda](https://www.anaconda.com/products/individual) to set up the environment, use `conda env create --file requirements.yaml python=3` to install the dependencies and activate it by `conda activate L2G-NeRF` -------------------------------------- # :key: Training and Evaluation ## Data download Both the Blender synthetic data and LLFF real-world data can be found in the [NeRF Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1). For convenience, you can download them with the following script: ```bash # Blender gdown --id 18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG # download nerf_synthetic.zip unzip nerf_synthetic.zip rm -f nerf_synthetic.zip mv nerf_synthetic data/blender # LLFF gdown --id 16VnMcF1KJYxN9QId6TClMsZRahHNMW5g # download nerf_llff_data.zip unzip nerf_llff_data.zip rm -f nerf_llff_data.zip mv nerf_llff_data data/llff ``` -------------------------------------- ## Running L2G-NeRF To train and evaluate L2G-NeRF: ```bash # and can be set to your likes, while is specific to datasets # NeRF (3D): Synthetic Objects # Blender (={chair,drums,ficus,hotdog,lego,materials,mic,ship}) python3 train.py \ --model=l2g_nerf --yaml=l2g_nerf_blender \ --group=exp_synthetic --name=l2g_lego \ --data.scene=lego --gpu=3 \ --data.root=/the/data/path/of/nerf_synthetic/ \ --camera.noise_r=0.07 --camera.noise_t=0.5 python3 evaluate.py \ --model=l2g_nerf --yaml=l2g_nerf_blender \ --group=exp_synthetic --name=l2g_lego \ --data.scene=lego --gpu=3 \ --data.root=/the/data/path/of/nerf_synthetic/ \ --data.val_sub= --resume # NeRF (3D): Real-World Scenes # LLFF (={fern,flower,fortress,horns,leaves,orchids,room,trex}) python3 train.py \ --model=l2g_nerf --yaml=l2g_nerf_llff \ --group=exp_LLFF --name=l2g_fern \ --data.scene=fern --gpu=3 \ --data.root=/the/data/path/of/nerf_llff_data/ \ --loss_weight.global_alignment=2 python3 evaluate.py \ --model=l2g_nerf --yaml=l2g_nerf_llff \ --group=exp_LLFF --name=l2g_fern \ --data.scene=fern --gpu=3 \ --data.root=/the/data/path/of/nerf_llff_data/ \ --loss_weight.global_alignment=2 \ --resume # Neural image Alignment (2D): Rigid # use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment python3 train.py \ --model=l2g_planar --yaml=l2g_planar \ --group=exp_planar --name=l2g_girl \ --warp.type=rigid --warp.dof=3 \ --data.image_fname=data/girl.jpg \ --data.image_size=[595,512] \ --data.patch_crop=[260,260] \ --seed=1 --gpu=3 # Neural image Alignment (2D): Homography # use the image of “cat” from ImageNet for homography image alignment python3 train.py \ --model=l2g_planar --yaml=l2g_planar \ --group=exp_planar --name=l2g_cat \ --warp.type=homography --warp.dof=8 \ --data.image_fname=data/cat.jpg \ --data.image_size=[360,480] \ --data.patch_crop=[180,180] \ --gpu=0 ``` -------------------------------------- ## Running BARF If you want to train and evaluate the BARF extension of the original NeRF model that jointly optimizes poses (coarse-to-fine positional encoding): ```bash # and can be set to your likes, while is specific to datasets # NeRF (3D): Synthetic Objects # Blender (={chair,drums,ficus,hotdog,lego,materials,mic,ship}) python3 train.py \ --model=barf --yaml=barf_blender \ --group=exp_synthetic --name=barf_lego \ --data.scene=lego --gpu=1 \ --data.root=/the/data/path/of/nerf_synthetic/ \ --camera.noise_r=0.07 --camera.noise_t=0.5 python3 evaluate.py \ --model=barf --yaml=barf_blender \ --group=exp_synthetic --name=barf_lego \ --data.scene=lego --gpu=1 \ --data.root=/the/data/path/of/nerf_synthetic/ \ --data.val_sub= --resume # NeRF (3D): Real-World Scenes # LLFF (={fern,flower,fortress,horns,leaves,orchids,room,trex}) python3 train.py \ --model=barf --yaml=barf_llff \ --group=exp_LLFF --name=barf_fern \ --data.scene=fern --gpu=1 \ --data.root=/the/data/path/of/nerf_llff_data/ python3 evaluate.py \ --model=barf --yaml=barf_llff \ --group=exp_LLFF --name=barf_fern \ --data.scene=fern --gpu=1 \ --data.root=/the/data/path/of/nerf_llff_data/ \ --resume # Neural image Alignment (2D): Rigid # use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment python3 train.py \ --model=planar --yaml=planar \ --group=exp_planar --name=barf_girl \ --warp.type=rigid --warp.dof=3 \ --data.image_fname=data/girl.jpg \ --data.image_size=[595,512] \ --data.patch_crop=[260,260] \ --seed=1 --gpu=1 # Neural image Alignment (2D): Homography # use the image of “cat” from ImageNet for homography image alignment python3 train.py \ --model=planar --yaml=planar \ --group=exp_planar --name=barf_cat \ --warp.type=homography --warp.dof=8 \ --data.image_fname=data/cat.jpg \ --data.image_size=[360,480] \ --data.patch_crop=[180,180] \ --gpu=2 ``` -------------------------------------- ## Running Naive If you want to train and evaluate the Naive extension of the original NeRF model that jointly optimizes poses (full positional encoding): ```bash # and can be set to your likes, while is specific to datasets # NeRF (3D): Synthetic Objects # Blender (={chair,drums,ficus,hotdog,lego,materials,mic,ship}) python3 train.py \ --model=barf --yaml=barf_blender \ --group=exp_synthetic --name=nerf_lego \ --data.scene=lego --gpu=2 \ --data.root=/home/cy/PNW/datasets/nerf_synthetic/ \ --barf_c2f=null \ --camera.noise_r=0.07 --camera.noise_t=0.5 python3 evaluate.py \ --model=barf --yaml=barf_blender \ --group=exp_synthetic --name=nerf_lego \ --data.scene=lego --gpu=2 \ --data.root=/home/cy/PNW/datasets/nerf_synthetic/ \ --barf_c2f=null \ --data.val_sub= --resume # NeRF (3D): Real-World Scenes # LLFF (={fern,flower,fortress,horns,leaves,orchids,room,trex}) python3 train.py \ --model=barf --yaml=barf_llff \ --group=exp_LLFF --name=nerf_fern \ --data.scene=fern --gpu=2 \ --data.root=/home/cy/PNW/datasets/nerf_llff_data/ \ --barf_c2f=null python3 evaluate.py \ --model=barf --yaml=barf_llff \ --group=exp_LLFF --name=nerf_fern \ --data.scene=fern --gpu=2 \ --data.root=/home/cy/PNW/datasets/nerf_llff_data/ \ --barf_c2f=null --resume # Neural image Alignment (2D): Rigid # use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment python3 train.py \ --model=planar --yaml=planar \ --group=exp_planar --name=naive_girl \ --warp.type=rigid --warp.dof=3 \ --data.image_fname=data/girl.jpg \ --data.image_size=[595,512] \ --data.patch_crop=[260,260] \ --seed=1 --gpu=2 --barf_c2f=null # Neural image Alignment (2D): Homography # use the image of “cat” from ImageNet for homography image alignment python3 train.py \ --model=planar --yaml=planar \ --group=exp_planar --name=naive_cat \ --warp.type=homography --warp.dof=8 \ --data.image_fname=data/cat.jpg \ --data.image_size=[360,480] \ --data.patch_crop=[180,180] \ --gpu=3 --barf_c2f=null ``` -------------------------------------- ## Running reference NeRF If you want to train and evaluate the reference NeRF models (assuming known camera poses): ```bash # and can be set to your likes, while is specific to datasets # NeRF (3D): Synthetic Objects # Blender (={chair,drums,ficus,hotdog,lego,materials,mic,ship}) python3 train.py \ --model=nerf --yaml=nerf_blender \ --group=exp_synthetic --name=ref_lego \ --data.scene=lego --gpu=0 \ --data.root=/home/cy/PNW/datasets/nerf_synthetic/ python3 evaluate.py \ --model=nerf --yaml=nerf_blender \ --group=exp_synthetic --name=ref_lego \ --data.scene=lego --gpu=0 \ --data.root=/home/cy/PNW/datasets/nerf_synthetic/ \ --data.val_sub= --resume # NeRF (3D): Real-World Scenes # LLFF (={fern,flower,fortress,horns,leaves,orchids,room,trex}) python3 train.py \ --model=nerf --yaml=nerf_llff \ --group=exp_LLFF --name=ref_fern \ --data.scene=fern --gpu=0 \ --data.root=/home/cy/PNW/datasets/nerf_llff_data/ python3 evaluate.py \ --model=nerf --yaml=nerf_llff \ --group=exp_LLFF --name=ref_fern \ --data.scene=fern --gpu=0 \ --data.root=/home/cy/PNW/datasets/nerf_llff_data/ \ --resume ``` -------------------------------------- # :laughing: Visualization ## Results and Videos All the results will be stored in the directory `output//`. You may want to organize your experiments by grouping different runs in the same group. Many videos will be created to visualize the pose optimization process and novel view synthesis. ## TensorBoard The TensorBoard events include the following: - **SCALARS**: the rendering losses and PSNR over the course of optimization. For L2G_NeRF/BARF/Naive, the rotational/translational errors with respect to the given poses are also computed. - **IMAGES**: visualization of the RGB images and the RGB/depth rendering. ## Visdom The visualization of 3D camera poses is provided in Visdom: Run `visdom -port 8600` to start the Visdom server. The Visdom host server is default to `localhost`; this can be overridden with `--visdom.server` (see `options/base.yaml` for details). If you want to disable Visdom visualization, add `--visdom!`. ## Mesh The `extract_mesh.py` script provides a simple way to extract the underlying 3D geometry using marching cubes (supporte for the Blender dataset). Run as follows: ```bash python3 extract_mesh.py \ --model=l2g_nerf --yaml=l2g_nerf_blender \ --group=exp_synthetic --name=l2g_lego \ --data.scene=lego --gpu=3 \ --data.root=/home/cy/PNW/datasets/nerf_synthetic/ \ --data.val_sub= --resume ``` -------------------------------------- # :mag_right: Codebase structure The main engine and network architecture in `model/l2g_nerf.py` inherit those from `model/nerf.py`. Some tips on using and understanding the codebase: - The computation graph for forward/backprop is stored in `var` throughout the codebase. - The losses are stored in `loss`. To add a new loss function, just implement it in `compute_loss()` and add its weight to `opt.loss_weight.`. It will automatically be added to the overall loss and logged to Tensorboard. - If you are using a multi-GPU machine, you can set `--gpu=` to specify which GPU to use. Multi-GPU training/evaluation is currently not supported. - To resume from a previous checkpoint, add `--resume=`, or just `--resume` to resume from the latest checkpoint. - To eliminate the global alignment objective, set `--loss_weight.global_alignment=null`, the ablation is equivalent to a local registration method. -------------------------------------- # Citation If you find this project useful for your research, please use the following BibTeX entry. ```bibtex @inproceedings{chen2023local, title={Local-to-global registration for bundle-adjusting neural radiance fields}, author={Chen, Yue and Chen, Xingyu and Wang, Xuan and Zhang, Qi and Guo, Yu and Shan, Ying and Wang, Fei}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={8264--8273}, year={2023} } ``` -------------------------------------- # Acknowledge Our code is based on the awesome pytorch implementation of Bundle-Adjusting Neural Radiance Fields ([BARF](https://github.com/chenhsuanlin/bundle-adjusting-NeRF)). We appreciate all the contributors. ================================================ FILE: camera.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import collections from easydict import EasyDict as edict import util from util import log,debug class Pose(): """ A class of operations on camera poses (PyTorch tensors with shape [...,3,4]) each [3,4] camera pose takes the form of [R|t] """ def __call__(self,R=None,t=None): # construct a camera pose from the given R and/or t assert(R is not None or t is not None) if R is None: if not isinstance(t,torch.Tensor): t = torch.tensor(t) R = torch.eye(3,device=t.device).repeat(*t.shape[:-1],1,1) elif t is None: if not isinstance(R,torch.Tensor): R = torch.tensor(R) t = torch.zeros(R.shape[:-1],device=R.device) else: if not isinstance(R,torch.Tensor): R = torch.tensor(R) if not isinstance(t,torch.Tensor): t = torch.tensor(t) assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3,3)) R = R.float() t = t.float() pose = torch.cat([R,t[...,None]],dim=-1) # [...,3,4] assert(pose.shape[-2:]==(3,4)) return pose def invert(self,pose,use_inverse=False): # invert a camera pose R,t = pose[...,:3],pose[...,3:] R_inv = R.inverse() if use_inverse else R.transpose(-1,-2) t_inv = (-R_inv@t)[...,0] pose_inv = self(R=R_inv,t=t_inv) return pose_inv def compose(self,pose_list): # compose a sequence of poses together # pose_new(x) = poseN o ... o pose2 o pose1(x) pose_new = pose_list[0] for pose in pose_list[1:]: pose_new = self.compose_pair(pose_new,pose) return pose_new def compose_pair(self,pose_a,pose_b): # pose_new(x) = pose_b o pose_a(x) R_a,t_a = pose_a[...,:3],pose_a[...,3:] R_b,t_b = pose_b[...,:3],pose_b[...,3:] R_new = R_b@R_a t_new = (R_b@t_a+t_b)[...,0] pose_new = self(R=R_new,t=t_new) return pose_new class Lie(): """ Lie algebra for SO(3) and SE(3) operations in PyTorch """ def so3_to_SO3(self,w): # [...,3] wx = self.skew_symmetric(w) theta = w.norm(dim=-1)[...,None,None] I = torch.eye(3,device=w.device,dtype=torch.float32) A = self.taylor_A(theta) B = self.taylor_B(theta) R = I+A*wx+B*wx@wx return R def SO3_to_so3(self,R,eps=1e-7): # [...,3,3] trace = R[...,0,0]+R[...,1,1]+R[...,2,2] theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0] w = torch.stack([w0,w1,w2],dim=-1) return w def se3_to_SE3(self,wu): # [...,3] w,u = wu.split([3,3],dim=-1) wx = self.skew_symmetric(w) theta = w.norm(dim=-1)[...,None,None] I = torch.eye(3,device=w.device,dtype=torch.float32) A = self.taylor_A(theta) B = self.taylor_B(theta) C = self.taylor_C(theta) R = I+A*wx+B*wx@wx V = I+B*wx+C*wx@wx Rt = torch.cat([R,(V@u[...,None])],dim=-1) return Rt def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4] R,t = Rt.split([3,1],dim=-1) w = self.SO3_to_so3(R) wx = self.skew_symmetric(w) theta = w.norm(dim=-1)[...,None,None] I = torch.eye(3,device=w.device,dtype=torch.float32) A = self.taylor_A(theta) B = self.taylor_B(theta) invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx u = (invV@t)[...,0] wu = torch.cat([w,u],dim=-1) return wu def skew_symmetric(self,w): w0,w1,w2 = w.unbind(dim=-1) O = torch.zeros_like(w0) wx = torch.stack([torch.stack([O,-w2,w1],dim=-1), torch.stack([w2,O,-w0],dim=-1), torch.stack([-w1,w0,O],dim=-1)],dim=-2) return wx def taylor_A(self,x,nth=10): # Taylor expansion of sin(x)/x ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): if i>0: denom *= (2*i)*(2*i+1) ans = ans+(-1)**i*x**(2*i)/denom return ans def taylor_B(self,x,nth=10): # Taylor expansion of (1-cos(x))/x**2 ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): denom *= (2*i+1)*(2*i+2) ans = ans+(-1)**i*x**(2*i)/denom return ans def taylor_C(self,x,nth=10): # Taylor expansion of (x-sin(x))/x**3 ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): denom *= (2*i+2)*(2*i+3) ans = ans+(-1)**i*x**(2*i)/denom return ans class Quaternion(): def q_to_R(self,q): # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion qa,qb,qc,qd = q.unbind(dim=-1) R = torch.stack([torch.stack([1-2*(qc**2+qd**2),2*(qb*qc-qa*qd),2*(qa*qc+qb*qd)],dim=-1), torch.stack([2*(qb*qc+qa*qd),1-2*(qb**2+qd**2),2*(qc*qd-qa*qb)],dim=-1), torch.stack([2*(qb*qd-qa*qc),2*(qa*qb+qc*qd),1-2*(qb**2+qc**2)],dim=-1)],dim=-2) return R def R_to_q(self,R,eps=1e-8): # [B,3,3] # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion # FIXME: this function seems a bit problematic, need to double-check row0,row1,row2 = R.unbind(dim=-2) R00,R01,R02 = row0.unbind(dim=-1) R10,R11,R12 = row1.unbind(dim=-1) R20,R21,R22 = row2.unbind(dim=-1) t = R[...,0,0]+R[...,1,1]+R[...,2,2] r = (1+t+eps).sqrt() qa = 0.5*r qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt() qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt() qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt() q = torch.stack([qa,qb,qc,qd],dim=-1) for i,qi in enumerate(q): if torch.isnan(qi).any(): K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1), torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R02],dim=-1), torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1), torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0 K = K[i] eigval,eigvec = torch.linalg.eigh(K) V = eigvec[:,eigval.argmax()] q[i] = torch.stack([V[3],V[0],V[1],V[2]]) return q def invert(self,q): qa,qb,qc,qd = q.unbind(dim=-1) norm = q.norm(dim=-1,keepdim=True) q_inv = torch.stack([qa,-qb,-qc,-qd],dim=-1)/norm**2 return q_inv def product(self,q1,q2): # [B,4] q1a,q1b,q1c,q1d = q1.unbind(dim=-1) q2a,q2b,q2c,q2d = q2.unbind(dim=-1) hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d, q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c, q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b, q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a],dim=-1) return hamil_prod pose = Pose() lie = Lie() quaternion = Quaternion() def to_hom(X): # get homogeneous coordinates of the input X_hom = torch.cat([X,torch.ones_like(X[...,:1])],dim=-1) return X_hom # basic operations of transforming 3D points between world/camera/image coordinates def world2cam(X,pose): # [B,N,3] X_hom = to_hom(X) return X_hom@pose.transpose(-1,-2) def cam2img(X,cam_intr): return X@cam_intr.transpose(-1,-2) def img2cam(X,cam_intr): return X@cam_intr.inverse().transpose(-1,-2) def cam2world(X,pose): X_hom = to_hom(X) pose_inv = Pose().invert(pose) return X_hom@pose_inv.transpose(-1,-2) def angle_to_rotation_matrix(a,axis): # get the rotation matrix from Euler angle around specific axis roll = dict(X=1,Y=2,Z=0)[axis] O = torch.zeros_like(a) I = torch.ones_like(a) M = torch.stack([torch.stack([a.cos(),-a.sin(),O],dim=-1), torch.stack([a.sin(),a.cos(),O],dim=-1), torch.stack([O,O,I],dim=-1)],dim=-2) M = M.roll((roll,roll),dims=(-2,-1)) return M def get_center_and_ray(opt,pose,intr=None): # [HW,2] # given the intrinsic/extrinsic matrices, get the camera center and ray directions] assert(opt.camera.model=="perspective") with torch.no_grad(): # compute image coordinate grid y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5) x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5) Y,X = torch.meshgrid(y_range,x_range) # [H,W] xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] # compute center and ray batch_size = len(pose) xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2] grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3] center_3D = torch.zeros_like(grid_3D) # [B,HW,3] # transform from camera to world coordinates grid_3D = cam2world(grid_3D,pose) # [B,HW,3] center_3D = cam2world(center_3D,pose) # [B,HW,3] ray = grid_3D-center_3D # [B,HW,3] return center_3D,ray def get_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [HW,2] # given the intrinsic matrices, get the grid_3D] assert(opt.camera.model=="perspective") with torch.no_grad(): # compute image coordinate grid y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5) x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5) Y,X = torch.meshgrid(y_range,x_range) # [H,W] xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] # compute grid_3D xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2] grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3] if ray_idx is not None: # consider only subset of rays grid_3D = grid_3D[:,ray_idx] return grid_3D def gather_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [HW,2] # given the intrinsic matrices, get the grid_3D] assert(opt.camera.model=="perspective") with torch.no_grad(): # compute image coordinate grid y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5) x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5) Y,X = torch.meshgrid(y_range,x_range) # [H,W] xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] # compute grid_3D xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2] grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3] if ray_idx is not None: # consider only subset of rays grid_3D = torch.gather(grid_3D, 1, ray_idx[...,None].expand(-1,-1,3)) return grid_3D def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False): if multi_samples: center,ray = center[:,:,None],ray[:,:,None] # x = c+dv points_3D = center+ray*depth # [B,HW,3]/[B,HW,N,3]/[N,3] return points_3D def convert_NDC(opt,center,ray,intr,near=1): # shift camera center (ray origins) to near plane (z=1) # (unlike conventional NDC, we assume the cameras are facing towards the +z direction) center = center+(near-center[...,2:])/ray[...,2:]*ray # projection cx,cy,cz = center.unbind(dim=-1) # [B,HW] rx,ry,rz = ray.unbind(dim=-1) # [B,HW] scale_x = intr[:,0,0]/intr[:,0,2] # [B] scale_y = intr[:,1,1]/intr[:,1,2] # [B] cnx = scale_x[:,None]*(cx/cz) cny = scale_y[:,None]*(cy/cz) cnz = 1-2*near/cz rnx = scale_x[:,None]*(rx/rz-cx/cz) rny = scale_y[:,None]*(ry/rz-cy/cz) rnz = 2*near/cz center_ndc = torch.stack([cnx,cny,cnz],dim=-1) # [B,HW,3] ray_ndc = torch.stack([rnx,rny,rnz],dim=-1) # [B,HW,3] return center_ndc,ray_ndc def rotation_distance(R1,R2,eps=1e-7): # http://www.boris-belousov.net/2016/12/01/quat-dist/ R_diff = R1@R2.transpose(-2,-1) trace = R_diff[...,0,0]+R_diff[...,1,1]+R_diff[...,2,2] angle = ((trace-1)/2).clamp(-1+eps,1-eps).acos_() # numerical stability near -1/+1 return angle def procrustes_analysis(X0,X1): # [N,3] # translation t0 = X0.mean(dim=0,keepdim=True) t1 = X1.mean(dim=0,keepdim=True) X0c = X0-t0 X1c = X1-t1 # scale s0 = (X0c**2).sum(dim=-1).mean().sqrt() s1 = (X1c**2).sum(dim=-1).mean().sqrt() X0cs = X0c/s0 X1cs = X1c/s1 # rotation (use double for SVD, float loses precision) U,S,V = (X0cs.t()@X1cs).double().svd(some=True) R = (U@V.t()).float() if R.det()<0: R[2] *= -1 # align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0 sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R) return sim3 def get_novel_view_poses(opt,pose_anchor,N=60,scale=1): # create circular viewpoints (small oscillations) theta = torch.arange(N)/N*2*np.pi R_x = angle_to_rotation_matrix((theta.sin()*0.05).asin(),"X") R_y = angle_to_rotation_matrix((theta.cos()*0.05).asin(),"Y") pose_rot = pose(R=R_y@R_x) pose_shift = pose(t=[0,0,-4*scale]) pose_shift2 = pose(t=[0,0,3.8*scale]) pose_oscil = pose.compose([pose_shift,pose_rot,pose_shift2]) pose_novel = pose.compose([pose_oscil,pose_anchor.cpu()[None]]) return pose_novel ================================================ FILE: data/base.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import torch.multiprocessing as mp import PIL import tqdm import threading,queue from easydict import EasyDict as edict import util from util import log,debug class Dataset(torch.utils.data.Dataset): def __init__(self,opt,split="train"): super().__init__() self.opt = opt self.split = split self.augment = split=="train" and opt.data.augment # define image sizes if opt.data.center_crop is not None: self.crop_H = int(self.raw_H*opt.data.center_crop) self.crop_W = int(self.raw_W*opt.data.center_crop) else: self.crop_H,self.crop_W = self.raw_H,self.raw_W if not opt.H or not opt.W: opt.H,opt.W = self.crop_H,self.crop_W def setup_loader(self,opt,shuffle=False,drop_last=False): loader = torch.utils.data.DataLoader(self, batch_size=opt.batch_size or 1, num_workers=opt.data.num_workers, shuffle=shuffle, drop_last=drop_last, pin_memory=False, # spews warnings in PyTorch 1.9 but should be True in general ) print("number of samples: {}".format(len(self))) return loader def get_list(self,opt): raise NotImplementedError def preload_worker(self,data_list,load_func,q,lock,idx_tqdm): while True: idx = q.get() data_list[idx] = load_func(self.opt,idx) with lock: idx_tqdm.update() q.task_done() def preload_threading(self,opt,load_func,data_str="images"): data_list = [None]*len(self) q = queue.Queue(maxsize=len(self)) idx_tqdm = tqdm.tqdm(range(len(self)),desc="preloading {}".format(data_str),leave=False) for i in range(len(self)): q.put(i) lock = threading.Lock() for ti in range(opt.data.num_workers): t = threading.Thread(target=self.preload_worker, args=(data_list,load_func,q,lock,idx_tqdm),daemon=True) t.start() q.join() idx_tqdm.close() assert(all(map(lambda x: x is not None,data_list))) return data_list def __getitem__(self,idx): raise NotImplementedError def get_image(self,opt,idx): raise NotImplementedError def generate_augmentation(self,opt): brightness = opt.data.augment.brightness or 0. contrast = opt.data.augment.contrast or 0. saturation = opt.data.augment.saturation or 0. hue = opt.data.augment.hue or 0. color_jitter = torchvision.transforms.ColorJitter.get_params( brightness=(1-brightness,1+brightness), contrast=(1-contrast,1+contrast), saturation=(1-saturation,1+saturation), hue=(-hue,hue), ) aug = edict( color_jitter=color_jitter, flip=np.random.randn()>0 if opt.data.augment.hflip else False, rot_angle=(np.random.rand()*2-1)*opt.data.augment.rotate if opt.data.augment.rotate else 0, ) return aug def preprocess_image(self,opt,image,aug=None): if aug is not None: image = self.apply_color_jitter(opt,image,aug.color_jitter) image = torchvision_F.hflip(image) if aug.flip else image image = image.rotate(aug.rot_angle,resample=PIL.Image.BICUBIC) # center crop if opt.data.center_crop is not None: self.crop_H = int(self.raw_H*opt.data.center_crop) self.crop_W = int(self.raw_W*opt.data.center_crop) image = torchvision_F.center_crop(image,(self.crop_H,self.crop_W)) else: self.crop_H,self.crop_W = self.raw_H,self.raw_W # resize if opt.data.image_size[0] is not None: image = image.resize((opt.W,opt.H)) image = torchvision_F.to_tensor(image) return image def preprocess_camera(self,opt,intr,pose,aug=None): intr,pose = intr.clone(),pose.clone() # center crop intr[0,2] -= (self.raw_W-self.crop_W)/2 intr[1,2] -= (self.raw_H-self.crop_H)/2 # resize intr[0] *= opt.W/self.crop_W intr[1] *= opt.H/self.crop_H return intr,pose def apply_color_jitter(self,opt,image,color_jitter): mode = image.mode if mode!="L": chan = image.split() rgb = PIL.Image.merge("RGB",chan[:3]) rgb = color_jitter(rgb) rgb_chan = rgb.split() image = PIL.Image.merge(mode,rgb_chan+chan[3:]) return image def __len__(self): return len(self.list) ================================================ FILE: data/blender.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import PIL import imageio from easydict import EasyDict as edict import json import pickle from . import base import camera from util import log,debug class Dataset(base.Dataset): def __init__(self,opt,split="train",subset=None): self.raw_H,self.raw_W = 800,800 super().__init__(opt,split) self.root = opt.data.root or "data/blender" self.path = "{}/{}".format(self.root,opt.data.scene) # load/parse metadata meta_fname = "{}/transforms_{}.json".format(self.path,split) with open(meta_fname) as file: self.meta = json.load(file) self.list = self.meta["frames"] self.focal = 0.5*self.raw_W/np.tan(0.5*self.meta["camera_angle_x"]) if subset: self.list = self.list[:subset] # preload dataset if opt.data.preload: self.images = self.preload_threading(opt,self.get_image) self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras") def prefetch_all_data(self,opt): assert(not opt.data.augment) # pre-iterate through all samples and group together self.all = torch.utils.data._utils.collate.default_collate([s for s in self]) def get_all_camera_poses(self,opt): pose_raw_all = [torch.tensor(f["transform_matrix"],dtype=torch.float32) for f in self.list] pose_canon_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0) return pose_canon_all def __getitem__(self,idx): opt = self.opt sample = dict(idx=idx) aug = self.generate_augmentation(opt) if self.augment else None image = self.images[idx] if opt.data.preload else self.get_image(opt,idx) image = self.preprocess_image(opt,image,aug=aug) intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx) intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug) sample.update( image=image, intr=intr, pose=pose, ) return sample def get_image(self,opt,idx): image_fname = "{}/{}.png".format(self.path,self.list[idx]["file_path"]) image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption.... return image def preprocess_image(self,opt,image,aug=None): image = super().preprocess_image(opt,image,aug=aug) rgb,mask = image[:3],image[3:] if opt.data.bgcolor is not None: rgb = rgb*mask+opt.data.bgcolor*(1-mask) return rgb def get_camera(self,opt,idx): intr = torch.tensor([[self.focal,0,self.raw_W/2], [0,self.focal,self.raw_H/2], [0,0,1]]).float() pose_raw = torch.tensor(self.list[idx]["transform_matrix"],dtype=torch.float32) pose = self.parse_raw_camera(opt,pose_raw) return intr,pose def parse_raw_camera(self,opt,pose_raw): pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1]))) pose = camera.pose.compose([pose_flip,pose_raw[:3]]) pose = camera.pose.invert(pose) return pose ================================================ FILE: data/iphone.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import PIL import imageio from easydict import EasyDict as edict import json import pickle from . import base import camera from util import log,debug class Dataset(base.Dataset): def __init__(self,opt,split="train",subset=None): self.raw_H,self.raw_W = 1080,1920 super().__init__(opt,split) self.root = opt.data.root or "data/iphone" self.path = "{}/{}".format(self.root,opt.data.scene) self.path_image = "{}/images".format(self.path) # self.list = sorted(os.listdir(self.path_image),key=lambda f: int(f.split(".")[0])) self.list = os.listdir(self.path_image) # manually split train/val subsets num_val_split = int(len(self)*opt.data.val_ratio) self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:] if subset: self.list = self.list[:subset] # preload dataset if opt.data.preload: self.images = self.preload_threading(opt,self.get_image) self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras") def prefetch_all_data(self,opt): assert(not opt.data.augment) # pre-iterate through all samples and group together self.all = torch.utils.data._utils.collate.default_collate([s for s in self]) def get_all_camera_poses(self,opt): # poses are unknown, so just return some dummy poses (identity transform) return camera.pose(t=torch.zeros(len(self),3)) def __getitem__(self,idx): opt = self.opt sample = dict(idx=idx) aug = self.generate_augmentation(opt) if self.augment else None image = self.images[idx] if opt.data.preload else self.get_image(opt,idx) image = self.preprocess_image(opt,image,aug=aug) intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx) intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug) sample.update( image=image, intr=intr, pose=pose, ) return sample def get_image(self,opt,idx): image_fname = "{}/{}".format(self.path_image,self.list[idx]) image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption.... return image def get_camera(self,opt,idx): self.focal = self.raw_W*4.2/(12.8/2.55) intr = torch.tensor([[self.focal,0,self.raw_W/2], [0,self.focal,self.raw_H/2], [0,0,1]]).float() pose = camera.pose(t=torch.zeros(3)) # dummy pose, won't be used return intr,pose ================================================ FILE: data/llff.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import PIL import imageio from easydict import EasyDict as edict import json import pickle from . import base import camera from util import log,debug class Dataset(base.Dataset): def __init__(self,opt,split="train",subset=None): self.raw_H,self.raw_W = 3024,4032 super().__init__(opt,split) self.root = opt.data.root or "data/llff" self.path = "{}/{}".format(self.root,opt.data.scene) self.path_image = "{}/images".format(self.path) image_fnames = sorted(os.listdir(self.path_image)) poses_raw,bounds = self.parse_cameras_and_bounds(opt) self.list = list(zip(image_fnames,poses_raw,bounds)) # manually split train/val subsets num_val_split = int(len(self)*opt.data.val_ratio) self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:] if subset: self.list = self.list[:subset] # preload dataset if opt.data.preload: self.images = self.preload_threading(opt,self.get_image) self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras") def prefetch_all_data(self,opt): assert(not opt.data.augment) # pre-iterate through all samples and group together self.all = torch.utils.data._utils.collate.default_collate([s for s in self]) def parse_cameras_and_bounds(self,opt): fname = "{}/poses_bounds.npy".format(self.path) data = torch.tensor(np.load(fname),dtype=torch.float32) # parse cameras (intrinsics and poses) cam_data = data[:,:-2].view([-1,3,5]) # [N,3,5] poses_raw = cam_data[...,:4] # [N,3,4] poses_raw[...,0],poses_raw[...,1] = poses_raw[...,1],-poses_raw[...,0] raw_H,raw_W,self.focal = cam_data[0,:,-1] assert(self.raw_H==raw_H and self.raw_W==raw_W) # parse depth bounds bounds = data[:,-2:] # [N,2] scale = 1./(bounds.min()*0.75) # not sure how this was determined poses_raw[...,3] *= scale bounds *= scale # roughly center camera poses poses_raw = self.center_camera_poses(opt,poses_raw) return poses_raw,bounds def center_camera_poses(self,opt,poses): # compute average pose center = poses[...,3].mean(dim=0) v1 = torch_F.normalize(poses[...,1].mean(dim=0),dim=0) v2 = torch_F.normalize(poses[...,2].mean(dim=0),dim=0) v0 = v1.cross(v2) pose_avg = torch.stack([v0,v1,v2,center],dim=-1)[None] # [1,3,4] # apply inverse of averaged pose poses = camera.pose.compose([poses,camera.pose.invert(pose_avg)]) return poses def get_all_camera_poses(self,opt): pose_raw_all = [tup[1] for tup in self.list] pose_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0) return pose_all def __getitem__(self,idx): opt = self.opt sample = dict(idx=idx) aug = self.generate_augmentation(opt) if self.augment else None image = self.images[idx] if opt.data.preload else self.get_image(opt,idx) image = self.preprocess_image(opt,image,aug=aug) intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx) intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug) sample.update( image=image, intr=intr, pose=pose, ) return sample def get_image(self,opt,idx): image_fname = "{}/{}".format(self.path_image,self.list[idx][0]) image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption.... return image def get_camera(self,opt,idx): intr = torch.tensor([[self.focal,0,self.raw_W/2], [0,self.focal,self.raw_H/2], [0,0,1]]).float() pose_raw = self.list[idx][1] pose = self.parse_raw_camera(opt,pose_raw) return intr,pose def parse_raw_camera(self,opt,pose_raw): pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1]))) pose = camera.pose.compose([pose_flip,pose_raw[:3]]) pose = camera.pose.invert(pose) pose = camera.pose.compose([pose_flip,pose]) return pose ================================================ FILE: evaluate.py ================================================ import numpy as np import os,sys,time import torch import importlib import options from util import log import warnings warnings.filterwarnings('ignore') def main(): log.process(os.getpid()) log.title("[{}] (PyTorch code for evaluating NeRF/BARF/L2G_NeRF)".format(sys.argv[0])) opt_cmd = options.parse_arguments(sys.argv[1:]) opt = options.set(opt_cmd=opt_cmd) with torch.cuda.device(opt.device): model = importlib.import_module("model.{}".format(opt.model)) m = model.Model(opt) m.load_dataset(opt,eval_split="test") m.build_networks(opt) if opt.model in ["barf", "l2g_nerf"]: m.generate_videos_pose(opt) m.restore_checkpoint(opt) if opt.data.dataset in ["blender","llff"]: m.evaluate_full(opt) m.generate_videos_synthesis(opt) if __name__=="__main__": main() ================================================ FILE: external/pohsun_ssim/LICENSE.txt ================================================ MIT ================================================ FILE: external/pohsun_ssim/README.md ================================================ # pytorch-ssim ### Differentiable structural similarity (SSIM) index. ![einstein](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/einstein.png) ![Max_ssim](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/max_ssim.gif) ## Installation 1. Clone this repo. 2. Copy "pytorch_ssim" folder in your project. ## Example ### basic usage ```python import pytorch_ssim import torch from torch.autograd import Variable img1 = Variable(torch.rand(1, 1, 256, 256)) img2 = Variable(torch.rand(1, 1, 256, 256)) if torch.cuda.is_available(): img1 = img1.cuda() img2 = img2.cuda() print(pytorch_ssim.ssim(img1, img2)) ssim_loss = pytorch_ssim.SSIM(window_size = 11) print(ssim_loss(img1, img2)) ``` ### maximize ssim ```python import pytorch_ssim import torch from torch.autograd import Variable from torch import optim import cv2 import numpy as np npImg1 = cv2.imread("einstein.png") img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0 img2 = torch.rand(img1.size()) if torch.cuda.is_available(): img1 = img1.cuda() img2 = img2.cuda() img1 = Variable( img1, requires_grad=False) img2 = Variable( img2, requires_grad = True) # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True) ssim_value = pytorch_ssim.ssim(img1, img2).data[0] print("Initial ssim:", ssim_value) # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True) ssim_loss = pytorch_ssim.SSIM() optimizer = optim.Adam([img2], lr=0.01) while ssim_value < 0.95: optimizer.zero_grad() ssim_out = -ssim_loss(img1, img2) ssim_value = - ssim_out.data[0] print(ssim_value) ssim_out.backward() optimizer.step() ``` ## Reference https://ece.uwaterloo.ca/~z70wang/research/ssim/ ================================================ FILE: external/pohsun_ssim/max_ssim.py ================================================ import pytorch_ssim import torch from torch.autograd import Variable from torch import optim import cv2 import numpy as np npImg1 = cv2.imread("einstein.png") img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0 img2 = torch.rand(img1.size()) if torch.cuda.is_available(): img1 = img1.cuda() img2 = img2.cuda() img1 = Variable( img1, requires_grad=False) img2 = Variable( img2, requires_grad = True) # Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True) ssim_value = pytorch_ssim.ssim(img1, img2).data[0] print("Initial ssim:", ssim_value) # Module: pytorch_ssim.SSIM(window_size = 11, size_average = True) ssim_loss = pytorch_ssim.SSIM() optimizer = optim.Adam([img2], lr=0.01) while ssim_value < 0.95: optimizer.zero_grad() ssim_out = -ssim_loss(img1, img2) ssim_value = - ssim_out.data[0] print(ssim_value) ssim_out.backward() optimizer.step() ================================================ FILE: external/pohsun_ssim/pytorch_ssim/__init__.py ================================================ import torch import torch.nn.functional as F from torch.autograd import Variable import numpy as np from math import exp def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window def _ssim(img1, img2, window, window_size, channel, size_average = True): mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1*mu2 sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) class SSIM(torch.nn.Module): def __init__(self, window_size = 11, size_average = True): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = 1 self.window = create_window(window_size, self.channel) def forward(self, img1, img2): (_, channel, _, _) = img1.size() if channel == self.channel and self.window.data.type() == img1.data.type(): window = self.window else: window = create_window(self.window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) self.window = window self.channel = channel return _ssim(img1, img2, window, self.window_size, channel, self.size_average) def ssim(img1, img2, window_size = 11, size_average = True): (_, channel, _, _) = img1.size() window = create_window(window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) return _ssim(img1, img2, window, window_size, channel, size_average) ================================================ FILE: external/pohsun_ssim/setup.cfg ================================================ [metadata] description-file = README.md ================================================ FILE: external/pohsun_ssim/setup.py ================================================ from distutils.core import setup setup( name = 'pytorch_ssim', packages = ['pytorch_ssim'], # this must be the same as the name above version = '0.1', description = 'Differentiable structural similarity (SSIM) index', author = 'Po-Hsun (Evan) Su', author_email = 'evan.pohsun.su@gmail.com', url = 'https://github.com/Po-Hsun-Su/pytorch-ssim', # use the URL to the github repo download_url = 'https://github.com/Po-Hsun-Su/pytorch-ssim/archive/0.1.tar.gz', # I'll explain this in a second keywords = ['pytorch', 'image-processing', 'deep-learning'], # arbitrary keywords classifiers = [], ) ================================================ FILE: extract_mesh.py ================================================ """Extracts a 3D mesh from a pretrained model using marching cubes.""" import importlib import sys import numpy as np import options import torch import tqdm import trimesh import mcubes from util import log,debug opt_cmd = options.parse_arguments(sys.argv[1:]) opt = options.set(opt_cmd=opt_cmd) with torch.cuda.device(opt.device),torch.no_grad(): model = importlib.import_module("model.{}".format(opt.model)) m = model.Model(opt) m.load_dataset(opt) m.build_networks(opt) m.restore_checkpoint(opt) t = torch.linspace(*opt.trimesh.range,opt.trimesh.res+1) # the best range might vary from model to model query = torch.stack(torch.meshgrid(t,t,t),dim=-1) query_flat = query.view(-1,3) density_all = [] for i in tqdm.trange(0,len(query_flat),opt.trimesh.chunk_size,leave=False): points = query_flat[None,i:i+opt.trimesh.chunk_size].to(opt.device) ray_unit = torch.zeros_like(points) # dummy ray to comply with interface, not used _,density_samples = m.graph.nerf.forward(opt,points,ray_unit=ray_unit,mode=None) density_all.append(density_samples.cpu()) density_all = torch.cat(density_all,dim=1)[0] density_all = density_all.view(*query.shape[:-1]).numpy() log.info("running marching cubes...") vertices,triangles = mcubes.marching_cubes(density_all,opt.trimesh.thres) vertices_centered = vertices/opt.trimesh.res-0.5 mesh = trimesh.Trimesh(vertices_centered,triangles) obj_fname = "{}/mesh.obj".format(opt.output_path) log.info("saving 3D mesh to {}...".format(obj_fname)) mesh.export(obj_fname) ================================================ FILE: model/barf.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import tqdm from easydict import EasyDict as edict import visdom import matplotlib.pyplot as plt import util,util_vis from util import log,debug from . import nerf import camera # ============================ main engine for training and evaluation ============================ class Model(nerf.Model): def __init__(self,opt): super().__init__(opt) def build_networks(self,opt): super().build_networks(opt) if opt.camera.noise: # pre-generate synthetic pose perturbation so3_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_r t_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_t self.graph.pose_noise = torch.cat([camera.lie.so3_to_SO3(so3_noise),t_noise[...,None]],dim=-1) # [...,3,4] self.graph.se3_refine = torch.nn.Embedding(len(self.train_data),6).to(opt.device) torch.nn.init.zeros_(self.graph.se3_refine.weight) pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device) # add synthetic pose perturbation to all training data if opt.data.dataset=="blender": pose = pose_GT if opt.camera.noise: pose = camera.pose.compose([pose, self.graph.pose_noise]) else: pose = self.graph.pose_eye[None].repeat(len(self.train_data),1,1) # use Embedding so it could be checkpointed self.graph.optimised_training_poses = torch.nn.Embedding(len(self.train_data),12,_weight=pose.view(-1,12)).to(opt.device) idx_range = torch.arange(len(self.train_data),dtype=torch.long,device=opt.device) idx_X,idx_Y = torch.meshgrid(idx_range,idx_range) self.graph.idx_grid = torch.stack([idx_X,idx_Y],dim=-1).view(-1,2) def setup_optimizer(self,opt): super().setup_optimizer(opt) optimizer = getattr(torch.optim,opt.optim.algo) self.optim_pose = optimizer([dict(params=self.graph.se3_refine.parameters(),lr=opt.optim.lr_pose)]) # set up scheduler if opt.optim.sched_pose: scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type) if opt.optim.lr_pose_end: assert(opt.optim.sched_pose.type=="ExponentialLR") opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter) kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!="type" } self.sched_pose = scheduler(self.optim_pose,**kwargs) def train_iteration(self,opt,var,loader): self.optim_pose.zero_grad() if opt.optim.warmup_pose: # simple linear warmup of pose learning rate self.optim_pose.param_groups[0]["lr_orig"] = self.optim_pose.param_groups[0]["lr"] # cache the original learning rate self.optim_pose.param_groups[0]["lr"] *= min(1,self.it/opt.optim.warmup_pose) loss = super().train_iteration(opt,var,loader) self.optim_pose.step() if opt.optim.warmup_pose: self.optim_pose.param_groups[0]["lr"] = self.optim_pose.param_groups[0]["lr_orig"] # reset learning rate if opt.optim.sched_pose: self.sched_pose.step() self.graph.nerf.progress.data.fill_(self.it/opt.max_iter) if opt.nerf.fine_sampling: self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter) return loss @torch.no_grad() def validate(self,opt,ep=None): pose,pose_GT = self.get_all_training_poses(opt) _,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT) super().validate(opt,ep=ep) @torch.no_grad() def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"): super().log_scalars(opt,var,loss,metric=metric,step=step,split=split) if split=="train": # log learning rate lr = self.optim_pose.param_groups[0]["lr"] self.tb.add_scalar("{0}/{1}".format(split,"lr_pose"),lr,step) # compute pose error if split=="train" and opt.data.dataset in ["blender","llff"]: pose,pose_GT = self.get_all_training_poses(opt) pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT) error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT) self.tb.add_scalar("{0}/error_R".format(split),error.R.mean(),step) self.tb.add_scalar("{0}/error_t".format(split),error.t.mean(),step) @torch.no_grad() def visualize(self,opt,var,step=0,split="train"): super().visualize(opt,var,step=step,split=split) if opt.visdom: if split=="val": pose,pose_GT = self.get_all_training_poses(opt) pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT) util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose_aligned,pose_GT]) @torch.no_grad() def get_all_training_poses(self,opt): # get ground-truth (canonical) camera poses pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device) pose = self.graph.optimised_training_poses.weight.data.detach().clone().view(-1,3,4) return pose,pose_GT @torch.no_grad() def prealign_cameras(self,opt,pose,pose_GT): # compute 3D similarity transform via Procrustes analysis center = torch.zeros(1,1,3,device=opt.device) center_pred = camera.cam2world(center,pose)[:,0] # [N,3] center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3] try: sim3 = camera.procrustes_analysis(center_GT,center_pred) except: print("warning: SVD did not converge...") sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device)) # align the camera poses center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0 R_aligned = pose[...,:3]@sim3.R.t() t_aligned = (-R_aligned@center_aligned[...,None])[...,0] pose_aligned = camera.pose(R=R_aligned,t=t_aligned) return pose_aligned,sim3 @torch.no_grad() def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT): # measure errors in rotation and translation R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1) R_GT,t_GT = pose_GT.split([3,1],dim=-1) R_error = camera.rotation_distance(R_aligned,R_GT) t_error = (t_aligned-t_GT)[...,0].norm(dim=-1) error = edict(R=R_error,t=t_error) return error @torch.no_grad() def evaluate_full(self,opt): self.graph.eval() # evaluate rotation/translation pose,pose_GT = self.get_all_training_poses(opt) pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT) error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT) print("--------------------------") print("rot: {:8.3f}".format(np.rad2deg(error.R.mean().cpu()))) print("trans: {:10.5f}".format(error.t.mean())) print("--------------------------") # dump numbers quant_fname = "{}/quant_pose.txt".format(opt.output_path) with open(quant_fname,"w") as file: for i,(err_R,err_t) in enumerate(zip(error.R,error.t)): file.write("{} {} {}\n".format(i,err_R.item(),err_t.item())) # evaluate novel view synthesis super().evaluate_full(opt) @torch.enable_grad() def evaluate_test_time_photometric_optim(self,opt,var): # use another se3 Parameter to absorb the remaining pose errors var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device)) optimizer = getattr(torch.optim,opt.optim.algo) optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)]) iterator = tqdm.trange(opt.optim.test_iter,desc="test-time optim.",leave=False,position=1) for it in iterator: optim_pose.zero_grad() var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test) var = self.graph.forward(opt,var,mode="test-optim") loss = self.graph.compute_loss(opt,var,mode="test-optim") loss = self.summarize_loss(opt,var,loss) loss.all.backward() optim_pose.step() iterator.set_postfix(loss="{:.3f}".format(loss.all)) return var @torch.no_grad() def generate_videos_pose(self,opt): self.graph.eval() fig = plt.figure(figsize=(10,10) if opt.data.dataset=="blender" else (16,8)) cam_path = "{}/poses".format(opt.output_path) os.makedirs(cam_path,exist_ok=True) ep_list = [] for ep in range(0,opt.max_iter+1,opt.freq.ckpt): # load checkpoint (0 is random init) if ep!=0: try: util.restore_checkpoint(opt,self,resume=ep) except: continue # get the camera poses pose,pose_ref = self.get_all_training_poses(opt) if opt.data.dataset in ["blender","llff"]: pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref) pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu() dict( blender=util_vis.plot_save_poses_blender, llff=util_vis.plot_save_poses, )[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep) else: pose = pose.detach().cpu() util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep) ep_list.append(ep) plt.close() # write videos print("writing videos...") list_fname = "{}/temp.list".format(cam_path) with open(list_fname,"w") as file: for ep in ep_list: file.write("file {}.png\n".format(ep)) cam_vid_fname = "{}/poses.mp4".format(opt.output_path) os.system("ffmpeg -y -r 4 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1".format(list_fname,cam_vid_fname)) os.remove(list_fname) # ============================ computation graph for forward/backprop ============================ class Graph(nerf.Graph): def __init__(self,opt): super().__init__(opt) self.nerf = NeRF(opt) if opt.nerf.fine_sampling: self.nerf_fine = NeRF(opt) self.pose_eye = torch.eye(3,4).to(opt.device) def forward(self,opt,var,mode=None): # rescale the size of the scene condition on the pose if opt.data.dataset=="blender": depth_min,depth_max = opt.nerf.depth.range position = camera.Pose().invert(self.optimised_training_poses.weight.data.detach().clone().view(-1,3,4))[...,-1] diameter = ((position[self.idx_grid[...,0]]-position[self.idx_grid[...,1]]).norm(dim=-1)).max() depth_min_new = (depth_min/(depth_max+depth_min))*diameter depth_max_new = (depth_max/(depth_max+depth_min))*diameter opt.nerf.depth.range = [depth_min_new, depth_max_new] # render images batch_size = len(var.idx) pose = self.get_pose(opt,var,mode=mode) if opt.nerf.rand_rays and mode in ["train","test-optim"]: # sample random rays for optimization var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size] ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1] else: # render full image (process in slices) ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \ self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1] var.update(ret) return var def get_pose(self,opt,var,mode=None): if mode=="train": # add the pre-generated pose perturbations if opt.data.dataset=="blender": if opt.camera.noise: var.pose_noise = self.pose_noise[var.idx] pose = camera.pose.compose([var.pose, var.pose_noise]) else: pose = var.pose else: pose = self.pose_eye # add learnable pose correction var.se3_refine = self.se3_refine.weight[var.idx] pose_refine = camera.lie.se3_to_SE3(var.se3_refine) pose = camera.pose.compose([pose_refine, pose]) self.optimised_training_poses.weight.data = pose.detach().clone().view(-1,12) elif mode in ["val","eval","test-optim"]: # align test pose to refined coordinate system (up to sim3) sim3 = self.sim3 center = torch.zeros(1,1,3,device=opt.device) center = camera.cam2world(center,var.pose)[:,0] # [N,3] center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1 R_aligned = var.pose[...,:3]@self.sim3.R t_aligned = (-R_aligned@center_aligned[...,None])[...,0] pose = camera.pose(R=R_aligned,t=t_aligned) # additionally factorize the remaining pose imperfection if opt.optim.test_photo and mode!="val": pose = camera.pose.compose([var.pose_refine_test, pose]) else: pose = var.pose return pose class NeRF(nerf.NeRF): def __init__(self,opt): super().__init__(opt) self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed def positional_encoding(self,opt,input,L): # [B,...,N] input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL] # coarse-to-fine: smoothly mask positional encoding for BARF if opt.barf_c2f is not None: # set weights for different frequency bands start,end = opt.barf_c2f alpha = (self.progress.data-start)/(end-start)*L k = torch.arange(L,dtype=torch.float32,device=opt.device) weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2 # apply weights shape = input_enc.shape input_enc = (input_enc.view(-1,L)*weight).view(*shape) return input_enc ================================================ FILE: model/base.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import torch.utils.tensorboard import visdom import importlib import tqdm from easydict import EasyDict as edict import util,util_vis from util import log,debug # ============================ main engine for training and evaluation ============================ class Model(): def __init__(self,opt): super().__init__() os.makedirs(opt.output_path,exist_ok=True) def load_dataset(self,opt,eval_split="val"): data = importlib.import_module("data.{}".format(opt.data.dataset)) log.info("loading training data...") self.train_data = data.Dataset(opt,split="train",subset=opt.data.train_sub) self.train_loader = self.train_data.setup_loader(opt,shuffle=True) log.info("loading test data...") if opt.data.val_on_test: eval_split = "test" self.test_data = data.Dataset(opt,split=eval_split,subset=opt.data.val_sub) self.test_loader = self.test_data.setup_loader(opt,shuffle=False) def build_networks(self,opt): graph = importlib.import_module("model.{}".format(opt.model)) log.info("building networks...") self.graph = graph.Graph(opt).to(opt.device) def setup_optimizer(self,opt): log.info("setting up optimizers...") optimizer = getattr(torch.optim,opt.optim.algo) self.optim = optimizer([dict(params=self.graph.parameters(),lr=opt.optim.lr)]) # set up scheduler if opt.optim.sched: scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type) kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" } self.sched = scheduler(self.optim,**kwargs) def restore_checkpoint(self,opt): epoch_start,iter_start = None,None if opt.resume: log.info("resuming from previous checkpoint...") epoch_start,iter_start = util.restore_checkpoint(opt,self,resume=opt.resume) elif opt.load is not None: log.info("loading weights from checkpoint {}...".format(opt.load)) epoch_start,iter_start = util.restore_checkpoint(opt,self,load_name=opt.load) else: log.info("initializing weights from scratch...") self.epoch_start = epoch_start or 0 self.iter_start = iter_start or 0 def setup_visualizer(self,opt): log.info("setting up visualizers...") if opt.tb: self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path,flush_secs=10) if opt.visdom: # check if visdom server is runninng is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port) retry = None while not is_open: retry = input("visdom port ({}) not open, retry? (y/n) ".format(opt.visdom.port)) if retry not in ["y","n"]: continue if retry=="y": is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port) else: break self.vis = visdom.Visdom(server=opt.visdom.server,port=opt.visdom.port,env=opt.group) def train(self,opt): # before training log.title("TRAINING START") self.timer = edict(start=time.time(),it_mean=None) self.it = self.iter_start # training if self.iter_start==0: self.validate(opt,ep=0) for self.ep in range(self.epoch_start,opt.max_epoch): self.train_epoch(opt) # after training if opt.tb: self.tb.flush() self.tb.close() if opt.visdom: self.vis.close() log.title("TRAINING DONE") def train_epoch(self,opt): # before train epoch self.graph.train() # train epoch loader = tqdm.tqdm(self.train_loader,desc="training epoch {}".format(self.ep+1),leave=False) for batch in loader: # train iteration var = edict(batch) var = util.move_to_device(var,opt.device) loss = self.train_iteration(opt,var,loader) # after train epoch lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr log.loss_train(opt,self.ep+1,lr,loss.all,self.timer) if opt.optim.sched: self.sched.step() if (self.ep+1)%opt.freq.val==0: self.validate(opt,ep=self.ep+1) if (self.ep+1)%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=self.ep+1,it=self.it) def train_iteration(self,opt,var,loader): # before train iteration self.timer.it_start = time.time() # train iteration self.optim.zero_grad() var = self.graph.forward(opt,var,mode="train") loss = self.graph.compute_loss(opt,var,mode="train") loss = self.summarize_loss(opt,var,loss) loss.all.backward() self.optim.step() # after train iteration if (self.it+1)%opt.freq.scalar==0: self.log_scalars(opt,var,loss,step=self.it+1,split="train") if (self.it+1)%opt.freq.vis==0: self.visualize(opt,var,step=self.it+1,split="train") self.it += 1 loader.set_postfix(it=self.it,loss="{:.3f}".format(loss.all)) self.timer.it_end = time.time() util.update_timer(opt,self.timer,self.ep,len(loader)) return loss def summarize_loss(self,opt,var,loss): loss_all = 0. assert("all" not in loss) # weigh losses for key in loss: assert(key in opt.loss_weight) assert(loss[key].shape==()) if opt.loss_weight[key] is not None: assert not torch.isinf(loss[key]),"loss {} is Inf".format(key) assert not torch.isnan(loss[key]),"loss {} is NaN".format(key) loss_all += 10**float(opt.loss_weight[key])*loss[key] loss.update(all=loss_all) return loss @torch.no_grad() def validate(self,opt,ep=None): self.graph.eval() loss_val = edict() loader = tqdm.tqdm(self.test_loader,desc="validating",leave=False) for it,batch in enumerate(loader): var = edict(batch) var = util.move_to_device(var,opt.device) var = self.graph.forward(opt,var,mode="val") loss = self.graph.compute_loss(opt,var,mode="val") loss = self.summarize_loss(opt,var,loss) for key in loss: loss_val.setdefault(key,0.) loss_val[key] += loss[key]*len(var.idx) loader.set_postfix(loss="{:.3f}".format(loss.all)) if it==0: self.visualize(opt,var,step=ep,split="val") for key in loss_val: loss_val[key] /= len(self.test_data) self.log_scalars(opt,var,loss_val,step=ep,split="val") log.loss_val(opt,loss_val.all) @torch.no_grad() def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"): for key,value in loss.items(): if key=="all": continue if opt.loss_weight[key] is not None: self.tb.add_scalar("{0}/loss_{1}".format(split,key),value,step) if metric is not None: for key,value in metric.items(): self.tb.add_scalar("{0}/{1}".format(split,key),value,step) @torch.no_grad() def visualize(self,opt,var,step=0,split="train"): raise NotImplementedError def save_checkpoint(self,opt,ep=0,it=0,latest=False): util.save_checkpoint(opt,self,ep=ep,it=it,latest=latest) if not latest: log.info("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group,opt.name,ep,it)) # ============================ computation graph for forward/backprop ============================ class Graph(torch.nn.Module): def __init__(self,opt): super().__init__() def forward(self,opt,var,mode=None): raise NotImplementedError return var def compute_loss(self,opt,var,mode=None): loss = edict() raise NotImplementedError return loss def L1_loss(self,pred,label=0): loss = (pred.contiguous()-label).abs() return loss.mean() def MSE_loss(self,pred,label=0): loss = (pred.contiguous()-label)**2 return loss.mean() ================================================ FILE: model/l2g_nerf.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import tqdm from easydict import EasyDict as edict import visdom import matplotlib.pyplot as plt import util,util_vis from util import log,debug from . import nerf import camera import roma # ============================ main engine for training and evaluation ============================ class Model(nerf.Model): def __init__(self,opt): super().__init__(opt) def build_networks(self,opt): super().build_networks(opt) if opt.camera.noise: # pre-generate synthetic pose perturbation so3_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_r t_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_t self.graph.pose_noise = torch.cat([camera.lie.so3_to_SO3(so3_noise),t_noise[...,None]],dim=-1) # [...,3,4] self.graph.warp_embedding = torch.nn.Embedding(len(self.train_data),opt.arch.embedding_dim).to(opt.device) self.graph.warp_mlp = localWarp(opt).to(opt.device) pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device) # add synthetic pose perturbation to all training data if opt.data.dataset=="blender": pose = pose_GT if opt.camera.noise: pose = camera.pose.compose([pose, self.graph.pose_noise]) else: pose = self.graph.pose_eye[None].repeat(len(self.train_data),1,1) # use Embedding so it could be checkpointed self.graph.optimised_training_poses = torch.nn.Embedding(len(self.train_data),12,_weight=pose.view(-1,12)).to(opt.device) # auto near/far for blender dataset if opt.data.dataset=="blender": idx_range = torch.arange(len(self.train_data),dtype=torch.long,device=opt.device) idx_X,idx_Y = torch.meshgrid(idx_range,idx_range) self.graph.idx_grid = torch.stack([idx_X,idx_Y],dim=-1).view(-1,2) if opt.error_map_size: self.graph.error_map = torch.ones([len(self.train_data), opt.error_map_size*opt.error_map_size], dtype=torch.float).to(opt.device) def setup_optimizer(self,opt): super().setup_optimizer(opt) optimizer = getattr(torch.optim,opt.optim.algo) self.optim_pose = optimizer([dict(params=self.graph.warp_embedding.parameters(),lr=opt.optim.lr_pose)]) self.optim_pose.add_param_group(dict(params=self.graph.warp_mlp.parameters(),lr=opt.optim.lr_pose)) # set up scheduler if opt.optim.sched_pose: scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type) if opt.optim.lr_pose_end: assert(opt.optim.sched_pose.type=="ExponentialLR") opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter) kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!="type" } self.sched_pose = scheduler(self.optim_pose,**kwargs) def train_iteration(self,opt,var,loader): self.optim_pose.zero_grad() if opt.optim.warmup_pose: # simple linear warmup of pose learning rate self.optim_pose.param_groups[0]["lr_orig"] = self.optim_pose.param_groups[0]["lr"] # cache the original learning rate self.optim_pose.param_groups[0]["lr"] *= min(1,self.it/opt.optim.warmup_pose) loss = super().train_iteration(opt,var,loader) self.optim_pose.step() if opt.optim.warmup_pose: self.optim_pose.param_groups[0]["lr"] = self.optim_pose.param_groups[0]["lr_orig"] # reset learning rate if opt.optim.sched_pose: self.sched_pose.step() self.graph.nerf.progress.data.fill_(self.it/opt.max_iter) if opt.nerf.fine_sampling: self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter) return loss @torch.no_grad() def validate(self,opt,ep=None): pose,pose_GT = self.get_all_training_poses(opt) _,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT) super().validate(opt,ep=ep) @torch.no_grad() def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"): super().log_scalars(opt,var,loss,metric=metric,step=step,split=split) if split=="train": # log learning rate lr = self.optim_pose.param_groups[0]["lr"] self.tb.add_scalar("{0}/{1}".format(split,"lr_pose"),lr,step) # compute pose error if split=="train" and opt.data.dataset in ["blender","llff"]: pose,pose_GT = self.get_all_training_poses(opt) pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT) error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT) self.tb.add_scalar("{0}/error_R".format(split),error.R.mean(),step) self.tb.add_scalar("{0}/error_t".format(split),error.t.mean(),step) @torch.no_grad() def visualize(self,opt,var,step=0,split="train"): super().visualize(opt,var,step=step,split=split) if opt.visdom: if split=="val": pose,pose_GT = self.get_all_training_poses(opt) pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT) util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose_aligned,pose_GT]) @torch.no_grad() def get_all_training_poses(self,opt): # get ground-truth (canonical) camera poses pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device) pose = self.graph.optimised_training_poses.weight.data.detach().clone().view(-1,3,4) return pose,pose_GT @torch.no_grad() def prealign_cameras(self,opt,pose,pose_GT): # compute 3D similarity transform via Procrustes analysis center = torch.zeros(1,1,3,device=opt.device) center_pred = camera.cam2world(center,pose)[:,0] # [N,3] center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3] try: sim3 = camera.procrustes_analysis(center_GT,center_pred) except: print("warning: SVD did not converge...") sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device)) # align the camera poses center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0 R_aligned = pose[...,:3]@sim3.R.t() t_aligned = (-R_aligned@center_aligned[...,None])[...,0] pose_aligned = camera.pose(R=R_aligned,t=t_aligned) return pose_aligned,sim3 @torch.no_grad() def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT): # measure errors in rotation and translation R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1) R_GT,t_GT = pose_GT.split([3,1],dim=-1) R_error = camera.rotation_distance(R_aligned,R_GT) t_error = (t_aligned-t_GT)[...,0].norm(dim=-1) error = edict(R=R_error,t=t_error) return error @torch.no_grad() def evaluate_full(self,opt): self.graph.eval() # evaluate rotation/translation pose,pose_GT = self.get_all_training_poses(opt) pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT) error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT) print("--------------------------") print("rot: {:8.3f}".format(np.rad2deg(error.R.mean().cpu()))) print("trans: {:10.5f}".format(error.t.mean())) print("--------------------------") # dump numbers quant_fname = "{}/quant_pose.txt".format(opt.output_path) with open(quant_fname,"w") as file: for i,(err_R,err_t) in enumerate(zip(error.R,error.t)): file.write("{} {} {}\n".format(i,err_R.item(),err_t.item())) # evaluate novel view synthesis super().evaluate_full(opt) @torch.enable_grad() def evaluate_test_time_photometric_optim(self,opt,var): # use another se3 Parameter to absorb the remaining pose errors var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device)) optimizer = getattr(torch.optim,opt.optim.algo) optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)]) iterator = tqdm.trange(opt.optim.test_iter,desc="test-time optim.",leave=False,position=1) for it in iterator: optim_pose.zero_grad() var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test) var = self.graph.forward(opt,var,mode="test-optim") loss = self.graph.compute_loss(opt,var,mode="test-optim") loss = self.summarize_loss(opt,var,loss) loss.all.backward() optim_pose.step() iterator.set_postfix(loss="{:.3f}".format(loss.all)) return var @torch.no_grad() def generate_videos_pose(self,opt): self.graph.eval() fig = plt.figure(figsize=(10,10) if opt.data.dataset=="blender" else (16,8)) cam_path = "{}/poses".format(opt.output_path) os.makedirs(cam_path,exist_ok=True) ep_list = [] for ep in range(0,opt.max_iter+1,opt.freq.ckpt): # load checkpoint (0 is random init) if ep!=0: try: util.restore_checkpoint(opt,self,resume=ep) except: continue # get the camera poses pose,pose_ref = self.get_all_training_poses(opt) if opt.data.dataset in ["blender","llff"]: pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref) pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu() dict( blender=util_vis.plot_save_poses_blender, llff=util_vis.plot_save_poses, )[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep) else: pose = pose.detach().cpu() util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep) ep_list.append(ep) plt.close() # write videos print("writing videos...") list_fname = "{}/temp.list".format(cam_path) with open(list_fname,"w") as file: for ep in ep_list: file.write("file {}.png\n".format(ep)) cam_vid_fname = "{}/poses.mp4".format(opt.output_path) os.system("ffmpeg -y -r 4 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1".format(list_fname,cam_vid_fname)) os.remove(list_fname) # ============================ computation graph for forward/backprop ============================ class Graph(nerf.Graph): def __init__(self,opt): super().__init__(opt) self.nerf = NeRF(opt) if opt.nerf.fine_sampling: self.nerf_fine = NeRF(opt) self.pose_eye = torch.eye(3,4).to(opt.device) def get_pose(self,opt,var,mode=None): if mode=="train": # add the pre-generated pose perturbations if opt.data.dataset=="blender": if opt.camera.noise: var.pose_noise = self.pose_noise[var.idx] pose = camera.pose.compose([var.pose, var.pose_noise]) else: pose = var.pose else: pose = self.pose_eye[None] # add learnable pose correction batch_size = len(var.idx) if opt.error_map_size: num_points = var.ray_idx.shape[1] camera_cords_grid_3D = camera.gather_camera_cords_grid_3D(opt,batch_size,intr=var.intr,ray_idx=var.ray_idx).detach() else: num_points = len(var.ray_idx) camera_cords_grid_3D = camera.get_camera_cords_grid_3D(opt,batch_size,intr=var.intr,ray_idx=var.ray_idx).detach() camera_cords_grid_2D = camera_cords_grid_3D[...,:2] embedding = self.warp_embedding.weight[var.idx,None,:].expand(-1,num_points,-1) local_se3_refine = self.warp_mlp(opt,torch.cat((camera_cords_grid_2D,embedding),dim=-1)) local_pose_refine = camera.lie.se3_to_SE3(local_se3_refine) local_pose = camera.pose.compose([local_pose_refine, pose[:,None,...]]) return local_pose elif mode in ["val","eval","test-optim"]: # align test pose to refined coordinate system (up to sim3) sim3 = self.sim3 center = torch.zeros(1,1,3,device=opt.device) center = camera.cam2world(center,var.pose)[:,0] # [N,3] center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1 R_aligned = var.pose[...,:3]@self.sim3.R t_aligned = (-R_aligned@center_aligned[...,None])[...,0] pose = camera.pose(R=R_aligned,t=t_aligned) if opt.optim.test_photo and mode!="val": pose = camera.pose.compose([var.pose_refine_test, pose]) else: pose = var.pose return pose def forward(self,opt,var,mode=None): # rescale the size of the scene condition on the pose if opt.data.dataset=="blender": depth_min,depth_max = opt.nerf.depth.range position = camera.Pose().invert(self.optimised_training_poses.weight.data.detach().clone().view(-1,3,4))[...,-1] diameter = ((position[self.idx_grid[...,0]]-position[self.idx_grid[...,1]]).norm(dim=-1)).max() depth_min_new = (depth_min/(depth_max+depth_min))*diameter depth_max_new = (depth_max/(depth_max+depth_min))*diameter opt.nerf.depth.range = [depth_min_new, depth_max_new] # render images batch_size = len(var.idx) if opt.nerf.rand_rays and mode in ["train"]: # sample rays for optimization if opt.error_map_size: sample_weight = self.error_map + 2*self.error_map.mean(-1,keepdim=True) # 1/3 importance + 2/3 random var.ray_idx_coarse = torch.multinomial(sample_weight, opt.nerf.rand_rays//batch_size, replacement=False) # [B, N], but in [0, opt.error_map_size*opt.error_map_size) inds_x, inds_y = var.ray_idx_coarse // opt.error_map_size, var.ray_idx_coarse % opt.error_map_size # `//` will throw a warning in torch 1.10... anyway. sx, sy = opt.H / opt.error_map_size, opt.W / opt.error_map_size inds_x = (inds_x * sx + torch.rand(batch_size, opt.nerf.rand_rays//batch_size, device=opt.device) * sx).long().clamp(max=opt.H - 1) inds_y = (inds_y * sy + torch.rand(batch_size, opt.nerf.rand_rays//batch_size, device=opt.device) * sy).long().clamp(max=opt.W - 1) var.ray_idx = inds_x * opt.W + inds_y else: var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]# 3/3 random local_pose = self.get_pose(opt,var,mode=mode) ret = self.local_render(opt,local_pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1] elif opt.nerf.rand_rays and mode in ["test-optim"]: # sample random rays for optimization var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size] pose = self.get_pose(opt,var,mode=mode) ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1] else: # render full image (process in slices) pose = self.get_pose(opt,var,mode=mode) ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \ self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1] var.update(ret) return var def compute_loss(self,opt,var,mode=None): loss = edict() batch_size = len(var.idx) image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1) if opt.nerf.rand_rays and mode in ["train","test-optim"]: if opt.error_map_size: image = torch.gather(image, 1, var.ray_idx[...,None].expand(-1,-1,3)) else: image = image[:,var.ray_idx] # compute image losses if opt.loss_weight.render is not None: render_error = ((var.rgb-image)**2).mean(-1) loss.render = render_error.mean() if mode in ["train"] and opt.error_map_size: ema_error = 0.1 * torch.gather(self.error_map, 1, var.ray_idx_coarse) + 0.9 * render_error.detach() self.error_map.scatter_(1, var.ray_idx_coarse, ema_error) if opt.loss_weight.render_fine is not None: assert(opt.nerf.fine_sampling) loss.render_fine = self.MSE_loss(var.rgb_fine,image) # global alignment if mode in ["train"]: source = torch.cat((var.camera_grid_3D,var.camera_center),dim=1) target = torch.cat((var.grid_3D,var.center),dim=1) R_global, t_global = roma.rigid_points_registration(target, source) svd_poses = torch.cat((R_global,t_global[...,None]),-1) self.optimised_training_poses.weight.data = svd_poses.detach().clone().view(-1,12) if opt.loss_weight.global_alignment is not None: loss.global_alignment = self.MSE_loss(target,camera.cam2world(source,svd_poses)) return loss def local_render(self,opt,local_pose,intr=None,ray_idx=None,mode=None): batch_size = len(local_pose) if opt.error_map_size: camera_grid_3D = camera.gather_camera_cords_grid_3D(opt,batch_size,intr=intr,ray_idx=ray_idx).detach() else: camera_grid_3D = camera.get_camera_cords_grid_3D(opt,batch_size,intr=intr,ray_idx=ray_idx).detach() camera_center = torch.zeros_like(camera_grid_3D) # [B,HW,3] grid_3D = camera.cam2world(camera_grid_3D[...,None,:],local_pose)[...,0,:] # [B,HW,3] center = camera.cam2world(camera_center[...,None,:],local_pose)[...,0,:] # [B,HW,3] ray = grid_3D-center # [B,HW,3] ret = edict(camera_grid_3D=camera_grid_3D, camera_center=camera_center, grid_3D=grid_3D, center=center) # [B,HW,K] use for global alignment if opt.camera.ndc: # convert center/ray representations to NDC center,ray = camera.convert_NDC(opt,center,ray,intr=intr) # render with main MLP depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1] rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode) rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples) ret.update(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K] # render with fine MLP from coarse MLP if opt.nerf.fine_sampling: with torch.no_grad(): # resample depth acoording to coarse empirical distribution depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1] depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1] depth_samples = depth_samples.sort(dim=2).values rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode) rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples) ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K] return ret class NeRF(nerf.NeRF): def __init__(self,opt): super().__init__(opt) self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed def positional_encoding(self,opt,input,L): # [B,...,N] input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL] # coarse-to-fine: smoothly mask positional encoding for BARF if opt.barf_c2f is not None: # set weights for different frequency bands start,end = opt.barf_c2f alpha = (self.progress.data-start)/(end-start)*L k = torch.arange(L,dtype=torch.float32,device=opt.device) weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2 # apply weights shape = input_enc.shape input_enc = (input_enc.view(-1,L)*weight).view(*shape) return input_enc class localWarp(torch.nn.Module): def __init__(self, opt): super().__init__() # point-wise se3 prediction input_2D_dim = 2 self.mlp_warp = torch.nn.ModuleList() L = util.get_layer_dims(opt.arch.layers_warp) for li,(k_in,k_out) in enumerate(L): if li==0: k_in = input_2D_dim+opt.arch.embedding_dim if li in opt.arch.skip_warp: k_in += input_2D_dim+opt.arch.embedding_dim linear = torch.nn.Linear(k_in,k_out) self.mlp_warp.append(linear) def forward(self,opt,uvf): feat = uvf for li,layer in enumerate(self.mlp_warp): if li in opt.arch.skip_warp: feat = torch.cat([feat,uvf],dim=-1) feat = layer(feat) if li!=len(self.mlp_warp)-1: feat = torch_F.relu(feat) warp = feat return warp ================================================ FILE: model/l2g_planar.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import tqdm from easydict import EasyDict as edict import PIL import PIL.Image,PIL.ImageDraw import imageio import util,util_vis from util import log,debug from . import base import warp import roma from kornia.geometry.homography import find_homography_dlt # ============================ main engine for training and evaluation ============================ class Model(base.Model): def __init__(self,opt): super().__init__(opt) opt.H_crop,opt.W_crop = opt.data.patch_crop def load_dataset(self,opt,eval_split=None): image_raw = PIL.Image.open(opt.data.image_fname).convert('RGB') self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device) def build_networks(self,opt): super().build_networks(opt) self.graph.warp_embedding = torch.nn.Embedding(opt.batch_size,opt.arch.embedding_dim).to(opt.device) self.graph.warp_mlp = localWarp(opt).to(opt.device) def setup_optimizer(self,opt): log.info("setting up optimizers...") optim_list = [ dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr), ] optimizer = getattr(torch.optim,opt.optim.algo) self.optim = optimizer(optim_list) # set up scheduler if opt.optim.sched: scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type) if opt.optim.lr_end: assert(opt.optim.sched.type=="ExponentialLR") opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter) kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" } self.sched = scheduler(self.optim,**kwargs) # warp self.optim_warp = optimizer([dict(params=self.graph.warp_embedding.parameters(),lr=opt.optim.lr_warp)]) self.optim_warp.add_param_group(dict(params=self.graph.warp_mlp.parameters(),lr=opt.optim.lr_warp)) # set up scheduler if opt.optim.sched_warp: scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_warp.type) if opt.optim.lr_warp_end: assert(opt.optim.sched_warp.type=="ExponentialLR") opt.optim.sched_warp.gamma = (opt.optim.lr_warp_end/opt.optim.lr_warp)**(1./opt.max_iter) kwargs = { k:v for k,v in opt.optim.sched_warp.items() if k!="type" } self.sched_warp = scheduler(self.optim_warp,**kwargs) def setup_visualizer(self,opt): super().setup_visualizer(opt) # set colors for visualization box_colors = ["#ff0000","#40afff","#9314ff","#ffd700","#00ff00"] box_colors = list(map(util.colorcode_to_number,box_colors)) self.box_colors = np.array(box_colors).astype(int) assert(len(self.box_colors)==opt.batch_size) # create visualization directory self.vis_path = "{}/vis".format(opt.output_path) os.makedirs(self.vis_path,exist_ok=True) self.video_fname = "{}/vis.mp4".format(opt.output_path) def train(self,opt): # before training log.title("TRAINING START") self.timer = edict(start=time.time(),it_mean=None) self.ep = self.it = self.vis_it = 0 self.graph.train() var = edict(idx=torch.arange(opt.batch_size)) # pre-generate perturbations self.homo_pert, self.rot_pert, self.trans_pert, var.image_pert = self.generate_warp_perturbation(opt) # train var = util.move_to_device(var,opt.device) loader = tqdm.trange(opt.max_iter,desc="training",leave=False) # visualize initial state var = self.graph.forward(opt,var) self.visualize(opt,var,step=0) for it in loader: # train iteration loss = self.train_iteration(opt,var,loader) # after training os.system("ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}".format(self.vis_path,self.video_fname)) self.save_checkpoint(opt,ep=None,it=self.it) if opt.tb: self.tb.flush() self.tb.close() if opt.visdom: self.vis.close() log.title("TRAINING DONE") def train_iteration(self,opt,var,loader): self.optim_warp.zero_grad() loss = super().train_iteration(opt,var,loader) self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter) self.optim_warp.step() if opt.optim.sched_warp: self.sched_warp.step() if opt.optim.sched: self.sched.step() return loss def generate_warp_perturbation(self,opt): # pre-generate perturbations (translational noise + homography noise) def create_random_perturbation(batch_size): if opt.warp.dof==1: homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r trans_pert = torch.zeros(batch_size,2,device=opt.device)*opt.warp.noise_t elif opt.warp.dof==2: homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h rot_pert = torch.zeros(batch_size,1,device=opt.device)*opt.warp.noise_r trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t elif opt.warp.dof==3: homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t elif opt.warp.dof==8: homo_pert = torch.randn(batch_size,8,device=opt.device)*opt.warp.noise_h homo_pert[:,:2]=0 rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t else: assert(False) return homo_pert, rot_pert, trans_pert homo_pert = torch.zeros(opt.batch_size,8,device=opt.device) rot_pert = torch.zeros(opt.batch_size,1,device=opt.device) trans_pert = torch.zeros(opt.batch_size,2,device=opt.device) for i in range(opt.batch_size): homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1) while not warp.check_corners_in_range_compose(opt, homo_pert_i, rot_pert_i, trans_pert_i): homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1) homo_pert[i], rot_pert[i], trans_pert[i] = homo_pert_i, rot_pert_i, trans_pert_i if opt.warp.fix_first: homo_pert[0],rot_pert[0],trans_pert[0] = 0,0,0 # create warped image patches xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2] xy_grid_hom = warp.camera.to_hom(xy_grid) warp_matrix = warp.lie.sl3_to_SL3(homo_pert) warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1) xy_grid_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2] warp_matrix = warp.lie.so2_to_SO2(rot_pert) xy_grid_warped = xy_grid_warped@warp_matrix.transpose(-2,-1) # [B,HW,2] xy_grid_warped = xy_grid_warped+trans_pert[...,None,:] xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2]) xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W, xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1) image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1) image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False) return homo_pert, rot_pert, trans_pert, image_pert_all def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert): image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA") draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0)) draw = PIL.ImageDraw.Draw(draw_pil) corners_all = warp.warp_corners_compose(opt, homo_pert, rot_pert, trans_pert) corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 for i,corners in enumerate(corners_all): P = [tuple(float(n) for n in corners[j]) for j in range(4)] draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3) image_pil.alpha_composite(draw_pil) image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB")) return image_tensor def visualize_patches_use_matrix(self,opt,warp_matrix): image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA") draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0)) draw = PIL.ImageDraw.Draw(draw_pil) corners_all = warp.warp_corners_use_matrix(opt,warp_matrix) corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 for i,corners in enumerate(corners_all): P = [tuple(float(n) for n in corners[j]) for j in range(4)] draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3) image_pil.alpha_composite(draw_pil) image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB")) return image_tensor @torch.no_grad() def predict_entire_image(self,opt): xy_grid = warp.get_normalized_pixel_grid(opt)[:1] rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3] image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1) return image @torch.no_grad() def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"): super().log_scalars(opt,var,loss,metric=metric,step=step,split=split) # compute PSNR psnr = -10*loss.render.log10() self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step) # warp error pred_corners = warp.warp_corners_use_matrix(opt,self.graph.warp_matrix) pred_corners[...,0] = (pred_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 pred_corners[...,1] = (pred_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 gt_corners = warp.warp_corners_compose(opt, self.homo_pert, self.rot_pert, self.trans_pert) gt_corners[...,0] = (gt_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 gt_corners[...,1] = (gt_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 warp_error = (pred_corners-gt_corners).norm(dim=-1).mean() self.tb.add_scalar("{0}/{1}".format(split,"warp error"),warp_error,step) @torch.no_grad() def visualize(self,opt,var,step=0,split="train"): # dump frames for writing to video frame_GT = self.visualize_patches_compose(opt,self.homo_pert, self.rot_pert, self.trans_pert) frame = self.visualize_patches_use_matrix(opt,self.graph.warp_matrix) frame2 = self.predict_entire_image(opt) frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy() imageio.imsave("{}/{}.png".format(self.vis_path,self.vis_it),frame_cat) self.vis_it += 1 # visualize in Tensorboard if opt.tb: colors = self.box_colors util_vis.tb_image(opt,self.tb,step,split,"image_pert",util_vis.color_border(var.image_pert,colors)) util_vis.tb_image(opt,self.tb,step,split,"rgb_warped",util_vis.color_border(var.rgb_warped_map,colors)) util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes",frame[None]) util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes_GT",frame_GT[None]) util_vis.tb_image(opt,self.tb,self.it+1,"train","image_entire",frame2[None]) local_warp_field = torch.cat([(var.local_warp_field+1)/2,torch.zeros_like(var.local_warp_field[:,:1,...])],dim=1) # [B,3,H,W] global_warp_field = torch.cat([(var.global_warp_field+1)/2,torch.zeros_like(var.global_warp_field[:,:1,...])],dim=1) # [B,3,H,W] util_vis.tb_image(opt,self.tb,step,split,"warp_field_local",util_vis.color_border(local_warp_field,colors)) util_vis.tb_image(opt,self.tb,step,split,"warp_field_global",util_vis.color_border(global_warp_field,colors)) # ============================ computation graph for forward/backprop ============================ class Graph(base.Graph): def __init__(self,opt): super().__init__(opt) self.neural_image = NeuralImageFunction(opt) def forward(self,opt,var,mode=None): # warp xy_grid = warp.get_normalized_pixel_grid_crop(opt) warp_embedding = self.warp_embedding.weight[:,None,:].expand(-1,xy_grid.shape[1],-1) local_warp_param = self.warp_mlp(opt,torch.cat((xy_grid,warp_embedding),dim=-1)) if opt.warp.fix_first: local_warp_param[0] = 0 if opt.warp.dof==1: local_warp_matrix = warp.lie.so2_to_SO2(local_warp_param) var.local_warped_grid = (xy_grid[...,None,:]@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2] self.warp_matrix = roma.rigid_vectors_registration(xy_grid, var.local_warped_grid) var.global_warped_grid = xy_grid@self.warp_matrix.transpose(-2,-1) # [B,HW,2] elif opt.warp.dof==2: var.local_warped_grid = xy_grid + local_warp_param self.warp_matrix = var.local_warped_grid.mean(-2)-xy_grid.mean(-2) var.global_warped_grid = xy_grid + self.warp_matrix[...,None,:] elif opt.warp.dof==3: xy_grid_hom = warp.camera.to_hom(xy_grid[...,None,:]) local_warp_matrix = warp.lie.se2_to_SE2(local_warp_param) var.local_warped_grid = (xy_grid_hom@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2] R_global, t_global = roma.rigid_points_registration(xy_grid, var.local_warped_grid) self.warp_matrix = torch.cat((R_global,t_global[...,None]),-1) xy_grid_hom = warp.camera.to_hom(xy_grid) var.global_warped_grid = xy_grid_hom@self.warp_matrix.transpose(-2,-1) # [B,HW,2] elif opt.warp.dof==8: xy_grid_hom = warp.camera.to_hom(xy_grid[...,None,:]) local_warp_matrix = warp.lie.se2_to_SE2(local_warp_param) var.local_warped_grid = (xy_grid_hom@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2] self.warp_matrix = find_homography_dlt(xy_grid, var.local_warped_grid) xy_grid_hom = warp.camera.to_hom(xy_grid) global_warped_grid = xy_grid_hom@self.warp_matrix.transpose(-2,-1) var.global_warped_grid = global_warped_grid[...,:2]/(global_warped_grid[...,2:]+1e-8) # [B,HW,2] # render images var.rgb_warped = self.neural_image.forward(opt,var.local_warped_grid) # [B,HW,3] var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W] var.local_warp_field = var.local_warped_grid.view(opt.batch_size,opt.H_crop,opt.W_crop,2).permute(0,3,1,2) # [B,2,H,W] var.global_warp_field = var.global_warped_grid.view(opt.batch_size,opt.H_crop,opt.W_crop,2).permute(0,3,1,2) # [B,2,H,W] return var def compute_loss(self,opt,var,mode=None): loss = edict() if opt.loss_weight.render is not None: image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1) loss.render = self.MSE_loss(var.rgb_warped,image_pert) if opt.loss_weight.global_alignment is not None: loss.global_alignment = self.MSE_loss(var.local_warped_grid,var.global_warped_grid) return loss class NeuralImageFunction(torch.nn.Module): def __init__(self,opt): super().__init__() self.define_network(opt) self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed def define_network(self,opt): input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2 # point-wise RGB prediction self.mlp = torch.nn.ModuleList() L = util.get_layer_dims(opt.arch.layers) for li,(k_in,k_out) in enumerate(L): if li==0: k_in = input_2D_dim if li in opt.arch.skip: k_in += input_2D_dim linear = torch.nn.Linear(k_in,k_out) if opt.barf_c2f and li==0: # rescale first layer init (distribution was for pos.enc. but only xy is first used) scale = np.sqrt(input_2D_dim/2.) linear.weight.data *= scale linear.bias.data *= scale self.mlp.append(linear) def forward(self,opt,coord_2D): # [B,...,3] if opt.arch.posenc: points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D) points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3] else: points_enc = coord_2D feat = points_enc # extract implicit features for li,layer in enumerate(self.mlp): if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1) feat = layer(feat) if li!=len(self.mlp)-1: feat = torch_F.relu(feat) rgb = feat.sigmoid_() # [B,...,3] return rgb def positional_encoding(self,opt,input,L): # [B,...,N] shape = input.shape freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L] spectrum = input[...,None]*freq # [B,...,N,L] sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] # coarse-to-fine: smoothly mask positional encoding for BARF if opt.barf_c2f is not None: # set weights for different frequency bands start,end = opt.barf_c2f alpha = (self.progress.data-start)/(end-start)*L k = torch.arange(L,dtype=torch.float32,device=opt.device) weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2 # apply weights shape = input_enc.shape input_enc = (input_enc.view(-1,L)*weight).view(*shape) return input_enc class localWarp(torch.nn.Module): def __init__(self, opt): super().__init__() # point-wise se3 prediction input_2D_dim = 2 self.mlp_warp = torch.nn.ModuleList() L = util.get_layer_dims(opt.arch.layers_warp) for li,(k_in,k_out) in enumerate(L): if li==0: k_in = input_2D_dim+opt.arch.embedding_dim if li in opt.arch.skip_warp: k_in += input_2D_dim+opt.arch.embedding_dim linear = torch.nn.Linear(k_in,k_out) self.mlp_warp.append(linear) def forward(self,opt,uvf): feat = uvf for li,layer in enumerate(self.mlp_warp): if li in opt.arch.skip_warp: feat = torch.cat([feat,uvf],dim=-1) feat = layer(feat) if li!=len(self.mlp_warp)-1: feat = torch_F.relu(feat) warp = feat return warp ================================================ FILE: model/nerf.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import tqdm from easydict import EasyDict as edict import lpips from external.pohsun_ssim import pytorch_ssim import util,util_vis from util import log,debug from . import base import camera # ============================ main engine for training and evaluation ============================ class Model(base.Model): def __init__(self,opt): super().__init__(opt) self.lpips_loss = lpips.LPIPS(net="alex").to(opt.device) def load_dataset(self,opt,eval_split="val"): super().load_dataset(opt,eval_split=eval_split) # prefetch all training data self.train_data.prefetch_all_data(opt) self.train_data.all = edict(util.move_to_device(self.train_data.all,opt.device)) def setup_optimizer(self,opt): log.info("setting up optimizers...") optimizer = getattr(torch.optim,opt.optim.algo) self.optim = optimizer([dict(params=self.graph.nerf.parameters(),lr=opt.optim.lr)]) if opt.nerf.fine_sampling: self.optim.add_param_group(dict(params=self.graph.nerf_fine.parameters(),lr=opt.optim.lr)) # set up scheduler if opt.optim.sched: scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type) if opt.optim.lr_end: assert(opt.optim.sched.type=="ExponentialLR") opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter) kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" } self.sched = scheduler(self.optim,**kwargs) def train(self,opt): # before training log.title("TRAINING START") self.timer = edict(start=time.time(),it_mean=None) self.graph.train() self.ep = 0 # dummy for timer # training if self.iter_start==0: self.validate(opt,0) loader = tqdm.trange(opt.max_iter,desc="training",leave=False) for self.it in loader: if self.it/dev/null 2>&1".format(test_path,rgb_vid_fname)) os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(test_path,depth_vid_fname)) else: pose_pred,pose_GT = self.get_all_training_poses(opt) poses = pose_pred if opt.model in ["barf","l2g_nerf"] else pose_GT if opt.model in ["barf","l2g_nerf"] and opt.data.dataset=="llff": _,sim3 = self.prealign_cameras(opt,pose_pred,pose_GT) scale = sim3.s1/sim3.s0 else: scale = 1 # rotate novel views around the "center" camera of all poses idx_center = (poses-poses.mean(dim=0,keepdim=True))[...,3].norm(dim=-1).argmin() pose_novel = camera.get_novel_view_poses(opt,poses[idx_center],N=60,scale=scale).to(opt.device) # render the novel views novel_path = "{}/novel_view".format(opt.output_path) os.makedirs(novel_path,exist_ok=True) pose_novel_tqdm = tqdm.tqdm(pose_novel,desc="rendering novel views",leave=False) intr = edict(next(iter(self.test_loader))).intr[:1].to(opt.device) # grab intrinsics for i,pose in enumerate(pose_novel_tqdm): ret = self.graph.render_by_slices(opt,pose[None],intr=intr) if opt.nerf.rand_rays else \ self.graph.render(opt,pose[None],intr=intr) depth=ret.depth/scale if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1) invdepth = (1-depth)/ret.opacity if opt.camera.ndc else 1/(depth/ret.opacity+eps) invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno") rgb_map = ret.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W] invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W] torchvision_F.to_pil_image(rgb_map.cpu()[0]).save("{}/rgb_{}.png".format(novel_path,i)) torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save("{}/depth_{}.png".format(novel_path,i)) # write videos print("writing videos...") rgb_vid_fname = "{}/novel_view_rgb.mp4".format(opt.output_path) depth_vid_fname = "{}/novel_view_depth.mp4".format(opt.output_path) os.system("ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,rgb_vid_fname)) os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,depth_vid_fname)) # ============================ computation graph for forward/backprop ============================ class Graph(base.Graph): def __init__(self,opt): super().__init__(opt) self.nerf = NeRF(opt) if opt.nerf.fine_sampling: self.nerf_fine = NeRF(opt) def forward(self,opt,var,mode=None): batch_size = len(var.idx) pose = self.get_pose(opt,var,mode=mode) # render images if opt.nerf.rand_rays and mode in ["train","test-optim"]: # sample random rays for optimization var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size] ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1] else: # render full image (process in slices) ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \ self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1] var.update(ret) return var def compute_loss(self,opt,var,mode=None): loss = edict() batch_size = len(var.idx) image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1) if opt.nerf.rand_rays and mode in ["train","test-optim"]: image = image[:,var.ray_idx] # compute image losses if opt.loss_weight.render is not None: loss.render = self.MSE_loss(var.rgb,image) if opt.loss_weight.render_fine is not None: assert(opt.nerf.fine_sampling) loss.render_fine = self.MSE_loss(var.rgb_fine,image) return loss def get_pose(self,opt,var,mode=None): return var.pose def render(self,opt,pose,intr=None,ray_idx=None,mode=None): batch_size = len(pose) center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3] while ray.isnan().any(): # TODO: weird bug, ray becomes NaN arbitrarily if batch_size>1, not deterministic reproducible center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3] if ray_idx is not None: # consider only subset of rays center,ray = center[:,ray_idx],ray[:,ray_idx] if opt.camera.ndc: # convert center/ray representations to NDC center,ray = camera.convert_NDC(opt,center,ray,intr=intr) # render with main MLP depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1] rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode) rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples) ret = edict(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K] # render with fine MLP from coarse MLP if opt.nerf.fine_sampling: with torch.no_grad(): # resample depth acoording to coarse empirical distribution depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1] depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1] depth_samples = depth_samples.sort(dim=2).values rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode) rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples) ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K] return ret def render_by_slices(self,opt,pose,intr=None,mode=None): ret_all = edict(rgb=[],depth=[],opacity=[]) if opt.nerf.fine_sampling: ret_all.update(rgb_fine=[],depth_fine=[],opacity_fine=[]) # render the image by slices for memory considerations for c in range(0,opt.H*opt.W,opt.nerf.rand_rays): ray_idx = torch.arange(c,min(c+opt.nerf.rand_rays,opt.H*opt.W),device=opt.device) ret = self.render(opt,pose,intr=intr,ray_idx=ray_idx,mode=mode) # [B,R,3],[B,R,1] for k in ret: ret_all[k].append(ret[k]) # group all slices of images for k in ret_all: ret_all[k] = torch.cat(ret_all[k],dim=1) return ret_all def sample_depth(self,opt,batch_size,num_rays=None): depth_min,depth_max = opt.nerf.depth.range num_rays = num_rays or opt.H*opt.W rand_samples = torch.rand(batch_size,num_rays,opt.nerf.sample_intvs,1,device=opt.device) if opt.nerf.sample_stratified else 0.5 rand_samples += torch.arange(opt.nerf.sample_intvs,device=opt.device)[None,None,:,None].float() # [B,HW,N,1] depth_samples = rand_samples/opt.nerf.sample_intvs*(depth_max-depth_min)+depth_min # [B,HW,N,1] depth_samples = dict( metric=depth_samples, inverse=1/(depth_samples+1e-8), )[opt.nerf.depth.param] return depth_samples def sample_depth_from_pdf(self,opt,pdf): depth_min,depth_max = opt.nerf.depth.range # get CDF from PDF (along last dimension) cdf = pdf.cumsum(dim=-1) # [B,HW,N] cdf = torch.cat([torch.zeros_like(cdf[...,:1]),cdf],dim=-1) # [B,HW,N+1] # take uniform samples grid = torch.linspace(0,1,opt.nerf.sample_intvs_fine+1,device=opt.device) # [Nf+1] unif = 0.5*(grid[:-1]+grid[1:]).repeat(*cdf.shape[:-1],1) # [B,HW,Nf] idx = torch.searchsorted(cdf,unif,right=True) # [B,HW,Nf] \in {1...N} # inverse transform sampling from CDF depth_bin = torch.linspace(depth_min,depth_max,opt.nerf.sample_intvs+1,device=opt.device) # [N+1] depth_bin = depth_bin.repeat(*cdf.shape[:-1],1) # [B,HW,N+1] depth_low = depth_bin.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf] depth_high = depth_bin.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf] cdf_low = cdf.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf] cdf_high = cdf.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf] # linear interpolation t = (unif-cdf_low)/(cdf_high-cdf_low+1e-8) # [B,HW,Nf] depth_samples = depth_low+t*(depth_high-depth_low) # [B,HW,Nf] return depth_samples[...,None] # [B,HW,Nf,1] class NeRF(torch.nn.Module): def __init__(self,opt): super().__init__() self.define_network(opt) def define_network(self,opt): input_3D_dim = 3+6*opt.arch.posenc.L_3D if opt.arch.posenc else 3 if opt.nerf.view_dep: input_view_dim = 3+6*opt.arch.posenc.L_view if opt.arch.posenc else 3 # point-wise feature self.mlp_feat = torch.nn.ModuleList() L = util.get_layer_dims(opt.arch.layers_feat) for li,(k_in,k_out) in enumerate(L): if li==0: k_in = input_3D_dim if li in opt.arch.skip: k_in += input_3D_dim if li==len(L)-1: k_out += 1 linear = torch.nn.Linear(k_in,k_out) if opt.arch.tf_init: self.tensorflow_init_weights(opt,linear,out="first" if li==len(L)-1 else None) self.mlp_feat.append(linear) # RGB prediction self.mlp_rgb = torch.nn.ModuleList() L = util.get_layer_dims(opt.arch.layers_rgb) feat_dim = opt.arch.layers_feat[-1] for li,(k_in,k_out) in enumerate(L): if li==0: k_in = feat_dim+(input_view_dim if opt.nerf.view_dep else 0) linear = torch.nn.Linear(k_in,k_out) if opt.arch.tf_init: self.tensorflow_init_weights(opt,linear,out="all" if li==len(L)-1 else None) self.mlp_rgb.append(linear) def tensorflow_init_weights(self,opt,linear,out=None): # use Xavier init instead of Kaiming init relu_gain = torch.nn.init.calculate_gain("relu") # sqrt(2) if out=="all": torch.nn.init.xavier_uniform_(linear.weight) elif out=="first": torch.nn.init.xavier_uniform_(linear.weight[:1]) torch.nn.init.xavier_uniform_(linear.weight[1:],gain=relu_gain) else: torch.nn.init.xavier_uniform_(linear.weight,gain=relu_gain) torch.nn.init.zeros_(linear.bias) def forward(self,opt,points_3D,ray_unit=None,mode=None): # [B,...,3] if opt.arch.posenc: points_enc = self.positional_encoding(opt,points_3D,L=opt.arch.posenc.L_3D) points_enc = torch.cat([points_3D,points_enc],dim=-1) # [B,...,6L+3] else: points_enc = points_3D feat = points_enc # extract coordinate-based features for li,layer in enumerate(self.mlp_feat): if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1) feat = layer(feat) if li==len(self.mlp_feat)-1: density = feat[...,0] if opt.nerf.density_noise_reg and mode=="train": density += torch.randn_like(density)*opt.nerf.density_noise_reg density_activ = getattr(torch_F,opt.arch.density_activ) # relu_,abs_,sigmoid_,exp_.... density = density_activ(density) feat = feat[...,1:] feat = torch_F.relu(feat) # predict RGB values if opt.nerf.view_dep: assert(ray_unit is not None) if opt.arch.posenc: ray_enc = self.positional_encoding(opt,ray_unit,L=opt.arch.posenc.L_view) ray_enc = torch.cat([ray_unit,ray_enc],dim=-1) # [B,...,6L+3] else: ray_enc = ray_unit feat = torch.cat([feat,ray_enc],dim=-1) for li,layer in enumerate(self.mlp_rgb): feat = layer(feat) if li!=len(self.mlp_rgb)-1: feat = torch_F.relu(feat) rgb = feat.sigmoid_() # [B,...,3] return rgb,density def forward_samples(self,opt,center,ray,depth_samples,mode=None): points_3D_samples = camera.get_3D_points_from_depth(opt,center,ray,depth_samples,multi_samples=True) # [B,HW,N,3] if opt.nerf.view_dep: ray_unit = torch_F.normalize(ray,dim=-1) # [B,HW,3] ray_unit_samples = ray_unit[...,None,:].expand_as(points_3D_samples) # [B,HW,N,3] else: ray_unit_samples = None rgb_samples,density_samples = self.forward(opt,points_3D_samples,ray_unit=ray_unit_samples,mode=mode) # [B,HW,N],[B,HW,N,3] return rgb_samples,density_samples def composite(self,opt,ray,rgb_samples,density_samples,depth_samples): ray_length = ray.norm(dim=-1,keepdim=True) # [B,HW,1] # volume rendering: compute probability (using quadrature) depth_intv_samples = depth_samples[...,1:,0]-depth_samples[...,:-1,0] # [B,HW,N-1] depth_intv_samples = torch.cat([depth_intv_samples,torch.empty_like(depth_intv_samples[...,:1]).fill_(1e10)],dim=2) # [B,HW,N] dist_samples = depth_intv_samples*ray_length # [B,HW,N] sigma_delta = density_samples*dist_samples # [B,HW,N] alpha = 1-(-sigma_delta).exp_() # [B,HW,N] T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)).exp_() # [B,HW,N] prob = (T*alpha)[...,None] # [B,HW,N,1] # integrate RGB and depth weighted by probability depth = (depth_samples*prob).sum(dim=2) # [B,HW,1] rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3] opacity = prob.sum(dim=2) # [B,HW,1] if opt.nerf.setbg_opaque: rgb = rgb+opt.data.bgcolor*(1-opacity) return rgb,depth,opacity,prob # [B,HW,K] def positional_encoding(self,opt,input,L): # [B,...,N] shape = input.shape freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L] spectrum = input[...,None]*freq # [B,...,N,L] sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] return input_enc ================================================ FILE: model/planar.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import tqdm from easydict import EasyDict as edict import PIL import PIL.Image,PIL.ImageDraw import imageio import util,util_vis from util import log,debug from . import base import warp # ============================ main engine for training and evaluation ============================ class Model(base.Model): def __init__(self,opt): super().__init__(opt) opt.H_crop,opt.W_crop = opt.data.patch_crop def load_dataset(self,opt,eval_split=None): image_raw = PIL.Image.open(opt.data.image_fname) self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device) def build_networks(self,opt): super().build_networks(opt) self.graph.warp_param = torch.nn.Embedding(opt.batch_size,opt.warp.dof).to(opt.device) torch.nn.init.zeros_(self.graph.warp_param.weight) def setup_optimizer(self,opt): log.info("setting up optimizers...") optim_list = [ dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr), dict(params=self.graph.warp_param.parameters(),lr=opt.optim.lr_warp), ] optimizer = getattr(torch.optim,opt.optim.algo) self.optim = optimizer(optim_list) # set up scheduler if opt.optim.sched: scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type) kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" } self.sched = scheduler(self.optim,**kwargs) def setup_visualizer(self,opt): super().setup_visualizer(opt) # set colors for visualization box_colors = ["#ff0000","#40afff","#9314ff","#ffd700","#00ff00"] box_colors = list(map(util.colorcode_to_number,box_colors)) self.box_colors = np.array(box_colors).astype(int) assert(len(self.box_colors)==opt.batch_size) # create visualization directory self.vis_path = "{}/vis".format(opt.output_path) os.makedirs(self.vis_path,exist_ok=True) self.video_fname = "{}/vis.mp4".format(opt.output_path) def train(self,opt): # before training log.title("TRAINING START") self.timer = edict(start=time.time(),it_mean=None) self.ep = self.it = self.vis_it = 0 self.graph.train() var = edict(idx=torch.arange(opt.batch_size)) # pre-generate perturbations self.homo_pert, self.rot_pert, self.trans_pert, var.image_pert = self.generate_warp_perturbation(opt) # train var = util.move_to_device(var,opt.device) loader = tqdm.trange(opt.max_iter,desc="training",leave=False) # visualize initial state var = self.graph.forward(opt,var) self.visualize(opt,var,step=0) for it in loader: # train iteration loss = self.train_iteration(opt,var,loader) if opt.warp.fix_first: self.graph.warp_param.weight.data[0] = 0 # after training os.system("ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}".format(self.vis_path,self.video_fname)) self.save_checkpoint(opt,ep=None,it=self.it) if opt.tb: self.tb.flush() self.tb.close() if opt.visdom: self.vis.close() log.title("TRAINING DONE") def train_iteration(self,opt,var,loader): loss = super().train_iteration(opt,var,loader) self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter) return loss def generate_warp_perturbation(self,opt): # pre-generate perturbations (translational noise + homography noise) def create_random_perturbation(batch_size): if opt.warp.dof==1: homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r trans_pert = torch.zeros(batch_size,2,device=opt.device)*opt.warp.noise_t elif opt.warp.dof==2: homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h rot_pert = torch.zeros(batch_size,1,device=opt.device)*opt.warp.noise_r trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t elif opt.warp.dof==3: homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t elif opt.warp.dof==8: homo_pert = torch.randn(batch_size,8,device=opt.device)*opt.warp.noise_h homo_pert[:,:2]=0 rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t else: assert(False) return homo_pert, rot_pert, trans_pert homo_pert = torch.zeros(opt.batch_size,8,device=opt.device) rot_pert = torch.zeros(opt.batch_size,1,device=opt.device) trans_pert = torch.zeros(opt.batch_size,2,device=opt.device) for i in range(opt.batch_size): homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1) while not warp.check_corners_in_range_compose(opt, homo_pert_i, rot_pert_i, trans_pert_i): homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1) homo_pert[i], rot_pert[i], trans_pert[i] = homo_pert_i, rot_pert_i, trans_pert_i if opt.warp.fix_first: homo_pert[0],rot_pert[0],trans_pert[0] = 0,0,0 # create warped image patches xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2] xy_grid_hom = warp.camera.to_hom(xy_grid) warp_matrix = warp.lie.sl3_to_SL3(homo_pert) warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1) xy_grid_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2] warp_matrix = warp.lie.so2_to_SO2(rot_pert) xy_grid_warped = xy_grid_warped@warp_matrix.transpose(-2,-1) # [B,HW,2] xy_grid_warped = xy_grid_warped+trans_pert[...,None,:] xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2]) xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W, xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1) image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1) image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False) return homo_pert, rot_pert, trans_pert, image_pert_all def visualize_patches(self,opt,warp_param): image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA") draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0)) draw = PIL.ImageDraw.Draw(draw_pil) corners_all = warp.warp_corners(opt,warp_param) corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 for i,corners in enumerate(corners_all): P = [tuple(float(n) for n in corners[j]) for j in range(4)] draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3) image_pil.alpha_composite(draw_pil) image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB")) return image_tensor def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert): image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA") draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0)) draw = PIL.ImageDraw.Draw(draw_pil) corners_all = warp.warp_corners_compose(opt, homo_pert, rot_pert, trans_pert) corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 for i,corners in enumerate(corners_all): P = [tuple(float(n) for n in corners[j]) for j in range(4)] draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3) image_pil.alpha_composite(draw_pil) image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB")) return image_tensor @torch.no_grad() def predict_entire_image(self,opt): xy_grid = warp.get_normalized_pixel_grid(opt)[:1] rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3] image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1) return image @torch.no_grad() def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"): super().log_scalars(opt,var,loss,metric=metric,step=step,split=split) # compute PSNR psnr = -10*loss.render.log10() self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step) # warp error # pred_corners = warp.warp_corners_use_matrix(opt,self.graph.warp_matrix) pred_corners = warp.warp_corners(opt,self.graph.warp_param.weight) pred_corners[...,0] = (pred_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 pred_corners[...,1] = (pred_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 gt_corners = warp.warp_corners_compose(opt, self.homo_pert, self.rot_pert, self.trans_pert) gt_corners[...,0] = (gt_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 gt_corners[...,1] = (gt_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 warp_error = (pred_corners-gt_corners).norm(dim=-1).mean() self.tb.add_scalar("{0}/{1}".format(split,"warp error"),warp_error,step) @torch.no_grad() def visualize(self,opt,var,step=0,split="train"): # dump frames for writing to video frame_GT = self.visualize_patches_compose(opt,self.homo_pert, self.rot_pert, self.trans_pert) frame = self.visualize_patches(opt,self.graph.warp_param.weight) frame2 = self.predict_entire_image(opt) frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy() imageio.imsave("{}/{}.png".format(self.vis_path,self.vis_it),frame_cat) self.vis_it += 1 # visualize in Tensorboard if opt.tb: colors = self.box_colors util_vis.tb_image(opt,self.tb,step,split,"image_pert",util_vis.color_border(var.image_pert,colors)) util_vis.tb_image(opt,self.tb,step,split,"rgb_warped",util_vis.color_border(var.rgb_warped_map,colors)) util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes",frame[None]) util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes_GT",frame_GT[None]) util_vis.tb_image(opt,self.tb,self.it+1,"train","image_entire",frame2[None]) # ============================ computation graph for forward/backprop ============================ class Graph(base.Graph): def __init__(self,opt): super().__init__(opt) self.neural_image = NeuralImageFunction(opt) def forward(self,opt,var,mode=None): xy_grid = warp.get_normalized_pixel_grid_crop(opt) xy_grid_warped = warp.warp_grid(opt,xy_grid,self.warp_param.weight) # render images var.rgb_warped = self.neural_image.forward(opt,xy_grid_warped) # [B,HW,3] var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W] return var def compute_loss(self,opt,var,mode=None): loss = edict() if opt.loss_weight.render is not None: image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1) loss.render = self.MSE_loss(var.rgb_warped,image_pert) return loss class NeuralImageFunction(torch.nn.Module): def __init__(self,opt): super().__init__() self.define_network(opt) self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed def define_network(self,opt): input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2 # point-wise RGB prediction self.mlp = torch.nn.ModuleList() L = util.get_layer_dims(opt.arch.layers) for li,(k_in,k_out) in enumerate(L): if li==0: k_in = input_2D_dim if li in opt.arch.skip: k_in += input_2D_dim linear = torch.nn.Linear(k_in,k_out) if opt.barf_c2f and li==0: # rescale first layer init (distribution was for pos.enc. but only xy is first used) scale = np.sqrt(input_2D_dim/2.) linear.weight.data *= scale linear.bias.data *= scale self.mlp.append(linear) def forward(self,opt,coord_2D): # [B,...,3] if opt.arch.posenc: points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D) points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3] else: points_enc = coord_2D feat = points_enc # extract implicit features for li,layer in enumerate(self.mlp): if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1) feat = layer(feat) if li!=len(self.mlp)-1: feat = torch_F.relu(feat) rgb = feat.sigmoid_() # [B,...,3] return rgb def positional_encoding(self,opt,input,L): # [B,...,N] shape = input.shape freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L] spectrum = input[...,None]*freq # [B,...,N,L] sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] # coarse-to-fine: smoothly mask positional encoding for BARF if opt.barf_c2f is not None: # set weights for different frequency bands start,end = opt.barf_c2f alpha = (self.progress.data-start)/(end-start)*L k = torch.arange(L,dtype=torch.float32,device=opt.device) weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2 # apply weights shape = input_enc.shape input_enc = (input_enc.view(-1,L)*weight).view(*shape) return input_enc ================================================ FILE: options/barf_blender.yaml ================================================ _parent_: options/nerf_blender.yaml barf_c2f: [0.1,0.5] # coarse-to-fine scheduling on positional encoding camera: # camera options noise: true # synthetic perturbations on the camera poses (Blender only) noise_r: 0.03 # synthetic perturbations on the camera poses (Blender only) noise_t: 0.3 # synthetic perturbations on the camera poses (Blender only) optim: # optimization options lr_pose: 1.e-3 # learning rate of camera poses lr_pose_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR) sched_pose: # learning rate scheduling options type: ExponentialLR # scheduler (see PyTorch doc) gamma: # decay rate (can be empty if lr_pose_end were specified) warmup_pose: # linear warmup of the pose learning rate (N iterations) test_photo: true # test-time photometric optimization for evaluation test_iter: 100 # number of iterations for test-time optimization visdom: # Visdom options cam_depth: 0.5 # size of visualized cameras ================================================ FILE: options/barf_iphone.yaml ================================================ _parent_: options/barf_llff.yaml data: # data options dataset: iphone # dataset name scene: IMG_0239 # scene name image_size: [480,640] # input image sizes [height,width] ================================================ FILE: options/barf_llff.yaml ================================================ _parent_: options/nerf_llff.yaml barf_c2f: [0.1,0.5] # coarse-to-fine scheduling on positional encoding camera: # camera options noise: # synthetic perturbations on the camera poses (Blender only) optim: # optimization options lr_pose: 3.e-3 # learning rate of camera poses lr_pose_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR) sched_pose: # learning rate scheduling options type: ExponentialLR # scheduler (see PyTorch doc) gamma: # decay rate (can be empty if lr_pose_end were specified) warmup_pose: # linear warmup of the pose learning rate (N iterations) test_photo: true # test-time photometric optimization for evaluation test_iter: 100 # number of iterations for test-time optimization visdom: # Visdom options cam_depth: 0.2 # size of visualized cameras ================================================ FILE: options/base.yaml ================================================ # default group: 0_test # name of experiment group name: debug # name of experiment run model: # type of model (must be specified from command line) yaml: # config file (must be specified from command line) seed: 0 # seed number (for both numpy and pytorch) gpu: 0 # GPU index number cpu: false # run only on CPU (not supported now) load: # load checkpoint from filename arch: {} # architectural options data: # data options root: # root path to dataset dataset: # dataset name image_size: [null,null] # input image sizes [height,width] num_workers: 8 # number of parallel workers for data loading preload: false # preload the entire dataset into the memory augment: {} # data augmentation (training only) # rotate: # random rotation # brightness: # 0.2 # random brightness jitter # contrast: # 0.2 # random contrast jitter # saturation: # 0.2 # random saturation jitter # hue: # 0.1 # random hue jitter # hflip: # True # random horizontal flip center_crop: # center crop the image by ratio val_on_test: false # validate on test set during training train_sub: # consider a subset of N training samples val_sub: # consider a subset of N validation samples loss_weight: {} # loss weights (in log scale) optim: # optimization options lr: 1.e-3 # learning rate (main) lr_end: # terminal learning rate (only used with sched.type=ExponentialLR) algo: Adam # optimizer (see PyTorch doc) sched: {} # learning rate scheduling options # type: StepLR # scheduler (see PyTorch doc) # steps: # decay every N epochs # gamma: 0.1 # decay rate (can be empty if lr_end were specified) batch_size: 16 # batch size max_epoch: 1000 # train to maximum number of epochs resume: false # resume training (true for latest checkpoint, or number for specific epoch number) output_root: output # root path for output files (checkpoints and results) tb: # TensorBoard options num_images: [4,8] # number of (tiled) images to visualize in TensorBoard visdom: # Visdom options server: localhost # server to host Visdom port: 8600 # port number for Visdom freq: # periodic actions during training scalar: 200 # log losses and scalar states (every N iterations) vis: 1000 # visualize results (every N iterations) val: 20 # validate on val set (every N epochs) ckpt: 50 # save checkpoint (every N epochs) ================================================ FILE: options/l2g_nerf_blender.yaml ================================================ _parent_: options/barf_blender.yaml arch: # architectural options layers_warp: [null,256,256,256,256,256,256,6] # hidden layers for MLP skip_warp: [4] # skip connections embedding_dim: 128 # embedding dim optim: # optimization options lr_pose: 1.e-3 # learning rate of camera poses lr_pose_end: 1.e-8 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR) loss_weight: # loss weights (in log scale) global_alignment: 2 # global alignment loss error_map_size: ================================================ FILE: options/l2g_nerf_iphone.yaml ================================================ _parent_: options/l2g_nerf_llff.yaml data: # data options dataset: iphone # dataset name scene: IMG_0239 # scene name image_size: [480,640] # input image sizes [height,width] optim: # optimization options test_iter: 1000 # number of iterations for test-time optimization ================================================ FILE: options/l2g_nerf_llff.yaml ================================================ _parent_: options/barf_llff.yaml arch: # architectural options layers_warp: [null,256,256,256,256,256,256,6] # hidden layers for MLP skip_warp: [4] # skip connections embedding_dim: 128 # embedding dim optim: # optimization options lr_pose: 3.e-3 # learning rate of camera poses lr_pose_end: 1.e-8 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR) loss_weight: # loss weights (in log scale) global_alignment: 2 # global alignment loss error_map_size: ================================================ FILE: options/l2g_planar.yaml ================================================ _parent_: options/planar.yaml arch: # architectural options layers_warp: [null,256,256,256,256,256,256,3] # hidden layers for MLP skip_warp: [4] # skip connections embedding_dim: 128 # embedding dim loss_weight: # loss weights (in log scale) global_alignment: 2 # global alignment loss optim: lr: 1.e-3 # learning rate (main) lr_end: 1.e-4 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR) sched: # learning rate scheduling options type: ExponentialLR # scheduler (see PyTorch doc) gamma: # decay rate (can be empty if lr_pose_end were specified) lr_warp: 1.e-3 # learning rate of camera poses lr_warp_end: 1.e-5 # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR) sched_warp: # learning rate scheduling options type: ExponentialLR # scheduler (see PyTorch doc) gamma: # decay rate (can be empty if lr_pose_end were specified) ================================================ FILE: options/nerf_blender.yaml ================================================ _parent_: options/base.yaml arch: # architectural options layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP layers_rgb: [null,128,3] # hidden layers for color MLP skip: [4] # skip connections posenc: # positional encoding L_3D: 10 # number of bases (3D point) L_view: 4 # number of bases (viewpoint) density_activ: softplus # activation function for output volume density tf_init: true # initialize network weights in TensorFlow style nerf: # NeRF-specific options view_dep: true # condition MLP on viewpoint depth: # depth-related options param: metric # depth parametrization (for sampling along the ray) range: [2,6] # near/far bounds for depth sampling sample_intvs: 128 # number of samples sample_stratified: true # stratified sampling fine_sampling: false # hierarchical sampling with another NeRF sample_intvs_fine: # number of samples for the fine NeRF rand_rays: 1024 # number of random rays for each step density_noise_reg: # Gaussian noise on density output as regularization setbg_opaque: false # fill transparent rendering with known background color (Blender only) data: # data options dataset: blender # dataset name scene: lego # scene name image_size: [400,400] # input image sizes [height,width] num_workers: 4 # number of parallel workers for data loading preload: true # preload the entire dataset into the memory bgcolor: 1 # background color (Blender only) val_sub: 4 # consider a subset of N validation samples camera: # camera options model: perspective # type of camera model ndc: false # reparametrize as normalized device coordinates (NDC) loss_weight: # loss weights (in log scale) render: 0 # RGB rendering loss render_fine: # RGB rendering loss (for fine NeRF) optim: # optimization options lr: 5.e-4 # learning rate (main) lr_end: 1.e-4 # terminal learning rate (only used with sched.type=ExponentialLR) sched: # learning rate scheduling options type: ExponentialLR # scheduler (see PyTorch doc) gamma: # decay rate (can be empty if lr_end were specified) batch_size: # batch size (not used for NeRF/BARF) max_epoch: # train to maximum number of epochs (not used for NeRF/BARF) max_iter: 200000 # train to maximum number of iterations trimesh: # options for marching cubes to extract 3D mesh res: 128 # 3D sampling resolution range: [-1.2,1.2] # 3D range of interest (assuming same for x,y,z) thres: 25. # volume density threshold for marching cubes chunk_size: 16384 # chunk size of dense samples to be evaluated at a time freq: # periodic actions during training scalar: 200 # log losses and scalar states (every N iterations) vis: 1000 # visualize results (every N iterations) val: 2000 # validate on val set (every N iterations) ckpt: 5000 # save checkpoint (every N iterations) ================================================ FILE: options/nerf_blender_repr.yaml ================================================ _parent_: options/base.yaml arch: # architectural options layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP layers_rgb: [null,128,3] # hidden layers for color MLP skip: [4] # skip connections posenc: # positional encoding L_3D: 10 # number of bases (3D point) L_view: 4 # number of bases (viewpoint) density_activ: relu # activation function for output volume density tf_init: true # initialize network weights in TensorFlow style nerf: # NeRF-specific options view_dep: true # condition MLP on viewpoint depth: # depth-related options param: metric # depth parametrization (for sampling along the ray) range: [2,6] # near/far bounds for depth sampling sample_intvs: 64 # number of samples sample_stratified: true # stratified sampling fine_sampling: true # hierarchical sampling with another NeRF sample_intvs_fine: 128 # number of samples for the fine NeRF rand_rays: 1024 # number of random rays for each step density_noise_reg: 0 # Gaussian noise on density output as regularization setbg_opaque: true # fill transparent rendering with known background color (Blender only) data: # data options dataset: blender # dataset name scene: lego # scene name image_size: [400,400] # input image sizes [height,width] num_workers: 4 # number of parallel workers for data loading preload: true # preload the entire dataset into the memory bgcolor: 1 # background color (Blender only) val_sub: 4 # consider a subset of N validation samples camera: # camera options model: perspective # type of camera model ndc: false # reparametrize as normalized device coordinates (NDC) loss_weight: # loss weights (in log scale) render: 0 # RGB rendering loss render_fine: 0 # RGB rendering loss (for fine NeRF) optim: # optimization options lr: 5.e-4 # learning rate (main) lr_end: 5.e-5 # terminal learning rate (only used with sched.type=ExponentialLR) sched: # learning rate scheduling options type: ExponentialLR # scheduler (see PyTorch doc) gamma: # decay rate (can be empty if lr_end were specified) batch_size: # batch size (not used for NeRF/BARF) max_epoch: # train to maximum number of epochs (not used for NeRF/BARF) max_iter: 500000 # train to maximum number of iterations trimesh: # options for marching cubes to extract 3D mesh res: 128 # 3D sampling resolution range: [-1.2,1.2] # 3D range of interest (assuming same for x,y,z) thres: 25. # volume density threshold for marching cubes chunk_size: 16384 # chunk size of dense samples to be evaluated at a time freq: # periodic actions during training scalar: 200 # log losses and scalar states (every N iterations) vis: 1000 # visualize results (every N iterations) val: 2000 # validate on val set (every N iterations) ckpt: 5000 # save checkpoint (every N iterations) ================================================ FILE: options/nerf_llff.yaml ================================================ _parent_: options/base.yaml arch: # architectural optionss layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP] layers_rgb: [null,128,3] # hidden layers for color MLP] skip: [4] # skip connections posenc: # positional encoding: L_3D: 10 # number of bases (3D point) L_view: 4 # number of bases (viewpoint) density_activ: softplus # activation function for output volume density tf_init: true # initialize network weights in TensorFlow style nerf: # NeRF-specific options view_dep: true # condition MLP on viewpoint depth: # depth-related options param: inverse # depth parametrization (for sampling along the ray) range: [1,0] # near/far bounds for depth sampling sample_intvs: 128 # number of samples sample_stratified: true # stratified sampling fine_sampling: false # hierarchical sampling with another NeRF sample_intvs_fine: # number of samples for the fine NeRF rand_rays: 2048 # number of random rays for each step density_noise_reg: # Gaussian noise on density output as regularization setbg_opaque: # fill transparent rendering with known background color (Blender only) data: # data options dataset: llff # dataset name scene: fern # scene name image_size: [480,640] # input image sizes [height,width] num_workers: 4 # number of parallel workers for data loading preload: true # preload the entire dataset into the memory val_ratio: 0.1 # ratio of sequence split for validation camera: # camera options model: perspective # type of camera model ndc: false # reparametrize as normalized device coordinates (NDC) loss_weight: # loss weights (in log scale) render: 0 # RGB rendering loss render_fine: # RGB rendering loss (for fine NeRF) optim: # optimization options lr: 1.e-3 # learning rate (main) lr_end: 1.e-4 # terminal learning rate (only used with sched.type=ExponentialLR) sched: # learning rate scheduling options type: ExponentialLR # scheduler (see PyTorch doc) gamma: # decay rate (can be empty if lr_end were specified) batch_size: # batch size (not used for NeRF/BARF) max_epoch: # train to maximum number of epochs (not used for NeRF/BARF) max_iter: 200000 # train to maximum number of iterations freq: # periodic actions during training scalar: 200 # log losses and scalar states (every N iterations) vis: 1000 # visualize results (every N iterations) val: 2000 # validate on val set (every N iterations) ckpt: 5000 # save checkpoint (every N iterations) ================================================ FILE: options/nerf_llff_repr.yaml ================================================ _parent_: options/base.yaml arch: # architectural options layers_feat: [null,256,256,256,256,256,256,256,256] # hidden layers for feature/density MLP layers_rgb: [null,128,3] # hidden layers for color MLP skip: [4] # skip connections posenc: # positional encoding L_3D: 10 # number of bases (3D point) L_view: 4 # number of bases (viewpoint) density_activ: relu # activation function for output volume density tf_init: true # initialize network weights in TensorFlow style nerf: # NeRF-specific options view_dep: true # condition MLP on viewpoint depth: # depth-related options param: metric # depth parametrization (for sampling along the ray) range: [0,1] # near/far bounds for depth sampling sample_intvs: 64 # number of samples sample_stratified: true # stratified sampling fine_sampling: true # hierarchical sampling with another NeRF sample_intvs_fine: 128 # number of samples for the fine NeRF rand_rays: 1024 # number of random rays for each step density_noise_reg: 1 # Gaussian noise on density output as regularization setbg_opaque: # fill transparent rendering with known background color (Blender only) data: # data options dataset: llff # dataset name scene: fern # scene name image_size: [480,640] # input image sizes [height,width] num_workers: 4 # number of parallel workers for data loading preload: true # preload the entire dataset into the memory val_ratio: 0.1 # ratio of sequence split for validation camera: # camera options model: perspective # type of camera model ndc: false # reparametrize as normalized device coordinates (NDC) loss_weight: # loss weights (in log scale) render: 0 # RGB rendering loss render_fine: 0 # RGB rendering loss (for fine NeRF) optim: # optimization options lr: 5.e-4 # learning rate (main) lr_end: 5.e-5 # terminal learning rate (only used with sched.type=ExponentialLR) sched: # learning rate scheduling options type: ExponentialLR # scheduler (see PyTorch doc) gamma: # decay rate (can be empty if lr_end were specified) batch_size: # batch size (not used for NeRF/BARF) max_epoch: # train to maximum number of epochs (not used for NeRF/BARF) max_iter: 500000 # train to maximum number of iterations freq: # periodic actions during training scalar: 200 # log losses and scalar states (every N iterations) vis: 1000 # visualize results (every N iterations) val: 2000 # validate on val set (every N iterations) ckpt: 5000 # save checkpoint (every N iterations) ================================================ FILE: options/planar.yaml ================================================ _parent_: options/base.yaml arch: # architectural options layers: [null,256,256,256,256,3] # hidden layers for MLP skip: [] # skip connections posenc: # positional encoding L_2D: 8 # number of bases (3D point) barf_c2f: [0,0.4] # coarse-to-fine scheduling on positional encoding data: # data options image_fname: data/cat.jpg # path to image file image_size: [360,480] # original image size patch_crop: [180,180] # crop size of image patches to align warp: # image warping options type: homography # type of warp function dof: 8 # degrees of freedom of the warp function noise_h: 0.3 # scale of pre-generated warp perturbation (homography) noise_t: 0.1 # scale of pre-generated warp perturbation (translation) noise_r: 1.0 # scale of pre-generated warp perturbation (rotation) fix_first: true # fix the first patch for uniqueness of solution loss_weight: # loss weights (in log scale) render: 0 # RGB rendering loss optim: # optimization options lr: 1.e-3 # learning rate (main) lr_warp: 1.e-3 # learning rate of warp parameters batch_size: 5 # batch size (set to number of patches to consider) max_iter: 5000 # train to maximum number of iterations visdom: # Visdom options (turned off) freq: # periodic actions during training scalar: 20 # log losses and scalar states (every N iterations) vis: 100 # visualize results (every N iterations) ================================================ FILE: options.py ================================================ import numpy as np import os,sys,time import torch import random import string import yaml from easydict import EasyDict as edict import util from util import log # torch.backends.cudnn.enabled = False # torch.backends.cudnn.benchmark = False # torch.backends.cudnn.deterministic = True def parse_arguments(args): """ Parse arguments from command line. Syntax: --key1.key2.key3=value --> value --key1.key2.key3= --> None --key1.key2.key3 --> True --key1.key2.key3! --> False """ opt_cmd = {} for arg in args: assert(arg.startswith("--")) if "=" not in arg[2:]: key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true") else: key_str,value = arg[2:].split("=") keys_sub = key_str.split(".") opt_sub = opt_cmd for k in keys_sub[:-1]: if k not in opt_sub: opt_sub[k] = {} opt_sub = opt_sub[k] assert keys_sub[-1] not in opt_sub,keys_sub[-1] opt_sub[keys_sub[-1]] = yaml.safe_load(value) opt_cmd = edict(opt_cmd) return opt_cmd def set(opt_cmd={}): log.info("setting configurations...") assert("model" in opt_cmd) # load config from yaml file assert("yaml" in opt_cmd) fname = "options/{}.yaml".format(opt_cmd.yaml) opt_base = load_options(fname) # override with command line arguments opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True) process_options(opt) log.options(opt) return opt def load_options(fname): with open(fname) as file: opt = edict(yaml.safe_load(file)) if "_parent_" in opt: # load parent yaml file(s) as base options parent_fnames = opt.pop("_parent_") if type(parent_fnames) is str: parent_fnames = [parent_fnames] for parent_fname in parent_fnames: opt_parent = load_options(parent_fname) opt_parent = override_options(opt_parent,opt,key_stack=[]) opt = opt_parent print("loading {}...".format(fname)) return opt def override_options(opt,opt_over,key_stack=None,safe_check=False): for key,value in opt_over.items(): if isinstance(value,dict): # parse child options (until leaf nodes are reached) opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check) else: # ensure command line argument to override is also in yaml file if safe_check and key not in opt: add_new = None while add_new not in ["y","n"]: key_str = ".".join(key_stack+[key]) add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str)) if add_new=="n": print("safe exiting...") exit() opt[key] = value return opt def process_options(opt): # set seed if opt.seed is not None: random.seed(opt.seed) np.random.seed(opt.seed) torch.manual_seed(opt.seed) torch.cuda.manual_seed_all(opt.seed) if opt.seed!=0: opt.name = str(opt.name)+"_seed{}".format(opt.seed) else: # create random string as run ID randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4)) opt.name = str(opt.name)+"_{}".format(randkey) # other default options opt.output_path = "{0}/{1}/{2}".format(opt.output_root,opt.group,opt.name) os.makedirs(opt.output_path,exist_ok=True) assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu) opt.H,opt.W = opt.data.image_size def save_options_file(opt): opt_fname = "{}/options.yaml".format(opt.output_path) if os.path.isfile(opt_fname): with open(opt_fname) as file: opt_old = yaml.safe_load(file) if opt!=opt_old: # prompt if options are not identical opt_new_fname = "{}/options_temp.yaml".format(opt.output_path) with open(opt_new_fname,"w") as file: yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4) print("existing options file found (different from current one)...") os.system("diff {} {}".format(opt_fname,opt_new_fname)) os.system("rm {}".format(opt_new_fname)) override = None while override not in ["y","n"]: override = input("override? (y/n) ") if override=="n": print("safe exiting...") exit() else: print("existing options file found (identical)") else: print("(creating new options file...)") with open(opt_fname,"w") as file: yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4) ================================================ FILE: requirements.yaml ================================================ name: L2G-NeRF channels: - conda-forge - pytorch dependencies: - numpy - scipy - tqdm - termcolor - easydict - imageio - ipdb - pytorch>=1.9.0 - torchvision - tensorboard - visdom - matplotlib - scikit-video - trimesh - pyyaml - pip - gdown - pip: - lpips - pymcubes - roma - kornia ================================================ FILE: train.py ================================================ import numpy as np import os,sys,time import torch import importlib import options from util import log import warnings warnings.filterwarnings('ignore') def main(): log.process(os.getpid()) log.title("[{}] (PyTorch code for training NeRF/BARF/L2G_NeRF)".format(sys.argv[0])) opt_cmd = options.parse_arguments(sys.argv[1:]) opt = options.set(opt_cmd=opt_cmd) options.save_options_file(opt) with torch.cuda.device(opt.device): model = importlib.import_module("model.{}".format(opt.model)) m = model.Model(opt) m.load_dataset(opt) m.build_networks(opt) m.setup_optimizer(opt) m.restore_checkpoint(opt) m.setup_visualizer(opt) m.train(opt) if __name__=="__main__": main() ================================================ FILE: util.py ================================================ import numpy as np import os,sys,time import shutil import datetime import torch import torch.nn.functional as torch_F import ipdb import types import termcolor import socket import contextlib from easydict import EasyDict as edict # convert to colored strings def red(message,**kwargs): return termcolor.colored(str(message),color="red",attrs=[k for k,v in kwargs.items() if v is True]) def green(message,**kwargs): return termcolor.colored(str(message),color="green",attrs=[k for k,v in kwargs.items() if v is True]) def blue(message,**kwargs): return termcolor.colored(str(message),color="blue",attrs=[k for k,v in kwargs.items() if v is True]) def cyan(message,**kwargs): return termcolor.colored(str(message),color="cyan",attrs=[k for k,v in kwargs.items() if v is True]) def yellow(message,**kwargs): return termcolor.colored(str(message),color="yellow",attrs=[k for k,v in kwargs.items() if v is True]) def magenta(message,**kwargs): return termcolor.colored(str(message),color="magenta",attrs=[k for k,v in kwargs.items() if v is True]) def grey(message,**kwargs): return termcolor.colored(str(message),color="grey",attrs=[k for k,v in kwargs.items() if v is True]) def get_time(sec): d = int(sec//(24*60*60)) h = int(sec//(60*60)%24) m = int((sec//60)%60) s = int(sec%60) return d,h,m,s def add_datetime(func): def wrapper(*args,**kwargs): datetime_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(grey("[{}] ".format(datetime_str),bold=True),end="") return func(*args,**kwargs) return wrapper def add_functionname(func): def wrapper(*args,**kwargs): print(grey("[{}] ".format(func.__name__),bold=True)) return func(*args,**kwargs) return wrapper def pre_post_actions(pre=None,post=None): def func_decorator(func): def wrapper(*args,**kwargs): if pre: pre() retval = func(*args,**kwargs) if post: post() return retval return wrapper return func_decorator debug = ipdb.set_trace class Log: def __init__(self): pass def process(self,pid): print(grey("Process ID: {}".format(pid),bold=True)) def title(self,message): print(yellow(message,bold=True,underline=True)) def info(self,message): print(magenta(message,bold=True)) def options(self,opt,level=0): for key,value in sorted(opt.items()): if isinstance(value,(dict,edict)): print(" "*level+cyan("* ")+green(key)+":") self.options(value,level+1) else: print(" "*level+cyan("* ")+green(key)+":",yellow(value)) def loss_train(self,opt,ep,lr,loss,timer): if not opt.max_epoch: return message = grey("[train] ",bold=True) message += "epoch {}/{}".format(cyan(ep,bold=True),opt.max_epoch) message += ", lr:{}".format(yellow("{:.2e}".format(lr),bold=True)) message += ", loss:{}".format(red("{:.3e}".format(loss),bold=True)) message += ", time:{}".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.elapsed)),bold=True)) message += " (ETA:{})".format(blue("{0}-{1:02d}:{2:02d}:{3:02d}".format(*get_time(timer.arrival)))) print(message) def loss_val(self,opt,loss): message = grey("[val] ",bold=True) message += "loss:{}".format(red("{:.3e}".format(loss),bold=True)) print(message) log = Log() def update_timer(opt,timer,ep,it_per_ep): if not opt.max_epoch: return momentum = 0.99 timer.elapsed = time.time()-timer.start timer.it = timer.it_end-timer.it_start # compute speed with moving average timer.it_mean = timer.it_mean*momentum+timer.it*(1-momentum) if timer.it_mean is not None else timer.it timer.arrival = timer.it_mean*it_per_ep*(opt.max_epoch-ep) # move tensors to device in-place def move_to_device(X,device): if isinstance(X,dict): for k,v in X.items(): X[k] = move_to_device(v,device) elif isinstance(X,list): for i,e in enumerate(X): X[i] = move_to_device(e,device) elif isinstance(X,tuple) and hasattr(X,"_fields"): # collections.namedtuple dd = X._asdict() dd = move_to_device(dd,device) return type(X)(**dd) elif isinstance(X,torch.Tensor): return X.to(device=device) return X def to_dict(D,dict_type=dict): D = dict_type(D) for k,v in D.items(): if isinstance(v,dict): D[k] = to_dict(v,dict_type) return D def get_child_state_dict(state_dict,key): return { ".".join(k.split(".")[1:]): v for k,v in state_dict.items() if k.startswith("{}.".format(key)) } def restore_checkpoint(opt,model,load_name=None,resume=False): assert((load_name is None)==(resume is not False)) # resume can be True/False or epoch numbers if resume: load_name = "{0}/model.ckpt".format(opt.output_path) if resume is True else \ "{0}/model/{1}.ckpt".format(opt.output_path,resume) checkpoint = torch.load(load_name,map_location=opt.device) # load individual (possibly partial) children modules for name,child in model.graph.named_children(): child_state_dict = get_child_state_dict(checkpoint["graph"],name) if child_state_dict: print("restoring {}...".format(name)) child.load_state_dict(child_state_dict) for key in model.__dict__: if key.split("_")[0] in ["optim","sched"] and key in checkpoint and resume: print("restoring {}...".format(key)) getattr(model,key).load_state_dict(checkpoint[key]) if resume: ep,it = checkpoint["epoch"],checkpoint["iter"] if resume is not True: assert(resume==(ep or it)) print("resuming from epoch {0} (iteration {1})".format(ep,it)) else: ep,it = None,None return ep,it def save_checkpoint(opt,model,ep,it,latest=False,children=None): os.makedirs("{0}/model".format(opt.output_path),exist_ok=True) if children is not None: graph_state_dict = { k: v for k,v in model.graph.state_dict().items() if k.startswith(children) } else: graph_state_dict = model.graph.state_dict() checkpoint = dict( epoch=ep, iter=it, graph=graph_state_dict, ) for key in model.__dict__: if key.split("_")[0] in ["optim","sched"]: checkpoint.update({ key: getattr(model,key).state_dict() }) torch.save(checkpoint,"{0}/model.ckpt".format(opt.output_path)) if not latest: shutil.copy("{0}/model.ckpt".format(opt.output_path), "{0}/model/{1}.ckpt".format(opt.output_path,ep or it)) # if ep is None, track it instead def check_socket_open(hostname,port): s = socket.socket(socket.AF_INET,socket.SOCK_STREAM) is_open = False try: s.bind((hostname,port)) except socket.error: is_open = True finally: s.close() return is_open def get_layer_dims(layers): # return a list of tuples (k_in,k_out) return list(zip(layers[:-1],layers[1:])) @contextlib.contextmanager def suppress(stdout=False,stderr=False): with open(os.devnull,"w") as devnull: if stdout: old_stdout,sys.stdout = sys.stdout,devnull if stderr: old_stderr,sys.stderr = sys.stderr,devnull try: yield finally: if stdout: sys.stdout = old_stdout if stderr: sys.stderr = old_stderr def colorcode_to_number(code): ords = [ord(c) for c in code[1:]] ords = [n-48 if n<58 else n-87 for n in ords] rgb = (ords[0]*16+ords[1],ords[2]*16+ords[3],ords[4]*16+ords[5]) return rgb ================================================ FILE: util_vis.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import torchvision import torchvision.transforms.functional as torchvision_F import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Poly3DCollection import PIL import imageio from easydict import EasyDict as edict import camera @torch.no_grad() def tb_image(opt,tb,step,group,name,images,num_vis=None,from_range=(0,1),cmap="gray"): images = preprocess_vis_image(opt,images,from_range=from_range,cmap=cmap) num_H,num_W = num_vis or opt.tb.num_images images = images[:num_H*num_W] image_grid = torchvision.utils.make_grid(images[:,:3],nrow=num_W,pad_value=1.) if images.shape[1]==4: mask_grid = torchvision.utils.make_grid(images[:,3:],nrow=num_W,pad_value=1.)[:1] image_grid = torch.cat([image_grid,mask_grid],dim=0) tag = "{0}/{1}".format(group,name) tb.add_image(tag,image_grid,step) def preprocess_vis_image(opt,images,from_range=(0,1),cmap="gray"): min,max = from_range images = (images-min)/(max-min) images = images.clamp(min=0,max=1).cpu() if images.shape[1]==1: images = get_heatmap(opt,images[:,0].cpu(),cmap=cmap) return images def dump_images(opt,idx,name,images,masks=None,from_range=(0,1),cmap="gray"): images = preprocess_vis_image(opt,images,masks=masks,from_range=from_range,cmap=cmap) # [B,3,H,W] images = images.cpu().permute(0,2,3,1).numpy() # [B,H,W,3] for i,img in zip(idx,images): fname = "{}/dump/{}_{}.png".format(opt.output_path,i,name) img_uint8 = (img*255).astype(np.uint8) imageio.imsave(fname,img_uint8) def get_heatmap(opt,gray,cmap): # [N,H,W] color = plt.get_cmap(cmap)(gray.numpy()) color = torch.from_numpy(color[...,:3]).permute(0,3,1,2).float() # [N,3,H,W] return color def color_border(images,colors,width=3): images_pad = [] for i,image in enumerate(images): image_pad = torch.ones(3,image.shape[1]+width*2,image.shape[2]+width*2)*(colors[i,:,None,None]/255.0) image_pad[:,width:-width,width:-width] = image images_pad.append(image_pad) images_pad = torch.stack(images_pad,dim=0) return images_pad @torch.no_grad() def vis_cameras(opt,vis,step,poses=[],colors=["blue","magenta"],plot_dist=True): win_name = "{}/{}".format(opt.group,opt.name) data = [] # set up plots centers = [] for pose,color in zip(poses,colors): pose = pose.detach().cpu() vertices,faces,wireframe = get_camera_mesh(pose,depth=opt.visdom.cam_depth) center = vertices[:,-1] centers.append(center) # camera centers data.append(dict( type="scatter3d", x=[float(n) for n in center[:,0]], y=[float(n) for n in center[:,1]], z=[float(n) for n in center[:,2]], mode="markers", marker=dict(color=color,size=3), )) # colored camera mesh vertices_merged,faces_merged = merge_meshes(vertices,faces) data.append(dict( type="mesh3d", x=[float(n) for n in vertices_merged[:,0]], y=[float(n) for n in vertices_merged[:,1]], z=[float(n) for n in vertices_merged[:,2]], i=[int(n) for n in faces_merged[:,0]], j=[int(n) for n in faces_merged[:,1]], k=[int(n) for n in faces_merged[:,2]], flatshading=True, color=color, opacity=0.05, )) # camera wireframe wireframe_merged = merge_wireframes(wireframe) data.append(dict( type="scatter3d", x=wireframe_merged[0], y=wireframe_merged[1], z=wireframe_merged[2], mode="lines", line=dict(color=color,), opacity=0.3, )) if plot_dist: # distance between two poses (camera centers) center_merged = merge_centers(centers[:2]) data.append(dict( type="scatter3d", x=center_merged[0], y=center_merged[1], z=center_merged[2], mode="lines", line=dict(color="red",width=4,), )) if len(centers)==4: center_merged = merge_centers(centers[2:4]) data.append(dict( type="scatter3d", x=center_merged[0], y=center_merged[1], z=center_merged[2], mode="lines", line=dict(color="red",width=4,), )) # send data to visdom vis._send(dict( data=data, win="poses", eid=win_name, layout=dict( title="({})".format(step), autosize=True, margin=dict(l=30,r=30,b=30,t=30,), showlegend=False, yaxis=dict( scaleanchor="x", scaleratio=1, ) ), opts=dict(title="{} poses ({})".format(win_name,step),), )) def get_camera_mesh(pose,depth=1): vertices = torch.tensor([[-0.5,-0.5,1], [0.5,-0.5,1], [0.5,0.5,1], [-0.5,0.5,1], [0,0,0]])*depth faces = torch.tensor([[0,1,2], [0,2,3], [0,1,4], [1,2,4], [2,3,4], [3,0,4]]) vertices = camera.cam2world(vertices[None],pose) wireframe = vertices[:,[0,1,2,3,0,4,1,2,4,3]] return vertices,faces,wireframe def merge_wireframes(wireframe): wireframe_merged = [[],[],[]] for w in wireframe: wireframe_merged[0] += [float(n) for n in w[:,0]]+[None] wireframe_merged[1] += [float(n) for n in w[:,1]]+[None] wireframe_merged[2] += [float(n) for n in w[:,2]]+[None] return wireframe_merged def merge_meshes(vertices,faces): mesh_N,vertex_N = vertices.shape[:2] faces_merged = torch.cat([faces+i*vertex_N for i in range(mesh_N)],dim=0) vertices_merged = vertices.view(-1,vertices.shape[-1]) return vertices_merged,faces_merged def merge_centers(centers): center_merged = [[],[],[]] for c1,c2 in zip(*centers): center_merged[0] += [float(c1[0]),float(c2[0]),None] center_merged[1] += [float(c1[1]),float(c2[1]),None] center_merged[2] += [float(c1[2]),float(c2[2]),None] return center_merged def plot_save_poses(opt,fig,pose,pose_ref=None,path=None,ep=None): # get the camera meshes _,_,cam = get_camera_mesh(pose,depth=opt.visdom.cam_depth) cam = cam.numpy() if pose_ref is not None: _,_,cam_ref = get_camera_mesh(pose_ref,depth=opt.visdom.cam_depth) cam_ref = cam_ref.numpy() # set up plot window(s) plt.title("epoch {}".format(ep)) ax1 = fig.add_subplot(121,projection="3d") ax2 = fig.add_subplot(122,projection="3d") setup_3D_plot(ax1,elev=-90,azim=-90,lim=edict(x=(-1,1),y=(-1,1),z=(-1,1))) setup_3D_plot(ax2,elev=0,azim=-90,lim=edict(x=(-1,1),y=(-1,1),z=(-1,1))) ax1.set_title("forward-facing view",pad=0) ax2.set_title("top-down view",pad=0) plt.subplots_adjust(left=0,right=1,bottom=0,top=0.95,wspace=0,hspace=0) plt.margins(tight=True,x=0,y=0) # plot the cameras N = len(cam) color = plt.get_cmap("gist_rainbow") if pose_ref is not None: for i in range(N): ax1.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=(0.3,0.3,0.3),linewidth=1) ax2.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=(0.3,0.3,0.3),linewidth=1) ax1.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=(0.3,0.3,0.3),s=40) ax2.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=(0.3,0.3,0.3),s=40) if ep==0: png_fname = "{}/GT.png".format(path) plt.savefig(png_fname,dpi=75) for i in range(N): c = np.array(color(float(i)/N))*0.8 ax1.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=c) ax2.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=c) ax1.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=c,s=40) ax2.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=c,s=40) png_fname = "{}/{}.png".format(path,ep) plt.savefig(png_fname,dpi=75) # clean up plt.clf() def plot_save_poses_blender(opt,fig,pose,pose_ref=None,path=None,ep=None): # get the camera meshes _,_,cam = get_camera_mesh(pose,depth=opt.visdom.cam_depth) cam = cam.numpy() if pose_ref is not None: _,_,cam_ref = get_camera_mesh(pose_ref,depth=opt.visdom.cam_depth) cam_ref = cam_ref.numpy() # set up plot window(s) ax = fig.add_subplot(111,projection="3d") ax.set_title("epoch {}".format(ep),pad=0) setup_3D_plot(ax,elev=45,azim=35,lim=edict(x=(-3,3),y=(-3,3),z=(-3,2.4))) plt.subplots_adjust(left=0,right=1,bottom=0,top=0.95,wspace=0,hspace=0) plt.margins(tight=True,x=0,y=0) # plot the cameras N = len(cam) ref_color = (0.7,0.2,0.7) pred_color = (0,0.6,0.7) ax.add_collection3d(Poly3DCollection([v[:4] for v in cam_ref],alpha=0.2,facecolor=ref_color)) for i in range(N): ax.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=ref_color,linewidth=0.5) ax.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=ref_color,s=20) if ep==0: png_fname = "{}/GT.png".format(path) plt.savefig(png_fname,dpi=75) ax.add_collection3d(Poly3DCollection([v[:4] for v in cam],alpha=0.2,facecolor=pred_color)) for i in range(N): ax.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=pred_color,linewidth=1) ax.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=pred_color,s=20) for i in range(N): ax.plot([cam[i,5,0],cam_ref[i,5,0]], [cam[i,5,1],cam_ref[i,5,1]], [cam[i,5,2],cam_ref[i,5,2]],color=(1,0,0),linewidth=3) png_fname = "{}/{}.png".format(path,ep) plt.savefig(png_fname,dpi=75) # clean up plt.clf() def setup_3D_plot(ax,elev,azim,lim=None): ax.xaxis.set_pane_color((1.0,1.0,1.0,0.0)) ax.yaxis.set_pane_color((1.0,1.0,1.0,0.0)) ax.zaxis.set_pane_color((1.0,1.0,1.0,0.0)) ax.xaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1) ax.yaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1) ax.zaxis._axinfo["grid"]["color"] = (0.9,0.9,0.9,1) ax.xaxis.set_tick_params(labelsize=8) ax.yaxis.set_tick_params(labelsize=8) ax.zaxis.set_tick_params(labelsize=8) ax.set_xlabel("X",fontsize=16) ax.set_ylabel("Y",fontsize=16) ax.set_zlabel("Z",fontsize=16) ax.set_xlim(lim.x[0],lim.x[1]) ax.set_ylim(lim.y[0],lim.y[1]) ax.set_zlim(lim.z[0],lim.z[1]) ax.view_init(elev=elev,azim=azim) # Copyright 2022 The Nerfstudio Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Helper functions for visualizing outputs """ import torch from matplotlib import cm def apply_colormap(image, cmap="viridis"): """Convert single channel to a color image. Args: image: Single channel image. : TensorType["bs":..., 1] cmap: Colormap for image. Returns: TensorType: Colored image -> TensorType["bs":..., "rgb":3] """ colormap = cm.get_cmap(cmap) colormap = torch.tensor(colormap.colors).to(image.device) # type: ignore image_long = (image * 255).long() image_long[image_long<0]=0 image_long[image_long>255]=255 image_long_min = torch.min(image_long) image_long_max = torch.max(image_long) assert image_long_min >= 0, f"the min value is {image_long_min}" assert image_long_max <= 255, f"the max value is {image_long_max}" return colormap[image_long[..., 0]] def apply_depth_colormap( depth, accumulation= None, near_plane= None, far_plane= None, cmap="turbo", ): """Converts a depth image to color for easier analysis. Args: depth: Depth image.: TensorType["bs":..., 1] accumulation: Ray accumulation used for masking vis. : Optional[TensorType["bs":..., 1]] near_plane: Closest depth to consider. If None, use min image value. : Optional[float] far_plane: Furthest depth to consider. If None, use max image value. : Optional[float] cmap: Colormap to apply. # inferno turbo viridis Returns: Colored depth image -> TensorType["bs":..., "rgb":3] """ # near_plane = near_plane or float(torch.min(depth)) # far_plane = far_plane or float(torch.max(depth)) # depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) depth = torch.clip(depth, 0 , 1) colored_image = apply_colormap(depth, cmap=cmap) if accumulation is not None: colored_image = colored_image * accumulation + (1 - accumulation) return colored_image ================================================ FILE: warp.py ================================================ import numpy as np import os,sys,time import torch import torch.nn.functional as torch_F import util from util import log,debug import camera import warnings def get_normalized_pixel_grid(opt): y_range = ((torch.arange(opt.H,dtype=torch.float32,device=opt.device)+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) x_range = ((torch.arange(opt.W,dtype=torch.float32,device=opt.device)+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) Y,X = torch.meshgrid(y_range,x_range) # [H,W] xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] xy_grid = xy_grid.repeat(opt.batch_size,1,1) # [B,HW,2] return xy_grid def get_normalized_pixel_grid_crop(opt): y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2) x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2) y_range = ((torch.arange(*(y_crop),dtype=torch.float32,device=opt.device)+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) x_range = ((torch.arange(*(x_crop),dtype=torch.float32,device=opt.device)+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) Y,X = torch.meshgrid(y_range,x_range) # [H,W] xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2] xy_grid = xy_grid.repeat(opt.batch_size,1,1) # [B,HW,2] return xy_grid def warp_grid(opt,xy_grid,warp): if opt.warp.type=="translation": assert(opt.warp.dof==2) warped_grid = xy_grid+warp[...,None,:] elif opt.warp.type=="rotation": assert(opt.warp.dof==1) warp_matrix = lie.so2_to_SO2(warp) warped_grid = xy_grid@warp_matrix.transpose(-2,-1) # [B,HW,2] elif opt.warp.type=="rigid": assert(opt.warp.dof==3) xy_grid_hom = camera.to_hom(xy_grid) warp_matrix = lie.se2_to_SE2(warp) warped_grid = xy_grid_hom@warp_matrix.transpose(-2,-1) # [B,HW,2] elif opt.warp.type=="homography": assert(opt.warp.dof==8) xy_grid_hom = camera.to_hom(xy_grid) warp_matrix = lie.sl3_to_SL3(warp) warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1) warped_grid = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2] else: assert(False) return warped_grid def warp_grid_use_matrix(opt,xy_grid,warp_matrix): if opt.warp.type=="translation": assert(opt.warp.dof==2) warped_grid = xy_grid+warp_matrix[...,None,:] elif opt.warp.type=="rotation": assert(opt.warp.dof==1) warped_grid = xy_grid@warp_matrix.transpose(-2,-1) # [B,HW,2] elif opt.warp.type=="rigid": assert(opt.warp.dof==3) xy_grid_hom = camera.to_hom(xy_grid) warped_grid = xy_grid_hom@warp_matrix.transpose(-2,-1) # [B,HW,2] elif opt.warp.type=="homography": assert(opt.warp.dof==8) xy_grid_hom = camera.to_hom(xy_grid) warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1) warped_grid = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2] else: assert(False) return warped_grid def warp_corners(opt,warp_param): y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2) x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2) Y = [((y+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) for y in y_crop] X = [((x+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) for x in x_crop] corners = [(X[0],Y[0]),(X[0],Y[1]),(X[1],Y[1]),(X[1],Y[0])] corners = torch.tensor(corners,dtype=torch.float32,device=opt.device).repeat(opt.batch_size,1,1) corners_warped = warp_grid(opt,corners,warp_param) return corners_warped def warp_corners_compose(opt, homo_pert, rot_pert, trans_pert): y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2) x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2) Y = [((y+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) for y in y_crop] X = [((x+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) for x in x_crop] corners = [(X[0],Y[0]),(X[0],Y[1]),(X[1],Y[1]),(X[1],Y[0])] corners = torch.tensor(corners,dtype=torch.float32,device=opt.device).repeat(opt.batch_size,1,1) xy_grid_hom = camera.to_hom(corners) warp_matrix = lie.sl3_to_SL3(homo_pert) warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1) corners_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2] warp_matrix = lie.so2_to_SO2(rot_pert) corners_warped = corners_warped@warp_matrix.transpose(-2,-1) # [B,HW,2] corners_warped = corners_warped+trans_pert[...,None,:] return corners_warped def warp_corners_use_matrix(opt,warp_matrix): y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2) x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2) Y = [((y+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) for y in y_crop] X = [((x+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) for x in x_crop] corners = [(X[0],Y[0]),(X[0],Y[1]),(X[1],Y[1]),(X[1],Y[0])] corners = torch.tensor(corners,dtype=torch.float32,device=opt.device).repeat(opt.batch_size,1,1) corners_warped = warp_grid_use_matrix(opt,corners,warp_matrix) return corners_warped def check_corners_in_range(opt,warp_param): corners_all = warp_corners(opt,warp_param) X = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5 Y = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5 return (0<=X).all() and (X0: denom *= (2*i)*(2*i+1) ans = ans+(-1)**i*x**(2*i)/denom return ans def taylor_B(self,x,nth=10): # Taylor expansion of (1-cos(x))/x ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): denom *= (2*i+1)*(2*i+2) ans = ans+(-1)**i*x**(2*i+1)/denom return ans def taylor_C(self,x,nth=10): # Taylor expansion of (x*cos(x)-sin(x))/x**2 ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): denom *= (2*i+2)*(2*i+3) ans = ans+(-1)**(i+1)*x**(2*i+1)*(2*i+2)/denom return ans def taylor_D(self,x,nth=10): # Taylor expansion of (x*sin(x)+cos(x)-1)/x**2 ans = torch.zeros_like(x) denom = 1. for i in range(nth+1): denom *= (2*i+1)*(2*i+2) ans = ans+(-1)**i*x**(2*i)*(2*i+1)/denom return ans lie = Lie()