Full Code of rover-xingyu/L2G-NeRF for AI

main 08d06597a233 cached
40 files
218.7 KB
58.3k tokens
296 symbols
1 requests
Download .txt
Showing preview only (231K chars total). Download the full file or copy to clipboard to get everything.
Repository: rover-xingyu/L2G-NeRF
Branch: main
Commit: 08d06597a233
Files: 40
Total size: 218.7 KB

Directory structure:
gitextract_16lz1fht/

├── LICENSE
├── README.md
├── camera.py
├── data/
│   ├── base.py
│   ├── blender.py
│   ├── iphone.py
│   └── llff.py
├── evaluate.py
├── external/
│   └── pohsun_ssim/
│       ├── LICENSE.txt
│       ├── README.md
│       ├── max_ssim.py
│       ├── pytorch_ssim/
│       │   └── __init__.py
│       ├── setup.cfg
│       └── setup.py
├── extract_mesh.py
├── model/
│   ├── barf.py
│   ├── base.py
│   ├── l2g_nerf.py
│   ├── l2g_planar.py
│   ├── nerf.py
│   └── planar.py
├── options/
│   ├── barf_blender.yaml
│   ├── barf_iphone.yaml
│   ├── barf_llff.yaml
│   ├── base.yaml
│   ├── l2g_nerf_blender.yaml
│   ├── l2g_nerf_iphone.yaml
│   ├── l2g_nerf_llff.yaml
│   ├── l2g_planar.yaml
│   ├── nerf_blender.yaml
│   ├── nerf_blender_repr.yaml
│   ├── nerf_llff.yaml
│   ├── nerf_llff_repr.yaml
│   └── planar.yaml
├── options.py
├── requirements.yaml
├── train.py
├── util.py
├── util_vis.py
└── warp.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) [year] [fullname]

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# L2G-NeRF: Local-to-Global Registration for Bundle-Adjusting Neural Radiance Fields
**[Project Page](https://rover-xingyu.github.io/L2G-NeRF/) |
[Paper](https://arxiv.org/pdf/2211.11505.pdf) |
[Video](https://www.youtube.com/watch?v=y8XP9Umt6Mw)**

[Yue Chen¹](https://scholar.google.com/citations?user=M2hq1_UAAAAJ&hl=en), 
[Xingyu Chen¹](https://scholar.google.com/citations?user=gDHPrWEAAAAJ&hl=en), 
[Xuan Wang²](https://scholar.google.com/citations?user=h-3xd3EAAAAJ&hl=en),
[Qi Zhang³](https://scholar.google.com/citations?user=2vFjhHMAAAAJ&hl=en), 
[Yu Guo¹](https://scholar.google.com/citations?user=OemeiSIAAAAJ&hl=en), 
[Ying Shan³](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en), 
[Fei Wang¹](https://scholar.google.com/citations?user=uU2JTpUAAAAJ&hl=en). 

[¹Xi'an Jiaotong University](http://en.xjtu.edu.cn/),
[²Ant Group](https://www.antgroup.com/en),
[³Tencent AI Lab](https://ai.tencent.com/ailab/en/index/). 

This repository is an official implementation of [L2G-NeRF](https://rover-xingyu.github.io/L2G-NeRF/) using [pytorch](https://pytorch.org/). 

# :computer: Installation

## Hardware

* We implement all experiments on a single NVIDIA GeForce RTX 2080 Ti GPU. 
* L2G-NeRF takes about 4.5 and 8 hours for training in synthetic objects and real-world scenes, respectively, while training BARF takes about 8 and 10.5 hours.

## Software

* Clone this repo by `git clone https://github.com/rover-xingyu/L2G-NeRF`
* This code is developed with Python3. PyTorch 1.9+ is required. It is recommended use [Anaconda](https://www.anaconda.com/products/individual) to set up the environment, use `conda env create --file requirements.yaml python=3` to install the dependencies and activate it by `conda activate L2G-NeRF`
--------------------------------------

# :key: Training and Evaluation

## Data download

Both the Blender synthetic data and LLFF real-world data can be found in the [NeRF Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).
For convenience, you can download them with the following script:
  ```bash
  # Blender
  gdown --id 18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG # download nerf_synthetic.zip
  unzip nerf_synthetic.zip
  rm -f nerf_synthetic.zip
  mv nerf_synthetic data/blender
  # LLFF
  gdown --id 16VnMcF1KJYxN9QId6TClMsZRahHNMW5g # download nerf_llff_data.zip
  unzip nerf_llff_data.zip
  rm -f nerf_llff_data.zip
  mv nerf_llff_data data/llff
  ```
--------------------------------------

## Running L2G-NeRF
To train and evaluate L2G-NeRF:
```bash
# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets

# NeRF (3D): Synthetic Objects
# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})
python3 train.py \
--model=l2g_nerf --yaml=l2g_nerf_blender \
--group=exp_synthetic --name=l2g_lego \
--data.scene=lego --gpu=3 \
--data.root=/the/data/path/of/nerf_synthetic/ \
--camera.noise_r=0.07 --camera.noise_t=0.5

python3 evaluate.py \
--model=l2g_nerf --yaml=l2g_nerf_blender \
--group=exp_synthetic --name=l2g_lego \
--data.scene=lego --gpu=3 \
--data.root=/the/data/path/of/nerf_synthetic/ \
--data.val_sub= --resume

# NeRF (3D): Real-World Scenes
# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})
python3 train.py \
--model=l2g_nerf --yaml=l2g_nerf_llff \
--group=exp_LLFF --name=l2g_fern \
--data.scene=fern --gpu=3 \
--data.root=/the/data/path/of/nerf_llff_data/ \
--loss_weight.global_alignment=2

python3 evaluate.py \
--model=l2g_nerf --yaml=l2g_nerf_llff \
--group=exp_LLFF --name=l2g_fern \
--data.scene=fern --gpu=3 \
--data.root=/the/data/path/of/nerf_llff_data/ \
--loss_weight.global_alignment=2 \
--resume

# Neural image Alignment (2D): Rigid
# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment
python3 train.py \
--model=l2g_planar --yaml=l2g_planar \
--group=exp_planar --name=l2g_girl  \
--warp.type=rigid --warp.dof=3 \
--data.image_fname=data/girl.jpg \
--data.image_size=[595,512] \
--data.patch_crop=[260,260] \
--seed=1 --gpu=3

# Neural image Alignment (2D): Homography
# use the image of “cat” from ImageNet for homography image alignment
python3 train.py \
--model=l2g_planar --yaml=l2g_planar \
--group=exp_planar --name=l2g_cat \
--warp.type=homography --warp.dof=8 \
--data.image_fname=data/cat.jpg \
--data.image_size=[360,480] \
--data.patch_crop=[180,180] \
--gpu=0
```
--------------------------------------

## Running BARF
If you want to train and evaluate the BARF extension of the original NeRF model that jointly optimizes poses (coarse-to-fine positional encoding):
```bash
# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets

# NeRF (3D): Synthetic Objects
# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})
python3 train.py \
--model=barf --yaml=barf_blender \
--group=exp_synthetic --name=barf_lego \
--data.scene=lego --gpu=1 \
--data.root=/the/data/path/of/nerf_synthetic/ \
--camera.noise_r=0.07 --camera.noise_t=0.5

python3 evaluate.py \
--model=barf --yaml=barf_blender \
--group=exp_synthetic --name=barf_lego \
--data.scene=lego --gpu=1 \
--data.root=/the/data/path/of/nerf_synthetic/ \
--data.val_sub= --resume

# NeRF (3D): Real-World Scenes
# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})
python3 train.py \
--model=barf --yaml=barf_llff \
--group=exp_LLFF --name=barf_fern \
--data.scene=fern --gpu=1 \
--data.root=/the/data/path/of/nerf_llff_data/

python3 evaluate.py \
--model=barf --yaml=barf_llff \
--group=exp_LLFF --name=barf_fern \
--data.scene=fern --gpu=1 \
--data.root=/the/data/path/of/nerf_llff_data/ \
--resume

# Neural image Alignment (2D): Rigid
# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment
python3 train.py \
--model=planar --yaml=planar \
--group=exp_planar --name=barf_girl \
--warp.type=rigid --warp.dof=3 \
--data.image_fname=data/girl.jpg \
--data.image_size=[595,512] \
--data.patch_crop=[260,260] \
--seed=1 --gpu=1

# Neural image Alignment (2D): Homography
# use the image of “cat” from ImageNet for homography image alignment
python3 train.py \
--model=planar --yaml=planar \
--group=exp_planar --name=barf_cat \
--warp.type=homography --warp.dof=8 \
--data.image_fname=data/cat.jpg \
--data.image_size=[360,480] \
--data.patch_crop=[180,180] \
--gpu=2
```
--------------------------------------

## Running Naive
If you want to train and evaluate the Naive extension of the original NeRF model that jointly optimizes poses (full positional encoding):

```bash
# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets

# NeRF (3D): Synthetic Objects
# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})
python3 train.py \
--model=barf --yaml=barf_blender \
--group=exp_synthetic --name=nerf_lego \
--data.scene=lego --gpu=2 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \
--barf_c2f=null \
--camera.noise_r=0.07 --camera.noise_t=0.5

python3 evaluate.py \
--model=barf --yaml=barf_blender \
--group=exp_synthetic --name=nerf_lego \
--data.scene=lego --gpu=2 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \
--barf_c2f=null \
--data.val_sub= --resume

# NeRF (3D): Real-World Scenes
# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})
python3 train.py \
--model=barf --yaml=barf_llff \
--group=exp_LLFF --name=nerf_fern \
--data.scene=fern --gpu=2 \
--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \
--barf_c2f=null

python3 evaluate.py \
--model=barf --yaml=barf_llff \
--group=exp_LLFF --name=nerf_fern \
--data.scene=fern --gpu=2 \
--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \
--barf_c2f=null --resume 

# Neural image Alignment (2D): Rigid
# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment
python3 train.py \
--model=planar --yaml=planar \
--group=exp_planar --name=naive_girl \
--warp.type=rigid --warp.dof=3 \
--data.image_fname=data/girl.jpg \
--data.image_size=[595,512] \
--data.patch_crop=[260,260] \
--seed=1 --gpu=2 --barf_c2f=null

# Neural image Alignment (2D): Homography
# use the image of “cat” from ImageNet for homography image alignment
python3 train.py \
--model=planar --yaml=planar \
--group=exp_planar --name=naive_cat \
--warp.type=homography --warp.dof=8 \
--data.image_fname=data/cat.jpg \
--data.image_size=[360,480] \
--data.patch_crop=[180,180] \
--gpu=3 --barf_c2f=null
```
--------------------------------------

## Running reference NeRF
If you want to train and evaluate the reference NeRF models (assuming known camera poses):
```bash
# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets

# NeRF (3D): Synthetic Objects
# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})
python3 train.py \
--model=nerf --yaml=nerf_blender \
--group=exp_synthetic --name=ref_lego \
--data.scene=lego --gpu=0 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ 

python3 evaluate.py \
--model=nerf --yaml=nerf_blender \
--group=exp_synthetic --name=ref_lego \
--data.scene=lego --gpu=0 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \
--data.val_sub= --resume

# NeRF (3D): Real-World Scenes
# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})
python3 train.py \
--model=nerf --yaml=nerf_llff \
--group=exp_LLFF --name=ref_fern \
--data.scene=fern --gpu=0 \
--data.root=/home/cy/PNW/datasets/nerf_llff_data/


python3 evaluate.py \
--model=nerf --yaml=nerf_llff \
--group=exp_LLFF --name=ref_fern \
--data.scene=fern --gpu=0 \
--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \
--resume
```
--------------------------------------

# :laughing: Visualization

## Results and Videos
All the results will be stored in the directory `output/<GROUP>/<NAME>`.
You may want to organize your experiments by grouping different runs in the same group. Many videos will be created to visualize the pose optimization process and novel view synthesis.

## TensorBoard
The TensorBoard events include the following:
- **SCALARS**: the rendering losses and PSNR over the course of optimization. For L2G_NeRF/BARF/Naive, the rotational/translational errors with respect to the given poses are also computed.
- **IMAGES**: visualization of the RGB images and the RGB/depth rendering.

## Visdom
The visualization of 3D camera poses is provided in Visdom:
Run `visdom -port 8600` to start the Visdom server.  The Visdom host server is default to `localhost`; this can be overridden with `--visdom.server` (see `options/base.yaml` for details). If you want to disable Visdom visualization, add `--visdom!`.

## Mesh
The `extract_mesh.py` script provides a simple way to extract the underlying 3D geometry using marching cubes (supporte for the Blender dataset). Run as follows:
```bash
python3 extract_mesh.py \
--model=l2g_nerf --yaml=l2g_nerf_blender \
--group=exp_synthetic --name=l2g_lego \
--data.scene=lego --gpu=3 \
--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \
--data.val_sub= --resume
```
--------------------------------------

# :mag_right: Codebase structure

The main engine and network architecture in `model/l2g_nerf.py` inherit those from `model/nerf.py`.

Some tips on using and understanding the codebase:
- The computation graph for forward/backprop is stored in `var` throughout the codebase.
- The losses are stored in `loss`. To add a new loss function, just implement it in `compute_loss()` and add its weight to `opt.loss_weight.<name>`. It will automatically be added to the overall loss and logged to Tensorboard.
- If you are using a multi-GPU machine, you can set `--gpu=<gpu_number>` to specify which GPU to use. Multi-GPU training/evaluation is currently not supported.
- To resume from a previous checkpoint, add `--resume=<ITER_NUMBER>`, or just `--resume` to resume from the latest checkpoint.
- To eliminate the global alignment objective, set `--loss_weight.global_alignment=null`, the ablation is equivalent to a local registration method.

--------------------------------------
# Citation

If you find this project useful for your research, please use the following BibTeX entry.

```bibtex
@inproceedings{chen2023local,
  title={Local-to-global registration for bundle-adjusting neural radiance fields},
  author={Chen, Yue and Chen, Xingyu and Wang, Xuan and Zhang, Qi and Guo, Yu and Shan, Ying and Wang, Fei},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={8264--8273},
  year={2023}
}
```
--------------------------------------

# Acknowledge
Our code is based on the awesome pytorch implementation of Bundle-Adjusting Neural Radiance Fields ([BARF](https://github.com/chenhsuanlin/bundle-adjusting-NeRF)). We appreciate all the contributors.


================================================
FILE: camera.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import collections
from easydict import EasyDict as edict

import util
from util import log,debug

class Pose():
    """
    A class of operations on camera poses (PyTorch tensors with shape [...,3,4])
    each [3,4] camera pose takes the form of [R|t]
    """

    def __call__(self,R=None,t=None):
        # construct a camera pose from the given R and/or t
        assert(R is not None or t is not None)
        if R is None:
            if not isinstance(t,torch.Tensor): t = torch.tensor(t)
            R = torch.eye(3,device=t.device).repeat(*t.shape[:-1],1,1)
        elif t is None:
            if not isinstance(R,torch.Tensor): R = torch.tensor(R)
            t = torch.zeros(R.shape[:-1],device=R.device)
        else:
            if not isinstance(R,torch.Tensor): R = torch.tensor(R)
            if not isinstance(t,torch.Tensor): t = torch.tensor(t)
        assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3,3))
        R = R.float()
        t = t.float()
        pose = torch.cat([R,t[...,None]],dim=-1) # [...,3,4]
        assert(pose.shape[-2:]==(3,4))
        return pose

    def invert(self,pose,use_inverse=False):
        # invert a camera pose
        R,t = pose[...,:3],pose[...,3:]
        R_inv = R.inverse() if use_inverse else R.transpose(-1,-2)
        t_inv = (-R_inv@t)[...,0]
        pose_inv = self(R=R_inv,t=t_inv)
        return pose_inv

    def compose(self,pose_list):
        # compose a sequence of poses together
        # pose_new(x) = poseN o ... o pose2 o pose1(x)
        pose_new = pose_list[0]
        for pose in pose_list[1:]:
            pose_new = self.compose_pair(pose_new,pose)
        return pose_new

    def compose_pair(self,pose_a,pose_b):
        # pose_new(x) = pose_b o pose_a(x)
        R_a,t_a = pose_a[...,:3],pose_a[...,3:]
        R_b,t_b = pose_b[...,:3],pose_b[...,3:]
        R_new = R_b@R_a
        t_new = (R_b@t_a+t_b)[...,0]
        pose_new = self(R=R_new,t=t_new)
        return pose_new

class Lie():
    """
    Lie algebra for SO(3) and SE(3) operations in PyTorch
    """

    def so3_to_SO3(self,w): # [...,3]
        wx = self.skew_symmetric(w)
        theta = w.norm(dim=-1)[...,None,None]
        I = torch.eye(3,device=w.device,dtype=torch.float32)
        A = self.taylor_A(theta)
        B = self.taylor_B(theta)
        R = I+A*wx+B*wx@wx
        return R

    def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
        trace = R[...,0,0]+R[...,1,1]+R[...,2,2]
        theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi
        lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird
        w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]
        w = torch.stack([w0,w1,w2],dim=-1)
        return w

    def se3_to_SE3(self,wu): # [...,3]
        w,u = wu.split([3,3],dim=-1)
        wx = self.skew_symmetric(w)
        theta = w.norm(dim=-1)[...,None,None]
        I = torch.eye(3,device=w.device,dtype=torch.float32)
        A = self.taylor_A(theta)
        B = self.taylor_B(theta)
        C = self.taylor_C(theta)
        R = I+A*wx+B*wx@wx
        V = I+B*wx+C*wx@wx
        Rt = torch.cat([R,(V@u[...,None])],dim=-1)
        return Rt

    def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
        R,t = Rt.split([3,1],dim=-1)
        w = self.SO3_to_so3(R)
        wx = self.skew_symmetric(w)
        theta = w.norm(dim=-1)[...,None,None]
        I = torch.eye(3,device=w.device,dtype=torch.float32)
        A = self.taylor_A(theta)
        B = self.taylor_B(theta)
        invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx
        u = (invV@t)[...,0]
        wu = torch.cat([w,u],dim=-1)
        return wu    

    def skew_symmetric(self,w):
        w0,w1,w2 = w.unbind(dim=-1)
        O = torch.zeros_like(w0)
        wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),
                          torch.stack([w2,O,-w0],dim=-1),
                          torch.stack([-w1,w0,O],dim=-1)],dim=-2)
        return wx

    def taylor_A(self,x,nth=10):
        # Taylor expansion of sin(x)/x
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            if i>0: denom *= (2*i)*(2*i+1)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans
    def taylor_B(self,x,nth=10):
        # Taylor expansion of (1-cos(x))/x**2
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            denom *= (2*i+1)*(2*i+2)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans
    def taylor_C(self,x,nth=10):
        # Taylor expansion of (x-sin(x))/x**3
        ans = torch.zeros_like(x)
        denom = 1.
        for i in range(nth+1):
            denom *= (2*i+2)*(2*i+3)
            ans = ans+(-1)**i*x**(2*i)/denom
        return ans

class Quaternion():

    def q_to_R(self,q):
        # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
        qa,qb,qc,qd = q.unbind(dim=-1)
        R = torch.stack([torch.stack([1-2*(qc**2+qd**2),2*(qb*qc-qa*qd),2*(qa*qc+qb*qd)],dim=-1),
                         torch.stack([2*(qb*qc+qa*qd),1-2*(qb**2+qd**2),2*(qc*qd-qa*qb)],dim=-1),
                         torch.stack([2*(qb*qd-qa*qc),2*(qa*qb+qc*qd),1-2*(qb**2+qc**2)],dim=-1)],dim=-2)
        return R

    def R_to_q(self,R,eps=1e-8): # [B,3,3]
        # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion
        # FIXME: this function seems a bit problematic, need to double-check
        row0,row1,row2 = R.unbind(dim=-2)
        R00,R01,R02 = row0.unbind(dim=-1)
        R10,R11,R12 = row1.unbind(dim=-1)
        R20,R21,R22 = row2.unbind(dim=-1)
        t = R[...,0,0]+R[...,1,1]+R[...,2,2]
        r = (1+t+eps).sqrt()
        qa = 0.5*r
        qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt()
        qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt()
        qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt()
        q = torch.stack([qa,qb,qc,qd],dim=-1)
        for i,qi in enumerate(q):
            if torch.isnan(qi).any():
                K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1),
                                 torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R02],dim=-1),
                                 torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1),
                                 torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0
                K = K[i]
                eigval,eigvec = torch.linalg.eigh(K)
                V = eigvec[:,eigval.argmax()]
                q[i] = torch.stack([V[3],V[0],V[1],V[2]])
        return q

    def invert(self,q):
        qa,qb,qc,qd = q.unbind(dim=-1)
        norm = q.norm(dim=-1,keepdim=True)
        q_inv = torch.stack([qa,-qb,-qc,-qd],dim=-1)/norm**2
        return q_inv

    def product(self,q1,q2): # [B,4]
        q1a,q1b,q1c,q1d = q1.unbind(dim=-1)
        q2a,q2b,q2c,q2d = q2.unbind(dim=-1)
        hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d,
                                  q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c,
                                  q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b,
                                  q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a],dim=-1)
        return hamil_prod

pose = Pose()
lie = Lie()
quaternion = Quaternion()

def to_hom(X):
    # get homogeneous coordinates of the input
    X_hom = torch.cat([X,torch.ones_like(X[...,:1])],dim=-1)
    return X_hom

# basic operations of transforming 3D points between world/camera/image coordinates
def world2cam(X,pose): # [B,N,3]
    X_hom = to_hom(X)
    return X_hom@pose.transpose(-1,-2)
def cam2img(X,cam_intr):
    return X@cam_intr.transpose(-1,-2)
def img2cam(X,cam_intr):
    return X@cam_intr.inverse().transpose(-1,-2)
def cam2world(X,pose):
    X_hom = to_hom(X)
    pose_inv = Pose().invert(pose)
    return X_hom@pose_inv.transpose(-1,-2)

def angle_to_rotation_matrix(a,axis):
    # get the rotation matrix from Euler angle around specific axis
    roll = dict(X=1,Y=2,Z=0)[axis]
    O = torch.zeros_like(a)
    I = torch.ones_like(a)
    M = torch.stack([torch.stack([a.cos(),-a.sin(),O],dim=-1),
                     torch.stack([a.sin(),a.cos(),O],dim=-1),
                     torch.stack([O,O,I],dim=-1)],dim=-2)
    M = M.roll((roll,roll),dims=(-2,-1))
    return M

def get_center_and_ray(opt,pose,intr=None): # [HW,2]
    # given the intrinsic/extrinsic matrices, get the camera center and ray directions]
    assert(opt.camera.model=="perspective")
    with torch.no_grad():
        # compute image coordinate grid
        y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)
        x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)
        Y,X = torch.meshgrid(y_range,x_range) # [H,W]
        xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
    # compute center and ray
    batch_size = len(pose)
    xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]
    grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]
    center_3D = torch.zeros_like(grid_3D) # [B,HW,3]
    # transform from camera to world coordinates
    grid_3D = cam2world(grid_3D,pose) # [B,HW,3]
    center_3D = cam2world(center_3D,pose) # [B,HW,3]
    ray = grid_3D-center_3D # [B,HW,3]
    return center_3D,ray

def get_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [HW,2]
    # given the intrinsic matrices, get the grid_3D]
    assert(opt.camera.model=="perspective")
    with torch.no_grad():
        # compute image coordinate grid
        y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)
        x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)
        Y,X = torch.meshgrid(y_range,x_range) # [H,W]
        xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
        # compute grid_3D
        xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]
        grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]
        if ray_idx is not None:
            # consider only subset of rays
            grid_3D = grid_3D[:,ray_idx]
    return grid_3D

