[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) [year] [fullname]\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# L2G-NeRF: Local-to-Global Registration for Bundle-Adjusting Neural Radiance Fields\n**[Project Page](https://rover-xingyu.github.io/L2G-NeRF/) |\n[Paper](https://arxiv.org/pdf/2211.11505.pdf) |\n[Video](https://www.youtube.com/watch?v=y8XP9Umt6Mw)**\n\n[Yue Chen¹](https://scholar.google.com/citations?user=M2hq1_UAAAAJ&hl=en), \n[Xingyu Chen¹](https://scholar.google.com/citations?user=gDHPrWEAAAAJ&hl=en), \n[Xuan Wang²](https://scholar.google.com/citations?user=h-3xd3EAAAAJ&hl=en),\n[Qi Zhang³](https://scholar.google.com/citations?user=2vFjhHMAAAAJ&hl=en), \n[Yu Guo¹](https://scholar.google.com/citations?user=OemeiSIAAAAJ&hl=en), \n[Ying Shan³](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en), \n[Fei Wang¹](https://scholar.google.com/citations?user=uU2JTpUAAAAJ&hl=en). \n\n[¹Xi'an Jiaotong University](http://en.xjtu.edu.cn/),\n[²Ant Group](https://www.antgroup.com/en),\n[³Tencent AI Lab](https://ai.tencent.com/ailab/en/index/). \n\nThis repository is an official implementation of [L2G-NeRF](https://rover-xingyu.github.io/L2G-NeRF/) using [pytorch](https://pytorch.org/). \n\n# :computer: Installation\n\n## Hardware\n\n* We implement all experiments on a single NVIDIA GeForce RTX 2080 Ti GPU. \n* L2G-NeRF takes about 4.5 and 8 hours for training in synthetic objects and real-world scenes, respectively, while training BARF takes about 8 and 10.5 hours.\n\n## Software\n\n* Clone this repo by `git clone https://github.com/rover-xingyu/L2G-NeRF`\n* This code is developed with Python3. PyTorch 1.9+ is required. It is recommended use [Anaconda](https://www.anaconda.com/products/individual) to set up the environment, use `conda env create --file requirements.yaml python=3` to install the dependencies and activate it by `conda activate L2G-NeRF`\n--------------------------------------\n\n# :key: Training and Evaluation\n\n## Data download\n\nBoth the Blender synthetic data and LLFF real-world data can be found in the [NeRF Google Drive](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1).\nFor convenience, you can download them with the following script:\n  ```bash\n  # Blender\n  gdown --id 18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG # download nerf_synthetic.zip\n  unzip nerf_synthetic.zip\n  rm -f nerf_synthetic.zip\n  mv nerf_synthetic data/blender\n  # LLFF\n  gdown --id 16VnMcF1KJYxN9QId6TClMsZRahHNMW5g # download nerf_llff_data.zip\n  unzip nerf_llff_data.zip\n  rm -f nerf_llff_data.zip\n  mv nerf_llff_data data/llff\n  ```\n--------------------------------------\n\n## Running L2G-NeRF\nTo train and evaluate L2G-NeRF:\n```bash\n# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets\n\n# NeRF (3D): Synthetic Objects\n# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})\npython3 train.py \\\n--model=l2g_nerf --yaml=l2g_nerf_blender \\\n--group=exp_synthetic --name=l2g_lego \\\n--data.scene=lego --gpu=3 \\\n--data.root=/the/data/path/of/nerf_synthetic/ \\\n--camera.noise_r=0.07 --camera.noise_t=0.5\n\npython3 evaluate.py \\\n--model=l2g_nerf --yaml=l2g_nerf_blender \\\n--group=exp_synthetic --name=l2g_lego \\\n--data.scene=lego --gpu=3 \\\n--data.root=/the/data/path/of/nerf_synthetic/ \\\n--data.val_sub= --resume\n\n# NeRF (3D): Real-World Scenes\n# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})\npython3 train.py \\\n--model=l2g_nerf --yaml=l2g_nerf_llff \\\n--group=exp_LLFF --name=l2g_fern \\\n--data.scene=fern --gpu=3 \\\n--data.root=/the/data/path/of/nerf_llff_data/ \\\n--loss_weight.global_alignment=2\n\npython3 evaluate.py \\\n--model=l2g_nerf --yaml=l2g_nerf_llff \\\n--group=exp_LLFF --name=l2g_fern \\\n--data.scene=fern --gpu=3 \\\n--data.root=/the/data/path/of/nerf_llff_data/ \\\n--loss_weight.global_alignment=2 \\\n--resume\n\n# Neural image Alignment (2D): Rigid\n# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment\npython3 train.py \\\n--model=l2g_planar --yaml=l2g_planar \\\n--group=exp_planar --name=l2g_girl  \\\n--warp.type=rigid --warp.dof=3 \\\n--data.image_fname=data/girl.jpg \\\n--data.image_size=[595,512] \\\n--data.patch_crop=[260,260] \\\n--seed=1 --gpu=3\n\n# Neural image Alignment (2D): Homography\n# use the image of “cat” from ImageNet for homography image alignment\npython3 train.py \\\n--model=l2g_planar --yaml=l2g_planar \\\n--group=exp_planar --name=l2g_cat \\\n--warp.type=homography --warp.dof=8 \\\n--data.image_fname=data/cat.jpg \\\n--data.image_size=[360,480] \\\n--data.patch_crop=[180,180] \\\n--gpu=0\n```\n--------------------------------------\n\n## Running BARF\nIf you want to train and evaluate the BARF extension of the original NeRF model that jointly optimizes poses (coarse-to-fine positional encoding):\n```bash\n# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets\n\n# NeRF (3D): Synthetic Objects\n# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})\npython3 train.py \\\n--model=barf --yaml=barf_blender \\\n--group=exp_synthetic --name=barf_lego \\\n--data.scene=lego --gpu=1 \\\n--data.root=/the/data/path/of/nerf_synthetic/ \\\n--camera.noise_r=0.07 --camera.noise_t=0.5\n\npython3 evaluate.py \\\n--model=barf --yaml=barf_blender \\\n--group=exp_synthetic --name=barf_lego \\\n--data.scene=lego --gpu=1 \\\n--data.root=/the/data/path/of/nerf_synthetic/ \\\n--data.val_sub= --resume\n\n# NeRF (3D): Real-World Scenes\n# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})\npython3 train.py \\\n--model=barf --yaml=barf_llff \\\n--group=exp_LLFF --name=barf_fern \\\n--data.scene=fern --gpu=1 \\\n--data.root=/the/data/path/of/nerf_llff_data/\n\npython3 evaluate.py \\\n--model=barf --yaml=barf_llff \\\n--group=exp_LLFF --name=barf_fern \\\n--data.scene=fern --gpu=1 \\\n--data.root=/the/data/path/of/nerf_llff_data/ \\\n--resume\n\n# Neural image Alignment (2D): Rigid\n# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment\npython3 train.py \\\n--model=planar --yaml=planar \\\n--group=exp_planar --name=barf_girl \\\n--warp.type=rigid --warp.dof=3 \\\n--data.image_fname=data/girl.jpg \\\n--data.image_size=[595,512] \\\n--data.patch_crop=[260,260] \\\n--seed=1 --gpu=1\n\n# Neural image Alignment (2D): Homography\n# use the image of “cat” from ImageNet for homography image alignment\npython3 train.py \\\n--model=planar --yaml=planar \\\n--group=exp_planar --name=barf_cat \\\n--warp.type=homography --warp.dof=8 \\\n--data.image_fname=data/cat.jpg \\\n--data.image_size=[360,480] \\\n--data.patch_crop=[180,180] \\\n--gpu=2\n```\n--------------------------------------\n\n## Running Naive\nIf you want to train and evaluate the Naive extension of the original NeRF model that jointly optimizes poses (full positional encoding):\n\n```bash\n# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets\n\n# NeRF (3D): Synthetic Objects\n# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})\npython3 train.py \\\n--model=barf --yaml=barf_blender \\\n--group=exp_synthetic --name=nerf_lego \\\n--data.scene=lego --gpu=2 \\\n--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \\\n--barf_c2f=null \\\n--camera.noise_r=0.07 --camera.noise_t=0.5\n\npython3 evaluate.py \\\n--model=barf --yaml=barf_blender \\\n--group=exp_synthetic --name=nerf_lego \\\n--data.scene=lego --gpu=2 \\\n--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \\\n--barf_c2f=null \\\n--data.val_sub= --resume\n\n# NeRF (3D): Real-World Scenes\n# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})\npython3 train.py \\\n--model=barf --yaml=barf_llff \\\n--group=exp_LLFF --name=nerf_fern \\\n--data.scene=fern --gpu=2 \\\n--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \\\n--barf_c2f=null\n\npython3 evaluate.py \\\n--model=barf --yaml=barf_llff \\\n--group=exp_LLFF --name=nerf_fern \\\n--data.scene=fern --gpu=2 \\\n--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \\\n--barf_c2f=null --resume \n\n# Neural image Alignment (2D): Rigid\n# use the image of “Girl With a Pearl Earring” renovation ©Koorosh Orooj (CC BY-SA 4.0) for rigid image alignment\npython3 train.py \\\n--model=planar --yaml=planar \\\n--group=exp_planar --name=naive_girl \\\n--warp.type=rigid --warp.dof=3 \\\n--data.image_fname=data/girl.jpg \\\n--data.image_size=[595,512] \\\n--data.patch_crop=[260,260] \\\n--seed=1 --gpu=2 --barf_c2f=null\n\n# Neural image Alignment (2D): Homography\n# use the image of “cat” from ImageNet for homography image alignment\npython3 train.py \\\n--model=planar --yaml=planar \\\n--group=exp_planar --name=naive_cat \\\n--warp.type=homography --warp.dof=8 \\\n--data.image_fname=data/cat.jpg \\\n--data.image_size=[360,480] \\\n--data.patch_crop=[180,180] \\\n--gpu=3 --barf_c2f=null\n```\n--------------------------------------\n\n## Running reference NeRF\nIf you want to train and evaluate the reference NeRF models (assuming known camera poses):\n```bash\n# <GROUP> and <NAME> can be set to your likes, while <SCENE> is specific to datasets\n\n# NeRF (3D): Synthetic Objects\n# Blender (<SCENE>={chair,drums,ficus,hotdog,lego,materials,mic,ship})\npython3 train.py \\\n--model=nerf --yaml=nerf_blender \\\n--group=exp_synthetic --name=ref_lego \\\n--data.scene=lego --gpu=0 \\\n--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \n\npython3 evaluate.py \\\n--model=nerf --yaml=nerf_blender \\\n--group=exp_synthetic --name=ref_lego \\\n--data.scene=lego --gpu=0 \\\n--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \\\n--data.val_sub= --resume\n\n# NeRF (3D): Real-World Scenes\n# LLFF (<SCENE>={fern,flower,fortress,horns,leaves,orchids,room,trex})\npython3 train.py \\\n--model=nerf --yaml=nerf_llff \\\n--group=exp_LLFF --name=ref_fern \\\n--data.scene=fern --gpu=0 \\\n--data.root=/home/cy/PNW/datasets/nerf_llff_data/\n\n\npython3 evaluate.py \\\n--model=nerf --yaml=nerf_llff \\\n--group=exp_LLFF --name=ref_fern \\\n--data.scene=fern --gpu=0 \\\n--data.root=/home/cy/PNW/datasets/nerf_llff_data/ \\\n--resume\n```\n--------------------------------------\n\n# :laughing: Visualization\n\n## Results and Videos\nAll the results will be stored in the directory `output/<GROUP>/<NAME>`.\nYou may want to organize your experiments by grouping different runs in the same group. Many videos will be created to visualize the pose optimization process and novel view synthesis.\n\n## TensorBoard\nThe TensorBoard events include the following:\n- **SCALARS**: the rendering losses and PSNR over the course of optimization. For L2G_NeRF/BARF/Naive, the rotational/translational errors with respect to the given poses are also computed.\n- **IMAGES**: visualization of the RGB images and the RGB/depth rendering.\n\n## Visdom\nThe visualization of 3D camera poses is provided in Visdom:\nRun `visdom -port 8600` to start the Visdom server.  The Visdom host server is default to `localhost`; this can be overridden with `--visdom.server` (see `options/base.yaml` for details). If you want to disable Visdom visualization, add `--visdom!`.\n\n## Mesh\nThe `extract_mesh.py` script provides a simple way to extract the underlying 3D geometry using marching cubes (supporte for the Blender dataset). Run as follows:\n```bash\npython3 extract_mesh.py \\\n--model=l2g_nerf --yaml=l2g_nerf_blender \\\n--group=exp_synthetic --name=l2g_lego \\\n--data.scene=lego --gpu=3 \\\n--data.root=/home/cy/PNW/datasets/nerf_synthetic/ \\\n--data.val_sub= --resume\n```\n--------------------------------------\n\n# :mag_right: Codebase structure\n\nThe main engine and network architecture in `model/l2g_nerf.py` inherit those from `model/nerf.py`.\n\nSome tips on using and understanding the codebase:\n- The computation graph for forward/backprop is stored in `var` throughout the codebase.\n- The losses are stored in `loss`. To add a new loss function, just implement it in `compute_loss()` and add its weight to `opt.loss_weight.<name>`. It will automatically be added to the overall loss and logged to Tensorboard.\n- If you are using a multi-GPU machine, you can set `--gpu=<gpu_number>` to specify which GPU to use. Multi-GPU training/evaluation is currently not supported.\n- To resume from a previous checkpoint, add `--resume=<ITER_NUMBER>`, or just `--resume` to resume from the latest checkpoint.\n- To eliminate the global alignment objective, set `--loss_weight.global_alignment=null`, the ablation is equivalent to a local registration method.\n\n--------------------------------------\n# Citation\n\nIf you find this project useful for your research, please use the following BibTeX entry.\n\n```bibtex\n@inproceedings{chen2023local,\n  title={Local-to-global registration for bundle-adjusting neural radiance fields},\n  author={Chen, Yue and Chen, Xingyu and Wang, Xuan and Zhang, Qi and Guo, Yu and Shan, Ying and Wang, Fei},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  pages={8264--8273},\n  year={2023}\n}\n```\n--------------------------------------\n\n# Acknowledge\nOur code is based on the awesome pytorch implementation of Bundle-Adjusting Neural Radiance Fields ([BARF](https://github.com/chenhsuanlin/bundle-adjusting-NeRF)). We appreciate all the contributors.\n"
  },
  {
    "path": "camera.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport collections\nfrom easydict import EasyDict as edict\n\nimport util\nfrom util import log,debug\n\nclass Pose():\n    \"\"\"\n    A class of operations on camera poses (PyTorch tensors with shape [...,3,4])\n    each [3,4] camera pose takes the form of [R|t]\n    \"\"\"\n\n    def __call__(self,R=None,t=None):\n        # construct a camera pose from the given R and/or t\n        assert(R is not None or t is not None)\n        if R is None:\n            if not isinstance(t,torch.Tensor): t = torch.tensor(t)\n            R = torch.eye(3,device=t.device).repeat(*t.shape[:-1],1,1)\n        elif t is None:\n            if not isinstance(R,torch.Tensor): R = torch.tensor(R)\n            t = torch.zeros(R.shape[:-1],device=R.device)\n        else:\n            if not isinstance(R,torch.Tensor): R = torch.tensor(R)\n            if not isinstance(t,torch.Tensor): t = torch.tensor(t)\n        assert(R.shape[:-1]==t.shape and R.shape[-2:]==(3,3))\n        R = R.float()\n        t = t.float()\n        pose = torch.cat([R,t[...,None]],dim=-1) # [...,3,4]\n        assert(pose.shape[-2:]==(3,4))\n        return pose\n\n    def invert(self,pose,use_inverse=False):\n        # invert a camera pose\n        R,t = pose[...,:3],pose[...,3:]\n        R_inv = R.inverse() if use_inverse else R.transpose(-1,-2)\n        t_inv = (-R_inv@t)[...,0]\n        pose_inv = self(R=R_inv,t=t_inv)\n        return pose_inv\n\n    def compose(self,pose_list):\n        # compose a sequence of poses together\n        # pose_new(x) = poseN o ... o pose2 o pose1(x)\n        pose_new = pose_list[0]\n        for pose in pose_list[1:]:\n            pose_new = self.compose_pair(pose_new,pose)\n        return pose_new\n\n    def compose_pair(self,pose_a,pose_b):\n        # pose_new(x) = pose_b o pose_a(x)\n        R_a,t_a = pose_a[...,:3],pose_a[...,3:]\n        R_b,t_b = pose_b[...,:3],pose_b[...,3:]\n        R_new = R_b@R_a\n        t_new = (R_b@t_a+t_b)[...,0]\n        pose_new = self(R=R_new,t=t_new)\n        return pose_new\n\nclass Lie():\n    \"\"\"\n    Lie algebra for SO(3) and SE(3) operations in PyTorch\n    \"\"\"\n\n    def so3_to_SO3(self,w): # [...,3]\n        wx = self.skew_symmetric(w)\n        theta = w.norm(dim=-1)[...,None,None]\n        I = torch.eye(3,device=w.device,dtype=torch.float32)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        R = I+A*wx+B*wx@wx\n        return R\n\n    def SO3_to_so3(self,R,eps=1e-7): # [...,3,3]\n        trace = R[...,0,0]+R[...,1,1]+R[...,2,2]\n        theta = ((trace-1)/2).clamp(-1+eps,1-eps).acos_()[...,None,None]%np.pi # ln(R) will explode if theta==pi\n        lnR = 1/(2*self.taylor_A(theta)+1e-8)*(R-R.transpose(-2,-1)) # FIXME: wei-chiu finds it weird\n        w0,w1,w2 = lnR[...,2,1],lnR[...,0,2],lnR[...,1,0]\n        w = torch.stack([w0,w1,w2],dim=-1)\n        return w\n\n    def se3_to_SE3(self,wu): # [...,3]\n        w,u = wu.split([3,3],dim=-1)\n        wx = self.skew_symmetric(w)\n        theta = w.norm(dim=-1)[...,None,None]\n        I = torch.eye(3,device=w.device,dtype=torch.float32)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        C = self.taylor_C(theta)\n        R = I+A*wx+B*wx@wx\n        V = I+B*wx+C*wx@wx\n        Rt = torch.cat([R,(V@u[...,None])],dim=-1)\n        return Rt\n\n    def SE3_to_se3(self,Rt,eps=1e-8): # [...,3,4]\n        R,t = Rt.split([3,1],dim=-1)\n        w = self.SO3_to_so3(R)\n        wx = self.skew_symmetric(w)\n        theta = w.norm(dim=-1)[...,None,None]\n        I = torch.eye(3,device=w.device,dtype=torch.float32)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        invV = I-0.5*wx+(1-A/(2*B))/(theta**2+eps)*wx@wx\n        u = (invV@t)[...,0]\n        wu = torch.cat([w,u],dim=-1)\n        return wu    \n\n    def skew_symmetric(self,w):\n        w0,w1,w2 = w.unbind(dim=-1)\n        O = torch.zeros_like(w0)\n        wx = torch.stack([torch.stack([O,-w2,w1],dim=-1),\n                          torch.stack([w2,O,-w0],dim=-1),\n                          torch.stack([-w1,w0,O],dim=-1)],dim=-2)\n        return wx\n\n    def taylor_A(self,x,nth=10):\n        # Taylor expansion of sin(x)/x\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            if i>0: denom *= (2*i)*(2*i+1)\n            ans = ans+(-1)**i*x**(2*i)/denom\n        return ans\n    def taylor_B(self,x,nth=10):\n        # Taylor expansion of (1-cos(x))/x**2\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            denom *= (2*i+1)*(2*i+2)\n            ans = ans+(-1)**i*x**(2*i)/denom\n        return ans\n    def taylor_C(self,x,nth=10):\n        # Taylor expansion of (x-sin(x))/x**3\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            denom *= (2*i+2)*(2*i+3)\n            ans = ans+(-1)**i*x**(2*i)/denom\n        return ans\n\nclass Quaternion():\n\n    def q_to_R(self,q):\n        # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion\n        qa,qb,qc,qd = q.unbind(dim=-1)\n        R = torch.stack([torch.stack([1-2*(qc**2+qd**2),2*(qb*qc-qa*qd),2*(qa*qc+qb*qd)],dim=-1),\n                         torch.stack([2*(qb*qc+qa*qd),1-2*(qb**2+qd**2),2*(qc*qd-qa*qb)],dim=-1),\n                         torch.stack([2*(qb*qd-qa*qc),2*(qa*qb+qc*qd),1-2*(qb**2+qc**2)],dim=-1)],dim=-2)\n        return R\n\n    def R_to_q(self,R,eps=1e-8): # [B,3,3]\n        # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion\n        # FIXME: this function seems a bit problematic, need to double-check\n        row0,row1,row2 = R.unbind(dim=-2)\n        R00,R01,R02 = row0.unbind(dim=-1)\n        R10,R11,R12 = row1.unbind(dim=-1)\n        R20,R21,R22 = row2.unbind(dim=-1)\n        t = R[...,0,0]+R[...,1,1]+R[...,2,2]\n        r = (1+t+eps).sqrt()\n        qa = 0.5*r\n        qb = (R21-R12).sign()*0.5*(1+R00-R11-R22+eps).sqrt()\n        qc = (R02-R20).sign()*0.5*(1-R00+R11-R22+eps).sqrt()\n        qd = (R10-R01).sign()*0.5*(1-R00-R11+R22+eps).sqrt()\n        q = torch.stack([qa,qb,qc,qd],dim=-1)\n        for i,qi in enumerate(q):\n            if torch.isnan(qi).any():\n                K = torch.stack([torch.stack([R00-R11-R22,R10+R01,R20+R02,R12-R21],dim=-1),\n                                 torch.stack([R10+R01,R11-R00-R22,R21+R12,R20-R02],dim=-1),\n                                 torch.stack([R20+R02,R21+R12,R22-R00-R11,R01-R10],dim=-1),\n                                 torch.stack([R12-R21,R20-R02,R01-R10,R00+R11+R22],dim=-1)],dim=-2)/3.0\n                K = K[i]\n                eigval,eigvec = torch.linalg.eigh(K)\n                V = eigvec[:,eigval.argmax()]\n                q[i] = torch.stack([V[3],V[0],V[1],V[2]])\n        return q\n\n    def invert(self,q):\n        qa,qb,qc,qd = q.unbind(dim=-1)\n        norm = q.norm(dim=-1,keepdim=True)\n        q_inv = torch.stack([qa,-qb,-qc,-qd],dim=-1)/norm**2\n        return q_inv\n\n    def product(self,q1,q2): # [B,4]\n        q1a,q1b,q1c,q1d = q1.unbind(dim=-1)\n        q2a,q2b,q2c,q2d = q2.unbind(dim=-1)\n        hamil_prod = torch.stack([q1a*q2a-q1b*q2b-q1c*q2c-q1d*q2d,\n                                  q1a*q2b+q1b*q2a+q1c*q2d-q1d*q2c,\n                                  q1a*q2c-q1b*q2d+q1c*q2a+q1d*q2b,\n                                  q1a*q2d+q1b*q2c-q1c*q2b+q1d*q2a],dim=-1)\n        return hamil_prod\n\npose = Pose()\nlie = Lie()\nquaternion = Quaternion()\n\ndef to_hom(X):\n    # get homogeneous coordinates of the input\n    X_hom = torch.cat([X,torch.ones_like(X[...,:1])],dim=-1)\n    return X_hom\n\n# basic operations of transforming 3D points between world/camera/image coordinates\ndef world2cam(X,pose): # [B,N,3]\n    X_hom = to_hom(X)\n    return X_hom@pose.transpose(-1,-2)\ndef cam2img(X,cam_intr):\n    return X@cam_intr.transpose(-1,-2)\ndef img2cam(X,cam_intr):\n    return X@cam_intr.inverse().transpose(-1,-2)\ndef cam2world(X,pose):\n    X_hom = to_hom(X)\n    pose_inv = Pose().invert(pose)\n    return X_hom@pose_inv.transpose(-1,-2)\n\ndef angle_to_rotation_matrix(a,axis):\n    # get the rotation matrix from Euler angle around specific axis\n    roll = dict(X=1,Y=2,Z=0)[axis]\n    O = torch.zeros_like(a)\n    I = torch.ones_like(a)\n    M = torch.stack([torch.stack([a.cos(),-a.sin(),O],dim=-1),\n                     torch.stack([a.sin(),a.cos(),O],dim=-1),\n                     torch.stack([O,O,I],dim=-1)],dim=-2)\n    M = M.roll((roll,roll),dims=(-2,-1))\n    return M\n\ndef get_center_and_ray(opt,pose,intr=None): # [HW,2]\n    # given the intrinsic/extrinsic matrices, get the camera center and ray directions]\n    assert(opt.camera.model==\"perspective\")\n    with torch.no_grad():\n        # compute image coordinate grid\n        y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)\n        x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)\n        Y,X = torch.meshgrid(y_range,x_range) # [H,W]\n        xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]\n    # compute center and ray\n    batch_size = len(pose)\n    xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]\n    grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]\n    center_3D = torch.zeros_like(grid_3D) # [B,HW,3]\n    # transform from camera to world coordinates\n    grid_3D = cam2world(grid_3D,pose) # [B,HW,3]\n    center_3D = cam2world(center_3D,pose) # [B,HW,3]\n    ray = grid_3D-center_3D # [B,HW,3]\n    return center_3D,ray\n\ndef get_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [HW,2]\n    # given the intrinsic matrices, get the grid_3D]\n    assert(opt.camera.model==\"perspective\")\n    with torch.no_grad():\n        # compute image coordinate grid\n        y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)\n        x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)\n        Y,X = torch.meshgrid(y_range,x_range) # [H,W]\n        xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]\n        # compute grid_3D\n        xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]\n        grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]\n        if ray_idx is not None:\n            # consider only subset of rays\n            grid_3D = grid_3D[:,ray_idx]\n    return grid_3D\n\ndef gather_camera_cords_grid_3D(opt,batch_size,intr=None,ray_idx=None): # [HW,2]\n    # given the intrinsic matrices, get the grid_3D]\n    assert(opt.camera.model==\"perspective\")\n    with torch.no_grad():\n        # compute image coordinate grid\n        y_range = torch.arange(opt.H,dtype=torch.float32,device=opt.device).add_(0.5)\n        x_range = torch.arange(opt.W,dtype=torch.float32,device=opt.device).add_(0.5)\n        Y,X = torch.meshgrid(y_range,x_range) # [H,W]\n        xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]\n        # compute grid_3D\n        xy_grid = xy_grid.repeat(batch_size,1,1) # [B,HW,2]\n        grid_3D = img2cam(to_hom(xy_grid),intr) # [B,HW,3]\n        if ray_idx is not None:\n            # consider only subset of rays\n            grid_3D = torch.gather(grid_3D, 1, ray_idx[...,None].expand(-1,-1,3))\n    return grid_3D\n\ndef get_3D_points_from_depth(opt,center,ray,depth,multi_samples=False):\n    if multi_samples: center,ray = center[:,:,None],ray[:,:,None]\n    # x = c+dv\n    points_3D = center+ray*depth # [B,HW,3]/[B,HW,N,3]/[N,3]\n    return points_3D\n\ndef convert_NDC(opt,center,ray,intr,near=1):\n    # shift camera center (ray origins) to near plane (z=1)\n    # (unlike conventional NDC, we assume the cameras are facing towards the +z direction)\n    center = center+(near-center[...,2:])/ray[...,2:]*ray\n    # projection\n    cx,cy,cz = center.unbind(dim=-1) # [B,HW]\n    rx,ry,rz = ray.unbind(dim=-1) # [B,HW]\n    scale_x = intr[:,0,0]/intr[:,0,2] # [B]\n    scale_y = intr[:,1,1]/intr[:,1,2] # [B]\n    cnx = scale_x[:,None]*(cx/cz)\n    cny = scale_y[:,None]*(cy/cz)\n    cnz = 1-2*near/cz\n    rnx = scale_x[:,None]*(rx/rz-cx/cz)\n    rny = scale_y[:,None]*(ry/rz-cy/cz)\n    rnz = 2*near/cz\n    center_ndc = torch.stack([cnx,cny,cnz],dim=-1) # [B,HW,3]\n    ray_ndc = torch.stack([rnx,rny,rnz],dim=-1) # [B,HW,3]\n    return center_ndc,ray_ndc\n\ndef rotation_distance(R1,R2,eps=1e-7):\n    # http://www.boris-belousov.net/2016/12/01/quat-dist/\n    R_diff = R1@R2.transpose(-2,-1)\n    trace = R_diff[...,0,0]+R_diff[...,1,1]+R_diff[...,2,2]\n    angle = ((trace-1)/2).clamp(-1+eps,1-eps).acos_() # numerical stability near -1/+1\n    return angle\n\ndef procrustes_analysis(X0,X1): # [N,3]\n    # translation\n    t0 = X0.mean(dim=0,keepdim=True)\n    t1 = X1.mean(dim=0,keepdim=True)\n    X0c = X0-t0\n    X1c = X1-t1\n    # scale\n    s0 = (X0c**2).sum(dim=-1).mean().sqrt()\n    s1 = (X1c**2).sum(dim=-1).mean().sqrt()\n    X0cs = X0c/s0\n    X1cs = X1c/s1\n    # rotation (use double for SVD, float loses precision)\n    U,S,V = (X0cs.t()@X1cs).double().svd(some=True)\n    R = (U@V.t()).float()\n    if R.det()<0: R[2] *= -1\n    # align X1 to X0: X1to0 = (X1-t1)/s1@R.t()*s0+t0\n    sim3 = edict(t0=t0[0],t1=t1[0],s0=s0,s1=s1,R=R)\n    return sim3\n\ndef get_novel_view_poses(opt,pose_anchor,N=60,scale=1):\n    # create circular viewpoints (small oscillations)\n    theta = torch.arange(N)/N*2*np.pi\n    R_x = angle_to_rotation_matrix((theta.sin()*0.05).asin(),\"X\")\n    R_y = angle_to_rotation_matrix((theta.cos()*0.05).asin(),\"Y\")\n    pose_rot = pose(R=R_y@R_x)\n    pose_shift = pose(t=[0,0,-4*scale])\n    pose_shift2 = pose(t=[0,0,3.8*scale])\n    pose_oscil = pose.compose([pose_shift,pose_rot,pose_shift2])\n    pose_novel = pose.compose([pose_oscil,pose_anchor.cpu()[None]])\n    return pose_novel\n"
  },
  {
    "path": "data/base.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport torch.multiprocessing as mp\nimport PIL\nimport tqdm\nimport threading,queue\nfrom easydict import EasyDict as edict\n\nimport util\nfrom util import log,debug\n\nclass Dataset(torch.utils.data.Dataset):\n\n    def __init__(self,opt,split=\"train\"):\n        super().__init__()\n        self.opt = opt\n        self.split = split\n        self.augment = split==\"train\" and opt.data.augment\n        # define image sizes\n        if opt.data.center_crop is not None:\n            self.crop_H = int(self.raw_H*opt.data.center_crop)\n            self.crop_W = int(self.raw_W*opt.data.center_crop)\n        else: self.crop_H,self.crop_W = self.raw_H,self.raw_W\n        if not opt.H or not opt.W:\n            opt.H,opt.W = self.crop_H,self.crop_W\n\n    def setup_loader(self,opt,shuffle=False,drop_last=False):\n        loader = torch.utils.data.DataLoader(self,\n            batch_size=opt.batch_size or 1,\n            num_workers=opt.data.num_workers,\n            shuffle=shuffle,\n            drop_last=drop_last,\n            pin_memory=False, # spews warnings in PyTorch 1.9 but should be True in general\n        )\n        print(\"number of samples: {}\".format(len(self)))\n        return loader\n\n    def get_list(self,opt):\n        raise NotImplementedError\n\n    def preload_worker(self,data_list,load_func,q,lock,idx_tqdm):\n        while True:\n            idx = q.get()\n            data_list[idx] = load_func(self.opt,idx)\n            with lock:\n                idx_tqdm.update()\n            q.task_done()\n\n    def preload_threading(self,opt,load_func,data_str=\"images\"):\n        data_list = [None]*len(self)\n        q = queue.Queue(maxsize=len(self))\n        idx_tqdm = tqdm.tqdm(range(len(self)),desc=\"preloading {}\".format(data_str),leave=False)\n        for i in range(len(self)): q.put(i)\n        lock = threading.Lock()\n        for ti in range(opt.data.num_workers):\n            t = threading.Thread(target=self.preload_worker,\n                                 args=(data_list,load_func,q,lock,idx_tqdm),daemon=True)\n            t.start()\n        q.join()\n        idx_tqdm.close()\n        assert(all(map(lambda x: x is not None,data_list)))\n        return data_list\n\n    def __getitem__(self,idx):\n        raise NotImplementedError\n\n    def get_image(self,opt,idx):\n        raise NotImplementedError\n\n    def generate_augmentation(self,opt):\n        brightness = opt.data.augment.brightness or 0.\n        contrast = opt.data.augment.contrast or 0.\n        saturation = opt.data.augment.saturation or 0.\n        hue = opt.data.augment.hue or 0.\n        color_jitter = torchvision.transforms.ColorJitter.get_params(\n            brightness=(1-brightness,1+brightness),\n            contrast=(1-contrast,1+contrast),\n            saturation=(1-saturation,1+saturation),\n            hue=(-hue,hue),\n        )\n        aug = edict(\n            color_jitter=color_jitter,\n            flip=np.random.randn()>0 if opt.data.augment.hflip else False,\n            rot_angle=(np.random.rand()*2-1)*opt.data.augment.rotate if opt.data.augment.rotate else 0,\n        )\n        return aug\n\n    def preprocess_image(self,opt,image,aug=None):\n        if aug is not None:\n            image = self.apply_color_jitter(opt,image,aug.color_jitter)\n            image = torchvision_F.hflip(image) if aug.flip else image\n            image = image.rotate(aug.rot_angle,resample=PIL.Image.BICUBIC)\n        # center crop\n        if opt.data.center_crop is not None:\n            self.crop_H = int(self.raw_H*opt.data.center_crop)\n            self.crop_W = int(self.raw_W*opt.data.center_crop)\n            image = torchvision_F.center_crop(image,(self.crop_H,self.crop_W))\n        else: self.crop_H,self.crop_W = self.raw_H,self.raw_W\n        # resize\n        if opt.data.image_size[0] is not None:\n            image = image.resize((opt.W,opt.H))\n        image = torchvision_F.to_tensor(image)\n        return image\n\n    def preprocess_camera(self,opt,intr,pose,aug=None):\n        intr,pose = intr.clone(),pose.clone()\n        # center crop\n        intr[0,2] -= (self.raw_W-self.crop_W)/2\n        intr[1,2] -= (self.raw_H-self.crop_H)/2\n        # resize\n        intr[0] *= opt.W/self.crop_W\n        intr[1] *= opt.H/self.crop_H\n        return intr,pose\n\n    def apply_color_jitter(self,opt,image,color_jitter):\n        mode = image.mode\n        if mode!=\"L\":\n            chan = image.split()\n            rgb = PIL.Image.merge(\"RGB\",chan[:3])\n            rgb = color_jitter(rgb)\n            rgb_chan = rgb.split()\n            image = PIL.Image.merge(mode,rgb_chan+chan[3:])\n        return image\n\n    def __len__(self):\n        return len(self.list)\n"
  },
  {
    "path": "data/blender.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport PIL\nimport imageio\nfrom easydict import EasyDict as edict\nimport json\nimport pickle\n\nfrom . import base\nimport camera\nfrom util import log,debug\n\nclass Dataset(base.Dataset):\n\n    def __init__(self,opt,split=\"train\",subset=None):\n        self.raw_H,self.raw_W = 800,800\n        super().__init__(opt,split)\n        self.root = opt.data.root or \"data/blender\"\n        self.path = \"{}/{}\".format(self.root,opt.data.scene)\n        # load/parse metadata\n        meta_fname = \"{}/transforms_{}.json\".format(self.path,split)\n        with open(meta_fname) as file:\n            self.meta = json.load(file)\n        self.list = self.meta[\"frames\"]\n        self.focal = 0.5*self.raw_W/np.tan(0.5*self.meta[\"camera_angle_x\"])\n        if subset: self.list = self.list[:subset]\n        # preload dataset\n        if opt.data.preload:\n            self.images = self.preload_threading(opt,self.get_image)\n            self.cameras = self.preload_threading(opt,self.get_camera,data_str=\"cameras\")\n\n    def prefetch_all_data(self,opt):\n        assert(not opt.data.augment)\n        # pre-iterate through all samples and group together\n        self.all = torch.utils.data._utils.collate.default_collate([s for s in self])\n\n    def get_all_camera_poses(self,opt):\n        pose_raw_all = [torch.tensor(f[\"transform_matrix\"],dtype=torch.float32) for f in self.list]\n        pose_canon_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0)\n        return pose_canon_all\n\n    def __getitem__(self,idx):\n        opt = self.opt\n        sample = dict(idx=idx)\n        aug = self.generate_augmentation(opt) if self.augment else None\n        image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)\n        image = self.preprocess_image(opt,image,aug=aug)\n        intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)\n        intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)\n        sample.update(\n            image=image,\n            intr=intr,\n            pose=pose,\n        )\n        return sample\n\n    def get_image(self,opt,idx):\n        image_fname = \"{}/{}.png\".format(self.path,self.list[idx][\"file_path\"])\n        image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....\n        return image\n\n    def preprocess_image(self,opt,image,aug=None):\n        image = super().preprocess_image(opt,image,aug=aug)\n        rgb,mask = image[:3],image[3:]\n        if opt.data.bgcolor is not None:\n            rgb = rgb*mask+opt.data.bgcolor*(1-mask)\n        return rgb\n\n    def get_camera(self,opt,idx):\n        intr = torch.tensor([[self.focal,0,self.raw_W/2],\n                             [0,self.focal,self.raw_H/2],\n                             [0,0,1]]).float()\n        pose_raw = torch.tensor(self.list[idx][\"transform_matrix\"],dtype=torch.float32)\n        pose = self.parse_raw_camera(opt,pose_raw)\n        return intr,pose\n\n    def parse_raw_camera(self,opt,pose_raw):\n        pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1])))\n        pose = camera.pose.compose([pose_flip,pose_raw[:3]])\n        pose = camera.pose.invert(pose)\n        return pose\n"
  },
  {
    "path": "data/iphone.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport PIL\nimport imageio\nfrom easydict import EasyDict as edict\nimport json\nimport pickle\n\nfrom . import base\nimport camera\nfrom util import log,debug\n\nclass Dataset(base.Dataset):\n\n    def __init__(self,opt,split=\"train\",subset=None):\n        self.raw_H,self.raw_W = 1080,1920\n        super().__init__(opt,split)\n        self.root = opt.data.root or \"data/iphone\"\n        self.path = \"{}/{}\".format(self.root,opt.data.scene)\n        self.path_image = \"{}/images\".format(self.path)\n        # self.list = sorted(os.listdir(self.path_image),key=lambda f: int(f.split(\".\")[0]))\n        self.list = os.listdir(self.path_image)\n        # manually split train/val subsets\n        num_val_split = int(len(self)*opt.data.val_ratio)\n        self.list = self.list[:-num_val_split] if split==\"train\" else self.list[-num_val_split:]\n        if subset: self.list = self.list[:subset]\n        # preload dataset\n        if opt.data.preload:\n            self.images = self.preload_threading(opt,self.get_image)\n            self.cameras = self.preload_threading(opt,self.get_camera,data_str=\"cameras\")\n\n    def prefetch_all_data(self,opt):\n        assert(not opt.data.augment)\n        # pre-iterate through all samples and group together\n        self.all = torch.utils.data._utils.collate.default_collate([s for s in self])\n\n    def get_all_camera_poses(self,opt):\n        # poses are unknown, so just return some dummy poses (identity transform)\n        return camera.pose(t=torch.zeros(len(self),3))\n\n    def __getitem__(self,idx):\n        opt = self.opt\n        sample = dict(idx=idx)\n        aug = self.generate_augmentation(opt) if self.augment else None\n        image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)\n        image = self.preprocess_image(opt,image,aug=aug)\n        intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)\n        intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)\n        sample.update(\n            image=image,\n            intr=intr,\n            pose=pose,\n        )\n        return sample\n\n    def get_image(self,opt,idx):\n        image_fname = \"{}/{}\".format(self.path_image,self.list[idx])\n        image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....\n        return image\n\n    def get_camera(self,opt,idx):\n        self.focal = self.raw_W*4.2/(12.8/2.55)\n        intr = torch.tensor([[self.focal,0,self.raw_W/2],\n                             [0,self.focal,self.raw_H/2],\n                             [0,0,1]]).float()\n        pose = camera.pose(t=torch.zeros(3)) # dummy pose, won't be used\n        return intr,pose\n"
  },
  {
    "path": "data/llff.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport PIL\nimport imageio\nfrom easydict import EasyDict as edict\nimport json\nimport pickle\n\nfrom . import base\nimport camera\nfrom util import log,debug\n\nclass Dataset(base.Dataset):\n\n    def __init__(self,opt,split=\"train\",subset=None):\n        self.raw_H,self.raw_W = 3024,4032\n        super().__init__(opt,split)\n        self.root = opt.data.root or \"data/llff\"\n        self.path = \"{}/{}\".format(self.root,opt.data.scene)\n        self.path_image = \"{}/images\".format(self.path)\n        image_fnames = sorted(os.listdir(self.path_image))\n        poses_raw,bounds = self.parse_cameras_and_bounds(opt)\n        self.list = list(zip(image_fnames,poses_raw,bounds))\n        # manually split train/val subsets\n        num_val_split = int(len(self)*opt.data.val_ratio)\n        self.list = self.list[:-num_val_split] if split==\"train\" else self.list[-num_val_split:]\n        if subset: self.list = self.list[:subset]\n        # preload dataset\n        if opt.data.preload:\n            self.images = self.preload_threading(opt,self.get_image)\n            self.cameras = self.preload_threading(opt,self.get_camera,data_str=\"cameras\")\n\n    def prefetch_all_data(self,opt):\n        assert(not opt.data.augment)\n        # pre-iterate through all samples and group together\n        self.all = torch.utils.data._utils.collate.default_collate([s for s in self])\n\n    def parse_cameras_and_bounds(self,opt):\n        fname = \"{}/poses_bounds.npy\".format(self.path)\n        data = torch.tensor(np.load(fname),dtype=torch.float32)\n        # parse cameras (intrinsics and poses)\n        cam_data = data[:,:-2].view([-1,3,5]) # [N,3,5]\n        poses_raw = cam_data[...,:4] # [N,3,4]\n        poses_raw[...,0],poses_raw[...,1] = poses_raw[...,1],-poses_raw[...,0]\n        raw_H,raw_W,self.focal = cam_data[0,:,-1]\n        assert(self.raw_H==raw_H and self.raw_W==raw_W)\n        # parse depth bounds\n        bounds = data[:,-2:] # [N,2]\n        scale = 1./(bounds.min()*0.75) # not sure how this was determined\n        poses_raw[...,3] *= scale\n        bounds *= scale\n        # roughly center camera poses\n        poses_raw = self.center_camera_poses(opt,poses_raw)\n        return poses_raw,bounds\n\n    def center_camera_poses(self,opt,poses):\n        # compute average pose\n        center = poses[...,3].mean(dim=0)\n        v1 = torch_F.normalize(poses[...,1].mean(dim=0),dim=0)\n        v2 = torch_F.normalize(poses[...,2].mean(dim=0),dim=0)\n        v0 = v1.cross(v2)\n        pose_avg = torch.stack([v0,v1,v2,center],dim=-1)[None] # [1,3,4]\n        # apply inverse of averaged pose\n        poses = camera.pose.compose([poses,camera.pose.invert(pose_avg)])\n        return poses\n\n    def get_all_camera_poses(self,opt):\n        pose_raw_all = [tup[1] for tup in self.list]\n        pose_all = torch.stack([self.parse_raw_camera(opt,p) for p in pose_raw_all],dim=0)\n        return pose_all\n\n    def __getitem__(self,idx):\n        opt = self.opt\n        sample = dict(idx=idx)\n        aug = self.generate_augmentation(opt) if self.augment else None\n        image = self.images[idx] if opt.data.preload else self.get_image(opt,idx)\n        image = self.preprocess_image(opt,image,aug=aug)\n        intr,pose = self.cameras[idx] if opt.data.preload else self.get_camera(opt,idx)\n        intr,pose = self.preprocess_camera(opt,intr,pose,aug=aug)\n        sample.update(\n            image=image,\n            intr=intr,\n            pose=pose,\n        )\n        return sample\n\n    def get_image(self,opt,idx):\n        image_fname = \"{}/{}\".format(self.path_image,self.list[idx][0])\n        image = PIL.Image.fromarray(imageio.imread(image_fname)) # directly using PIL.Image.open() leads to weird corruption....\n        return image\n\n    def get_camera(self,opt,idx):\n        intr = torch.tensor([[self.focal,0,self.raw_W/2],\n                             [0,self.focal,self.raw_H/2],\n                             [0,0,1]]).float()\n        pose_raw = self.list[idx][1]\n        pose = self.parse_raw_camera(opt,pose_raw)\n        return intr,pose\n\n    def parse_raw_camera(self,opt,pose_raw):\n        pose_flip = camera.pose(R=torch.diag(torch.tensor([1,-1,-1])))\n        pose = camera.pose.compose([pose_flip,pose_raw[:3]])\n        pose = camera.pose.invert(pose)\n        pose = camera.pose.compose([pose_flip,pose])\n        return pose\n"
  },
  {
    "path": "evaluate.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport importlib\n\nimport options\nfrom util import log\n\nimport warnings\nwarnings.filterwarnings('ignore')\n\ndef main():\n\n    log.process(os.getpid())\n    log.title(\"[{}] (PyTorch code for evaluating NeRF/BARF/L2G_NeRF)\".format(sys.argv[0]))\n\n    opt_cmd = options.parse_arguments(sys.argv[1:])\n    opt = options.set(opt_cmd=opt_cmd)\n\n    with torch.cuda.device(opt.device):\n\n        model = importlib.import_module(\"model.{}\".format(opt.model))\n        m = model.Model(opt)\n\n        m.load_dataset(opt,eval_split=\"test\")\n        m.build_networks(opt)\n\n        if opt.model in [\"barf\", \"l2g_nerf\"]:\n            m.generate_videos_pose(opt)\n\n        m.restore_checkpoint(opt)\n        if opt.data.dataset in [\"blender\",\"llff\"]:\n            m.evaluate_full(opt)\n        m.generate_videos_synthesis(opt)\n\nif __name__==\"__main__\":\n    main()\n"
  },
  {
    "path": "external/pohsun_ssim/LICENSE.txt",
    "content": "MIT\n"
  },
  {
    "path": "external/pohsun_ssim/README.md",
    "content": "# pytorch-ssim\n\n### Differentiable structural similarity (SSIM) index.\n![einstein](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/einstein.png) ![Max_ssim](https://raw.githubusercontent.com/Po-Hsun-Su/pytorch-ssim/master/max_ssim.gif)\n\n## Installation\n1. Clone this repo.\n2. Copy \"pytorch_ssim\" folder in your project.\n\n## Example\n### basic usage\n```python\nimport pytorch_ssim\nimport torch\nfrom torch.autograd import Variable\n\nimg1 = Variable(torch.rand(1, 1, 256, 256))\nimg2 = Variable(torch.rand(1, 1, 256, 256))\n\nif torch.cuda.is_available():\n    img1 = img1.cuda()\n    img2 = img2.cuda()\n\nprint(pytorch_ssim.ssim(img1, img2))\n\nssim_loss = pytorch_ssim.SSIM(window_size = 11)\n\nprint(ssim_loss(img1, img2))\n\n```\n### maximize ssim\n```python\nimport pytorch_ssim\nimport torch\nfrom torch.autograd import Variable\nfrom torch import optim\nimport cv2\nimport numpy as np\n\nnpImg1 = cv2.imread(\"einstein.png\")\n\nimg1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0\nimg2 = torch.rand(img1.size())\n\nif torch.cuda.is_available():\n    img1 = img1.cuda()\n    img2 = img2.cuda()\n\n\nimg1 = Variable( img1,  requires_grad=False)\nimg2 = Variable( img2, requires_grad = True)\n\n\n# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)\nssim_value = pytorch_ssim.ssim(img1, img2).data[0]\nprint(\"Initial ssim:\", ssim_value)\n\n# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)\nssim_loss = pytorch_ssim.SSIM()\n\noptimizer = optim.Adam([img2], lr=0.01)\n\nwhile ssim_value < 0.95:\n    optimizer.zero_grad()\n    ssim_out = -ssim_loss(img1, img2)\n    ssim_value = - ssim_out.data[0]\n    print(ssim_value)\n    ssim_out.backward()\n    optimizer.step()\n\n```\n\n## Reference\nhttps://ece.uwaterloo.ca/~z70wang/research/ssim/\n"
  },
  {
    "path": "external/pohsun_ssim/max_ssim.py",
    "content": "import pytorch_ssim\nimport torch\nfrom torch.autograd import Variable\nfrom torch import optim\nimport cv2\nimport numpy as np\n\nnpImg1 = cv2.imread(\"einstein.png\")\n\nimg1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0\nimg2 = torch.rand(img1.size())\n\nif torch.cuda.is_available():\n    img1 = img1.cuda()\n    img2 = img2.cuda()\n\n\nimg1 = Variable( img1,  requires_grad=False)\nimg2 = Variable( img2, requires_grad = True)\n\n\n# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)\nssim_value = pytorch_ssim.ssim(img1, img2).data[0]\nprint(\"Initial ssim:\", ssim_value)\n\n# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)\nssim_loss = pytorch_ssim.SSIM()\n\noptimizer = optim.Adam([img2], lr=0.01)\n\nwhile ssim_value < 0.95:\n    optimizer.zero_grad()\n    ssim_out = -ssim_loss(img1, img2)\n    ssim_value = - ssim_out.data[0]\n    print(ssim_value)\n    ssim_out.backward()\n    optimizer.step()\n"
  },
  {
    "path": "external/pohsun_ssim/pytorch_ssim/__init__.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nimport numpy as np\nfrom math import exp\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])\n    return gauss/gauss.sum()\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())\n    return window\n\ndef _ssim(img1, img2, window, window_size, channel, size_average = True):\n    mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)\n    mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1*mu2\n\n    sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq\n    sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2\n\n    C1 = 0.01**2\n    C2 = 0.03**2\n\n    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\nclass SSIM(torch.nn.Module):\n    def __init__(self, window_size = 11, size_average = True):\n        super(SSIM, self).__init__()\n        self.window_size = window_size\n        self.size_average = size_average\n        self.channel = 1\n        self.window = create_window(window_size, self.channel)\n\n    def forward(self, img1, img2):\n        (_, channel, _, _) = img1.size()\n\n        if channel == self.channel and self.window.data.type() == img1.data.type():\n            window = self.window\n        else:\n            window = create_window(self.window_size, channel)\n            \n            if img1.is_cuda:\n                window = window.cuda(img1.get_device())\n            window = window.type_as(img1)\n            \n            self.window = window\n            self.channel = channel\n\n\n        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)\n\ndef ssim(img1, img2, window_size = 11, size_average = True):\n    (_, channel, _, _) = img1.size()\n    window = create_window(window_size, channel)\n    \n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n    \n    return _ssim(img1, img2, window, window_size, channel, size_average)\n"
  },
  {
    "path": "external/pohsun_ssim/setup.cfg",
    "content": "[metadata]\ndescription-file = README.md\n"
  },
  {
    "path": "external/pohsun_ssim/setup.py",
    "content": "from distutils.core import setup\nsetup(\n  name = 'pytorch_ssim',\n  packages = ['pytorch_ssim'], # this must be the same as the name above\n  version = '0.1',\n  description = 'Differentiable structural similarity (SSIM) index',\n  author = 'Po-Hsun (Evan) Su',\n  author_email = 'evan.pohsun.su@gmail.com',\n  url = 'https://github.com/Po-Hsun-Su/pytorch-ssim', # use the URL to the github repo\n  download_url = 'https://github.com/Po-Hsun-Su/pytorch-ssim/archive/0.1.tar.gz', # I'll explain this in a second\n  keywords = ['pytorch', 'image-processing', 'deep-learning'], # arbitrary keywords\n  classifiers = [],\n)\n"
  },
  {
    "path": "extract_mesh.py",
    "content": "\"\"\"Extracts a 3D mesh from a pretrained model using marching cubes.\"\"\"\n\nimport importlib\nimport sys\n\nimport numpy as np\nimport options\nimport torch\nimport tqdm\nimport trimesh\nimport mcubes\n\nfrom util import log,debug\n\nopt_cmd = options.parse_arguments(sys.argv[1:])\nopt = options.set(opt_cmd=opt_cmd)\n\nwith torch.cuda.device(opt.device),torch.no_grad():\n\n    model = importlib.import_module(\"model.{}\".format(opt.model))\n    m = model.Model(opt)\n\n    m.load_dataset(opt)\n    m.build_networks(opt)\n    m.restore_checkpoint(opt)\n\n    t = torch.linspace(*opt.trimesh.range,opt.trimesh.res+1) # the best range might vary from model to model\n    query = torch.stack(torch.meshgrid(t,t,t),dim=-1)\n    query_flat = query.view(-1,3)\n\n    density_all = []\n    for i in tqdm.trange(0,len(query_flat),opt.trimesh.chunk_size,leave=False):\n        points = query_flat[None,i:i+opt.trimesh.chunk_size].to(opt.device)\n        ray_unit = torch.zeros_like(points) # dummy ray to comply with interface, not used\n        _,density_samples = m.graph.nerf.forward(opt,points,ray_unit=ray_unit,mode=None)\n        density_all.append(density_samples.cpu())\n    density_all = torch.cat(density_all,dim=1)[0]\n    density_all = density_all.view(*query.shape[:-1]).numpy()\n\n    log.info(\"running marching cubes...\")\n    vertices,triangles = mcubes.marching_cubes(density_all,opt.trimesh.thres)\n    vertices_centered = vertices/opt.trimesh.res-0.5\n    mesh = trimesh.Trimesh(vertices_centered,triangles)\n\n    obj_fname = \"{}/mesh.obj\".format(opt.output_path)\n    log.info(\"saving 3D mesh to {}...\".format(obj_fname))\n    mesh.export(obj_fname)\n"
  },
  {
    "path": "model/barf.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport tqdm\nfrom easydict import EasyDict as edict\nimport visdom\nimport matplotlib.pyplot as plt\n\nimport util,util_vis\nfrom util import log,debug\nfrom . import nerf\nimport camera\n\n# ============================ main engine for training and evaluation ============================\n\nclass Model(nerf.Model):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n\n    def build_networks(self,opt):\n        super().build_networks(opt)\n        if opt.camera.noise:\n            # pre-generate synthetic pose perturbation\n            so3_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_r\n            t_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_t\n            self.graph.pose_noise = torch.cat([camera.lie.so3_to_SO3(so3_noise),t_noise[...,None]],dim=-1) # [...,3,4]\n            \n        self.graph.se3_refine = torch.nn.Embedding(len(self.train_data),6).to(opt.device)\n        torch.nn.init.zeros_(self.graph.se3_refine.weight)\n\n        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)\n        # add synthetic pose perturbation to all training data\n        if opt.data.dataset==\"blender\":\n            pose = pose_GT\n            if opt.camera.noise:\n                pose = camera.pose.compose([pose, self.graph.pose_noise])\n        else: pose = self.graph.pose_eye[None].repeat(len(self.train_data),1,1)\n        # use Embedding so it could be checkpointed\n        self.graph.optimised_training_poses = torch.nn.Embedding(len(self.train_data),12,_weight=pose.view(-1,12)).to(opt.device) \n\n        idx_range = torch.arange(len(self.train_data),dtype=torch.long,device=opt.device)\n        idx_X,idx_Y = torch.meshgrid(idx_range,idx_range)\n        self.graph.idx_grid = torch.stack([idx_X,idx_Y],dim=-1).view(-1,2)\n\n    def setup_optimizer(self,opt):\n        super().setup_optimizer(opt)\n        optimizer = getattr(torch.optim,opt.optim.algo)\n        self.optim_pose = optimizer([dict(params=self.graph.se3_refine.parameters(),lr=opt.optim.lr_pose)])\n        # set up scheduler\n        if opt.optim.sched_pose:\n            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type)\n            if opt.optim.lr_pose_end:\n                assert(opt.optim.sched_pose.type==\"ExponentialLR\")\n                opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter)\n            kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!=\"type\" }\n            self.sched_pose = scheduler(self.optim_pose,**kwargs)\n\n    def train_iteration(self,opt,var,loader):\n        self.optim_pose.zero_grad()\n        if opt.optim.warmup_pose:\n            # simple linear warmup of pose learning rate\n            self.optim_pose.param_groups[0][\"lr_orig\"] = self.optim_pose.param_groups[0][\"lr\"] # cache the original learning rate\n            self.optim_pose.param_groups[0][\"lr\"] *= min(1,self.it/opt.optim.warmup_pose)\n        loss = super().train_iteration(opt,var,loader)\n        self.optim_pose.step()\n        if opt.optim.warmup_pose:\n            self.optim_pose.param_groups[0][\"lr\"] = self.optim_pose.param_groups[0][\"lr_orig\"] # reset learning rate\n        if opt.optim.sched_pose: self.sched_pose.step()\n        self.graph.nerf.progress.data.fill_(self.it/opt.max_iter)\n        if opt.nerf.fine_sampling:\n            self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter)\n        return loss\n\n    @torch.no_grad()\n    def validate(self,opt,ep=None):\n        pose,pose_GT = self.get_all_training_poses(opt)\n        _,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)\n        super().validate(opt,ep=ep)\n\n    @torch.no_grad()\n    def log_scalars(self,opt,var,loss,metric=None,step=0,split=\"train\"):\n        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)\n        if split==\"train\":\n            # log learning rate\n            lr = self.optim_pose.param_groups[0][\"lr\"]\n            self.tb.add_scalar(\"{0}/{1}\".format(split,\"lr_pose\"),lr,step)\n        # compute pose error\n        if split==\"train\" and opt.data.dataset in [\"blender\",\"llff\"]:\n            pose,pose_GT = self.get_all_training_poses(opt)\n            pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)\n            error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)\n            self.tb.add_scalar(\"{0}/error_R\".format(split),error.R.mean(),step)\n            self.tb.add_scalar(\"{0}/error_t\".format(split),error.t.mean(),step)\n\n    @torch.no_grad()\n    def visualize(self,opt,var,step=0,split=\"train\"):\n        super().visualize(opt,var,step=step,split=split)\n        if opt.visdom:\n            if split==\"val\":\n                pose,pose_GT = self.get_all_training_poses(opt)\n                pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)\n                util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose_aligned,pose_GT])\n\n    @torch.no_grad()\n    def get_all_training_poses(self,opt):\n        # get ground-truth (canonical) camera poses\n        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)\n        pose = self.graph.optimised_training_poses.weight.data.detach().clone().view(-1,3,4)\n        return pose,pose_GT\n\n    @torch.no_grad()\n    def prealign_cameras(self,opt,pose,pose_GT):\n        # compute 3D similarity transform via Procrustes analysis\n        center = torch.zeros(1,1,3,device=opt.device)\n        center_pred = camera.cam2world(center,pose)[:,0] # [N,3]\n        center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3]\n        try:\n            sim3 = camera.procrustes_analysis(center_GT,center_pred)\n        except:\n            print(\"warning: SVD did not converge...\")\n            sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device))\n        # align the camera poses\n        center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0\n        R_aligned = pose[...,:3]@sim3.R.t()\n        t_aligned = (-R_aligned@center_aligned[...,None])[...,0]\n        pose_aligned = camera.pose(R=R_aligned,t=t_aligned)\n        return pose_aligned,sim3\n\n    @torch.no_grad()\n    def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):\n        # measure errors in rotation and translation\n        R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1)\n        R_GT,t_GT = pose_GT.split([3,1],dim=-1)\n        R_error = camera.rotation_distance(R_aligned,R_GT)\n        t_error = (t_aligned-t_GT)[...,0].norm(dim=-1)\n        error = edict(R=R_error,t=t_error)\n        return error\n\n    @torch.no_grad()\n    def evaluate_full(self,opt):\n        self.graph.eval()\n        # evaluate rotation/translation\n        pose,pose_GT = self.get_all_training_poses(opt)\n        pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)\n        error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)\n        print(\"--------------------------\")\n        print(\"rot:   {:8.3f}\".format(np.rad2deg(error.R.mean().cpu())))\n        print(\"trans: {:10.5f}\".format(error.t.mean()))\n        print(\"--------------------------\")\n        # dump numbers\n        quant_fname = \"{}/quant_pose.txt\".format(opt.output_path)\n        with open(quant_fname,\"w\") as file:\n            for i,(err_R,err_t) in enumerate(zip(error.R,error.t)):\n                file.write(\"{} {} {}\\n\".format(i,err_R.item(),err_t.item()))\n        # evaluate novel view synthesis\n        super().evaluate_full(opt)\n\n    @torch.enable_grad()\n    def evaluate_test_time_photometric_optim(self,opt,var):\n        # use another se3 Parameter to absorb the remaining pose errors\n        var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device))\n        optimizer = getattr(torch.optim,opt.optim.algo)\n        optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)])\n        iterator = tqdm.trange(opt.optim.test_iter,desc=\"test-time optim.\",leave=False,position=1)\n        for it in iterator:\n            optim_pose.zero_grad()\n            var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test)\n            var = self.graph.forward(opt,var,mode=\"test-optim\")\n            loss = self.graph.compute_loss(opt,var,mode=\"test-optim\")\n            loss = self.summarize_loss(opt,var,loss)\n            loss.all.backward()\n            optim_pose.step()\n            iterator.set_postfix(loss=\"{:.3f}\".format(loss.all))\n        return var\n\n    @torch.no_grad()\n    def generate_videos_pose(self,opt):\n        self.graph.eval()\n        fig = plt.figure(figsize=(10,10) if opt.data.dataset==\"blender\" else (16,8))\n        cam_path = \"{}/poses\".format(opt.output_path)\n        os.makedirs(cam_path,exist_ok=True)\n        ep_list = []\n        for ep in range(0,opt.max_iter+1,opt.freq.ckpt):\n            # load checkpoint (0 is random init)\n            if ep!=0:\n                try: util.restore_checkpoint(opt,self,resume=ep)\n                except: continue\n            # get the camera poses\n            pose,pose_ref = self.get_all_training_poses(opt)\n            if opt.data.dataset in [\"blender\",\"llff\"]:\n                pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref)\n                pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu()\n                dict(\n                    blender=util_vis.plot_save_poses_blender,\n                    llff=util_vis.plot_save_poses,\n                )[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep)\n            else:\n                pose = pose.detach().cpu()\n                util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep)\n            ep_list.append(ep)\n        plt.close()\n        # write videos\n        print(\"writing videos...\")\n        list_fname = \"{}/temp.list\".format(cam_path)\n        with open(list_fname,\"w\") as file:\n            for ep in ep_list: file.write(\"file {}.png\\n\".format(ep))\n        cam_vid_fname = \"{}/poses.mp4\".format(opt.output_path)\n        os.system(\"ffmpeg -y -r 4 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1\".format(list_fname,cam_vid_fname))\n        os.remove(list_fname)\n\n# ============================ computation graph for forward/backprop ============================\n\nclass Graph(nerf.Graph):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        self.nerf = NeRF(opt)\n        if opt.nerf.fine_sampling:\n            self.nerf_fine = NeRF(opt)\n        self.pose_eye = torch.eye(3,4).to(opt.device)\n\n    def forward(self,opt,var,mode=None):\n        # rescale the size of the scene condition on the pose\n        if opt.data.dataset==\"blender\":\n            depth_min,depth_max = opt.nerf.depth.range\n            position = camera.Pose().invert(self.optimised_training_poses.weight.data.detach().clone().view(-1,3,4))[...,-1]\n            diameter = ((position[self.idx_grid[...,0]]-position[self.idx_grid[...,1]]).norm(dim=-1)).max()\n            depth_min_new = (depth_min/(depth_max+depth_min))*diameter\n            depth_max_new = (depth_max/(depth_max+depth_min))*diameter\n            opt.nerf.depth.range = [depth_min_new, depth_max_new]\n        # render images\n        batch_size = len(var.idx)\n        pose = self.get_pose(opt,var,mode=mode)\n        if opt.nerf.rand_rays and mode in [\"train\",\"test-optim\"]:\n            # sample random rays for optimization\n            var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]\n            ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]\n        else:\n            # render full image (process in slices)\n            ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \\\n                  self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]\n        var.update(ret)\n        return var\n\n    def get_pose(self,opt,var,mode=None):\n        if mode==\"train\":\n            # add the pre-generated pose perturbations\n            if opt.data.dataset==\"blender\":\n                if opt.camera.noise:\n                    var.pose_noise = self.pose_noise[var.idx]\n                    pose = camera.pose.compose([var.pose, var.pose_noise])\n                else: pose = var.pose\n            else: pose = self.pose_eye\n            # add learnable pose correction\n            var.se3_refine = self.se3_refine.weight[var.idx]\n            pose_refine = camera.lie.se3_to_SE3(var.se3_refine)\n            pose = camera.pose.compose([pose_refine, pose])\n            self.optimised_training_poses.weight.data = pose.detach().clone().view(-1,12)\n        elif mode in [\"val\",\"eval\",\"test-optim\"]:\n            # align test pose to refined coordinate system (up to sim3)\n            sim3 = self.sim3\n            center = torch.zeros(1,1,3,device=opt.device)\n            center = camera.cam2world(center,var.pose)[:,0] # [N,3]\n            center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1\n            R_aligned = var.pose[...,:3]@self.sim3.R\n            t_aligned = (-R_aligned@center_aligned[...,None])[...,0]\n            pose = camera.pose(R=R_aligned,t=t_aligned)\n            # additionally factorize the remaining pose imperfection\n            if opt.optim.test_photo and mode!=\"val\":\n                pose = camera.pose.compose([var.pose_refine_test, pose])\n        else: pose = var.pose\n        return pose\n\nclass NeRF(nerf.NeRF):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed\n\n    def positional_encoding(self,opt,input,L): # [B,...,N]\n        input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL]\n        # coarse-to-fine: smoothly mask positional encoding for BARF\n        if opt.barf_c2f is not None:\n            # set weights for different frequency bands\n            start,end = opt.barf_c2f\n            alpha = (self.progress.data-start)/(end-start)*L\n            k = torch.arange(L,dtype=torch.float32,device=opt.device)\n            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2\n            # apply weights\n            shape = input_enc.shape\n            input_enc = (input_enc.view(-1,L)*weight).view(*shape)\n        return input_enc\n"
  },
  {
    "path": "model/base.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport torch.utils.tensorboard\nimport visdom\nimport importlib\nimport tqdm\nfrom easydict import EasyDict as edict\n\nimport util,util_vis\nfrom util import log,debug\n\n# ============================ main engine for training and evaluation ============================\n\nclass Model():\n\n    def __init__(self,opt):\n        super().__init__()\n        os.makedirs(opt.output_path,exist_ok=True)\n\n    def load_dataset(self,opt,eval_split=\"val\"):\n        data = importlib.import_module(\"data.{}\".format(opt.data.dataset))\n        log.info(\"loading training data...\")\n        self.train_data = data.Dataset(opt,split=\"train\",subset=opt.data.train_sub)\n        self.train_loader = self.train_data.setup_loader(opt,shuffle=True)\n        log.info(\"loading test data...\")\n        if opt.data.val_on_test: eval_split = \"test\"\n        self.test_data = data.Dataset(opt,split=eval_split,subset=opt.data.val_sub)\n        self.test_loader = self.test_data.setup_loader(opt,shuffle=False)\n\n    def build_networks(self,opt):\n        graph = importlib.import_module(\"model.{}\".format(opt.model))\n        log.info(\"building networks...\")\n        self.graph = graph.Graph(opt).to(opt.device)\n\n    def setup_optimizer(self,opt):\n        log.info(\"setting up optimizers...\")\n        optimizer = getattr(torch.optim,opt.optim.algo)\n        self.optim = optimizer([dict(params=self.graph.parameters(),lr=opt.optim.lr)])\n        # set up scheduler\n        if opt.optim.sched:\n            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)\n            kwargs = { k:v for k,v in opt.optim.sched.items() if k!=\"type\" }\n            self.sched = scheduler(self.optim,**kwargs)\n\n    def restore_checkpoint(self,opt):\n        epoch_start,iter_start = None,None\n        if opt.resume:\n            log.info(\"resuming from previous checkpoint...\")\n            epoch_start,iter_start = util.restore_checkpoint(opt,self,resume=opt.resume)\n        elif opt.load is not None:\n            log.info(\"loading weights from checkpoint {}...\".format(opt.load))\n            epoch_start,iter_start = util.restore_checkpoint(opt,self,load_name=opt.load)\n        else:\n            log.info(\"initializing weights from scratch...\")\n        self.epoch_start = epoch_start or 0\n        self.iter_start = iter_start or 0\n\n    def setup_visualizer(self,opt):\n        log.info(\"setting up visualizers...\")\n        if opt.tb:\n            self.tb = torch.utils.tensorboard.SummaryWriter(log_dir=opt.output_path,flush_secs=10)\n        if opt.visdom:\n            # check if visdom server is runninng\n            is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port)\n            retry = None\n            while not is_open:\n                retry = input(\"visdom port ({}) not open, retry? (y/n) \".format(opt.visdom.port))\n                if retry not in [\"y\",\"n\"]: continue\n                if retry==\"y\":\n                    is_open = util.check_socket_open(opt.visdom.server,opt.visdom.port)\n                else: break\n            self.vis = visdom.Visdom(server=opt.visdom.server,port=opt.visdom.port,env=opt.group)\n\n    def train(self,opt):\n        # before training\n        log.title(\"TRAINING START\")\n        self.timer = edict(start=time.time(),it_mean=None)\n        self.it = self.iter_start\n        # training\n        if self.iter_start==0: self.validate(opt,ep=0)\n        for self.ep in range(self.epoch_start,opt.max_epoch):\n            self.train_epoch(opt)\n        # after training\n        if opt.tb:\n            self.tb.flush()\n            self.tb.close()\n        if opt.visdom: self.vis.close()\n        log.title(\"TRAINING DONE\")\n\n    def train_epoch(self,opt):\n        # before train epoch\n        self.graph.train()\n        # train epoch\n        loader = tqdm.tqdm(self.train_loader,desc=\"training epoch {}\".format(self.ep+1),leave=False)\n        for batch in loader:\n            # train iteration\n            var = edict(batch)\n            var = util.move_to_device(var,opt.device)\n            loss = self.train_iteration(opt,var,loader)\n        # after train epoch\n        lr = self.sched.get_last_lr()[0] if opt.optim.sched else opt.optim.lr\n        log.loss_train(opt,self.ep+1,lr,loss.all,self.timer)\n        if opt.optim.sched: self.sched.step()\n        if (self.ep+1)%opt.freq.val==0: self.validate(opt,ep=self.ep+1)\n        if (self.ep+1)%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=self.ep+1,it=self.it)\n\n    def train_iteration(self,opt,var,loader):\n        # before train iteration\n        self.timer.it_start = time.time()\n        # train iteration\n        self.optim.zero_grad()\n        var = self.graph.forward(opt,var,mode=\"train\")\n        loss = self.graph.compute_loss(opt,var,mode=\"train\")\n        loss = self.summarize_loss(opt,var,loss)\n        loss.all.backward()\n        self.optim.step()\n        # after train iteration\n        if (self.it+1)%opt.freq.scalar==0: self.log_scalars(opt,var,loss,step=self.it+1,split=\"train\")\n        if (self.it+1)%opt.freq.vis==0: self.visualize(opt,var,step=self.it+1,split=\"train\")\n        self.it += 1\n        loader.set_postfix(it=self.it,loss=\"{:.3f}\".format(loss.all))\n        self.timer.it_end = time.time()\n        util.update_timer(opt,self.timer,self.ep,len(loader))\n        return loss\n\n    def summarize_loss(self,opt,var,loss):\n        loss_all = 0.\n        assert(\"all\" not in loss)\n        # weigh losses\n        for key in loss:\n            assert(key in opt.loss_weight)\n            assert(loss[key].shape==())\n            if opt.loss_weight[key] is not None:\n                assert not torch.isinf(loss[key]),\"loss {} is Inf\".format(key)\n                assert not torch.isnan(loss[key]),\"loss {} is NaN\".format(key)\n                loss_all += 10**float(opt.loss_weight[key])*loss[key]\n        loss.update(all=loss_all)\n        return loss\n\n    @torch.no_grad()\n    def validate(self,opt,ep=None):\n        self.graph.eval()\n        loss_val = edict()\n        loader = tqdm.tqdm(self.test_loader,desc=\"validating\",leave=False)\n        for it,batch in enumerate(loader):\n            var = edict(batch)\n            var = util.move_to_device(var,opt.device)\n            var = self.graph.forward(opt,var,mode=\"val\")\n            loss = self.graph.compute_loss(opt,var,mode=\"val\")\n            loss = self.summarize_loss(opt,var,loss)\n            for key in loss:\n                loss_val.setdefault(key,0.)\n                loss_val[key] += loss[key]*len(var.idx)\n            loader.set_postfix(loss=\"{:.3f}\".format(loss.all))\n            if it==0: self.visualize(opt,var,step=ep,split=\"val\")\n        for key in loss_val: loss_val[key] /= len(self.test_data)\n        self.log_scalars(opt,var,loss_val,step=ep,split=\"val\")\n        log.loss_val(opt,loss_val.all)\n\n    @torch.no_grad()\n    def log_scalars(self,opt,var,loss,metric=None,step=0,split=\"train\"):\n        for key,value in loss.items():\n            if key==\"all\": continue\n            if opt.loss_weight[key] is not None:\n                self.tb.add_scalar(\"{0}/loss_{1}\".format(split,key),value,step)\n        if metric is not None:\n            for key,value in metric.items():\n                self.tb.add_scalar(\"{0}/{1}\".format(split,key),value,step)\n\n    @torch.no_grad()\n    def visualize(self,opt,var,step=0,split=\"train\"):\n        raise NotImplementedError\n\n    def save_checkpoint(self,opt,ep=0,it=0,latest=False):\n        util.save_checkpoint(opt,self,ep=ep,it=it,latest=latest)\n        if not latest:\n            log.info(\"checkpoint saved: ({0}) {1}, epoch {2} (iteration {3})\".format(opt.group,opt.name,ep,it))\n\n# ============================ computation graph for forward/backprop ============================\n\nclass Graph(torch.nn.Module):\n\n    def __init__(self,opt):\n        super().__init__()\n\n    def forward(self,opt,var,mode=None):\n        raise NotImplementedError\n        return var\n\n    def compute_loss(self,opt,var,mode=None):\n        loss = edict()\n        raise NotImplementedError\n        return loss\n\n    def L1_loss(self,pred,label=0):\n        loss = (pred.contiguous()-label).abs()\n        return loss.mean()\n    def MSE_loss(self,pred,label=0):\n        loss = (pred.contiguous()-label)**2\n        return loss.mean()\n"
  },
  {
    "path": "model/l2g_nerf.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport tqdm\nfrom easydict import EasyDict as edict\nimport visdom\nimport matplotlib.pyplot as plt\n\nimport util,util_vis\nfrom util import log,debug\nfrom . import nerf\nimport camera\n\nimport roma\n# ============================ main engine for training and evaluation ============================\n\nclass Model(nerf.Model):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n\n    def build_networks(self,opt):\n        super().build_networks(opt)\n        if opt.camera.noise:\n            # pre-generate synthetic pose perturbation\n            so3_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_r\n            t_noise = torch.randn(len(self.train_data),3,device=opt.device)*opt.camera.noise_t\n            self.graph.pose_noise = torch.cat([camera.lie.so3_to_SO3(so3_noise),t_noise[...,None]],dim=-1) # [...,3,4]\n\n        self.graph.warp_embedding = torch.nn.Embedding(len(self.train_data),opt.arch.embedding_dim).to(opt.device)\n        self.graph.warp_mlp = localWarp(opt).to(opt.device)\n\n        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)\n        # add synthetic pose perturbation to all training data\n        if opt.data.dataset==\"blender\":\n            pose = pose_GT\n            if opt.camera.noise:\n                pose = camera.pose.compose([pose, self.graph.pose_noise])\n        else: pose = self.graph.pose_eye[None].repeat(len(self.train_data),1,1)\n        # use Embedding so it could be checkpointed\n        self.graph.optimised_training_poses = torch.nn.Embedding(len(self.train_data),12,_weight=pose.view(-1,12)).to(opt.device)\n\n        # auto near/far for blender dataset\n        if opt.data.dataset==\"blender\":\n            idx_range = torch.arange(len(self.train_data),dtype=torch.long,device=opt.device)\n            idx_X,idx_Y = torch.meshgrid(idx_range,idx_range)\n            self.graph.idx_grid = torch.stack([idx_X,idx_Y],dim=-1).view(-1,2)\n\n        if opt.error_map_size:\n            self.graph.error_map = torch.ones([len(self.train_data), opt.error_map_size*opt.error_map_size], dtype=torch.float).to(opt.device)\n\n    def setup_optimizer(self,opt):\n        super().setup_optimizer(opt)\n        optimizer = getattr(torch.optim,opt.optim.algo)\n        self.optim_pose = optimizer([dict(params=self.graph.warp_embedding.parameters(),lr=opt.optim.lr_pose)])\n        self.optim_pose.add_param_group(dict(params=self.graph.warp_mlp.parameters(),lr=opt.optim.lr_pose))\n        # set up scheduler\n        if opt.optim.sched_pose:\n            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_pose.type)\n            if opt.optim.lr_pose_end:\n                assert(opt.optim.sched_pose.type==\"ExponentialLR\")\n                opt.optim.sched_pose.gamma = (opt.optim.lr_pose_end/opt.optim.lr_pose)**(1./opt.max_iter)\n            kwargs = { k:v for k,v in opt.optim.sched_pose.items() if k!=\"type\" }\n            self.sched_pose = scheduler(self.optim_pose,**kwargs)\n\n    def train_iteration(self,opt,var,loader):\n        self.optim_pose.zero_grad()\n        if opt.optim.warmup_pose:\n            # simple linear warmup of pose learning rate\n            self.optim_pose.param_groups[0][\"lr_orig\"] = self.optim_pose.param_groups[0][\"lr\"] # cache the original learning rate\n            self.optim_pose.param_groups[0][\"lr\"] *= min(1,self.it/opt.optim.warmup_pose)\n        loss = super().train_iteration(opt,var,loader)\n        self.optim_pose.step()\n        if opt.optim.warmup_pose:\n            self.optim_pose.param_groups[0][\"lr\"] = self.optim_pose.param_groups[0][\"lr_orig\"] # reset learning rate\n        if opt.optim.sched_pose: self.sched_pose.step()\n        self.graph.nerf.progress.data.fill_(self.it/opt.max_iter)\n        if opt.nerf.fine_sampling:\n            self.graph.nerf_fine.progress.data.fill_(self.it/opt.max_iter)\n        return loss\n\n    @torch.no_grad()\n    def validate(self,opt,ep=None):\n        pose,pose_GT = self.get_all_training_poses(opt)\n        _,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)\n        super().validate(opt,ep=ep)\n\n    @torch.no_grad()\n    def log_scalars(self,opt,var,loss,metric=None,step=0,split=\"train\"):\n        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)\n        if split==\"train\":\n            # log learning rate\n            lr = self.optim_pose.param_groups[0][\"lr\"]\n            self.tb.add_scalar(\"{0}/{1}\".format(split,\"lr_pose\"),lr,step)\n        # compute pose error\n        if split==\"train\" and opt.data.dataset in [\"blender\",\"llff\"]:\n            pose,pose_GT = self.get_all_training_poses(opt)\n            pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)\n            error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)\n            self.tb.add_scalar(\"{0}/error_R\".format(split),error.R.mean(),step)\n            self.tb.add_scalar(\"{0}/error_t\".format(split),error.t.mean(),step)\n\n    @torch.no_grad()\n    def visualize(self,opt,var,step=0,split=\"train\"):\n        super().visualize(opt,var,step=step,split=split)\n        if opt.visdom:\n            if split==\"val\":\n                pose,pose_GT = self.get_all_training_poses(opt)\n                pose_aligned,_ = self.prealign_cameras(opt,pose,pose_GT)\n                util_vis.vis_cameras(opt,self.vis,step=step,poses=[pose_aligned,pose_GT])\n\n    @torch.no_grad()\n    def get_all_training_poses(self,opt):\n        # get ground-truth (canonical) camera poses\n        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)\n        pose = self.graph.optimised_training_poses.weight.data.detach().clone().view(-1,3,4)\n        return pose,pose_GT\n\n    @torch.no_grad()\n    def prealign_cameras(self,opt,pose,pose_GT):\n        # compute 3D similarity transform via Procrustes analysis\n        center = torch.zeros(1,1,3,device=opt.device)\n        center_pred = camera.cam2world(center,pose)[:,0] # [N,3]\n        center_GT = camera.cam2world(center,pose_GT)[:,0] # [N,3]\n        try:\n            sim3 = camera.procrustes_analysis(center_GT,center_pred)\n        except:\n            print(\"warning: SVD did not converge...\")\n            sim3 = edict(t0=0,t1=0,s0=1,s1=1,R=torch.eye(3,device=opt.device))\n        # align the camera poses\n        center_aligned = (center_pred-sim3.t1)/sim3.s1@sim3.R.t()*sim3.s0+sim3.t0\n        R_aligned = pose[...,:3]@sim3.R.t()\n        t_aligned = (-R_aligned@center_aligned[...,None])[...,0]\n        pose_aligned = camera.pose(R=R_aligned,t=t_aligned)\n        return pose_aligned,sim3\n\n    @torch.no_grad()\n    def evaluate_camera_alignment(self,opt,pose_aligned,pose_GT):\n        # measure errors in rotation and translation\n        R_aligned,t_aligned = pose_aligned.split([3,1],dim=-1)\n        R_GT,t_GT = pose_GT.split([3,1],dim=-1)\n        R_error = camera.rotation_distance(R_aligned,R_GT)\n        t_error = (t_aligned-t_GT)[...,0].norm(dim=-1)\n        error = edict(R=R_error,t=t_error)\n        return error\n\n    @torch.no_grad()\n    def evaluate_full(self,opt):\n        self.graph.eval()\n        # evaluate rotation/translation\n        pose,pose_GT = self.get_all_training_poses(opt)\n        pose_aligned,self.graph.sim3 = self.prealign_cameras(opt,pose,pose_GT)\n        error = self.evaluate_camera_alignment(opt,pose_aligned,pose_GT)\n        print(\"--------------------------\")\n        print(\"rot:   {:8.3f}\".format(np.rad2deg(error.R.mean().cpu())))\n        print(\"trans: {:10.5f}\".format(error.t.mean()))\n        print(\"--------------------------\")\n        # dump numbers\n        quant_fname = \"{}/quant_pose.txt\".format(opt.output_path)\n        with open(quant_fname,\"w\") as file:\n            for i,(err_R,err_t) in enumerate(zip(error.R,error.t)):\n                file.write(\"{} {} {}\\n\".format(i,err_R.item(),err_t.item()))\n        # evaluate novel view synthesis\n        super().evaluate_full(opt)\n\n    @torch.enable_grad()\n    def evaluate_test_time_photometric_optim(self,opt,var):\n        # use another se3 Parameter to absorb the remaining pose errors\n        var.se3_refine_test = torch.nn.Parameter(torch.zeros(1,6,device=opt.device))\n        optimizer = getattr(torch.optim,opt.optim.algo)\n        optim_pose = optimizer([dict(params=[var.se3_refine_test],lr=opt.optim.lr_pose)])\n        iterator = tqdm.trange(opt.optim.test_iter,desc=\"test-time optim.\",leave=False,position=1)\n        for it in iterator:\n            optim_pose.zero_grad()\n            var.pose_refine_test = camera.lie.se3_to_SE3(var.se3_refine_test)\n            var = self.graph.forward(opt,var,mode=\"test-optim\")\n            loss = self.graph.compute_loss(opt,var,mode=\"test-optim\")\n            loss = self.summarize_loss(opt,var,loss)\n            loss.all.backward()\n            optim_pose.step()\n            iterator.set_postfix(loss=\"{:.3f}\".format(loss.all))\n        return var\n\n    @torch.no_grad()\n    def generate_videos_pose(self,opt):\n        self.graph.eval()\n        fig = plt.figure(figsize=(10,10) if opt.data.dataset==\"blender\" else (16,8))\n        cam_path = \"{}/poses\".format(opt.output_path)\n        os.makedirs(cam_path,exist_ok=True)\n        ep_list = []\n        for ep in range(0,opt.max_iter+1,opt.freq.ckpt):\n            # load checkpoint (0 is random init)\n            if ep!=0:\n                try: util.restore_checkpoint(opt,self,resume=ep)\n                except: continue\n            # get the camera poses\n            pose,pose_ref = self.get_all_training_poses(opt)\n            if opt.data.dataset in [\"blender\",\"llff\"]:\n                pose_aligned,_ = self.prealign_cameras(opt,pose,pose_ref)\n                pose_aligned,pose_ref = pose_aligned.detach().cpu(),pose_ref.detach().cpu()\n                dict(\n                    blender=util_vis.plot_save_poses_blender,\n                    llff=util_vis.plot_save_poses,\n                )[opt.data.dataset](opt,fig,pose_aligned,pose_ref=pose_ref,path=cam_path,ep=ep)\n            else:\n                pose = pose.detach().cpu()\n                util_vis.plot_save_poses(opt,fig,pose,pose_ref=None,path=cam_path,ep=ep)\n            ep_list.append(ep)\n        plt.close()\n        # write videos\n        print(\"writing videos...\")\n        list_fname = \"{}/temp.list\".format(cam_path)\n        with open(list_fname,\"w\") as file:\n            for ep in ep_list: file.write(\"file {}.png\\n\".format(ep))\n        cam_vid_fname = \"{}/poses.mp4\".format(opt.output_path)\n        os.system(\"ffmpeg -y -r 4 -f concat -i {0} -pix_fmt yuv420p {1} >/dev/null 2>&1\".format(list_fname,cam_vid_fname))\n        os.remove(list_fname)\n\n# ============================ computation graph for forward/backprop ============================\n\nclass Graph(nerf.Graph):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        self.nerf = NeRF(opt)\n        if opt.nerf.fine_sampling:\n            self.nerf_fine = NeRF(opt)\n        self.pose_eye = torch.eye(3,4).to(opt.device)\n\n    def get_pose(self,opt,var,mode=None):\n        if mode==\"train\":\n            # add the pre-generated pose perturbations\n            if opt.data.dataset==\"blender\":\n                if opt.camera.noise:\n                    var.pose_noise = self.pose_noise[var.idx]\n                    pose = camera.pose.compose([var.pose, var.pose_noise])\n                else: pose = var.pose\n            else: pose = self.pose_eye[None]\n            # add learnable pose correction\n            batch_size = len(var.idx)\n\n            if opt.error_map_size:\n                num_points = var.ray_idx.shape[1]\n                camera_cords_grid_3D = camera.gather_camera_cords_grid_3D(opt,batch_size,intr=var.intr,ray_idx=var.ray_idx).detach()\n            else:\n                num_points = len(var.ray_idx)\n                camera_cords_grid_3D = camera.get_camera_cords_grid_3D(opt,batch_size,intr=var.intr,ray_idx=var.ray_idx).detach()\n\n            camera_cords_grid_2D = camera_cords_grid_3D[...,:2]\n            embedding = self.warp_embedding.weight[var.idx,None,:].expand(-1,num_points,-1)\n            local_se3_refine = self.warp_mlp(opt,torch.cat((camera_cords_grid_2D,embedding),dim=-1))\n            local_pose_refine = camera.lie.se3_to_SE3(local_se3_refine)\n            local_pose = camera.pose.compose([local_pose_refine, pose[:,None,...]])\n            return local_pose\n\n        elif mode in [\"val\",\"eval\",\"test-optim\"]:\n            # align test pose to refined coordinate system (up to sim3)\n            sim3 = self.sim3\n            center = torch.zeros(1,1,3,device=opt.device)\n            center = camera.cam2world(center,var.pose)[:,0] # [N,3]\n            center_aligned = (center-sim3.t0)/sim3.s0@sim3.R*sim3.s1+sim3.t1\n            R_aligned = var.pose[...,:3]@self.sim3.R\n            t_aligned = (-R_aligned@center_aligned[...,None])[...,0]\n            pose = camera.pose(R=R_aligned,t=t_aligned)\n            if opt.optim.test_photo and mode!=\"val\":\n                pose = camera.pose.compose([var.pose_refine_test, pose])\n\n        else: pose = var.pose\n        return pose\n\n    def forward(self,opt,var,mode=None):\n        # rescale the size of the scene condition on the pose\n        if opt.data.dataset==\"blender\":\n            depth_min,depth_max = opt.nerf.depth.range\n            position = camera.Pose().invert(self.optimised_training_poses.weight.data.detach().clone().view(-1,3,4))[...,-1]\n            diameter = ((position[self.idx_grid[...,0]]-position[self.idx_grid[...,1]]).norm(dim=-1)).max()\n            depth_min_new = (depth_min/(depth_max+depth_min))*diameter\n            depth_max_new = (depth_max/(depth_max+depth_min))*diameter\n            opt.nerf.depth.range = [depth_min_new, depth_max_new]\n\n        # render images\n        batch_size = len(var.idx)\n        if opt.nerf.rand_rays and mode in [\"train\"]:\n            # sample rays for optimization\n            if opt.error_map_size:\n                sample_weight = self.error_map + 2*self.error_map.mean(-1,keepdim=True) # 1/3 importance + 2/3 random\n                var.ray_idx_coarse = torch.multinomial(sample_weight, opt.nerf.rand_rays//batch_size, replacement=False) # [B, N], but in [0, opt.error_map_size*opt.error_map_size)\n                inds_x, inds_y = var.ray_idx_coarse // opt.error_map_size, var.ray_idx_coarse % opt.error_map_size # `//` will throw a warning in torch 1.10... anyway.\n                sx, sy = opt.H / opt.error_map_size, opt.W / opt.error_map_size\n                inds_x = (inds_x * sx + torch.rand(batch_size, opt.nerf.rand_rays//batch_size, device=opt.device) * sx).long().clamp(max=opt.H - 1)\n                inds_y = (inds_y * sy + torch.rand(batch_size, opt.nerf.rand_rays//batch_size, device=opt.device) * sy).long().clamp(max=opt.W - 1)\n                var.ray_idx = inds_x * opt.W + inds_y\n            else:\n                var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]# 3/3 random\n\n            local_pose = self.get_pose(opt,var,mode=mode)\n            ret = self.local_render(opt,local_pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]\n        elif opt.nerf.rand_rays and mode in [\"test-optim\"]:\n            # sample random rays for optimization\n            var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]\n            pose = self.get_pose(opt,var,mode=mode)\n            ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]\n        else:\n            # render full image (process in slices)\n            pose = self.get_pose(opt,var,mode=mode)\n            ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \\\n                  self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]\n        var.update(ret)\n        return var\n\n    def compute_loss(self,opt,var,mode=None):\n        loss = edict()\n        batch_size = len(var.idx)\n        image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1)\n        if opt.nerf.rand_rays and mode in [\"train\",\"test-optim\"]:\n            if opt.error_map_size:\n                image = torch.gather(image, 1, var.ray_idx[...,None].expand(-1,-1,3))\n            else:\n                image = image[:,var.ray_idx]\n\n        # compute image losses\n        if opt.loss_weight.render is not None:\n            render_error = ((var.rgb-image)**2).mean(-1)\n            loss.render = render_error.mean()\n            if mode in [\"train\"] and opt.error_map_size:\n                ema_error = 0.1 * torch.gather(self.error_map, 1, var.ray_idx_coarse) + 0.9 * render_error.detach()\n                self.error_map.scatter_(1, var.ray_idx_coarse, ema_error)\n\n        if opt.loss_weight.render_fine is not None:\n            assert(opt.nerf.fine_sampling)\n            loss.render_fine = self.MSE_loss(var.rgb_fine,image)\n\n        # global alignment\n        if mode in [\"train\"]:\n            source = torch.cat((var.camera_grid_3D,var.camera_center),dim=1)\n            target = torch.cat((var.grid_3D,var.center),dim=1)\n            R_global, t_global = roma.rigid_points_registration(target, source)\n            svd_poses = torch.cat((R_global,t_global[...,None]),-1)\n            self.optimised_training_poses.weight.data = svd_poses.detach().clone().view(-1,12)\n            if opt.loss_weight.global_alignment is not None:\n                loss.global_alignment = self.MSE_loss(target,camera.cam2world(source,svd_poses))\n\n        return loss\n\n    def local_render(self,opt,local_pose,intr=None,ray_idx=None,mode=None):\n        batch_size = len(local_pose)\n        if opt.error_map_size:\n            camera_grid_3D = camera.gather_camera_cords_grid_3D(opt,batch_size,intr=intr,ray_idx=ray_idx).detach()\n        else:\n            camera_grid_3D = camera.get_camera_cords_grid_3D(opt,batch_size,intr=intr,ray_idx=ray_idx).detach()\n\n        camera_center = torch.zeros_like(camera_grid_3D) # [B,HW,3]\n        grid_3D = camera.cam2world(camera_grid_3D[...,None,:],local_pose)[...,0,:] # [B,HW,3]\n        center = camera.cam2world(camera_center[...,None,:],local_pose)[...,0,:] # [B,HW,3]\n        ray = grid_3D-center # [B,HW,3]\n        ret = edict(camera_grid_3D=camera_grid_3D, camera_center=camera_center, grid_3D=grid_3D, center=center) # [B,HW,K] use for global alignment\n        if opt.camera.ndc:\n            # convert center/ray representations to NDC\n            center,ray = camera.convert_NDC(opt,center,ray,intr=intr)\n        # render with main MLP\n        depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1]\n        rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode)\n        rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples)\n        ret.update(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K]\n        # render with fine MLP from coarse MLP\n        if opt.nerf.fine_sampling:\n            with torch.no_grad():\n                # resample depth acoording to coarse empirical distribution\n                depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1]\n                depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1]\n                depth_samples = depth_samples.sort(dim=2).values\n            rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode)\n            rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples)\n            ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K]\n        return ret\n\nclass NeRF(nerf.NeRF):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed\n\n    def positional_encoding(self,opt,input,L): # [B,...,N]\n        input_enc = super().positional_encoding(opt,input,L=L) # [B,...,2NL]\n        # coarse-to-fine: smoothly mask positional encoding for BARF\n        if opt.barf_c2f is not None:\n            # set weights for different frequency bands\n            start,end = opt.barf_c2f\n            alpha = (self.progress.data-start)/(end-start)*L\n            k = torch.arange(L,dtype=torch.float32,device=opt.device)\n            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2\n            # apply weights\n            shape = input_enc.shape\n            input_enc = (input_enc.view(-1,L)*weight).view(*shape)\n        return input_enc\n\nclass localWarp(torch.nn.Module):\n    def __init__(self, opt):\n        super().__init__()\n        # point-wise se3 prediction\n        input_2D_dim = 2\n        self.mlp_warp = torch.nn.ModuleList()\n        L = util.get_layer_dims(opt.arch.layers_warp)\n        for li,(k_in,k_out) in enumerate(L):\n            if li==0: k_in = input_2D_dim+opt.arch.embedding_dim\n            if li in opt.arch.skip_warp: k_in += input_2D_dim+opt.arch.embedding_dim\n            linear = torch.nn.Linear(k_in,k_out)\n            self.mlp_warp.append(linear)\n\n    def forward(self,opt,uvf):\n        feat = uvf\n        for li,layer in enumerate(self.mlp_warp):\n            if li in opt.arch.skip_warp: feat = torch.cat([feat,uvf],dim=-1)\n            feat = layer(feat)\n            if li!=len(self.mlp_warp)-1:\n                feat = torch_F.relu(feat)\n        warp = feat\n        return warp"
  },
  {
    "path": "model/l2g_planar.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport tqdm\nfrom easydict import EasyDict as edict\nimport PIL\nimport PIL.Image,PIL.ImageDraw\nimport imageio\n\nimport util,util_vis\nfrom util import log,debug\nfrom . import base\nimport warp\nimport roma\nfrom kornia.geometry.homography import find_homography_dlt\n# ============================ main engine for training and evaluation ============================\n\nclass Model(base.Model):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        opt.H_crop,opt.W_crop = opt.data.patch_crop\n\n    def load_dataset(self,opt,eval_split=None):\n        image_raw = PIL.Image.open(opt.data.image_fname).convert('RGB')\n        self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device)\n\n    def build_networks(self,opt):\n        super().build_networks(opt)\n        self.graph.warp_embedding = torch.nn.Embedding(opt.batch_size,opt.arch.embedding_dim).to(opt.device)\n        self.graph.warp_mlp = localWarp(opt).to(opt.device)\n\n    def setup_optimizer(self,opt):\n        log.info(\"setting up optimizers...\")\n        optim_list = [\n            dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr),\n        ]\n        optimizer = getattr(torch.optim,opt.optim.algo)\n        self.optim = optimizer(optim_list)\n        # set up scheduler\n        if opt.optim.sched:\n            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)\n            if opt.optim.lr_end:\n                assert(opt.optim.sched.type==\"ExponentialLR\")\n                opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter)\n            kwargs = { k:v for k,v in opt.optim.sched.items() if k!=\"type\" }\n            self.sched = scheduler(self.optim,**kwargs)\n\n        # warp \n        self.optim_warp = optimizer([dict(params=self.graph.warp_embedding.parameters(),lr=opt.optim.lr_warp)])\n        self.optim_warp.add_param_group(dict(params=self.graph.warp_mlp.parameters(),lr=opt.optim.lr_warp))\n        # set up scheduler\n        if opt.optim.sched_warp:\n            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched_warp.type)\n            if opt.optim.lr_warp_end:\n                assert(opt.optim.sched_warp.type==\"ExponentialLR\")\n                opt.optim.sched_warp.gamma = (opt.optim.lr_warp_end/opt.optim.lr_warp)**(1./opt.max_iter)\n            kwargs = { k:v for k,v in opt.optim.sched_warp.items() if k!=\"type\" }\n            self.sched_warp = scheduler(self.optim_warp,**kwargs)\n\n\n    def setup_visualizer(self,opt):\n        super().setup_visualizer(opt)\n        # set colors for visualization\n        box_colors = [\"#ff0000\",\"#40afff\",\"#9314ff\",\"#ffd700\",\"#00ff00\"]\n        box_colors = list(map(util.colorcode_to_number,box_colors))\n        self.box_colors = np.array(box_colors).astype(int)\n        assert(len(self.box_colors)==opt.batch_size)\n        # create visualization directory\n        self.vis_path = \"{}/vis\".format(opt.output_path)\n        os.makedirs(self.vis_path,exist_ok=True)\n        self.video_fname = \"{}/vis.mp4\".format(opt.output_path)\n\n    def train(self,opt):\n        # before training\n        log.title(\"TRAINING START\")\n        self.timer = edict(start=time.time(),it_mean=None)\n        self.ep = self.it = self.vis_it = 0\n        self.graph.train()\n        var = edict(idx=torch.arange(opt.batch_size))\n        # pre-generate perturbations\n        self.homo_pert, self.rot_pert, self.trans_pert, var.image_pert = self.generate_warp_perturbation(opt)\n        # train\n        var = util.move_to_device(var,opt.device)\n        loader = tqdm.trange(opt.max_iter,desc=\"training\",leave=False)\n        # visualize initial state\n        var = self.graph.forward(opt,var)\n        self.visualize(opt,var,step=0)\n        for it in loader:\n            # train iteration\n            loss = self.train_iteration(opt,var,loader)\n        # after training\n        os.system(\"ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}\".format(self.vis_path,self.video_fname))\n        self.save_checkpoint(opt,ep=None,it=self.it)\n        if opt.tb:\n            self.tb.flush()\n            self.tb.close()\n        if opt.visdom: self.vis.close()\n        log.title(\"TRAINING DONE\")\n\n    def train_iteration(self,opt,var,loader):\n        self.optim_warp.zero_grad()\n        loss = super().train_iteration(opt,var,loader)\n        self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter)\n        self.optim_warp.step()\n        if opt.optim.sched_warp: self.sched_warp.step()\n        if opt.optim.sched: self.sched.step()\n        return loss\n\n    def generate_warp_perturbation(self,opt):\n        # pre-generate perturbations (translational noise + homography noise)\n        def create_random_perturbation(batch_size):\n            if opt.warp.dof==1:\n                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h\n                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r\n                trans_pert = torch.zeros(batch_size,2,device=opt.device)*opt.warp.noise_t\n            elif opt.warp.dof==2:\n                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h\n                rot_pert = torch.zeros(batch_size,1,device=opt.device)*opt.warp.noise_r\n                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t\n            elif opt.warp.dof==3:\n                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h\n                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r\n                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t\n            elif opt.warp.dof==8:\n                homo_pert = torch.randn(batch_size,8,device=opt.device)*opt.warp.noise_h\n                homo_pert[:,:2]=0\n                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r\n                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t\n            else: assert(False)\n            return homo_pert, rot_pert, trans_pert\n\n        homo_pert = torch.zeros(opt.batch_size,8,device=opt.device)\n        rot_pert = torch.zeros(opt.batch_size,1,device=opt.device)\n        trans_pert = torch.zeros(opt.batch_size,2,device=opt.device)\n        for i in range(opt.batch_size):\n            homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)\n            while not warp.check_corners_in_range_compose(opt, homo_pert_i, rot_pert_i, trans_pert_i):\n                homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)\n            homo_pert[i], rot_pert[i], trans_pert[i] = homo_pert_i, rot_pert_i, trans_pert_i\n\n        if opt.warp.fix_first:\n            homo_pert[0],rot_pert[0],trans_pert[0] = 0,0,0\n\n        # create warped image patches\n        xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2]\n\n        xy_grid_hom = warp.camera.to_hom(xy_grid)\n        warp_matrix = warp.lie.sl3_to_SL3(homo_pert)\n        warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)\n        xy_grid_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]\n        warp_matrix = warp.lie.so2_to_SO2(rot_pert)\n        xy_grid_warped = xy_grid_warped@warp_matrix.transpose(-2,-1) # [B,HW,2]\n        xy_grid_warped = xy_grid_warped+trans_pert[...,None,:]\n\n        xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2])\n        xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W,\n                                      xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1)\n        image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1)\n        image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False)\n\n        return homo_pert, rot_pert, trans_pert, image_pert_all\n\n    def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):\n        image_pil = torchvision_F.to_pil_image(self.image_raw).convert(\"RGBA\")\n        draw_pil = PIL.Image.new(\"RGBA\",image_pil.size,(0,0,0,0))\n        draw = PIL.ImageDraw.Draw(draw_pil)\n        corners_all = warp.warp_corners_compose(opt, homo_pert, rot_pert, trans_pert)\n        corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n        corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n        for i,corners in enumerate(corners_all):\n            P = [tuple(float(n) for n in corners[j]) for j in range(4)]\n            draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)\n        image_pil.alpha_composite(draw_pil)\n        image_tensor = torchvision_F.to_tensor(image_pil.convert(\"RGB\"))\n        return image_tensor\n\n    def visualize_patches_use_matrix(self,opt,warp_matrix):\n        image_pil = torchvision_F.to_pil_image(self.image_raw).convert(\"RGBA\")\n        draw_pil = PIL.Image.new(\"RGBA\",image_pil.size,(0,0,0,0))\n        draw = PIL.ImageDraw.Draw(draw_pil)\n        corners_all = warp.warp_corners_use_matrix(opt,warp_matrix)\n        corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n        corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n        for i,corners in enumerate(corners_all):\n            P = [tuple(float(n) for n in corners[j]) for j in range(4)]\n            draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)\n        image_pil.alpha_composite(draw_pil)\n        image_tensor = torchvision_F.to_tensor(image_pil.convert(\"RGB\"))\n        return image_tensor\n\n    @torch.no_grad()\n    def predict_entire_image(self,opt):\n        xy_grid = warp.get_normalized_pixel_grid(opt)[:1]\n        rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3]\n        image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1)\n        return image\n\n    @torch.no_grad()\n    def log_scalars(self,opt,var,loss,metric=None,step=0,split=\"train\"):\n        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)\n        # compute PSNR\n        psnr = -10*loss.render.log10()\n        self.tb.add_scalar(\"{0}/{1}\".format(split,\"PSNR\"),psnr,step)\n        # warp error\n        pred_corners = warp.warp_corners_use_matrix(opt,self.graph.warp_matrix)\n        pred_corners[...,0] = (pred_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n        pred_corners[...,1] = (pred_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n        \n        gt_corners = warp.warp_corners_compose(opt, self.homo_pert, self.rot_pert, self.trans_pert)\n        gt_corners[...,0] = (gt_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n        gt_corners[...,1] = (gt_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n\n        warp_error = (pred_corners-gt_corners).norm(dim=-1).mean()\n        self.tb.add_scalar(\"{0}/{1}\".format(split,\"warp error\"),warp_error,step)\n\n\n    @torch.no_grad()\n    def visualize(self,opt,var,step=0,split=\"train\"):\n        # dump frames for writing to video\n        frame_GT = self.visualize_patches_compose(opt,self.homo_pert, self.rot_pert, self.trans_pert)\n        frame = self.visualize_patches_use_matrix(opt,self.graph.warp_matrix)\n        frame2 = self.predict_entire_image(opt)\n        frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy()\n        imageio.imsave(\"{}/{}.png\".format(self.vis_path,self.vis_it),frame_cat)\n        self.vis_it += 1\n        # visualize in Tensorboard\n        if opt.tb:\n            colors = self.box_colors\n            util_vis.tb_image(opt,self.tb,step,split,\"image_pert\",util_vis.color_border(var.image_pert,colors))\n            util_vis.tb_image(opt,self.tb,step,split,\"rgb_warped\",util_vis.color_border(var.rgb_warped_map,colors))\n            util_vis.tb_image(opt,self.tb,self.it+1,\"train\",\"image_boxes\",frame[None])\n            util_vis.tb_image(opt,self.tb,self.it+1,\"train\",\"image_boxes_GT\",frame_GT[None])\n            util_vis.tb_image(opt,self.tb,self.it+1,\"train\",\"image_entire\",frame2[None])\n            local_warp_field = torch.cat([(var.local_warp_field+1)/2,torch.zeros_like(var.local_warp_field[:,:1,...])],dim=1) # [B,3,H,W]\n            global_warp_field = torch.cat([(var.global_warp_field+1)/2,torch.zeros_like(var.global_warp_field[:,:1,...])],dim=1) # [B,3,H,W]            \n            util_vis.tb_image(opt,self.tb,step,split,\"warp_field_local\",util_vis.color_border(local_warp_field,colors))\n            util_vis.tb_image(opt,self.tb,step,split,\"warp_field_global\",util_vis.color_border(global_warp_field,colors))\n\n# ============================ computation graph for forward/backprop ============================\n\nclass Graph(base.Graph):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        self.neural_image = NeuralImageFunction(opt)\n\n    def forward(self,opt,var,mode=None):\n        # warp\n        xy_grid = warp.get_normalized_pixel_grid_crop(opt)\n        warp_embedding = self.warp_embedding.weight[:,None,:].expand(-1,xy_grid.shape[1],-1)\n        local_warp_param = self.warp_mlp(opt,torch.cat((xy_grid,warp_embedding),dim=-1))\n        if opt.warp.fix_first:\n            local_warp_param[0] = 0\n\n        if opt.warp.dof==1:\n            local_warp_matrix = warp.lie.so2_to_SO2(local_warp_param)\n            var.local_warped_grid = (xy_grid[...,None,:]@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]\n            self.warp_matrix = roma.rigid_vectors_registration(xy_grid, var.local_warped_grid)\n            var.global_warped_grid = xy_grid@self.warp_matrix.transpose(-2,-1) # [B,HW,2]\n        elif opt.warp.dof==2:\n            var.local_warped_grid = xy_grid + local_warp_param\n            self.warp_matrix = var.local_warped_grid.mean(-2)-xy_grid.mean(-2)\n            var.global_warped_grid = xy_grid + self.warp_matrix[...,None,:]\n        elif opt.warp.dof==3:\n            xy_grid_hom = warp.camera.to_hom(xy_grid[...,None,:])\n            local_warp_matrix = warp.lie.se2_to_SE2(local_warp_param)\n            var.local_warped_grid = (xy_grid_hom@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]\n            R_global, t_global = roma.rigid_points_registration(xy_grid, var.local_warped_grid)\n            self.warp_matrix = torch.cat((R_global,t_global[...,None]),-1)\n            xy_grid_hom = warp.camera.to_hom(xy_grid)\n            var.global_warped_grid = xy_grid_hom@self.warp_matrix.transpose(-2,-1) # [B,HW,2]\n        elif opt.warp.dof==8:\n            xy_grid_hom = warp.camera.to_hom(xy_grid[...,None,:])\n            local_warp_matrix = warp.lie.se2_to_SE2(local_warp_param)\n            var.local_warped_grid = (xy_grid_hom@local_warp_matrix.transpose(-2,-1))[...,0,:] # [B,HW,2]\n            self.warp_matrix = find_homography_dlt(xy_grid, var.local_warped_grid)\n            xy_grid_hom = warp.camera.to_hom(xy_grid)\n            global_warped_grid = xy_grid_hom@self.warp_matrix.transpose(-2,-1)\n            var.global_warped_grid = global_warped_grid[...,:2]/(global_warped_grid[...,2:]+1e-8) # [B,HW,2]\n\n        # render images\n        var.rgb_warped = self.neural_image.forward(opt,var.local_warped_grid) # [B,HW,3]\n        var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W]\n        var.local_warp_field = var.local_warped_grid.view(opt.batch_size,opt.H_crop,opt.W_crop,2).permute(0,3,1,2) # [B,2,H,W]\n        var.global_warp_field = var.global_warped_grid.view(opt.batch_size,opt.H_crop,opt.W_crop,2).permute(0,3,1,2) # [B,2,H,W]\n        return var\n\n    def compute_loss(self,opt,var,mode=None):\n        loss = edict()\n        if opt.loss_weight.render is not None:\n            image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1)\n            loss.render = self.MSE_loss(var.rgb_warped,image_pert)\n        if opt.loss_weight.global_alignment is not None:\n            loss.global_alignment = self.MSE_loss(var.local_warped_grid,var.global_warped_grid)\n        return loss\n\nclass NeuralImageFunction(torch.nn.Module):\n\n    def __init__(self,opt):\n        super().__init__()\n        self.define_network(opt)\n        self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed\n\n    def define_network(self,opt):\n        input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2\n        # point-wise RGB prediction\n        self.mlp = torch.nn.ModuleList()\n        L = util.get_layer_dims(opt.arch.layers)\n        for li,(k_in,k_out) in enumerate(L):\n            if li==0: k_in = input_2D_dim\n            if li in opt.arch.skip: k_in += input_2D_dim\n            linear = torch.nn.Linear(k_in,k_out)\n            if opt.barf_c2f and li==0:\n                # rescale first layer init (distribution was for pos.enc. but only xy is first used)\n                scale = np.sqrt(input_2D_dim/2.)\n                linear.weight.data *= scale\n                linear.bias.data *= scale\n            self.mlp.append(linear)\n\n    def forward(self,opt,coord_2D): # [B,...,3]\n        if opt.arch.posenc:\n            points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D)\n            points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3]\n        else: points_enc = coord_2D\n        feat = points_enc\n        # extract implicit features\n        for li,layer in enumerate(self.mlp):\n            if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)\n            feat = layer(feat)\n            if li!=len(self.mlp)-1:\n                feat = torch_F.relu(feat)\n        rgb = feat.sigmoid_() # [B,...,3]\n        return rgb\n\n    def positional_encoding(self,opt,input,L): # [B,...,N]\n        shape = input.shape\n        freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]\n        spectrum = input[...,None]*freq # [B,...,N,L]\n        sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]\n        input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]\n        input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]\n        # coarse-to-fine: smoothly mask positional encoding for BARF\n        if opt.barf_c2f is not None:\n            # set weights for different frequency bands\n            start,end = opt.barf_c2f\n            alpha = (self.progress.data-start)/(end-start)*L\n            k = torch.arange(L,dtype=torch.float32,device=opt.device)\n            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2\n            # apply weights\n            shape = input_enc.shape\n            input_enc = (input_enc.view(-1,L)*weight).view(*shape)\n        return input_enc\n\nclass localWarp(torch.nn.Module):\n    def __init__(self, opt):\n        super().__init__()\n        # point-wise se3 prediction\n        input_2D_dim = 2\n        self.mlp_warp = torch.nn.ModuleList()\n        L = util.get_layer_dims(opt.arch.layers_warp)\n        for li,(k_in,k_out) in enumerate(L):\n            if li==0: k_in = input_2D_dim+opt.arch.embedding_dim\n            if li in opt.arch.skip_warp: k_in += input_2D_dim+opt.arch.embedding_dim\n            linear = torch.nn.Linear(k_in,k_out)\n            self.mlp_warp.append(linear)\n\n    def forward(self,opt,uvf):\n        feat = uvf\n        for li,layer in enumerate(self.mlp_warp):\n            if li in opt.arch.skip_warp: feat = torch.cat([feat,uvf],dim=-1)\n            feat = layer(feat)\n            if li!=len(self.mlp_warp)-1:\n                feat = torch_F.relu(feat)\n        warp = feat\n        return warp"
  },
  {
    "path": "model/nerf.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport tqdm\nfrom easydict import EasyDict as edict\n\nimport lpips\nfrom external.pohsun_ssim import pytorch_ssim\n\nimport util,util_vis\nfrom util import log,debug\nfrom . import base\nimport camera\n# ============================ main engine for training and evaluation ============================\n\nclass Model(base.Model):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        self.lpips_loss = lpips.LPIPS(net=\"alex\").to(opt.device)\n\n    def load_dataset(self,opt,eval_split=\"val\"):\n        super().load_dataset(opt,eval_split=eval_split)\n        # prefetch all training data\n        self.train_data.prefetch_all_data(opt)\n        self.train_data.all = edict(util.move_to_device(self.train_data.all,opt.device))\n\n    def setup_optimizer(self,opt):\n        log.info(\"setting up optimizers...\")\n        optimizer = getattr(torch.optim,opt.optim.algo)\n        self.optim = optimizer([dict(params=self.graph.nerf.parameters(),lr=opt.optim.lr)])\n        if opt.nerf.fine_sampling:\n            self.optim.add_param_group(dict(params=self.graph.nerf_fine.parameters(),lr=opt.optim.lr))\n        # set up scheduler\n        if opt.optim.sched:\n            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)\n            if opt.optim.lr_end:\n                assert(opt.optim.sched.type==\"ExponentialLR\")\n                opt.optim.sched.gamma = (opt.optim.lr_end/opt.optim.lr)**(1./opt.max_iter)\n            kwargs = { k:v for k,v in opt.optim.sched.items() if k!=\"type\" }\n            self.sched = scheduler(self.optim,**kwargs)\n\n    def train(self,opt):\n        # before training\n        log.title(\"TRAINING START\")\n        self.timer = edict(start=time.time(),it_mean=None)\n        self.graph.train()\n        self.ep = 0 # dummy for timer\n        # training\n        if self.iter_start==0: self.validate(opt,0)\n        loader = tqdm.trange(opt.max_iter,desc=\"training\",leave=False)\n        for self.it in loader:\n            if self.it<self.iter_start: continue\n            # set var to all available images\n            var = self.train_data.all\n            self.train_iteration(opt,var,loader)\n            if opt.optim.sched: self.sched.step()\n            if self.it%opt.freq.val==0: self.validate(opt,self.it)\n            if self.it%opt.freq.ckpt==0: self.save_checkpoint(opt,ep=None,it=self.it)\n        # after training\n        if opt.tb:\n            self.tb.flush()\n            self.tb.close()\n        if opt.visdom: self.vis.close()\n        log.title(\"TRAINING DONE\")\n\n    @torch.no_grad()\n    def log_scalars(self,opt,var,loss,metric=None,step=0,split=\"train\"):\n        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)\n        # log learning rate\n        if split==\"train\":\n            lr = self.optim.param_groups[0][\"lr\"]\n            self.tb.add_scalar(\"{0}/{1}\".format(split,\"lr\"),lr,step)\n            if opt.nerf.fine_sampling:\n                lr = self.optim.param_groups[1][\"lr\"]\n                self.tb.add_scalar(\"{0}/{1}\".format(split,\"lr_fine\"),lr,step)\n        # compute PSNR\n        psnr = -10*loss.render.log10()\n        self.tb.add_scalar(\"{0}/{1}\".format(split,\"PSNR\"),psnr,step)\n        if opt.nerf.fine_sampling:\n            psnr = -10*loss.render_fine.log10()\n            self.tb.add_scalar(\"{0}/{1}\".format(split,\"PSNR_fine\"),psnr,step)\n\n    @torch.no_grad()\n    def visualize(self,opt,var,step=0,split=\"train\",eps=1e-10):\n        if opt.tb:\n            util_vis.tb_image(opt,self.tb,step,split,\"image\",var.image)\n            if not opt.nerf.rand_rays or split!=\"train\":\n                if opt.model in [\"barf\",\"l2g_nerf\"]:\n                    scale = self.graph.sim3.s1/self.graph.sim3.s0\n                else: scale = 1\n                depth=var.depth/scale\n                if opt.data.dataset==\"blender\": depth = torch.clip(depth - 1.5, min=1)\n                invdepth = (1-depth)/var.opacity if opt.camera.ndc else 1/(depth/var.opacity+eps)\n                invdepth = util_vis.apply_depth_colormap(invdepth, cmap=\"inferno\")\n                rgb_map = var.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]\n                invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]\n                util_vis.tb_image(opt,self.tb,step,split,\"rgb\",rgb_map)\n                util_vis.tb_image(opt,self.tb,step,split,\"invdepth\",invdepth_map)\n                if opt.nerf.fine_sampling:\n                    depth=var.depth_fine/scale\n                    if opt.data.dataset==\"blender\": depth = torch.clip(depth - 1.5, min=1)\n                    invdepth = (1-depth)/var.opacity_fine if opt.camera.ndc else 1/(depth/var.opacity_fine+eps)\n                    invdepth = util_vis.apply_depth_colormap(invdepth, cmap=\"inferno\")\n                    rgb_map = var.rgb_fine.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]\n                    invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]\n                    util_vis.tb_image(opt,self.tb,step,split,\"rgb_fine\",rgb_map)\n                    util_vis.tb_image(opt,self.tb,step,split,\"invdepth_fine\",invdepth_map)\n\n    @torch.no_grad()\n    def get_all_training_poses(self,opt):\n        # get ground-truth (canonical) camera poses\n        pose_GT = self.train_data.get_all_camera_poses(opt).to(opt.device)\n        return None,pose_GT\n\n    @torch.no_grad()\n    def evaluate_full(self,opt,eps=1e-10):\n        self.graph.eval()\n        loader = tqdm.tqdm(self.test_loader,desc=\"evaluating\",leave=False)\n        res = []\n        test_path = \"{}/test_view\".format(opt.output_path)\n        os.makedirs(test_path,exist_ok=True)\n        for i,batch in enumerate(loader):\n            var = edict(batch)\n            var = util.move_to_device(var,opt.device)\n            if opt.model in [\"barf\",\"l2g_nerf\"] and opt.optim.test_photo:\n                # run test-time optimization to factorize imperfection in optimized poses from view synthesis evaluation\n                var = self.evaluate_test_time_photometric_optim(opt,var)\n            var = self.graph.forward(opt,var,mode=\"eval\")\n            # evaluate view synthesis\n            if opt.model in [\"barf\",\"l2g_nerf\"]:\n                scale = self.graph.sim3.s1/self.graph.sim3.s0\n            else: scale = 1\n            depth = var.depth/scale\n            if opt.data.dataset==\"blender\": depth = torch.clip(depth - 1.5, min=1)\n            invdepth = (1-depth)/var.opacity if opt.camera.ndc else 1/(depth/var.opacity+eps)\n            invdepth = util_vis.apply_depth_colormap(invdepth, cmap=\"inferno\")\n            rgb_map = var.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]\n            invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]\n            psnr = -10*self.graph.MSE_loss(rgb_map,var.image).log10().item()\n            ssim = pytorch_ssim.ssim(rgb_map,var.image).item()\n            lpips = self.lpips_loss(rgb_map*2-1,var.image*2-1).item()\n            res.append(edict(psnr=psnr,ssim=ssim,lpips=lpips))\n            # dump novel views\n            torchvision_F.to_pil_image(rgb_map.cpu()[0]).save(\"{}/rgb_{}.png\".format(test_path,i))\n            torchvision_F.to_pil_image(var.image.cpu()[0]).save(\"{}/rgb_GT_{}.png\".format(test_path,i))\n            torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save(\"{}/depth_{}.png\".format(test_path,i))\n        # show results in terminal\n        print(\"--------------------------\")\n        print(\"PSNR:  {:8.2f}\".format(np.mean([r.psnr for r in res])))\n        print(\"SSIM:  {:8.2f}\".format(np.mean([r.ssim for r in res])))\n        print(\"LPIPS: {:8.2f}\".format(np.mean([r.lpips for r in res])))\n        print(\"--------------------------\")\n        # dump numbers to file\n        quant_fname = \"{}/quant.txt\".format(opt.output_path)\n        with open(quant_fname,\"w\") as file:\n            for i,r in enumerate(res):\n                file.write(\"{} {} {} {}\\n\".format(i,r.psnr,r.ssim,r.lpips))\n\n    @torch.no_grad()\n    def generate_videos_synthesis(self,opt,eps=1e-10):\n        self.graph.eval()\n        if opt.data.dataset==\"blender\":\n            test_path = \"{}/test_view\".format(opt.output_path)\n            # assume the test view synthesis are already generated\n            print(\"writing videos...\")\n            rgb_vid_fname = \"{}/test_view_rgb.mp4\".format(opt.output_path)\n            depth_vid_fname = \"{}/test_view_depth.mp4\".format(opt.output_path)\n            os.system(\"ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1\".format(test_path,rgb_vid_fname))\n            os.system(\"ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1\".format(test_path,depth_vid_fname))\n        else:\n            pose_pred,pose_GT = self.get_all_training_poses(opt)\n            poses = pose_pred if opt.model in [\"barf\",\"l2g_nerf\"] else pose_GT\n            if opt.model in [\"barf\",\"l2g_nerf\"] and opt.data.dataset==\"llff\":\n                _,sim3 = self.prealign_cameras(opt,pose_pred,pose_GT)\n                scale = sim3.s1/sim3.s0\n            else: scale = 1\n            # rotate novel views around the \"center\" camera of all poses\n            idx_center = (poses-poses.mean(dim=0,keepdim=True))[...,3].norm(dim=-1).argmin()\n            pose_novel = camera.get_novel_view_poses(opt,poses[idx_center],N=60,scale=scale).to(opt.device)\n            # render the novel views\n            novel_path = \"{}/novel_view\".format(opt.output_path)\n            os.makedirs(novel_path,exist_ok=True)\n            pose_novel_tqdm = tqdm.tqdm(pose_novel,desc=\"rendering novel views\",leave=False)\n            intr = edict(next(iter(self.test_loader))).intr[:1].to(opt.device) # grab intrinsics\n            for i,pose in enumerate(pose_novel_tqdm):\n                ret = self.graph.render_by_slices(opt,pose[None],intr=intr) if opt.nerf.rand_rays else \\\n                      self.graph.render(opt,pose[None],intr=intr)\n                depth=ret.depth/scale\n                if opt.data.dataset==\"blender\": depth = torch.clip(depth - 1.5, min=1)\n                invdepth = (1-depth)/ret.opacity if opt.camera.ndc else 1/(depth/ret.opacity+eps)\n                invdepth = util_vis.apply_depth_colormap(invdepth, cmap=\"inferno\")\n                rgb_map = ret.rgb.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,3,H,W]\n                invdepth_map = invdepth.view(-1,opt.H,opt.W,3).permute(0,3,1,2) # [B,1,H,W]\n                torchvision_F.to_pil_image(rgb_map.cpu()[0]).save(\"{}/rgb_{}.png\".format(novel_path,i))\n                torchvision_F.to_pil_image(invdepth_map.cpu()[0]).save(\"{}/depth_{}.png\".format(novel_path,i))\n\n            # write videos\n            print(\"writing videos...\")\n            rgb_vid_fname = \"{}/novel_view_rgb.mp4\".format(opt.output_path)\n            depth_vid_fname = \"{}/novel_view_depth.mp4\".format(opt.output_path)\n            os.system(\"ffmpeg -y -framerate 30 -i {0}/rgb_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1\".format(novel_path,rgb_vid_fname))\n            os.system(\"ffmpeg -y -framerate 30 -i {0}/depth_%d.png -pix_fmt yuv420p {1} >/dev/null 2>&1\".format(novel_path,depth_vid_fname))\n\n# ============================ computation graph for forward/backprop ============================\n\nclass Graph(base.Graph):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        self.nerf = NeRF(opt)\n        if opt.nerf.fine_sampling:\n            self.nerf_fine = NeRF(opt)\n\n    def forward(self,opt,var,mode=None):\n        batch_size = len(var.idx)\n        pose = self.get_pose(opt,var,mode=mode)\n        # render images\n        if opt.nerf.rand_rays and mode in [\"train\",\"test-optim\"]:\n            # sample random rays for optimization\n            var.ray_idx = torch.randperm(opt.H*opt.W,device=opt.device)[:opt.nerf.rand_rays//batch_size]\n            ret = self.render(opt,pose,intr=var.intr,ray_idx=var.ray_idx,mode=mode) # [B,N,3],[B,N,1]\n        else:\n            # render full image (process in slices)\n            ret = self.render_by_slices(opt,pose,intr=var.intr,mode=mode) if opt.nerf.rand_rays else \\\n                  self.render(opt,pose,intr=var.intr,mode=mode) # [B,HW,3],[B,HW,1]\n        var.update(ret)\n        return var\n\n    def compute_loss(self,opt,var,mode=None):\n        loss = edict()\n        batch_size = len(var.idx)\n        image = var.image.view(batch_size,3,opt.H*opt.W).permute(0,2,1)\n        if opt.nerf.rand_rays and mode in [\"train\",\"test-optim\"]:\n            image = image[:,var.ray_idx]\n        # compute image losses\n        if opt.loss_weight.render is not None:\n            loss.render = self.MSE_loss(var.rgb,image)\n        if opt.loss_weight.render_fine is not None:\n            assert(opt.nerf.fine_sampling)\n            loss.render_fine = self.MSE_loss(var.rgb_fine,image)\n        return loss\n\n    def get_pose(self,opt,var,mode=None):\n        return var.pose\n\n    def render(self,opt,pose,intr=None,ray_idx=None,mode=None):\n        batch_size = len(pose)\n        center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3]\n        while ray.isnan().any(): # TODO: weird bug, ray becomes NaN arbitrarily if batch_size>1, not deterministic reproducible\n            center,ray = camera.get_center_and_ray(opt,pose,intr=intr) # [B,HW,3]\n        if ray_idx is not None:\n            # consider only subset of rays\n            center,ray = center[:,ray_idx],ray[:,ray_idx]\n        if opt.camera.ndc:\n            # convert center/ray representations to NDC\n            center,ray = camera.convert_NDC(opt,center,ray,intr=intr)\n        # render with main MLP\n        depth_samples = self.sample_depth(opt,batch_size,num_rays=ray.shape[1]) # [B,HW,N,1]\n        rgb_samples,density_samples = self.nerf.forward_samples(opt,center,ray,depth_samples,mode=mode)\n        rgb,depth,opacity,prob = self.nerf.composite(opt,ray,rgb_samples,density_samples,depth_samples)\n        ret = edict(rgb=rgb,depth=depth,opacity=opacity) # [B,HW,K]\n        # render with fine MLP from coarse MLP\n        if opt.nerf.fine_sampling:\n            with torch.no_grad():\n                # resample depth acoording to coarse empirical distribution\n                depth_samples_fine = self.sample_depth_from_pdf(opt,pdf=prob[...,0]) # [B,HW,Nf,1]\n                depth_samples = torch.cat([depth_samples,depth_samples_fine],dim=2) # [B,HW,N+Nf,1]\n                depth_samples = depth_samples.sort(dim=2).values\n            rgb_samples,density_samples = self.nerf_fine.forward_samples(opt,center,ray,depth_samples,mode=mode)\n            rgb_fine,depth_fine,opacity_fine,_ = self.nerf_fine.composite(opt,ray,rgb_samples,density_samples,depth_samples)\n            ret.update(rgb_fine=rgb_fine,depth_fine=depth_fine,opacity_fine=opacity_fine) # [B,HW,K]\n        return ret\n\n    def render_by_slices(self,opt,pose,intr=None,mode=None):\n        ret_all = edict(rgb=[],depth=[],opacity=[])\n        if opt.nerf.fine_sampling:\n            ret_all.update(rgb_fine=[],depth_fine=[],opacity_fine=[])\n        # render the image by slices for memory considerations\n        for c in range(0,opt.H*opt.W,opt.nerf.rand_rays):\n            ray_idx = torch.arange(c,min(c+opt.nerf.rand_rays,opt.H*opt.W),device=opt.device)\n            ret = self.render(opt,pose,intr=intr,ray_idx=ray_idx,mode=mode) # [B,R,3],[B,R,1]\n            for k in ret: ret_all[k].append(ret[k])\n        # group all slices of images\n        for k in ret_all: ret_all[k] = torch.cat(ret_all[k],dim=1)\n        return ret_all\n\n    def sample_depth(self,opt,batch_size,num_rays=None):\n        depth_min,depth_max = opt.nerf.depth.range\n        num_rays = num_rays or opt.H*opt.W\n        rand_samples = torch.rand(batch_size,num_rays,opt.nerf.sample_intvs,1,device=opt.device) if opt.nerf.sample_stratified else 0.5\n        rand_samples += torch.arange(opt.nerf.sample_intvs,device=opt.device)[None,None,:,None].float() # [B,HW,N,1]\n        depth_samples = rand_samples/opt.nerf.sample_intvs*(depth_max-depth_min)+depth_min # [B,HW,N,1]\n        depth_samples = dict(\n            metric=depth_samples,\n            inverse=1/(depth_samples+1e-8),\n        )[opt.nerf.depth.param]\n        return depth_samples\n\n    def sample_depth_from_pdf(self,opt,pdf):\n        depth_min,depth_max = opt.nerf.depth.range\n        # get CDF from PDF (along last dimension)\n        cdf = pdf.cumsum(dim=-1) # [B,HW,N]\n        cdf = torch.cat([torch.zeros_like(cdf[...,:1]),cdf],dim=-1) # [B,HW,N+1]\n        # take uniform samples\n        grid = torch.linspace(0,1,opt.nerf.sample_intvs_fine+1,device=opt.device) # [Nf+1]\n        unif = 0.5*(grid[:-1]+grid[1:]).repeat(*cdf.shape[:-1],1) # [B,HW,Nf]\n        idx = torch.searchsorted(cdf,unif,right=True) # [B,HW,Nf] \\in {1...N}\n        # inverse transform sampling from CDF\n        depth_bin = torch.linspace(depth_min,depth_max,opt.nerf.sample_intvs+1,device=opt.device) # [N+1]\n        depth_bin = depth_bin.repeat(*cdf.shape[:-1],1) # [B,HW,N+1]\n        depth_low = depth_bin.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf]\n        depth_high = depth_bin.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf]\n        cdf_low = cdf.gather(dim=2,index=(idx-1).clamp(min=0)) # [B,HW,Nf]\n        cdf_high = cdf.gather(dim=2,index=idx.clamp(max=opt.nerf.sample_intvs)) # [B,HW,Nf]\n        # linear interpolation\n        t = (unif-cdf_low)/(cdf_high-cdf_low+1e-8) # [B,HW,Nf]\n        depth_samples = depth_low+t*(depth_high-depth_low) # [B,HW,Nf]\n        return depth_samples[...,None] # [B,HW,Nf,1]\n\nclass NeRF(torch.nn.Module):\n\n    def __init__(self,opt):\n        super().__init__()\n        self.define_network(opt)\n\n    def define_network(self,opt):\n        input_3D_dim = 3+6*opt.arch.posenc.L_3D if opt.arch.posenc else 3\n        if opt.nerf.view_dep:\n            input_view_dim = 3+6*opt.arch.posenc.L_view if opt.arch.posenc else 3\n        # point-wise feature\n        self.mlp_feat = torch.nn.ModuleList()\n        L = util.get_layer_dims(opt.arch.layers_feat)\n        for li,(k_in,k_out) in enumerate(L):\n            if li==0: k_in = input_3D_dim\n            if li in opt.arch.skip: k_in += input_3D_dim\n            if li==len(L)-1: k_out += 1\n            linear = torch.nn.Linear(k_in,k_out)\n            if opt.arch.tf_init:\n                self.tensorflow_init_weights(opt,linear,out=\"first\" if li==len(L)-1 else None)\n            self.mlp_feat.append(linear)\n        # RGB prediction\n        self.mlp_rgb = torch.nn.ModuleList()\n        L = util.get_layer_dims(opt.arch.layers_rgb)\n        feat_dim = opt.arch.layers_feat[-1]\n        for li,(k_in,k_out) in enumerate(L):\n            if li==0: k_in = feat_dim+(input_view_dim if opt.nerf.view_dep else 0)\n            linear = torch.nn.Linear(k_in,k_out)\n            if opt.arch.tf_init:\n                self.tensorflow_init_weights(opt,linear,out=\"all\" if li==len(L)-1 else None)\n            self.mlp_rgb.append(linear)\n\n    def tensorflow_init_weights(self,opt,linear,out=None):\n        # use Xavier init instead of Kaiming init\n        relu_gain = torch.nn.init.calculate_gain(\"relu\") # sqrt(2)\n        if out==\"all\":\n            torch.nn.init.xavier_uniform_(linear.weight)\n        elif out==\"first\":\n            torch.nn.init.xavier_uniform_(linear.weight[:1])\n            torch.nn.init.xavier_uniform_(linear.weight[1:],gain=relu_gain)\n        else:\n            torch.nn.init.xavier_uniform_(linear.weight,gain=relu_gain)\n        torch.nn.init.zeros_(linear.bias)\n\n    def forward(self,opt,points_3D,ray_unit=None,mode=None): # [B,...,3]\n        if opt.arch.posenc:\n            points_enc = self.positional_encoding(opt,points_3D,L=opt.arch.posenc.L_3D)\n            points_enc = torch.cat([points_3D,points_enc],dim=-1) # [B,...,6L+3]\n        else: points_enc = points_3D\n        feat = points_enc\n        # extract coordinate-based features\n        for li,layer in enumerate(self.mlp_feat):\n            if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)\n            feat = layer(feat)\n            if li==len(self.mlp_feat)-1:\n                density = feat[...,0]\n                if opt.nerf.density_noise_reg and mode==\"train\":\n                    density += torch.randn_like(density)*opt.nerf.density_noise_reg\n                density_activ = getattr(torch_F,opt.arch.density_activ) # relu_,abs_,sigmoid_,exp_....\n                density = density_activ(density)\n                feat = feat[...,1:]\n            feat = torch_F.relu(feat)\n        # predict RGB values\n        if opt.nerf.view_dep:\n            assert(ray_unit is not None)\n            if opt.arch.posenc:\n                ray_enc = self.positional_encoding(opt,ray_unit,L=opt.arch.posenc.L_view)\n                ray_enc = torch.cat([ray_unit,ray_enc],dim=-1) # [B,...,6L+3]\n            else: ray_enc = ray_unit\n            feat = torch.cat([feat,ray_enc],dim=-1)\n        for li,layer in enumerate(self.mlp_rgb):\n            feat = layer(feat)\n            if li!=len(self.mlp_rgb)-1:\n                feat = torch_F.relu(feat)\n        rgb = feat.sigmoid_() # [B,...,3]\n        return rgb,density\n\n    def forward_samples(self,opt,center,ray,depth_samples,mode=None):\n        points_3D_samples = camera.get_3D_points_from_depth(opt,center,ray,depth_samples,multi_samples=True) # [B,HW,N,3]\n        if opt.nerf.view_dep:\n            ray_unit = torch_F.normalize(ray,dim=-1) # [B,HW,3]\n            ray_unit_samples = ray_unit[...,None,:].expand_as(points_3D_samples) # [B,HW,N,3]\n        else: ray_unit_samples = None\n        rgb_samples,density_samples = self.forward(opt,points_3D_samples,ray_unit=ray_unit_samples,mode=mode) # [B,HW,N],[B,HW,N,3]\n        return rgb_samples,density_samples\n\n    def composite(self,opt,ray,rgb_samples,density_samples,depth_samples):\n        ray_length = ray.norm(dim=-1,keepdim=True) # [B,HW,1]\n        # volume rendering: compute probability (using quadrature)\n        depth_intv_samples = depth_samples[...,1:,0]-depth_samples[...,:-1,0] # [B,HW,N-1]\n        depth_intv_samples = torch.cat([depth_intv_samples,torch.empty_like(depth_intv_samples[...,:1]).fill_(1e10)],dim=2) # [B,HW,N]\n        dist_samples = depth_intv_samples*ray_length # [B,HW,N]\n        sigma_delta = density_samples*dist_samples # [B,HW,N]\n        alpha = 1-(-sigma_delta).exp_() # [B,HW,N]\n        T = (-torch.cat([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)).exp_() # [B,HW,N]\n        prob = (T*alpha)[...,None] # [B,HW,N,1]\n        # integrate RGB and depth weighted by probability\n        depth = (depth_samples*prob).sum(dim=2) # [B,HW,1]\n        rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3]\n        opacity = prob.sum(dim=2) # [B,HW,1]\n        if opt.nerf.setbg_opaque:\n            rgb = rgb+opt.data.bgcolor*(1-opacity)\n        return rgb,depth,opacity,prob # [B,HW,K]\n\n    def positional_encoding(self,opt,input,L): # [B,...,N]\n        shape = input.shape\n        freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]\n        spectrum = input[...,None]*freq # [B,...,N,L]\n        sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]\n        input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]\n        input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]\n        return input_enc\n"
  },
  {
    "path": "model/planar.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport tqdm\nfrom easydict import EasyDict as edict\nimport PIL\nimport PIL.Image,PIL.ImageDraw\nimport imageio\n\nimport util,util_vis\nfrom util import log,debug\nfrom . import base\nimport warp\n\n# ============================ main engine for training and evaluation ============================\n\nclass Model(base.Model):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        opt.H_crop,opt.W_crop = opt.data.patch_crop\n\n    def load_dataset(self,opt,eval_split=None):\n        image_raw = PIL.Image.open(opt.data.image_fname)\n        self.image_raw = torchvision_F.to_tensor(image_raw).to(opt.device)\n\n    def build_networks(self,opt):\n        super().build_networks(opt)\n        self.graph.warp_param = torch.nn.Embedding(opt.batch_size,opt.warp.dof).to(opt.device)\n        torch.nn.init.zeros_(self.graph.warp_param.weight)\n\n    def setup_optimizer(self,opt):\n        log.info(\"setting up optimizers...\")\n        optim_list = [\n            dict(params=self.graph.neural_image.parameters(),lr=opt.optim.lr),\n            dict(params=self.graph.warp_param.parameters(),lr=opt.optim.lr_warp),\n        ]\n        optimizer = getattr(torch.optim,opt.optim.algo)\n        self.optim = optimizer(optim_list)\n        # set up scheduler\n        if opt.optim.sched:\n            scheduler = getattr(torch.optim.lr_scheduler,opt.optim.sched.type)\n            kwargs = { k:v for k,v in opt.optim.sched.items() if k!=\"type\" }\n            self.sched = scheduler(self.optim,**kwargs)\n\n    def setup_visualizer(self,opt):\n        super().setup_visualizer(opt)\n        # set colors for visualization\n        box_colors = [\"#ff0000\",\"#40afff\",\"#9314ff\",\"#ffd700\",\"#00ff00\"]\n        box_colors = list(map(util.colorcode_to_number,box_colors))\n        self.box_colors = np.array(box_colors).astype(int)\n        assert(len(self.box_colors)==opt.batch_size)\n        # create visualization directory\n        self.vis_path = \"{}/vis\".format(opt.output_path)\n        os.makedirs(self.vis_path,exist_ok=True)\n        self.video_fname = \"{}/vis.mp4\".format(opt.output_path)\n\n    def train(self,opt):\n        # before training\n        log.title(\"TRAINING START\")\n        self.timer = edict(start=time.time(),it_mean=None)\n        self.ep = self.it = self.vis_it = 0\n        self.graph.train()\n        var = edict(idx=torch.arange(opt.batch_size))\n        # pre-generate perturbations\n        self.homo_pert, self.rot_pert, self.trans_pert, var.image_pert = self.generate_warp_perturbation(opt)\n        # train\n        var = util.move_to_device(var,opt.device)\n        loader = tqdm.trange(opt.max_iter,desc=\"training\",leave=False)\n        # visualize initial state\n        var = self.graph.forward(opt,var)\n        self.visualize(opt,var,step=0)\n        for it in loader:\n            # train iteration\n            loss = self.train_iteration(opt,var,loader)\n            if opt.warp.fix_first:\n                self.graph.warp_param.weight.data[0] = 0\n        # after training\n        os.system(\"ffmpeg -y -framerate 30 -i {}/%d.png -pix_fmt yuv420p {}\".format(self.vis_path,self.video_fname))\n        self.save_checkpoint(opt,ep=None,it=self.it)\n        if opt.tb:\n            self.tb.flush()\n            self.tb.close()\n        if opt.visdom: self.vis.close()\n        log.title(\"TRAINING DONE\")\n\n    def train_iteration(self,opt,var,loader):\n        loss = super().train_iteration(opt,var,loader)\n        self.graph.neural_image.progress.data.fill_(self.it/opt.max_iter)\n        return loss\n\n    def generate_warp_perturbation(self,opt):\n        # pre-generate perturbations (translational noise + homography noise)\n        def create_random_perturbation(batch_size):\n            if opt.warp.dof==1:\n                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h\n                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r\n                trans_pert = torch.zeros(batch_size,2,device=opt.device)*opt.warp.noise_t\n            elif opt.warp.dof==2:\n                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h\n                rot_pert = torch.zeros(batch_size,1,device=opt.device)*opt.warp.noise_r\n                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t\n            elif opt.warp.dof==3:\n                homo_pert = torch.zeros(batch_size,8,device=opt.device)*opt.warp.noise_h\n                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r\n                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t\n            elif opt.warp.dof==8:\n                homo_pert = torch.randn(batch_size,8,device=opt.device)*opt.warp.noise_h\n                homo_pert[:,:2]=0\n                rot_pert = torch.randn(batch_size,1,device=opt.device)*opt.warp.noise_r\n                trans_pert = torch.randn(batch_size,2,device=opt.device)*opt.warp.noise_t\n            else: assert(False)\n            return homo_pert, rot_pert, trans_pert\n\n        homo_pert = torch.zeros(opt.batch_size,8,device=opt.device)\n        rot_pert = torch.zeros(opt.batch_size,1,device=opt.device)\n        trans_pert = torch.zeros(opt.batch_size,2,device=opt.device)\n        for i in range(opt.batch_size):\n            homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)\n            while not warp.check_corners_in_range_compose(opt, homo_pert_i, rot_pert_i, trans_pert_i):\n                homo_pert_i, rot_pert_i, trans_pert_i = create_random_perturbation(1)\n            homo_pert[i], rot_pert[i], trans_pert[i] = homo_pert_i, rot_pert_i, trans_pert_i\n\n        if opt.warp.fix_first:\n            homo_pert[0],rot_pert[0],trans_pert[0] = 0,0,0\n\n        # create warped image patches\n        xy_grid = warp.get_normalized_pixel_grid_crop(opt) # [B,HW,2]\n\n        xy_grid_hom = warp.camera.to_hom(xy_grid)\n        warp_matrix = warp.lie.sl3_to_SL3(homo_pert)\n        warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)\n        xy_grid_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]\n        warp_matrix = warp.lie.so2_to_SO2(rot_pert)\n        xy_grid_warped = xy_grid_warped@warp_matrix.transpose(-2,-1) # [B,HW,2]\n        xy_grid_warped = xy_grid_warped+trans_pert[...,None,:]\n\n        xy_grid_warped = xy_grid_warped.view([opt.batch_size,opt.H_crop,opt.W_crop,2])\n        xy_grid_warped = torch.stack([xy_grid_warped[...,0]*max(opt.H,opt.W)/opt.W,\n                                      xy_grid_warped[...,1]*max(opt.H,opt.W)/opt.H],dim=-1)\n        image_raw_batch = self.image_raw.repeat(opt.batch_size,1,1,1)\n        image_pert_all = torch_F.grid_sample(image_raw_batch,xy_grid_warped,align_corners=False)\n\n        return homo_pert, rot_pert, trans_pert, image_pert_all\n\n    def visualize_patches(self,opt,warp_param):\n        image_pil = torchvision_F.to_pil_image(self.image_raw).convert(\"RGBA\")\n        draw_pil = PIL.Image.new(\"RGBA\",image_pil.size,(0,0,0,0))\n        draw = PIL.ImageDraw.Draw(draw_pil)\n        corners_all = warp.warp_corners(opt,warp_param)\n        corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n        corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n        for i,corners in enumerate(corners_all):\n            P = [tuple(float(n) for n in corners[j]) for j in range(4)]\n            draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)\n        image_pil.alpha_composite(draw_pil)\n        image_tensor = torchvision_F.to_tensor(image_pil.convert(\"RGB\"))\n        return image_tensor\n\n    def visualize_patches_compose(self,opt, homo_pert, rot_pert, trans_pert):\n        image_pil = torchvision_F.to_pil_image(self.image_raw).convert(\"RGBA\")\n        draw_pil = PIL.Image.new(\"RGBA\",image_pil.size,(0,0,0,0))\n        draw = PIL.ImageDraw.Draw(draw_pil)\n        corners_all = warp.warp_corners_compose(opt, homo_pert, rot_pert, trans_pert)\n        corners_all[...,0] = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n        corners_all[...,1] = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n        for i,corners in enumerate(corners_all):\n            P = [tuple(float(n) for n in corners[j]) for j in range(4)]\n            draw.line([P[0],P[1],P[2],P[3],P[0]],fill=tuple(self.box_colors[i]),width=3)\n        image_pil.alpha_composite(draw_pil)\n        image_tensor = torchvision_F.to_tensor(image_pil.convert(\"RGB\"))\n        return image_tensor\n\n    @torch.no_grad()\n    def predict_entire_image(self,opt):\n        xy_grid = warp.get_normalized_pixel_grid(opt)[:1]\n        rgb = self.graph.neural_image.forward(opt,xy_grid) # [B,HW,3]\n        image = rgb.view(opt.H,opt.W,3).detach().cpu().permute(2,0,1)\n        return image\n\n    @torch.no_grad()\n    def log_scalars(self,opt,var,loss,metric=None,step=0,split=\"train\"):\n        super().log_scalars(opt,var,loss,metric=metric,step=step,split=split)\n        # compute PSNR\n        psnr = -10*loss.render.log10()\n        self.tb.add_scalar(\"{0}/{1}\".format(split,\"PSNR\"),psnr,step)\n        # warp error\n        # pred_corners = warp.warp_corners_use_matrix(opt,self.graph.warp_matrix)\n        pred_corners = warp.warp_corners(opt,self.graph.warp_param.weight)\n        pred_corners[...,0] = (pred_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n        pred_corners[...,1] = (pred_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n        \n        gt_corners = warp.warp_corners_compose(opt, self.homo_pert, self.rot_pert, self.trans_pert)\n        gt_corners[...,0] = (gt_corners[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n        gt_corners[...,1] = (gt_corners[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n\n        warp_error = (pred_corners-gt_corners).norm(dim=-1).mean()\n        self.tb.add_scalar(\"{0}/{1}\".format(split,\"warp error\"),warp_error,step)\n\n\n    @torch.no_grad()\n    def visualize(self,opt,var,step=0,split=\"train\"):\n        # dump frames for writing to video\n        frame_GT = self.visualize_patches_compose(opt,self.homo_pert, self.rot_pert, self.trans_pert)\n        frame = self.visualize_patches(opt,self.graph.warp_param.weight)\n        frame2 = self.predict_entire_image(opt)\n        frame_cat = (torch.cat([frame,frame2],dim=1)*255).byte().permute(1,2,0).numpy()\n        imageio.imsave(\"{}/{}.png\".format(self.vis_path,self.vis_it),frame_cat)\n        self.vis_it += 1\n        # visualize in Tensorboard\n        if opt.tb:\n            colors = self.box_colors\n            util_vis.tb_image(opt,self.tb,step,split,\"image_pert\",util_vis.color_border(var.image_pert,colors))\n            util_vis.tb_image(opt,self.tb,step,split,\"rgb_warped\",util_vis.color_border(var.rgb_warped_map,colors))\n            util_vis.tb_image(opt,self.tb,self.it+1,\"train\",\"image_boxes\",frame[None])\n            util_vis.tb_image(opt,self.tb,self.it+1,\"train\",\"image_boxes_GT\",frame_GT[None])\n            util_vis.tb_image(opt,self.tb,self.it+1,\"train\",\"image_entire\",frame2[None])\n\n# ============================ computation graph for forward/backprop ============================\n\nclass Graph(base.Graph):\n\n    def __init__(self,opt):\n        super().__init__(opt)\n        self.neural_image = NeuralImageFunction(opt)\n\n    def forward(self,opt,var,mode=None):\n        xy_grid = warp.get_normalized_pixel_grid_crop(opt)\n        xy_grid_warped = warp.warp_grid(opt,xy_grid,self.warp_param.weight)\n        # render images\n        var.rgb_warped = self.neural_image.forward(opt,xy_grid_warped) # [B,HW,3]\n        var.rgb_warped_map = var.rgb_warped.view(opt.batch_size,opt.H_crop,opt.W_crop,3).permute(0,3,1,2) # [B,3,H,W]\n        return var\n\n    def compute_loss(self,opt,var,mode=None):\n        loss = edict()\n        if opt.loss_weight.render is not None:\n            image_pert = var.image_pert.view(opt.batch_size,3,opt.H_crop*opt.W_crop).permute(0,2,1)\n            loss.render = self.MSE_loss(var.rgb_warped,image_pert)\n        return loss\n\nclass NeuralImageFunction(torch.nn.Module):\n\n    def __init__(self,opt):\n        super().__init__()\n        self.define_network(opt)\n        self.progress = torch.nn.Parameter(torch.tensor(0.)) # use Parameter so it could be checkpointed\n\n    def define_network(self,opt):\n        input_2D_dim = 2+4*opt.arch.posenc.L_2D if opt.arch.posenc else 2\n        # point-wise RGB prediction\n        self.mlp = torch.nn.ModuleList()\n        L = util.get_layer_dims(opt.arch.layers)\n        for li,(k_in,k_out) in enumerate(L):\n            if li==0: k_in = input_2D_dim\n            if li in opt.arch.skip: k_in += input_2D_dim\n            linear = torch.nn.Linear(k_in,k_out)\n            if opt.barf_c2f and li==0:\n                # rescale first layer init (distribution was for pos.enc. but only xy is first used)\n                scale = np.sqrt(input_2D_dim/2.)\n                linear.weight.data *= scale\n                linear.bias.data *= scale\n            self.mlp.append(linear)\n\n    def forward(self,opt,coord_2D): # [B,...,3]\n        if opt.arch.posenc:\n            points_enc = self.positional_encoding(opt,coord_2D,L=opt.arch.posenc.L_2D)\n            points_enc = torch.cat([coord_2D,points_enc],dim=-1) # [B,...,6L+3]\n        else: points_enc = coord_2D\n        feat = points_enc\n        # extract implicit features\n        for li,layer in enumerate(self.mlp):\n            if li in opt.arch.skip: feat = torch.cat([feat,points_enc],dim=-1)\n            feat = layer(feat)\n            if li!=len(self.mlp)-1:\n                feat = torch_F.relu(feat)\n        rgb = feat.sigmoid_() # [B,...,3]\n        return rgb\n\n    def positional_encoding(self,opt,input,L): # [B,...,N]\n        shape = input.shape\n        freq = 2**torch.arange(L,dtype=torch.float32,device=opt.device)*np.pi # [L]\n        spectrum = input[...,None]*freq # [B,...,N,L]\n        sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L]\n        input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L]\n        input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL]\n        # coarse-to-fine: smoothly mask positional encoding for BARF\n        if opt.barf_c2f is not None:\n            # set weights for different frequency bands\n            start,end = opt.barf_c2f\n            alpha = (self.progress.data-start)/(end-start)*L\n            k = torch.arange(L,dtype=torch.float32,device=opt.device)\n            weight = (1-(alpha-k).clamp_(min=0,max=1).mul_(np.pi).cos_())/2\n            # apply weights\n            shape = input_enc.shape\n            input_enc = (input_enc.view(-1,L)*weight).view(*shape)\n        return input_enc\n"
  },
  {
    "path": "options/barf_blender.yaml",
    "content": "_parent_: options/nerf_blender.yaml\n\nbarf_c2f: [0.1,0.5]                                         # coarse-to-fine scheduling on positional encoding\n\ncamera:                                                     # camera options\n    noise: true                                             # synthetic perturbations on the camera poses (Blender only)\n    noise_r: 0.03                                           # synthetic perturbations on the camera poses (Blender only)\n    noise_t: 0.3                                            # synthetic perturbations on the camera poses (Blender only)\n\noptim:                                                      # optimization options\n    lr_pose: 1.e-3                                          # learning rate of camera poses\n    lr_pose_end: 1.e-5                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)\n    sched_pose:                                             # learning rate scheduling options\n        type: ExponentialLR                                 # scheduler (see PyTorch doc)\n        gamma:                                              # decay rate (can be empty if lr_pose_end were specified)\n    warmup_pose:                                            # linear warmup of the pose learning rate (N iterations)\n    test_photo: true                                        # test-time photometric optimization for evaluation\n    test_iter: 100                                          # number of iterations for test-time optimization\n\nvisdom:                                                     # Visdom options\n    cam_depth: 0.5                                          # size of visualized cameras\n"
  },
  {
    "path": "options/barf_iphone.yaml",
    "content": "_parent_: options/barf_llff.yaml\n\ndata:                                                       # data options\n    dataset: iphone                                         # dataset name\n    scene: IMG_0239                                         # scene name\n    image_size: [480,640]                                   # input image sizes [height,width]\n"
  },
  {
    "path": "options/barf_llff.yaml",
    "content": "_parent_: options/nerf_llff.yaml\n\nbarf_c2f: [0.1,0.5]                                         # coarse-to-fine scheduling on positional encoding\n\ncamera:                                                     # camera options\n    noise:                                                  # synthetic perturbations on the camera poses (Blender only)\n\noptim:                                                      # optimization options\n    lr_pose: 3.e-3                                          # learning rate of camera poses\n    lr_pose_end: 1.e-5                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)\n    sched_pose:                                             # learning rate scheduling options\n        type: ExponentialLR                                 # scheduler (see PyTorch doc)\n        gamma:                                              # decay rate (can be empty if lr_pose_end were specified)\n    warmup_pose:                                            # linear warmup of the pose learning rate (N iterations)\n    test_photo: true                                        # test-time photometric optimization for evaluation\n    test_iter: 100                                          # number of iterations for test-time optimization\n\nvisdom:                                                     # Visdom options\n    cam_depth: 0.2                                          # size of visualized cameras\n"
  },
  {
    "path": "options/base.yaml",
    "content": "# default\n\ngroup: 0_test                                               # name of experiment group\nname: debug                                                 # name of experiment run\nmodel:                                                      # type of model (must be specified from command line)\nyaml:                                                       # config file (must be specified from command line)\nseed: 0                                                     # seed number (for both numpy and pytorch)\ngpu: 0                                                      # GPU index number\ncpu: false                                                  # run only on CPU (not supported now)\nload:                                                       # load checkpoint from filename\n\narch: {}                                                    # architectural options\n\ndata:                                                       # data options\n    root:                                                   # root path to dataset\n    dataset:                                                # dataset name\n    image_size: [null,null]                                 # input image sizes [height,width]\n    num_workers: 8                                          # number of parallel workers for data loading\n    preload: false                                          # preload the entire dataset into the memory\n    augment: {}                                             # data augmentation (training only)\n        # rotate:                                           # random rotation\n        # brightness: # 0.2                                 # random brightness jitter\n        # contrast: # 0.2                                   # random contrast jitter\n        # saturation: # 0.2                                 # random saturation jitter\n        # hue: # 0.1                                        # random hue jitter\n        # hflip: # True                                     # random horizontal flip\n    center_crop:                                            # center crop the image by ratio\n    val_on_test: false                                      # validate on test set during training\n    train_sub:                                              # consider a subset of N training samples\n    val_sub:                                                # consider a subset of N validation samples\n\nloss_weight: {}                                             # loss weights (in log scale)\n\noptim:                                                      # optimization options\n    lr: 1.e-3                                               # learning rate (main)\n    lr_end:                                                 # terminal learning rate (only used with sched.type=ExponentialLR)\n    algo: Adam                                              # optimizer (see PyTorch doc)\n    sched: {}                                               # learning rate scheduling options\n        # type: StepLR                                      # scheduler (see PyTorch doc)\n        # steps:                                            # decay every N epochs\n        # gamma: 0.1                                        # decay rate (can be empty if lr_end were specified)\n\nbatch_size: 16                                              # batch size\nmax_epoch: 1000                                             # train to maximum number of epochs\nresume: false                                               # resume training (true for latest checkpoint, or number for specific epoch number)\n\noutput_root: output                                         # root path for output files (checkpoints and results)\ntb:                                                         # TensorBoard options\n    num_images: [4,8]                                       # number of (tiled) images to visualize in TensorBoard\nvisdom:                                                     # Visdom options\n    server: localhost                                       # server to host Visdom\n    port: 8600                                              # port number for Visdom\n\nfreq:                                                       # periodic actions during training\n    scalar: 200                                             # log losses and scalar states (every N iterations)\n    vis: 1000                                               # visualize results (every N iterations)\n    val: 20                                                 # validate on val set (every N epochs)\n    ckpt: 50                                                # save checkpoint (every N epochs)\n"
  },
  {
    "path": "options/l2g_nerf_blender.yaml",
    "content": "_parent_: options/barf_blender.yaml\n\narch:                                                       # architectural options\n    layers_warp: [null,256,256,256,256,256,256,6]           # hidden layers for MLP\n    skip_warp: [4]                                          # skip connections\n    embedding_dim: 128                                      # embedding dim\n\noptim:                                                      # optimization options\n    lr_pose: 1.e-3                                          # learning rate of camera poses\n    lr_pose_end: 1.e-8                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)\n\nloss_weight:                                                # loss weights (in log scale)\n    global_alignment: 2                                     # global alignment loss\n\nerror_map_size:"
  },
  {
    "path": "options/l2g_nerf_iphone.yaml",
    "content": "_parent_: options/l2g_nerf_llff.yaml\n\ndata:                                                       # data options\n    dataset: iphone                                         # dataset name\n    scene: IMG_0239                                         # scene name\n    image_size: [480,640]                                   # input image sizes [height,width]\n\noptim:                                                      # optimization options\n    test_iter: 1000                                         # number of iterations for test-time optimization"
  },
  {
    "path": "options/l2g_nerf_llff.yaml",
    "content": "_parent_: options/barf_llff.yaml\n\narch:                                                       # architectural options\n    layers_warp: [null,256,256,256,256,256,256,6]           # hidden layers for MLP\n    skip_warp: [4]                                          # skip connections\n    embedding_dim: 128                                      # embedding dim\n\noptim:                                                      # optimization options\n    lr_pose: 3.e-3                                          # learning rate of camera poses\n    lr_pose_end: 1.e-8                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)\n\nloss_weight:                                                # loss weights (in log scale)\n    global_alignment: 2                                     # global alignment loss\n\nerror_map_size:"
  },
  {
    "path": "options/l2g_planar.yaml",
    "content": "_parent_: options/planar.yaml\n\narch:                                                       # architectural options\n    layers_warp: [null,256,256,256,256,256,256,3]           # hidden layers for MLP\n    skip_warp: [4]                                          # skip connections\n    embedding_dim: 128                                      # embedding dim\n\nloss_weight:                                                # loss weights (in log scale)\n    global_alignment: 2                                     # global alignment loss\n\noptim:\n    lr: 1.e-3                                               # learning rate (main)\n    lr_end: 1.e-4                                           # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)\n    sched:                                                  # learning rate scheduling options\n        type: ExponentialLR                                 # scheduler (see PyTorch doc)\n        gamma:                                              # decay rate (can be empty if lr_pose_end were specified)\n    lr_warp: 1.e-3                                          # learning rate of camera poses\n    lr_warp_end: 1.e-5                                      # terminal learning rate of camera poses (only used with sched_pose.type=ExponentialLR)\n    sched_warp:                                             # learning rate scheduling options\n        type: ExponentialLR                                 # scheduler (see PyTorch doc)\n        gamma:                                              # decay rate (can be empty if lr_pose_end were specified)"
  },
  {
    "path": "options/nerf_blender.yaml",
    "content": "_parent_: options/base.yaml\n\narch:                                                       # architectural options\n    layers_feat: [null,256,256,256,256,256,256,256,256]     # hidden layers for feature/density MLP\n    layers_rgb: [null,128,3]                                # hidden layers for color MLP\n    skip: [4]                                               # skip connections\n    posenc:                                                 # positional encoding\n        L_3D: 10                                            # number of bases (3D point)\n        L_view: 4                                           # number of bases (viewpoint)\n    density_activ: softplus                                 # activation function for output volume density\n    tf_init: true                                           # initialize network weights in TensorFlow style\n\nnerf:                                                       # NeRF-specific options\n    view_dep: true                                          # condition MLP on viewpoint\n    depth:                                                  # depth-related options\n        param: metric                                       # depth parametrization (for sampling along the ray)\n        range: [2,6]                                        # near/far bounds for depth sampling\n    sample_intvs: 128                                       # number of samples\n    sample_stratified: true                                 # stratified sampling\n    fine_sampling: false                                    # hierarchical sampling with another NeRF\n    sample_intvs_fine:                                      # number of samples for the fine NeRF\n    rand_rays: 1024                                         # number of random rays for each step\n    density_noise_reg:                                      # Gaussian noise on density output as regularization\n    setbg_opaque: false                                     # fill transparent rendering with known background color (Blender only)\n\ndata:                                                       # data options\n    dataset: blender                                        # dataset name\n    scene: lego                                             # scene name\n    image_size: [400,400]                                   # input image sizes [height,width]\n    num_workers: 4                                          # number of parallel workers for data loading\n    preload: true                                           # preload the entire dataset into the memory\n    bgcolor: 1                                              # background color (Blender only)\n    val_sub: 4                                              # consider a subset of N validation samples\n\ncamera:                                                     # camera options\n    model: perspective                                      # type of camera model\n    ndc: false                                              # reparametrize as normalized device coordinates (NDC)\n\nloss_weight:                                                # loss weights (in log scale)\n    render: 0                                               # RGB rendering loss\n    render_fine:                                            # RGB rendering loss (for fine NeRF)\n\noptim:                                                      # optimization options\n    lr: 5.e-4                                               # learning rate (main)\n    lr_end: 1.e-4                                           # terminal learning rate (only used with sched.type=ExponentialLR)\n    sched:                                                  # learning rate scheduling options\n        type: ExponentialLR                                 # scheduler (see PyTorch doc)\n        gamma:                                              # decay rate (can be empty if lr_end were specified)\n\nbatch_size:                                                 # batch size (not used for NeRF/BARF)\nmax_epoch:                                                  # train to maximum number of epochs (not used for NeRF/BARF)\nmax_iter: 200000                                            # train to maximum number of iterations\n\ntrimesh:                                                    # options for marching cubes to extract 3D mesh\n    res: 128                                                # 3D sampling resolution\n    range: [-1.2,1.2]                                       # 3D range of interest (assuming same for x,y,z)\n    thres: 25.                                              # volume density threshold for marching cubes\n    chunk_size: 16384                                       # chunk size of dense samples to be evaluated at a time\n\nfreq:                                                       # periodic actions during training\n    scalar: 200                                             # log losses and scalar states (every N iterations)\n    vis: 1000                                               # visualize results (every N iterations)\n    val: 2000                                               # validate on val set (every N iterations)\n    ckpt: 5000                                              # save checkpoint (every N iterations)\n"
  },
  {
    "path": "options/nerf_blender_repr.yaml",
    "content": "_parent_: options/base.yaml\n\narch:                                                       # architectural options\n    layers_feat: [null,256,256,256,256,256,256,256,256]     # hidden layers for feature/density MLP\n    layers_rgb: [null,128,3]                                # hidden layers for color MLP\n    skip: [4]                                               # skip connections\n    posenc:                                                 # positional encoding\n        L_3D: 10                                            # number of bases (3D point)\n        L_view: 4                                           # number of bases (viewpoint)\n    density_activ: relu                                     # activation function for output volume density\n    tf_init: true                                           # initialize network weights in TensorFlow style\n\nnerf:                                                       # NeRF-specific options\n    view_dep: true                                          # condition MLP on viewpoint\n    depth:                                                  # depth-related options\n        param: metric                                       # depth parametrization (for sampling along the ray)\n        range: [2,6]                                        # near/far bounds for depth sampling\n    sample_intvs: 64                                        # number of samples\n    sample_stratified: true                                 # stratified sampling\n    fine_sampling: true                                     # hierarchical sampling with another NeRF\n    sample_intvs_fine: 128                                  # number of samples for the fine NeRF\n    rand_rays: 1024                                         # number of random rays for each step\n    density_noise_reg: 0                                    # Gaussian noise on density output as regularization\n    setbg_opaque: true                                      # fill transparent rendering with known background color (Blender only)\n\ndata:                                                       # data options\n    dataset: blender                                        # dataset name\n    scene: lego                                             # scene name\n    image_size: [400,400]                                   # input image sizes [height,width]\n    num_workers: 4                                          # number of parallel workers for data loading\n    preload: true                                           # preload the entire dataset into the memory\n    bgcolor: 1                                              # background color (Blender only)\n    val_sub: 4                                              # consider a subset of N validation samples\n\ncamera:                                                     # camera options\n    model: perspective                                      # type of camera model\n    ndc: false                                              # reparametrize as normalized device coordinates (NDC)\n\nloss_weight:                                                # loss weights (in log scale)\n    render: 0                                               # RGB rendering loss\n    render_fine: 0                                          # RGB rendering loss (for fine NeRF)\n\noptim:                                                      # optimization options\n    lr: 5.e-4                                               # learning rate (main)\n    lr_end: 5.e-5                                           # terminal learning rate (only used with sched.type=ExponentialLR)\n    sched:                                                  # learning rate scheduling options\n        type: ExponentialLR                                 # scheduler (see PyTorch doc)\n        gamma:                                              # decay rate (can be empty if lr_end were specified)\n\nbatch_size:                                                 # batch size (not used for NeRF/BARF)\nmax_epoch:                                                  # train to maximum number of epochs (not used for NeRF/BARF)\nmax_iter: 500000                                            # train to maximum number of iterations\n\ntrimesh:                                                    # options for marching cubes to extract 3D mesh\n    res: 128                                                # 3D sampling resolution\n    range: [-1.2,1.2]                                       # 3D range of interest (assuming same for x,y,z)\n    thres: 25.                                              # volume density threshold for marching cubes\n    chunk_size: 16384                                       # chunk size of dense samples to be evaluated at a time\n\nfreq:                                                       # periodic actions during training\n    scalar: 200                                             # log losses and scalar states (every N iterations)\n    vis: 1000                                               # visualize results (every N iterations)\n    val: 2000                                               # validate on val set (every N iterations)\n    ckpt: 5000                                              # save checkpoint (every N iterations)\n"
  },
  {
    "path": "options/nerf_llff.yaml",
    "content": "_parent_: options/base.yaml\n\narch:                                                       # architectural optionss\n    layers_feat: [null,256,256,256,256,256,256,256,256]     # hidden layers for feature/density MLP]\n    layers_rgb: [null,128,3]                                # hidden layers for color MLP]\n    skip: [4]                                               # skip connections\n    posenc:                                                 # positional encoding:\n        L_3D: 10                                            # number of bases (3D point)\n        L_view: 4                                           # number of bases (viewpoint)\n    density_activ: softplus                                 # activation function for output volume density\n    tf_init: true                                           # initialize network weights in TensorFlow style\n\nnerf:                                                       # NeRF-specific options\n    view_dep: true                                          # condition MLP on viewpoint\n    depth:                                                  # depth-related options\n        param: inverse                                      # depth parametrization (for sampling along the ray)\n        range: [1,0]                                        # near/far bounds for depth sampling\n    sample_intvs: 128                                       # number of samples\n    sample_stratified: true                                 # stratified sampling\n    fine_sampling: false                                    # hierarchical sampling with another NeRF\n    sample_intvs_fine:                                      # number of samples for the fine NeRF\n    rand_rays: 2048                                         # number of random rays for each step\n    density_noise_reg:                                      # Gaussian noise on density output as regularization\n    setbg_opaque:                                           # fill transparent rendering with known background color (Blender only)\n\ndata:                                                       # data options\n    dataset: llff                                           # dataset name\n    scene: fern                                             # scene name\n    image_size: [480,640]                                   # input image sizes [height,width]\n    num_workers: 4                                          # number of parallel workers for data loading\n    preload: true                                           # preload the entire dataset into the memory\n    val_ratio: 0.1                                          # ratio of sequence split for validation\n\ncamera:                                                     # camera options\n    model: perspective                                      # type of camera model\n    ndc: false                                              # reparametrize as normalized device coordinates (NDC)\n\nloss_weight:                                                # loss weights (in log scale)\n    render: 0                                               # RGB rendering loss\n    render_fine:                                            # RGB rendering loss (for fine NeRF)\n\noptim:                                                      # optimization options\n    lr: 1.e-3                                               # learning rate (main)\n    lr_end: 1.e-4                                           # terminal learning rate (only used with sched.type=ExponentialLR)\n    sched:                                                  # learning rate scheduling options\n        type: ExponentialLR                                 # scheduler (see PyTorch doc)\n        gamma:                                              # decay rate (can be empty if lr_end were specified)\n\nbatch_size:                                                 # batch size (not used for NeRF/BARF)\nmax_epoch:                                                  # train to maximum number of epochs (not used for NeRF/BARF)\nmax_iter: 200000                                            # train to maximum number of iterations\n\nfreq:                                                       # periodic actions during training\n    scalar: 200                                             # log losses and scalar states (every N iterations)\n    vis: 1000                                               # visualize results (every N iterations)\n    val: 2000                                               # validate on val set (every N iterations)\n    ckpt: 5000                                              # save checkpoint (every N iterations)\n"
  },
  {
    "path": "options/nerf_llff_repr.yaml",
    "content": "_parent_: options/base.yaml\n\narch:                                                       # architectural options\n    layers_feat: [null,256,256,256,256,256,256,256,256]     # hidden layers for feature/density MLP\n    layers_rgb: [null,128,3]                                # hidden layers for color MLP\n    skip: [4]                                               # skip connections\n    posenc:                                                 # positional encoding\n        L_3D: 10                                            # number of bases (3D point)\n        L_view: 4                                           # number of bases (viewpoint)\n    density_activ: relu                                     # activation function for output volume density\n    tf_init: true                                           # initialize network weights in TensorFlow style\n\nnerf:                                                       # NeRF-specific options\n    view_dep: true                                          # condition MLP on viewpoint\n    depth:                                                  # depth-related options\n        param: metric                                       # depth parametrization (for sampling along the ray)\n        range: [0,1]                                        # near/far bounds for depth sampling\n    sample_intvs: 64                                        # number of samples\n    sample_stratified: true                                 # stratified sampling\n    fine_sampling: true                                     # hierarchical sampling with another NeRF\n    sample_intvs_fine: 128                                  # number of samples for the fine NeRF\n    rand_rays: 1024                                         # number of random rays for each step\n    density_noise_reg: 1                                    # Gaussian noise on density output as regularization\n    setbg_opaque:                                           # fill transparent rendering with known background color (Blender only)\n\ndata:                                                       # data options\n    dataset: llff                                           # dataset name\n    scene: fern                                             # scene name\n    image_size: [480,640]                                   # input image sizes [height,width]\n    num_workers: 4                                          # number of parallel workers for data loading\n    preload: true                                           # preload the entire dataset into the memory\n    val_ratio: 0.1                                          # ratio of sequence split for validation\n\ncamera:                                                     # camera options\n    model: perspective                                      # type of camera model\n    ndc: false                                              # reparametrize as normalized device coordinates (NDC)\n\nloss_weight:                                                # loss weights (in log scale)\n    render: 0                                               # RGB rendering loss\n    render_fine: 0                                          # RGB rendering loss (for fine NeRF)\n\noptim:                                                      # optimization options\n    lr: 5.e-4                                               # learning rate (main)\n    lr_end: 5.e-5                                           # terminal learning rate (only used with sched.type=ExponentialLR)\n    sched:                                                  # learning rate scheduling options\n        type: ExponentialLR                                 # scheduler (see PyTorch doc)\n        gamma:                                              # decay rate (can be empty if lr_end were specified)\n\nbatch_size:                                                 # batch size (not used for NeRF/BARF)\nmax_epoch:                                                  # train to maximum number of epochs (not used for NeRF/BARF)\nmax_iter: 500000                                            # train to maximum number of iterations\n\nfreq:                                                       # periodic actions during training\n    scalar: 200                                             # log losses and scalar states (every N iterations)\n    vis: 1000                                               # visualize results (every N iterations)\n    val: 2000                                               # validate on val set (every N iterations)\n    ckpt: 5000                                              # save checkpoint (every N iterations)\n"
  },
  {
    "path": "options/planar.yaml",
    "content": "_parent_: options/base.yaml\n\narch:                                                       # architectural options\n    layers: [null,256,256,256,256,3]                        # hidden layers for MLP\n    skip: []                                                # skip connections\n    posenc:                                                 # positional encoding\n        L_2D: 8                                             # number of bases (3D point)\n\nbarf_c2f: [0,0.4]                                           # coarse-to-fine scheduling on positional encoding\n\ndata:                                                       # data options\n    image_fname: data/cat.jpg                               # path to image file\n    image_size: [360,480]                                   # original image size\n    patch_crop: [180,180]                                   # crop size of image patches to align\n\nwarp:                                                       # image warping options\n    type: homography                                        # type of warp function\n    dof: 8                                                  # degrees of freedom of the warp function\n    noise_h: 0.3                                            # scale of pre-generated warp perturbation (homography)\n    noise_t: 0.1                                            # scale of pre-generated warp perturbation (translation)\n    noise_r: 1.0                                            # scale of pre-generated warp perturbation (rotation)\n    fix_first: true                                         # fix the first patch for uniqueness of solution\n\nloss_weight:                                                # loss weights (in log scale)\n    render: 0                                               # RGB rendering loss\n\noptim:                                                      # optimization options\n    lr: 1.e-3                                               # learning rate (main)\n    lr_warp: 1.e-3                                          # learning rate of warp parameters\n\nbatch_size: 5                                               # batch size (set to number of patches to consider)\nmax_iter: 5000                                              # train to maximum number of iterations\n\nvisdom:                                                     # Visdom options (turned off)\n\nfreq:                                                       # periodic actions during training\n    scalar: 20                                              # log losses and scalar states (every N iterations)\n    vis: 100                                                # visualize results (every N iterations)\n"
  },
  {
    "path": "options.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport random\nimport string\nimport yaml\nfrom easydict import EasyDict as edict\n\nimport util\nfrom util import log\n\n# torch.backends.cudnn.enabled = False\n# torch.backends.cudnn.benchmark = False\n# torch.backends.cudnn.deterministic = True\n\ndef parse_arguments(args):\n    \"\"\"\n    Parse arguments from command line.\n    Syntax: --key1.key2.key3=value --> value\n            --key1.key2.key3=      --> None\n            --key1.key2.key3       --> True\n            --key1.key2.key3!      --> False\n    \"\"\"\n    opt_cmd = {}\n    for arg in args:\n        assert(arg.startswith(\"--\"))\n        if \"=\" not in arg[2:]:\n            key_str,value = (arg[2:-1],\"false\") if arg[-1]==\"!\" else (arg[2:],\"true\")\n        else:\n            key_str,value = arg[2:].split(\"=\")\n        keys_sub = key_str.split(\".\")\n        opt_sub = opt_cmd\n        for k in keys_sub[:-1]:\n            if k not in opt_sub: opt_sub[k] = {}\n            opt_sub = opt_sub[k]\n        assert keys_sub[-1] not in opt_sub,keys_sub[-1]\n        opt_sub[keys_sub[-1]] = yaml.safe_load(value)\n    opt_cmd = edict(opt_cmd)\n    return opt_cmd\n\ndef set(opt_cmd={}):\n    log.info(\"setting configurations...\")\n    assert(\"model\" in opt_cmd)\n    # load config from yaml file\n    assert(\"yaml\" in opt_cmd)\n    fname = \"options/{}.yaml\".format(opt_cmd.yaml)\n    opt_base = load_options(fname)\n    # override with command line arguments\n    opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True)\n    process_options(opt)\n    log.options(opt)\n    return opt\n\ndef load_options(fname):\n    with open(fname) as file:\n        opt = edict(yaml.safe_load(file))\n    if \"_parent_\" in opt:\n        # load parent yaml file(s) as base options\n        parent_fnames = opt.pop(\"_parent_\")\n        if type(parent_fnames) is str:\n            parent_fnames = [parent_fnames]\n        for parent_fname in parent_fnames:\n            opt_parent = load_options(parent_fname)\n            opt_parent = override_options(opt_parent,opt,key_stack=[])\n            opt = opt_parent\n    print(\"loading {}...\".format(fname))\n    return opt\n\ndef override_options(opt,opt_over,key_stack=None,safe_check=False):\n    for key,value in opt_over.items():\n        if isinstance(value,dict):\n            # parse child options (until leaf nodes are reached)\n            opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check)\n        else:\n            # ensure command line argument to override is also in yaml file\n            if safe_check and key not in opt:\n                add_new = None\n                while add_new not in [\"y\",\"n\"]:\n                    key_str = \".\".join(key_stack+[key])\n                    add_new = input(\"\\\"{}\\\" not found in original opt, add? (y/n) \".format(key_str))\n                if add_new==\"n\":\n                    print(\"safe exiting...\")\n                    exit()\n            opt[key] = value\n    return opt\n\ndef process_options(opt):\n    # set seed\n    if opt.seed is not None:\n        random.seed(opt.seed)\n        np.random.seed(opt.seed)\n        torch.manual_seed(opt.seed)\n        torch.cuda.manual_seed_all(opt.seed)\n        if opt.seed!=0:\n            opt.name = str(opt.name)+\"_seed{}\".format(opt.seed)\n    else:\n        # create random string as run ID\n        randkey = \"\".join(random.choice(string.ascii_uppercase) for _ in range(4))\n        opt.name = str(opt.name)+\"_{}\".format(randkey)\n    # other default options\n    opt.output_path = \"{0}/{1}/{2}\".format(opt.output_root,opt.group,opt.name)\n    os.makedirs(opt.output_path,exist_ok=True)\n    assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough\n    opt.device = \"cpu\" if opt.cpu or not torch.cuda.is_available() else \"cuda:{}\".format(opt.gpu)\n    opt.H,opt.W = opt.data.image_size\n\ndef save_options_file(opt):\n    opt_fname = \"{}/options.yaml\".format(opt.output_path)\n    if os.path.isfile(opt_fname):\n        with open(opt_fname) as file:\n            opt_old = yaml.safe_load(file)\n        if opt!=opt_old:\n            # prompt if options are not identical\n            opt_new_fname = \"{}/options_temp.yaml\".format(opt.output_path)\n            with open(opt_new_fname,\"w\") as file:\n                yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4)\n            print(\"existing options file found (different from current one)...\")\n            os.system(\"diff {} {}\".format(opt_fname,opt_new_fname))\n            os.system(\"rm {}\".format(opt_new_fname))\n            override = None\n            while override not in [\"y\",\"n\"]:\n                override = input(\"override? (y/n) \")\n            if override==\"n\":\n                print(\"safe exiting...\")\n                exit()\n        else: print(\"existing options file found (identical)\")\n    else: print(\"(creating new options file...)\")\n    with open(opt_fname,\"w\") as file:\n        yaml.safe_dump(util.to_dict(opt),file,default_flow_style=False,indent=4)\n"
  },
  {
    "path": "requirements.yaml",
    "content": "name: L2G-NeRF\nchannels:\n  - conda-forge\n  - pytorch\ndependencies:\n  - numpy\n  - scipy\n  - tqdm\n  - termcolor\n  - easydict\n  - imageio\n  - ipdb\n  - pytorch>=1.9.0\n  - torchvision\n  - tensorboard\n  - visdom\n  - matplotlib\n  - scikit-video\n  - trimesh\n  - pyyaml\n  - pip\n  - gdown\n  - pip:\n    - lpips\n    - pymcubes\n    - roma\n    - kornia\n"
  },
  {
    "path": "train.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport importlib\n\nimport options\nfrom util import log\n\nimport warnings\nwarnings.filterwarnings('ignore')\n\ndef main():\n\n    log.process(os.getpid())\n    log.title(\"[{}] (PyTorch code for training NeRF/BARF/L2G_NeRF)\".format(sys.argv[0]))\n\n    opt_cmd = options.parse_arguments(sys.argv[1:])\n    opt = options.set(opt_cmd=opt_cmd)\n    options.save_options_file(opt)\n\n    with torch.cuda.device(opt.device):\n\n        model = importlib.import_module(\"model.{}\".format(opt.model))\n        m = model.Model(opt)\n\n        m.load_dataset(opt)\n        m.build_networks(opt)\n        m.setup_optimizer(opt)\n        m.restore_checkpoint(opt)\n        m.setup_visualizer(opt)\n\n        m.train(opt)\n\nif __name__==\"__main__\":\n    main()\n"
  },
  {
    "path": "util.py",
    "content": "import numpy as np\nimport os,sys,time\nimport shutil\nimport datetime\nimport torch\nimport torch.nn.functional as torch_F\nimport ipdb\nimport types\nimport termcolor\nimport socket\nimport contextlib\nfrom easydict import EasyDict as edict\n\n# convert to colored strings\ndef red(message,**kwargs): return termcolor.colored(str(message),color=\"red\",attrs=[k for k,v in kwargs.items() if v is True])\ndef green(message,**kwargs): return termcolor.colored(str(message),color=\"green\",attrs=[k for k,v in kwargs.items() if v is True])\ndef blue(message,**kwargs): return termcolor.colored(str(message),color=\"blue\",attrs=[k for k,v in kwargs.items() if v is True])\ndef cyan(message,**kwargs): return termcolor.colored(str(message),color=\"cyan\",attrs=[k for k,v in kwargs.items() if v is True])\ndef yellow(message,**kwargs): return termcolor.colored(str(message),color=\"yellow\",attrs=[k for k,v in kwargs.items() if v is True])\ndef magenta(message,**kwargs): return termcolor.colored(str(message),color=\"magenta\",attrs=[k for k,v in kwargs.items() if v is True])\ndef grey(message,**kwargs): return termcolor.colored(str(message),color=\"grey\",attrs=[k for k,v in kwargs.items() if v is True])\n\ndef get_time(sec):\n    d = int(sec//(24*60*60))\n    h = int(sec//(60*60)%24)\n    m = int((sec//60)%60)\n    s = int(sec%60)\n    return d,h,m,s\n\ndef add_datetime(func):\n    def wrapper(*args,**kwargs):\n        datetime_str = datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n        print(grey(\"[{}] \".format(datetime_str),bold=True),end=\"\")\n        return func(*args,**kwargs)\n    return wrapper\n\ndef add_functionname(func):\n    def wrapper(*args,**kwargs):\n        print(grey(\"[{}] \".format(func.__name__),bold=True))\n        return func(*args,**kwargs)\n    return wrapper\n\ndef pre_post_actions(pre=None,post=None):\n    def func_decorator(func):\n        def wrapper(*args,**kwargs):\n            if pre: pre()\n            retval = func(*args,**kwargs)\n            if post: post()\n            return retval\n        return wrapper\n    return func_decorator\n\ndebug = ipdb.set_trace\n\nclass Log:\n    def __init__(self): pass\n    def process(self,pid):\n        print(grey(\"Process ID: {}\".format(pid),bold=True))\n    def title(self,message):\n        print(yellow(message,bold=True,underline=True))\n    def info(self,message):\n        print(magenta(message,bold=True))\n    def options(self,opt,level=0):\n        for key,value in sorted(opt.items()):\n            if isinstance(value,(dict,edict)):\n                print(\"   \"*level+cyan(\"* \")+green(key)+\":\")\n                self.options(value,level+1)\n            else:\n                print(\"   \"*level+cyan(\"* \")+green(key)+\":\",yellow(value))\n    def loss_train(self,opt,ep,lr,loss,timer):\n        if not opt.max_epoch: return\n        message = grey(\"[train] \",bold=True)\n        message += \"epoch {}/{}\".format(cyan(ep,bold=True),opt.max_epoch)\n        message += \", lr:{}\".format(yellow(\"{:.2e}\".format(lr),bold=True))\n        message += \", loss:{}\".format(red(\"{:.3e}\".format(loss),bold=True))\n        message += \", time:{}\".format(blue(\"{0}-{1:02d}:{2:02d}:{3:02d}\".format(*get_time(timer.elapsed)),bold=True))\n        message += \" (ETA:{})\".format(blue(\"{0}-{1:02d}:{2:02d}:{3:02d}\".format(*get_time(timer.arrival))))\n        print(message)\n    def loss_val(self,opt,loss):\n        message = grey(\"[val] \",bold=True)\n        message += \"loss:{}\".format(red(\"{:.3e}\".format(loss),bold=True))\n        print(message)\nlog = Log()\n\ndef update_timer(opt,timer,ep,it_per_ep):\n    if not opt.max_epoch: return\n    momentum = 0.99\n    timer.elapsed = time.time()-timer.start\n    timer.it = timer.it_end-timer.it_start\n    # compute speed with moving average\n    timer.it_mean = timer.it_mean*momentum+timer.it*(1-momentum) if timer.it_mean is not None else timer.it\n    timer.arrival = timer.it_mean*it_per_ep*(opt.max_epoch-ep)\n\n# move tensors to device in-place\ndef move_to_device(X,device):\n    if isinstance(X,dict):\n        for k,v in X.items():\n            X[k] = move_to_device(v,device)\n    elif isinstance(X,list):\n        for i,e in enumerate(X):\n            X[i] = move_to_device(e,device)\n    elif isinstance(X,tuple) and hasattr(X,\"_fields\"): # collections.namedtuple\n        dd = X._asdict()\n        dd = move_to_device(dd,device)\n        return type(X)(**dd)\n    elif isinstance(X,torch.Tensor):\n        return X.to(device=device)\n    return X\n\ndef to_dict(D,dict_type=dict):\n    D = dict_type(D)\n    for k,v in D.items():\n        if isinstance(v,dict):\n            D[k] = to_dict(v,dict_type)\n    return D\n\ndef get_child_state_dict(state_dict,key):\n    return { \".\".join(k.split(\".\")[1:]): v for k,v in state_dict.items() if k.startswith(\"{}.\".format(key)) }\n\ndef restore_checkpoint(opt,model,load_name=None,resume=False):\n    assert((load_name is None)==(resume is not False)) # resume can be True/False or epoch numbers\n    if resume:\n        load_name = \"{0}/model.ckpt\".format(opt.output_path) if resume is True else \\\n                    \"{0}/model/{1}.ckpt\".format(opt.output_path,resume)\n    checkpoint = torch.load(load_name,map_location=opt.device)\n    # load individual (possibly partial) children modules\n    for name,child in model.graph.named_children():\n        child_state_dict = get_child_state_dict(checkpoint[\"graph\"],name)\n        if child_state_dict:\n            print(\"restoring {}...\".format(name))\n            child.load_state_dict(child_state_dict)\n    for key in model.__dict__:\n        if key.split(\"_\")[0] in [\"optim\",\"sched\"] and key in checkpoint and resume:\n            print(\"restoring {}...\".format(key))\n            getattr(model,key).load_state_dict(checkpoint[key])\n    if resume:\n        ep,it = checkpoint[\"epoch\"],checkpoint[\"iter\"]\n        if resume is not True: assert(resume==(ep or it))\n        print(\"resuming from epoch {0} (iteration {1})\".format(ep,it))\n    else: ep,it = None,None\n    return ep,it\n\ndef save_checkpoint(opt,model,ep,it,latest=False,children=None):\n    os.makedirs(\"{0}/model\".format(opt.output_path),exist_ok=True)\n    if children is not None:\n        graph_state_dict = { k: v for k,v in model.graph.state_dict().items() if k.startswith(children) }\n    else: graph_state_dict = model.graph.state_dict()\n    checkpoint = dict(\n        epoch=ep,\n        iter=it,\n        graph=graph_state_dict,\n    )\n    for key in model.__dict__:\n        if key.split(\"_\")[0] in [\"optim\",\"sched\"]:\n            checkpoint.update({ key: getattr(model,key).state_dict() })\n    torch.save(checkpoint,\"{0}/model.ckpt\".format(opt.output_path))\n    if not latest:\n        shutil.copy(\"{0}/model.ckpt\".format(opt.output_path),\n                    \"{0}/model/{1}.ckpt\".format(opt.output_path,ep or it)) # if ep is None, track it instead\n\ndef check_socket_open(hostname,port):\n    s = socket.socket(socket.AF_INET,socket.SOCK_STREAM)\n    is_open = False\n    try:\n        s.bind((hostname,port))\n    except socket.error:\n        is_open = True\n    finally:\n        s.close()\n    return is_open\n\ndef get_layer_dims(layers):\n    # return a list of tuples (k_in,k_out)\n    return list(zip(layers[:-1],layers[1:]))\n\n@contextlib.contextmanager\ndef suppress(stdout=False,stderr=False):\n    with open(os.devnull,\"w\") as devnull:\n        if stdout: old_stdout,sys.stdout = sys.stdout,devnull\n        if stderr: old_stderr,sys.stderr = sys.stderr,devnull\n        try: yield\n        finally:\n            if stdout: sys.stdout = old_stdout\n            if stderr: sys.stderr = old_stderr\n\ndef colorcode_to_number(code):\n    ords = [ord(c) for c in code[1:]]\n    ords = [n-48 if n<58 else n-87 for n in ords]\n    rgb = (ords[0]*16+ords[1],ords[2]*16+ords[3],ords[4]*16+ords[5])\n    return rgb\n"
  },
  {
    "path": "util_vis.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\nimport torchvision\nimport torchvision.transforms.functional as torchvision_F\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d.art3d import Poly3DCollection\nimport PIL\nimport imageio\nfrom easydict import EasyDict as edict\n\nimport camera\n\n@torch.no_grad()\ndef tb_image(opt,tb,step,group,name,images,num_vis=None,from_range=(0,1),cmap=\"gray\"):\n    images = preprocess_vis_image(opt,images,from_range=from_range,cmap=cmap)\n    num_H,num_W = num_vis or opt.tb.num_images\n    images = images[:num_H*num_W]\n    image_grid = torchvision.utils.make_grid(images[:,:3],nrow=num_W,pad_value=1.)\n    if images.shape[1]==4:\n        mask_grid = torchvision.utils.make_grid(images[:,3:],nrow=num_W,pad_value=1.)[:1]\n        image_grid = torch.cat([image_grid,mask_grid],dim=0)\n    tag = \"{0}/{1}\".format(group,name)\n    tb.add_image(tag,image_grid,step)\n\ndef preprocess_vis_image(opt,images,from_range=(0,1),cmap=\"gray\"):\n    min,max = from_range\n    images = (images-min)/(max-min)\n    images = images.clamp(min=0,max=1).cpu()\n    if images.shape[1]==1:\n        images = get_heatmap(opt,images[:,0].cpu(),cmap=cmap)\n    return images\n\ndef dump_images(opt,idx,name,images,masks=None,from_range=(0,1),cmap=\"gray\"):\n    images = preprocess_vis_image(opt,images,masks=masks,from_range=from_range,cmap=cmap) # [B,3,H,W]\n    images = images.cpu().permute(0,2,3,1).numpy() # [B,H,W,3]\n    for i,img in zip(idx,images):\n        fname = \"{}/dump/{}_{}.png\".format(opt.output_path,i,name)\n        img_uint8 = (img*255).astype(np.uint8)\n        imageio.imsave(fname,img_uint8)\n\ndef get_heatmap(opt,gray,cmap): # [N,H,W]\n    color = plt.get_cmap(cmap)(gray.numpy())\n    color = torch.from_numpy(color[...,:3]).permute(0,3,1,2).float() # [N,3,H,W]\n    return color\n\ndef color_border(images,colors,width=3):\n    images_pad = []\n    for i,image in enumerate(images):\n        image_pad = torch.ones(3,image.shape[1]+width*2,image.shape[2]+width*2)*(colors[i,:,None,None]/255.0)\n        image_pad[:,width:-width,width:-width] = image\n        images_pad.append(image_pad)\n    images_pad = torch.stack(images_pad,dim=0)\n    return images_pad\n\n@torch.no_grad()\ndef vis_cameras(opt,vis,step,poses=[],colors=[\"blue\",\"magenta\"],plot_dist=True):\n    win_name = \"{}/{}\".format(opt.group,opt.name)\n    data = []\n    # set up plots\n    centers = []\n    for pose,color in zip(poses,colors):\n        pose = pose.detach().cpu()\n        vertices,faces,wireframe = get_camera_mesh(pose,depth=opt.visdom.cam_depth)\n        center = vertices[:,-1]\n        centers.append(center)\n        # camera centers\n        data.append(dict(\n            type=\"scatter3d\",\n            x=[float(n) for n in center[:,0]],\n            y=[float(n) for n in center[:,1]],\n            z=[float(n) for n in center[:,2]],\n            mode=\"markers\",\n            marker=dict(color=color,size=3),\n        ))\n        # colored camera mesh\n        vertices_merged,faces_merged = merge_meshes(vertices,faces)\n        data.append(dict(\n            type=\"mesh3d\",\n            x=[float(n) for n in vertices_merged[:,0]],\n            y=[float(n) for n in vertices_merged[:,1]],\n            z=[float(n) for n in vertices_merged[:,2]],\n            i=[int(n) for n in faces_merged[:,0]],\n            j=[int(n) for n in faces_merged[:,1]],\n            k=[int(n) for n in faces_merged[:,2]],\n            flatshading=True,\n            color=color,\n            opacity=0.05,\n        ))\n        # camera wireframe\n        wireframe_merged = merge_wireframes(wireframe)\n        data.append(dict(\n            type=\"scatter3d\",\n            x=wireframe_merged[0],\n            y=wireframe_merged[1],\n            z=wireframe_merged[2],\n            mode=\"lines\",\n            line=dict(color=color,),\n            opacity=0.3,\n        ))\n    if plot_dist:\n        # distance between two poses (camera centers)\n        center_merged = merge_centers(centers[:2])\n        data.append(dict(\n            type=\"scatter3d\",\n            x=center_merged[0],\n            y=center_merged[1],\n            z=center_merged[2],\n            mode=\"lines\",\n            line=dict(color=\"red\",width=4,),\n        ))\n        if len(centers)==4:\n            center_merged = merge_centers(centers[2:4])\n            data.append(dict(\n                type=\"scatter3d\",\n                x=center_merged[0],\n                y=center_merged[1],\n                z=center_merged[2],\n                mode=\"lines\",\n                line=dict(color=\"red\",width=4,),\n            ))\n    # send data to visdom\n    vis._send(dict(\n        data=data,\n        win=\"poses\",\n        eid=win_name,\n        layout=dict(\n            title=\"({})\".format(step),\n            autosize=True,\n            margin=dict(l=30,r=30,b=30,t=30,),\n            showlegend=False,\n            yaxis=dict(\n                scaleanchor=\"x\",\n                scaleratio=1,\n            )\n        ),\n        opts=dict(title=\"{} poses ({})\".format(win_name,step),),\n    ))\n\ndef get_camera_mesh(pose,depth=1):\n    vertices = torch.tensor([[-0.5,-0.5,1],\n                             [0.5,-0.5,1],\n                             [0.5,0.5,1],\n                             [-0.5,0.5,1],\n                             [0,0,0]])*depth\n    faces = torch.tensor([[0,1,2],\n                          [0,2,3],\n                          [0,1,4],\n                          [1,2,4],\n                          [2,3,4],\n                          [3,0,4]])\n    vertices = camera.cam2world(vertices[None],pose)\n    wireframe = vertices[:,[0,1,2,3,0,4,1,2,4,3]]\n    return vertices,faces,wireframe\n\ndef merge_wireframes(wireframe):\n    wireframe_merged = [[],[],[]]\n    for w in wireframe:\n        wireframe_merged[0] += [float(n) for n in w[:,0]]+[None]\n        wireframe_merged[1] += [float(n) for n in w[:,1]]+[None]\n        wireframe_merged[2] += [float(n) for n in w[:,2]]+[None]\n    return wireframe_merged\ndef merge_meshes(vertices,faces):\n    mesh_N,vertex_N = vertices.shape[:2]\n    faces_merged = torch.cat([faces+i*vertex_N for i in range(mesh_N)],dim=0)\n    vertices_merged = vertices.view(-1,vertices.shape[-1])\n    return vertices_merged,faces_merged\ndef merge_centers(centers):\n    center_merged = [[],[],[]]\n    for c1,c2 in zip(*centers):\n        center_merged[0] += [float(c1[0]),float(c2[0]),None]\n        center_merged[1] += [float(c1[1]),float(c2[1]),None]\n        center_merged[2] += [float(c1[2]),float(c2[2]),None]\n    return center_merged\n\ndef plot_save_poses(opt,fig,pose,pose_ref=None,path=None,ep=None):\n    # get the camera meshes\n    _,_,cam = get_camera_mesh(pose,depth=opt.visdom.cam_depth)\n    cam = cam.numpy()\n    if pose_ref is not None:\n        _,_,cam_ref = get_camera_mesh(pose_ref,depth=opt.visdom.cam_depth)\n        cam_ref = cam_ref.numpy()\n    # set up plot window(s)\n    plt.title(\"epoch {}\".format(ep))\n    ax1 = fig.add_subplot(121,projection=\"3d\")\n    ax2 = fig.add_subplot(122,projection=\"3d\")\n    setup_3D_plot(ax1,elev=-90,azim=-90,lim=edict(x=(-1,1),y=(-1,1),z=(-1,1)))\n    setup_3D_plot(ax2,elev=0,azim=-90,lim=edict(x=(-1,1),y=(-1,1),z=(-1,1)))\n    ax1.set_title(\"forward-facing view\",pad=0)\n    ax2.set_title(\"top-down view\",pad=0)\n    plt.subplots_adjust(left=0,right=1,bottom=0,top=0.95,wspace=0,hspace=0)\n    plt.margins(tight=True,x=0,y=0)\n    # plot the cameras\n    N = len(cam)\n    color = plt.get_cmap(\"gist_rainbow\")\n\n    if pose_ref is not None:\n        for i in range(N):\n            ax1.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=(0.3,0.3,0.3),linewidth=1)\n            ax2.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=(0.3,0.3,0.3),linewidth=1)\n            ax1.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=(0.3,0.3,0.3),s=40)\n            ax2.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=(0.3,0.3,0.3),s=40)\n    if ep==0:\n        png_fname = \"{}/GT.png\".format(path)\n        plt.savefig(png_fname,dpi=75)\n\n    for i in range(N):\n        c = np.array(color(float(i)/N))*0.8\n        ax1.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=c)\n        ax2.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=c)\n        ax1.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=c,s=40)\n        ax2.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=c,s=40)\n        \n    png_fname = \"{}/{}.png\".format(path,ep)\n    plt.savefig(png_fname,dpi=75)\n    # clean up\n    plt.clf()\n\ndef plot_save_poses_blender(opt,fig,pose,pose_ref=None,path=None,ep=None):\n    # get the camera meshes\n    _,_,cam = get_camera_mesh(pose,depth=opt.visdom.cam_depth)\n    cam = cam.numpy()\n    if pose_ref is not None:\n        _,_,cam_ref = get_camera_mesh(pose_ref,depth=opt.visdom.cam_depth)\n        cam_ref = cam_ref.numpy()\n    # set up plot window(s)\n    ax = fig.add_subplot(111,projection=\"3d\")\n    ax.set_title(\"epoch {}\".format(ep),pad=0)\n    setup_3D_plot(ax,elev=45,azim=35,lim=edict(x=(-3,3),y=(-3,3),z=(-3,2.4)))\n    plt.subplots_adjust(left=0,right=1,bottom=0,top=0.95,wspace=0,hspace=0)\n    plt.margins(tight=True,x=0,y=0)\n    # plot the cameras\n    N = len(cam)\n    ref_color = (0.7,0.2,0.7)\n    pred_color = (0,0.6,0.7)\n    ax.add_collection3d(Poly3DCollection([v[:4] for v in cam_ref],alpha=0.2,facecolor=ref_color))\n    for i in range(N):\n        ax.plot(cam_ref[i,:,0],cam_ref[i,:,1],cam_ref[i,:,2],color=ref_color,linewidth=0.5)\n        ax.scatter(cam_ref[i,5,0],cam_ref[i,5,1],cam_ref[i,5,2],color=ref_color,s=20)\n    if ep==0:\n        png_fname = \"{}/GT.png\".format(path)\n        plt.savefig(png_fname,dpi=75)\n    ax.add_collection3d(Poly3DCollection([v[:4] for v in cam],alpha=0.2,facecolor=pred_color))\n    for i in range(N):\n        ax.plot(cam[i,:,0],cam[i,:,1],cam[i,:,2],color=pred_color,linewidth=1)\n        ax.scatter(cam[i,5,0],cam[i,5,1],cam[i,5,2],color=pred_color,s=20)\n    for i in range(N):\n        ax.plot([cam[i,5,0],cam_ref[i,5,0]],\n                [cam[i,5,1],cam_ref[i,5,1]],\n                [cam[i,5,2],cam_ref[i,5,2]],color=(1,0,0),linewidth=3)\n    png_fname = \"{}/{}.png\".format(path,ep)\n    plt.savefig(png_fname,dpi=75)\n    # clean up\n    plt.clf()\n\ndef setup_3D_plot(ax,elev,azim,lim=None):\n    ax.xaxis.set_pane_color((1.0,1.0,1.0,0.0))\n    ax.yaxis.set_pane_color((1.0,1.0,1.0,0.0))\n    ax.zaxis.set_pane_color((1.0,1.0,1.0,0.0))\n    ax.xaxis._axinfo[\"grid\"][\"color\"] = (0.9,0.9,0.9,1)\n    ax.yaxis._axinfo[\"grid\"][\"color\"] = (0.9,0.9,0.9,1)\n    ax.zaxis._axinfo[\"grid\"][\"color\"] = (0.9,0.9,0.9,1)\n    ax.xaxis.set_tick_params(labelsize=8)\n    ax.yaxis.set_tick_params(labelsize=8)\n    ax.zaxis.set_tick_params(labelsize=8)\n    ax.set_xlabel(\"X\",fontsize=16)\n    ax.set_ylabel(\"Y\",fontsize=16)\n    ax.set_zlabel(\"Z\",fontsize=16)\n    ax.set_xlim(lim.x[0],lim.x[1])\n    ax.set_ylim(lim.y[0],lim.y[1])\n    ax.set_zlim(lim.z[0],lim.z[1])\n    ax.view_init(elev=elev,azim=azim)\n\n\n# Copyright 2022 The Nerfstudio Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\" Helper functions for visualizing outputs \"\"\"\nimport torch\nfrom matplotlib import cm\n\ndef apply_colormap(image, cmap=\"viridis\"):\n    \"\"\"Convert single channel to a color image.\n    Args:\n        image: Single channel image. : TensorType[\"bs\":..., 1]\n        cmap: Colormap for image.\n    Returns:\n        TensorType: Colored image  -> TensorType[\"bs\":..., \"rgb\":3]\n    \"\"\"\n\n    colormap = cm.get_cmap(cmap)\n    colormap = torch.tensor(colormap.colors).to(image.device)  # type: ignore\n    image_long = (image * 255).long()\n\n    image_long[image_long<0]=0\n    image_long[image_long>255]=255\n\n    image_long_min = torch.min(image_long)\n    image_long_max = torch.max(image_long)\n    assert image_long_min >= 0, f\"the min value is {image_long_min}\"\n    assert image_long_max <= 255, f\"the max value is {image_long_max}\"\n    return colormap[image_long[..., 0]]\n\n\ndef apply_depth_colormap(\n    depth,\n    accumulation= None,\n    near_plane= None,\n    far_plane= None,\n    cmap=\"turbo\",\n):\n    \"\"\"Converts a depth image to color for easier analysis.\n    Args:\n        depth: Depth image.: TensorType[\"bs\":..., 1]\n        accumulation: Ray accumulation used for masking vis. : Optional[TensorType[\"bs\":..., 1]] \n        near_plane: Closest depth to consider. If None, use min image value. : Optional[float] \n        far_plane: Furthest depth to consider. If None, use max image value. : Optional[float] \n        cmap: Colormap to apply. # inferno turbo viridis\n    Returns:\n        Colored depth image  -> TensorType[\"bs\":..., \"rgb\":3]\n    \"\"\"\n    # near_plane = near_plane or float(torch.min(depth))\n    # far_plane = far_plane or float(torch.max(depth))\n\n    # depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)\n    depth = torch.clip(depth, 0 , 1)\n    colored_image = apply_colormap(depth, cmap=cmap)\n\n    if accumulation is not None:\n        colored_image = colored_image * accumulation + (1 - accumulation)\n\n    return colored_image"
  },
  {
    "path": "warp.py",
    "content": "import numpy as np\nimport os,sys,time\nimport torch\nimport torch.nn.functional as torch_F\n\nimport util\nfrom util import log,debug\nimport camera\nimport warnings\n\ndef get_normalized_pixel_grid(opt):\n    y_range = ((torch.arange(opt.H,dtype=torch.float32,device=opt.device)+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W))\n    x_range = ((torch.arange(opt.W,dtype=torch.float32,device=opt.device)+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W))\n    Y,X = torch.meshgrid(y_range,x_range) # [H,W]\n    xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]\n    xy_grid = xy_grid.repeat(opt.batch_size,1,1) # [B,HW,2]\n    return xy_grid\n\ndef get_normalized_pixel_grid_crop(opt):\n    y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2)\n    x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2)\n    y_range = ((torch.arange(*(y_crop),dtype=torch.float32,device=opt.device)+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W))\n    x_range = ((torch.arange(*(x_crop),dtype=torch.float32,device=opt.device)+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W))\n    Y,X = torch.meshgrid(y_range,x_range) # [H,W]\n    xy_grid = torch.stack([X,Y],dim=-1).view(-1,2) # [HW,2]\n    xy_grid = xy_grid.repeat(opt.batch_size,1,1) # [B,HW,2]\n    return xy_grid\n\ndef warp_grid(opt,xy_grid,warp):\n    if opt.warp.type==\"translation\":\n        assert(opt.warp.dof==2)\n        warped_grid = xy_grid+warp[...,None,:]\n    elif opt.warp.type==\"rotation\":\n        assert(opt.warp.dof==1)\n        warp_matrix = lie.so2_to_SO2(warp)\n        warped_grid = xy_grid@warp_matrix.transpose(-2,-1) # [B,HW,2]\n    elif opt.warp.type==\"rigid\":\n        assert(opt.warp.dof==3)\n        xy_grid_hom = camera.to_hom(xy_grid)\n        warp_matrix = lie.se2_to_SE2(warp)\n        warped_grid = xy_grid_hom@warp_matrix.transpose(-2,-1) # [B,HW,2]\n    elif opt.warp.type==\"homography\":\n        assert(opt.warp.dof==8)\n        xy_grid_hom = camera.to_hom(xy_grid)\n        warp_matrix = lie.sl3_to_SL3(warp)\n        warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)\n        warped_grid = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]\n    else: assert(False)\n    return warped_grid\n\ndef warp_grid_use_matrix(opt,xy_grid,warp_matrix):\n    if opt.warp.type==\"translation\":\n        assert(opt.warp.dof==2)\n        warped_grid = xy_grid+warp_matrix[...,None,:]\n    elif opt.warp.type==\"rotation\":\n        assert(opt.warp.dof==1)\n        warped_grid = xy_grid@warp_matrix.transpose(-2,-1) # [B,HW,2]\n    elif opt.warp.type==\"rigid\":\n        assert(opt.warp.dof==3)\n        xy_grid_hom = camera.to_hom(xy_grid)\n        warped_grid = xy_grid_hom@warp_matrix.transpose(-2,-1) # [B,HW,2]\n    elif opt.warp.type==\"homography\":\n        assert(opt.warp.dof==8)\n        xy_grid_hom = camera.to_hom(xy_grid)\n        warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)\n        warped_grid = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]\n    else: assert(False)\n    return warped_grid\n\ndef warp_corners(opt,warp_param):\n    y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2)\n    x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2)\n    Y = [((y+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) for y in y_crop]\n    X = [((x+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) for x in x_crop]\n    corners = [(X[0],Y[0]),(X[0],Y[1]),(X[1],Y[1]),(X[1],Y[0])]\n    corners = torch.tensor(corners,dtype=torch.float32,device=opt.device).repeat(opt.batch_size,1,1)\n    corners_warped = warp_grid(opt,corners,warp_param)\n    return corners_warped\n\ndef warp_corners_compose(opt, homo_pert, rot_pert, trans_pert):\n    y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2)\n    x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2)\n    Y = [((y+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) for y in y_crop]\n    X = [((x+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) for x in x_crop]\n    corners = [(X[0],Y[0]),(X[0],Y[1]),(X[1],Y[1]),(X[1],Y[0])]\n    corners = torch.tensor(corners,dtype=torch.float32,device=opt.device).repeat(opt.batch_size,1,1)\n\n    xy_grid_hom = camera.to_hom(corners)\n    warp_matrix = lie.sl3_to_SL3(homo_pert)\n    warped_grid_hom = xy_grid_hom@warp_matrix.transpose(-2,-1)\n    corners_warped = warped_grid_hom[...,:2]/(warped_grid_hom[...,2:]+1e-8) # [B,HW,2]\n    warp_matrix = lie.so2_to_SO2(rot_pert)\n    corners_warped = corners_warped@warp_matrix.transpose(-2,-1) # [B,HW,2]\n    corners_warped = corners_warped+trans_pert[...,None,:]\n    return corners_warped\n\ndef warp_corners_use_matrix(opt,warp_matrix):\n    y_crop = (opt.H//2-opt.H_crop//2,opt.H//2+opt.H_crop//2)\n    x_crop = (opt.W//2-opt.W_crop//2,opt.W//2+opt.W_crop//2)\n    Y = [((y+0.5)/opt.H*2-1)*(opt.H/max(opt.H,opt.W)) for y in y_crop]\n    X = [((x+0.5)/opt.W*2-1)*(opt.W/max(opt.H,opt.W)) for x in x_crop]\n    corners = [(X[0],Y[0]),(X[0],Y[1]),(X[1],Y[1]),(X[1],Y[0])]\n    corners = torch.tensor(corners,dtype=torch.float32,device=opt.device).repeat(opt.batch_size,1,1)\n    corners_warped = warp_grid_use_matrix(opt,corners,warp_matrix)\n    return corners_warped\n\ndef check_corners_in_range(opt,warp_param):\n    corners_all = warp_corners(opt,warp_param)\n    X = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n    Y = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n    return (0<=X).all() and (X<opt.W).all() and (0<=Y).all() and (Y<opt.H).all()\n\ndef check_corners_in_range_compose(opt, homo_pert, rot_pert, trans_pert):\n    corners_all = warp_corners_compose(opt,homo_pert, rot_pert, trans_pert)\n    X = (corners_all[...,0]/opt.W*max(opt.H,opt.W)+1)/2*opt.W-0.5\n    Y = (corners_all[...,1]/opt.H*max(opt.H,opt.W)+1)/2*opt.H-0.5\n    return (((0<=X) * (X<(opt.W//2-opt.W_crop//8)) + ((opt.W//2+opt.W_crop//8)<=X) * (X<opt.W)) \\\n      * ((0<=Y) * (Y<(opt.H//2-opt.H_crop//8)) + ((opt.H//2+opt.H_crop//8)<=Y) * (Y<opt.H))).all()\n\nclass Lie():\n\n    def so2_to_SO2(self,theta): # [...,1]\n        thetax = torch.stack([torch.cat([theta.cos(),-theta.sin()],dim=-1),\n                              torch.cat([theta.sin(),theta.cos()],dim=-1)],dim=-2)\n        R = thetax\n        return R\n\n    def SO2_to_so2(self,R): # [...,2,2]\n        theta = torch.atan2(R[...,1,0],R[...,0,0])\n        return theta[...,None]\n\n    def so2_jacobian(self,X,theta): # [...,N,2],[...,1]\n        dR_dtheta = torch.stack([torch.cat([-theta.sin(),-theta.cos()],dim=-1),\n                                 torch.cat([theta.cos(),-theta.sin()],dim=-1)],dim=-2) # [...,2,2]\n        J = X@dR_dtheta.transpose(-2,-1)\n        return J[...,None] # [...,N,2,1]\n\n    def se2_to_SE2(self,delta): # [...,3]\n        u,theta = delta.split([2,1],dim=-1)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        V = torch.stack([torch.cat([A,-B],dim=-1),\n                         torch.cat([B,A],dim=-1)],dim=-2)\n        R = self.so2_to_SO2(theta)\n        Rt = torch.cat([R,V@u[...,None]],dim=-1)\n        return Rt\n\n    def SE2_to_se2(self,Rt,eps=1e-7): # [...,2,3]\n        R,t = Rt.split([2,1],dim=-1)\n        theta = self.SO2_to_so2(R)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        denom = (A**2+B**2+eps)[...,None]\n        invV = torch.stack([torch.cat([A,B],dim=-1),\n                            torch.cat([-B,A],dim=-1)],dim=-2)/denom\n        u = (invV@t)[...,0]\n        delta = torch.cat([u,theta],dim=-1)\n        return delta\n\n    def se2_jacobian(self,X,delta): # [...,N,2],[...,3]\n        u,theta = delta.split([2,1],dim=-1)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        C = self.taylor_C(theta)\n        D = self.taylor_D(theta)\n        V = torch.stack([torch.cat([A,-B],dim=-1),\n                         torch.cat([B,A],dim=-1)],dim=-2)\n        R = self.so2_to_SO2(theta)\n        dV_dtheta = torch.stack([torch.cat([C,-D],dim=-1),\n                                 torch.cat([D,C],dim=-1)],dim=-2) # [...,2,2]\n        dt_dtheta = dV_dtheta@u[...,None] # [...,2,1]\n        J_so2 = self.so2_jacobian(X,theta) # [...,N,2,1]\n        dX_dtheta = J_so2+dt_dtheta[...,None,:,:] # [...,N,2,1]\n        dX_du = V[...,None,:,:].repeat(*[1]*(len(dX_dtheta.shape)-3),dX_dtheta.shape[-3],1,1)\n        J = torch.cat([dX_du,dX_dtheta],dim=-1)\n        return J # [...,N,2,3]\n\n    def sl3_to_SL3(self,h):\n        # homography: directly expand matrix exponential\n        # https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.61.6151&rep=rep1&type=pdf\n        h1,h2,h3,h4,h5,h6,h7,h8 = h.chunk(8,dim=-1)\n        A = torch.stack([torch.cat([h5,h3,h1],dim=-1),\n                         torch.cat([h4,-h5-h6,h2],dim=-1),\n                         torch.cat([h7,h8,h6],dim=-1)],dim=-2)\n        H = A.matrix_exp()\n        return H\n\n    def taylor_A(self,x,nth=10):\n        # Taylor expansion of sin(x)/x\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            if i>0: denom *= (2*i)*(2*i+1)\n            ans = ans+(-1)**i*x**(2*i)/denom\n        return ans\n    def taylor_B(self,x,nth=10):\n        # Taylor expansion of (1-cos(x))/x\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            denom *= (2*i+1)*(2*i+2)\n            ans = ans+(-1)**i*x**(2*i+1)/denom\n        return ans\n\n    def taylor_C(self,x,nth=10):\n        # Taylor expansion of (x*cos(x)-sin(x))/x**2\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            denom *= (2*i+2)*(2*i+3)\n            ans = ans+(-1)**(i+1)*x**(2*i+1)*(2*i+2)/denom\n        return ans\n\n    def taylor_D(self,x,nth=10):\n        # Taylor expansion of (x*sin(x)+cos(x)-1)/x**2\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth+1):\n            denom *= (2*i+1)*(2*i+2)\n            ans = ans+(-1)**i*x**(2*i)*(2*i+1)/denom\n        return ans\n\nlie = Lie()\n"
  }
]