[
  {
    "path": ".gitignore",
    "content": ".idea/\n.vscode/\n\n__pycache__/\n\nlogs*/\n\ndata/\nconfigs_test/\nweights/"
  },
  {
    "path": "Dockerfile",
    "content": "FROM nvcr.io/nvidia/pytorch:23.02-py3\n\nARG DEBIAN_FRONTEND=noninteractive\nENV TZ=Asia/Shanghai LANG=C.UTF-8 LC_ALL=C.UTF-8 PIP_NO_CACHE_DIR=1 PIP_CACHE_DIR=/tmp/ PYTHONUNBUFFERED=1 PYTHONFAULTHANDLER=1 PYTHONHASHSEED=0\n\nRUN \\\n    # uncomment to use apt mirror\n    # sed -i \"s/archive.ubuntu.com/mirrors.ustc.edu.cn/g\" /etc/apt/sources.list &&\\\n    # sed -i \"s/security.ubuntu.com/mirrors.ustc.edu.cn/g\" /etc/apt/sources.list &&\\\n    rm -f /etc/apt/sources.list.d/* &&\\\n    rm -rf /opt/hpcx/ &&\\\n    apt-get update && apt-get upgrade -y &&\\\n    apt-get install -y --no-install-recommends \\\n        autoconf automake autotools-dev build-essential ca-certificates \\\n        make cmake ninja-build pkg-config g++ ccache yasm openmpi-bin \\\n        git curl wget unzip nano net-tools htop iotop \\\n        cloc rsync xz-utils software-properties-common tzdata \\\n    && apt-get purge -y unattended-upgrades \\\n    && apt-get clean \\\n    && rm -rf /var/lib/apt/lists/*\n\nADD requirements.txt /tmp\nRUN \\\n    # uncomment to use pypi mirror\n    # pip config set global.index-url https://mirrors.bfsu.edu.cn/pypi/web/simple &&\\\n    pip install -U pip &&\\\n    pip install -r /tmp/requirements.txt &&\\\n    pip install \"jupyterlab~=3.5.0\" \"jupyter-archive~=3.2\" &&\\\n    rm -rf /tmp/*\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Peng Wang, Lingzhe Zhao, Ruijie Ma, Peidong Liu\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# 😈BAD-NeRF\n\n<a href=\"https://arxiv.org/abs/2211.12853\"><img src=\"https://img.shields.io/badge/arXiv-2211.12853-b31b1b.svg\" height=22.5></a>\n<a href=\"https://opensource.org/licenses/MIT\"><img src=\"https://img.shields.io/github/license/WU-CVGL/BAD-NeRF\" height=22.5></a>\n<a href=\"https://www.youtube.com/watch?v=xoES4eONYoA\"><img src=\"https://img.shields.io/badge/YouTube-%23FF0000.svg?style=flat&logo=YouTube&logoColor=white\" height=22.5></a>\n<a href=\"https://www.bilibili.com/video/BV1Gz4y1e7oH/\"><img src=\"https://img.shields.io/badge/Bilibili-0696c6\" height=22.5></a>\n\nThis is an official PyTorch implementation of the paper [BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields](https://arxiv.org/abs/2211.12853) (CVPR 2023). Authors: [Peng Wang](https://github.com/wangpeng000), [Lingzhe Zhao](https://github.com/LingzheZhao), Ruijie Ma and [Peidong Liu](https://ethliup.github.io/).\n\nBAD-NeRF jointly learns the 3D representation and optimizes the camera motion trajectories within exposure time from blurry images and inaccurate initial poses.\n\nHere is the [Project page](https://wangpeng000.github.io/BAD-NeRF/).\n\n## ✨News\n📺 **[2023.12]** We were invited to give an online talk on Bilibili and Wechat, hosted by **计算机视觉life**. Archive of our live stream (in Chinese): [[link to Bilibili]](https://www.bilibili.com/video/BV1Rb4y1G7Ek/)\n\n⚡ **[2023.11]** We made our **nerfstudio**-framework-based implementation: [[BAD-NeRFstudio]](https://github.com/WU-CVGL/BAD-NeRFstudio) public. Now you can train a scene from blurry images in minutes!\n\n## Novel View Synthesis\n<div><video autoplay loop controls src=\"https://user-images.githubusercontent.com/43722188/232816090-ced1fbbc-4246-45c6-a265-e7424e754c7b.mp4\" muted=\"true\"></video></div>\n\n## Deblurring Result\n![teaser](./doc/bad-nerf.jpg)\n\n## Pose Estimation Result\n![pose estimation](./doc/pose-estimation.jpg)\n\n## Method overview\n![method](./doc/overview.jpg)\n\nWe follow the real physical image formation process of a motion-blurred image to synthesize blurry images from NeRF. Both NeRF and the motion trajectories are estimated by maximizing the photometric consistency between the synthesized blurry images and the real blurry images.\n\n## Quickstart\n\n### 1. Setup environment\n\n```\ngit clone https://github.com/WU-CVGL/BAD-NeRF\ncd BAD-NeRF\npip install -r requirements.txt\n```\n\n### 2. Download datasets\n\nYou can download the data and weights [here](https://westlakeu-my.sharepoint.com/:f:/g/personal/cvgl_westlake_edu_cn/EsgdW2cRic5JqerhNbTsxtkBqy9m6cbnb2ugYZtvaib3qA?e=bjK7op).\n\nFor the scenes of Deblur-NeRF (*cozy2room*, *factory* etc.), the folder `images` only includes blurry images and the folder `images_1` additionally includes novel view images. But for our scenes (*room-low*, *room-high* and *dark*), there are no novel view images. Note that the images in the *dark* scene are undistorted, causing that there are some useless pixels, and you should uncomment the code of `Graph.forward` in `nerf.py`.\n\n### 3. Configs\n\nChange the data path and other parameters (if needed) in `configs/cozy2room.txt`. We use *cozy2room* scene as an example.\n\n### 4. Demo with our pre-trained model\n\nYou can test our code and render sharp images with the provided weight files. To do this, you should first put the weight file under the corresponding logs folder `./logs/cozy2room-linear`, and then change the parameter `load_weights=True` in `cozy2room.txt`, finally run\n\n```\npython test.py --config configs/cozy2room.txt\n```\n\n### 5. Training\n\n```\npython train.py --config configs/cozy2room.txt\n```\n\nAfter training, you can get deblurred images, optimized camera poses and synthesized novel view images.\n\n## Notes\n\n### Camera poses\n\nThe poses (`poses_bounds.npy`) are generated from only blurred images (folder `images`) by COLMAP.\n\n### Spline model\n\nWe use `linear interpolation` as the default spline model in our experiments, you can simply change the parameter `linear` (all the parameters can be changed in `configs/***.txt` or `run_nerf.py`) to `False` to use the higher-order spline model (i.e. `cubic B-Spline`).\n\n### Virtual images\n\nYou can change the important parameter `deblur_images` to a smaller/bigger value for lightly/severely blurred images.\n\n### Learning rate\n\nAfter rebuttal, we found that sometimes the gradients will be NaN if `cubic B-Spline` model with a `pose_lrate=1e-3` is used. Therefore, we set the initial pose learning rate to 1e-4 and it may achieve a better performance compared to that in our paper. If the gradient appears NaN in your experiments unfortunately, just kill it and try again or decrease the `pose_lrate`.\n\n## Your own data\n\n`images`: This folder is used to estimate initial camera poses from blurry images. Specifically, just put your own data in the folder `images` (only blurry images), and run `imgs2poses.py` script from the [LLFF code](https://github.com/fyusion/llff) to estimate camera poses and generate `poses_bounds.npy`.\n\n`images_1`: This is the default training folder, which includes the same blurry images in `images` folder and (optional) several novel view sharp images. If you want to add novel view images (sharp images), please put them into the folder `images_1` with an interval of `llffhold` (a parameter used for novel view testing). Remember that, set parameter `novel_view` to `True` if `images_1` includes novel view images. Otherwise, if there are no novel view images, you can directly put the blurry images to the folder `images_1` and set parameter `novel_view` to `False`.\n\n`images_test`: To compute deblurring metrics, this folder contains ground truth images theoretically. However, you can copy the blurry images in `images` folder to `images_test` folder if you don't have ground truth images, which is the easiest way to run the code correctly (remember the computed metrics are wrong).\n```\n#-----------------------------------------------------------------------------------------#\n# images folder: img_blur_*.png is the blurry image.                                      #\n#-----------------------------------------------------------------------------------------#\n# images_1 folder: img_blur_*.png is the same as that in `images` and (optional)          #\n# img_novel_*.png is the sharp novel view image.                                          #\n#-----------------------------------------------------------------------------------------#\n# images_test folder: img_test_*.png should be the ground truth image corrseponds to      #\n# img_blur_*.png to compute PSNR metric. Of course, you can directly put img_blur_*.png   #\n# to run the code if you don't have gt images (then the metrics are wrong).               #\n#-----------------------------------------------------------------------------------------#\nimages folder: (suppose 10 images)\nimg_blur_0.png\nimg_blur_1.png\n.\n.\n.\nimg_blur_9.png\n#-----------------------------------------------------------------------------------------#\nimages_1 folder: (suppose novel view images are placed with an `llffhold=5` interval.)\nimg_novel_0.png (optional)\nimg_blur_0.png\nimg_blur_1.png\n.\nimg_blur_4.png\nimg_novel_1.png (optional)\nimg_blur_5.png\n.\nimg_blur_9.png\nimg_novel_2.png (optional)\n#-----------------------------------------------------------------------------------------#\nimages_test folder: (theoretically gt images, but can be other images)\nimg_test_0.png\nimg_test_1.png\n.\n.\n.\nimg_test_9.png\n```\n## Citation\n\nIf you find this useful, please consider citing our paper:\n\n```bibtex\n@InProceedings{wang2023badnerf,\n    author    = {Wang, Peng and Zhao, Lingzhe and Ma, Ruijie and Liu, Peidong},\n    title     = {{BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields}},\n    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n    month     = {June},\n    year      = {2023},\n    pages     = {4170-4179}\n}\n```\n\n## Acknowledgment\n\nThe overall framework, metrics computing and camera transformation are derived from [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch/), [Deblur-NeRF](https://github.com/limacv/Deblur-NeRF) and [BARF](https://github.com/chenhsuanlin/bundle-adjusting-NeRF) respectively. We appreciate the effort of the contributors to these repositories.\n"
  },
  {
    "path": "Spline.py",
    "content": "import torch\nimport numpy as np\n\ndelt = 0\n\n\ndef skew_symmetric(w):\n    w0, w1, w2 = w.unbind(dim=-1)\n    O = torch.zeros_like(w0)\n    wx = torch.stack(\n        [\n            torch.stack([O, -w2, w1], dim=-1),\n            torch.stack([w2, O, -w0], dim=-1),\n            torch.stack([-w1, w0, O], dim=-1),\n        ],\n        dim=-2,\n    )\n    return wx\n\n\ndef taylor_A(x, nth=10):\n    # Taylor expansion of sin(x)/x\n    ans = torch.zeros_like(x)\n    denom = 1.0\n    for i in range(nth + 1):\n        if i > 0:\n            denom *= (2 * i) * (2 * i + 1)\n        ans = ans + (-1) ** i * x ** (2 * i) / denom\n    return ans\n\n\ndef taylor_B(x, nth=10):\n    # Taylor expansion of (1-cos(x))/x**2\n    ans = torch.zeros_like(x)\n    denom = 1.0\n    for i in range(nth + 1):\n        denom *= (2 * i + 1) * (2 * i + 2)\n        ans = ans + (-1) ** i * x ** (2 * i) / denom\n    return ans\n\n\ndef taylor_C(x, nth=10):\n    # Taylor expansion of (x-sin(x))/x**3\n    ans = torch.zeros_like(x)\n    denom = 1.0\n    for i in range(nth + 1):\n        denom *= (2 * i + 2) * (2 * i + 3)\n        ans = ans + (-1) ** i * x ** (2 * i) / denom\n    return ans\n\n\ndef exp_r2q_parallel(r, eps=1e-9):\n    x, y, z = r[..., 0], r[..., 1], r[..., 2]\n    theta = 0.5 * torch.sqrt(x**2 + y**2 + z**2)\n    bool_criterion = (theta < eps).unsqueeze(-1).repeat(1, 1, 4)\n    return torch.where(\n        bool_criterion, exp_r2q_taylor(x, y, z, theta), exp_r2q(x, y, z, theta)\n    )\n\n\ndef exp_r2q(x, y, z, theta):\n    lambda_ = torch.sin(theta) / (2.0 * theta)\n    qx = lambda_ * x\n    qy = lambda_ * y\n    qz = lambda_ * z\n    qw = torch.cos(theta)\n    return torch.stack([qx, qy, qz, qw], -1)\n\n\ndef exp_r2q_taylor(x, y, z, theta):\n    qx = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * x\n    qy = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * y\n    qz = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * z\n    qw = 1.0 - 1.0 / 2.0 * theta**2 + 1.0 / 24.0 * theta**4\n    return torch.stack([qx, qy, qz, qw], -1)\n\n\ndef q_to_R_parallel(q):\n    qb, qc, qd, qa = q.unbind(dim=-1)\n    R = torch.stack(\n        [\n            torch.stack(\n                [\n                    1 - 2 * (qc**2 + qd**2),\n                    2 * (qb * qc - qa * qd),\n                    2 * (qa * qc + qb * qd),\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    2 * (qb * qc + qa * qd),\n                    1 - 2 * (qb**2 + qd**2),\n                    2 * (qc * qd - qa * qb),\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    2 * (qb * qd - qa * qc),\n                    2 * (qa * qb + qc * qd),\n                    1 - 2 * (qb**2 + qc**2),\n                ],\n                dim=-1,\n            ),\n        ],\n        dim=-2,\n    )\n    return R\n\n\ndef q_to_Q_parallel(q):\n    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]\n    Q_0 = torch.stack([w, -z, y, x], -1).unsqueeze(-2)\n    Q_1 = torch.stack([z, w, -x, y], -1).unsqueeze(-2)\n    Q_2 = torch.stack([-y, x, w, z], -1).unsqueeze(-2)\n    Q_3 = torch.stack([-x, -y, -z, w], -1).unsqueeze(-2)\n    Q_ = torch.cat([Q_0, Q_1, Q_2, Q_3], -2)\n    return Q_\n\n\ndef q_to_q_conj_parallel(q):\n    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]\n    q_conj_ = torch.stack([-x, -y, -z, w], -1)\n    return q_conj_\n\n\ndef log_q2r_parallel(q, eps_theta=1e-20, eps_w=1e-10):\n    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]\n\n    theta = torch.sqrt(x**2 + y**2 + z**2)\n\n    bool_theta_0 = theta < eps_theta\n    bool_w_0 = torch.abs(w) < eps_w\n    bool_w_0_left = torch.logical_and(bool_w_0, w < 0)\n\n    lambda_ = torch.where(\n        bool_w_0,\n        torch.where(\n            bool_w_0_left,\n            log_q2r_lim_w_0_left(theta),\n            log_q2r_lim_w_0_right(theta)\n        ),\n        torch.where(\n            bool_theta_0,\n            log_q2r_taylor_theta_0(w, theta),\n            log_q2r(w, theta)\n        ),\n    )\n\n    r_ = torch.stack([lambda_ * x, lambda_ * y, lambda_ * z], -1)\n\n    return r_\n\n\ndef log_q2r(w, theta):\n    return 2.0 * (torch.arctan(theta / w)) / theta\n\n\ndef log_q2r_taylor_theta_0(w, theta):\n    return 2.0 / w - 2.0 / 3.0 * (theta**2) / (w * w * w)\n\n\ndef log_q2r_lim_w_0_left(theta):\n    return -torch.pi / theta\n\n\ndef log_q2r_lim_w_0_right(theta):\n    return torch.pi / theta\n\n\ndef SE3_to_se3(Rt, eps=1e-8):  # [...,3,4]\n    R, t = Rt.split([3, 1], dim=-1)\n    w = SO3_to_so3(R)\n    wx = skew_symmetric(w)\n    theta = w.norm(dim=-1)[..., None, None]\n    I = torch.eye(3, device=w.device, dtype=torch.float32)\n    A = taylor_A(theta)\n    B = taylor_B(theta)\n    invV = I - 0.5 * wx + (1 - A / (2 * B)) / (theta**2 + eps) * wx @ wx\n    u = (invV @ t)[..., 0]\n    wu = torch.cat([w, u], dim=-1)\n    return wu\n\n\ndef SO3_to_so3(R, eps=1e-7):  # [...,3,3]\n    trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]\n    theta = ((trace - 1) / 2).clamp(-1 + eps, 1 - eps).acos_()[\n        ..., None, None\n    ] % np.pi  # ln(R) will explode if theta==pi\n    lnR = (\n        1 / (2 * taylor_A(theta) + 1e-8) * (R - R.transpose(-2, -1))\n    )  # FIXME: wei-chiu finds it weird\n    w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]\n    w = torch.stack([w0, w1, w2], dim=-1)\n    return w\n\n\ndef se3_to_SE3(wu):  # [...,3]\n    w, u = wu.split([3, 3], dim=-1)\n    wx = skew_symmetric(w)  # wx=[0 -w(2) w(1);w(2) 0 -w(0);-w(1) w(0) 0]\n    theta = w.norm(dim=-1)[..., None, None]  # theta=sqrt(w'*w)\n    I = torch.eye(3, device=w.device, dtype=torch.float32)\n    A = taylor_A(theta)\n    B = taylor_B(theta)\n    C = taylor_C(theta)\n    R = I + A * wx + B * wx @ wx\n    V = I + B * wx + C * wx @ wx\n    Rt = torch.cat([R, (V @ u[..., None])], dim=-1)\n    return Rt\n\n\ndef SE3_to_se3_N(poses_rt):\n    poses_se3_list = []\n    for i in range(poses_rt.shape[0]):\n        pose_se3 = SE3_to_se3(poses_rt[i])\n        poses_se3_list.append(pose_se3)\n    poses = torch.stack(poses_se3_list, 0)\n    return poses\n\n\ndef se3_to_SE3_N(poses_wu):\n    poses_se3_list = []\n    for i in range(poses_wu.shape[0]):\n        pose_se3 = se3_to_SE3(poses_wu[i])\n        poses_se3_list.append(pose_se3)\n    poses = torch.stack(poses_se3_list, 0)\n    return poses\n\n\ndef se3_2_qt_parallel(wu):\n    w, u = wu.split([3, 3], dim=-1)\n    wx = skew_symmetric(w)\n    theta = w.norm(dim=-1)[..., None, None]\n    I = torch.eye(3, device=w.device, dtype=torch.float32)\n    # A = taylor_A(theta)\n    B = taylor_B(theta)\n    C = taylor_C(theta)\n    # R = I + A * wx + B * wx @ wx\n    V = I + B * wx + C * wx @ wx\n    t = V @ u[..., None]\n    q = exp_r2q_parallel(w)\n    return q, t.squeeze(-1)\n\n\ndef SplineN_linear(start_pose, end_pose, poses_number, NUM, device=None):\n    pose_time = poses_number / (NUM - 1)\n\n    # parallel\n    pos_0 = torch.where(pose_time == 0)\n    pose_time[pos_0] = pose_time[pos_0] + 0.000001\n    pos_1 = torch.where(pose_time == 1)\n    pose_time[pos_1] = pose_time[pos_1] - 0.000001\n\n    q_start, t_start = se3_2_qt_parallel(start_pose)\n    q_end, t_end = se3_2_qt_parallel(end_pose)\n    # sample t_vector\n    t_t = (1 - pose_time)[..., None] * t_start + pose_time[..., None] * t_end\n\n    # sample rotation_vector\n    q_tau_0 = q_to_Q_parallel(q_to_q_conj_parallel(q_start)) @ q_end[..., None]\n    r = pose_time[..., None] * log_q2r_parallel(q_tau_0.squeeze(-1))\n    q_t_0 = exp_r2q_parallel(r)\n    q_t = q_to_Q_parallel(q_start) @ q_t_0[..., None]\n\n    # convert q&t to RT\n    R = q_to_R_parallel(q_t.squeeze(dim=-1))\n    t = t_t.unsqueeze(dim=-1)\n    pose_spline = torch.cat([R, t], -1)\n\n    poses = pose_spline.reshape([-1, 3, 4])\n\n    return poses\n\n\ndef SplineN_cubic(pose0, pose1, pose2, pose3, poses_number, NUM):\n    sample_time = poses_number / (NUM - 1)\n    # parallel\n\n    pos_0 = torch.where(sample_time == 0)\n    sample_time[pos_0] = sample_time[pos_0] + 0.000001\n    pos_1 = torch.where(sample_time == 1)\n    sample_time[pos_1] = sample_time[pos_1] - 0.000001\n\n    sample_time = sample_time.unsqueeze(-1)\n\n    q0, t0 = se3_2_qt_parallel(pose0)\n    q1, t1 = se3_2_qt_parallel(pose1)\n    q2, t2 = se3_2_qt_parallel(pose2)\n    q3, t3 = se3_2_qt_parallel(pose3)\n\n    u = sample_time\n    uu = sample_time**2\n    uuu = sample_time**3\n    one_over_six = 1.0 / 6.0\n    half_one = 0.5\n\n    # t\n    coeff0 = one_over_six - half_one * u + half_one * uu - one_over_six * uuu\n    coeff1 = 4 * one_over_six - uu + half_one * uuu\n    coeff2 = one_over_six + half_one * u + half_one * uu - half_one * uuu\n    coeff3 = one_over_six * uuu\n\n    # spline t\n    t_t = coeff0 * t0 + coeff1 * t1 + coeff2 * t2 + coeff3 * t3\n\n    # R\n    coeff1_r = 5 * one_over_six + half_one * u - half_one * uu + one_over_six * uuu\n    coeff2_r = one_over_six + half_one * u + half_one * uu - 2 * one_over_six * uuu\n    coeff3_r = one_over_six * uuu\n\n    # spline R\n    q_01 = q_to_Q_parallel(q_to_q_conj_parallel(q0)) @ q1[..., None]  # [1]\n    q_12 = q_to_Q_parallel(q_to_q_conj_parallel(q1)) @ q2[..., None]  # [2]\n    q_23 = q_to_Q_parallel(q_to_q_conj_parallel(q2)) @ q3[..., None]  # [3]\n\n    r_01 = log_q2r_parallel(q_01.squeeze(-1)) * coeff1_r  # [4]\n    r_12 = log_q2r_parallel(q_12.squeeze(-1)) * coeff2_r  # [5]\n    r_23 = log_q2r_parallel(q_23.squeeze(-1)) * coeff3_r  # [6]\n\n    q_t_0 = exp_r2q_parallel(r_01)  # [7]\n    q_t_1 = exp_r2q_parallel(r_12)  # [8]\n    q_t_2 = exp_r2q_parallel(r_23)  # [9]\n\n    q_product1 = q_to_Q_parallel(q_t_1) @ q_t_2[..., None]  # [10]\n    q_product2 = q_to_Q_parallel(q_t_0) @ q_product1  # [10]\n    q_t = q_to_Q_parallel(q0) @ q_product2  # [10]\n\n    R = q_to_R_parallel(q_t.squeeze(-1))\n    t = t_t.unsqueeze(dim=-1)\n\n    pose_spline = torch.cat([R, t], -1)\n\n    poses = pose_spline.reshape([-1, 3, 4])\n\n    return poses\n"
  },
  {
    "path": "configs/cozy2room-cubic.txt",
    "content": "expname = cozy2room-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurcozy2room\ndataset_type = llff\n\nfactor = 1\n\nlinear = False\npose_lrate = 1e-4\n\nnovel_view = True\nfactor_pose_novel = 20.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/cozy2room.txt",
    "content": "expname = cozy2room-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurcozy2room\ndataset_type = llff\n\nfactor = 1\n\nlinear = True\n\nnovel_view = True\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/factory-cubic.txt",
    "content": "expname = factory-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurfactory\ndataset_type = llff\n\nfactor = 1\n\nlinear = False\npose_lrate = 1e-4\n\nnovel_view = True\nfactor_pose_novel = 200.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/factory.txt",
    "content": "expname = factory-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurfactory\ndataset_type = llff\n\nfactor = 1\n\nlinear = True\n\nnovel_view = True\nfactor_pose_novel = 20.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/pool-cubic.txt",
    "content": "expname = pool-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurpool\ndataset_type = llff\n\nfactor = 1\n\nlinear = False\npose_lrate = 1e-4\n\nnovel_view = True\nfactor_pose_novel = 20.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/pool.txt",
    "content": "expname = pool-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurpool\ndataset_type = llff\n\nfactor = 1\n\nlinear = True\n\nnovel_view = True\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/roomdark-cubic.txt",
    "content": "expname = roomdark-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/dark\ndataset_type = llff\n\nfactor = 1\n\nlinear = False\npose_lrate = 1e-4\n\nnovel_view = False\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/roomdark.txt",
    "content": "expname = roomdark-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/dark\ndataset_type = llff\n\nfactor = 1\n\nlinear = True\n\nnovel_view = False\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/roomhigh-cubic.txt",
    "content": "expname = roomhigh-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/roomblur_high\ndataset_type = llff\n\nfactor = 1\n\nlinear = False\npose_lrate = 1e-4\n\nnovel_view = False\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/roomhigh.txt",
    "content": "expname = roomhigh-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/roomblur_high\ndataset_type = llff\n\nfactor = 1\n\nlinear = True\n\nnovel_view = False\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/roomlow-cubic.txt",
    "content": "expname = roomlow-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/roomblur_low\ndataset_type = llff\n\nfactor = 1\n\nlinear = False\npose_lrate = 1e-4\n\nnovel_view = False\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/roomlow.txt",
    "content": "expname = roomlow-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/roomblur_low\ndataset_type = llff\n\nfactor = 1\n\nlinear = True\n\nnovel_view = False\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/tanabata-cubic.txt",
    "content": "expname = tanabata-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurtanabata\ndataset_type = llff\n\nfactor = 1\n\nlinear = False\npose_lrate = 1e-4\n\nnovel_view = True\nfactor_pose_novel = 50.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/tanabata.txt",
    "content": "expname = tanabata-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurtanabata\ndataset_type = llff\n\nfactor = 1\n\nlinear = True\n\nnovel_view = True\nfactor_pose_novel = 5.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/wine-cubic.txt",
    "content": "expname = wine-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurwine\ndataset_type = llff\n\nfactor = 1\n\nlinear = False\npose_lrate = 1e-4\n\nnovel_view = True\nfactor_pose_novel = 20.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "configs/wine.txt",
    "content": "expname = wine-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurwine\ndataset_type = llff\n\nfactor = 1\n\nlinear = True\n\nnovel_view = True\nfactor_pose_novel = 2.0\ni_novel_view = 200000\n\nN_rand = 5000\ndeblur_images = 7\n\nN_samples = 64\nN_importance = 64\n\nuse_viewdirs = True\nraw_noise_std = 1.0\n\nload_weights = False\nweight_iter = 200000\n\ni_img = 25000\ni_video = 200000\ni_weights = 10000\n"
  },
  {
    "path": "load_llff.py",
    "content": "import numpy as np\nimport os, imageio\nimport torch\n\n\n########## Slightly modified version of LLFF data loading code\n##########  see https://github.com/Fyusion/LLFF for original\n\n# downsample\ndef _minify(basedir, factors=[], resolutions=[]):  # basedir: ./data/nerf_llff_data/fern\n    needtoload = False\n    for r in factors:\n        imgdir = os.path.join(basedir, 'images_{}'.format(r))\n        if not os.path.exists(imgdir):\n            needtoload = True\n    for r in resolutions:\n        imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))\n        if not os.path.exists(imgdir):\n            needtoload = True\n    if not needtoload:\n        return\n\n    from shutil import copy\n    from subprocess import check_output\n\n    imgdir = os.path.join(basedir, 'images')\n    imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]\n    imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]\n    imgdir_orig = imgdir\n\n    wd = os.getcwd()\n\n    for r in factors + resolutions:\n        if isinstance(r, int):\n            name = 'images_{}'.format(r)\n            resizearg = '{}%'.format(100. / r)\n        else:\n            name = 'images_{}x{}'.format(r[1], r[0])\n            resizearg = '{}x{}'.format(r[1], r[0])\n        imgdir = os.path.join(basedir, name)\n        if os.path.exists(imgdir):\n            continue\n\n        print('Minifying', r, basedir)\n\n        os.makedirs(imgdir)\n        check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)\n\n        ext = imgs[0].split('.')[-1]\n        args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])\n        print(args)\n        os.chdir(imgdir)\n        check_output(args, shell=True)\n        os.chdir(wd)\n\n        if ext != 'png':\n            check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)\n            print('Removed duplicates')\n        print('Done')\n\n\ndef _load_data(basedir, pose_state, factor=None, width=None, height=None, load_imgs=True):\n\n    poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))\n\n    poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])\n    bds = poses_arr[:, -2:].transpose([1, 0])\n\n    img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \\\n            if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]\n    sh = imageio.imread(img0).shape\n\n    sfx = ''\n\n    if factor is not None:\n        sfx = '_{}'.format(factor)\n        _minify(basedir, factors=[factor])\n        factor = factor\n    elif height is not None:\n        factor = sh[0] / float(height)\n        width = int(sh[1] / factor)\n        _minify(basedir, resolutions=[[height, width]])\n        sfx = '_{}x{}'.format(width, height)\n    elif width is not None:\n        factor = sh[1] / float(width)\n        height = int(sh[0] / factor)\n        _minify(basedir, resolutions=[[height, width]])\n        sfx = '_{}x{}'.format(width, height)\n    else:\n        factor = 1\n\n    imgdir = os.path.join(basedir, 'images' + sfx)\n    if not os.path.exists(imgdir):\n        print(imgdir, 'does not exist, returning')\n        return\n\n    imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if\n                f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]\n\n    if poses.shape[-1] != len(imgfiles):\n        print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )\n        # return\n\n    sh = imageio.imread(imgfiles[0]).shape\n    poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])\n    poses[2, 4, :] = poses[2, 4, :] * 1. / factor\n\n    if not load_imgs:\n        return poses, bds\n\n    def imread(f):\n        if f.endswith('png'):\n            return imageio.imread(f, ignoregamma=True)\n        else:\n            return imageio.imread(f)\n\n    imgs = [imread(f)[..., :3] / 255. for f in imgfiles]\n    imgs = np.stack(imgs, -1)\n\n    print('Loaded image data', imgs.shape, poses[:, -1, 0])\n\n    return poses, bds, imgs\n\n\ndef normalize(x):\n    return x / np.linalg.norm(x)\n\n\ndef viewmatrix(z, up, pos):\n    vec2 = normalize(z)\n    vec1_avg = up\n    vec0 = normalize(np.cross(vec1_avg, vec2))\n    vec1 = normalize(np.cross(vec2, vec0))\n    m = np.stack([vec0, vec1, vec2, pos], 1)\n    return m\n\n\ndef ptstocam(pts, c2w):\n    tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]\n    return tt\n\n\ndef poses_avg(poses):\n    hwf = poses[0, :3, -1:]\n\n    center = poses[:, :3, 3].mean(0)\n    vec2 = normalize(poses[:, :3, 2].sum(0))\n    up = poses[:, :3, 1].sum(0)\n    c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)\n\n    return c2w\n\n\ndef render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):\n    render_poses = []\n    rads = np.array(list(rads) + [1.])\n    hwf = c2w[:, 4:5]\n\n    for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]:\n        c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)\n        z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))\n        render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))\n    return render_poses\n\n\ndef recenter_poses(poses):\n    poses_ = poses + 0\n    bottom = np.reshape([0, 0, 0, 1.], [1, 4])\n    c2w = poses_avg(poses)\n    c2w = np.concatenate([c2w[:3, :4], bottom], -2)\n    bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])\n    poses = np.concatenate([poses[:, :3, :4], bottom], -2)\n\n    poses = np.linalg.inv(c2w) @ poses\n    poses_[:, :3, :4] = poses[:, :3, :4]\n    poses = poses_\n    return poses\n\n\n#####################\n\n\ndef spherify_poses(poses, bds):\n    p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1)\n\n    rays_d = poses[:, :3, 2:3]\n    rays_o = poses[:, :3, 3:4]\n\n    def min_line_dist(rays_o, rays_d):\n        A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])\n        b_i = -A_i @ rays_o\n        pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0))\n        return pt_mindist\n\n    pt_mindist = min_line_dist(rays_o, rays_d)\n\n    center = pt_mindist\n    up = (poses[:, :3, 3] - center).mean(0)\n\n    vec0 = normalize(up)\n    vec1 = normalize(np.cross([.1, .2, .3], vec0))\n    vec2 = normalize(np.cross(vec0, vec1))\n    pos = center\n    c2w = np.stack([vec1, vec2, vec0, pos], 1)\n\n    poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])\n\n    rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))\n\n    sc = 1. / rad\n    poses_reset[:, :3, 3] *= sc\n    bds *= sc\n    rad *= sc\n\n    centroid = np.mean(poses_reset[:, :3, 3], 0)\n    zh = centroid[2]\n    radcircle = np.sqrt(rad ** 2 - zh ** 2)\n    new_poses = []\n\n    for th in np.linspace(0., 2. * np.pi, 120):\n        camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])\n        up = np.array([0, 0, -1.])\n\n        vec2 = normalize(camorigin)\n        vec0 = normalize(np.cross(vec2, up))\n        vec1 = normalize(np.cross(vec2, vec0))\n        pos = camorigin\n        p = np.stack([vec0, vec1, vec2, pos], 1)\n\n        new_poses.append(p)\n\n    new_poses = np.stack(new_poses, 0)\n\n    new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1)\n    poses_reset = np.concatenate(\n        [poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1)\n\n    return poses_reset, new_poses, bds\n\n\ndef load_llff_data(basedir, pose_state=None, factor=1, recenter=True, bd_factor=.75, spherify=False,\n                   path_zflat=False):\n    poses, bds, imgs = _load_data(basedir, pose_state=pose_state,factor=factor)\n    print('Loaded', basedir, bds.min(), bds.max(), 'Pose State: ', pose_state)\n\n    poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)\n    poses = np.moveaxis(poses, -1, 0).astype(np.float32)\n    imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)\n    bds = np.moveaxis(bds, -1, 0).astype(np.float32)\n\n    # Rescale if bd_factor is provided\n    sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor)\n    poses[:, :3, 3] *= sc  # T\n    bds *= sc\n\n    if recenter:\n        poses = recenter_poses(poses)\n\n    if spherify:\n        poses, render_poses, bds = spherify_poses(poses, bds)\n\n    else:\n        c2w = poses_avg(poses)\n        up = normalize(poses[:, :3, 1].sum(0))\n\n        # Find a reasonable \"focus depth\" for this dataset\n        close_depth, inf_depth = bds.min() * .9, bds.max() * 5.\n        dt = .75\n        mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))\n        focal = mean_dz\n\n        # Get radii for spiral path\n        shrink_factor = .8\n        zdelta = close_depth * .2\n        tt = poses[:, :3, 3]  # ptstocam(poses[:3,3,:].T, c2w).T\n        rads = np.percentile(np.abs(tt), 90, 0)\n        c2w_path = c2w\n        N_views = 120\n        N_rots = 2\n        if path_zflat:\n            #             zloc = np.percentile(tt, 10, 0)[2]\n            zloc = -close_depth * .1\n            c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2]\n            rads[2] = 0.\n            N_rots = 1\n            N_views /= 2\n\n        # Generate poses for spiral path\n        render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)\n\n    render_poses = np.array(render_poses).astype(np.float32)\n\n    imgs = torch.Tensor(imgs)\n    poses = torch.Tensor(poses)\n    bds = torch.Tensor(bds)\n    render_poses = torch.Tensor(render_poses)\n\n    return imgs, poses, bds, render_poses\n\n"
  },
  {
    "path": "lpips/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport torch\n# from torch.autograd import Variable\n\nfrom .lpips import *\n"
  },
  {
    "path": "lpips/lpips.py",
    "content": "from __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\nimport numpy as np\nfrom . import pretrained_networks as pn\nimport torch.nn\n\n\ndef normalize_tensor(in_feat, eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))\n    return in_feat / (norm_factor + eps)\n\n\ndef l2(p0, p1, range=255.):\n    return .5 * np.mean((p0 / range - p1 / range) ** 2)\n\n\ndef psnr(p0, p1, peak=255.):\n    return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))\n\n\ndef dssim(p0, p1, range=255.):\n    from skimage.measure import compare_ssim\n    return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.\n\n\ndef rgb2lab(in_img, mean_cent=False):\n    from skimage import color\n    img_lab = color.rgb2lab(in_img)\n    if mean_cent:\n        img_lab[:, :, 0] = img_lab[:, :, 0] - 50\n    return img_lab\n\n\ndef tensor2np(tensor_obj):\n    # change dimension of a tensor object into a numpy array\n    return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))\n\n\ndef np2tensor(np_obj):\n    # change dimenion of np array into tensor array\n    return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))\n\n\ndef tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):\n    # image tensor to lab tensor\n    from skimage import color\n\n    img = tensor2im(image_tensor)\n    img_lab = color.rgb2lab(img)\n    if mc_only:\n        img_lab[:, :, 0] = img_lab[:, :, 0] - 50\n    if to_norm and not mc_only:\n        img_lab[:, :, 0] = img_lab[:, :, 0] - 50\n        img_lab = img_lab / 100.\n\n    return np2tensor(img_lab)\n\n\ndef tensorlab2tensor(lab_tensor, return_inbnd=False):\n    from skimage import color\n    import warnings\n    warnings.filterwarnings(\"ignore\")\n\n    lab = tensor2np(lab_tensor) * 100.\n    lab[:, :, 0] = lab[:, :, 0] + 50\n\n    rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)\n    if return_inbnd:\n        # convert back to lab, see if we match\n        lab_back = color.rgb2lab(rgb_back.astype('uint8'))\n        mask = 1. * np.isclose(lab_back, lab, atol=2.)\n        mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])\n        return im2tensor(rgb_back), mask\n    else:\n        return im2tensor(rgb_back)\n\n\ndef tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):\n    image_numpy = image_tensor[0].cpu().float().numpy()\n    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor\n    return image_numpy.astype(imtype)\n\n\ndef im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):\n    return torch.tensor((image / factor - cent)\n                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))\n\n\ndef tensor2vec(vector_tensor):\n    return vector_tensor.data.cpu().numpy()[:, :, 0, 0]\n\n\ndef voc_ap(rec, prec, use_07_metric=False):\n    \"\"\" ap = voc_ap(rec, prec, [use_07_metric])\n    Compute VOC AP given precision and recall.\n    If use_07_metric is true, uses the\n    VOC 07 11 point method (default:False).\n    \"\"\"\n    if use_07_metric:\n        # 11 point metric\n        ap = 0.\n        for t in np.arange(0., 1.1, 0.1):\n            if np.sum(rec >= t) == 0:\n                p = 0\n            else:\n                p = np.max(prec[rec >= t])\n            ap = ap + p / 11.\n    else:\n        # correct AP calculation\n        # first append sentinel values at the end\n        mrec = np.concatenate(([0.], rec, [1.]))\n        mpre = np.concatenate(([0.], prec, [0.]))\n\n        # compute the precision envelope\n        for i in range(mpre.size - 1, 0, -1):\n            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])\n\n        # to calculate area under PR curve, look for points\n        # where X axis (recall) changes value\n        i = np.where(mrec[1:] != mrec[:-1])[0]\n\n        # and sum (\\Delta recall) * prec\n        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])\n    return ap\n\ndef spatial_average(in_tens, keepdim=True):\n    return in_tens.mean([2, 3], keepdim=keepdim)\n\n\ndef upsample(in_tens, out_HW=(64, 64)):  # assumes scale factor is same for H and W\n    in_H, in_W = in_tens.shape[2], in_tens.shape[3]\n    return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)\n\n\n# Learned perceptual metric\nclass LPIPS(nn.Module):\n    def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,\n                 pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):\n        # lpips - [True] means with linear calibration on top of base network\n        # pretrained - [True] means load linear weights\n\n        super(LPIPS, self).__init__()\n        if verbose:\n            print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' %\n                  ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))\n\n        self.pnet_type = net\n        self.pnet_tune = pnet_tune\n        self.pnet_rand = pnet_rand\n        self.spatial = spatial\n        self.lpips = lpips  # false means baseline of just averaging all layers\n        self.version = version\n        self.scaling_layer = ScalingLayer()\n\n        if self.pnet_type in ['vgg', 'vgg16']:\n            net_type = pn.vgg16\n            self.chns = [64, 128, 256, 512, 512]\n        elif self.pnet_type == 'alex':\n            net_type = pn.alexnet\n            self.chns = [64, 192, 384, 256, 256]\n        elif self.pnet_type == 'squeeze':\n            net_type = pn.squeezenet\n            self.chns = [64, 128, 256, 384, 384, 512, 512]\n        self.L = len(self.chns)\n\n        self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)\n\n        if lpips:\n            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n            self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]\n            if self.pnet_type == 'squeeze':  # 7 layers for squeezenet\n                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)\n                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)\n                self.lins += [self.lin5, self.lin6]\n            self.lins = nn.ModuleList(self.lins)\n\n            if pretrained:\n                if model_path is None:\n                    import inspect\n                    import os\n                    model_path = os.path.abspath(\n                        os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))\n\n                if verbose:\n                    print('Loading model from: %s' % model_path)\n                self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)\n\n        if eval_mode:\n            self.eval()\n\n    def forward(self, in0, in1, retPerLayer=False, normalize=False):\n        if normalize:  # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]\n            in0 = 2 * in0 - 1\n            in1 = 2 * in1 - 1\n\n        # v0.0 - original release had a bug, where input was not scaled\n        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (\n        in0, in1)\n        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)\n        feats0, feats1, diffs = {}, {}, {}\n\n        for kk in range(self.L):\n            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])\n            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2\n\n        if self.lpips:\n            if self.spatial:\n                res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]\n            else:\n                res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]\n        else:\n            if self.spatial:\n                res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]\n            else:\n                res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]\n\n        val = res[0]\n        for l in range(1, self.L):\n            val += res[l]\n\n        # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)\n        # b = torch.max(self.lins[kk](feats0[kk]**2))\n        # for kk in range(self.L):\n        #     a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)\n        #     b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))\n        # a = a/self.L\n        # from IPython import embed\n        # embed()\n        # return 10*torch.log10(b/a)\n\n        if retPerLayer:\n            return val, res\n        else:\n            return val\n\n\nclass ScalingLayer(nn.Module):\n    def __init__(self):\n        super(ScalingLayer, self).__init__()\n        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])\n        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])\n\n    def forward(self, inp):\n        return (inp - self.shift) / self.scale\n\n\nclass NetLinLayer(nn.Module):\n    ''' A single linear layer which does a 1x1 conv '''\n\n    def __init__(self, chn_in, chn_out=1, use_dropout=False):\n        super(NetLinLayer, self).__init__()\n\n        layers = [nn.Dropout(), ] if use_dropout else []\n        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.model(x)\n\n\nclass Dist2LogitLayer(nn.Module):\n    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''\n\n    def __init__(self, chn_mid=32, use_sigmoid=True):\n        super(Dist2LogitLayer, self).__init__()\n\n        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ]\n        layers += [nn.LeakyReLU(0.2, True), ]\n        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ]\n        layers += [nn.LeakyReLU(0.2, True), ]\n        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ]\n        if use_sigmoid:\n            layers += [nn.Sigmoid(), ]\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, d0, d1, eps=0.1):\n        return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))\n\n\nclass BCERankingLoss(nn.Module):\n    def __init__(self, chn_mid=32):\n        super(BCERankingLoss, self).__init__()\n        self.net = Dist2LogitLayer(chn_mid=chn_mid)\n        # self.parameters = list(self.net.parameters())\n        self.loss = torch.nn.BCELoss()\n\n    def forward(self, d0, d1, judge):\n        per = (judge + 1.) / 2.\n        self.logit = self.net.forward(d0, d1)\n        return self.loss(self.logit, per)\n\n\n# L2, DSSIM metrics\nclass FakeNet(nn.Module):\n    def __init__(self, use_gpu=True, colorspace='Lab'):\n        super(FakeNet, self).__init__()\n        self.use_gpu = use_gpu\n        self.colorspace = colorspace\n\n\nclass L2(FakeNet):\n    def forward(self, in0, in1, retPerLayer=None):\n        assert (in0.size()[0] == 1)  # currently only supports batchSize 1\n\n        if self.colorspace == 'RGB':\n            (N, C, X, Y) = in0.size()\n            value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y),\n                               dim=3).view(N)\n            return value\n        elif self.colorspace == 'Lab':\n            value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),\n                             tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype(\n                'float')\n            ret_var = Variable(torch.Tensor((value,)))\n            if self.use_gpu:\n                ret_var = ret_var.cuda()\n            return ret_var\n\n\nclass DSSIM(FakeNet):\n\n    def forward(self, in0, in1, retPerLayer=None):\n        assert (in0.size()[0] == 1)  # currently only supports batchSize 1\n\n        if self.colorspace == 'RGB':\n            value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype(\n                'float')\n        elif self.colorspace == 'Lab':\n            value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),\n                                tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype(\n                'float')\n        ret_var = Variable(torch.Tensor((value,)))\n        if self.use_gpu:\n            ret_var = ret_var.cuda()\n        return ret_var\n\n\ndef print_network(net):\n    num_params = 0\n    for param in net.parameters():\n        num_params += param.numel()\n    print('Network', net)\n    print('Total number of parameters: %d' % num_params)\n"
  },
  {
    "path": "lpips/pretrained_networks.py",
    "content": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\n\n\nclass squeezenet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(squeezenet, self).__init__()\n        pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.slice6 = torch.nn.Sequential()\n        self.slice7 = torch.nn.Sequential()\n        self.N_slices = 7\n        for x in range(2):\n            self.slice1.add_module(str(x), pretrained_features[x])\n        for x in range(2, 5):\n            self.slice2.add_module(str(x), pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), pretrained_features[x])\n        for x in range(10, 11):\n            self.slice5.add_module(str(x), pretrained_features[x])\n        for x in range(11, 12):\n            self.slice6.add_module(str(x), pretrained_features[x])\n        for x in range(12, 13):\n            self.slice7.add_module(str(x), pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        h = self.slice6(h)\n        h_relu6 = h\n        h = self.slice7(h)\n        h_relu7 = h\n        vgg_outputs = namedtuple(\"SqueezeOutputs\", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])\n        out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)\n\n        return out\n\n\nclass alexnet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(alexnet, self).__init__()\n        alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(2):\n            self.slice1.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(2, 5):\n            self.slice2.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(10, 12):\n            self.slice5.add_module(str(x), alexnet_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        alexnet_outputs = namedtuple(\"AlexnetOutputs\", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])\n        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)\n\n        return out\n\n\nclass vgg16(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(vgg16, self).__init__()\n        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\"VggOutputs\", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n\n        return out\n\n\nclass resnet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True, num=18):\n        super(resnet, self).__init__()\n        if (num == 18):\n            self.net = tv.resnet18(pretrained=pretrained)\n        elif (num == 34):\n            self.net = tv.resnet34(pretrained=pretrained)\n        elif (num == 50):\n            self.net = tv.resnet50(pretrained=pretrained)\n        elif (num == 101):\n            self.net = tv.resnet101(pretrained=pretrained)\n        elif (num == 152):\n            self.net = tv.resnet152(pretrained=pretrained)\n        self.N_slices = 5\n\n        self.conv1 = self.net.conv1\n        self.bn1 = self.net.bn1\n        self.relu = self.net.relu\n        self.maxpool = self.net.maxpool\n        self.layer1 = self.net.layer1\n        self.layer2 = self.net.layer2\n        self.layer3 = self.net.layer3\n        self.layer4 = self.net.layer4\n\n    def forward(self, X):\n        h = self.conv1(X)\n        h = self.bn1(h)\n        h = self.relu(h)\n        h_relu1 = h\n        h = self.maxpool(h)\n        h = self.layer1(h)\n        h_conv2 = h\n        h = self.layer2(h)\n        h_conv3 = h\n        h = self.layer3(h)\n        h_conv4 = h\n        h = self.layer4(h)\n        h_conv5 = h\n\n        outputs = namedtuple(\"Outputs\", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])\n        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)\n\n        return out\n"
  },
  {
    "path": "metrics.py",
    "content": "from skimage import metrics\nimport torch\nimport torch.hub\nfrom lpips.lpips import LPIPS\nimport os\nimport numpy as np\n\nphotometric = {\n    \"mse\": None,\n    \"ssim\": None,\n    \"psnr\": None,\n    \"lpips\": None\n}\n\n\ndef compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor,\n                       metric=\"mse\", margin=0, mask=None):\n    \"\"\"\n    im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)\n    \"\"\"\n    if metric not in photometric.keys():\n        raise RuntimeError(f\"img_utils:: metric {metric} not recognized\")\n    if photometric[metric] is None:\n        if metric == \"mse\":\n            photometric[metric] = metrics.mean_squared_error\n        elif metric == \"ssim\":\n            photometric[metric] = metrics.structural_similarity\n        elif metric == \"psnr\":\n            photometric[metric] = metrics.peak_signal_noise_ratio\n        elif metric == \"lpips\":\n            photometric[metric] = LPIPS().cpu()\n\n    if mask is not None:\n        if mask.dim() == 3:\n            mask = mask.unsqueeze(1)\n        if mask.shape[1] == 1:\n            mask = mask.expand(-1, 3, -1, -1)\n        mask = mask.permute(0, 2, 3, 1).numpy()\n        batchsz, hei, wid, _ = mask.shape\n        if margin > 0:\n            marginh = int(hei * margin) + 1\n            marginw = int(wid * margin) + 1\n            mask = mask[:, marginh:hei - marginh, marginw:wid - marginw]\n\n    # convert from [0, 1] to [-1, 1]\n    im1t = (im1t * 2 - 1).clamp(-1, 1)\n    im2t = (im2t * 2 - 1).clamp(-1, 1)\n\n    if im1t.dim() == 3:\n        im1t = im1t.unsqueeze(0)\n        im2t = im2t.unsqueeze(0)\n    im1t = im1t.detach().cpu()\n    im2t = im2t.detach().cpu()\n\n    if im1t.shape[-1] == 3:\n        im1t = im1t.permute(0, 3, 1, 2)\n        im2t = im2t.permute(0, 3, 1, 2)\n\n    im1 = im1t.permute(0, 2, 3, 1).numpy()\n    im2 = im2t.permute(0, 2, 3, 1).numpy()\n    batchsz, hei, wid, _ = im1.shape\n    if margin > 0:\n        marginh = int(hei * margin) + 1\n        marginw = int(wid * margin) + 1\n        im1 = im1[:, marginh:hei - marginh, marginw:wid - marginw]\n        im2 = im2[:, marginh:hei - marginh, marginw:wid - marginw]\n    values = []\n\n    for i in range(batchsz):\n        if metric in [\"mse\", \"psnr\"]:\n            if mask is not None:\n                im1 = im1 * mask[i]\n                im2 = im2 * mask[i]\n            value = photometric[metric](\n                im1[i], im2[i]\n            )\n            if mask is not None:\n                hei, wid, _ = im1[i].shape\n                pixelnum = mask[i, ..., 0].sum()\n                value = value - 10 * np.log10(hei * wid / pixelnum)\n        elif metric in [\"ssim\"]:\n            value, ssimmap = photometric[\"ssim\"](\n                im1[i], im2[i], multichannel=True, full=True\n            )\n            if mask is not None:\n                value = (ssimmap * mask[i]).sum() / mask[i].sum()\n        elif metric in [\"lpips\"]:\n            value = photometric[metric](\n                im1t[i:i + 1], im2t[i:i + 1]\n            )\n        else:\n            raise NotImplementedError\n        values.append(value)\n\n    return sum(values) / len(values)\n"
  },
  {
    "path": "nerf.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom run_nerf import *\n\nfrom Spline import se3_to_SE3\n\n\nmax_iter = 200000\nT = max_iter+1\nBOUNDARY = 20\n\n\nclass Embedder:\n    def __init__(self, **kwargs):\n        self.kwargs = kwargs\n        self.create_embedding_fn()\n\n    def create_embedding_fn(self):\n        embed_fns = []\n        d = self.kwargs['input_dims']\n        out_dim = 0\n        if self.kwargs['include_input']:\n            embed_fns.append(lambda x: x)\n            out_dim += d\n\n        max_freq = self.kwargs['max_freq_log2']\n        N_freqs = self.kwargs['num_freqs']\n\n        if self.kwargs['log_sampling']:\n            freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)\n        else:\n            freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)\n\n        for freq in freq_bands:\n            for p_fn in self.kwargs['periodic_fns']:\n                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))\n                out_dim += d\n\n        self.embed_fns = embed_fns\n        self.out_dim = out_dim\n\n    def embed(self, inputs):\n        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)\n\n\ndef get_embedder(args, multires, i=0):\n    if i == -1:\n        return nn.Identity(), 3\n\n    embed_kwargs = {\n        'include_input': False if args.barf else True,\n        'input_dims': 3,\n        'max_freq_log2': multires - 1,\n        'num_freqs': multires,\n        'log_sampling': True,\n        'periodic_fns': [torch.sin, torch.cos],\n    }\n\n    embedder_obj = Embedder(**embed_kwargs)\n    embed = lambda x, eo=embedder_obj: eo.embed(x)\n    return embed, embedder_obj.out_dim\n\n\nclass Model():\n    def __init__(self):\n        super().__init__()\n\n    def build_network(self, args):\n        self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)\n\n        return self.graph\n\n    def setup_optimizer(self, args):\n        grad_vars = list(self.graph.nerf.parameters())\n        if args.N_importance>0:\n            grad_vars += list(self.graph.nerf_fine.parameters())\n        self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))\n\n        return self.optim\n\n\nclass NeRF(nn.Module):\n    def __init__(self, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=False):\n        super().__init__()\n        self.D = D\n        self.W = W\n        self.input_ch = input_ch\n        self.input_ch_views = input_ch_views\n        self.skips = skips\n        self.use_viewdirs = use_viewdirs\n\n        # network\n        self.pts_linears = nn.ModuleList(\n            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in\n                                        range(D - 1)])\n\n        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)])\n\n        if use_viewdirs:\n            self.feature_linear = nn.Linear(W, W)\n            self.alpha_linear = nn.Linear(W, 1)\n            self.rgb_linear = nn.Linear(W // 2, 3)\n        else:\n            self.output_linear = nn.Linear(W, output_ch)\n\n    def forward(self, barf_i, pts, viewdirs, args):\n        embed_fn, input_ch = get_embedder(args, args.multires, args.i_embed)\n        embeddirs_fn = None\n        if args.use_viewdirs:\n            embeddirs_fn, input_ch_views = get_embedder(args, args.multires_views, args.i_embed)\n        pts_flat = torch.reshape(pts, [-1, pts.shape[-1]])\n        embedded = embed_fn(pts_flat)\n\n        if viewdirs is not None:\n            input_dirs = viewdirs[:, None].expand(pts.shape)\n            input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])\n            embedded_dirs = embeddirs_fn(input_dirs_flat)\n            embedded = torch.cat([embedded, embedded_dirs], -1)\n\n        input_pts, input_views = torch.split(embedded, [self.input_ch, self.input_ch_views], dim=-1)\n        h = input_pts\n        for i, l in enumerate(self.pts_linears):\n            h = self.pts_linears[i](h)\n            h = F.relu(h)\n            if i in self.skips:\n                h = torch.cat([input_pts, h], -1)\n\n        if self.use_viewdirs:\n            alpha = self.alpha_linear(h)\n            feature = self.feature_linear(h)\n            h = torch.cat([feature, input_views], -1)\n\n            for i, l in enumerate(self.views_linears):\n                h = self.views_linears[i](h)\n                h = F.relu(h)\n\n            rgb = self.rgb_linear(h)\n            outputs = torch.cat([rgb, alpha], -1)\n        else:\n            outputs = self.output_linear(h)\n\n        outputs = torch.reshape(outputs, list(pts.shape[:-1]) + [outputs.shape[-1]])\n\n        return outputs\n\n    def raw2output(self, raw, z_vals, rays_d, raw_noise_std=0.0):\n\n        raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists)\n\n        dists = z_vals[..., 1:] - z_vals[..., :-1]\n        dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1)\n\n        dists = dists * torch.norm(rays_d[..., None, :], dim=-1)\n\n        rgb = torch.sigmoid(raw[..., :3])\n        noise = 0.\n        if raw_noise_std > 0.:\n            noise = torch.randn(raw[..., 3].shape) * raw_noise_std\n\n        alpha = raw2alpha(raw[..., 3] + noise, dists)\n        weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:,:-1]\n        rgb_map = torch.sum(weights[..., None] * rgb, -2)\n\n        depth_map = torch.sum(weights * z_vals, -1)\n        # disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))\n        disp_map = torch.max(1e-6 * torch.ones_like(depth_map), depth_map / (torch.sum(weights, -1)+1e-6))\n        acc_map = torch.sum(weights, -1)\n\n        sigma = F.relu(raw[..., 3] + noise)\n\n        return rgb_map, disp_map, acc_map, weights, depth_map, sigma\n\n\nclass Graph(nn.Module):\n\n    def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=False):\n        super().__init__()\n        self.nerf = NeRF(D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)\n        if args.N_importance > 0:\n            self.nerf_fine = NeRF(D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)\n\n    def forward(self, i, img_idx, poses_num, H, W, K, args, novel_view=False):\n        if novel_view:\n            poses_sharp = se3_to_SE3(self.se3_sharp.weight)\n            ray_idx_sharp = torch.randperm(H * W)[:300]\n            ret = self.render(i, poses_sharp, ray_idx_sharp, H, W, K, args)\n            return ret, ray_idx_sharp, poses_sharp\n\n        spline_poses = self.get_pose(i, img_idx, args)\n        ray_idx = torch.randperm(H * W)[:args.N_rand // poses_num]\n\n        '''\n        # only used in distorted data\n        # aims to prevent the ray_idx lying on the edges\n        for j in range(ray_idx.shape[0]):\n            h = torch.randperm(H - 1)[0]\n            w = torch.randperm(W - 1)[0]\n            while (h < BOUNDARY or h > (H - 1 - BOUNDARY) or w < BOUNDARY or w > (W - 1 - BOUNDARY)):\n                h = torch.randperm(H - 1)[0]\n                w = torch.randperm(W - 1)[0]\n            index = h * W + w\n            ray_idx[j] = index\n        '''\n\n        ret = self.render(i, spline_poses, ray_idx, H, W, K, args, near=0, far=1.0, ray_idx_tv=None, training=True)\n        if (i % args.i_img == 0 or i % args.i_novel_view == 0) and i > 0:\n            if args.deblur_images % 2 == 0:\n                all_poses = self.get_pose_even(i, torch.arange(self.se3.weight.shape[0]), args.deblur_images)\n            else:\n                all_poses = self.get_pose(i, torch.arange(self.se3.weight.shape[0]), args)\n            return ret, ray_idx, spline_poses, all_poses\n        else:\n            return ret, ray_idx, spline_poses\n\n    def get_pose(self, i, img_idx, args):\n\n        return i\n\n    def get_gt_pose(self, poses, args):\n\n        return poses\n\n    def render(self, barf_i, poses, ray_idx, H, W, K, args, near=0., far=1., ray_idx_tv=None, training=False):\n        if training:\n            ray_idx_ = ray_idx.repeat(poses.shape[0])\n            poses = poses.unsqueeze(1).repeat(1, ray_idx.shape[0], 1, 1).reshape(-1, 3, 4)\n            j = ray_idx_.reshape(-1, 1).squeeze() // W\n            i = ray_idx_.reshape(-1, 1).squeeze() % W\n            rays_o_, rays_d_ = get_specific_rays(i, j, K, poses)\n            rays_o_d = torch.stack([rays_o_, rays_d_], 0)\n            batch_rays = torch.permute(rays_o_d, [1, 0, 2])\n        \n        else:\n            rays_list = []\n            for p in poses[:, :3, :4]:\n                rays_o_, rays_d_ = get_rays(H, W, K, p)\n                rays_o_d = torch.stack([rays_o_, rays_d_], 0)\n                rays_list.append(rays_o_d)\n\n            rays = torch.stack(rays_list, 0)\n            rays = rays.reshape(-1, 2, H * W, 3)\n            rays = torch.permute(rays, [0, 2, 1, 3])\n\n            batch_rays = rays[:, ray_idx]\n        batch_rays = batch_rays.reshape(-1, 2, 3)\n        batch_rays = torch.transpose(batch_rays, 0, 1)\n\n        # get standard rays\n        rays_o, rays_d = batch_rays\n        if args.use_viewdirs:\n            viewdirs = rays_d\n            viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)\n            viewdirs = torch.reshape(viewdirs, [-1, 3]).float()\n\n        sh = rays_d.shape\n        if args.ndc:\n            rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)\n\n        # Create ray batch\n        rays_o = torch.reshape(rays_o, [-1, 3]).float()\n        rays_d = torch.reshape(rays_d, [-1, 3]).float()\n\n        near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])\n        rays = torch.cat([rays_o, rays_d, near, far], -1)\n\n        if args.use_viewdirs:\n            rays = torch.cat([rays, viewdirs], -1)\n\n        N_rays = rays.shape[0]\n        rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]\n        viewdirs = rays[:, -3:] if rays.shape[-1] > 8 else None\n        bounds = torch.reshape(rays[..., 6:8], [-1, 1, 2])\n        near, far = bounds[..., 0], bounds[..., 1]\n\n        t_vals = torch.linspace(0., 1., steps=args.N_samples)\n        z_vals = near * (1. - t_vals) + far * (t_vals)\n        z_vals = z_vals.expand([N_rays, args.N_samples])\n\n        # perturb\n        mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])\n        upper = torch.cat([mids, z_vals[..., -1:]], -1)\n        lower = torch.cat([z_vals[..., :1], mids], -1)\n        # stratified samples in those intervals\n        t_rand = torch.rand(z_vals.shape)\n        z_vals = lower + (upper - lower) * t_rand\n        pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]\n\n        raw_output = self.nerf.forward(barf_i, pts, viewdirs, args)\n\n        rgb_map, disp_map, acc_map, weights, depth_map, sigma = self.nerf.raw2output(raw_output, z_vals, rays_d)\n\n        if args.N_importance > 0:\n\n            rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map\n\n            z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])\n            z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], args.N_importance)\n            z_samples = z_samples.detach()\n\n            z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)\n            pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]\n\n            raw_output = self.nerf_fine.forward(barf_i, pts, viewdirs, args)\n            rgb_map, disp_map, acc_map, weights, depth_map, sigma = self.nerf_fine.raw2output(raw_output, z_vals, rays_d)\n\n        ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}\n        if args.N_importance > 0:\n            ret['rgb0'] = rgb_map_0\n            ret['disp0'] = disp_map_0\n            ret['acc0'] = acc_map_0\n            ret['sigma'] = sigma\n\n        return ret\n\n    @torch.no_grad()\n    def render_video(self, barf_i, poses, H, W, K, args):\n        all_ret = {}\n        ray_idx = torch.arange(0, H*W)\n        for i in range(0, ray_idx.shape[0], args.chunk):\n            ret = self.render(barf_i, poses, ray_idx[i:i+args.chunk], H, W, K, args)\n            for k in ret:\n                if k not in all_ret:\n                    all_ret[k] = []\n                all_ret[k].append(ret[k])\n        all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}\n\n        for k in all_ret:\n            k_sh = list([H, W]) + list(all_ret[k].shape[1:])\n            all_ret[k] = torch.reshape(all_ret[k], k_sh)\n        return all_ret\n"
  },
  {
    "path": "novel_view_test.py",
    "content": "import nerf\nimport torch.nn\n\n\nclass Model(nerf.Model):\n    def __init__(self, se3_start, graph):\n        super().__init__()\n        self.start = se3_start\n        self.graph_fixed = graph\n\n    def build_network(self, args):\n        self.graph_fixed.se3_sharp = torch.nn.Embedding(self.start.shape[0], 6)  # 22和25\n        self.graph_fixed.se3_sharp.weight.data = torch.nn.Parameter(self.start)\n\n        return self.graph_fixed\n\n    def setup_optimizer(self, args):\n        grad_vars_se3 = list(self.graph_fixed.se3_sharp.parameters())\n        self.optim_se3_sharp = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)\n\n        return self.optim_se3_sharp"
  },
  {
    "path": "optimize_pose_cubic.py",
    "content": "import torch.nn\n\nimport Spline\nimport nerf\n\n\nclass Model(nerf.Model):\n    def __init__(self, se3_0, se3_1, se3_2, se3_3):\n        super().__init__()\n        self.se3_0 = se3_0\n        self.se3_1 = se3_1\n        self.se3_2 = se3_2\n        self.se3_3 = se3_3\n\n    def build_network(self, args):\n        self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)\n        self.graph.se3 = torch.nn.Embedding(self.se3_0.shape[0], 6*4)\n\n        start_end = torch.cat([self.se3_0, self.se3_1, self.se3_2, self.se3_3], -1)\n        self.graph.se3.weight.data = torch.nn.Parameter(start_end)\n\n        return self.graph\n\n    def setup_optimizer(self, args):\n        grad_vars = list(self.graph.nerf.parameters())\n        if args.N_importance > 0:\n            grad_vars += list(self.graph.nerf_fine.parameters())\n        self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))\n\n        grad_vars_se3 = list(self.graph.se3.parameters())\n        self.optim_se3 = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)\n\n        return self.optim, self.optim_se3\n\n\nclass Graph(nerf.Graph):\n    def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True):\n        super().__init__(args, D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)\n        self.pose_eye = torch.eye(3, 4)\n        self.se3_start = None\n        self.se3_end = None\n\n    def get_pose(self, i, img_idx, args):\n        se3_0 = self.se3.weight[:, :6][img_idx]\n        se3_1 = self.se3.weight[:, 6:12][img_idx]\n        se3_2 = self.se3.weight[:, 12:18][img_idx]\n        se3_3 = self.se3.weight[:, 18:][img_idx]\n\n        pose_nums = torch.arange(args.deblur_images).reshape(1, -1).repeat(se3_0.shape[0], 1)\n        seg_pos_x = torch.arange(se3_0.shape[0]).reshape([se3_0.shape[0], 1]).repeat(1, args.deblur_images)\n\n        se3_0 = se3_0[seg_pos_x, :]\n        se3_1 = se3_1[seg_pos_x, :]\n        se3_2 = se3_2[seg_pos_x, :]\n        se3_3 = se3_3[seg_pos_x, :]\n\n        spline_poses = Spline.SplineN_cubic(se3_0, se3_1, se3_2, se3_3, pose_nums, args.deblur_images)\n        return spline_poses\n\n    def get_pose_even(self, i, img_idx, num):\n        deblur_images_num = num+1\n        se3_0 = self.se3.weight[:, :6][img_idx]\n        se3_1 = self.se3.weight[:, 6:12][img_idx]\n        se3_2 = self.se3.weight[:, 12:18][img_idx]\n        se3_3 = self.se3.weight[:, 18:][img_idx]\n\n        pose_nums = torch.arange(deblur_images_num).reshape(1, -1).repeat(se3_0.shape[0], 1)\n        seg_pos_x = torch.arange(se3_0.shape[0]).reshape([se3_0.shape[0], 1]).repeat(1, deblur_images_num)\n\n        se3_0 = se3_0[seg_pos_x, :]\n        se3_1 = se3_1[seg_pos_x, :]\n        se3_2 = se3_2[seg_pos_x, :]\n        se3_3 = se3_3[seg_pos_x, :]\n\n        spline_poses = Spline.SplineN_cubic(se3_0, se3_1, se3_2, se3_3, pose_nums, deblur_images_num)\n        return spline_poses\n\n    def get_gt_pose(self, poses, args):\n        a = self.pose_eye\n        return poses\n"
  },
  {
    "path": "optimize_pose_linear.py",
    "content": "import torch.nn\n\nimport Spline\nimport nerf\n\n\nclass Model(nerf.Model):\n    def __init__(self, se3_start, se3_end):\n        super().__init__()\n        self.start = se3_start\n        self.end = se3_end\n\n    def build_network(self, args):\n        self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)\n        self.graph.se3 = torch.nn.Embedding(self.start.shape[0], 6*2)\n\n        start_end = torch.cat([self.start, self.end], -1)\n        self.graph.se3.weight.data = torch.nn.Parameter(start_end)\n\n        return self.graph\n\n    def setup_optimizer(self, args):\n        grad_vars = list(self.graph.nerf.parameters())\n        if args.N_importance > 0:\n            grad_vars += list(self.graph.nerf_fine.parameters())\n        self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))\n\n        grad_vars_se3 = list(self.graph.se3.parameters())\n        self.optim_se3 = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)\n\n        return self.optim, self.optim_se3\n\n\nclass Graph(nerf.Graph):\n    def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True):\n        super().__init__(args, D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)\n        self.pose_eye = torch.eye(3, 4)\n        self.se3_start = None\n        self.se3_end = None\n\n    def get_pose(self, i, img_idx, args):\n        se3_start = self.se3.weight[:, :6][img_idx]\n        se3_end = self.se3.weight[:, 6:][img_idx]\n        pose_nums = torch.arange(args.deblur_images).reshape(1, -1).repeat(se3_start.shape[0], 1)\n        seg_pos_x = torch.arange(se3_start.shape[0]).reshape([se3_start.shape[0], 1]).repeat(1, args.deblur_images)\n\n        se3_start = se3_start[seg_pos_x, :]\n        se3_end = se3_end[seg_pos_x, :]\n\n        spline_poses = Spline.SplineN_linear(se3_start, se3_end, pose_nums, args.deblur_images)\n        return spline_poses\n    \n    def get_pose_even(self, i, img_idx, num):\n        deblur_images_num = num+1\n        se3_start = self.se3.weight[:, :6][img_idx]\n        se3_end = self.se3.weight[:, 6:][img_idx]\n        pose_nums = torch.arange(deblur_images_num).reshape(1, -1).repeat(se3_start.shape[0],1)\n        seg_pos_x = torch.arange(se3_start.shape[0]).reshape([se3_start.shape[0], 1]).repeat(1, deblur_images_num)\n\n        se3_start = se3_start[seg_pos_x, :]\n        se3_end = se3_end[seg_pos_x, :]\n\n        spline_poses = Spline.SplineN_linear(se3_start, se3_end, pose_nums, deblur_images_num)\n        return spline_poses\n\n    def get_gt_pose(self, poses, args):\n        a = self.pose_eye\n        return poses\n"
  },
  {
    "path": "requirements.txt",
    "content": "configargparse\nimageio<2.28.0,>=2.26.0\nimageio-ffmpeg\nmatplotlib\nnumpy\nscikit-learn<1\nscikit-image<0.20,>=0.19\ntorch>=1.8\ntorchvision>=0.9.1\ntqdm"
  },
  {
    "path": "run_nerf.py",
    "content": "import os, sys\nimport numpy as np\nimport imageio\nimport json\nimport random\nimport time\nimport torch\nfrom run_nerf_helpers import *\nfrom Spline import *\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom tqdm import tqdm, trange\n\nimport matplotlib.pyplot as plt\nfrom load_llff import load_llff_data\nimport torchvision.transforms.functional as torchvision_F\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nnp.random.seed(0)\nDEBUG = False\n\n\ndef config_parser():\n\n    import configargparse\n    parser = configargparse.ArgumentParser()\n    parser.add_argument('--config', is_config_file=True, default='configs/cozy2room.txt',\n                        help='config file path')\n    parser.add_argument(\"--expname\", type=str,\n                        help='experiment name')\n    parser.add_argument(\"--basedir\", type=str, default='./logs/',\n                        help='where to store ckpts and logs')\n    parser.add_argument(\"--datadir\", type=str, default='./data/llff/fern',\n                        help='input data directory')\n\n    # training options\n    parser.add_argument(\"--N_iters\", type=int, default=200000,\n                        help='the number of sharp images one blur image corresponds to')\n    parser.add_argument(\"--deblur_images\", type=int, default=5,\n                        help='the number of sharp images one blur image corresponds to')\n    parser.add_argument(\"--skip\", type=int, default=8,\n                        help='original llffhold before concatenate images')\n    parser.add_argument(\"--netdepth\", type=int, default=8,\n                        help='layers in network')\n    parser.add_argument(\"--netwidth\", type=int, default=256,\n                        help='channels per layer')\n    parser.add_argument(\"--netdepth_fine\", type=int, default=8,\n                        help='layers in fine network')\n    parser.add_argument(\"--netwidth_fine\", type=int, default=256,\n                        help='channels per layer in fine network')\n    parser.add_argument(\"--N_rand\", type=int, default=32*32*4,\n                        help='batch size (number of random rays per gradient step)')\n    parser.add_argument(\"--lrate\", type=float, default=5e-4,\n                        help='learning rate')\n    parser.add_argument(\"--pose_lrate\", type=float, default=1e-3,\n                        help='learning rate')\n    parser.add_argument(\"--lrate_decay\", type=int, default=200,\n                        help='exponential learning rate decay (in 1000 steps)')\n    parser.add_argument(\"--chunk\", type=int, default=1024*2,\n                        help='number of rays processed in parallel, decrease if running out of memory')\n    parser.add_argument(\"--netchunk\", type=int, default=1024*32,\n                        help='number of pts sent through network in parallel, decrease if running out of memory')\n    parser.add_argument(\"--no_batching\", action='store_true',\n                        help='only take random rays from 1 image at a time')\n    parser.add_argument(\"--no_reload\", action='store_true',\n                        help='do not reload weights from saved ckpt')\n    parser.add_argument(\"--ft_path\", type=str, default=None,\n                        help='specific weights npy file to reload for coarse network')\n\n    # rendering options\n    parser.add_argument(\"--N_samples\", type=int, default=64,\n                        help='number of coarse samples per ray')\n    parser.add_argument(\"--N_importance\", type=int, default=0,\n                        help='number of additional fine samples per ray')\n    parser.add_argument(\"--perturb\", type=float, default=1.,\n                        help='set to 0. for no jitter, 1. for jitter')\n    parser.add_argument(\"--use_viewdirs\", action='store_true',\n                        help='use full 5D input instead of 3D')\n    parser.add_argument(\"--i_embed\", type=int, default=0,\n                        help='set 0 for default positional encoding, -1 for none')\n    parser.add_argument(\"--multires\", type=int, default=10,\n                        help='log2 of max freq for positional encoding (3D location)')\n    parser.add_argument(\"--multires_views\", type=int, default=4,\n                        help='log2 of max freq for positional encoding (2D direction)')\n    parser.add_argument(\"--raw_noise_std\", type=float, default=0.,\n                        help='std dev of noise added to regularize sigma_a output, 1e0 recommended')\n\n    parser.add_argument(\"--render_only\", action='store_true',\n                        help='do not optimize, reload weights and render out render_poses path')\n    parser.add_argument(\"--render_test\", action='store_true',\n                        help='render the test set instead of render_poses path')\n    parser.add_argument(\"--render_factor\", type=int, default=0,\n                        help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')\n    parser.add_argument(\"--ndc\", type=bool, default=True,\n                        help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')\n\n    # training options\n    parser.add_argument(\"--precrop_iters\", type=int, default=0,\n                        help='number of steps to train on central crops')\n    parser.add_argument(\"--precrop_frac\", type=float,\n                        default=.5, help='fraction of img taken for central crops')\n\n    # dataset options\n    parser.add_argument(\"--dataset_type\", type=str, default='llff',\n                        help='options: llff / blender / deepvoxels')\n    parser.add_argument(\"--testskip\", type=int, default=8,\n                        help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')\n\n    ## deepvoxels flags\n    parser.add_argument(\"--shape\", type=str, default='greek',\n                        help='options : armchair / cube / greek / vase')\n\n    ## blender flags\n    parser.add_argument(\"--white_bkgd\", action='store_true',\n                        help='set to render synthetic data on a white bkgd (always use for dvoxels)')\n    parser.add_argument(\"--half_res\", action='store_true',\n                        help='load blender synthetic data at 400x400 instead of 800x800')\n\n    ## llff flags\n    parser.add_argument(\"--factor\", type=int, default=8,\n                        help='downsample factor for LLFF images')\n    parser.add_argument(\"--no_ndc\", action='store_true',\n                        help='do not use normalized device coordinates (set for non-forward facing scenes)')\n    parser.add_argument(\"--lindisp\", action='store_true',\n                        help='sampling linearly in disparity rather than depth')\n    parser.add_argument(\"--spherify\", action='store_true',\n                        help='set for spherical 360 scenes')\n    parser.add_argument(\"--llffhold\", type=int, default=8,\n                        help='will take every 1/N images as LLFF test set, paper uses 8')\n\n    # logging/saving options\n    parser.add_argument(\"--i_print\",   type=int, default=100,\n                        help='frequency of console printout and metric loggin')\n    parser.add_argument(\"--i_img\",     type=int, default=15000,\n                        help='frequency of tensorboard image logging')\n    parser.add_argument(\"--i_weights\", type=int, default=10000,\n                        help='frequency of weight ckpt saving')\n    parser.add_argument(\"--i_testset\", type=int, default=50000,\n                        help='frequency of testset saving')\n    parser.add_argument(\"--i_video\",   type=int, default=30000,\n                        help='frequency of render_poses video saving')\n    parser.add_argument(\"--load_weights\", action='store_true',\n                        help='frequency of weight ckpt loading')\n\n    # barf: up & down\n    parser.add_argument(\"--barf\", action='store_true',\n                        help='barf')\n    parser.add_argument(\"--barf_start\", type=float, default=0.1,\n                        help='barf start')\n    parser.add_argument(\"--barf_end\", type=float, default=0.9,\n                        help='barf start')\n\n    # test option\n    parser.add_argument(\"--only_optim_one\", action='store_true', default=False,\n                        help='frequency of weight ckpt loading')\n\n    parser.add_argument(\"--split_train_data\", action='store_true', default=False,\n                        help='frequency of weight ckpt loading')\n    \n    parser.add_argument(\"--weight_iter\", type=int, default=20000,\n                        help='weight_iter')\n    \n    # pose noise\n    parser.add_argument(\"--pose_noise\", type=float, default=0.1,\n                        help='random noise of pose')\n\n    # linear\n    parser.add_argument(\"--linear\", action='store_true', default=False,\n                        help='linear or cubic spline')\n\n    # novel view\n    parser.add_argument(\"--novel_view\", action='store_true', default=False,\n                        help='novel view test')\n    parser.add_argument(\"--i_novel_view\", type=int, default=200000,\n                        help='novel view iter')\n    parser.add_argument(\"--factor_pose_novel\", type=float, default=2.0,\n                        help='factor of learning rate')\n    parser.add_argument(\"--N_novel_view\", type=int, default=20000,\n                        help='novel view iter for optimizing poses')\n\n    return parser\n\n\n"
  },
  {
    "path": "run_nerf_helpers.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom tqdm import tqdm, trange\nimport os\nimport imageio\n\n\n# Misc\nimg2mse = lambda x, y : torch.mean((x - y) ** 2)\nimg2se = lambda x, y : (x - y) ** 2\nmse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))  # logab = logcb / logca\nto8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)\nto8b_tensor = lambda x : (255*torch.clip(x,0,1)).type(torch.int)\n\n\ndef imread(f):\n    if f.endswith('png'):\n        return imageio.imread(f, ignoregamma=True)\n    else:\n        return imageio.imread(f)\n\ndef load_imgs(path):\n    imgfiles = [os.path.join(path, f) for f in sorted(os.listdir(path)) if\n                f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]\n    imgs = [imread(f)[..., :3] / 255. for f in imgfiles]\n    imgs = np.stack(imgs, -1)\n    imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)\n    imgs = imgs.astype(np.float32)\n    imgs = torch.tensor(imgs).cuda()\n\n    return imgs\n\n\n# Ray helpers\ndef get_rays(H, W, K, c2w):\n    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))  # pytorch's meshgrid has indexing='ij'\n    i = i.t()\n    j = j.t()\n    dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)\n    # Rotate ray directions from camera frame to the world frame\n    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]\n    # Translate camera frame's origin to the world frame. It is the origin of all rays.\n    rays_o = c2w[:3,-1].expand(rays_d.shape)\n    return rays_o, rays_d\n\n\ndef ndc_rays(H, W, focal, near, rays_o, rays_d):\n    # Shift ray origins to near plane\n    t = -(near + rays_o[...,2]) / rays_d[...,2]\n    rays_o = rays_o + t[...,None] * rays_d\n    \n    # Projection\n    o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]\n    o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]\n    o2 = 1. + 2. * near / rays_o[...,2]\n\n    d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])\n    d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])\n    d2 = -2. * near / rays_o[...,2]\n    \n    rays_o = torch.stack([o0,o1,o2], -1)\n    rays_d = torch.stack([d0,d1,d2], -1)\n    \n    return rays_o, rays_d\n\n\n# Hierarchical sampling (section 5.2)\ndef sample_pdf(bins, weights, N_samples, det=False, pytest=False):\n    # Get pdf\n    weights = weights + 1e-5 # prevent nans\n    pdf = weights / torch.sum(weights, -1, keepdim=True)\n    cdf = torch.cumsum(pdf, -1)\n    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)  # (batch, len(bins))\n\n    # Take uniform samples\n    if det:\n        u = torch.linspace(0., 1., steps=N_samples)\n        u = u.expand(list(cdf.shape[:-1]) + [N_samples])\n    else:\n        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])\n\n    # Pytest, overwrite u with numpy's fixed random numbers\n    if pytest:\n        np.random.seed(0)\n        new_shape = list(cdf.shape[:-1]) + [N_samples]\n        if det:\n            u = np.linspace(0., 1., N_samples)\n            u = np.broadcast_to(u, new_shape)\n        else:\n            u = np.random.rand(*new_shape)\n        u = torch.Tensor(u)\n\n    # Invert CDF\n    u = u.contiguous()\n    inds = torch.searchsorted(cdf, u, right=True)\n    below = torch.max(torch.zeros_like(inds-1), inds-1)\n    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)\n    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)\n\n    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)\n    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)\n    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]\n    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)\n    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)\n\n    denom = (cdf_g[...,1]-cdf_g[...,0])\n    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)\n    t = (u-cdf_g[...,0])/denom\n    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])\n\n    return samples\n\n\ndef render_video_test(i_, graph, render_poses, H, W, K, args):\n    rgbs = []\n    disps = []\n    # t = time.time()\n    for i, pose in enumerate(tqdm(render_poses)):\n        # print(i, time.time() - t)\n        # t = time.time()\n        pose = pose[None, :3, :4]\n        ret = graph.render_video(i_, pose[:3, :4], H, W, K, args)\n        rgbs.append(ret['rgb_map'].cpu().numpy())\n        disps.append(ret['disp_map'].cpu().numpy())\n        if i==0:\n            print(ret['rgb_map'].shape, ret['disp_map'].shape)\n    rgbs = np.stack(rgbs, 0)\n    disps = np.stack(disps, 0)\n\n    return rgbs, disps\n\n\nto8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)\n\n\ndef render_image_test(i, graph, render_poses, H, W, K, args, novel_view=False, need_depth=False):\n    if novel_view:\n        img_dir = os.path.join(args.basedir, args.expname, 'img_novel_{:06d}'.format(i))\n    else:\n        img_dir = os.path.join(args.basedir, args.expname, 'img_test_{:06d}'.format(i))\n    os.makedirs(img_dir, exist_ok=True)\n    imgs =[]\n\n    for j, pose in enumerate(tqdm(render_poses)):\n        # print(i, time.time() - t)\n        # t = time.time()\n        pose = pose[None, :3, :4]\n        ret = graph.render_video(i, pose[:3, :4], H, W, K, args)\n        imgs.append(ret['rgb_map'])\n        rgbs = ret['rgb_map'].cpu().numpy()\n        rgb8 = to8b(rgbs)\n        imageio.imwrite(os.path.join(img_dir, 'rgb_{:03d}.png'.format(j)), rgb8)\n        if need_depth:\n            depths = ret['disp_map'].cpu().numpy()\n            depths_ = depths/np.max(depths)\n            depth8 = to8b(depths_)\n            imageio.imwrite(os.path.join(img_dir, 'depth_{:03d}.png'.format(j)), depth8)\n    imgs = torch.stack(imgs, 0)\n    return imgs\n\n\ndef init_weights(linear):\n    # use Xavier init instead of Kaiming init\n    torch.nn.init.kaiming_normal_(linear.weight)\n    torch.nn.init.zeros_(linear.bias)\n\n\ndef init_nerf(nerf):\n    for linear_pt in nerf.pts_linears:\n        init_weights(linear_pt)\n\n    for linear_view in nerf.views_linears:\n        init_weights(linear_view)\n\n    init_weights(nerf.feature_linear)\n\n    init_weights(nerf.alpha_linear)\n\n    init_weights(nerf.rgb_linear)\n\n# Ray helpers only get specific rays\ndef get_specific_rays(i, j, K, c2w):\n    # i, j = torch.meshgrid(torch.linspace(0, W - 1, W),\n    #                       torch.linspace(0, H - 1, H))  # pytorch's meshgrid has indexing='ij'\n    # i = i.t()\n    # j = j.t()\n    dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)\n    # Rotate ray directions from camera frame to the world frame\n    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[..., :3, :3], -1)\n    # dot product, equals to: [c2w.dot(dir) for dir in dirs]\n    # Translate camera frame's origin to the world frame. It is the origin of all rays.\n    rays_o = c2w[..., :3, -1]\n    return rays_o, rays_d\n\n\ndef save_render_pose(poses, path):\n    poses_np = poses.cpu().detach().numpy()\n    N = poses_np.shape[0]\n    bottom = np.reshape([0., 0., 0., 1.], [1, 4])\n    bottom_all = np.expand_dims(bottom, 0).repeat(N, axis=0)\n    poses_Rt = np.concatenate([poses_np, bottom_all], 1)\n\n    poses_txt = os.path.join(path, 'poses_render.txt')\n\n    for j in range(poses_np.shape[0]):\n        poses_flat = poses_Rt[j].reshape(16, 1).squeeze()\n        for k in range(16):\n            with open(poses_txt, 'a') as outfile:\n                if k == 0:\n                    outfile.write(f\"pose{j} \")\n                if k != 15:\n                    outfile.write(f\"{poses_flat[k]} \")\n                if k == 15:\n                    outfile.write(f\"{poses_flat[k]}\\n\")\n"
  },
  {
    "path": "test.py",
    "content": "import torch\n\nfrom nerf import *\nimport optimize_pose_linear, optimize_pose_cubic\nimport torchvision.transforms.functional as torchvision_F\n\nimport matplotlib.pyplot as plt\n\nfrom metrics import compute_img_metric\nimport novel_view_test\n\n\ndef test():\n    parser = config_parser()\n    args = parser.parse_args()\n    print('spline numbers: ', args.deblur_images)\n\n    imgs_sharp_dir = os.path.join(args.datadir, 'images_test')\n    imgs_sharp = load_imgs(imgs_sharp_dir)\n\n    # Load data images and groundtruth\n    K = None\n    if args.dataset_type == 'llff':\n        images_all, poses_start, bds_start, render_poses = load_llff_data(args.datadir, pose_state=None,\n                                                                      factor=args.factor, recenter=True,\n                                                                      bd_factor=.75, spherify=args.spherify)\n        hwf = poses_start[0, :3, -1]\n\n        # split train/val/test\n        if args.novel_view:\n            i_test = torch.arange(0, images_all.shape[0], args.llffhold)\n        else:\n            i_test = torch.tensor([100]).long()\n        i_val = i_test\n        i_train = torch.Tensor([i for i in torch.arange(int(images_all.shape[0])) if\n                                (i not in i_test and i not in i_val)]).long()\n\n        # train data\n        images = images_all[i_train]\n        # novel view data\n        if args.novel_view:\n            images_novel = images_all[i_test]\n        # gt data\n        imgs_sharp = imgs_sharp\n\n        # get poses\n        poses_end = poses_start\n        poses_start_se3 = SE3_to_se3_N(poses_start[:, :3, :4])\n        poses_end_se3 = poses_start_se3\n        poses_org = poses_start.repeat(args.deblur_images, 1, 1)\n        poses = poses_org[:, :, :4]\n\n        print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)\n\n        print('DEFINING BOUNDS')\n        if args.no_ndc:\n            near = torch.min(bds_start) * .9\n            far = torch.max(bds_start) * 1.\n\n        else:\n            near = 0.\n            far = 1.\n        print('NEAR FAR', near, far)\n\n    else:\n        print('Unknown dataset type', args.dataset_type, 'exiting')\n        return\n\n    # Cast intrinsics to right types\n    H, W, focal = hwf\n    H, W = int(H), int(W)\n    hwf = [H, W, focal]\n\n    if K is None:\n        K = torch.Tensor([\n            [focal, 0, 0.5 * W],\n            [0, focal, 0.5 * H],\n            [0, 0, 1]\n        ])\n\n    # Create log dir and copy the config file\n    basedir = args.basedir\n    expname = args.expname\n    test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt')\n    test_metric_file_novel = os.path.join(basedir, expname, 'test_metrics_novel.txt')\n    # print_file = os.path.join(basedir, expname, 'print.txt')\n    os.makedirs(os.path.join(basedir, expname), exist_ok=True)\n    f = os.path.join(basedir, expname, 'args.txt')\n    with open(f, 'w') as file:\n        for arg in sorted(vars(args)):\n            attr = getattr(args, arg)\n            file.write('{} = {}\\n'.format(arg, attr))\n    if args.config is not None:\n        f = os.path.join(basedir, expname, 'config.txt')\n        with open(f, 'w') as file:\n            file.write(open(args.config, 'r').read())\n\n    if args.linear:\n        print('Linear Spline Model Loading!')\n        model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)\n    else:\n        print('Cubic Spline Model Loading!')\n        model = optimize_pose_cubic.Model(poses_start_se3, poses_start_se3, poses_start_se3, poses_start_se3)\n    graph = model.build_network(args)\n    optimizer, optimizer_se3 = model.setup_optimizer(args)\n    path = os.path.join(basedir, expname, '{:06d}.tar'.format(args.weight_iter))\n    graph_ckpt = torch.load(path)\n    graph.load_state_dict(graph_ckpt['graph'])\n    optimizer.load_state_dict(graph_ckpt['optimizer'])\n    optimizer_se3.load_state_dict(graph_ckpt['optimizer_se3'])\n    global_step = graph_ckpt['global_step']\n\n    if args.deblur_images % 2 == 0:\n        all_poses = graph.get_pose_even(0, torch.arange(graph.se3.weight.shape[0]), args.deblur_images)\n    else:\n        all_poses = graph.get_pose(0, torch.arange(graph.se3.weight.shape[0]), args)\n    # Turn on testing mode\n    with torch.no_grad():\n        if args.deblur_images % 2 == 0:\n            i_render = torch.arange(i_train.shape[0]) * (args.deblur_images + 1) + args.deblur_images // 2\n        else:\n            i_render = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2\n        imgs_render = render_image_test(0, graph, all_poses[i_render], H, W, K, args)\n    mse_render = compute_img_metric(imgs_sharp, imgs_render, 'mse')\n    psnr_render = compute_img_metric(imgs_sharp, imgs_render, 'psnr')\n    ssim_render = compute_img_metric(imgs_sharp, imgs_render, 'ssim')\n    lpips_render = compute_img_metric(imgs_sharp, imgs_render, 'lpips')\n    with open(test_metric_file, 'a') as outfile:\n        outfile.write(f\"test: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}\"\n\n              f\" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\\n\")\n\n    # Turn on novel view testing mode\n    if args.novel_view:\n        i_ = torch.arange(0, images.shape[0], args.llffhold - 1)\n        poses_test_se3_ = graph.se3.weight[i_, :6]\n        model_test = novel_view_test.Model(poses_test_se3_, graph)\n        graph_test = model_test.build_network(args)\n        optimizer_test = model_test.setup_optimizer(args)\n        for j in range(args.N_novel_view):\n            ret_sharp, ray_idx_sharp, poses_sharp = graph_test.forward(0, 0, 0, H, W, K, args,\n                                                                       novel_view=True)\n            target_s_novel = images_novel.reshape(-1, H * W, 3)[:, ray_idx_sharp]\n            target_s_novel = target_s_novel.reshape(-1, 3)\n            loss_sharp = img2mse(ret_sharp['rgb_map'], target_s_novel)\n            psnr_sharp = mse2psnr(loss_sharp)\n            if 'rgb0' in ret_sharp:\n                img_loss0 = img2mse(ret_sharp['rgb0'], target_s_novel)\n                loss_sharp = loss_sharp + img_loss0\n            if j % 100 == 0:\n                print(psnr_sharp.item(), loss_sharp.item())\n            optimizer_test.zero_grad()\n            loss_sharp.backward()\n            optimizer_test.step()\n            decay_rate_sharp = 0.01\n            decay_steps_sharp = args.lrate_decay * 100\n            new_lrate_novel = args.pose_lrate * (decay_rate_sharp ** (j / decay_steps_sharp))\n            for param_group in optimizer_test.param_groups:\n                if (j / decay_steps_sharp) <= 1.:\n                    param_group['lr'] = new_lrate_novel * args.factor_pose_novel\n        with torch.no_grad():\n            imgs_render_novel = render_image_test(0, graph, poses_sharp, H, W, K, args, novel_view=True)\n\n            mse_render = compute_img_metric(images_novel, imgs_render_novel, 'mse')\n            psnr_render = compute_img_metric(images_novel, imgs_render_novel, 'psnr')\n            ssim_render = compute_img_metric(images_novel, imgs_render_novel, 'ssim')\n            lpips_render = compute_img_metric(images_novel, imgs_render_novel, 'lpips')\n            with open(test_metric_file_novel, 'a') as outfile:\n                outfile.write(f\"novel view test: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}\"\n                              f\" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\\n\")\n\n    return 0\n\n\nif __name__=='__main__':\n    torch.set_default_tensor_type('torch.cuda.FloatTensor')\n\n    test()\n"
  },
  {
    "path": "train.py",
    "content": "import torch\n\nfrom nerf import *\nimport optimize_pose_linear, optimize_pose_cubic\nimport torchvision.transforms.functional as torchvision_F\n\nimport matplotlib.pyplot as plt\n\nfrom metrics import compute_img_metric\nimport novel_view_test\n\n\ndef train():\n    parser = config_parser()\n    args = parser.parse_args()\n    print('spline numbers: ', args.deblur_images)\n\n    imgs_sharp_dir = os.path.join(args.datadir, 'images_test')\n    imgs_sharp = load_imgs(imgs_sharp_dir)\n\n    # Load data images and groundtruth\n    K = None\n    if args.dataset_type == 'llff':\n        images_all, poses_start, bds_start, render_poses = load_llff_data(args.datadir, pose_state=None,\n                                                                      factor=args.factor, recenter=True,\n                                                                      bd_factor=.75, spherify=args.spherify)\n        hwf = poses_start[0, :3, -1]\n\n        # split train/val/test\n        if args.novel_view:\n            i_test = torch.arange(0, images_all.shape[0], args.llffhold)\n        else:\n            i_test = torch.tensor([100]).long()\n        i_val = i_test\n        i_train = torch.Tensor([i for i in torch.arange(int(images_all.shape[0])) if\n                                (i not in i_test and i not in i_val)]).long()\n\n        # train data\n        images = images_all[i_train]\n        # novel view data\n        if args.novel_view:\n            images_novel = images_all[i_test]\n        # gt data\n        imgs_sharp = imgs_sharp\n\n        # get poses\n        poses_end = poses_start\n        poses_start_se3 = SE3_to_se3_N(poses_start[:, :3, :4])\n        poses_end_se3 = poses_start_se3\n        poses_org = poses_start.repeat(args.deblur_images, 1, 1)\n        poses = poses_org[:, :, :4]\n\n        print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)\n\n        print('DEFINING BOUNDS')\n        if args.no_ndc:\n            near = torch.min(bds_start) * .9\n            far = torch.max(bds_start) * 1.\n\n        else:\n            near = 0.\n            far = 1.\n        print('NEAR FAR', near, far)\n\n    else:\n        print('Unknown dataset type', args.dataset_type, 'exiting')\n        return\n\n    # Cast intrinsics to right types\n    H, W, focal = hwf\n    H, W = int(H), int(W)\n    hwf = [H, W, focal]\n\n    if K is None:\n        K = torch.Tensor([\n            [focal, 0, 0.5 * W],\n            [0, focal, 0.5 * H],\n            [0, 0, 1]\n        ])\n\n    # Create log dir and copy the config file\n    basedir = args.basedir\n    expname = args.expname\n    test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt')\n    test_metric_file_novel = os.path.join(basedir, expname, 'test_metrics_novel.txt')\n    print_file = os.path.join(basedir, expname, 'print.txt')\n    os.makedirs(os.path.join(basedir, expname), exist_ok=True)\n    f = os.path.join(basedir, expname, 'args.txt')\n    with open(f, 'w') as file:\n        for arg in sorted(vars(args)):\n            attr = getattr(args, arg)\n            file.write('{} = {}\\n'.format(arg, attr))\n    if args.config is not None:\n        f = os.path.join(basedir, expname, 'config.txt')\n        with open(f, 'w') as file:\n            file.write(open(args.config, 'r').read())\n\n    if args.load_weights:\n        if args.linear:\n            print('Linear Spline Model Loading!')\n            model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)\n        else:\n            print('Cubic Spline Model Loading!')\n            model = optimize_pose_cubic.Model(poses_start_se3, poses_start_se3, poses_start_se3, poses_start_se3)\n        graph = model.build_network(args)\n        optimizer, optimizer_se3 = model.setup_optimizer(args)\n        path = os.path.join(basedir, expname, '{:06d}.tar'.format(args.weight_iter))  # here\n        graph_ckpt = torch.load(path)\n        graph.load_state_dict(graph_ckpt['graph'])\n        optimizer.load_state_dict(graph_ckpt['optimizer'])\n        optimizer_se3.load_state_dict(graph_ckpt['optimizer_se3'])\n        global_step = graph_ckpt['global_step']\n\n    else:\n        if args.linear:\n            low, high = 0.0001, 0.005\n            rand = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low\n            poses_start_se3 = poses_start_se3 + rand\n\n            model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)\n        else:\n            low, high = 0.0001, 0.01\n            rand1 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low\n            rand2 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low\n            rand3 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low\n            poses_se3_1 = poses_start_se3 + rand1\n            poses_se3_2 = poses_start_se3 + rand2\n            poses_se3_3 = poses_start_se3 + rand3\n\n            model = optimize_pose_cubic.Model(poses_start_se3, poses_se3_1, poses_se3_2, poses_se3_3)\n\n        graph = model.build_network(args)  # nerf, nerf_fine, forward\n        optimizer, optimizer_se3 = model.setup_optimizer(args)\n\n    N_iters = args.N_iters + 1\n    print('Begin')\n    print('TRAIN views are', i_train)\n    print('TEST views are', i_test)\n    print('VAL views are', i_val)\n\n    start = 0\n    if not args.load_weights:\n        global_step = start\n    global_step_ = global_step\n    threshold = N_iters + 1\n\n    poses_num = poses.shape[0]\n\n    for i in trange(start, threshold):\n    ### core optimization loop ###\n        i = i+global_step_\n        if i == 0:\n            init_nerf(graph.nerf)\n            init_nerf(graph.nerf_fine)\n\n        img_idx = torch.randperm(images.shape[0])\n\n        if (i % args.i_img == 0 or i % args.i_novel_view == 0) and i > 0:\n            ret, ray_idx, spline_poses, all_poses = graph.forward(i, img_idx, poses_num, H, W, K, args)\n        else:\n            ret, ray_idx, spline_poses = graph.forward(i, img_idx, poses_num, H, W, K, args)\n\n        # get image ground truth\n        target_s = images[img_idx].reshape(-1, H * W, 3)\n        target_s = target_s[:, ray_idx]\n        target_s = target_s.reshape(-1, 3)\n\n        # average\n        shape0 = img_idx.shape[0]\n        interval = target_s.shape[0] // shape0\n        rgb_list = []\n        extras_list = []\n        rgb_ = 0\n        extras_ = 0\n\n        for j in range(0, shape0 * args.deblur_images):\n            rgb_ += ret['rgb_map'][j * interval:(j + 1) * interval]\n            if 'rgb0' in ret:\n                extras_ += ret['rgb0'][j * interval:(j + 1) * interval]\n            if (j + 1) % args.deblur_images == 0:\n                rgb_ = rgb_ / args.deblur_images\n                rgb_list.append(rgb_)\n                rgb_ = 0\n                if 'rgb0' in ret:\n                    extras_ = extras_ / args.deblur_images\n                    extras_list.append(extras_)\n                    extras_ = 0\n\n        rgb_blur = torch.stack(rgb_list, 0)\n        rgb_blur = rgb_blur.reshape(-1, 3)\n\n        if 'rgb0' in ret:\n            extras_blur = torch.stack(extras_list, 0)\n            extras_blur = extras_blur.reshape(-1, 3)\n\n        # backward\n        optimizer_se3.zero_grad()\n        optimizer.zero_grad()\n        img_loss = img2mse(rgb_blur, target_s)\n        loss = img_loss\n        psnr = mse2psnr(img_loss)\n\n        if 'rgb0' in ret:\n            img_loss0 = img2mse(extras_blur, target_s)\n            loss = loss + img_loss0\n            psnr0 = mse2psnr(img_loss0)\n\n        loss.backward()\n\n        optimizer.step()\n        optimizer_se3.step()\n\n        # NOTE: IMPORTANT!\n        ###   update learning rate   ###\n        decay_rate = 0.1\n        decay_steps = args.lrate_decay * 1000\n        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = new_lrate\n\n        decay_rate_pose = 0.01\n        new_lrate_pose = args.pose_lrate * (decay_rate_pose ** (global_step / decay_steps))\n        for param_group in optimizer_se3.param_groups:\n            param_group['lr'] = new_lrate_pose\n        ###############################\n\n        if i%args.i_print==0:\n            tqdm.write(f\"[TRAIN] Iter: {i} Loss: {loss.item()}  coarse_loss:, {img_loss0.item()}, PSNR: {psnr.item()}\")\n            with open(print_file, 'a') as outfile:\n                outfile.write(f\"[TRAIN] Iter: {i} Loss: {loss.item()}  coarse_loss:, {img_loss0.item()}, PSNR: {psnr.item()}\\n\")\n\n        if i < 10:\n            print('coarse_loss:', img_loss0.item())\n            with open(print_file, 'a') as outfile:\n                outfile.write(f\"coarse loss: {img_loss0.item()}\\n\")\n\n        if i % args.i_weights == 0 and i > 0:\n            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))\n            torch.save({\n                'global_step': global_step,\n                'graph': graph.state_dict(),\n                'optimizer': optimizer.state_dict(),\n                'optimizer_se3': optimizer_se3.state_dict(),\n            }, path)\n            print('Saved checkpoints at', path)\n\n        if i % args.i_img == 0 and i > 0:\n            # Turn on testing mode\n            with torch.no_grad():\n                if args.deblur_images % 2 == 0:\n                    i_render = torch.arange(i_train.shape[0]) * (args.deblur_images+1) + args.deblur_images // 2\n                else:\n                    i_render = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2\n                imgs_render = render_image_test(i, graph, all_poses[i_render], H, W, K, args)\n            mse_render = compute_img_metric(imgs_sharp, imgs_render, 'mse')\n            psnr_render = compute_img_metric(imgs_sharp, imgs_render, 'psnr')\n            ssim_render = compute_img_metric(imgs_sharp, imgs_render, 'ssim')\n            lpips_render = compute_img_metric(imgs_sharp, imgs_render, 'lpips')\n            with open(test_metric_file, 'a') as outfile:\n                outfile.write(f\"iter{i}: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}\"\n                              f\" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\\n\")\n\n        if i % args.i_video == 0 and i > 0:\n            # Turn on testing mode\n            with torch.no_grad():\n                rgbs, disps = render_video_test(i, graph, render_poses, H, W, K, args)\n            print('Done, saving', rgbs.shape, disps.shape)\n            moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))\n            imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)\n            imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)\n\n        if args.novel_view and i % args.i_novel_view == 0 and i > 0:\n            # Turn on novel view testing mode\n            i_ = torch.arange(0, images.shape[0], args.llffhold-1)\n            poses_test_se3_ = graph.se3.weight[i_,:6]\n            model_test = novel_view_test.Model(poses_test_se3_, graph)\n            graph_test = model_test.build_network(args)\n            optimizer_test = model_test.setup_optimizer(args)\n            for j in range(args.N_novel_view):\n                ret_sharp, ray_idx_sharp, poses_sharp = graph_test.forward(i, img_idx, poses_num, H, W, K, args, novel_view=True)\n                target_s_novel = images_novel.reshape(-1, H*W, 3)[:, ray_idx_sharp]\n                target_s_novel = target_s_novel.reshape(-1, 3)\n                loss_sharp = img2mse(ret_sharp['rgb_map'], target_s_novel)\n                psnr_sharp = mse2psnr(loss_sharp)\n                if 'rgb0' in ret_sharp:\n                    img_loss0 = img2mse(ret_sharp['rgb0'], target_s_novel)\n                    loss_sharp = loss_sharp + img_loss0\n                if j%100==0:\n                    print(psnr_sharp.item(), loss_sharp.item())\n                optimizer_test.zero_grad()\n                loss_sharp.backward()\n                optimizer_test.step()\n                decay_rate_sharp = 0.01\n                decay_steps_sharp = args.lrate_decay * 100\n                new_lrate_novel = args.pose_lrate * (decay_rate_sharp ** (j / decay_steps_sharp))\n                for param_group in optimizer_test.param_groups:\n                    if (j / decay_steps_sharp) <= 1.:\n                        param_group['lr'] = new_lrate_novel * args.factor_pose_novel\n            with torch.no_grad():\n                imgs_render_novel = render_image_test(i, graph, poses_sharp, H, W, K, args, novel_view=True)\n\n                mse_render = compute_img_metric(images_novel, imgs_render_novel, 'mse')\n                psnr_render = compute_img_metric(images_novel, imgs_render_novel, 'psnr')\n                ssim_render = compute_img_metric(images_novel, imgs_render_novel, 'ssim')\n                lpips_render = compute_img_metric(images_novel, imgs_render_novel, 'lpips')\n                with open(test_metric_file_novel, 'a') as outfile:\n                    outfile.write(f\"iter{i}: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}\"\n                                  f\" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\\n\")\n\n        if i % args.N_iters == 0 and i > 0:\n            # Turn on testing mode\n            with torch.no_grad():\n                path_pose = os.path.join(basedir, expname)\n                i_render_pose = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2\n                render_poses_final = all_poses[i_render_pose]\n                save_render_pose(render_poses_final, path_pose)\n\n        global_step += 1\n\n\nif __name__=='__main__':\n    torch.set_default_tensor_type('torch.cuda.FloatTensor')\n\n    train()\n"
  }
]