def gather_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [HW,2]
    # given the intrinsic matrices, get the grid_3D]
    assert(opt.camera.model=="perspective")
    with torch.no_grad():
        # compute image coordinate grid
        y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)
        x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)
        Y,X = torch.meshgrid(y_range,x_range) # [H,W]
        xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]
        # compute grid_3D
        xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]
        grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]
        if ray_idx is not None:
            # consider only subset of rays
            grid_3D = torch.gather(grid_3D, 1, ray_idx[...,None].expand(-1,-1,3))
    return grid_3D

def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False):
    if multi_samples: center,ray = center[:,:,None],ray[:,:,None]
    # x = c+dv
    points_3D = center+ray*depth # [B,HW,3]/[B,HW,N,3]/[N,3]
    return points_3D

def convert_NDC(opt,center,ray,intr,near=1):
    # shift camera center (ray origins) to near plane (z=1)
    # (unlike conventional NDC, we assume the cameras are facing towards the +z direction)
    center = center+(near-center[...,2:])/ray[...,2:]*ray
    # projection
    cx,cy,cz = center.unbind(dim=-1) # [B,HW]
    rx,ry,rz = ray.unbind(dim=-1) # [B,HW]
    scale_x = intr[:,0,0]/intr[:,0,2] # [B]
    scale_y = intr[:,1,1]/intr[:,1,2] # [B]
    cnx = scale_x[:,None]*(cx/cz)
    cny = scale_y[:,None]*(cy/cz)
    cnz = 1-2*near/cz
    rnx = scale_x[:,None]*(rx/rz-cx/cz)
    rny = scale_y[:,None]*(ry/rz-cy/cz)
    rnz = 2*near/cz
    center_ndc = torch.stack([cnx,cny,cnz],dim=-1) # [B,HW,3]
    ray_ndc = torch.stack([rnx,rny,rnz],dim=-1) # [B,HW,3]
    return center_ndc,ray_ndc

def rotation_distance(R1,R2,eps=1e-7):
    # http://www.boris-belousov.net/2016/12/01/quat-dist/
    R_diff = R1@R2.transpose(-2,-1)
    trace = R_diff[...,0,0]+R_diff[...,1,1]+R_diff[...,2,2]
    angle = ((trace-1)/2).clamp(-1+eps,1-eps).acos_() # numerical stability near -1/+1
    return angle

def procrustes_analysis(X0,X1): # [N,3]
    # translation
    t0 = X0.mean(dim=0,keepdim=True)
    t1 = X1.mean(dim=0,keepdim=True)
    X0c = X0-t0
    X1c = X1-t1
    # scale
    s0 = (X0c**2).sum(dim=-1).mean().sqrt()
    s1 = (X1c**2).sum(dim=-1).mean().sqrt()
    X0cs = X0c/s0
    X1cs = X1c/s1
    # rotation (use double for SVD, float loses precision)
    U,S,V = (X0cs.t()@X1cs).double().svd(some=True)
    R = (U@V.t()).float()
    if R.det()<0: R[2] *= -1
    # align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0
    sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R)
    return sim3

def get_novel_view_poses(opt,pose_anchor,N=60,scale=1):
    # create circular viewpoints (small oscillations)
    theta = torch.arange(N)/N*2*np.pi
    R_x = angle_to_rotation_matrix((theta.sin()*0.05).asin(),"X")
    R_y = angle_to_rotation_matrix((theta.cos()*0.05).asin(),"Y")
    pose_rot = pose(R=R_y@R_x)
    pose_shift = pose(t=[0,0,-4*scale])
    pose_shift2 = pose(t=[0,0,3.8*scale])
    pose_oscil = pose.compose([pose_shift,pose_rot,pose_shift2])
    pose_novel = pose.compose([pose_oscil,pose_anchor.cpu()[None]])
    return pose_novel


================================================
FILE: data/base.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import torch.multiprocessing as mp
import PIL
import tqdm
import threading,queue
from easydict import EasyDict as edict

import util
from util import log,debug

class Dataset(torch.utils.data.Dataset):

    def __init__(self,opt,split="train"):
        super().__init__()
        self.opt = opt
        self.split = split
        self.augment = split=="train" and opt.data.augment
        # define image sizes
        if opt.data.center_crop is not None:
            self.crop_H = int(self.raw_H*opt.data.center_crop)
            self.crop_W = int(self.raw_W*opt.data.center_crop)
        else: self.crop_H,self.crop_W = self.raw_H,self.raw_W
        if not opt.H or not opt.W:
            opt.H,opt.W = self.crop_H,self.crop_W

    def setup_loader(self,opt,shuffle=False,drop_last=False):
        loader = torch.utils.data.DataLoader(self,
            batch_size=opt.batch_size or 1,
            num_workers=opt.data.num_workers,
            shuffle=shuffle,
            drop_last=drop_last,
            pin_memory=False, # spews warnings in PyTorch 1.9 but should be True in general
        )
        print("number of samples: {}".format(len(self)))
        return loader

    def get_list(self,opt):
        raise NotImplementedError

    def preload_worker(self,data_list,load_func,q,lock,idx_tqdm):
        while True:
            idx = q.get()
            data_list[idx] = load_func(self.opt,idx)
            with lock:
                idx_tqdm.update()
            q.task_done()

    def preload_threading(self,opt,load_func,data_str="images"):
        data_list = [None]*len(self)
        q = queue.Queue(maxsize=len(self))
        idx_tqdm = tqdm.tqdm(range(len(self)),desc="preloading {}".format(data_str),leave=False)
        for i in range(len(self)): q.put(i)
        lock = threading.Lock()
        for ti in range(opt.data.num_workers):
            t = threading.Thread(target=self.preload_worker,
                                 args=(data_list,load_func,q,lock,idx_tqdm),daemon=True)
            t.start()
        q.join()
        idx_tqdm.close()
        assert(all(map(lambda x: x is not None,data_list)))
        return data_list

    def __getitem__(self,idx):
        raise NotImplementedError

    def get_image(self,opt,idx):
        raise NotImplementedError

    def generate_augmentation(self,opt):
        brightness = opt.data.augment.brightness or 0.
        contrast = opt.data.augment.contrast or 0.
        saturation = opt.data.augment.saturation or 0.
        hue = opt.data.augment.hue or 0.
        color_jitter = torchvision.transforms.ColorJitter.get_params(
            brightness=(1-brightness,1+brightness),
            contrast=(1-contrast,1+contrast),
            saturation=(1-saturation,1+saturation),
            hue=(-hue,hue),
        )
        aug = edict(
            color_jitter=color_jitter,
            flip=np.random.randn()>0 if opt.data.augment.hflip else False,
            rot_angle=(np.random.rand()*2-1)*opt.data.augment.rotate if opt.data.augment.rotate else 0,
        )
        return aug

    def preprocess_image(self,opt,image,aug=None):
        if aug is not None:
            image = self.apply_color_jitter(opt,image,aug.color_jitter)
            image = torchvision_F.hflip(image) if aug.flip else image
            image = image.rotate(aug.rot_angle,resample=PIL.Image.BICUBIC)
        # center crop
        if opt.data.center_crop is not None:
            self.crop_H = int(self.raw_H*opt.data.center_crop)
            self.crop_W = int(self.raw_W*opt.data.center_crop)
            image = torchvision_F.center_crop(image,(self.crop_H,self.crop_W))
        else: self.crop_H,self.crop_W = self.raw_H,self.raw_W
        # resize
        if opt.data.image_size[0] is not None:
            image = image.resize((opt.W,opt.H))
        image = torchvision_F.to_tensor(image)
        return image

    def preprocess_camera(self,opt,intr,pose,aug=None):
        intr,pose = intr.clone(),pose.clone()
        # center crop
        intr[0,2] -= (self.raw_W-self.crop_W)/2
        intr[1,2] -= (self.raw_H-self.crop_H)/2
        # resize
        intr[0] *= opt.W/self.crop_W
        intr[1] *= opt.H/self.crop_H
        return intr,pose

    def apply_color_jitter(self,opt,image,color_jitter):
        mode = image.mode
        if mode!="L":
            chan = image.split()
            rgb = PIL.Image.merge("RGB",chan[:3])
            rgb = color_jitter(rgb)
            rgb_chan = rgb.split()
            image = PIL.Image.merge(mode,rgb_chan+chan[3:])
        return image

    def __len__(self):
        return len(self.list)


================================================
FILE: data/blender.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import PIL
import imageio
from easydict import EasyDict as edict
import json
import pickle

from . import base
import camera
from util import log,debug

class Dataset(base.Dataset):

    def __init__(self,opt,split="train",subset=None):
        self.raw_H,self.raw_W = 800,800
        super().__init__(opt,split)
        self.root = opt.data.root or "data/blender"
        self.path = "{}/{}".format(self.root,opt.data.scene)
        # load/parse metadata
        meta_fname = "{}/transforms_{}.json".format(self.path,split)
        with open(meta_fname) as file:
            self.meta = json.load(file)
        self.list = self.meta["frames"]
        self.focal = 0.5*self.raw_W/np.tan(0.5*self.meta["camera_angle_x"])
        if subset: self.list = self.list[:subset]
        # preload dataset
        if opt.data.preload:
            self.images = self.preload_threading(opt,self.get_image)
            self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")

    def prefetch_all_data(self,opt):
        assert(not opt.data.augment)
        # pre-iterate through all samples and group together
        self.all = torch.utils.data._utils.collate.default_collate([s for s in self])

    def get_all_camera_poses(self,opt):
        pose_raw_all = [torch.tensor(f["transform_matrix"],dtype=torch.float32) for f in self.list]
        pose_canon_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0)
        return pose_canon_all

    def __getitem__(self,idx):
        opt = self.opt
        sample = dict(idx=idx)
        aug = self.generate_augmentation(opt) if self.augment else None
        image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
        image = self.preprocess_image(opt,image,aug=aug)
        intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
        intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
        sample.update(
            image=image,
            intr=intr,
            pose=pose,
        )
        return sample

    def get_image(self,opt,idx):
        image_fname = "{}/{}.png".format(self.path,self.list[idx]["file_path"])
        image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
        return image

    def preprocess_image(self,opt,image,aug=None):
        image = super().preprocess_image(opt,image,aug=aug)
        rgb,mask = image[:3],image[3:]
        if opt.data.bgcolor is not None:
            rgb = rgb*mask+opt.data.bgcolor*(1-mask)
        return rgb

    def get_camera(self,opt,idx):
        intr = torch.tensor([[self.focal,0,self.raw_W/2],
                             [0,self.focal,self.raw_H/2],
                             [0,0,1]]).float()
        pose_raw = torch.tensor(self.list[idx]["transform_matrix"],dtype=torch.float32)
        pose = self.parse_raw_camera(opt,pose_raw)
        return intr,pose

    def parse_raw_camera(self,opt,pose_raw):
        pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1])))
        pose = camera.pose.compose([pose_flip,pose_raw[:3]])
        pose = camera.pose.invert(pose)
        return pose


================================================
FILE: data/iphone.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import PIL
import imageio
from easydict import EasyDict as edict
import json
import pickle

from . import base
import camera
from util import log,debug

class Dataset(base.Dataset):

    def __init__(self,opt,split="train",subset=None):
        self.raw_H,self.raw_W = 1080,1920
        super().__init__(opt,split)
        self.root = opt.data.root or "data/iphone"
        self.path = "{}/{}".format(self.root,opt.data.scene)
        self.path_image = "{}/images".format(self.path)
        # self.list = sorted(os.listdir(self.path_image),key=lambda f: int(f.split(".")[0]))
        self.list = os.listdir(self.path_image)
        # manually split train/val subsets
        num_val_split = int(len(self)*opt.data.val_ratio)
        self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:]
        if subset: self.list = self.list[:subset]
        # preload dataset
        if opt.data.preload:
            self.images = self.preload_threading(opt,self.get_image)
            self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")

    def prefetch_all_data(self,opt):
        assert(not opt.data.augment)
        # pre-iterate through all samples and group together
        self.all = torch.utils.data._utils.collate.default_collate([s for s in self])

    def get_all_camera_poses(self,opt):
        # poses are unknown, so just return some dummy poses (identity transform)
        return camera.pose(t=torch.zeros(len(self),3))

    def __getitem__(self,idx):
        opt = self.opt
        sample = dict(idx=idx)
        aug = self.generate_augmentation(opt) if self.augment else None
        image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
        image = self.preprocess_image(opt,image,aug=aug)
        intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
        intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
        sample.update(
            image=image,
            intr=intr,
            pose=pose,
        )
        return sample

    def get_image(self,opt,idx):
        image_fname = "{}/{}".format(self.path_image,self.list[idx])
        image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
        return image

    def get_camera(self,opt,idx):
        self.focal = self.raw_W*4.2/(12.8/2.55)
        intr = torch.tensor([[self.focal,0,self.raw_W/2],
                             [0,self.focal,self.raw_H/2],
                             [0,0,1]]).float()
        pose = camera.pose(t=torch.zeros(3)) # dummy pose, won't be used
        return intr,pose


================================================
FILE: data/llff.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import PIL
import imageio
from easydict import EasyDict as edict
import json
import pickle

from . import base
import camera
from util import log,debug

class Dataset(base.Dataset):

    def __init__(self,opt,split="train",subset=None):
        self.raw_H,self.raw_W = 3024,4032
        super().__init__(opt,split)
        self.root = opt.data.root or "data/llff"
        self.path = "{}/{}".format(self.root,opt.data.scene)
        self.path_image = "{}/images".format(self.path)
        image_fnames = sorted(os.listdir(self.path_image))
        poses_raw,bounds = self.parse_cameras_and_bounds(opt)
        self.list = list(zip(image_fnames,poses_raw,bounds))
        # manually split train/val subsets
        num_val_split = int(len(self)*opt.data.val_ratio)
        self.list = self.list[:-num_val_split] if split=="train" else self.list[-num_val_split:]
        if subset: self.list = self.list[:subset]
        # preload dataset
        if opt.data.preload:
            self.images = self.preload_threading(opt,self.get_image)
            self.cameras = self.preload_threading(opt,self.get_camera,data_str="cameras")

    def prefetch_all_data(self,opt):
        assert(not opt.data.augment)
        # pre-iterate through all samples and group together
        self.all = torch.utils.data._utils.collate.default_collate([s for s in self])

    def parse_cameras_and_bounds(self,opt):
        fname = "{}/poses_bounds.npy".format(self.path)
        data = torch.tensor(np.load(fname),dtype=torch.float32)
        # parse cameras (intrinsics and poses)
        cam_data = data[:,:-2].view([-1,3,5]) # [N,3,5]
        poses_raw = cam_data[...,:4] # [N,3,4]
        poses_raw[...,0],poses_raw[...,1] = poses_raw[...,1],-poses_raw[...,0]
        raw_H,raw_W,self.focal = cam_data[0,:,-1]
        assert(self.raw_H==raw_H and self.raw_W==raw_W)
        # parse depth bounds
        bounds = data[:,-2:] # [N,2]
        scale = 1./(bounds.min()*0.75) # not sure how this was determined
        poses_raw[...,3] *= scale
        bounds *= scale
        # roughly center camera poses
        poses_raw = self.center_camera_poses(opt,poses_raw)
        return poses_raw,bounds

    def center_camera_poses(self,opt,poses):
        # compute average pose
        center = poses[...,3].mean(dim=0)
        v1 = torch_F.normalize(poses[...,1].mean(dim=0),dim=0)
        v2 = torch_F.normalize(poses[...,2].mean(dim=0),dim=0)
        v0 = v1.cross(v2)
        pose_avg = torch.stack([v0,v1,v2,center],dim=-1)[None] # [1,3,4]
        # apply inverse of averaged pose
        poses = camera.pose.compose([poses,camera.pose.invert(pose_avg)])
        return poses

    def get_all_camera_poses(self,opt):
        pose_raw_all = [tup[1] for tup in self.list]
        pose_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0)
        return pose_all

    def __getitem__(self,idx):
        opt = self.opt
        sample = dict(idx=idx)
        aug = self.generate_augmentation(opt) if self.augment else None
        image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)
        image = self.preprocess_image(opt,image,aug=aug)
        intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)
        intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)
        sample.update(
            image=image,
            intr=intr,
            pose=pose,
        )
        return sample

    def get_image(self,opt,idx):
        image_fname = "{}/{}".format(self.path_image,self.list[idx][0])
        image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....
        return image

    def get_camera(self,opt,idx):
        intr = torch.tensor([[self.focal,0,self.raw_W/2],
                             [0,self.focal,self.raw_H/2],
                             [0,0,1]]).float()
        pose_raw = self.list[idx][1]
        pose = self.parse_raw_camera(opt,pose_raw)
        return intr,pose

    def parse_raw_camera(self,opt,pose_raw):
        pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1])))
        pose = camera.pose.compose([pose_flip,pose_raw[:3]])
        pose = camera.pose.invert(pose)
        pose = camera.pose.compose([pose_flip,pose])
        return pose


================================================
FILE: evaluate.py
================================================
import numpy as np
import os,sys,time
import torch
import importlib

import options
from util import log

import warnings
warnings.filterwarnings('ignore')

def main():

    log.process(os.getpid())
    log.title("[{}] (PyTorch code for evaluating NeRF/BARF/L2G_NeRF)".format(sys.argv[0]))

    opt_cmd = options.parse_arguments(sys.argv[1:])
    opt = options.set(opt_cmd=opt_cmd)

    with torch.cuda.device(opt.device):

        model = importlib.import_module("model.{}".format(opt.model))
        m = model.Model(opt)

        m.load_dataset(opt,eval_split="test")
        m.build_networks(opt)

        if opt.model in ["barf", "l2g_nerf"]:
            m.generate_videos_pose(opt)

        m.restore_checkpoint(opt)
        if opt.data.dataset in ["blender","llff"]:
            m.evaluate_full(opt)
        m.generate_videos_synthesis(opt)

if __name__=="__main__":
    main()


================================================
FILE: external/pohsun_ssim/LICENSE.txt
================================================
MIT


================================================
FILE: external/pohsun_ssim/README.md
================================================
# pytorch-ssim

### Differentiable structural similarity (SSIM) index.
![einstein](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/einstein.png) ![Max_ssim](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/max_ssim.gif)

## Installation
1. Clone this repo.
2. Copy "pytorch_ssim" folder in your project.

## Example
### basic usage
```python
import pytorch_ssim
import torch
from torch.autograd import Variable

img1 = Variable(torch.rand(1, 1, 256, 256))
img2 = Variable(torch.rand(1, 1, 256, 256))

if torch.cuda.is_available():
    img1 = img1.cuda()
    img2 = img2.cuda()

print(pytorch_ssim.ssim(img1, img2))

ssim_loss = pytorch_ssim.SSIM(window_size = 11)

print(ssim_loss(img1, img2))

```
### maximize ssim
```python
import pytorch_ssim
import torch
from torch.autograd import Variable
from torch import optim
import cv2
import numpy as np

npImg1 = cv2.imread("einstein.png")

img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
img2 = torch.rand(img1.size())

if torch.cuda.is_available():
    img1 = img1.cuda()
    img2 = img2.cuda()


img1 = Variable( img1,  requires_grad=False)
img2 = Variable( img2, requires_grad = True)


# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
ssim_value = pytorch_ssim.ssim(img1, img2).data[0]
print("Initial ssim:", ssim_value)

# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
ssim_loss = pytorch_ssim.SSIM()

optimizer = optim.Adam([img2], lr=0.01)

while ssim_value < 0.95:
    optimizer.zero_grad()
    ssim_out = -ssim_loss(img1, img2)
    ssim_value = - ssim_out.data[0]
    print(ssim_value)
    ssim_out.backward()
    optimizer.step()

```

## Reference
https://ece.uwaterloo.ca/~z70wang/research/ssim/


================================================
FILE: external/pohsun_ssim/max_ssim.py
================================================
import pytorch_ssim
import torch
from torch.autograd import Variable
from torch import optim
import cv2
import numpy as np

npImg1 = cv2.imread("einstein.png")

img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
img2 = torch.rand(img1.size())

if torch.cuda.is_available():
    img1 = img1.cuda()
    img2 = img2.cuda()


img1 = Variable( img1,  requires_grad=False)
img2 = Variable( img2, requires_grad = True)


# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
ssim_value = pytorch_ssim.ssim(img1, img2).data[0]
print("Initial ssim:", ssim_value)

# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
ssim_loss = pytorch_ssim.SSIM()

optimizer = optim.Adam([img2], lr=0.01)

while ssim_value < 0.95:
    optimizer.zero_grad()
    ssim_out = -ssim_loss(img1, img2)
    ssim_value = - ssim_out.data[0]
    print(ssim_value)
    ssim_out.backward()
    optimizer.step()


================================================
FILE: external/pohsun_ssim/pytorch_ssim/__init__.py
================================================
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from math import exp

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window

def _ssim(img1, img2, window, window_size, channel, size_average = True):
    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1*mu2

    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2

    C1 = 0.01**2
    C2 = 0.03**2

    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

class SSIM(torch.nn.Module):
    def __init__(self, window_size = 11, size_average = True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)
            
            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)
            
            self.window = window
            self.channel = channel


        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

def ssim(img1, img2, window_size = 11, size_average = True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)
    
    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)
    
    return _ssim(img1, img2, window, window_size, channel, size_average)


================================================
FILE: external/pohsun_ssim/setup.cfg
================================================
[metadata]
description-file = README.md


================================================
FILE: external/pohsun_ssim/setup.py
================================================
from distutils.core import setup
setup(
  name = 'pytorch_ssim',
  packages = ['pytorch_ssim'], # this must be the same as the name above
  version = '0.1',
  description = 'Differentiable structural similarity (SSIM) index',
  author = 'Po-Hsun (Evan) Su',
  author_email = 'evan.pohsun.su@gmail.com',
  url = 'https://github.com/Po-Hsun-Su/pytorch-ssim', # use the URL to the github repo
  download_url = 'https://github.com/Po-Hsun-Su/pytorch-ssim/archive/0.1.tar.gz', # I'll explain this in a second
  keywords = ['pytorch', 'image-processing', 'deep-learning'], # arbitrary keywords
  classifiers = [],
)


================================================
FILE: extract_mesh.py
================================================
"""Extracts a 3D mesh from a pretrained model using marching cubes."""

import importlib
import sys

import numpy as np
import options
import torch
import tqdm
import trimesh
import mcubes

from util import log,debug

opt_cmd = options.parse_arguments(sys.argv[1:])
opt = options.set(opt_cmd=opt_cmd)

with torch.cuda.device(opt.device),torch.no_grad():

    model = importlib.import_module("model.{}".format(opt.model))
    m = model.Model(opt)

    m.load_dataset(opt)
    m.build_networks(opt)
    m.restore_checkpoint(opt)

    t = torch.linspace(*opt.trimesh.range,opt.trimesh.res+1) # the best range might vary from model to model
    query = torch.stack(torch.meshgrid(t,t,t),dim=-1)
    query_flat = query.view(-1,3)

    density_all = []
    for i in tqdm.trange(0,len(query_flat),opt.trimesh.chunk_size,leave=False):
        points = query_flat[None,i:i+opt.trimesh.chunk_size].to(opt.device)
        ray_unit = torch.zeros_like(points) # dummy ray to comply with interface, not used
        _,density_samples = m.graph.nerf.forward(opt,points,ray_unit=ray_unit,mode=None)
        density_all.append(density_samples.cpu())
    density_all = torch.cat(density_all,dim=1)[0]
    density_all = density_all.view(*query.shape[:-1]).numpy()

    log.info("running marching cubes...")
    vertices,triangles = mcubes.marching_cubes(density_all,opt.trimesh.thres)
    vertices_centered = vertices/opt.trimesh.res-0.5
    mesh = trimesh.Trimesh(vertices_centered,triangles)

    obj_fname = "{}/mesh.obj".format(opt.output_path)
    log.info("saving 3D mesh to {}...".format(obj_fname))
    mesh.export(obj_fname)


================================================
FILE: model/barf.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import visdom
import matplotlib.pyplot as plt

import util,util_vis
from util import log,debug
from . import nerf
import camera

# ============================ main engine for training and evaluation ============================

class Model(nerf.Model):

    def __init__(self,opt):
        super().__init__(opt)

    def build_networks(self,opt):
        super().build_networks(opt)
        if opt.camera.noise:
            # pre-generate synthetic pose perturbation
            so3_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_r
            t_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_t
            self.graph.pose_noise = torch.cat([camera.lie.so3_to_SO3(so3_noise),t_noise[...,None]],dim=-1) # [...,3,4]
            
        self.graph.se3_refine = torch.nn.Embedding(len(self.train_data),6).to(opt.device)
        torch.nn.init.zeros_(self.graph.se3_refine.weight)

        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
        # add synthetic pose perturbation to all training data
        if opt.data.dataset=="blender":
            pose = pose_GT
            if opt.camera.noise:
                pose = camera.pose.compose([pose, self.graph.pose_noise])
        else: pose = self.graph.pose_eye[None].repeat(len(self.train_data),1,1)
        # use Embedding so it could be checkpointed
        self.graph.optimised_training_poses = torch.nn.Embedding(len(self.train_data),12,_weight=pose.view(-1,12)).to(opt.device) 

        idx_range = torch.arange(len(self.train_data),dtype=torch.long,device=opt.device)
        idx_X,idx_Y = torch.meshgrid(idx_range,idx_range)
        self.graph.idx_grid = torch.stack([idx_X,idx_Y],dim=-1).view(-1,2)

    def setup_optimizer(self,opt):
        super().setup_optimizer(opt)
        optimizer = getattr(torch.optim,opt.optim.algo)
        self.optim_pose = optimizer([dict(params=self.graph.se3_refine.parameters(),lr=opt.optim.lr_pose)])
        # set up scheduler
        if opt.optim.sched_pose:
            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type)
            if opt.optim.lr_pose_end:
                assert(opt.optim.sched_pose.type=="ExponentialLR")
                opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter)
            kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!="type" }
            self.sched_pose = scheduler(self.optim_pose,**kwargs)

    def train_iteration(self,opt,var,loader):
        self.optim_pose.zero_grad()
        if opt.optim.warmup_pose:
            # simple linear warmup of pose learning rate
            self.optim_pose.param_groups[0]["lr_orig"] = self.optim_pose.param_groups[0]["lr"] # cache the original learning rate
            self.optim_pose.param_groups[0]["lr"] *= min(1,self.it/opt.optim.warmup_pose)
        loss = super().train_iteration(opt,var,loader)
        self.optim_pose.step()
        if opt.optim.warmup_pose:
            self.optim_pose.param_groups[0]["lr"] = self.optim_pose.param_groups[0]["lr_orig"] # reset learning rate
        if opt.optim.sched_pose: self.sched_pose.step()
        self.graph.nerf.progress.data.fill_(self.it/opt.max_iter)
        if opt.nerf.fine_sampling:
            self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter)
        return loss

    @torch.no_grad()
    def validate(self,opt,ep=None):
        pose,pose_GT = self.get_all_training_poses(opt)
        _,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
        super().validate(opt,ep=ep)

    @torch.no_grad()
    def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
        if split=="train":
            # log learning rate
            lr = self.optim_pose.param_groups[0]["lr"]
            self.tb.add_scalar("{0}/{1}".format(split,"lr_pose"),lr,step)
        # compute pose error
        if split=="train" and opt.data.dataset in ["blender","llff"]:
            pose,pose_GT = self.get_all_training_poses(opt)
            pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
            error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
            self.tb.add_scalar("{0}/error_R".format(split),error.R.mean(),step)
            self.tb.add_scalar("{0}/error_t".format(split),error.t.mean(),step)

    @torch.no_grad()
    def visualize(self,opt,var,step=0,split="train"):
        super().visualize(opt,var,step=step,split=split)
        if opt.visdom:
            if split=="val":
                pose,pose_GT = self.get_all_training_poses(opt)
                pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
                util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose_aligned,pose_GT])

    @torch.no_grad()
    def get_all_training_poses(self,opt):
        # get ground-truth (canonical) camera poses
        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
        pose = self.graph.optimised_training_poses.weight.data.detach().clone().view(-1,3,4)
        return pose,pose_GT

    @torch.no_grad()
    def prealign_cameras(self,opt,pose,pose_GT):
        # compute 3D similarity transform via Procrustes analysis
        center = torch.zeros(1,1,3,device=opt.device)
        center_pred = camera.cam2world(center,pose)[:,0] # [N,3]
        center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3]
        try:
            sim3 = camera.procrustes_analysis(center_GT,center_pred)
        except:
            print("warning: SVD did not converge...")
            sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device))
        # align the camera poses
        center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0
        R_aligned = pose[...,:3]@sim3.R.t()
        t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
        pose_aligned = camera.pose(R=R_aligned,t=t_aligned)
        return pose_aligned,sim3

    @torch.no_grad()
    def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
        # measure errors in rotation and translation
        R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1)
        R_GT,t_GT = pose_GT.split([3,1],dim=-1)
        R_error = camera.rotation_distance(R_aligned,R_GT)
        t_error = (t_aligned-t_GT)[...,0].norm(dim=-1)
        error = edict(R=R_error,t=t_error)
        return error

    @torch.no_grad()
    def evaluate_full(self,opt):
        self.graph.eval()
        # evaluate rotation/translation
        pose,pose_GT = self.get_all_training_poses(opt)
        pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
        error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
        print("--------------------------")
        print("rot:   {:8.3f}".format(np.rad2deg(error.R.mean().cpu())))
        print("trans: {:10.5f}".format(error.t.mean()))
        print("--------------------------")
        # dump numbers
        quant_fname = "{}/quant_pose.txt".format(opt.output_path)
        with open(quant_fname,"w") as file:
            for i,(err_R,err_t) in enumerate(zip(error.R,error.t)):
                file.write("{} {} {}\n".format(i,err_R.item(),err_t.item()))
        # evaluate novel view synthesis
        super().evaluate_full(opt)

    @torch.enable_grad()
    def evaluate_test_time_photometric_optim(self,opt,var):
        # use another se3 Parameter to absorb the remaining pose errors
        var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device))
        optimizer = getattr(torch.optim,opt.optim.algo)
        optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)])
        iterator = tqdm.trange(opt.optim.test_iter,desc="test-time optim.",leave=False,position=1)
        for it in iterator:
            optim_pose.zero_grad()
            var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test)
            var = self.graph.forward(opt,var,mode="test-optim")
            loss = self.graph.compute_loss(opt,var,mode="test-optim")
            loss = self.summarize_loss(opt,var,loss)
            loss.all.backward()
            optim_pose.step()
            iterator.set_postfix(loss="{:.3f}".format(loss.all))
        return var

    @torch.no_grad()
    def generate_videos_pose(self,opt):
        self.graph.eval()
        fig = plt.figure(figsize=(10,10) if opt.data.dataset=="blender" else (16,8))
        cam_path = "{}/poses".format(opt.output_path)
        os.makedirs(cam_path,exist_ok=True)
        ep_list = []
        for ep in range(0,opt.max_iter+1,opt.freq.ckpt):
            # load checkpoint (0 is random init)
            if ep!=0:
                try: util.restore_checkpoint(opt,self,resume=ep)
                except: continue
            # get the camera poses
            pose,pose_ref = self.get_all_training_poses(opt)
            if opt.data.dataset in ["blender","llff"]:
                pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref)
                pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu()
                dict(
                    blender=util_vis.plot_save_poses_blender,
                    llff=util_vis.plot_save_poses,
                )[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep)
            else:
                pose = pose.detach().cpu()
                util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep)
            ep_list.append(ep)
        plt.close()
        # write videos
        print("writing videos...")
        list_fname = "{}/temp.list".format(cam_path)
        with open(list_fname,"w") as file:
            for ep in ep_list: file.write("file {}.png\n".format(ep))
        cam_vid_fname = "{}/poses.mp4".format(opt.output_path)
        os.system("ffmpeg -y -r 4 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1".format(list_fname,cam_vid_fname))
        os.remove(list_fname)

# ============================ computation graph for forward/backprop ============================

class Graph(nerf.Graph):

    def __init__(self,opt):
        super().__init__(opt)
        self.nerf = NeRF(opt)
        if opt.nerf.fine_sampling:
            self.nerf_fine = NeRF(opt)
        self.pose_eye = torch.eye(3,4).to(opt.device)

    def forward(self,opt,var,mode=None):
        # rescale the size of the scene condition on the pose
        if opt.data.dataset=="blender":
            depth_min,depth_max = opt.nerf.depth.range
            position = camera.Pose().invert(self.optimised_training_poses.weight.data.detach().clone().view(-1,3,4))[...,-1]
            diameter = ((position[self.idx_grid[...,0]]-position[self.idx_grid[...,1]]).norm(dim=-1)).max()
            depth_min_new = (depth_min/(depth_max+depth_min))*diameter
            depth_max_new = (depth_max/(depth_max+depth_min))*diameter
            opt.nerf.depth.range = [depth_min_new, depth_max_new]
        # render images
        batch_size = len(var.idx)
        pose = self.get_pose(opt,var,mode=mode)
        if opt.nerf.rand_rays and mode in ["train","test-optim"]:
            # sample random rays for optimization
            var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]
            ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
        else:
            # render full image (process in slices)
            ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \
                  self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]
        var.update(ret)
        return var

    def get_pose(self,opt,var,mode=None):
        if mode=="train":
            # add the pre-generated pose perturbations
            if opt.data.dataset=="blender":
                if opt.camera.noise:
                    var.pose_noise = self.pose_noise[var.idx]
                    pose = camera.pose.compose([var.pose, var.pose_noise])
                else: pose = var.pose
            else: pose = self.pose_eye
            # add learnable pose correction
            var.se3_refine = self.se3_refine.weight[var.idx]
            pose_refine = camera.lie.se3_to_SE3(var.se3_refine)
            pose = camera.pose.compose([pose_refine, pose])
            self.optimised_training_poses.weight.data = pose.detach().clone().view(-1,12)
        elif mode in ["val","eval","test-optim"]:
            # align test pose to refined coordinate system (up to sim3)
            sim3 = self.sim3
            center = torch.zeros(1,1,3,device=opt.device)
            center = camera.cam2world(center,var.pose)[:,0] # [N,3]
            center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1
            R_aligned = var.pose[...,:3]@self.sim3.R
            t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
            pose = camera.pose(R=R_aligned,t=t_aligned)
            # additionally factorize the remaining pose imperfection
            if opt.optim.test_photo and mode!="val":
                pose = camera.pose.compose([var.pose_refine_test, pose])
        else: pose = var.pose
        return pose

class NeRF(nerf.NeRF):

    def __init__(self,opt):
        super().__init__(opt)
        self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed

    def positional_encoding(self,opt,input,L): # [B,...,N]
        input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL]
        # coarse-to-fine: smoothly mask positional encoding for BARF
        if opt.barf_c2f is not None:
            # set weights for different frequency bands
            start,end = opt.barf_c2f
            alpha = (self.progress.data-start)/(end-start)*L
            k = torch.arange(L,dtype=torch.float32,device=opt.device)
            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
            # apply weights
            shape = input_enc.shape
            input_enc = (input_enc.view(-1,L)*weight).view(*shape)
        return input_enc


================================================
FILE: model/base.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import torch.utils.tensorboard
import visdom
import importlib
import tqdm
from easydict import EasyDict as edict

import util,util_vis
from util import log,debug

# ============================ main engine for training and evaluation ============================

class Model():

    def __init__(self,opt):
        super().__init__()
        os.makedirs(opt.output_path,exist_ok=True)

    def load_dataset(self,opt,eval_split="val"):
        data = importlib.import_module("data.{}".format(opt.data.dataset))
        log.info("loading training data...")
        self.train_data = data.Dataset(opt,split="train",subset=opt.data.train_sub)
        self.train_loader = self.train_data.setup_loader(opt,shuffle=True)
        log.info("loading test data...")
        if opt.data.val_on_test: eval_split = "test"
        self.test_data = data.Dataset(opt,split=eval_split,subset=opt.data.val_sub)
        self.test_loader = self.test_data.setup_loader(opt,shuffle=False)

    def build_networks(self,opt):
        graph = importlib.import_module("model.{}".format(opt.model))
        log.info("building networks...")
        self.graph = graph.Graph(opt).to(opt.device)

    def setup_optimizer(self,opt):
        log.info("setting up optimizers...")
        optimizer = getattr(torch.optim,opt.optim.algo)
        self.optim = optimizer([dict(params=self.graph.parameters(),lr=opt.optim.lr)])
        # set up scheduler
        if opt.optim.sched:
            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
            kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
            self.sched = scheduler(self.optim,**kwargs)

    def restore_checkpoint(self,opt):
        epoch_start,iter_start = None,None
        if opt.resume:
            log.info("resuming from previous checkpoint...")
            epoch_start,iter_start = util.restore_checkpoint(opt,self,resume=opt.resume)
        elif opt.load is not None:
            log.info("loading weights from checkpoint {}...".format(opt.load))
            epoch_start,iter_start = util.restore_checkpoint(opt,self,load_name=opt.load)
        else:
            log.info("initializing weights from scratch...")
        self.epoch_start = epoch_start or 0
        self.iter_start = iter_start or 0

    def setup_visualizer(self,opt):
        log.info("setting up visualizers...")
        if opt.tb:
            self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path,flush_secs=10)
        if opt.visdom:
            # check if visdom server is runninng
            is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port)
            retry = None
            while not is_open:
                retry = input("visdom port ({}) not open, retry? (y/n) ".format(opt.visdom.port))
                if retry not in ["y","n"]: continue
                if retry=="y":
                    is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port)
                else: break
            self.vis = visdom.Visdom(server=opt.visdom.server,port=opt.visdom.port,env=opt.group)

    def train(self,opt):
        # before training
        log.title("TRAINING START")
        self.timer = edict(start=time.time(),it_mean=None)
        self.it = self.iter_start
        # training
        if self.iter_start==0: self.validate(opt,ep=0)
        for self.ep in range(self.epoch_start,opt.max_epoch):
            self.train_epoch(opt)
        # after training
        if opt.tb:
            self.tb.flush()
            self.tb.close()
        if opt.visdom: self.vis.close()
        log.title("TRAINING DONE")

    def train_epoch(self,opt):
        # before train epoch
        self.graph.train()
        # train epoch
        loader = tqdm.tqdm(self.train_loader,desc="training epoch {}".format(self.ep+1),leave=False)
        for batch in loader:
            # train iteration
            var = edict(batch)
            var = util.move_to_device(var,opt.device)
            loss = self.train_iteration(opt,var,loader)
        # after train epoch
        lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr
        log.loss_train(opt,self.ep+1,lr,loss.all,self.timer)
        if opt.optim.sched: self.sched.step()
        if (self.ep+1)%opt.freq.val==0: self.validate(opt,ep=self.ep+1)
        if (self.ep+1)%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=self.ep+1,it=self.it)

    def train_iteration(self,opt,var,loader):
        # before train iteration
        self.timer.it_start = time.time()
        # train iteration
        self.optim.zero_grad()
        var = self.graph.forward(opt,var,mode="train")
        loss = self.graph.compute_loss(opt,var,mode="train")
        loss = self.summarize_loss(opt,var,loss)
        loss.all.backward()
        self.optim.step()
        # after train iteration
        if (self.it+1)%opt.freq.scalar==0: self.log_scalars(opt,var,loss,step=self.it+1,split="train")
        if (self.it+1)%opt.freq.vis==0: self.visualize(opt,var,step=self.it+1,split="train")
        self.it += 1
        loader.set_postfix(it=self.it,loss="{:.3f}".format(loss.all))
        self.timer.it_end = time.time()
        util.update_timer(opt,self.timer,self.ep,len(loader))
        return loss

    def summarize_loss(self,opt,var,loss):
        loss_all = 0.
        assert("all" not in loss)
        # weigh losses
        for key in loss:
            assert(key in opt.loss_weight)
            assert(loss[key].shape==())
            if opt.loss_weight[key] is not None:
                assert not torch.isinf(loss[key]),"loss {} is Inf".format(key)
                assert not torch.isnan(loss[key]),"loss {} is NaN".format(key)
                loss_all += 10**float(opt.loss_weight[key])*loss[key]
        loss.update(all=loss_all)
        return loss

    @torch.no_grad()
    def validate(self,opt,ep=None):
        self.graph.eval()
        loss_val = edict()
        loader = tqdm.tqdm(self.test_loader,desc="validating",leave=False)
        for it,batch in enumerate(loader):
            var = edict(batch)
            var = util.move_to_device(var,opt.device)
            var = self.graph.forward(opt,var,mode="val")
            loss = self.graph.compute_loss(opt,var,mode="val")
            loss = self.summarize_loss(opt,var,loss)
            for key in loss:
                loss_val.setdefault(key,0.)
                loss_val[key] += loss[key]*len(var.idx)
            loader.set_postfix(loss="{:.3f}".format(loss.all))
            if it==0: self.visualize(opt,var,step=ep,split="val")
        for key in loss_val: loss_val[key] /= len(self.test_data)
        self.log_scalars(opt,var,loss_val,step=ep,split="val")
        log.loss_val(opt,loss_val.all)

    @torch.no_grad()
    def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
        for key,value in loss.items():
            if key=="all": continue
            if opt.loss_weight[key] is not None:
                self.tb.add_scalar("{0}/loss_{1}".format(split,key),value,step)
        if metric is not None:
            for key,value in metric.items():
                self.tb.add_scalar("{0}/{1}".format(split,key),value,step)

    @torch.no_grad()
    def visualize(self,opt,var,step=0,split="train"):
        raise NotImplementedError

    def save_checkpoint(self,opt,ep=0,it=0,latest=False):
        util.save_checkpoint(opt,self,ep=ep,it=it,latest=latest)
        if not latest:
            log.info("checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})".format(opt.group,opt.name,ep,it))

# ============================ computation graph for forward/backprop ============================

class Graph(torch.nn.Module):

    def __init__(self,opt):
        super().__init__()

    def forward(self,opt,var,mode=None):
        raise NotImplementedError
        return var

    def compute_loss(self,opt,var,mode=None):
        loss = edict()
        raise NotImplementedError
        return loss

    def L1_loss(self,pred,label=0):
        loss = (pred.contiguous()-label).abs()
        return loss.mean()
    def MSE_loss(self,pred,label=0):
        loss = (pred.contiguous()-label)**2
        return loss.mean()


================================================
FILE: model/l2g_nerf.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import visdom
import matplotlib.pyplot as plt

import util,util_vis
from util import log,debug
from . import nerf
import camera

import roma
# ============================ main engine for training and evaluation ============================

class Model(nerf.Model):

    def __init__(self,opt):
        super().__init__(opt)

    def build_networks(self,opt):
        super().build_networks(opt)
        if opt.camera.noise:
            # pre-generate synthetic pose perturbation
            so3_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_r
            t_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_t
            self.graph.pose_noise = torch.cat([camera.lie.so3_to_SO3(so3_noise),t_noise[...,None]],dim=-1) # [...,3,4]

        self.graph.warp_embedding = torch.nn.Embedding(len(self.train_data),opt.arch.embedding_dim).to(opt.device)
        self.graph.warp_mlp = localWarp(opt).to(opt.device)

        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
        # add synthetic pose perturbation to all training data
        if opt.data.dataset=="blender":
            pose = pose_GT
            if opt.camera.noise:
                pose = camera.pose.compose([pose, self.graph.pose_noise])
        else: pose = self.graph.pose_eye[None].repeat(len(self.train_data),1,1)
        # use Embedding so it could be checkpointed
        self.graph.optimised_training_poses = torch.nn.Embedding(len(self.train_data),12,_weight=pose.view(-1,12)).to(opt.device)

        # auto near/far for blender dataset
        if opt.data.dataset=="blender":
            idx_range = torch.arange(len(self.train_data),dtype=torch.long,device=opt.device)
            idx_X,idx_Y = torch.meshgrid(idx_range,idx_range)
            self.graph.idx_grid = torch.stack([idx_X,idx_Y],dim=-1).view(-1,2)

        if opt.error_map_size:
            self.graph.error_map = torch.ones([len(self.train_data), opt.error_map_size*opt.error_map_size], dtype=torch.float).to(opt.device)

    def setup_optimizer(self,opt):
        super().setup_optimizer(opt)
        optimizer = getattr(torch.optim,opt.optim.algo)
        self.optim_pose = optimizer([dict(params=self.graph.warp_embedding.parameters(),lr=opt.optim.lr_pose)])
        self.optim_pose.add_param_group(dict(params=self.graph.warp_mlp.parameters(),lr=opt.optim.lr_pose))
        # set up scheduler
        if opt.optim.sched_pose:
            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type)
            if opt.optim.lr_pose_end:
                assert(opt.optim.sched_pose.type=="ExponentialLR")
                opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter)
            kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!="type" }
            self.sched_pose = scheduler(self.optim_pose,**kwargs)

    def train_iteration(self,opt,var,loader):
        self.optim_pose.zero_grad()
        if opt.optim.warmup_pose:
            # simple linear warmup of pose learning rate
            self.optim_pose.param_groups[0]["lr_orig"] = self.optim_pose.param_groups[0]["lr"] # cache the original learning rate
            self.optim_pose.param_groups[0]["lr"] *= min(1,self.it/opt.optim.warmup_pose)
        loss = super().train_iteration(opt,var,loader)
        self.optim_pose.step()
        if opt.optim.warmup_pose:
            self.optim_pose.param_groups[0]["lr"] = self.optim_pose.param_groups[0]["lr_orig"] # reset learning rate
        if opt.optim.sched_pose: self.sched_pose.step()
        self.graph.nerf.progress.data.fill_(self.it/opt.max_iter)
        if opt.nerf.fine_sampling:
            self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter)
        return loss

    @torch.no_grad()
    def validate(self,opt,ep=None):
        pose,pose_GT = self.get_all_training_poses(opt)
        _,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
        super().validate(opt,ep=ep)

    @torch.no_grad()
    def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
        if split=="train":
            # log learning rate
            lr = self.optim_pose.param_groups[0]["lr"]
            self.tb.add_scalar("{0}/{1}".format(split,"lr_pose"),lr,step)
        # compute pose error
        if split=="train" and opt.data.dataset in ["blender","llff"]:
            pose,pose_GT = self.get_all_training_poses(opt)
            pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
            error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
            self.tb.add_scalar("{0}/error_R".format(split),error.R.mean(),step)
            self.tb.add_scalar("{0}/error_t".format(split),error.t.mean(),step)

    @torch.no_grad()
    def visualize(self,opt,var,step=0,split="train"):
        super().visualize(opt,var,step=step,split=split)
        if opt.visdom:
            if split=="val":
                pose,pose_GT = self.get_all_training_poses(opt)
                pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)
                util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose_aligned,pose_GT])

    @torch.no_grad()
    def get_all_training_poses(self,opt):
        # get ground-truth (canonical) camera poses
        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
        pose = self.graph.optimised_training_poses.weight.data.detach().clone().view(-1,3,4)
        return pose,pose_GT

    @torch.no_grad()
    def prealign_cameras(self,opt,pose,pose_GT):
        # compute 3D similarity transform via Procrustes analysis
        center = torch.zeros(1,1,3,device=opt.device)
        center_pred = camera.cam2world(center,pose)[:,0] # [N,3]
        center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3]
        try:
            sim3 = camera.procrustes_analysis(center_GT,center_pred)
        except:
            print("warning: SVD did not converge...")
            sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device))
        # align the camera poses
        center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0
        R_aligned = pose[...,:3]@sim3.R.t()
        t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
        pose_aligned = camera.pose(R=R_aligned,t=t_aligned)
        return pose_aligned,sim3

    @torch.no_grad()
    def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
        # measure errors in rotation and translation
        R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1)
        R_GT,t_GT = pose_GT.split([3,1],dim=-1)
        R_error = camera.rotation_distance(R_aligned,R_GT)
        t_error = (t_aligned-t_GT)[...,0].norm(dim=-1)
        error = edict(R=R_error,t=t_error)
        return error

    @torch.no_grad()
    def evaluate_full(self,opt):
        self.graph.eval()
        # evaluate rotation/translation
        pose,pose_GT = self.get_all_training_poses(opt)
        pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)
        error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)
        print("--------------------------")
        print("rot:   {:8.3f}".format(np.rad2deg(error.R.mean().cpu())))
        print("trans: {:10.5f}".format(error.t.mean()))
        print("--------------------------")
        # dump numbers
        quant_fname = "{}/quant_pose.txt".format(opt.output_path)
        with open(quant_fname,"w") as file:
            for i,(err_R,err_t) in enumerate(zip(error.R,error.t)):
                file.write("{} {} {}\n".format(i,err_R.item(),err_t.item()))
        # evaluate novel view synthesis
        super().evaluate_full(opt)

    @torch.enable_grad()
    def evaluate_test_time_photometric_optim(self,opt,var):
        # use another se3 Parameter to absorb the remaining pose errors
        var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device))
        optimizer = getattr(torch.optim,opt.optim.algo)
        optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)])
        iterator = tqdm.trange(opt.optim.test_iter,desc="test-time optim.",leave=False,position=1)
        for it in iterator:
            optim_pose.zero_grad()
            var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test)
            var = self.graph.forward(opt,var,mode="test-optim")
            loss = self.graph.compute_loss(opt,var,mode="test-optim")
            loss = self.summarize_loss(opt,var,loss)
            loss.all.backward()
            optim_pose.step()
            iterator.set_postfix(loss="{:.3f}".format(loss.all))
        return var

    @torch.no_grad()
    def generate_videos_pose(self,opt):
        self.graph.eval()
        fig = plt.figure(figsize=(10,10) if opt.data.dataset=="blender" else (16,8))
        cam_path = "{}/poses".format(opt.output_path)
        os.makedirs(cam_path,exist_ok=True)
        ep_list = []
        for ep in range(0,opt.max_iter+1,opt.freq.ckpt):
            # load checkpoint (0 is random init)
            if ep!=0:
                try: util.restore_checkpoint(opt,self,resume=ep)
                except: continue
            # get the camera poses
            pose,pose_ref = self.get_all_training_poses(opt)
            if opt.data.dataset in ["blender","llff"]:
                pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref)
                pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu()
                dict(
                    blender=util_vis.plot_save_poses_blender,
                    llff=util_vis.plot_save_poses,
                )[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep)
            else:
                pose = pose.detach().cpu()
                util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep)
            ep_list.append(ep)
        plt.close()
        # write videos
        print("writing videos...")
        list_fname = "{}/temp.list".format(cam_path)
        with open(list_fname,"w") as file:
            for ep in ep_list: file.write("file {}.png\n".format(ep))
        cam_vid_fname = "{}/poses.mp4".format(opt.output_path)
        os.system("ffmpeg -y -r 4 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1".format(list_fname,cam_vid_fname))
        os.remove(list_fname)

# ============================ computation graph for forward/backprop ============================

class Graph(nerf.Graph):

    def __init__(self,opt):
        super().__init__(opt)
        self.nerf = NeRF(opt)
        if opt.nerf.fine_sampling:
            self.nerf_fine = NeRF(opt)
        self.pose_eye = torch.eye(3,4).to(opt.device)

    def get_pose(self,opt,var,mode=None):
        if mode=="train":
            # add the pre-generated pose perturbations
            if opt.data.dataset=="blender":
                if opt.camera.noise:
                    var.pose_noise = self.pose_noise[var.idx]
                    pose = camera.pose.compose([var.pose, var.pose_noise])
                else: pose = var.pose
            else: pose = self.pose_eye[None]
            # add learnable pose correction
            batch_size = len(var.idx)

            if opt.error_map_size:
                num_points = var.ray_idx.shape[1]
                camera_cords_grid_3D = camera.gather_camera_cords_grid_3D(opt,batch_size,intr=var.intr,ray_idx=var.ray_idx).detach()
            else:
                num_points = len(var.ray_idx)
                camera_cords_grid_3D = camera.get_camera_cords_grid_3D(opt,batch_size,intr=var.intr,ray_idx=var.ray_idx).detach()

            camera_cords_grid_2D = camera_cords_grid_3D[...,:2]
            embedding = self.warp_embedding.weight[var.idx,None,:].expand(-1,num_points,-1)
            local_se3_refine = self.warp_mlp(opt,torch.cat((camera_cords_grid_2D,embedding),dim=-1))
            local_pose_refine = camera.lie.se3_to_SE3(local_se3_refine)
            local_pose = camera.pose.compose([local_pose_refine, pose[:,None,...]])
            return local_pose

        elif mode in ["val","eval","test-optim"]:
            # align test pose to refined coordinate system (up to sim3)
            sim3 = self.sim3
            center = torch.zeros(1,1,3,device=opt.device)
            center = camera.cam2world(center,var.pose)[:,0] # [N,3]
            center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1
            R_aligned = var.pose[...,:3]@self.sim3.R
            t_aligned = (-R_aligned@center_aligned[...,None])[...,0]
            pose = camera.pose(R=R_aligned,t=t_aligned)
            if opt.optim.test_photo and mode!="val":
                pose = camera.pose.compose([var.pose_refine_test, pose])

        else: pose = var.pose
        return pose

    def forward(self,opt,var,mode=None):
        # rescale the size of the scene condition on the pose
        if opt.data.dataset=="blender":
            depth_min,depth_max = opt.nerf.depth.range
            position = camera.Pose().invert(self.optimised_training_poses.weight.data.detach().clone().view(-1,3,4))[...,-1]
            diameter = ((position[self.idx_grid[...,0]]-position[self.idx_grid[...,1]]).norm(dim=-1)).max()
            depth_min_new = (depth_min/(depth_max+depth_min))*diameter
            depth_max_new = (depth_max/(depth_max+depth_min))*diameter
            opt.nerf.depth.range = [depth_min_new, depth_max_new]

        # render images
        batch_size = len(var.idx)
        if opt.nerf.rand_rays and mode in ["train"]:
            # sample rays for optimization
            if opt.error_map_size:
                sample_weight = self.error_map + 2*self.error_map.mean(-1,keepdim=True) # 1/3 importance + 2/3 random
                var.ray_idx_coarse = torch.multinomial(sample_weight, opt.nerf.rand_rays//batch_size, replacement=False) # [B, N], but in [0, opt.error_map_size*opt.error_map_size)
                inds_x, inds_y = var.ray_idx_coarse // opt.error_map_size, var.ray_idx_coarse % opt.error_map_size # `//` will throw a warning in torch 1.10... anyway.
                sx, sy = opt.H / opt.error_map_size, opt.W / opt.error_map_size
                inds_x = (inds_x * sx + torch.rand(batch_size, opt.nerf.rand_rays//batch_size, device=opt.device) * sx).long().clamp(max=opt.H - 1)
                inds_y = (inds_y * sy + torch.rand(batch_size, opt.nerf.rand_rays//batch_size, device=opt.device) * sy).long().clamp(max=opt.W - 1)
                var.ray_idx = inds_x * opt.W + inds_y
            else:
                var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]# 3/3 random

            local_pose = self.get_pose(opt,var,mode=mode)
            ret = self.local_render(opt,local_pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
        elif opt.nerf.rand_rays and mode in ["test-optim"]:
            # sample random rays for optimization
            var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]
            pose = self.get_pose(opt,var,mode=mode)
            ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
        else:
            # render full image (process in slices)
            pose = self.get_pose(opt,var,mode=mode)
            ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \
                  self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]
        var.update(ret)
        return var

    def compute_loss(self,opt,var,mode=None):
        loss = edict()
        batch_size = len(var.idx)
        image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1)
        if opt.nerf.rand_rays and mode in ["train","test-optim"]:
            if opt.error_map_size:
                image = torch.gather(image, 1, var.ray_idx[...,None].expand(-1,-1,3))
            else:
                image = image[:,var.ray_idx]

        # compute image losses
        if opt.loss_weight.render is not None:
            render_error = ((var.rgb-image)**2).mean(-1)
            loss.render = render_error.mean()
            if mode in ["train"] and opt.error_map_size:
                ema_error = 0.1 * torch.gather(self.error_map, 1, var.ray_idx_coarse) + 0.9 * render_error.detach()
                self.error_map.scatter_(1, var.ray_idx_coarse, ema_error)

        if opt.loss_weight.render_fine is not None:
            assert(opt.nerf.fine_sampling)
            loss.render_fine = self.MSE_loss(var.rgb_fine,image)

        # global alignment
        if mode in ["train"]:
            source = torch.cat((var.camera_grid_3D,var.camera_center),dim=1)
            target = torch.cat((var.grid_3D,var.center),dim=1)
            R_global, t_global = roma.rigid_points_registration(target, source)
            svd_poses = torch.cat((R_global,t_global[...,None]),-1)
            self.optimised_training_poses.weight.data = svd_poses.detach().clone().view(-1,12)
            if opt.loss_weight.global_alignment is not None:
                loss.global_alignment = self.MSE_loss(target,camera.cam2world(source,svd_poses))

        return loss

    def local_render(self,opt,local_pose,intr=None,ray_idx=None,mode=None):
        batch_size = len(local_pose)
        if opt.error_map_size:
            camera_grid_3D = camera.gather_camera_cords_grid_3D(opt,batch_size,intr=intr,ray_idx=ray_idx).detach()
        else:
            camera_grid_3D = camera.get_camera_cords_grid_3D(opt,batch_size,intr=intr,ray_idx=ray_idx).detach()

        camera_center = torch.zeros_like(camera_grid_3D) # [B,HW,3]
        grid_3D = camera.cam2world(camera_grid_3D[...,None,:],local_pose)[...,0,:] # [B,HW,3]
        center = camera.cam2world(camera_center[...,None,:],local_pose)[...,0,:] # [B,HW,3]
        ray = grid_3D-center # [B,HW,3]
        ret = edict(camera_grid_3D=camera_grid_3D, camera_center=camera_center, grid_3D=grid_3D, center=center) # [B,HW,K] use for global alignment
        if opt.camera.ndc:
            # convert center/ray representations to NDC
            center,ray = camera.convert_NDC(opt,center,ray,intr=intr)
        # render with main MLP
        depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1]
        rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode)
        rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples)
        ret.update(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K]
        # render with fine MLP from coarse MLP
        if opt.nerf.fine_sampling:
            with torch.no_grad():
                # resample depth acoording to coarse empirical distribution
                depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1]
                depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1]
                depth_samples = depth_samples.sort(dim=2).values
            rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode)
            rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples)
            ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K]
        return ret

class NeRF(nerf.NeRF):

    def __init__(self,opt):
        super().__init__(opt)
        self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed

    def positional_encoding(self,opt,input,L): # [B,...,N]
        input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL]
        # coarse-to-fine: smoothly mask positional encoding for BARF
        if opt.barf_c2f is not None:
            # set weights for different frequency bands
            start,end = opt.barf_c2f
            alpha = (self.progress.data-start)/(end-start)*L
            k = torch.arange(L,dtype=torch.float32,device=opt.device)
            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
            # apply weights
            shape = input_enc.shape
            input_enc = (input_enc.view(-1,L)*weight).view(*shape)
        return input_enc

class localWarp(torch.nn.Module):
    def __init__(self, opt):
        super().__init__()
        # point-wise se3 prediction
        input_2D_dim = 2
        self.mlp_warp = torch.nn.ModuleList()
        L = util.get_layer_dims(opt.arch.layers_warp)
        for li,(k_in,k_out) in enumerate(L):
            if li==0: k_in = input_2D_dim+opt.arch.embedding_dim
            if li in opt.arch.skip_warp: k_in += input_2D_dim+opt.arch.embedding_dim
            linear = torch.nn.Linear(k_in,k_out)
            self.mlp_warp.append(linear)

    def forward(self,opt,uvf):
        feat = uvf
        for li,layer in enumerate(self.mlp_warp):
            if li in opt.arch.skip_warp: feat = torch.cat([feat,uvf],dim=-1)
            feat = layer(feat)
            if li!=len(self.mlp_warp)-1:
                feat = torch_F.relu(feat)
        warp = feat
        return warp

================================================
FILE: model/l2g_planar.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import PIL
import PIL.Image,PIL.ImageDraw
import imageio

import util,util_vis
from util import log,debug
from . import base
import warp
import roma
from kornia.geometry.homography import find_homography_dlt
# ============================ main engine for training and evaluation ============================

class Model(base.Model):

    def __init__(self,opt):
        super().__init__(opt)
        opt.H_crop,opt.W_crop = opt.data.patch_crop

    def load_dataset(self,opt,eval_split=None):
        image_raw = PIL.Image.open(opt.data.image_fname).convert('RGB')
        self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device)

    def build_networks(self,opt):
        super().build_networks(opt)
        self.graph.warp_embedding = torch.nn.Embedding(opt.batch_size,opt.arch.embedding_dim).to(opt.device)
        self.graph.warp_mlp = localWarp(opt).to(opt.device)

    def setup_optimizer(self,opt):
        log.info("setting up optimizers...")
        optim_list = [
            dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr),
        ]
        optimizer = getattr(torch.optim,opt.optim.algo)
        self.optim = optimizer(optim_list)
        # set up scheduler
        if opt.optim.sched:
            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
            if opt.optim.lr_end:
                assert(opt.optim.sched.type=="ExponentialLR")
                opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter)
            kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
            self.sched = scheduler(self.optim,**kwargs)

        # warp 
        self.optim_warp = optimizer([dict(params=self.graph.warp_embedding.parameters(),lr=opt.optim.lr_warp)])
        self.optim_warp.add_param_group(dict(params=self.graph.warp_mlp.parameters(),lr=opt.optim.lr_warp))
        # set up scheduler
        if opt.optim.sched_warp:
            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_warp.type)
            if opt.optim.lr_warp_end:
                assert(opt.optim.sched_warp.type=="ExponentialLR")
                opt.optim.sched_warp.gamma = (opt.optim.lr_warp_end/opt.optim.lr_warp)**(1./opt.max_iter)
            kwargs = { k:v for k,v in opt.optim.sched_warp.items() if k!="type" }
            self.sched_warp = scheduler(self.optim_warp,**kwargs)


    def setup_visualizer(self,opt):
        super().setup_visualizer(opt)
        # set colors for visualization
        box_colors = ["#ff0000","#40afff","#9314ff","#ffd700","#00ff00"]
        box_colors = list(map(util.colorcode_to_number,box_colors))
        self.box_colors = np.array(box_colors).astype(int)
        assert(len(self.box_colors)==opt.batch_size)
        # create visualization directory
        self.vis_path = "{}/vis".format(opt.output_path)
        os.makedirs(self.vis_path,exist_ok=True)
        self.video_fname = "{}/vis.mp4".format(opt.output_path)

    def train(self,opt):
        # before training
        log.title("TRAINING START")
        self.timer = edict(start=time.time(),it_mean=None)
        self.ep = self.it = self.vis_it = 0
        self.graph.train()
        var = edict(idx=torch.arange(opt.batch_size))
        # pre-generate perturbations
        self.homo_pert, self.rot_pert, self.trans_pert, var.image_pert = self.generate_warp_perturbation(opt)
        # train
        var = util.move_to_device(var,opt.device)
        loader = tqdm.trange(opt.max_iter,desc="training",leave=False)
        # visualize initial state
        var = self.graph.forward(opt,var)
        self.visualize(opt,var,step=0)
        for it in loader:
            # train iteration
            loss = self.train_iteration(opt,var,loader)
        # after training
        os.system("ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}".format(self.vis_path,self.video_fname))
        self.save_checkpoint(opt,ep=None,it=self.it)
        if opt.tb:
            self.tb.flush()
            self.tb.close()
        if opt.visdom: self.vis.close()
        log.title("TRAINING DONE")

    def train_iteration(self,opt,var,loader):
        self.optim_warp.zero_grad()
        loss = super().train_iteration(opt,var,loader)
        self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter)
        self.optim_warp.step()
        if opt.optim.sched_warp: self.sched_warp.step()
        if opt.optim.sched: self.sched.step()
        return loss

    def generate_warp_perturbation(self,opt):
        # pre-generate perturbations (translational noise + homography noise)
        def create_random_perturbation(batch_size):
            if opt.warp.dof==1:
                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
                trans_pert = torch.zeros(batch_size,2,device=opt.device)*opt.warp.noise_t
            elif opt.warp.dof==2:
                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
                rot_pert = torch.zeros(batch_size,1,device=opt.device)*opt.warp.noise_r
                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
            elif opt.warp.dof==3:
                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
            elif opt.warp.dof==8:
                homo_pert = torch.randn(batch_size,8,device=opt.device)*opt.warp.noise_h
                homo_pert[:,:2]=0
                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
            else: assert(False)
            return homo_pert, rot_pert, trans_pert

        homo_pert = torch.zeros(opt.batch_size,8,device=opt.device)
        rot_pert = torch.zeros(opt.batch_size,1,device=opt.device)
        trans_pert = torch.zeros(opt.batch_size,2,device=opt.device)
        for i in range(opt.batch_size):
            homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)
            while not warp.check_corners_in_range_compose(opt, homo_pert_i, rot_pert_i, trans_pert_i):
                homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)
            homo_pert[i], rot_pert[i], trans_pert[i] = homo_pert_i, rot_pert_i, trans_pert_i

        if opt.warp.fix_first:
            homo_pert[0],rot_pert[0],trans_pert[0] = 0,0,0

        # create warped image patches
        xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2]

        xy_grid_hom = warp.camera.to_hom(xy_grid)
        warp_matrix = warp.lie.sl3_to_SL3(homo_pert)
        warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)
        xy_grid_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]
        warp_matrix = warp.lie.so2_to_SO2(rot_pert)
        xy_grid_warped = xy_grid_warped@warp_matrix.transpose(-2,-1) # [B,HW,2]
        xy_grid_warped = xy_grid_warped+trans_pert[...,None,:]

        xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2])
        xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W,
                                      xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1)
        image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1)
        image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False)

        return homo_pert, rot_pert, trans_pert, image_pert_all

    def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):
        image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
        draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
        draw = PIL.ImageDraw.Draw(draw_pil)
        corners_all = warp.warp_corners_compose(opt, homo_pert, rot_pert, trans_pert)
        corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
        corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
        for i,corners in enumerate(corners_all):
            P = [tuple(float(n) for n in corners[j]) for j in range(4)]
            draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
        image_pil.alpha_composite(draw_pil)
        image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
        return image_tensor

    def visualize_patches_use_matrix(self,opt,warp_matrix):
        image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
        draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
        draw = PIL.ImageDraw.Draw(draw_pil)
        corners_all = warp.warp_corners_use_matrix(opt,warp_matrix)
        corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
        corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
        for i,corners in enumerate(corners_all):
            P = [tuple(float(n) for n in corners[j]) for j in range(4)]
            draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
        image_pil.alpha_composite(draw_pil)
        image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
        return image_tensor

    @torch.no_grad()
    def predict_entire_image(self,opt):
        xy_grid = warp.get_normalized_pixel_grid(opt)[:1]
        rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3]
        image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1)
        return image

    @torch.no_grad()
    def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
        # compute PSNR
        psnr = -10*loss.render.log10()
        self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step)
        # warp error
        pred_corners = warp.warp_corners_use_matrix(opt,self.graph.warp_matrix)
        pred_corners[...,0] = (pred_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
        pred_corners[...,1] = (pred_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
        
        gt_corners = warp.warp_corners_compose(opt, self.homo_pert, self.rot_pert, self.trans_pert)
        gt_corners[...,0] = (gt_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
        gt_corners[...,1] = (gt_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5

        warp_error = (pred_corners-gt_corners).norm(dim=-1).mean()
        self.tb.add_scalar("{0}/{1}".format(split,"warp error"),warp_error,step)


    @torch.no_grad()
    def visualize(self,opt,var,step=0,split="train"):
        # dump frames for writing to video
        frame_GT = self.visualize_patches_compose(opt,self.homo_pert, self.rot_pert, self.trans_pert)
        frame = self.visualize_patches_use_matrix(opt,self.graph.warp_matrix)
        frame2 = self.predict_entire_image(opt)
        frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy()
        imageio.imsave("{}/{}.png".format(self.vis_path,self.vis_it),frame_cat)
        self.vis_it += 1
        # visualize in Tensorboard
        if opt.tb:
            colors = self.box_colors
            util_vis.tb_image(opt,self.tb,step,split,"image_pert",util_vis.color_border(var.image_pert,colors))
            util_vis.tb_image(opt,self.tb,step,split,"rgb_warped",util_vis.color_border(var.rgb_warped_map,colors))
            util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes",frame[None])
            util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes_GT",frame_GT[None])
            util_vis.tb_image(opt,self.tb,self.it+1,"train","image_entire",frame2[None])
            local_warp_field = torch.cat([(var.local_warp_field+1)/2,torch.zeros_like(var.local_warp_field[:,:1,...])],dim=1) # [B,3,H,W]
            global_warp_field = torch.cat([(var.global_warp_field+1)/2,torch.zeros_like(var.global_warp_field[:,:1,...])],dim=1) # [B,3,H,W]            
            util_vis.tb_image(opt,self.tb,step,split,"warp_field_local",util_vis.color_border(local_warp_field,colors))
            util_vis.tb_image(opt,self.tb,step,split,"warp_field_global",util_vis.color_border(global_warp_field,colors))

# ============================ computation graph for forward/backprop ============================

class Graph(base.Graph):

    def __init__(self,opt):
        super().__init__(opt)
        self.neural_image = NeuralImageFunction(opt)

    def forward(self,opt,var,mode=None):
        # warp
        xy_grid = warp.get_normalized_pixel_grid_crop(opt)
        warp_embedding = self.warp_embedding.weight[:,None,:].expand(-1,xy_grid.shape[1],-1)
        local_warp_param = self.warp_mlp(opt,torch.cat((xy_grid,warp_embedding),dim=-1))
        if opt.warp.fix_first:
            local_warp_param[0] = 0

        if opt.warp.dof==1:
            local_warp_matrix = warp.lie.so2_to_SO2(local_warp_param)
            var.local_warped_grid = (xy_grid[...,None,:]@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]
            self.warp_matrix = roma.rigid_vectors_registration(xy_grid, var.local_warped_grid)
            var.global_warped_grid = xy_grid@self.warp_matrix.transpose(-2,-1) # [B,HW,2]
        elif opt.warp.dof==2:
            var.local_warped_grid = xy_grid + local_warp_param
            self.warp_matrix = var.local_warped_grid.mean(-2)-xy_grid.mean(-2)
            var.global_warped_grid = xy_grid + self.warp_matrix[...,None,:]
        elif opt.warp.dof==3:
            xy_grid_hom = warp.camera.to_hom(xy_grid[...,None,:])
            local_warp_matrix = warp.lie.se2_to_SE2(local_warp_param)
            var.local_warped_grid = (xy_grid_hom@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]
            R_global, t_global = roma.rigid_points_registration(xy_grid, var.local_warped_grid)
            self.warp_matrix = torch.cat((R_global,t_global[...,None]),-1)
            xy_grid_hom = warp.camera.to_hom(xy_grid)
            var.global_warped_grid = xy_grid_hom@self.warp_matrix.transpose(-2,-1) # [B,HW,2]
        elif opt.warp.dof==8:
            xy_grid_hom = warp.camera.to_hom(xy_grid[...,None,:])
            local_warp_matrix = warp.lie.se2_to_SE2(local_warp_param)
            var.local_warped_grid = (xy_grid_hom@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]
            self.warp_matrix = find_homography_dlt(xy_grid, var.local_warped_grid)
            xy_grid_hom = warp.camera.to_hom(xy_grid)
            global_warped_grid = xy_grid_hom@self.warp_matrix.transpose(-2,-1)
            var.global_warped_grid = global_warped_grid[...,:2]/(global_warped_grid[...,2:]+1e-8) # [B,HW,2]

        # render images
        var.rgb_warped = self.neural_image.forward(opt,var.local_warped_grid) # [B,HW,3]
        var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W]
        var.local_warp_field = var.local_warped_grid.view(opt.batch_size,opt.H_crop,opt.W_crop,2).permute(0,3,1,2) # [B,2,H,W]
        var.global_warp_field = var.global_warped_grid.view(opt.batch_size,opt.H_crop,opt.W_crop,2).permute(0,3,1,2) # [B,2,H,W]
        return var

    def compute_loss(self,opt,var,mode=None):
        loss = edict()
        if opt.loss_weight.render is not None:
            image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1)
            loss.render = self.MSE_loss(var.rgb_warped,image_pert)
        if opt.loss_weight.global_alignment is not None:
            loss.global_alignment = self.MSE_loss(var.local_warped_grid,var.global_warped_grid)
        return loss

class NeuralImageFunction(torch.nn.Module):

    def __init__(self,opt):
        super().__init__()
        self.define_network(opt)
        self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed

    def define_network(self,opt):
        input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2
        # point-wise RGB prediction
        self.mlp = torch.nn.ModuleList()
        L = util.get_layer_dims(opt.arch.layers)
        for li,(k_in,k_out) in enumerate(L):
            if li==0: k_in = input_2D_dim
            if li in opt.arch.skip: k_in += input_2D_dim
            linear = torch.nn.Linear(k_in,k_out)
            if opt.barf_c2f and li==0:
                # rescale first layer init (distribution was for pos.enc. but only xy is first used)
                scale = np.sqrt(input_2D_dim/2.)
                linear.weight.data *= scale
                linear.bias.data *= scale
            self.mlp.append(linear)

    def forward(self,opt,coord_2D): # [B,...,3]
        if opt.arch.posenc:
            points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D)
            points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3]
        else: points_enc = coord_2D
        feat = points_enc
        # extract implicit features
        for li,layer in enumerate(self.mlp):
            if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)
            feat = layer(feat)
            if li!=len(self.mlp)-1:
                feat = torch_F.relu(feat)
        rgb = feat.sigmoid_() # [B,...,3]
        return rgb

    def positional_encoding(self,opt,input,L): # [B,...,N]
        shape = input.shape
        freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]
        spectrum = input[...,None]*freq # [B,...,N,L]
        sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
        input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
        input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
        # coarse-to-fine: smoothly mask positional encoding for BARF
        if opt.barf_c2f is not None:
            # set weights for different frequency bands
            start,end = opt.barf_c2f
            alpha = (self.progress.data-start)/(end-start)*L
            k = torch.arange(L,dtype=torch.float32,device=opt.device)
            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
            # apply weights
            shape = input_enc.shape
            input_enc = (input_enc.view(-1,L)*weight).view(*shape)
        return input_enc

class localWarp(torch.nn.Module):
    def __init__(self, opt):
        super().__init__()
        # point-wise se3 prediction
        input_2D_dim = 2
        self.mlp_warp = torch.nn.ModuleList()
        L = util.get_layer_dims(opt.arch.layers_warp)
        for li,(k_in,k_out) in enumerate(L):
            if li==0: k_in = input_2D_dim+opt.arch.embedding_dim
            if li in opt.arch.skip_warp: k_in += input_2D_dim+opt.arch.embedding_dim
            linear = torch.nn.Linear(k_in,k_out)
            self.mlp_warp.append(linear)

    def forward(self,opt,uvf):
        feat = uvf
        for li,layer in enumerate(self.mlp_warp):
            if li in opt.arch.skip_warp: feat = torch.cat([feat,uvf],dim=-1)
            feat = layer(feat)
            if li!=len(self.mlp_warp)-1:
                feat = torch_F.relu(feat)
        warp = feat
        return warp

================================================
FILE: model/nerf.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict

import lpips
from external.pohsun_ssim import pytorch_ssim

import util,util_vis
from util import log,debug
from . import base
import camera
# ============================ main engine for training and evaluation ============================

class Model(base.Model):

    def __init__(self,opt):
        super().__init__(opt)
        self.lpips_loss = lpips.LPIPS(net="alex").to(opt.device)

    def load_dataset(self,opt,eval_split="val"):
        super().load_dataset(opt,eval_split=eval_split)
        # prefetch all training data
        self.train_data.prefetch_all_data(opt)
        self.train_data.all = edict(util.move_to_device(self.train_data.all,opt.device))

    def setup_optimizer(self,opt):
        log.info("setting up optimizers...")
        optimizer = getattr(torch.optim,opt.optim.algo)
        self.optim = optimizer([dict(params=self.graph.nerf.parameters(),lr=opt.optim.lr)])
        if opt.nerf.fine_sampling:
            self.optim.add_param_group(dict(params=self.graph.nerf_fine.parameters(),lr=opt.optim.lr))
        # set up scheduler
        if opt.optim.sched:
            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
            if opt.optim.lr_end:
                assert(opt.optim.sched.type=="ExponentialLR")
                opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter)
            kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
            self.sched = scheduler(self.optim,**kwargs)

    def train(self,opt):
        # before training
        log.title("TRAINING START")
        self.timer = edict(start=time.time(),it_mean=None)
        self.graph.train()
        self.ep = 0 # dummy for timer
        # training
        if self.iter_start==0: self.validate(opt,0)
        loader = tqdm.trange(opt.max_iter,desc="training",leave=False)
        for self.it in loader:
            if self.it<self.iter_start: continue
            # set var to all available images
            var = self.train_data.all
            self.train_iteration(opt,var,loader)
            if opt.optim.sched: self.sched.step()
            if self.it%opt.freq.val==0: self.validate(opt,self.it)
            if self.it%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=None,it=self.it)
        # after training
        if opt.tb:
            self.tb.flush()
            self.tb.close()
        if opt.visdom: self.vis.close()
        log.title("TRAINING DONE")

    @torch.no_grad()
    def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
        # log learning rate
        if split=="train":
            lr = self.optim.param_groups[0]["lr"]
            self.tb.add_scalar("{0}/{1}".format(split,"lr"),lr,step)
            if opt.nerf.fine_sampling:
                lr = self.optim.param_groups[1]["lr"]
                self.tb.add_scalar("{0}/{1}".format(split,"lr_fine"),lr,step)
        # compute PSNR
        psnr = -10*loss.render.log10()
        self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step)
        if opt.nerf.fine_sampling:
            psnr = -10*loss.render_fine.log10()
            self.tb.add_scalar("{0}/{1}".format(split,"PSNR_fine"),psnr,step)

    @torch.no_grad()
    def visualize(self,opt,var,step=0,split="train",eps=1e-10):
        if opt.tb:
            util_vis.tb_image(opt,self.tb,step,split,"image",var.image)
            if not opt.nerf.rand_rays or split!="train":
                if opt.model in ["barf","l2g_nerf"]:
                    scale = self.graph.sim3.s1/self.graph.sim3.s0
                else: scale = 1
                depth=var.depth/scale
                if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1)
                invdepth = (1-depth)/var.opacity if opt.camera.ndc else 1/(depth/var.opacity+eps)
                invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno")
                rgb_map = var.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
                invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]
                util_vis.tb_image(opt,self.tb,step,split,"rgb",rgb_map)
                util_vis.tb_image(opt,self.tb,step,split,"invdepth",invdepth_map)
                if opt.nerf.fine_sampling:
                    depth=var.depth_fine/scale
                    if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1)
                    invdepth = (1-depth)/var.opacity_fine if opt.camera.ndc else 1/(depth/var.opacity_fine+eps)
                    invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno")
                    rgb_map = var.rgb_fine.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
                    invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]
                    util_vis.tb_image(opt,self.tb,step,split,"rgb_fine",rgb_map)
                    util_vis.tb_image(opt,self.tb,step,split,"invdepth_fine",invdepth_map)

    @torch.no_grad()
    def get_all_training_poses(self,opt):
        # get ground-truth (canonical) camera poses
        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)
        return None,pose_GT

    @torch.no_grad()
    def evaluate_full(self,opt,eps=1e-10):
        self.graph.eval()
        loader = tqdm.tqdm(self.test_loader,desc="evaluating",leave=False)
        res = []
        test_path = "{}/test_view".format(opt.output_path)
        os.makedirs(test_path,exist_ok=True)
        for i,batch in enumerate(loader):
            var = edict(batch)
            var = util.move_to_device(var,opt.device)
            if opt.model in ["barf","l2g_nerf"] and opt.optim.test_photo:
                # run test-time optimization to factorize imperfection in optimized poses from view synthesis evaluation
                var = self.evaluate_test_time_photometric_optim(opt,var)
            var = self.graph.forward(opt,var,mode="eval")
            # evaluate view synthesis
            if opt.model in ["barf","l2g_nerf"]:
                scale = self.graph.sim3.s1/self.graph.sim3.s0
            else: scale = 1
            depth = var.depth/scale
            if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1)
            invdepth = (1-depth)/var.opacity if opt.camera.ndc else 1/(depth/var.opacity+eps)
            invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno")
            rgb_map = var.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
            invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]
            psnr = -10*self.graph.MSE_loss(rgb_map,var.image).log10().item()
            ssim = pytorch_ssim.ssim(rgb_map,var.image).item()
            lpips = self.lpips_loss(rgb_map*2-1,var.image*2-1).item()
            res.append(edict(psnr=psnr,ssim=ssim,lpips=lpips))
            # dump novel views
            torchvision_F.to_pil_image(rgb_map.cpu()[0]).save("{}/rgb_{}.png".format(test_path,i))
            torchvision_F.to_pil_image(var.image.cpu()[0]).save("{}/rgb_GT_{}.png".format(test_path,i))
            torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save("{}/depth_{}.png".format(test_path,i))
        # show results in terminal
        print("--------------------------")
        print("PSNR:  {:8.2f}".format(np.mean([r.psnr for r in res])))
        print("SSIM:  {:8.2f}".format(np.mean([r.ssim for r in res])))
        print("LPIPS: {:8.2f}".format(np.mean([r.lpips for r in res])))
        print("--------------------------")
        # dump numbers to file
        quant_fname = "{}/quant.txt".format(opt.output_path)
        with open(quant_fname,"w") as file:
            for i,r in enumerate(res):
                file.write("{} {} {} {}\n".format(i,r.psnr,r.ssim,r.lpips))

    @torch.no_grad()
    def generate_videos_synthesis(self,opt,eps=1e-10):
        self.graph.eval()
        if opt.data.dataset=="blender":
            test_path = "{}/test_view".format(opt.output_path)
            # assume the test view synthesis are already generated
            print("writing videos...")
            rgb_vid_fname = "{}/test_view_rgb.mp4".format(opt.output_path)
            depth_vid_fname = "{}/test_view_depth.mp4".format(opt.output_path)
            os.system("ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(test_path,rgb_vid_fname))
            os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(test_path,depth_vid_fname))
        else:
            pose_pred,pose_GT = self.get_all_training_poses(opt)
            poses = pose_pred if opt.model in ["barf","l2g_nerf"] else pose_GT
            if opt.model in ["barf","l2g_nerf"] and opt.data.dataset=="llff":
                _,sim3 = self.prealign_cameras(opt,pose_pred,pose_GT)
                scale = sim3.s1/sim3.s0
            else: scale = 1
            # rotate novel views around the "center" camera of all poses
            idx_center = (poses-poses.mean(dim=0,keepdim=True))[...,3].norm(dim=-1).argmin()
            pose_novel = camera.get_novel_view_poses(opt,poses[idx_center],N=60,scale=scale).to(opt.device)
            # render the novel views
            novel_path = "{}/novel_view".format(opt.output_path)
            os.makedirs(novel_path,exist_ok=True)
            pose_novel_tqdm = tqdm.tqdm(pose_novel,desc="rendering novel views",leave=False)
            intr = edict(next(iter(self.test_loader))).intr[:1].to(opt.device) # grab intrinsics
            for i,pose in enumerate(pose_novel_tqdm):
                ret = self.graph.render_by_slices(opt,pose[None],intr=intr) if opt.nerf.rand_rays else \
                      self.graph.render(opt,pose[None],intr=intr)
                depth=ret.depth/scale
                if opt.data.dataset=="blender": depth = torch.clip(depth - 1.5, min=1)
                invdepth = (1-depth)/ret.opacity if opt.camera.ndc else 1/(depth/ret.opacity+eps)
                invdepth = util_vis.apply_depth_colormap(invdepth, cmap="inferno")
                rgb_map = ret.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]
                invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]
                torchvision_F.to_pil_image(rgb_map.cpu()[0]).save("{}/rgb_{}.png".format(novel_path,i))
                torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save("{}/depth_{}.png".format(novel_path,i))

            # write videos
            print("writing videos...")
            rgb_vid_fname = "{}/novel_view_rgb.mp4".format(opt.output_path)
            depth_vid_fname = "{}/novel_view_depth.mp4".format(opt.output_path)
            os.system("ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,rgb_vid_fname))
            os.system("ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1".format(novel_path,depth_vid_fname))

# ============================ computation graph for forward/backprop ============================

class Graph(base.Graph):

    def __init__(self,opt):
        super().__init__(opt)
        self.nerf = NeRF(opt)
        if opt.nerf.fine_sampling:
            self.nerf_fine = NeRF(opt)

    def forward(self,opt,var,mode=None):
        batch_size = len(var.idx)
        pose = self.get_pose(opt,var,mode=mode)
        # render images
        if opt.nerf.rand_rays and mode in ["train","test-optim"]:
            # sample random rays for optimization
            var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]
            ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]
        else:
            # render full image (process in slices)
            ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \
                  self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]
        var.update(ret)
        return var

    def compute_loss(self,opt,var,mode=None):
        loss = edict()
        batch_size = len(var.idx)
        image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1)
        if opt.nerf.rand_rays and mode in ["train","test-optim"]:
            image = image[:,var.ray_idx]
        # compute image losses
        if opt.loss_weight.render is not None:
            loss.render = self.MSE_loss(var.rgb,image)
        if opt.loss_weight.render_fine is not None:
            assert(opt.nerf.fine_sampling)
            loss.render_fine = self.MSE_loss(var.rgb_fine,image)
        return loss

    def get_pose(self,opt,var,mode=None):
        return var.pose

    def render(self,opt,pose,intr=None,ray_idx=None,mode=None):
        batch_size = len(pose)
        center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3]
        while ray.isnan().any(): # TODO: weird bug, ray becomes NaN arbitrarily if batch_size>1, not deterministic reproducible
            center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3]
        if ray_idx is not None:
            # consider only subset of rays
            center,ray = center[:,ray_idx],ray[:,ray_idx]
        if opt.camera.ndc:
            # convert center/ray representations to NDC
            center,ray = camera.convert_NDC(opt,center,ray,intr=intr)
        # render with main MLP
        depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1]
        rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode)
        rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples)
        ret = edict(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K]
        # render with fine MLP from coarse MLP
        if opt.nerf.fine_sampling:
            with torch.no_grad():
                # resample depth acoording to coarse empirical distribution
                depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1]
                depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1]
                depth_samples = depth_samples.sort(dim=2).values
            rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode)
            rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples)
            ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K]
        return ret

    def render_by_slices(self,opt,pose,intr=None,mode=None):
        ret_all = edict(rgb=[],depth=[],opacity=[])
        if opt.nerf.fine_sampling:
            ret_all.update(rgb_fine=[],depth_fine=[],opacity_fine=[])
        # render the image by slices for memory considerations
        for c in range(0,opt.H*opt.W,opt.nerf.rand_rays):
            ray_idx = torch.arange(c,min(c+opt.nerf.rand_rays,opt.H*opt.W),device=opt.device)
            ret = self.render(opt,pose,intr=intr,ray_idx=ray_idx,mode=mode) # [B,R,3],[B,R,1]
            for k in ret: ret_all[k].append(ret[k])
        # group all slices of images
        for k in ret_all: ret_all[k] = torch.cat(ret_all[k],dim=1)
        return ret_all

    def sample_depth(self,opt,batch_size,num_rays=None):
        depth_min,depth_max = opt.nerf.depth.range
        num_rays = num_rays or opt.H*opt.W
        rand_samples = torch.rand(batch_size,num_rays,opt.nerf.sample_intvs,1,device=opt.device) if opt.nerf.sample_stratified else 0.5
        rand_samples += torch.arange(opt.nerf.sample_intvs,device=opt.device)[None,None,:,None].float() # [B,HW,N,1]
        depth_samples = rand_samples/opt.nerf.sample_intvs*(depth_max-depth_min)+depth_min # [B,HW,N,1]
        depth_samples = dict(
            metric=depth_samples,
            inverse=1/(depth_samples+1e-8),
        )[opt.nerf.depth.param]
        return depth_samples

    def sample_depth_from_pdf(self,opt,pdf):
        depth_min,depth_max = opt.nerf.depth.range
        # get CDF from PDF (along last dimension)
        cdf = pdf.cumsum(dim=-1) # [B,HW,N]
        cdf = torch.cat([torch.zeros_like(cdf[...,:1]),cdf],dim=-1) # [B,HW,N+1]
        # take uniform samples
        grid = torch.linspace(0,1,opt.nerf.sample_intvs_fine+1,device=opt.device) # [Nf+1]
        unif = 0.5*(grid[:-1]+grid[1:]).repeat(*cdf.shape[:-1],1) # [B,HW,Nf]
        idx = torch.searchsorted(cdf,unif,right=True) # [B,HW,Nf] \in {1...N}
        # inverse transform sampling from CDF
        depth_bin = torch.linspace(depth_min,depth_max,opt.nerf.sample_intvs+1,device=opt.device) # [N+1]
        depth_bin = depth_bin.repeat(*cdf.shape[:-1],1) # [B,HW,N+1]
        depth_low = depth_bin.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf]
        depth_high = depth_bin.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf]
        cdf_low = cdf.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf]
        cdf_high = cdf.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf]
        # linear interpolation
        t = (unif-cdf_low)/(cdf_high-cdf_low+1e-8) # [B,HW,Nf]
        depth_samples = depth_low+t*(depth_high-depth_low) # [B,HW,Nf]
        return depth_samples[...,None] # [B,HW,Nf,1]

class NeRF(torch.nn.Module):

    def __init__(self,opt):
        super().__init__()
        self.define_network(opt)

    def define_network(self,opt):
        input_3D_dim = 3+6*opt.arch.posenc.L_3D if opt.arch.posenc else 3
        if opt.nerf.view_dep:
            input_view_dim = 3+6*opt.arch.posenc.L_view if opt.arch.posenc else 3
        # point-wise feature
        self.mlp_feat = torch.nn.ModuleList()
        L = util.get_layer_dims(opt.arch.layers_feat)
        for li,(k_in,k_out) in enumerate(L):
            if li==0: k_in = input_3D_dim
            if li in opt.arch.skip: k_in += input_3D_dim
            if li==len(L)-1: k_out += 1
            linear = torch.nn.Linear(k_in,k_out)
            if opt.arch.tf_init:
                self.tensorflow_init_weights(opt,linear,out="first" if li==len(L)-1 else None)
            self.mlp_feat.append(linear)
        # RGB prediction
        self.mlp_rgb = torch.nn.ModuleList()
        L = util.get_layer_dims(opt.arch.layers_rgb)
        feat_dim = opt.arch.layers_feat[-1]
        for li,(k_in,k_out) in enumerate(L):
            if li==0: k_in = feat_dim+(input_view_dim if opt.nerf.view_dep else 0)
            linear = torch.nn.Linear(k_in,k_out)
            if opt.arch.tf_init:
                self.tensorflow_init_weights(opt,linear,out="all" if li==len(L)-1 else None)
            self.mlp_rgb.append(linear)

    def tensorflow_init_weights(self,opt,linear,out=None):
        # use Xavier init instead of Kaiming init
        relu_gain = torch.nn.init.calculate_gain("relu") # sqrt(2)
        if out=="all":
            torch.nn.init.xavier_uniform_(linear.weight)
        elif out=="first":
            torch.nn.init.xavier_uniform_(linear.weight[:1])
            torch.nn.init.xavier_uniform_(linear.weight[1:],gain=relu_gain)
        else:
            torch.nn.init.xavier_uniform_(linear.weight,gain=relu_gain)
        torch.nn.init.zeros_(linear.bias)

    def forward(self,opt,points_3D,ray_unit=None,mode=None): # [B,...,3]
        if opt.arch.posenc:
            points_enc = self.positional_encoding(opt,points_3D,L=opt.arch.posenc.L_3D)
            points_enc = torch.cat([points_3D,points_enc],dim=-1) # [B,...,6L+3]
        else: points_enc = points_3D
        feat = points_enc
        # extract coordinate-based features
        for li,layer in enumerate(self.mlp_feat):
            if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)
            feat = layer(feat)
            if li==len(self.mlp_feat)-1:
                density = feat[...,0]
                if opt.nerf.density_noise_reg and mode=="train":
                    density += torch.randn_like(density)*opt.nerf.density_noise_reg
                density_activ = getattr(torch_F,opt.arch.density_activ) # relu_,abs_,sigmoid_,exp_....
                density = density_activ(density)
                feat = feat[...,1:]
            feat = torch_F.relu(feat)
        # predict RGB values
        if opt.nerf.view_dep:
            assert(ray_unit is not None)
            if opt.arch.posenc:
                ray_enc = self.positional_encoding(opt,ray_unit,L=opt.arch.posenc.L_view)
                ray_enc = torch.cat([ray_unit,ray_enc],dim=-1) # [B,...,6L+3]
            else: ray_enc = ray_unit
            feat = torch.cat([feat,ray_enc],dim=-1)
        for li,layer in enumerate(self.mlp_rgb):
            feat = layer(feat)
            if li!=len(self.mlp_rgb)-1:
                feat = torch_F.relu(feat)
        rgb = feat.sigmoid_() # [B,...,3]
        return rgb,density

    def forward_samples(self,opt,center,ray,depth_samples,mode=None):
        points_3D_samples = camera.get_3D_points_from_depth(opt,center,ray,depth_samples,multi_samples=True) # [B,HW,N,3]
        if opt.nerf.view_dep:
            ray_unit = torch_F.normalize(ray,dim=-1) # [B,HW,3]
            ray_unit_samples = ray_unit[...,None,:].expand_as(points_3D_samples) # [B,HW,N,3]
        else: ray_unit_samples = None
        rgb_samples,density_samples = self.forward(opt,points_3D_samples,ray_unit=ray_unit_samples,mode=mode) # [B,HW,N],[B,HW,N,3]
        return rgb_samples,density_samples

    def composite(self,opt,ray,rgb_samples,density_samples,depth_samples):
        ray_length = ray.norm(dim=-1,keepdim=True) # [B,HW,1]
        # volume rendering: compute probability (using quadrature)
        depth_intv_samples = depth_samples[...,1:,0]-depth_samples[...,:-1,0] # [B,HW,N-1]
        depth_intv_samples = torch.cat([depth_intv_samples,torch.empty_like(depth_intv_samples[...,:1]).fill_(1e10)],dim=2) # [B,HW,N]
        dist_samples = depth_intv_samples*ray_length # [B,HW,N]
        sigma_delta = density_samples*dist_samples # [B,HW,N]
        alpha = 1-(-sigma_delta).exp_() # [B,HW,N]
        T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)).exp_() # [B,HW,N]
        prob = (T*alpha)[...,None] # [B,HW,N,1]
        # integrate RGB and depth weighted by probability
        depth = (depth_samples*prob).sum(dim=2) # [B,HW,1]
        rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3]
        opacity = prob.sum(dim=2) # [B,HW,1]
        if opt.nerf.setbg_opaque:
            rgb = rgb+opt.data.bgcolor*(1-opacity)
        return rgb,depth,opacity,prob # [B,HW,K]

    def positional_encoding(self,opt,input,L): # [B,...,N]
        shape = input.shape
        freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]
        spectrum = input[...,None]*freq # [B,...,N,L]
        sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
        input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
        input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
        return input_enc


================================================
FILE: model/planar.py
================================================
import numpy as np
import os,sys,time
import torch
import torch.nn.functional as torch_F
import torchvision
import torchvision.transforms.functional as torchvision_F
import tqdm
from easydict import EasyDict as edict
import PIL
import PIL.Image,PIL.ImageDraw
import imageio

import util,util_vis
from util import log,debug
from . import base
import warp

# ============================ main engine for training and evaluation ============================

class Model(base.Model):

    def __init__(self,opt):
        super().__init__(opt)
        opt.H_crop,opt.W_crop = opt.data.patch_crop

    def load_dataset(self,opt,eval_split=None):
        image_raw = PIL.Image.open(opt.data.image_fname)
        self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device)

    def build_networks(self,opt):
        super().build_networks(opt)
        self.graph.warp_param = torch.nn.Embedding(opt.batch_size,opt.warp.dof).to(opt.device)
        torch.nn.init.zeros_(self.graph.warp_param.weight)

    def setup_optimizer(self,opt):
        log.info("setting up optimizers...")
        optim_list = [
            dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr),
            dict(params=self.graph.warp_param.parameters(),lr=opt.optim.lr_warp),
        ]
        optimizer = getattr(torch.optim,opt.optim.algo)
        self.optim = optimizer(optim_list)
        # set up scheduler
        if opt.optim.sched:
            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)
            kwargs = { k:v for k,v in opt.optim.sched.items() if k!="type" }
            self.sched = scheduler(self.optim,**kwargs)

    def setup_visualizer(self,opt):
        super().setup_visualizer(opt)
        # set colors for visualization
        box_colors = ["#ff0000","#40afff","#9314ff","#ffd700","#00ff00"]
        box_colors = list(map(util.colorcode_to_number,box_colors))
        self.box_colors = np.array(box_colors).astype(int)
        assert(len(self.box_colors)==opt.batch_size)
        # create visualization directory
        self.vis_path = "{}/vis".format(opt.output_path)
        os.makedirs(self.vis_path,exist_ok=True)
        self.video_fname = "{}/vis.mp4".format(opt.output_path)

    def train(self,opt):
        # before training
        log.title("TRAINING START")
        self.timer = edict(start=time.time(),it_mean=None)
        self.ep = self.it = self.vis_it = 0
        self.graph.train()
        var = edict(idx=torch.arange(opt.batch_size))
        # pre-generate perturbations
        self.homo_pert, self.rot_pert, self.trans_pert, var.image_pert = self.generate_warp_perturbation(opt)
        # train
        var = util.move_to_device(var,opt.device)
        loader = tqdm.trange(opt.max_iter,desc="training",leave=False)
        # visualize initial state
        var = self.graph.forward(opt,var)
        self.visualize(opt,var,step=0)
        for it in loader:
            # train iteration
            loss = self.train_iteration(opt,var,loader)
            if opt.warp.fix_first:
                self.graph.warp_param.weight.data[0] = 0
        # after training
        os.system("ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}".format(self.vis_path,self.video_fname))
        self.save_checkpoint(opt,ep=None,it=self.it)
        if opt.tb:
            self.tb.flush()
            self.tb.close()
        if opt.visdom: self.vis.close()
        log.title("TRAINING DONE")

    def train_iteration(self,opt,var,loader):
        loss = super().train_iteration(opt,var,loader)
        self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter)
        return loss

    def generate_warp_perturbation(self,opt):
        # pre-generate perturbations (translational noise + homography noise)
        def create_random_perturbation(batch_size):
            if opt.warp.dof==1:
                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
                trans_pert = torch.zeros(batch_size,2,device=opt.device)*opt.warp.noise_t
            elif opt.warp.dof==2:
                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
                rot_pert = torch.zeros(batch_size,1,device=opt.device)*opt.warp.noise_r
                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
            elif opt.warp.dof==3:
                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h
                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
            elif opt.warp.dof==8:
                homo_pert = torch.randn(batch_size,8,device=opt.device)*opt.warp.noise_h
                homo_pert[:,:2]=0
                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r
                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t
            else: assert(False)
            return homo_pert, rot_pert, trans_pert

        homo_pert = torch.zeros(opt.batch_size,8,device=opt.device)
        rot_pert = torch.zeros(opt.batch_size,1,device=opt.device)
        trans_pert = torch.zeros(opt.batch_size,2,device=opt.device)
        for i in range(opt.batch_size):
            homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)
            while not warp.check_corners_in_range_compose(opt, homo_pert_i, rot_pert_i, trans_pert_i):
                homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)
            homo_pert[i], rot_pert[i], trans_pert[i] = homo_pert_i, rot_pert_i, trans_pert_i

        if opt.warp.fix_first:
            homo_pert[0],rot_pert[0],trans_pert[0] = 0,0,0

        # create warped image patches
        xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2]

        xy_grid_hom = warp.camera.to_hom(xy_grid)
        warp_matrix = warp.lie.sl3_to_SL3(homo_pert)
        warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)
        xy_grid_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]
        warp_matrix = warp.lie.so2_to_SO2(rot_pert)
        xy_grid_warped = xy_grid_warped@warp_matrix.transpose(-2,-1) # [B,HW,2]
        xy_grid_warped = xy_grid_warped+trans_pert[...,None,:]

        xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2])
        xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W,
                                      xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1)
        image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1)
        image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False)

        return homo_pert, rot_pert, trans_pert, image_pert_all

    def visualize_patches(self,opt,warp_param):
        image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
        draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
        draw = PIL.ImageDraw.Draw(draw_pil)
        corners_all = warp.warp_corners(opt,warp_param)
        corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
        corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
        for i,corners in enumerate(corners_all):
            P = [tuple(float(n) for n in corners[j]) for j in range(4)]
            draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
        image_pil.alpha_composite(draw_pil)
        image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
        return image_tensor

    def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):
        image_pil = torchvision_F.to_pil_image(self.image_raw).convert("RGBA")
        draw_pil = PIL.Image.new("RGBA",image_pil.size,(0,0,0,0))
        draw = PIL.ImageDraw.Draw(draw_pil)
        corners_all = warp.warp_corners_compose(opt, homo_pert, rot_pert, trans_pert)
        corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
        corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
        for i,corners in enumerate(corners_all):
            P = [tuple(float(n) for n in corners[j]) for j in range(4)]
            draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)
        image_pil.alpha_composite(draw_pil)
        image_tensor = torchvision_F.to_tensor(image_pil.convert("RGB"))
        return image_tensor

    @torch.no_grad()
    def predict_entire_image(self,opt):
        xy_grid = warp.get_normalized_pixel_grid(opt)[:1]
        rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3]
        image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1)
        return image

    @torch.no_grad()
    def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)
        # compute PSNR
        psnr = -10*loss.render.log10()
        self.tb.add_scalar("{0}/{1}".format(split,"PSNR"),psnr,step)
        # warp error
        # pred_corners = warp.warp_corners_use_matrix(opt,self.graph.warp_matrix)
        pred_corners = warp.warp_corners(opt,self.graph.warp_param.weight)
        pred_corners[...,0] = (pred_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
        pred_corners[...,1] = (pred_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5
        
        gt_corners = warp.warp_corners_compose(opt, self.homo_pert, self.rot_pert, self.trans_pert)
        gt_corners[...,0] = (gt_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5
        gt_corners[...,1] = (gt_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5

        warp_error = (pred_corners-gt_corners).norm(dim=-1).mean()
        self.tb.add_scalar("{0}/{1}".format(split,"warp error"),warp_error,step)


    @torch.no_grad()
    def visualize(self,opt,var,step=0,split="train"):
        # dump frames for writing to video
        frame_GT = self.visualize_patches_compose(opt,self.homo_pert, self.rot_pert, self.trans_pert)
        frame = self.visualize_patches(opt,self.graph.warp_param.weight)
        frame2 = self.predict_entire_image(opt)
        frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy()
        imageio.imsave("{}/{}.png".format(self.vis_path,self.vis_it),frame_cat)
        self.vis_it += 1
        # visualize in Tensorboard
        if opt.tb:
            colors = self.box_colors
            util_vis.tb_image(opt,self.tb,step,split,"image_pert",util_vis.color_border(var.image_pert,colors))
            util_vis.tb_image(opt,self.tb,step,split,"rgb_warped",util_vis.color_border(var.rgb_warped_map,colors))
            util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes",frame[None])
            util_vis.tb_image(opt,self.tb,self.it+1,"train","image_boxes_GT",frame_GT[None])
            util_vis.tb_image(opt,self.tb,self.it+1,"train","image_entire",frame2[None])

# ============================ computation graph for forward/backprop ============================

class Graph(base.Graph):

    def __init__(self,opt):
        super().__init__(opt)
        self.neural_image = NeuralImageFunction(opt)

    def forward(self,opt,var,mode=None):
        xy_grid = warp.get_normalized_pixel_grid_crop(opt)
        xy_grid_warped = warp.warp_grid(opt,xy_grid,self.warp_param.weight)
        # render images
        var.rgb_warped = self.neural_image.forward(opt,xy_grid_warped) # [B,HW,3]
        var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W]
        return var

    def compute_loss(self,opt,var,mode=None):
        loss = edict()
        if opt.loss_weight.render is not None:
            image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1)
            loss.render = self.MSE_loss(var.rgb_warped,image_pert)
        return loss

class NeuralImageFunction(torch.nn.Module):

    def __init__(self,opt):
        super().__init__()
        self.define_network(opt)
        self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed

    def define_network(self,opt):
        input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2
        # point-wise RGB prediction
        self.mlp = torch.nn.ModuleList()
        L = util.get_layer_dims(opt.arch.layers)
        for li,(k_in,k_out) in enumerate(L):
            if li==0: k_in = input_2D_dim
            if li in opt.arch.skip: k_in += input_2D_dim
            linear = torch.nn.Linear(k_in,k_out)
            if opt.barf_c2f and li==0:
                # rescale first layer init (distribution was for pos.enc. but only xy is first used)
                scale = np.sqrt(input_2D_dim/2.)
                linear.weight.data *= scale
                linear.bias.data *= scale
            self.mlp.append(linear)

    def forward(self,opt,coord_2D): # [B,...,3]
        if opt.arch.posenc:
            points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D)
            points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3]
        else: points_enc = coord_2D
        feat = points_enc
        # extract implicit features
        for li,layer in enumerate(self.mlp):
            if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)
            feat = layer(feat)
            if li!=len(self.mlp)-1:
                feat = torch_F.relu(feat)
        rgb = feat.sigmoid_() # [B,...,3]
        return rgb

    def positional_encoding(self,opt,input,L): # [B,...,N]
        shape = input.shape
        freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]
        spectrum = input[...,None]*freq # [B,...,N,L]
        sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]
        input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]
        input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]
        # coarse-to-fine: smoothly mask positional encoding for BARF
        if opt.barf_c2f is not None:
            # set weights for different frequency bands
            start,end = opt.barf_c2f
            alpha = (self.progress.data-start)/(end-start)*L
            k = torch.arange(L,dtype=torch.float32,device=opt.device)
            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2
            # apply weights
            shape = input_enc.shape
            input_enc = (input_enc.view(-1,L)*weight).view(*shape)
        return input_enc


================================================
FILE: options/barf_blender.yaml
================================================
_parent_: options/nerf_blender.yaml

barf_c2f: [0.1,0.5]                                         # coarse-to-fine scheduling on positional encoding

camera:                                                     # camera options
    noise: true                                             # synthetic perturbations on the camera poses (Blender only)
    noise_r: 0.03                                           # synthetic perturbations on the camera poses (Blender only)
    noise_t: 0.3                                            # synthetic perturbations on the camera poses (Blender only)

optim:                                                      # optimization options
    lr_pose: 1.e-3                                          # learning rate of camera poses
    lr_pose_end: 1.e-5                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
    sched_pose:                                             # learning rate scheduling options
        type: ExponentialLR                                 # scheduler (see PyTorch doc)
        gamma:                                              # decay rate (can be empty if lr_pose_end were specified)
    warmup_pose:                                            # linear warmup of the pose learning rate (N iterations)
    test_photo: true                                        # test-time photometric optimization for evaluation
    test_iter: 100                                          # number of iterations for test-time optimization

visdom:                                                     # Visdom options
    cam_depth: 0.5                                          # size of visualized cameras


================================================
FILE: options/barf_iphone.yaml
================================================
_parent_: options/barf_llff.yaml

data:                                                       # data options
    dataset: iphone                                         # dataset name
    scene: IMG_0239                                         # scene name
    image_size: [480,640]                                   # input image sizes [height,width]


================================================
FILE: options/barf_llff.yaml
================================================
_parent_: options/nerf_llff.yaml

barf_c2f: [0.1,0.5]                                         # coarse-to-fine scheduling on positional encoding

camera:                                                     # camera options
    noise:                                                  # synthetic perturbations on the camera poses (Blender only)

optim:                                                      # optimization options
    lr_pose: 3.e-3                                          # learning rate of camera poses
    lr_pose_end: 1.e-5                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
    sched_pose:                                             # learning rate scheduling options
        type: ExponentialLR                                 # scheduler (see PyTorch doc)
        gamma:                                              # decay rate (can be empty if lr_pose_end were specified)
    warmup_pose:                                            # linear warmup of the pose learning rate (N iterations)
    test_photo: true                                        # test-time photometric optimization for evaluation
    test_iter: 100                                          # number of iterations for test-time optimization

visdom:                                                     # Visdom options
    cam_depth: 0.2                                          # size of visualized cameras


================================================
FILE: options/base.yaml
================================================
# default

group: 0_test                                               # name of experiment group
name: debug                                                 # name of experiment run
model:                                                      # type of model (must be specified from command line)
yaml:                                                       # config file (must be specified from command line)
seed: 0                                                     # seed number (for both numpy and pytorch)
gpu: 0                                                      # GPU index number
cpu: false                                                  # run only on CPU (not supported now)
load:                                                       # load checkpoint from filename

arch: {}                                                    # architectural options

data:                                                       # data options
    root:                                                   # root path to dataset
    dataset:                                                # dataset name
    image_size: [null,null]                                 # input image sizes [height,width]
    num_workers: 8                                          # number of parallel workers for data loading
    preload: false                                          # preload the entire dataset into the memory
    augment: {}                                             # data augmentation (training only)
        # rotate:                                           # random rotation
        # brightness: # 0.2                                 # random brightness jitter
        # contrast: # 0.2                                   # random contrast jitter
        # saturation: # 0.2                                 # random saturation jitter
        # hue: # 0.1                                        # random hue jitter
        # hflip: # True                                     # random horizontal flip
    center_crop:                                            # center crop the image by ratio
    val_on_test: false                                      # validate on test set during training
    train_sub:                                              # consider a subset of N training samples
    val_sub:                                                # consider a subset of N validation samples

loss_weight: {}                                             # loss weights (in log scale)

optim:                                                      # optimization options
    lr: 1.e-3                                               # learning rate (main)
    lr_end:                                                 # terminal learning rate (only used with sched.type=ExponentialLR)
    algo: Adam                                              # optimizer (see PyTorch doc)
    sched: {}                                               # learning rate scheduling options
        # type: StepLR                                      # scheduler (see PyTorch doc)
        # steps:                                            # decay every N epochs
        # gamma: 0.1                                        # decay rate (can be empty if lr_end were specified)

batch_size: 16                                              # batch size
max_epoch: 1000                                             # train to maximum number of epochs
resume: false                                               # resume training (true for latest checkpoint, or number for specific epoch number)

output_root: output                                         # root path for output files (checkpoints and results)
tb:                                                         # TensorBoard options
    num_images: [4,8]                                       # number of (tiled) images to visualize in TensorBoard
visdom:                                                     # Visdom options
    server: localhost                                       # server to host Visdom
    port: 8600                                              # port number for Visdom

freq:                                                       # periodic actions during training
    scalar: 200                                             # log losses and scalar states (every N iterations)
    vis: 1000                                               # visualize results (every N iterations)
    val: 20                                                 # validate on val set (every N epochs)
    ckpt: 50                                                # save checkpoint (every N epochs)


================================================
FILE: options/l2g_nerf_blender.yaml
================================================
_parent_: options/barf_blender.yaml

arch:                                                       # architectural options
    layers_warp: [null,256,256,256,256,256,256,6]           # hidden layers for MLP
    skip_warp: [4]                                          # skip connections
    embedding_dim: 128                                      # embedding dim

optim:                                                      # optimization options
    lr_pose: 1.e-3                                          # learning rate of camera poses
    lr_pose_end: 1.e-8                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)

loss_weight:                                                # loss weights (in log scale)
    global_alignment: 2                                     # global alignment loss

error_map_size:

================================================
FILE: options/l2g_nerf_iphone.yaml
================================================
_parent_: options/l2g_nerf_llff.yaml

data:                                                       # data options
    dataset: iphone                                         # dataset name
    scene: IMG_0239                                         # scene name
    image_size: [480,640]                                   # input image sizes [height,width]

optim:                                                      # optimization options
    test_iter: 1000                                         # number of iterations for test-time optimization

================================================
FILE: options/l2g_nerf_llff.yaml
================================================
_parent_: options/barf_llff.yaml

arch:                                                       # architectural options
    layers_warp: [null,256,256,256,256,256,256,6]           # hidden layers for MLP
    skip_warp: [4]                                          # skip connections
    embedding_dim: 128                                      # embedding dim

optim:                                                      # optimization options
    lr_pose: 3.e-3                                          # learning rate of camera poses
    lr_pose_end: 1.e-8                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)

loss_weight:                                                # loss weights (in log scale)
    global_alignment: 2                                     # global alignment loss

error_map_size:

================================================
FILE: options/l2g_planar.yaml
================================================
_parent_: options/planar.yaml

arch:                                                       # architectural options
    layers_warp: [null,256,256,256,256,256,256,3]           # hidden layers for MLP
    skip_warp: [4]                                          # skip connections
    embedding_dim: 128                                      # embedding dim

loss_weight:                                                # loss weights (in log scale)
    global_alignment: 2                                     # global alignment loss

optim:
    lr: 1.e-3                                               # learning rate (main)
    lr_end: 1.e-4                                           # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
    sched:                                                  # learning rate scheduling options
        type: ExponentialLR                                 # scheduler (see PyTorch doc)
        gamma:                                              # decay rate (can be empty if lr_pose_end were specified)
    lr_warp: 1.e-3                                          # learning rate of camera poses
    lr_warp_end: 1.e-5                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)
    sched_warp:                                             # learning rate scheduling options
        type: ExponentialLR                                 # scheduler (see PyTorch doc)
        gamma:                                              # decay rate (can be empty if lr_pose_end were specified)

================================================
FILE: options/nerf_blender.yaml
================================================
_parent_: options/base.yaml

arch:                                                       # architectural options
    layers_feat: [null,256,256,256,256,256,256,256,256]     # hidden layers for feature/density MLP
    layers_rgb: [null,128,3]                                # hidden layers for color MLP
    skip: [4]                                               # skip connections
    posenc:                                                 # positional encoding
        L_3D: 10                                            # number of bases (3D point)
        L_view: 4                                           # number of bases (viewpoint)
    density_activ: softplus                                 # activation function for output volume density
    tf_init: true                                           # initialize network weights in TensorFlow style

nerf:                                                       # NeRF-specific options
    view_dep: true                                          # condition MLP on viewpoint
    depth:                                                  # depth-related options
        param: metric                                       # depth parametrization (for sampling along the ray)
        range: [2,6]                                        # near/far bounds for depth sampling
    sample_intvs: 128                                       # number of samples
    sample_stratified: true                                 # stratified sampling
    fine_sampling: false                                    # hierarchical sampling with another NeRF
    sample_intvs_fine:                                      # number of samples for the fine NeRF
    rand_rays: 1024                                         # number of random rays for each step
    density_noise_reg:                                      # Gaussian noise on density output as regularization
    setbg_opaque: false                                     # fill transparent rendering with known background color (Blender only)

data:                                                       # data options
    dataset: blender                                        # dataset name
    scene: lego                                             # scene name
    image_size: [400,400]                                   # input image sizes [height,width]
    num_workers: 4                                          # number of parallel workers for data loading
    preload: true                                           # preload the entire dataset into the memory
    bgcolor: 1                                              # background color (Blender only)
    val_sub: 4                                              # consider a subset of N validation samples

camera:                                                     # camera options
    model: perspective                                      # type of camera model
    ndc: false                                              # reparametrize as normalized device coordinates (NDC)

loss_weight:                                                # loss weights (in log scale)
    render: 0                                               # RGB rendering loss
    render_fine:                                            # RGB rendering loss (for fine NeRF)

optim:                                                      # optimization options
    lr: 5.e-4                                               # learning rate (main)
    lr_end: 1.e-4                                           # terminal learning rate (only used with sched.type=ExponentialLR)
    sched:                                                  # learning rate scheduling options
        type: ExponentialLR                                 # scheduler (see PyTorch doc)
        gamma:                                              # decay rate (can be empty if lr_end were specified)

batch_size:                                                 # batch size (not used for NeRF/BARF)
max_epoch:                                                  # train to maximum number of epochs (not used for NeRF/BARF)
max_iter: 200000                                            # train to maximum number of iterations

trimesh:                                                    # options for marching cubes to extract 3D mesh
    res: 128                                                # 3D sampling resolution
    range: [-1.2,1.2]                                       # 3D range of interest (assuming same for x,y,z)
    thres: 25.                                              # volume density threshold for marching cubes
    chunk_size: 16384                                       # chunk size of dense samples to be evaluated at a time

freq:                                                       # periodic actions during training
    scalar: 200                                             # log losses and scalar states (every N iterations)
    vis: 1000                                               # visualize results (every N iterations)
    val: 2000                                               # validate on val set (every N iterations)
    ckpt: 5000                                              # save checkpoint (every N iterations)


================================================
FILE: options/nerf_blender_repr.yaml
================================================
_parent_: options/base.yaml

arch:                                                       # architectural options
    layers_feat: [null,256,256,256,256,256,256,256,256]     # hidden layers for feature/density MLP
    layers_rgb: [null,128,3]                                # hidden layers for color MLP
    skip: [4]                                               # skip connections
    posenc:                                                 # positional encoding
        L_3D: 10                                            # number of bases (3D point)
        L_view: 4                                           # number of bases (viewpoint)
    density_activ: relu                                     # activation function for output volume density
    tf_init: true                                           # initialize network weights in TensorFlow style

nerf:                                                       # NeRF-specific options
    view_dep: true                                          # condition MLP on viewpoint
    depth:                                                  # depth-related options
        param: metric                                       # depth parametrization (for sampling along the ray)
        range: [2,6]                                        # near/far bounds for depth sampling
    sample_intvs: 64                                        # number of samples
    sample_stratified: true                                 # stratified sampling
    fine_sampling: true                                     # hierarchical sampling with another NeRF
    sample_intvs_fine: 128                                  # number of samples for the fine NeRF
    rand_rays: 1024                                         # number of random rays for each step
    density_noise_reg: 0                                    # Gaussian noise on density output as regularization
    setbg_opaque: true                                      # fill transparent rendering with known background color (Blender only)

data:                                                       # data options
    dataset: blender                                        # dataset name
    scene: lego                                             # scene name
    image_size: [400,400]                                   # input image sizes [height,width]
    num_workers: 4                                          # number of parallel workers for data loading
    preload: true                                           # preload the entire dataset into the memory
    bgcolor: 1                                              # background color (Blender only)
    val_sub: 4                                              # consider a subset of N validation samples

camera:                                                     # camera options
    model: perspective                                      # type of camera model
    ndc: false                                              # reparametrize as normalized device coordinates (NDC)

loss_weight:                                                # loss weights (in log scale)
    render: 0                                               # RGB rendering loss
    render_fine: 0                                          # RGB rendering loss (for fine NeRF)

optim:                                                      # optimization options
    lr: 5.e-4                                               # learning rate (main)
    lr_end: 5.e-5                                           # terminal learning rate (only used with sched.type=ExponentialLR)
    sched:                                                  # learning rate scheduling options
        type: ExponentialLR                                 # scheduler (see PyTorch doc)
        gamma:                                              # decay rate (can be empty if lr_end were specified)

batch_size:                                                 # batch size (not used for NeRF/BARF)
max_epoch:                                                  # train to maximum number of epochs (not used for NeRF/BARF)
max_iter: 500000                                            # train to maximum number of iterations

trimesh:                                                    # options for marching cubes to extract 3D mesh
    res: 128                                                # 3D sampling resolution
    range: [-1.2,1.2]                                       # 3D range of interest (assuming same for x,y,z)
    thres: 25.                                              # volume density threshold for marching cubes
    chunk_size: 16384                                       # chunk size of dense samples to be evaluated at a time

freq:                                                       # periodic actions during training
    scalar: 200                                             # log losses and scalar states (every N iterations)
    vis: 1000                                               # visualize results (every N iterations)
    val: 2000                                               # validate on val set (every N iterations)
    ckpt: 5000                                              # save checkpoint (every N iterations)


================================================
FILE: options/nerf_llff.yaml
================================================
_parent_: options/base.yaml

arch:                                                       # architectural optionss
    layers_feat: [null,256,256,256,256,256,256,256,256]     # hidden layers for feature/density MLP]
    layers_rgb: [null,128,3]                                # hidden layers for color MLP]
    skip: [4]                                               # skip connections
    posenc:                                                 # positional encoding:
        L_3D: 10                                            # number of bases (3D point)
        L_view: 4                                           # number of bases (viewpoint)
    density_activ: softplus                                 # activation function for output volume density
    tf_init: true                                           # initialize network weights in TensorFlow style

nerf:                                                       # NeRF-specific options
    view_dep: true                                          # condition MLP on viewpoint
    depth:                                                  # depth-related options
        param: inverse                                      # depth parametrization (for sampling along the ray)
        range: [1,0]                                        # near/far bounds for depth sampling
    sample_intvs: 128                                       # number of samples
    sample_stratified: true                                 # stratified sampling
    fine_sampling: false                                    # hierarchical sampling with another NeRF
    sample_intvs_fine:                                      # number of samples for the fine NeRF
    rand_rays: 2048                                         # number of random rays for each step
    density_noise_reg:                                      # Gaussian noise on density output as regularization
    setbg_opaque:                                           # fill transparent rendering with known background color (Blender only)

data:                                                       # data options
    dataset: llff                                           # dataset name
    scene: fern                                             # scene name
    image_size: [480,640]                                   # input image sizes [height,width]
    num_workers: 4                                          # number of parallel workers for data loading
    preload: true                                           # preload the entire dataset into the memory
    val_ratio: 0.1                                          # ratio of sequence split for validation

camera:                                                     # camera options
    model: perspective                                      # type of camera model
    ndc: false                                              # reparametrize as normalized device coordinates (NDC)

loss_weight:                                                # loss weights (in log scale)
    render: 0                                               # RGB rendering loss
    render_fine:                                            # RGB rendering loss (for fine NeRF)

optim:                                                      # optimization options
    lr: 1.e-3                                               # learning rate (main)
    lr_end: 1.e-4                                           # terminal learning rate (only used with sched.type=ExponentialLR)
    sched:                                                  # learning rate scheduling options
        type: ExponentialLR                                 # scheduler (see PyTorch doc)
        gamma:                                              # decay rate (can be empty if lr_end were specified)

batch_size:                                                 # batch size (not used for NeRF/BARF)
max_epoch:                                                  # train to maximum number of epochs (not used for NeRF/BARF)
max_iter: 200000                                            # train to maximum number of iterations

freq:                                                       # periodic actions during training
    scalar: 200                                             # log losses and scalar states (every N iterations)
    vis: 1000                                               # visualize results (every N iterations)
    val: 2000                                               # validate on val set (every N iterations)
    ckpt: 5000                                              # save checkpoint (every N iterations)


================================================
FILE: options/nerf_llff_repr.yaml
================================================
_parent_: options/base.yaml

arch:                                                       # architectural options
    layers_feat: [null,256,256,256,256,256,256,256,256]     # hidden layers for feature/density MLP
    layers_rgb: [null,128,3]                                # hidden layers for color MLP
    skip: [4]                                               # skip connections
    posenc:                                                 # positional encoding
        L_3D: 10                                            # number of bases (3D point)
        L_view: 4                                           # number of bases (viewpoint)
    density_activ: relu                                     # activation function for output volume density
    tf_init: true                                           # initialize network weights in TensorFlow style

nerf:                                                       # NeRF-specific options
    view_dep: true                                          # condition MLP on viewpoint
    depth:                                                  # depth-related options
        param: metric                                       # depth parametrization (for sampling along the ray)
        range: [0,1]                                        # near/far bounds for depth sampling
    sample_intvs: 64                                        # number of samples
    sample_stratified: true                                 # stratified sampling
    fine_sampling: true                                     # hierarchical sampling with another NeRF
    sample_intvs_fine: 128                                  # number of samples for the fine NeRF
    rand_rays: 1024                                         # number of random rays for each step
    density_noise_reg: 1                                    # Gaussian noise on density output as regularization
    setbg_opaque:                                           # fill transparent rendering with known background color (Blender only)

data:                                                       # data options
    dataset: llff                                           # dataset name
    scene: fern                                             # scene name
    image_size: [480,640]                                   # input image sizes [height,width]
    num_workers: 4                                          # number of parallel workers for data loading
    preload: true                                           # preload the entire dataset into the memory
    val_ratio: 0.1                                          # ratio of sequence split for validation

camera:                                                     # camera options
    model: perspective                                      # type of camera model
    ndc: false                                              # reparametrize as normalized device coordinates (NDC)

loss_weight:                                                # loss weights (in log scale)
    render: 0                                               # RGB rendering loss
    render_fine: 0                                          # RGB rendering loss (for fine NeRF)

optim:                                                      # optimization options
    lr: 5.e-4                                               # learning rate (main)
    lr_end: 5.e-5                                           # terminal learning rate (only used with sched.type=ExponentialLR)
    sched:                                                  # learning rate scheduling options
        type: ExponentialLR                                 # scheduler (see PyTorch doc)
        gamma:                                              # decay rate (can be empty if lr_end were specified)

batch_size:                                                 # batch size (not used for NeRF/BARF)
max_epoch:                                                  # train to maximum number of epochs (not used for NeRF/BARF)
max_iter: 500000                                            # train to maximum number of iterations

freq:                                                       # periodic actions during training
    scalar: 200                                             # log losses and scalar states (every N iterations)
    vis: 1000                                               # visualize results (every N iterations)
    val: 2000                                               # validate on val set (every N iterations)
    ckpt: 5000                                              # save checkpoint (every N iterations)


================================================
FILE: options/planar.yaml
================================================
_parent_: options/base.yaml

arch:                                                       # architectural options
    layers: [null,256,256,256,256,3]                        # hidden layers for MLP
    skip: []                                                # skip connections
    posenc:                                                 # positional encoding
        L_2D: 8                                             # number of bases (3D point)

barf_c2f: [0,0.4]                                           # coarse-to-fine scheduling on positional encoding

data:                                                       # data options
    image_fname: data/cat.jpg                               # path to image file
    image_size: [360,480]                                   # original image size
    patch_crop: [180,180]                                   # crop size of image patches to align

warp:                                                       # image warping options
    type: homography                                        # type of warp function
    dof: 8                                                  # degrees of freedom of the warp function
    noise_h: 0.3                                            # scale of pre-generated warp perturbation (homography)
    noise_t: 0.1                                            # scale of pre-generated warp perturbation (translation)
    noise_r: 1.0                                            # scale of pre-generated warp perturbation (rotation)
    fix_first: true                                         # fix the first patch for uniqueness of solution

loss_weight:                                                # loss weights (in log scale)
    render: 0                                               # RGB rendering loss

optim:                                                      # optimization options
    lr: 1.e-3                                               # learning rate (main)
    lr_warp: 1.e-3                                          # learning rate of warp parameters

batch_size: 5                                               # batch size (set to number of patches to consider)
max_iter: 5000                                              # train to maximum number of iterations

visdom:                                                     # Visdom options (turned off)

freq:                                                       # periodic actions during training
    scalar: 20                                              # log losses and scalar states (every N iterations)
    vis: 100                                                # visualize results (every N iterations)


================================================
FILE: options.py
================================================
import numpy as np
import os,sys,time
import torch
import random
import string
import yaml
from easydict import EasyDict as edict

import util
from util import log

# torch.backends.cudnn.enabled = False
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True

def parse_arguments(args):
    """
    Parse arguments from command line.
    Syntax: --key1.key2.key3=value --> value
            --key1.key2.key3=      --> None
            --key1.key2.key3       --> True
            --key1.key2.key3!      --> False
    """
    opt_cmd = {}
    for arg in args:
        assert(arg.startswith("--"))
        if "=" not in arg[2:]:
            key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true")
        else:
            key_str,value = arg[2:].split("=")
        keys_sub = key_str.split(".")
        opt_sub = opt_cmd
        for k in keys_sub[:-1]:
            if k not in opt_sub: opt_sub[k] = {}
            opt_sub = opt_sub[k]
        assert keys_sub[-1] not in opt_sub,keys_sub[-1]
        opt_sub[keys_sub[-1]] = yaml.safe_load(value)
    opt_cmd = edict(opt_cmd)
    return opt_cmd

def set(opt_cmd={}):
    log.info("setting configurations...")
    assert("model" in opt_cmd)
    # load config from yaml file
    assert("yaml" in opt_cmd)
    fname = "options/{}.yaml".format(opt_cmd.yaml)
    opt_base = load_options(fname)
    # override with command line arguments
    opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True)
    process_options(opt)
    log.options(opt)
    return opt

def load_options(fname):
    with open(fname) as file:
        opt = edict(yaml.safe_load(file))
    if "_parent_" in opt:
        # load parent yaml file(s) as base options
        parent_fnames = opt.pop("_parent_")
        if type(parent_fnames) is str:
            parent_fnames = [parent_fnames]
        for parent_fname in parent_fnames:
            opt_parent = load_options(parent_fname)
            opt_parent = override_options(opt_parent,opt,key_stack=[])
            opt = opt_parent
    print("loading {}...".format(fname))
    return opt

def override_options(opt,opt_over,key_stack=None,safe_check=False):
    for key,value in opt_over.items():
        if isinstance(value,dict):
            # parse child options (until leaf nodes are reached)
            opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check)
        else:
            # ensure command line argument to override is also in yaml file
            if safe_check and key not in opt:
                add_new = None
                while add_new not in ["y","n"]:
                    key_str = ".".join(key_stack+[key])
                    add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str))
                if add_new=="n":
                    print("safe exiting...")
                    exit()
            opt[key] = value
    return opt

def process_options(opt):
    # set seed
    if opt.seed is not None:
        random.seed(opt.seed)
        np.random.seed(opt.seed)
        torch.manual_seed(opt.seed)
        torch.cuda.manual_seed_all(opt.seed)
        if opt.seed!=0:
            opt.name = str(opt.name)+"_seed{}".format(opt.seed)
    else:
        # create random string as run ID
        randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4))
        opt.name = str(opt.name)+"_{}".format(randkey)
    # other default options
    opt.output_path = "{0}/{1}/{2}".format(opt.output_root,opt.group,opt.name)
    os.makedirs(opt.output_path,exist_ok=True)
    assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough
    opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu)
    opt.H,opt.W = opt.data.image_size

def save_options_file(opt):
    opt_fname = "{}/options.yaml".format(opt.output_path)
    if os.path.isfile(opt_fname):
        with open(opt_fname) as file:
            opt_old = yaml.safe_load(file)
        if opt!=opt_old:
            # prompt if options are not identical
            opt_new_fname = "{}/options_temp.yaml".format(opt.output_path)
            with open(opt_new_fname,"w") as file:
                yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4)
            print("existing options file found (different from current one)...")
            os.system("diff {} {}".format(opt_fname,opt_new_fname))
            os.system("rm {}".format(opt_new_fname))
            override = None
            while override not in ["y","n"]:
                override = input("override? (y/n) ")
            if override=="n":
                print("safe exiting...")
                exit()
        else: print("existing options file found (identical)")
    else: print("(creating new options file...)")
    with open(opt_fname,"w") as file:
        yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4)


================================================
FILE: requirements.yaml
================================================
name: L2G-NeRF
channels:
  - conda-forge
  - pytorch
dependencies:
  - numpy
  - scipy
  - tqdm
  - termcolor
  - easydict
  - imageio
  - ipdb
  - pytorch>=1.9.0
  - torchvision
  - tensorboard
  - visdom
  - matplotlib
  - scikit-video
  - trimesh
  - pyyaml
  - pip
  - gdown
  - pip:
    - lpips
    - pymcubes
    - roma
    - kornia


================================================
FILE: train.py
================================================
import numpy as np
import os,sys,time
import torch
import importlib

import options
from util import log

import warnings
warnings.filterwarnings('ignore')

def main():

    log.process(os.getpid())
    log.title("[{}] (PyTorch code for training NeRF/BARF/L2G_NeRF)".format(sys.argv[0]))

    opt_cmd = options.parse_arguments(sys.argv[1:])
    opt = options.set(opt_cmd=opt_cmd)
    options.save_options_file(opt)

    with torch.cuda.device(opt.device):

        model = importlib.import_module("model.{}".format(opt.model))
        m = model.Model(opt)

        m.load_dataset(opt)
        m.build_networks(opt)
        m.setup_optimizer(opt)
        m.restore_checkpoint(opt)
        m.setup_visualizer(opt)

        m.train(opt)

if __name__=="__main__":
    main()


================================================
FILE: util.py
================================================
import numpy as np
import os,sys,time
import shutil
import datetime
import torch
import torch.nn.functional as torch_F
import ipdb
import types
import termcolor
import socket
import contextlib
from easydict import EasyDict as edict

# convert to colored strings
def red(message,**kwargs): return termcolor.colored(str(message),color="red",attrs=[k for k,v in kwargs.items() if v is True])
def green(message,**kwargs): return termcolor.colored(str(message),color="green",attrs=[k for k,v in kwargs.items() if v is True])
def blue(message,**kwargs): return termcolor.colored(str(message),color="blue",attrs=[k for k,v in kwargs.items() if v is True])
def cyan(message,**kwargs): return termcolor.colored(str(message),color="cyan",attrs=[k for k,v in kwargs.items() if v is True])
def yellow(message,**kwargs): return termcolor.colored(str(message),color="yellow",attrs=[k for k,v in kwargs.items() if v is True])
def magenta(message,**kwargs): return termcolor.colored(str(message),color="magenta",attrs=[k for k,v in kwargs.items() if v is True])
def grey(message,**kwargs): return termcolor.colored(str(message),color="grey",attrs=[k for k,v in kwargs.items() if v is True])

def get_time(sec):
    d = int(sec//(24*60*60))
    h = int(sec//(60*60)%24)
    m = int((sec//60)%60)
    s = int(sec%60)
    return d,h,m,s

def add_datetime(func):
    def wrapper(*args,**kwargs):
        datetime_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(grey("[{}] ".format(datetime_str),bold=True),end="")
        return func(*args,**kwargs)
    return wrapper

def add_functionname(func):
    def wrapper(*args,**kwargs):
        print(grey("[{}] ".format(func.__name__),bold=True))
        return func(*args,**kwargs)
    return wrapper

def pre_post_actions(pre=None,post=None):
    def func_decorator(func):
        def wrapper(*args,**kwargs):
            if pre: pre()
            retval = func(*args,**kwargs)
            if 
Download .txt
gitextract_16lz1fht/

├── LICENSE
├── README.md
├── camera.py
├── data/
│   ├── base.py
│   ├── blender.py
│   ├── iphone.py
│   └── llff.py
├── evaluate.py
├── external/
│   └── pohsun_ssim/
│       ├── LICENSE.txt
│       ├── README.md
│       ├── max_ssim.py
│       ├── pytorch_ssim/
│       │   └── __init__.py
│       ├── setup.cfg
│       └── setup.py
├── extract_mesh.py
├── model/
│   ├── barf.py
│   ├── base.py
│   ├── l2g_nerf.py
│   ├── l2g_planar.py
│   ├── nerf.py
│   └── planar.py
├── options/
│   ├── barf_blender.yaml
│   ├── barf_iphone.yaml
│   ├── barf_llff.yaml
│   ├── base.yaml
│   ├── l2g_nerf_blender.yaml
│   ├── l2g_nerf_iphone.yaml
│   ├── l2g_nerf_llff.yaml
│   ├── l2g_planar.yaml
│   ├── nerf_blender.yaml
│   ├── nerf_blender_repr.yaml
│   ├── nerf_llff.yaml
│   ├── nerf_llff_repr.yaml
│   └── planar.yaml
├── options.py
├── requirements.yaml
├── train.py
├── util.py
├── util_vis.py
└── warp.py
Download .txt
SYMBOL INDEX (296 symbols across 18 files)

FILE: camera.py
  class Pose (line 11) | class Pose():
    method __call__ (line 17) | def __call__(self,R=None,t=None):
    method invert (line 36) | def invert(self,pose,use_inverse=False):
    method compose (line 44) | def compose(self,pose_list):
    method compose_pair (line 52) | def compose_pair(self,pose_a,pose_b):
  class Lie (line 61) | class Lie():
    method so3_to_SO3 (line 66) | def so3_to_SO3(self,w): # [...,3]
    method SO3_to_so3 (line 75) | def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]
    method se3_to_SE3 (line 83) | def se3_to_SE3(self,wu): # [...,3]
    method SE3_to_se3 (line 96) | def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]
    method skew_symmetric (line 109) | def skew_symmetric(self,w):
    method taylor_A (line 117) | def taylor_A(self,x,nth=10):
    method taylor_B (line 125) | def taylor_B(self,x,nth=10):
    method taylor_C (line 133) | def taylor_C(self,x,nth=10):
  class Quaternion (line 142) | class Quaternion():
    method q_to_R (line 144) | def q_to_R(self,q):
    method R_to_q (line 152) | def R_to_q(self,R,eps=1e-8): # [B,3,3]
    method invert (line 178) | def invert(self,q):
    method product (line 184) | def product(self,q1,q2): # [B,4]
  function to_hom (line 197) | def to_hom(X):
  function world2cam (line 203) | def world2cam(X,pose): # [B,N,3]
  function cam2img (line 206) | def cam2img(X,cam_intr):
  function img2cam (line 208) | def img2cam(X,cam_intr):
  function cam2world (line 210) | def cam2world(X,pose):
  function angle_to_rotation_matrix (line 215) | def angle_to_rotation_matrix(a,axis):
  function get_center_and_ray (line 226) | def get_center_and_ray(opt,pose,intr=None): # [HW,2]
  function get_camera_cords_grid_3D (line 246) | def get_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [...
  function gather_camera_cords_grid_3D (line 263) | def gather_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): ...
  function get_3D_points_from_depth (line 280) | def get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False):
  function convert_NDC (line 286) | def convert_NDC(opt,center,ray,intr,near=1):
  function rotation_distance (line 305) | def rotation_distance(R1,R2,eps=1e-7):
  function procrustes_analysis (line 312) | def procrustes_analysis(X0,X1): # [N,3]
  function get_novel_view_poses (line 331) | def get_novel_view_poses(opt,pose_anchor,N=60,scale=1):

FILE: data/base.py
  class Dataset (line 16) | class Dataset(torch.utils.data.Dataset):
    method __init__ (line 18) | def __init__(self,opt,split="train"):
    method setup_loader (line 31) | def setup_loader(self,opt,shuffle=False,drop_last=False):
    method get_list (line 42) | def get_list(self,opt):
    method preload_worker (line 45) | def preload_worker(self,data_list,load_func,q,lock,idx_tqdm):
    method preload_threading (line 53) | def preload_threading(self,opt,load_func,data_str="images"):
    method __getitem__ (line 68) | def __getitem__(self,idx):
    method get_image (line 71) | def get_image(self,opt,idx):
    method generate_augmentation (line 74) | def generate_augmentation(self,opt):
    method preprocess_image (line 92) | def preprocess_image(self,opt,image,aug=None):
    method preprocess_camera (line 109) | def preprocess_camera(self,opt,intr,pose,aug=None):
    method apply_color_jitter (line 119) | def apply_color_jitter(self,opt,image,color_jitter):
    method __len__ (line 129) | def __len__(self):

FILE: data/blender.py
  class Dataset (line 17) | class Dataset(base.Dataset):
    method __init__ (line 19) | def __init__(self,opt,split="train",subset=None):
    method prefetch_all_data (line 36) | def prefetch_all_data(self,opt):
    method get_all_camera_poses (line 41) | def get_all_camera_poses(self,opt):
    method __getitem__ (line 46) | def __getitem__(self,idx):
    method get_image (line 61) | def get_image(self,opt,idx):
    method preprocess_image (line 66) | def preprocess_image(self,opt,image,aug=None):
    method get_camera (line 73) | def get_camera(self,opt,idx):
    method parse_raw_camera (line 81) | def parse_raw_camera(self,opt,pose_raw):

FILE: data/iphone.py
  class Dataset (line 17) | class Dataset(base.Dataset):
    method __init__ (line 19) | def __init__(self,opt,split="train",subset=None):
    method prefetch_all_data (line 36) | def prefetch_all_data(self,opt):
    method get_all_camera_poses (line 41) | def get_all_camera_poses(self,opt):
    method __getitem__ (line 45) | def __getitem__(self,idx):
    method get_image (line 60) | def get_image(self,opt,idx):
    method get_camera (line 65) | def get_camera(self,opt,idx):

FILE: data/llff.py
  class Dataset (line 17) | class Dataset(base.Dataset):
    method __init__ (line 19) | def __init__(self,opt,split="train",subset=None):
    method prefetch_all_data (line 37) | def prefetch_all_data(self,opt):
    method parse_cameras_and_bounds (line 42) | def parse_cameras_and_bounds(self,opt):
    method center_camera_poses (line 60) | def center_camera_poses(self,opt,poses):
    method get_all_camera_poses (line 71) | def get_all_camera_poses(self,opt):
    method __getitem__ (line 76) | def __getitem__(self,idx):
    method get_image (line 91) | def get_image(self,opt,idx):
    method get_camera (line 96) | def get_camera(self,opt,idx):
    method parse_raw_camera (line 104) | def parse_raw_camera(self,opt,pose_raw):

FILE: evaluate.py
  function main (line 12) | def main():

FILE: external/pohsun_ssim/pytorch_ssim/__init__.py
  function gaussian (line 7) | def gaussian(window_size, sigma):
  function create_window (line 11) | def create_window(window_size, channel):
  function _ssim (line 17) | def _ssim(img1, img2, window, window_size, channel, size_average = True):
  class SSIM (line 39) | class SSIM(torch.nn.Module):
    method __init__ (line 40) | def __init__(self, window_size = 11, size_average = True):
    method forward (line 47) | def forward(self, img1, img2):
  function ssim (line 65) | def ssim(img1, img2, window_size = 11, size_average = True):

FILE: model/barf.py
  class Model (line 19) | class Model(nerf.Model):
    method __init__ (line 21) | def __init__(self,opt):
    method build_networks (line 24) | def build_networks(self,opt):
    method setup_optimizer (line 49) | def setup_optimizer(self,opt):
    method train_iteration (line 62) | def train_iteration(self,opt,var,loader):
    method validate (line 79) | def validate(self,opt,ep=None):
    method log_scalars (line 85) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
    method visualize (line 100) | def visualize(self,opt,var,step=0,split="train"):
    method get_all_training_poses (line 109) | def get_all_training_poses(self,opt):
    method prealign_cameras (line 116) | def prealign_cameras(self,opt,pose,pose_GT):
    method evaluate_camera_alignment (line 134) | def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
    method evaluate_full (line 144) | def evaluate_full(self,opt):
    method evaluate_test_time_photometric_optim (line 163) | def evaluate_test_time_photometric_optim(self,opt,var):
    method generate_videos_pose (line 181) | def generate_videos_pose(self,opt):
  class Graph (line 217) | class Graph(nerf.Graph):
    method __init__ (line 219) | def __init__(self,opt):
    method forward (line 226) | def forward(self,opt,var,mode=None):
    method get_pose (line 249) | def get_pose(self,opt,var,mode=None):
  class NeRF (line 278) | class NeRF(nerf.NeRF):
    method __init__ (line 280) | def __init__(self,opt):
    method positional_encoding (line 284) | def positional_encoding(self,opt,input,L): # [B,...,N]

FILE: model/base.py
  class Model (line 18) | class Model():
    method __init__ (line 20) | def __init__(self,opt):
    method load_dataset (line 24) | def load_dataset(self,opt,eval_split="val"):
    method build_networks (line 34) | def build_networks(self,opt):
    method setup_optimizer (line 39) | def setup_optimizer(self,opt):
    method restore_checkpoint (line 49) | def restore_checkpoint(self,opt):
    method setup_visualizer (line 62) | def setup_visualizer(self,opt):
    method train (line 78) | def train(self,opt):
    method train_epoch (line 94) | def train_epoch(self,opt):
    method train_iteration (line 111) | def train_iteration(self,opt,var,loader):
    method summarize_loss (line 130) | def summarize_loss(self,opt,var,loss):
    method validate (line 145) | def validate(self,opt,ep=None):
    method log_scalars (line 165) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
    method visualize (line 175) | def visualize(self,opt,var,step=0,split="train"):
    method save_checkpoint (line 178) | def save_checkpoint(self,opt,ep=0,it=0,latest=False):
  class Graph (line 185) | class Graph(torch.nn.Module):
    method __init__ (line 187) | def __init__(self,opt):
    method forward (line 190) | def forward(self,opt,var,mode=None):
    method compute_loss (line 194) | def compute_loss(self,opt,var,mode=None):
    method L1_loss (line 199) | def L1_loss(self,pred,label=0):
    method MSE_loss (line 202) | def MSE_loss(self,pred,label=0):

FILE: model/l2g_nerf.py
  class Model (line 20) | class Model(nerf.Model):
    method __init__ (line 22) | def __init__(self,opt):
    method build_networks (line 25) | def build_networks(self,opt):
    method setup_optimizer (line 55) | def setup_optimizer(self,opt):
    method train_iteration (line 69) | def train_iteration(self,opt,var,loader):
    method validate (line 86) | def validate(self,opt,ep=None):
    method log_scalars (line 92) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
    method visualize (line 107) | def visualize(self,opt,var,step=0,split="train"):
    method get_all_training_poses (line 116) | def get_all_training_poses(self,opt):
    method prealign_cameras (line 123) | def prealign_cameras(self,opt,pose,pose_GT):
    method evaluate_camera_alignment (line 141) | def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):
    method evaluate_full (line 151) | def evaluate_full(self,opt):
    method evaluate_test_time_photometric_optim (line 170) | def evaluate_test_time_photometric_optim(self,opt,var):
    method generate_videos_pose (line 188) | def generate_videos_pose(self,opt):
  class Graph (line 224) | class Graph(nerf.Graph):
    method __init__ (line 226) | def __init__(self,opt):
    method get_pose (line 233) | def get_pose(self,opt,var,mode=None):
    method forward (line 274) | def forward(self,opt,var,mode=None):
    method compute_loss (line 314) | def compute_loss(self,opt,var,mode=None):
    method local_render (line 348) | def local_render(self,opt,local_pose,intr=None,ray_idx=None,mode=None):
  class NeRF (line 380) | class NeRF(nerf.NeRF):
    method __init__ (line 382) | def __init__(self,opt):
    method positional_encoding (line 386) | def positional_encoding(self,opt,input,L): # [B,...,N]
  class localWarp (line 400) | class localWarp(torch.nn.Module):
    method __init__ (line 401) | def __init__(self, opt):
    method forward (line 413) | def forward(self,opt,uvf):

FILE: model/l2g_planar.py
  class Model (line 21) | class Model(base.Model):
    method __init__ (line 23) | def __init__(self,opt):
    method load_dataset (line 27) | def load_dataset(self,opt,eval_split=None):
    method build_networks (line 31) | def build_networks(self,opt):
    method setup_optimizer (line 36) | def setup_optimizer(self,opt):
    method setup_visualizer (line 65) | def setup_visualizer(self,opt):
    method train (line 77) | def train(self,opt):
    method train_iteration (line 104) | def train_iteration(self,opt,var,loader):
    method generate_warp_perturbation (line 113) | def generate_warp_perturbation(self,opt):
    method visualize_patches_compose (line 167) | def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):
    method visualize_patches_use_matrix (line 181) | def visualize_patches_use_matrix(self,opt,warp_matrix):
    method predict_entire_image (line 196) | def predict_entire_image(self,opt):
    method log_scalars (line 203) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
    method visualize (line 222) | def visualize(self,opt,var,step=0,split="train"):
  class Graph (line 245) | class Graph(base.Graph):
    method __init__ (line 247) | def __init__(self,opt):
    method forward (line 251) | def forward(self,opt,var,mode=None):
    method compute_loss (line 292) | def compute_loss(self,opt,var,mode=None):
  class NeuralImageFunction (line 301) | class NeuralImageFunction(torch.nn.Module):
    method __init__ (line 303) | def __init__(self,opt):
    method define_network (line 308) | def define_network(self,opt):
    method forward (line 324) | def forward(self,opt,coord_2D): # [B,...,3]
    method positional_encoding (line 339) | def positional_encoding(self,opt,input,L): # [B,...,N]
  class localWarp (line 358) | class localWarp(torch.nn.Module):
    method __init__ (line 359) | def __init__(self, opt):
    method forward (line 371) | def forward(self,opt,uvf):

FILE: model/nerf.py
  class Model (line 19) | class Model(base.Model):
    method __init__ (line 21) | def __init__(self,opt):
    method load_dataset (line 25) | def load_dataset(self,opt,eval_split="val"):
    method setup_optimizer (line 31) | def setup_optimizer(self,opt):
    method train (line 46) | def train(self,opt):
    method log_scalars (line 71) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
    method visualize (line 88) | def visualize(self,opt,var,step=0,split="train",eps=1e-10):
    method get_all_training_poses (line 114) | def get_all_training_poses(self,opt):
    method evaluate_full (line 120) | def evaluate_full(self,opt,eps=1e-10):
    method generate_videos_synthesis (line 164) | def generate_videos_synthesis(self,opt,eps=1e-10):
  class Graph (line 210) | class Graph(base.Graph):
    method __init__ (line 212) | def __init__(self,opt):
    method forward (line 218) | def forward(self,opt,var,mode=None):
    method compute_loss (line 233) | def compute_loss(self,opt,var,mode=None):
    method get_pose (line 247) | def get_pose(self,opt,var,mode=None):
    method render (line 250) | def render(self,opt,pose,intr=None,ray_idx=None,mode=None):
    method render_by_slices (line 278) | def render_by_slices(self,opt,pose,intr=None,mode=None):
    method sample_depth (line 291) | def sample_depth(self,opt,batch_size,num_rays=None):
    method sample_depth_from_pdf (line 303) | def sample_depth_from_pdf(self,opt,pdf):
  class NeRF (line 324) | class NeRF(torch.nn.Module):
    method __init__ (line 326) | def __init__(self,opt):
    method define_network (line 330) | def define_network(self,opt):
    method tensorflow_init_weights (line 356) | def tensorflow_init_weights(self,opt,linear,out=None):
    method forward (line 368) | def forward(self,opt,points_3D,ray_unit=None,mode=None): # [B,...,3]
    method forward_samples (line 401) | def forward_samples(self,opt,center,ray,depth_samples,mode=None):
    method composite (line 410) | def composite(self,opt,ray,rgb_samples,density_samples,depth_samples):
    method positional_encoding (line 428) | def positional_encoding(self,opt,input,L): # [B,...,N]

FILE: model/planar.py
  class Model (line 20) | class Model(base.Model):
    method __init__ (line 22) | def __init__(self,opt):
    method load_dataset (line 26) | def load_dataset(self,opt,eval_split=None):
    method build_networks (line 30) | def build_networks(self,opt):
    method setup_optimizer (line 35) | def setup_optimizer(self,opt):
    method setup_visualizer (line 49) | def setup_visualizer(self,opt):
    method train (line 61) | def train(self,opt):
    method train_iteration (line 90) | def train_iteration(self,opt,var,loader):
    method generate_warp_perturbation (line 95) | def generate_warp_perturbation(self,opt):
    method visualize_patches (line 149) | def visualize_patches(self,opt,warp_param):
    method visualize_patches_compose (line 163) | def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):
    method predict_entire_image (line 178) | def predict_entire_image(self,opt):
    method log_scalars (line 185) | def log_scalars(self,opt,var,loss,metric=None,step=0,split="train"):
    method visualize (line 205) | def visualize(self,opt,var,step=0,split="train"):
  class Graph (line 224) | class Graph(base.Graph):
    method __init__ (line 226) | def __init__(self,opt):
    method forward (line 230) | def forward(self,opt,var,mode=None):
    method compute_loss (line 238) | def compute_loss(self,opt,var,mode=None):
  class NeuralImageFunction (line 245) | class NeuralImageFunction(torch.nn.Module):
    method __init__ (line 247) | def __init__(self,opt):
    method define_network (line 252) | def define_network(self,opt):
    method forward (line 268) | def forward(self,opt,coord_2D): # [B,...,3]
    method positional_encoding (line 283) | def positional_encoding(self,opt,input,L): # [B,...,N]

FILE: options.py
  function parse_arguments (line 16) | def parse_arguments(args):
  function set (line 41) | def set(opt_cmd={}):
  function load_options (line 54) | def load_options(fname):
  function override_options (line 69) | def override_options(opt,opt_over,key_stack=None,safe_check=False):
  function process_options (line 87) | def process_options(opt):
  function save_options_file (line 107) | def save_options_file(opt):

FILE: train.py
  function main (line 12) | def main():

FILE: util.py
  function red (line 15) | def red(message,**kwargs): return termcolor.colored(str(message),color="...
  function green (line 16) | def green(message,**kwargs): return termcolor.colored(str(message),color...
  function blue (line 17) | def blue(message,**kwargs): return termcolor.colored(str(message),color=...
  function cyan (line 18) | def cyan(message,**kwargs): return termcolor.colored(str(message),color=...
  function yellow (line 19) | def yellow(message,**kwargs): return termcolor.colored(str(message),colo...
  function magenta (line 20) | def magenta(message,**kwargs): return termcolor.colored(str(message),col...
  function grey (line 21) | def grey(message,**kwargs): return termcolor.colored(str(message),color=...
  function get_time (line 23) | def get_time(sec):
  function add_datetime (line 30) | def add_datetime(func):
  function add_functionname (line 37) | def add_functionname(func):
  function pre_post_actions (line 43) | def pre_post_actions(pre=None,post=None):
  class Log (line 55) | class Log:
    method __init__ (line 56) | def __init__(self): pass
    method process (line 57) | def process(self,pid):
    method title (line 59) | def title(self,message):
    method info (line 61) | def info(self,message):
    method options (line 63) | def options(self,opt,level=0):
    method loss_train (line 70) | def loss_train(self,opt,ep,lr,loss,timer):
    method loss_val (line 79) | def loss_val(self,opt,loss):
  function update_timer (line 85) | def update_timer(opt,timer,ep,it_per_ep):
  function move_to_device (line 95) | def move_to_device(X,device):
  function to_dict (line 110) | def to_dict(D,dict_type=dict):
  function get_child_state_dict (line 117) | def get_child_state_dict(state_dict,key):
  function restore_checkpoint (line 120) | def restore_checkpoint(opt,model,load_name=None,resume=False):
  function save_checkpoint (line 143) | def save_checkpoint(opt,model,ep,it,latest=False,children=None):
  function check_socket_open (line 161) | def check_socket_open(hostname,port):
  function get_layer_dims (line 172) | def get_layer_dims(layers):
  function suppress (line 177) | def suppress(stdout=False,stderr=False):
  function colorcode_to_number (line 186) | def colorcode_to_number(code):

FILE: util_vis.py
  function tb_image (line 16) | def tb_image(opt,tb,step,group,name,images,num_vis=None,from_range=(0,1)...
  function preprocess_vis_image (line 27) | def preprocess_vis_image(opt,images,from_range=(0,1),cmap="gray"):
  function dump_images (line 35) | def dump_images(opt,idx,name,images,masks=None,from_range=(0,1),cmap="gr...
  function get_heatmap (line 43) | def get_heatmap(opt,gray,cmap): # [N,H,W]
  function color_border (line 48) | def color_border(images,colors,width=3):
  function vis_cameras (line 58) | def vis_cameras(opt,vis,step,poses=[],colors=["blue","magenta"],plot_dis...
  function get_camera_mesh (line 141) | def get_camera_mesh(pose,depth=1):
  function merge_wireframes (line 157) | def merge_wireframes(wireframe):
  function merge_meshes (line 164) | def merge_meshes(vertices,faces):
  function merge_centers (line 169) | def merge_centers(centers):
  function plot_save_poses (line 177) | def plot_save_poses(opt,fig,pose,pose_ref=None,path=None,ep=None):
  function plot_save_poses_blender (line 220) | def plot_save_poses_blender(opt,fig,pose,pose_ref=None,path=None,ep=None):
  function setup_3D_plot (line 257) | def setup_3D_plot(ax,elev,azim,lim=None):
  function apply_colormap (line 294) | def apply_colormap(image, cmap="viridis"):
  function apply_depth_colormap (line 317) | def apply_depth_colormap(

FILE: warp.py
  function get_normalized_pixel_grid (line 11) | def get_normalized_pixel_grid(opt):
  function get_normalized_pixel_grid_crop (line 19) | def get_normalized_pixel_grid_crop(opt):
  function warp_grid (line 29) | def warp_grid(opt,xy_grid,warp):
  function warp_grid_use_matrix (line 51) | def warp_grid_use_matrix(opt,xy_grid,warp_matrix):
  function warp_corners (line 70) | def warp_corners(opt,warp_param):
  function warp_corners_compose (line 80) | def warp_corners_compose(opt, homo_pert, rot_pert, trans_pert):
  function warp_corners_use_matrix (line 97) | def warp_corners_use_matrix(opt,warp_matrix):
  function check_corners_in_range (line 107) | def check_corners_in_range(opt,warp_param):
  function check_corners_in_range_compose (line 113) | def check_corners_in_range_compose(opt, homo_pert, rot_pert, trans_pert):
  class Lie (line 120) | class Lie():
    method so2_to_SO2 (line 122) | def so2_to_SO2(self,theta): # [...,1]
    method SO2_to_so2 (line 128) | def SO2_to_so2(self,R): # [...,2,2]
    method so2_jacobian (line 132) | def so2_jacobian(self,X,theta): # [...,N,2],[...,1]
    method se2_to_SE2 (line 138) | def se2_to_SE2(self,delta): # [...,3]
    method SE2_to_se2 (line 148) | def SE2_to_se2(self,Rt,eps=1e-7): # [...,2,3]
    method se2_jacobian (line 160) | def se2_jacobian(self,X,delta): # [...,N,2],[...,3]
    method sl3_to_SL3 (line 178) | def sl3_to_SL3(self,h):
    method taylor_A (line 188) | def taylor_A(self,x,nth=10):
    method taylor_B (line 196) | def taylor_B(self,x,nth=10):
    method taylor_C (line 205) | def taylor_C(self,x,nth=10):
    method taylor_D (line 214) | def taylor_D(self,x,nth=10):
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (232K chars).
[
  {
    "path": "LICENSE",
    "chars": 1069,
    "preview": "MIT License\n\nCopyright (c) [year] [fullname]\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "README.md",
    "chars": 12824,
    "preview": "# L2G-NeRF: Local-to-Global Registration for Bundle-Adjusting Neural Radiance Fields\n**[Project Page](https://rover-xing"
  },
  {
    "path": "camera.py",
    "chars": 13464,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport collections\nfrom easydic"
  },
  {
    "path": "data/base.py",
    "chars": 4765,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "data/blender.py",
    "chars": 3345,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "data/iphone.py",
    "chars": 2826,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "data/llff.py",
    "chars": 4465,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "evaluate.py",
    "chars": 884,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport importlib\n\nimport options\nfrom util import log\n\nimport warning"
  },
  {
    "path": "external/pohsun_ssim/LICENSE.txt",
    "chars": 4,
    "preview": "MIT\n"
  },
  {
    "path": "external/pohsun_ssim/README.md",
    "chars": 1769,
    "preview": "# pytorch-ssim\n\n### Differentiable structural similarity (SSIM) index.\n![einstein](https://raw.githubusercontent.com/Po-"
  },
  {
    "path": "external/pohsun_ssim/max_ssim.py",
    "chars": 941,
    "preview": "import pytorch_ssim\nimport torch\nfrom torch.autograd import Variable\nfrom torch import optim\nimport cv2\nimport numpy as "
  },
  {
    "path": "external/pohsun_ssim/pytorch_ssim/__init__.py",
    "chars": 2635,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nimport numpy as np\nfrom math import exp"
  },
  {
    "path": "external/pohsun_ssim/setup.cfg",
    "chars": 40,
    "preview": "[metadata]\ndescription-file = README.md\n"
  },
  {
    "path": "external/pohsun_ssim/setup.py",
    "chars": 610,
    "preview": "from distutils.core import setup\nsetup(\n  name = 'pytorch_ssim',\n  packages = ['pytorch_ssim'], # this must be the same "
  },
  {
    "path": "extract_mesh.py",
    "chars": 1615,
    "preview": "\"\"\"Extracts a 3D mesh from a pretrained model using marching cubes.\"\"\"\n\nimport importlib\nimport sys\n\nimport numpy as np\n"
  },
  {
    "path": "model/barf.py",
    "chars": 14328,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "model/base.py",
    "chars": 8318,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "model/l2g_nerf.py",
    "chars": 21385,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "model/l2g_planar.py",
    "chars": 19457,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "model/nerf.py",
    "chars": 23224,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "model/planar.py",
    "chars": 14646,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "options/barf_blender.yaml",
    "chars": 1722,
    "preview": "_parent_: options/nerf_blender.yaml\n\nbarf_c2f: [0.1,0.5]                                         # coarse-to-fine schedu"
  },
  {
    "path": "options/barf_iphone.yaml",
    "chars": 352,
    "preview": "_parent_: options/barf_llff.yaml\n\ndata:                                                       # data options\n    dataset"
  },
  {
    "path": "options/barf_llff.yaml",
    "chars": 1477,
    "preview": "_parent_: options/nerf_llff.yaml\n\nbarf_c2f: [0.1,0.5]                                         # coarse-to-fine schedulin"
  },
  {
    "path": "options/base.yaml",
    "chars": 4634,
    "preview": "# default\n\ngroup: 0_test                                               # name of experiment group\nname: debug           "
  },
  {
    "path": "options/l2g_nerf_blender.yaml",
    "chars": 875,
    "preview": "_parent_: options/barf_blender.yaml\n\narch:                                                       # architectural options"
  },
  {
    "path": "options/l2g_nerf_iphone.yaml",
    "chars": 549,
    "preview": "_parent_: options/l2g_nerf_llff.yaml\n\ndata:                                                       # data options\n    dat"
  },
  {
    "path": "options/l2g_nerf_llff.yaml",
    "chars": 872,
    "preview": "_parent_: options/barf_llff.yaml\n\narch:                                                       # architectural options\n  "
  },
  {
    "path": "options/l2g_planar.yaml",
    "chars": 1613,
    "preview": "_parent_: options/planar.yaml\n\narch:                                                       # architectural options\n    l"
  },
  {
    "path": "options/nerf_blender.yaml",
    "chars": 5254,
    "preview": "_parent_: options/base.yaml\n\narch:                                                       # architectural options\n    lay"
  },
  {
    "path": "options/nerf_blender_repr.yaml",
    "chars": 5254,
    "preview": "_parent_: options/base.yaml\n\narch:                                                       # architectural options\n    lay"
  },
  {
    "path": "options/nerf_llff.yaml",
    "chars": 4636,
    "preview": "_parent_: options/base.yaml\n\narch:                                                       # architectural optionss\n    la"
  },
  {
    "path": "options/nerf_llff_repr.yaml",
    "chars": 4632,
    "preview": "_parent_: options/base.yaml\n\narch:                                                       # architectural options\n    lay"
  },
  {
    "path": "options/planar.yaml",
    "chars": 2670,
    "preview": "_parent_: options/base.yaml\n\narch:                                                       # architectural options\n    lay"
  },
  {
    "path": "options.py",
    "chars": 4975,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport random\nimport string\nimport yaml\nfrom easydict import EasyDict"
  },
  {
    "path": "requirements.yaml",
    "chars": 339,
    "preview": "name: L2G-NeRF\nchannels:\n  - conda-forge\n  - pytorch\ndependencies:\n  - numpy\n  - scipy\n  - tqdm\n  - termcolor\n  - easydi"
  },
  {
    "path": "train.py",
    "chars": 771,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport importlib\n\nimport options\nfrom util import log\n\nimport warning"
  },
  {
    "path": "util.py",
    "chars": 7652,
    "preview": "import numpy as np\nimport os,sys,time\nimport shutil\nimport datetime\nimport torch\nimport torch.nn.functional as torch_F\ni"
  },
  {
    "path": "util_vis.py",
    "chars": 13340,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torch"
  },
  {
    "path": "warp.py",
    "chars": 9696,
    "preview": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\n\nimport util\nfrom util import l"
  }
]

About this extraction

This page contains the full source code of the rover-xingyu/L2G-NeRF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (218.7 KB), approximately 58.3k tokens, and a symbol index with 296 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!