Full Code of WU-CVGL/BAD-NeRF for AI

main aed2c4a4b230 cached
35 files
113.4 KB
34.4k tokens
139 symbols
1 requests
Download .txt
Repository: WU-CVGL/BAD-NeRF
Branch: main
Commit: aed2c4a4b230
Files: 35
Total size: 113.4 KB

Directory structure:
gitextract_1dub1nd4/

├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── Spline.py
├── configs/
│   ├── cozy2room-cubic.txt
│   ├── cozy2room.txt
│   ├── factory-cubic.txt
│   ├── factory.txt
│   ├── pool-cubic.txt
│   ├── pool.txt
│   ├── roomdark-cubic.txt
│   ├── roomdark.txt
│   ├── roomhigh-cubic.txt
│   ├── roomhigh.txt
│   ├── roomlow-cubic.txt
│   ├── roomlow.txt
│   ├── tanabata-cubic.txt
│   ├── tanabata.txt
│   ├── wine-cubic.txt
│   └── wine.txt
├── load_llff.py
├── lpips/
│   ├── __init__.py
│   ├── lpips.py
│   └── pretrained_networks.py
├── metrics.py
├── nerf.py
├── novel_view_test.py
├── optimize_pose_cubic.py
├── optimize_pose_linear.py
├── requirements.txt
├── run_nerf.py
├── run_nerf_helpers.py
├── test.py
└── train.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.idea/
.vscode/

__pycache__/

logs*/

data/
configs_test/
weights/

================================================
FILE: Dockerfile
================================================
FROM nvcr.io/nvidia/pytorch:23.02-py3

ARG DEBIAN_FRONTEND=noninteractive
ENV TZ=Asia/Shanghai LANG=C.UTF-8 LC_ALL=C.UTF-8 PIP_NO_CACHE_DIR=1 PIP_CACHE_DIR=/tmp/ PYTHONUNBUFFERED=1 PYTHONFAULTHANDLER=1 PYTHONHASHSEED=0

RUN \
    # uncomment to use apt mirror
    # sed -i "s/archive.ubuntu.com/mirrors.ustc.edu.cn/g" /etc/apt/sources.list &&\
    # sed -i "s/security.ubuntu.com/mirrors.ustc.edu.cn/g" /etc/apt/sources.list &&\
    rm -f /etc/apt/sources.list.d/* &&\
    rm -rf /opt/hpcx/ &&\
    apt-get update && apt-get upgrade -y &&\
    apt-get install -y --no-install-recommends \
        autoconf automake autotools-dev build-essential ca-certificates \
        make cmake ninja-build pkg-config g++ ccache yasm openmpi-bin \
        git curl wget unzip nano net-tools htop iotop \
        cloc rsync xz-utils software-properties-common tzdata \
    && apt-get purge -y unattended-upgrades \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

ADD requirements.txt /tmp
RUN \
    # uncomment to use pypi mirror
    # pip config set global.index-url https://mirrors.bfsu.edu.cn/pypi/web/simple &&\
    pip install -U pip &&\
    pip install -r /tmp/requirements.txt &&\
    pip install "jupyterlab~=3.5.0" "jupyter-archive~=3.2" &&\
    rm -rf /tmp/*


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 Peng Wang, Lingzhe Zhao, Ruijie Ma, Peidong Liu

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# 😈BAD-NeRF

<a href="https://arxiv.org/abs/2211.12853"><img src="https://img.shields.io/badge/arXiv-2211.12853-b31b1b.svg" height=22.5></a>
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/github/license/WU-CVGL/BAD-NeRF" height=22.5></a>
<a href="https://www.youtube.com/watch?v=xoES4eONYoA"><img src="https://img.shields.io/badge/YouTube-%23FF0000.svg?style=flat&logo=YouTube&logoColor=white" height=22.5></a>
<a href="https://www.bilibili.com/video/BV1Gz4y1e7oH/"><img src="https://img.shields.io/badge/Bilibili-0696c6" height=22.5></a>

This is an official PyTorch implementation of the paper [BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields](https://arxiv.org/abs/2211.12853) (CVPR 2023). Authors: [Peng Wang](https://github.com/wangpeng000), [Lingzhe Zhao](https://github.com/LingzheZhao), Ruijie Ma and [Peidong Liu](https://ethliup.github.io/).

BAD-NeRF jointly learns the 3D representation and optimizes the camera motion trajectories within exposure time from blurry images and inaccurate initial poses.

Here is the [Project page](https://wangpeng000.github.io/BAD-NeRF/).

## ✨News
📺 **[2023.12]** We were invited to give an online talk on Bilibili and Wechat, hosted by **计算机视觉life**. Archive of our live stream (in Chinese): [[link to Bilibili]](https://www.bilibili.com/video/BV1Rb4y1G7Ek/)

⚡ **[2023.11]** We made our **nerfstudio**-framework-based implementation: [[BAD-NeRFstudio]](https://github.com/WU-CVGL/BAD-NeRFstudio) public. Now you can train a scene from blurry images in minutes!

## Novel View Synthesis
<div><video autoplay loop controls src="https://user-images.githubusercontent.com/43722188/232816090-ced1fbbc-4246-45c6-a265-e7424e754c7b.mp4" muted="true"></video></div>

## Deblurring Result
![teaser](./doc/bad-nerf.jpg)

## Pose Estimation Result
![pose estimation](./doc/pose-estimation.jpg)

## Method overview
![method](./doc/overview.jpg)

We follow the real physical image formation process of a motion-blurred image to synthesize blurry images from NeRF. Both NeRF and the motion trajectories are estimated by maximizing the photometric consistency between the synthesized blurry images and the real blurry images.

## Quickstart

### 1. Setup environment

```
git clone https://github.com/WU-CVGL/BAD-NeRF
cd BAD-NeRF
pip install -r requirements.txt
```

### 2. Download datasets

You can download the data and weights [here](https://westlakeu-my.sharepoint.com/:f:/g/personal/cvgl_westlake_edu_cn/EsgdW2cRic5JqerhNbTsxtkBqy9m6cbnb2ugYZtvaib3qA?e=bjK7op).

For the scenes of Deblur-NeRF (*cozy2room*, *factory* etc.), the folder `images` only includes blurry images and the folder `images_1` additionally includes novel view images. But for our scenes (*room-low*, *room-high* and *dark*), there are no novel view images. Note that the images in the *dark* scene are undistorted, causing that there are some useless pixels, and you should uncomment the code of `Graph.forward` in `nerf.py`.

### 3. Configs

Change the data path and other parameters (if needed) in `configs/cozy2room.txt`. We use *cozy2room* scene as an example.

### 4. Demo with our pre-trained model

You can test our code and render sharp images with the provided weight files. To do this, you should first put the weight file under the corresponding logs folder `./logs/cozy2room-linear`, and then change the parameter `load_weights=True` in `cozy2room.txt`, finally run

```
python test.py --config configs/cozy2room.txt
```

### 5. Training

```
python train.py --config configs/cozy2room.txt
```

After training, you can get deblurred images, optimized camera poses and synthesized novel view images.

## Notes

### Camera poses

The poses (`poses_bounds.npy`) are generated from only blurred images (folder `images`) by COLMAP.

### Spline model

We use `linear interpolation` as the default spline model in our experiments, you can simply change the parameter `linear` (all the parameters can be changed in `configs/***.txt` or `run_nerf.py`) to `False` to use the higher-order spline model (i.e. `cubic B-Spline`).

### Virtual images

You can change the important parameter `deblur_images` to a smaller/bigger value for lightly/severely blurred images.

### Learning rate

After rebuttal, we found that sometimes the gradients will be NaN if `cubic B-Spline` model with a `pose_lrate=1e-3` is used. Therefore, we set the initial pose learning rate to 1e-4 and it may achieve a better performance compared to that in our paper. If the gradient appears NaN in your experiments unfortunately, just kill it and try again or decrease the `pose_lrate`.

## Your own data

`images`: This folder is used to estimate initial camera poses from blurry images. Specifically, just put your own data in the folder `images` (only blurry images), and run `imgs2poses.py` script from the [LLFF code](https://github.com/fyusion/llff) to estimate camera poses and generate `poses_bounds.npy`.

`images_1`: This is the default training folder, which includes the same blurry images in `images` folder and (optional) several novel view sharp images. If you want to add novel view images (sharp images), please put them into the folder `images_1` with an interval of `llffhold` (a parameter used for novel view testing). Remember that, set parameter `novel_view` to `True` if `images_1` includes novel view images. Otherwise, if there are no novel view images, you can directly put the blurry images to the folder `images_1` and set parameter `novel_view` to `False`.

`images_test`: To compute deblurring metrics, this folder contains ground truth images theoretically. However, you can copy the blurry images in `images` folder to `images_test` folder if you don't have ground truth images, which is the easiest way to run the code correctly (remember the computed metrics are wrong).
```
#-----------------------------------------------------------------------------------------#
# images folder: img_blur_*.png is the blurry image.                                      #
#-----------------------------------------------------------------------------------------#
# images_1 folder: img_blur_*.png is the same as that in `images` and (optional)          #
# img_novel_*.png is the sharp novel view image.                                          #
#-----------------------------------------------------------------------------------------#
# images_test folder: img_test_*.png should be the ground truth image corrseponds to      #
# img_blur_*.png to compute PSNR metric. Of course, you can directly put img_blur_*.png   #
# to run the code if you don't have gt images (then the metrics are wrong).               #
#-----------------------------------------------------------------------------------------#
images folder: (suppose 10 images)
img_blur_0.png
img_blur_1.png
.
.
.
img_blur_9.png
#-----------------------------------------------------------------------------------------#
images_1 folder: (suppose novel view images are placed with an `llffhold=5` interval.)
img_novel_0.png (optional)
img_blur_0.png
img_blur_1.png
.
img_blur_4.png
img_novel_1.png (optional)
img_blur_5.png
.
img_blur_9.png
img_novel_2.png (optional)
#-----------------------------------------------------------------------------------------#
images_test folder: (theoretically gt images, but can be other images)
img_test_0.png
img_test_1.png
.
.
.
img_test_9.png
```
## Citation

If you find this useful, please consider citing our paper:

```bibtex
@InProceedings{wang2023badnerf,
    author    = {Wang, Peng and Zhao, Lingzhe and Ma, Ruijie and Liu, Peidong},
    title     = {{BAD-NeRF: Bundle Adjusted Deblur Neural Radiance Fields}},
    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2023},
    pages     = {4170-4179}
}
```

## Acknowledgment

The overall framework, metrics computing and camera transformation are derived from [nerf-pytorch](https://github.com/yenchenlin/nerf-pytorch/), [Deblur-NeRF](https://github.com/limacv/Deblur-NeRF) and [BARF](https://github.com/chenhsuanlin/bundle-adjusting-NeRF) respectively. We appreciate the effort of the contributors to these repositories.


================================================
FILE: Spline.py
================================================
import torch
import numpy as np

delt = 0


def skew_symmetric(w):
    w0, w1, w2 = w.unbind(dim=-1)
    O = torch.zeros_like(w0)
    wx = torch.stack(
        [
            torch.stack([O, -w2, w1], dim=-1),
            torch.stack([w2, O, -w0], dim=-1),
            torch.stack([-w1, w0, O], dim=-1),
        ],
        dim=-2,
    )
    return wx


def taylor_A(x, nth=10):
    # Taylor expansion of sin(x)/x
    ans = torch.zeros_like(x)
    denom = 1.0
    for i in range(nth + 1):
        if i > 0:
            denom *= (2 * i) * (2 * i + 1)
        ans = ans + (-1) ** i * x ** (2 * i) / denom
    return ans


def taylor_B(x, nth=10):
    # Taylor expansion of (1-cos(x))/x**2
    ans = torch.zeros_like(x)
    denom = 1.0
    for i in range(nth + 1):
        denom *= (2 * i + 1) * (2 * i + 2)
        ans = ans + (-1) ** i * x ** (2 * i) / denom
    return ans


def taylor_C(x, nth=10):
    # Taylor expansion of (x-sin(x))/x**3
    ans = torch.zeros_like(x)
    denom = 1.0
    for i in range(nth + 1):
        denom *= (2 * i + 2) * (2 * i + 3)
        ans = ans + (-1) ** i * x ** (2 * i) / denom
    return ans


def exp_r2q_parallel(r, eps=1e-9):
    x, y, z = r[..., 0], r[..., 1], r[..., 2]
    theta = 0.5 * torch.sqrt(x**2 + y**2 + z**2)
    bool_criterion = (theta < eps).unsqueeze(-1).repeat(1, 1, 4)
    return torch.where(
        bool_criterion, exp_r2q_taylor(x, y, z, theta), exp_r2q(x, y, z, theta)
    )


def exp_r2q(x, y, z, theta):
    lambda_ = torch.sin(theta) / (2.0 * theta)
    qx = lambda_ * x
    qy = lambda_ * y
    qz = lambda_ * z
    qw = torch.cos(theta)
    return torch.stack([qx, qy, qz, qw], -1)


def exp_r2q_taylor(x, y, z, theta):
    qx = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * x
    qy = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * y
    qz = (1.0 / 2.0 - 1.0 / 12.0 * theta**2 - 1.0 / 240.0 * theta**4) * z
    qw = 1.0 - 1.0 / 2.0 * theta**2 + 1.0 / 24.0 * theta**4
    return torch.stack([qx, qy, qz, qw], -1)


def q_to_R_parallel(q):
    qb, qc, qd, qa = q.unbind(dim=-1)
    R = torch.stack(
        [
            torch.stack(
                [
                    1 - 2 * (qc**2 + qd**2),
                    2 * (qb * qc - qa * qd),
                    2 * (qa * qc + qb * qd),
                ],
                dim=-1,
            ),
            torch.stack(
                [
                    2 * (qb * qc + qa * qd),
                    1 - 2 * (qb**2 + qd**2),
                    2 * (qc * qd - qa * qb),
                ],
                dim=-1,
            ),
            torch.stack(
                [
                    2 * (qb * qd - qa * qc),
                    2 * (qa * qb + qc * qd),
                    1 - 2 * (qb**2 + qc**2),
                ],
                dim=-1,
            ),
        ],
        dim=-2,
    )
    return R


def q_to_Q_parallel(q):
    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    Q_0 = torch.stack([w, -z, y, x], -1).unsqueeze(-2)
    Q_1 = torch.stack([z, w, -x, y], -1).unsqueeze(-2)
    Q_2 = torch.stack([-y, x, w, z], -1).unsqueeze(-2)
    Q_3 = torch.stack([-x, -y, -z, w], -1).unsqueeze(-2)
    Q_ = torch.cat([Q_0, Q_1, Q_2, Q_3], -2)
    return Q_


def q_to_q_conj_parallel(q):
    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    q_conj_ = torch.stack([-x, -y, -z, w], -1)
    return q_conj_


def log_q2r_parallel(q, eps_theta=1e-20, eps_w=1e-10):
    x, y, z, w = q[..., 0], q[..., 1], q[..., 2], q[..., 3]

    theta = torch.sqrt(x**2 + y**2 + z**2)

    bool_theta_0 = theta < eps_theta
    bool_w_0 = torch.abs(w) < eps_w
    bool_w_0_left = torch.logical_and(bool_w_0, w < 0)

    lambda_ = torch.where(
        bool_w_0,
        torch.where(
            bool_w_0_left,
            log_q2r_lim_w_0_left(theta),
            log_q2r_lim_w_0_right(theta)
        ),
        torch.where(
            bool_theta_0,
            log_q2r_taylor_theta_0(w, theta),
            log_q2r(w, theta)
        ),
    )

    r_ = torch.stack([lambda_ * x, lambda_ * y, lambda_ * z], -1)

    return r_


def log_q2r(w, theta):
    return 2.0 * (torch.arctan(theta / w)) / theta


def log_q2r_taylor_theta_0(w, theta):
    return 2.0 / w - 2.0 / 3.0 * (theta**2) / (w * w * w)


def log_q2r_lim_w_0_left(theta):
    return -torch.pi / theta


def log_q2r_lim_w_0_right(theta):
    return torch.pi / theta


def SE3_to_se3(Rt, eps=1e-8):  # [...,3,4]
    R, t = Rt.split([3, 1], dim=-1)
    w = SO3_to_so3(R)
    wx = skew_symmetric(w)
    theta = w.norm(dim=-1)[..., None, None]
    I = torch.eye(3, device=w.device, dtype=torch.float32)
    A = taylor_A(theta)
    B = taylor_B(theta)
    invV = I - 0.5 * wx + (1 - A / (2 * B)) / (theta**2 + eps) * wx @ wx
    u = (invV @ t)[..., 0]
    wu = torch.cat([w, u], dim=-1)
    return wu


def SO3_to_so3(R, eps=1e-7):  # [...,3,3]
    trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]
    theta = ((trace - 1) / 2).clamp(-1 + eps, 1 - eps).acos_()[
        ..., None, None
    ] % np.pi  # ln(R) will explode if theta==pi
    lnR = (
        1 / (2 * taylor_A(theta) + 1e-8) * (R - R.transpose(-2, -1))
    )  # FIXME: wei-chiu finds it weird
    w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]
    w = torch.stack([w0, w1, w2], dim=-1)
    return w


def se3_to_SE3(wu):  # [...,3]
    w, u = wu.split([3, 3], dim=-1)
    wx = skew_symmetric(w)  # wx=[0 -w(2) w(1);w(2) 0 -w(0);-w(1) w(0) 0]
    theta = w.norm(dim=-1)[..., None, None]  # theta=sqrt(w'*w)
    I = torch.eye(3, device=w.device, dtype=torch.float32)
    A = taylor_A(theta)
    B = taylor_B(theta)
    C = taylor_C(theta)
    R = I + A * wx + B * wx @ wx
    V = I + B * wx + C * wx @ wx
    Rt = torch.cat([R, (V @ u[..., None])], dim=-1)
    return Rt


def SE3_to_se3_N(poses_rt):
    poses_se3_list = []
    for i in range(poses_rt.shape[0]):
        pose_se3 = SE3_to_se3(poses_rt[i])
        poses_se3_list.append(pose_se3)
    poses = torch.stack(poses_se3_list, 0)
    return poses


def se3_to_SE3_N(poses_wu):
    poses_se3_list = []
    for i in range(poses_wu.shape[0]):
        pose_se3 = se3_to_SE3(poses_wu[i])
        poses_se3_list.append(pose_se3)
    poses = torch.stack(poses_se3_list, 0)
    return poses


def se3_2_qt_parallel(wu):
    w, u = wu.split([3, 3], dim=-1)
    wx = skew_symmetric(w)
    theta = w.norm(dim=-1)[..., None, None]
    I = torch.eye(3, device=w.device, dtype=torch.float32)
    # A = taylor_A(theta)
    B = taylor_B(theta)
    C = taylor_C(theta)
    # R = I + A * wx + B * wx @ wx
    V = I + B * wx + C * wx @ wx
    t = V @ u[..., None]
    q = exp_r2q_parallel(w)
    return q, t.squeeze(-1)


def SplineN_linear(start_pose, end_pose, poses_number, NUM, device=None):
    pose_time = poses_number / (NUM - 1)

    # parallel
    pos_0 = torch.where(pose_time == 0)
    pose_time[pos_0] = pose_time[pos_0] + 0.000001
    pos_1 = torch.where(pose_time == 1)
    pose_time[pos_1] = pose_time[pos_1] - 0.000001

    q_start, t_start = se3_2_qt_parallel(start_pose)
    q_end, t_end = se3_2_qt_parallel(end_pose)
    # sample t_vector
    t_t = (1 - pose_time)[..., None] * t_start + pose_time[..., None] * t_end

    # sample rotation_vector
    q_tau_0 = q_to_Q_parallel(q_to_q_conj_parallel(q_start)) @ q_end[..., None]
    r = pose_time[..., None] * log_q2r_parallel(q_tau_0.squeeze(-1))
    q_t_0 = exp_r2q_parallel(r)
    q_t = q_to_Q_parallel(q_start) @ q_t_0[..., None]

    # convert q&t to RT
    R = q_to_R_parallel(q_t.squeeze(dim=-1))
    t = t_t.unsqueeze(dim=-1)
    pose_spline = torch.cat([R, t], -1)

    poses = pose_spline.reshape([-1, 3, 4])

    return poses


def SplineN_cubic(pose0, pose1, pose2, pose3, poses_number, NUM):
    sample_time = poses_number / (NUM - 1)
    # parallel

    pos_0 = torch.where(sample_time == 0)
    sample_time[pos_0] = sample_time[pos_0] + 0.000001
    pos_1 = torch.where(sample_time == 1)
    sample_time[pos_1] = sample_time[pos_1] - 0.000001

    sample_time = sample_time.unsqueeze(-1)

    q0, t0 = se3_2_qt_parallel(pose0)
    q1, t1 = se3_2_qt_parallel(pose1)
    q2, t2 = se3_2_qt_parallel(pose2)
    q3, t3 = se3_2_qt_parallel(pose3)

    u = sample_time
    uu = sample_time**2
    uuu = sample_time**3
    one_over_six = 1.0 / 6.0
    half_one = 0.5

    # t
    coeff0 = one_over_six - half_one * u + half_one * uu - one_over_six * uuu
    coeff1 = 4 * one_over_six - uu + half_one * uuu
    coeff2 = one_over_six + half_one * u + half_one * uu - half_one * uuu
    coeff3 = one_over_six * uuu

    # spline t
    t_t = coeff0 * t0 + coeff1 * t1 + coeff2 * t2 + coeff3 * t3

    # R
    coeff1_r = 5 * one_over_six + half_one * u - half_one * uu + one_over_six * uuu
    coeff2_r = one_over_six + half_one * u + half_one * uu - 2 * one_over_six * uuu
    coeff3_r = one_over_six * uuu

    # spline R
    q_01 = q_to_Q_parallel(q_to_q_conj_parallel(q0)) @ q1[..., None]  # [1]
    q_12 = q_to_Q_parallel(q_to_q_conj_parallel(q1)) @ q2[..., None]  # [2]
    q_23 = q_to_Q_parallel(q_to_q_conj_parallel(q2)) @ q3[..., None]  # [3]

    r_01 = log_q2r_parallel(q_01.squeeze(-1)) * coeff1_r  # [4]
    r_12 = log_q2r_parallel(q_12.squeeze(-1)) * coeff2_r  # [5]
    r_23 = log_q2r_parallel(q_23.squeeze(-1)) * coeff3_r  # [6]

    q_t_0 = exp_r2q_parallel(r_01)  # [7]
    q_t_1 = exp_r2q_parallel(r_12)  # [8]
    q_t_2 = exp_r2q_parallel(r_23)  # [9]

    q_product1 = q_to_Q_parallel(q_t_1) @ q_t_2[..., None]  # [10]
    q_product2 = q_to_Q_parallel(q_t_0) @ q_product1  # [10]
    q_t = q_to_Q_parallel(q0) @ q_product2  # [10]

    R = q_to_R_parallel(q_t.squeeze(-1))
    t = t_t.unsqueeze(dim=-1)

    pose_spline = torch.cat([R, t], -1)

    poses = pose_spline.reshape([-1, 3, 4])

    return poses


================================================
FILE: configs/cozy2room-cubic.txt
================================================
expname = cozy2room-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurcozy2room
dataset_type = llff

factor = 1

linear = False
pose_lrate = 1e-4

novel_view = True
factor_pose_novel = 20.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/cozy2room.txt
================================================
expname = cozy2room-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurcozy2room
dataset_type = llff

factor = 1

linear = True

novel_view = True
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/factory-cubic.txt
================================================
expname = factory-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurfactory
dataset_type = llff

factor = 1

linear = False
pose_lrate = 1e-4

novel_view = True
factor_pose_novel = 200.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/factory.txt
================================================
expname = factory-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurfactory
dataset_type = llff

factor = 1

linear = True

novel_view = True
factor_pose_novel = 20.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/pool-cubic.txt
================================================
expname = pool-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurpool
dataset_type = llff

factor = 1

linear = False
pose_lrate = 1e-4

novel_view = True
factor_pose_novel = 20.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/pool.txt
================================================
expname = pool-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurpool
dataset_type = llff

factor = 1

linear = True

novel_view = True
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/roomdark-cubic.txt
================================================
expname = roomdark-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/dark
dataset_type = llff

factor = 1

linear = False
pose_lrate = 1e-4

novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/roomdark.txt
================================================
expname = roomdark-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/dark
dataset_type = llff

factor = 1

linear = True

novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/roomhigh-cubic.txt
================================================
expname = roomhigh-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/roomblur_high
dataset_type = llff

factor = 1

linear = False
pose_lrate = 1e-4

novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/roomhigh.txt
================================================
expname = roomhigh-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/roomblur_high
dataset_type = llff

factor = 1

linear = True

novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/roomlow-cubic.txt
================================================
expname = roomlow-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/roomblur_low
dataset_type = llff

factor = 1

linear = False
pose_lrate = 1e-4

novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/roomlow.txt
================================================
expname = roomlow-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/roomblur_low
dataset_type = llff

factor = 1

linear = True

novel_view = False
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/tanabata-cubic.txt
================================================
expname = tanabata-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurtanabata
dataset_type = llff

factor = 1

linear = False
pose_lrate = 1e-4

novel_view = True
factor_pose_novel = 50.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/tanabata.txt
================================================
expname = tanabata-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurtanabata
dataset_type = llff

factor = 1

linear = True

novel_view = True
factor_pose_novel = 5.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/wine-cubic.txt
================================================
expname = wine-cubic
basedir = ./logs
datadir = ./data/nerf_llff_data/blurwine
dataset_type = llff

factor = 1

linear = False
pose_lrate = 1e-4

novel_view = True
factor_pose_novel = 20.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: configs/wine.txt
================================================
expname = wine-linear
basedir = ./logs
datadir = ./data/nerf_llff_data/blurwine
dataset_type = llff

factor = 1

linear = True

novel_view = True
factor_pose_novel = 2.0
i_novel_view = 200000

N_rand = 5000
deblur_images = 7

N_samples = 64
N_importance = 64

use_viewdirs = True
raw_noise_std = 1.0

load_weights = False
weight_iter = 200000

i_img = 25000
i_video = 200000
i_weights = 10000


================================================
FILE: load_llff.py
================================================
import numpy as np
import os, imageio
import torch


########## Slightly modified version of LLFF data loading code
##########  see https://github.com/Fyusion/LLFF for original

# downsample
def _minify(basedir, factors=[], resolutions=[]):  # basedir: ./data/nerf_llff_data/fern
    needtoload = False
    for r in factors:
        imgdir = os.path.join(basedir, 'images_{}'.format(r))
        if not os.path.exists(imgdir):
            needtoload = True
    for r in resolutions:
        imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
        if not os.path.exists(imgdir):
            needtoload = True
    if not needtoload:
        return

    from shutil import copy
    from subprocess import check_output

    imgdir = os.path.join(basedir, 'images')
    imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
    imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
    imgdir_orig = imgdir

    wd = os.getcwd()

    for r in factors + resolutions:
        if isinstance(r, int):
            name = 'images_{}'.format(r)
            resizearg = '{}%'.format(100. / r)
        else:
            name = 'images_{}x{}'.format(r[1], r[0])
            resizearg = '{}x{}'.format(r[1], r[0])
        imgdir = os.path.join(basedir, name)
        if os.path.exists(imgdir):
            continue

        print('Minifying', r, basedir)

        os.makedirs(imgdir)
        check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)

        ext = imgs[0].split('.')[-1]
        args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
        print(args)
        os.chdir(imgdir)
        check_output(args, shell=True)
        os.chdir(wd)

        if ext != 'png':
            check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
            print('Removed duplicates')
        print('Done')


def _load_data(basedir, pose_state, factor=None, width=None, height=None, load_imgs=True):

    poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))

    poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
    bds = poses_arr[:, -2:].transpose([1, 0])

    img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
            if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
    sh = imageio.imread(img0).shape

    sfx = ''

    if factor is not None:
        sfx = '_{}'.format(factor)
        _minify(basedir, factors=[factor])
        factor = factor
    elif height is not None:
        factor = sh[0] / float(height)
        width = int(sh[1] / factor)
        _minify(basedir, resolutions=[[height, width]])
        sfx = '_{}x{}'.format(width, height)
    elif width is not None:
        factor = sh[1] / float(width)
        height = int(sh[0] / factor)
        _minify(basedir, resolutions=[[height, width]])
        sfx = '_{}x{}'.format(width, height)
    else:
        factor = 1

    imgdir = os.path.join(basedir, 'images' + sfx)
    if not os.path.exists(imgdir):
        print(imgdir, 'does not exist, returning')
        return

    imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if
                f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]

    if poses.shape[-1] != len(imgfiles):
        print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
        # return

    sh = imageio.imread(imgfiles[0]).shape
    poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
    poses[2, 4, :] = poses[2, 4, :] * 1. / factor

    if not load_imgs:
        return poses, bds

    def imread(f):
        if f.endswith('png'):
            return imageio.imread(f, ignoregamma=True)
        else:
            return imageio.imread(f)

    imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
    imgs = np.stack(imgs, -1)

    print('Loaded image data', imgs.shape, poses[:, -1, 0])

    return poses, bds, imgs


def normalize(x):
    return x / np.linalg.norm(x)


def viewmatrix(z, up, pos):
    vec2 = normalize(z)
    vec1_avg = up
    vec0 = normalize(np.cross(vec1_avg, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.stack([vec0, vec1, vec2, pos], 1)
    return m


def ptstocam(pts, c2w):
    tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
    return tt


def poses_avg(poses):
    hwf = poses[0, :3, -1:]

    center = poses[:, :3, 3].mean(0)
    vec2 = normalize(poses[:, :3, 2].sum(0))
    up = poses[:, :3, 1].sum(0)
    c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)

    return c2w


def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
    render_poses = []
    rads = np.array(list(rads) + [1.])
    hwf = c2w[:, 4:5]

    for theta in np.linspace(0., 2. * np.pi * rots, N + 1)[:-1]:
        c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
        z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
        render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
    return render_poses


def recenter_poses(poses):
    poses_ = poses + 0
    bottom = np.reshape([0, 0, 0, 1.], [1, 4])
    c2w = poses_avg(poses)
    c2w = np.concatenate([c2w[:3, :4], bottom], -2)
    bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
    poses = np.concatenate([poses[:, :3, :4], bottom], -2)

    poses = np.linalg.inv(c2w) @ poses
    poses_[:, :3, :4] = poses[:, :3, :4]
    poses = poses_
    return poses


#####################


def spherify_poses(poses, bds):
    p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1)

    rays_d = poses[:, :3, 2:3]
    rays_o = poses[:, :3, 3:4]

    def min_line_dist(rays_o, rays_d):
        A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
        b_i = -A_i @ rays_o
        pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0))
        return pt_mindist

    pt_mindist = min_line_dist(rays_o, rays_d)

    center = pt_mindist
    up = (poses[:, :3, 3] - center).mean(0)

    vec0 = normalize(up)
    vec1 = normalize(np.cross([.1, .2, .3], vec0))
    vec2 = normalize(np.cross(vec0, vec1))
    pos = center
    c2w = np.stack([vec1, vec2, vec0, pos], 1)

    poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])

    rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))

    sc = 1. / rad
    poses_reset[:, :3, 3] *= sc
    bds *= sc
    rad *= sc

    centroid = np.mean(poses_reset[:, :3, 3], 0)
    zh = centroid[2]
    radcircle = np.sqrt(rad ** 2 - zh ** 2)
    new_poses = []

    for th in np.linspace(0., 2. * np.pi, 120):
        camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
        up = np.array([0, 0, -1.])

        vec2 = normalize(camorigin)
        vec0 = normalize(np.cross(vec2, up))
        vec1 = normalize(np.cross(vec2, vec0))
        pos = camorigin
        p = np.stack([vec0, vec1, vec2, pos], 1)

        new_poses.append(p)

    new_poses = np.stack(new_poses, 0)

    new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1)
    poses_reset = np.concatenate(
        [poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1)

    return poses_reset, new_poses, bds


def load_llff_data(basedir, pose_state=None, factor=1, recenter=True, bd_factor=.75, spherify=False,
                   path_zflat=False):
    poses, bds, imgs = _load_data(basedir, pose_state=pose_state,factor=factor)
    print('Loaded', basedir, bds.min(), bds.max(), 'Pose State: ', pose_state)

    poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
    poses = np.moveaxis(poses, -1, 0).astype(np.float32)
    imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
    bds = np.moveaxis(bds, -1, 0).astype(np.float32)

    # Rescale if bd_factor is provided
    sc = 1. if bd_factor is None else 1. / (bds.min() * bd_factor)
    poses[:, :3, 3] *= sc  # T
    bds *= sc

    if recenter:
        poses = recenter_poses(poses)

    if spherify:
        poses, render_poses, bds = spherify_poses(poses, bds)

    else:
        c2w = poses_avg(poses)
        up = normalize(poses[:, :3, 1].sum(0))

        # Find a reasonable "focus depth" for this dataset
        close_depth, inf_depth = bds.min() * .9, bds.max() * 5.
        dt = .75
        mean_dz = 1. / (((1. - dt) / close_depth + dt / inf_depth))
        focal = mean_dz

        # Get radii for spiral path
        shrink_factor = .8
        zdelta = close_depth * .2
        tt = poses[:, :3, 3]  # ptstocam(poses[:3,3,:].T, c2w).T
        rads = np.percentile(np.abs(tt), 90, 0)
        c2w_path = c2w
        N_views = 120
        N_rots = 2
        if path_zflat:
            #             zloc = np.percentile(tt, 10, 0)[2]
            zloc = -close_depth * .1
            c2w_path[:3, 3] = c2w_path[:3, 3] + zloc * c2w_path[:3, 2]
            rads[2] = 0.
            N_rots = 1
            N_views /= 2

        # Generate poses for spiral path
        render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views)

    render_poses = np.array(render_poses).astype(np.float32)

    imgs = torch.Tensor(imgs)
    poses = torch.Tensor(poses)
    bds = torch.Tensor(bds)
    render_poses = torch.Tensor(render_poses)

    return imgs, poses, bds, render_poses



================================================
FILE: lpips/__init__.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import torch
# from torch.autograd import Variable

from .lpips import *


================================================
FILE: lpips/lpips.py
================================================
from __future__ import absolute_import

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from . import pretrained_networks as pn
import torch.nn


def normalize_tensor(in_feat, eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
    return in_feat / (norm_factor + eps)


def l2(p0, p1, range=255.):
    return .5 * np.mean((p0 / range - p1 / range) ** 2)


def psnr(p0, p1, peak=255.):
    return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2))


def dssim(p0, p1, range=255.):
    from skimage.measure import compare_ssim
    return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.


def rgb2lab(in_img, mean_cent=False):
    from skimage import color
    img_lab = color.rgb2lab(in_img)
    if mean_cent:
        img_lab[:, :, 0] = img_lab[:, :, 0] - 50
    return img_lab


def tensor2np(tensor_obj):
    # change dimension of a tensor object into a numpy array
    return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0))


def np2tensor(np_obj):
    # change dimenion of np array into tensor array
    return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))


def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
    # image tensor to lab tensor
    from skimage import color

    img = tensor2im(image_tensor)
    img_lab = color.rgb2lab(img)
    if mc_only:
        img_lab[:, :, 0] = img_lab[:, :, 0] - 50
    if to_norm and not mc_only:
        img_lab[:, :, 0] = img_lab[:, :, 0] - 50
        img_lab = img_lab / 100.

    return np2tensor(img_lab)


def tensorlab2tensor(lab_tensor, return_inbnd=False):
    from skimage import color
    import warnings
    warnings.filterwarnings("ignore")

    lab = tensor2np(lab_tensor) * 100.
    lab[:, :, 0] = lab[:, :, 0] + 50

    rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1)
    if return_inbnd:
        # convert back to lab, see if we match
        lab_back = color.rgb2lab(rgb_back.astype('uint8'))
        mask = 1. * np.isclose(lab_back, lab, atol=2.)
        mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis])
        return im2tensor(rgb_back), mask
    else:
        return im2tensor(rgb_back)


def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
    image_numpy = image_tensor[0].cpu().float().numpy()
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
    return image_numpy.astype(imtype)


def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
    return torch.tensor((image / factor - cent)
                        [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))


def tensor2vec(vector_tensor):
    return vector_tensor.data.cpu().numpy()[:, :, 0, 0]


def voc_ap(rec, prec, use_07_metric=False):
    """ ap = voc_ap(rec, prec, [use_07_metric])
    Compute VOC AP given precision and recall.
    If use_07_metric is true, uses the
    VOC 07 11 point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.
        for t in np.arange(0., 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.], rec, [1.]))
        mpre = np.concatenate(([0.], prec, [0.]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

def spatial_average(in_tens, keepdim=True):
    return in_tens.mean([2, 3], keepdim=keepdim)


def upsample(in_tens, out_HW=(64, 64)):  # assumes scale factor is same for H and W
    in_H, in_W = in_tens.shape[2], in_tens.shape[3]
    return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)


# Learned perceptual metric
class LPIPS(nn.Module):
    def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
                 pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
        # lpips - [True] means with linear calibration on top of base network
        # pretrained - [True] means load linear weights

        super(LPIPS, self).__init__()
        if verbose:
            print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' %
                  ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))

        self.pnet_type = net
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.lpips = lpips  # false means baseline of just averaging all layers
        self.version = version
        self.scaling_layer = ScalingLayer()

        if self.pnet_type in ['vgg', 'vgg16']:
            net_type = pn.vgg16
            self.chns = [64, 128, 256, 512, 512]
        elif self.pnet_type == 'alex':
            net_type = pn.alexnet
            self.chns = [64, 192, 384, 256, 256]
        elif self.pnet_type == 'squeeze':
            net_type = pn.squeezenet
            self.chns = [64, 128, 256, 384, 384, 512, 512]
        self.L = len(self.chns)

        self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)

        if lpips:
            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
            self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
            if self.pnet_type == 'squeeze':  # 7 layers for squeezenet
                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
                self.lins += [self.lin5, self.lin6]
            self.lins = nn.ModuleList(self.lins)

            if pretrained:
                if model_path is None:
                    import inspect
                    import os
                    model_path = os.path.abspath(
                        os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))

                if verbose:
                    print('Loading model from: %s' % model_path)
                self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)

        if eval_mode:
            self.eval()

    def forward(self, in0, in1, retPerLayer=False, normalize=False):
        if normalize:  # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
            in0 = 2 * in0 - 1
            in1 = 2 * in1 - 1

        # v0.0 - original release had a bug, where input was not scaled
        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else (
        in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

        if self.lpips:
            if self.spatial:
                res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if self.spatial:
                res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]

        val = res[0]
        for l in range(1, self.L):
            val += res[l]

        # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
        # b = torch.max(self.lins[kk](feats0[kk]**2))
        # for kk in range(self.L):
        #     a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
        #     b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
        # a = a/self.L
        # from IPython import embed
        # embed()
        # return 10*torch.log10(b/a)

        if retPerLayer:
            return val, res
        else:
            return val


class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def forward(self, inp):
        return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
    ''' A single linear layer which does a 1x1 conv '''

    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()

        layers = [nn.Dropout(), ] if use_dropout else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class Dist2LogitLayer(nn.Module):
    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''

    def __init__(self, chn_mid=32, use_sigmoid=True):
        super(Dist2LogitLayer, self).__init__()

        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ]
        layers += [nn.LeakyReLU(0.2, True), ]
        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ]
        layers += [nn.LeakyReLU(0.2, True), ]
        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ]
        if use_sigmoid:
            layers += [nn.Sigmoid(), ]
        self.model = nn.Sequential(*layers)

    def forward(self, d0, d1, eps=0.1):
        return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1))


class BCERankingLoss(nn.Module):
    def __init__(self, chn_mid=32):
        super(BCERankingLoss, self).__init__()
        self.net = Dist2LogitLayer(chn_mid=chn_mid)
        # self.parameters = list(self.net.parameters())
        self.loss = torch.nn.BCELoss()

    def forward(self, d0, d1, judge):
        per = (judge + 1.) / 2.
        self.logit = self.net.forward(d0, d1)
        return self.loss(self.logit, per)


# L2, DSSIM metrics
class FakeNet(nn.Module):
    def __init__(self, use_gpu=True, colorspace='Lab'):
        super(FakeNet, self).__init__()
        self.use_gpu = use_gpu
        self.colorspace = colorspace


class L2(FakeNet):
    def forward(self, in0, in1, retPerLayer=None):
        assert (in0.size()[0] == 1)  # currently only supports batchSize 1

        if self.colorspace == 'RGB':
            (N, C, X, Y) = in0.size()
            value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y),
                               dim=3).view(N)
            return value
        elif self.colorspace == 'Lab':
            value = l2(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
                             tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype(
                'float')
            ret_var = Variable(torch.Tensor((value,)))
            if self.use_gpu:
                ret_var = ret_var.cuda()
            return ret_var


class DSSIM(FakeNet):

    def forward(self, in0, in1, retPerLayer=None):
        assert (in0.size()[0] == 1)  # currently only supports batchSize 1

        if self.colorspace == 'RGB':
            value = dssim(1. * tensor2im(in0.data), 1. * tensor2im(in1.data), range=255.).astype(
                'float')
        elif self.colorspace == 'Lab':
            value = dssim(tensor2np(tensor2tensorlab(in0.data, to_norm=False)),
                                tensor2np(tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype(
                'float')
        ret_var = Variable(torch.Tensor((value,)))
        if self.use_gpu:
            ret_var = ret_var.cuda()
        return ret_var


def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print('Network', net)
    print('Total number of parameters: %d' % num_params)


================================================
FILE: lpips/pretrained_networks.py
================================================
from collections import namedtuple
import torch
from torchvision import models as tv


class squeezenet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(squeezenet, self).__init__()
        pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.slice6 = torch.nn.Sequential()
        self.slice7 = torch.nn.Sequential()
        self.N_slices = 7
        for x in range(2):
            self.slice1.add_module(str(x), pretrained_features[x])
        for x in range(2, 5):
            self.slice2.add_module(str(x), pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), pretrained_features[x])
        for x in range(10, 11):
            self.slice5.add_module(str(x), pretrained_features[x])
        for x in range(11, 12):
            self.slice6.add_module(str(x), pretrained_features[x])
        for x in range(12, 13):
            self.slice7.add_module(str(x), pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        h = self.slice6(h)
        h_relu6 = h
        h = self.slice7(h)
        h_relu7 = h
        vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7'])
        out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7)

        return out


class alexnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(alexnet, self).__init__()
        alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(2):
            self.slice1.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(2, 5):
            self.slice2.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(10, 12):
            self.slice5.add_module(str(x), alexnet_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)

        return out


class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

        return out


class resnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True, num=18):
        super(resnet, self).__init__()
        if (num == 18):
            self.net = tv.resnet18(pretrained=pretrained)
        elif (num == 34):
            self.net = tv.resnet34(pretrained=pretrained)
        elif (num == 50):
            self.net = tv.resnet50(pretrained=pretrained)
        elif (num == 101):
            self.net = tv.resnet101(pretrained=pretrained)
        elif (num == 152):
            self.net = tv.resnet152(pretrained=pretrained)
        self.N_slices = 5

        self.conv1 = self.net.conv1
        self.bn1 = self.net.bn1
        self.relu = self.net.relu
        self.maxpool = self.net.maxpool
        self.layer1 = self.net.layer1
        self.layer2 = self.net.layer2
        self.layer3 = self.net.layer3
        self.layer4 = self.net.layer4

    def forward(self, X):
        h = self.conv1(X)
        h = self.bn1(h)
        h = self.relu(h)
        h_relu1 = h
        h = self.maxpool(h)
        h = self.layer1(h)
        h_conv2 = h
        h = self.layer2(h)
        h_conv3 = h
        h = self.layer3(h)
        h_conv4 = h
        h = self.layer4(h)
        h_conv5 = h

        outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5'])
        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)

        return out


================================================
FILE: metrics.py
================================================
from skimage import metrics
import torch
import torch.hub
from lpips.lpips import LPIPS
import os
import numpy as np

photometric = {
    "mse": None,
    "ssim": None,
    "psnr": None,
    "lpips": None
}


def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor,
                       metric="mse", margin=0, mask=None):
    """
    im1t, im2t: torch.tensors with batched imaged shape, range from (0, 1)
    """
    if metric not in photometric.keys():
        raise RuntimeError(f"img_utils:: metric {metric} not recognized")
    if photometric[metric] is None:
        if metric == "mse":
            photometric[metric] = metrics.mean_squared_error
        elif metric == "ssim":
            photometric[metric] = metrics.structural_similarity
        elif metric == "psnr":
            photometric[metric] = metrics.peak_signal_noise_ratio
        elif metric == "lpips":
            photometric[metric] = LPIPS().cpu()

    if mask is not None:
        if mask.dim() == 3:
            mask = mask.unsqueeze(1)
        if mask.shape[1] == 1:
            mask = mask.expand(-1, 3, -1, -1)
        mask = mask.permute(0, 2, 3, 1).numpy()
        batchsz, hei, wid, _ = mask.shape
        if margin > 0:
            marginh = int(hei * margin) + 1
            marginw = int(wid * margin) + 1
            mask = mask[:, marginh:hei - marginh, marginw:wid - marginw]

    # convert from [0, 1] to [-1, 1]
    im1t = (im1t * 2 - 1).clamp(-1, 1)
    im2t = (im2t * 2 - 1).clamp(-1, 1)

    if im1t.dim() == 3:
        im1t = im1t.unsqueeze(0)
        im2t = im2t.unsqueeze(0)
    im1t = im1t.detach().cpu()
    im2t = im2t.detach().cpu()

    if im1t.shape[-1] == 3:
        im1t = im1t.permute(0, 3, 1, 2)
        im2t = im2t.permute(0, 3, 1, 2)

    im1 = im1t.permute(0, 2, 3, 1).numpy()
    im2 = im2t.permute(0, 2, 3, 1).numpy()
    batchsz, hei, wid, _ = im1.shape
    if margin > 0:
        marginh = int(hei * margin) + 1
        marginw = int(wid * margin) + 1
        im1 = im1[:, marginh:hei - marginh, marginw:wid - marginw]
        im2 = im2[:, marginh:hei - marginh, marginw:wid - marginw]
    values = []

    for i in range(batchsz):
        if metric in ["mse", "psnr"]:
            if mask is not None:
                im1 = im1 * mask[i]
                im2 = im2 * mask[i]
            value = photometric[metric](
                im1[i], im2[i]
            )
            if mask is not None:
                hei, wid, _ = im1[i].shape
                pixelnum = mask[i, ..., 0].sum()
                value = value - 10 * np.log10(hei * wid / pixelnum)
        elif metric in ["ssim"]:
            value, ssimmap = photometric["ssim"](
                im1[i], im2[i], multichannel=True, full=True
            )
            if mask is not None:
                value = (ssimmap * mask[i]).sum() / mask[i].sum()
        elif metric in ["lpips"]:
            value = photometric[metric](
                im1t[i:i + 1], im2t[i:i + 1]
            )
        else:
            raise NotImplementedError
        values.append(value)

    return sum(values) / len(values)


================================================
FILE: nerf.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

from run_nerf import *

from Spline import se3_to_SE3


max_iter = 200000
T = max_iter+1
BOUNDARY = 20


class Embedder:
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


def get_embedder(args, multires, i=0):
    if i == -1:
        return nn.Identity(), 3

    embed_kwargs = {
        'include_input': False if args.barf else True,
        'input_dims': 3,
        'max_freq_log2': multires - 1,
        'num_freqs': multires,
        'log_sampling': True,
        'periodic_fns': [torch.sin, torch.cos],
    }

    embedder_obj = Embedder(**embed_kwargs)
    embed = lambda x, eo=embedder_obj: eo.embed(x)
    return embed, embedder_obj.out_dim


class Model():
    def __init__(self):
        super().__init__()

    def build_network(self, args):
        self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)

        return self.graph

    def setup_optimizer(self, args):
        grad_vars = list(self.graph.nerf.parameters())
        if args.N_importance>0:
            grad_vars += list(self.graph.nerf_fine.parameters())
        self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

        return self.optim


class NeRF(nn.Module):
    def __init__(self, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=False):
        super().__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_viewdirs = use_viewdirs

        # network
        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in
                                        range(D - 1)])

        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)])

        if use_viewdirs:
            self.feature_linear = nn.Linear(W, W)
            self.alpha_linear = nn.Linear(W, 1)
            self.rgb_linear = nn.Linear(W // 2, 3)
        else:
            self.output_linear = nn.Linear(W, output_ch)

    def forward(self, barf_i, pts, viewdirs, args):
        embed_fn, input_ch = get_embedder(args, args.multires, args.i_embed)
        embeddirs_fn = None
        if args.use_viewdirs:
            embeddirs_fn, input_ch_views = get_embedder(args, args.multires_views, args.i_embed)
        pts_flat = torch.reshape(pts, [-1, pts.shape[-1]])
        embedded = embed_fn(pts_flat)

        if viewdirs is not None:
            input_dirs = viewdirs[:, None].expand(pts.shape)
            input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
            embedded_dirs = embeddirs_fn(input_dirs_flat)
            embedded = torch.cat([embedded, embedded_dirs], -1)

        input_pts, input_views = torch.split(embedded, [self.input_ch, self.input_ch_views], dim=-1)
        h = input_pts
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = F.relu(h)
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)

        if self.use_viewdirs:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            h = torch.cat([feature, input_views], -1)

            for i, l in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = F.relu(h)

            rgb = self.rgb_linear(h)
            outputs = torch.cat([rgb, alpha], -1)
        else:
            outputs = self.output_linear(h)

        outputs = torch.reshape(outputs, list(pts.shape[:-1]) + [outputs.shape[-1]])

        return outputs

    def raw2output(self, raw, z_vals, rays_d, raw_noise_std=0.0):

        raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists)

        dists = z_vals[..., 1:] - z_vals[..., :-1]
        dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)], -1)

        dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

        rgb = torch.sigmoid(raw[..., :3])
        noise = 0.
        if raw_noise_std > 0.:
            noise = torch.randn(raw[..., 3].shape) * raw_noise_std

        alpha = raw2alpha(raw[..., 3] + noise, dists)
        weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:,:-1]
        rgb_map = torch.sum(weights[..., None] * rgb, -2)

        depth_map = torch.sum(weights * z_vals, -1)
        # disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1))
        disp_map = torch.max(1e-6 * torch.ones_like(depth_map), depth_map / (torch.sum(weights, -1)+1e-6))
        acc_map = torch.sum(weights, -1)

        sigma = F.relu(raw[..., 3] + noise)

        return rgb_map, disp_map, acc_map, weights, depth_map, sigma


class Graph(nn.Module):

    def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=False):
        super().__init__()
        self.nerf = NeRF(D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)
        if args.N_importance > 0:
            self.nerf_fine = NeRF(D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)

    def forward(self, i, img_idx, poses_num, H, W, K, args, novel_view=False):
        if novel_view:
            poses_sharp = se3_to_SE3(self.se3_sharp.weight)
            ray_idx_sharp = torch.randperm(H * W)[:300]
            ret = self.render(i, poses_sharp, ray_idx_sharp, H, W, K, args)
            return ret, ray_idx_sharp, poses_sharp

        spline_poses = self.get_pose(i, img_idx, args)
        ray_idx = torch.randperm(H * W)[:args.N_rand // poses_num]

        '''
        # only used in distorted data
        # aims to prevent the ray_idx lying on the edges
        for j in range(ray_idx.shape[0]):
            h = torch.randperm(H - 1)[0]
            w = torch.randperm(W - 1)[0]
            while (h < BOUNDARY or h > (H - 1 - BOUNDARY) or w < BOUNDARY or w > (W - 1 - BOUNDARY)):
                h = torch.randperm(H - 1)[0]
                w = torch.randperm(W - 1)[0]
            index = h * W + w
            ray_idx[j] = index
        '''

        ret = self.render(i, spline_poses, ray_idx, H, W, K, args, near=0, far=1.0, ray_idx_tv=None, training=True)
        if (i % args.i_img == 0 or i % args.i_novel_view == 0) and i > 0:
            if args.deblur_images % 2 == 0:
                all_poses = self.get_pose_even(i, torch.arange(self.se3.weight.shape[0]), args.deblur_images)
            else:
                all_poses = self.get_pose(i, torch.arange(self.se3.weight.shape[0]), args)
            return ret, ray_idx, spline_poses, all_poses
        else:
            return ret, ray_idx, spline_poses

    def get_pose(self, i, img_idx, args):

        return i

    def get_gt_pose(self, poses, args):

        return poses

    def render(self, barf_i, poses, ray_idx, H, W, K, args, near=0., far=1., ray_idx_tv=None, training=False):
        if training:
            ray_idx_ = ray_idx.repeat(poses.shape[0])
            poses = poses.unsqueeze(1).repeat(1, ray_idx.shape[0], 1, 1).reshape(-1, 3, 4)
            j = ray_idx_.reshape(-1, 1).squeeze() // W
            i = ray_idx_.reshape(-1, 1).squeeze() % W
            rays_o_, rays_d_ = get_specific_rays(i, j, K, poses)
            rays_o_d = torch.stack([rays_o_, rays_d_], 0)
            batch_rays = torch.permute(rays_o_d, [1, 0, 2])
        
        else:
            rays_list = []
            for p in poses[:, :3, :4]:
                rays_o_, rays_d_ = get_rays(H, W, K, p)
                rays_o_d = torch.stack([rays_o_, rays_d_], 0)
                rays_list.append(rays_o_d)

            rays = torch.stack(rays_list, 0)
            rays = rays.reshape(-1, 2, H * W, 3)
            rays = torch.permute(rays, [0, 2, 1, 3])

            batch_rays = rays[:, ray_idx]
        batch_rays = batch_rays.reshape(-1, 2, 3)
        batch_rays = torch.transpose(batch_rays, 0, 1)

        # get standard rays
        rays_o, rays_d = batch_rays
        if args.use_viewdirs:
            viewdirs = rays_d
            viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
            viewdirs = torch.reshape(viewdirs, [-1, 3]).float()

        sh = rays_d.shape
        if args.ndc:
            rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

        # Create ray batch
        rays_o = torch.reshape(rays_o, [-1, 3]).float()
        rays_d = torch.reshape(rays_d, [-1, 3]).float()

        near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])
        rays = torch.cat([rays_o, rays_d, near, far], -1)

        if args.use_viewdirs:
            rays = torch.cat([rays, viewdirs], -1)

        N_rays = rays.shape[0]
        rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]
        viewdirs = rays[:, -3:] if rays.shape[-1] > 8 else None
        bounds = torch.reshape(rays[..., 6:8], [-1, 1, 2])
        near, far = bounds[..., 0], bounds[..., 1]

        t_vals = torch.linspace(0., 1., steps=args.N_samples)
        z_vals = near * (1. - t_vals) + far * (t_vals)
        z_vals = z_vals.expand([N_rays, args.N_samples])

        # perturb
        mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = torch.cat([mids, z_vals[..., -1:]], -1)
        lower = torch.cat([z_vals[..., :1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)
        z_vals = lower + (upper - lower) * t_rand
        pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]

        raw_output = self.nerf.forward(barf_i, pts, viewdirs, args)

        rgb_map, disp_map, acc_map, weights, depth_map, sigma = self.nerf.raw2output(raw_output, z_vals, rays_d)

        if args.N_importance > 0:

            rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map

            z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
            z_samples = sample_pdf(z_vals_mid, weights[..., 1:-1], args.N_importance)
            z_samples = z_samples.detach()

            z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
            pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]

            raw_output = self.nerf_fine.forward(barf_i, pts, viewdirs, args)
            rgb_map, disp_map, acc_map, weights, depth_map, sigma = self.nerf_fine.raw2output(raw_output, z_vals, rays_d)

        ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}
        if args.N_importance > 0:
            ret['rgb0'] = rgb_map_0
            ret['disp0'] = disp_map_0
            ret['acc0'] = acc_map_0
            ret['sigma'] = sigma

        return ret

    @torch.no_grad()
    def render_video(self, barf_i, poses, H, W, K, args):
        all_ret = {}
        ray_idx = torch.arange(0, H*W)
        for i in range(0, ray_idx.shape[0], args.chunk):
            ret = self.render(barf_i, poses, ray_idx[i:i+args.chunk], H, W, K, args)
            for k in ret:
                if k not in all_ret:
                    all_ret[k] = []
                all_ret[k].append(ret[k])
        all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}

        for k in all_ret:
            k_sh = list([H, W]) + list(all_ret[k].shape[1:])
            all_ret[k] = torch.reshape(all_ret[k], k_sh)
        return all_ret


================================================
FILE: novel_view_test.py
================================================
import nerf
import torch.nn


class Model(nerf.Model):
    def __init__(self, se3_start, graph):
        super().__init__()
        self.start = se3_start
        self.graph_fixed = graph

    def build_network(self, args):
        self.graph_fixed.se3_sharp = torch.nn.Embedding(self.start.shape[0], 6)  # 22和25
        self.graph_fixed.se3_sharp.weight.data = torch.nn.Parameter(self.start)

        return self.graph_fixed

    def setup_optimizer(self, args):
        grad_vars_se3 = list(self.graph_fixed.se3_sharp.parameters())
        self.optim_se3_sharp = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)

        return self.optim_se3_sharp

================================================
FILE: optimize_pose_cubic.py
================================================
import torch.nn

import Spline
import nerf


class Model(nerf.Model):
    def __init__(self, se3_0, se3_1, se3_2, se3_3):
        super().__init__()
        self.se3_0 = se3_0
        self.se3_1 = se3_1
        self.se3_2 = se3_2
        self.se3_3 = se3_3

    def build_network(self, args):
        self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)
        self.graph.se3 = torch.nn.Embedding(self.se3_0.shape[0], 6*4)

        start_end = torch.cat([self.se3_0, self.se3_1, self.se3_2, self.se3_3], -1)
        self.graph.se3.weight.data = torch.nn.Parameter(start_end)

        return self.graph

    def setup_optimizer(self, args):
        grad_vars = list(self.graph.nerf.parameters())
        if args.N_importance > 0:
            grad_vars += list(self.graph.nerf_fine.parameters())
        self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

        grad_vars_se3 = list(self.graph.se3.parameters())
        self.optim_se3 = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)

        return self.optim, self.optim_se3


class Graph(nerf.Graph):
    def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True):
        super().__init__(args, D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)
        self.pose_eye = torch.eye(3, 4)
        self.se3_start = None
        self.se3_end = None

    def get_pose(self, i, img_idx, args):
        se3_0 = self.se3.weight[:, :6][img_idx]
        se3_1 = self.se3.weight[:, 6:12][img_idx]
        se3_2 = self.se3.weight[:, 12:18][img_idx]
        se3_3 = self.se3.weight[:, 18:][img_idx]

        pose_nums = torch.arange(args.deblur_images).reshape(1, -1).repeat(se3_0.shape[0], 1)
        seg_pos_x = torch.arange(se3_0.shape[0]).reshape([se3_0.shape[0], 1]).repeat(1, args.deblur_images)

        se3_0 = se3_0[seg_pos_x, :]
        se3_1 = se3_1[seg_pos_x, :]
        se3_2 = se3_2[seg_pos_x, :]
        se3_3 = se3_3[seg_pos_x, :]

        spline_poses = Spline.SplineN_cubic(se3_0, se3_1, se3_2, se3_3, pose_nums, args.deblur_images)
        return spline_poses

    def get_pose_even(self, i, img_idx, num):
        deblur_images_num = num+1
        se3_0 = self.se3.weight[:, :6][img_idx]
        se3_1 = self.se3.weight[:, 6:12][img_idx]
        se3_2 = self.se3.weight[:, 12:18][img_idx]
        se3_3 = self.se3.weight[:, 18:][img_idx]

        pose_nums = torch.arange(deblur_images_num).reshape(1, -1).repeat(se3_0.shape[0], 1)
        seg_pos_x = torch.arange(se3_0.shape[0]).reshape([se3_0.shape[0], 1]).repeat(1, deblur_images_num)

        se3_0 = se3_0[seg_pos_x, :]
        se3_1 = se3_1[seg_pos_x, :]
        se3_2 = se3_2[seg_pos_x, :]
        se3_3 = se3_3[seg_pos_x, :]

        spline_poses = Spline.SplineN_cubic(se3_0, se3_1, se3_2, se3_3, pose_nums, deblur_images_num)
        return spline_poses

    def get_gt_pose(self, poses, args):
        a = self.pose_eye
        return poses


================================================
FILE: optimize_pose_linear.py
================================================
import torch.nn

import Spline
import nerf


class Model(nerf.Model):
    def __init__(self, se3_start, se3_end):
        super().__init__()
        self.start = se3_start
        self.end = se3_end

    def build_network(self, args):
        self.graph = Graph(args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True)
        self.graph.se3 = torch.nn.Embedding(self.start.shape[0], 6*2)

        start_end = torch.cat([self.start, self.end], -1)
        self.graph.se3.weight.data = torch.nn.Parameter(start_end)

        return self.graph

    def setup_optimizer(self, args):
        grad_vars = list(self.graph.nerf.parameters())
        if args.N_importance > 0:
            grad_vars += list(self.graph.nerf_fine.parameters())
        self.optim = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

        grad_vars_se3 = list(self.graph.se3.parameters())
        self.optim_se3 = torch.optim.Adam(params=grad_vars_se3, lr=args.lrate)

        return self.optim, self.optim_se3


class Graph(nerf.Graph):
    def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, output_ch=4, skips=[4], use_viewdirs=True):
        super().__init__(args, D, W, input_ch, input_ch_views, output_ch, skips, use_viewdirs)
        self.pose_eye = torch.eye(3, 4)
        self.se3_start = None
        self.se3_end = None

    def get_pose(self, i, img_idx, args):
        se3_start = self.se3.weight[:, :6][img_idx]
        se3_end = self.se3.weight[:, 6:][img_idx]
        pose_nums = torch.arange(args.deblur_images).reshape(1, -1).repeat(se3_start.shape[0], 1)
        seg_pos_x = torch.arange(se3_start.shape[0]).reshape([se3_start.shape[0], 1]).repeat(1, args.deblur_images)

        se3_start = se3_start[seg_pos_x, :]
        se3_end = se3_end[seg_pos_x, :]

        spline_poses = Spline.SplineN_linear(se3_start, se3_end, pose_nums, args.deblur_images)
        return spline_poses
    
    def get_pose_even(self, i, img_idx, num):
        deblur_images_num = num+1
        se3_start = self.se3.weight[:, :6][img_idx]
        se3_end = self.se3.weight[:, 6:][img_idx]
        pose_nums = torch.arange(deblur_images_num).reshape(1, -1).repeat(se3_start.shape[0],1)
        seg_pos_x = torch.arange(se3_start.shape[0]).reshape([se3_start.shape[0], 1]).repeat(1, deblur_images_num)

        se3_start = se3_start[seg_pos_x, :]
        se3_end = se3_end[seg_pos_x, :]

        spline_poses = Spline.SplineN_linear(se3_start, se3_end, pose_nums, deblur_images_num)
        return spline_poses

    def get_gt_pose(self, poses, args):
        a = self.pose_eye
        return poses


================================================
FILE: requirements.txt
================================================
configargparse
imageio<2.28.0,>=2.26.0
imageio-ffmpeg
matplotlib
numpy
scikit-learn<1
scikit-image<0.20,>=0.19
torch>=1.8
torchvision>=0.9.1
tqdm

================================================
FILE: run_nerf.py
================================================
import os, sys
import numpy as np
import imageio
import json
import random
import time
import torch
from run_nerf_helpers import *
from Spline import *
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt
from load_llff import load_llff_data
import torchvision.transforms.functional as torchvision_F


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
DEBUG = False


def config_parser():

    import configargparse
    parser = configargparse.ArgumentParser()
    parser.add_argument('--config', is_config_file=True, default='configs/cozy2room.txt',
                        help='config file path')
    parser.add_argument("--expname", type=str,
                        help='experiment name')
    parser.add_argument("--basedir", type=str, default='./logs/',
                        help='where to store ckpts and logs')
    parser.add_argument("--datadir", type=str, default='./data/llff/fern',
                        help='input data directory')

    # training options
    parser.add_argument("--N_iters", type=int, default=200000,
                        help='the number of sharp images one blur image corresponds to')
    parser.add_argument("--deblur_images", type=int, default=5,
                        help='the number of sharp images one blur image corresponds to')
    parser.add_argument("--skip", type=int, default=8,
                        help='original llffhold before concatenate images')
    parser.add_argument("--netdepth", type=int, default=8,
                        help='layers in network')
    parser.add_argument("--netwidth", type=int, default=256,
                        help='channels per layer')
    parser.add_argument("--netdepth_fine", type=int, default=8,
                        help='layers in fine network')
    parser.add_argument("--netwidth_fine", type=int, default=256,
                        help='channels per layer in fine network')
    parser.add_argument("--N_rand", type=int, default=32*32*4,
                        help='batch size (number of random rays per gradient step)')
    parser.add_argument("--lrate", type=float, default=5e-4,
                        help='learning rate')
    parser.add_argument("--pose_lrate", type=float, default=1e-3,
                        help='learning rate')
    parser.add_argument("--lrate_decay", type=int, default=200,
                        help='exponential learning rate decay (in 1000 steps)')
    parser.add_argument("--chunk", type=int, default=1024*2,
                        help='number of rays processed in parallel, decrease if running out of memory')
    parser.add_argument("--netchunk", type=int, default=1024*32,
                        help='number of pts sent through network in parallel, decrease if running out of memory')
    parser.add_argument("--no_batching", action='store_true',
                        help='only take random rays from 1 image at a time')
    parser.add_argument("--no_reload", action='store_true',
                        help='do not reload weights from saved ckpt')
    parser.add_argument("--ft_path", type=str, default=None,
                        help='specific weights npy file to reload for coarse network')

    # rendering options
    parser.add_argument("--N_samples", type=int, default=64,
                        help='number of coarse samples per ray')
    parser.add_argument("--N_importance", type=int, default=0,
                        help='number of additional fine samples per ray')
    parser.add_argument("--perturb", type=float, default=1.,
                        help='set to 0. for no jitter, 1. for jitter')
    parser.add_argument("--use_viewdirs", action='store_true',
                        help='use full 5D input instead of 3D')
    parser.add_argument("--i_embed", type=int, default=0,
                        help='set 0 for default positional encoding, -1 for none')
    parser.add_argument("--multires", type=int, default=10,
                        help='log2 of max freq for positional encoding (3D location)')
    parser.add_argument("--multires_views", type=int, default=4,
                        help='log2 of max freq for positional encoding (2D direction)')
    parser.add_argument("--raw_noise_std", type=float, default=0.,
                        help='std dev of noise added to regularize sigma_a output, 1e0 recommended')

    parser.add_argument("--render_only", action='store_true',
                        help='do not optimize, reload weights and render out render_poses path')
    parser.add_argument("--render_test", action='store_true',
                        help='render the test set instead of render_poses path')
    parser.add_argument("--render_factor", type=int, default=0,
                        help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
    parser.add_argument("--ndc", type=bool, default=True,
                        help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')

    # training options
    parser.add_argument("--precrop_iters", type=int, default=0,
                        help='number of steps to train on central crops')
    parser.add_argument("--precrop_frac", type=float,
                        default=.5, help='fraction of img taken for central crops')

    # dataset options
    parser.add_argument("--dataset_type", type=str, default='llff',
                        help='options: llff / blender / deepvoxels')
    parser.add_argument("--testskip", type=int, default=8,
                        help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')

    ## deepvoxels flags
    parser.add_argument("--shape", type=str, default='greek',
                        help='options : armchair / cube / greek / vase')

    ## blender flags
    parser.add_argument("--white_bkgd", action='store_true',
                        help='set to render synthetic data on a white bkgd (always use for dvoxels)')
    parser.add_argument("--half_res", action='store_true',
                        help='load blender synthetic data at 400x400 instead of 800x800')

    ## llff flags
    parser.add_argument("--factor", type=int, default=8,
                        help='downsample factor for LLFF images')
    parser.add_argument("--no_ndc", action='store_true',
                        help='do not use normalized device coordinates (set for non-forward facing scenes)')
    parser.add_argument("--lindisp", action='store_true',
                        help='sampling linearly in disparity rather than depth')
    parser.add_argument("--spherify", action='store_true',
                        help='set for spherical 360 scenes')
    parser.add_argument("--llffhold", type=int, default=8,
                        help='will take every 1/N images as LLFF test set, paper uses 8')

    # logging/saving options
    parser.add_argument("--i_print",   type=int, default=100,
                        help='frequency of console printout and metric loggin')
    parser.add_argument("--i_img",     type=int, default=15000,
                        help='frequency of tensorboard image logging')
    parser.add_argument("--i_weights", type=int, default=10000,
                        help='frequency of weight ckpt saving')
    parser.add_argument("--i_testset", type=int, default=50000,
                        help='frequency of testset saving')
    parser.add_argument("--i_video",   type=int, default=30000,
                        help='frequency of render_poses video saving')
    parser.add_argument("--load_weights", action='store_true',
                        help='frequency of weight ckpt loading')

    # barf: up & down
    parser.add_argument("--barf", action='store_true',
                        help='barf')
    parser.add_argument("--barf_start", type=float, default=0.1,
                        help='barf start')
    parser.add_argument("--barf_end", type=float, default=0.9,
                        help='barf start')

    # test option
    parser.add_argument("--only_optim_one", action='store_true', default=False,
                        help='frequency of weight ckpt loading')

    parser.add_argument("--split_train_data", action='store_true', default=False,
                        help='frequency of weight ckpt loading')
    
    parser.add_argument("--weight_iter", type=int, default=20000,
                        help='weight_iter')
    
    # pose noise
    parser.add_argument("--pose_noise", type=float, default=0.1,
                        help='random noise of pose')

    # linear
    parser.add_argument("--linear", action='store_true', default=False,
                        help='linear or cubic spline')

    # novel view
    parser.add_argument("--novel_view", action='store_true', default=False,
                        help='novel view test')
    parser.add_argument("--i_novel_view", type=int, default=200000,
                        help='novel view iter')
    parser.add_argument("--factor_pose_novel", type=float, default=2.0,
                        help='factor of learning rate')
    parser.add_argument("--N_novel_view", type=int, default=20000,
                        help='novel view iter for optimizing poses')

    return parser




================================================
FILE: run_nerf_helpers.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from tqdm import tqdm, trange
import os
import imageio


# Misc
img2mse = lambda x, y : torch.mean((x - y) ** 2)
img2se = lambda x, y : (x - y) ** 2
mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))  # logab = logcb / logca
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
to8b_tensor = lambda x : (255*torch.clip(x,0,1)).type(torch.int)


def imread(f):
    if f.endswith('png'):
        return imageio.imread(f, ignoregamma=True)
    else:
        return imageio.imread(f)

def load_imgs(path):
    imgfiles = [os.path.join(path, f) for f in sorted(os.listdir(path)) if
                f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
    imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
    imgs = np.stack(imgs, -1)
    imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
    imgs = imgs.astype(np.float32)
    imgs = torch.tensor(imgs).cuda()

    return imgs


# Ray helpers
def get_rays(H, W, K, c2w):
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H))  # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3,-1].expand(rays_d.shape)
    return rays_o, rays_d


def ndc_rays(H, W, focal, near, rays_o, rays_d):
    # Shift ray origins to near plane
    t = -(near + rays_o[...,2]) / rays_d[...,2]
    rays_o = rays_o + t[...,None] * rays_d
    
    # Projection
    o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2]
    o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2]
    o2 = 1. + 2. * near / rays_o[...,2]

    d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2])
    d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2])
    d2 = -2. * near / rays_o[...,2]
    
    rays_o = torch.stack([o0,o1,o2], -1)
    rays_d = torch.stack([d0,d1,d2], -1)
    
    return rays_o, rays_d


# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    # Get pdf
    weights = weights + 1e-5 # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])

    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)

    # Invert CDF
    u = u.contiguous()
    inds = torch.searchsorted(cdf, u, right=True)
    below = torch.max(torch.zeros_like(inds-1), inds-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[...,1]-cdf_g[...,0])
    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
    t = (u-cdf_g[...,0])/denom
    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])

    return samples


def render_video_test(i_, graph, render_poses, H, W, K, args):
    rgbs = []
    disps = []
    # t = time.time()
    for i, pose in enumerate(tqdm(render_poses)):
        # print(i, time.time() - t)
        # t = time.time()
        pose = pose[None, :3, :4]
        ret = graph.render_video(i_, pose[:3, :4], H, W, K, args)
        rgbs.append(ret['rgb_map'].cpu().numpy())
        disps.append(ret['disp_map'].cpu().numpy())
        if i==0:
            print(ret['rgb_map'].shape, ret['disp_map'].shape)
    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)

    return rgbs, disps


to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)


def render_image_test(i, graph, render_poses, H, W, K, args, novel_view=False, need_depth=False):
    if novel_view:
        img_dir = os.path.join(args.basedir, args.expname, 'img_novel_{:06d}'.format(i))
    else:
        img_dir = os.path.join(args.basedir, args.expname, 'img_test_{:06d}'.format(i))
    os.makedirs(img_dir, exist_ok=True)
    imgs =[]

    for j, pose in enumerate(tqdm(render_poses)):
        # print(i, time.time() - t)
        # t = time.time()
        pose = pose[None, :3, :4]
        ret = graph.render_video(i, pose[:3, :4], H, W, K, args)
        imgs.append(ret['rgb_map'])
        rgbs = ret['rgb_map'].cpu().numpy()
        rgb8 = to8b(rgbs)
        imageio.imwrite(os.path.join(img_dir, 'rgb_{:03d}.png'.format(j)), rgb8)
        if need_depth:
            depths = ret['disp_map'].cpu().numpy()
            depths_ = depths/np.max(depths)
            depth8 = to8b(depths_)
            imageio.imwrite(os.path.join(img_dir, 'depth_{:03d}.png'.format(j)), depth8)
    imgs = torch.stack(imgs, 0)
    return imgs


def init_weights(linear):
    # use Xavier init instead of Kaiming init
    torch.nn.init.kaiming_normal_(linear.weight)
    torch.nn.init.zeros_(linear.bias)


def init_nerf(nerf):
    for linear_pt in nerf.pts_linears:
        init_weights(linear_pt)

    for linear_view in nerf.views_linears:
        init_weights(linear_view)

    init_weights(nerf.feature_linear)

    init_weights(nerf.alpha_linear)

    init_weights(nerf.rgb_linear)

# Ray helpers only get specific rays
def get_specific_rays(i, j, K, c2w):
    # i, j = torch.meshgrid(torch.linspace(0, W - 1, W),
    #                       torch.linspace(0, H - 1, H))  # pytorch's meshgrid has indexing='ij'
    # i = i.t()
    # j = j.t()
    dirs = torch.stack([(i - K[0][2]) / K[0][0], -(j - K[1][2]) / K[1][1], -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[..., :3, :3], -1)
    # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[..., :3, -1]
    return rays_o, rays_d


def save_render_pose(poses, path):
    poses_np = poses.cpu().detach().numpy()
    N = poses_np.shape[0]
    bottom = np.reshape([0., 0., 0., 1.], [1, 4])
    bottom_all = np.expand_dims(bottom, 0).repeat(N, axis=0)
    poses_Rt = np.concatenate([poses_np, bottom_all], 1)

    poses_txt = os.path.join(path, 'poses_render.txt')

    for j in range(poses_np.shape[0]):
        poses_flat = poses_Rt[j].reshape(16, 1).squeeze()
        for k in range(16):
            with open(poses_txt, 'a') as outfile:
                if k == 0:
                    outfile.write(f"pose{j} ")
                if k != 15:
                    outfile.write(f"{poses_flat[k]} ")
                if k == 15:
                    outfile.write(f"{poses_flat[k]}\n")


================================================
FILE: test.py
================================================
import torch

from nerf import *
import optimize_pose_linear, optimize_pose_cubic
import torchvision.transforms.functional as torchvision_F

import matplotlib.pyplot as plt

from metrics import compute_img_metric
import novel_view_test


def test():
    parser = config_parser()
    args = parser.parse_args()
    print('spline numbers: ', args.deblur_images)

    imgs_sharp_dir = os.path.join(args.datadir, 'images_test')
    imgs_sharp = load_imgs(imgs_sharp_dir)

    # Load data images and groundtruth
    K = None
    if args.dataset_type == 'llff':
        images_all, poses_start, bds_start, render_poses = load_llff_data(args.datadir, pose_state=None,
                                                                      factor=args.factor, recenter=True,
                                                                      bd_factor=.75, spherify=args.spherify)
        hwf = poses_start[0, :3, -1]

        # split train/val/test
        if args.novel_view:
            i_test = torch.arange(0, images_all.shape[0], args.llffhold)
        else:
            i_test = torch.tensor([100]).long()
        i_val = i_test
        i_train = torch.Tensor([i for i in torch.arange(int(images_all.shape[0])) if
                                (i not in i_test and i not in i_val)]).long()

        # train data
        images = images_all[i_train]
        # novel view data
        if args.novel_view:
            images_novel = images_all[i_test]
        # gt data
        imgs_sharp = imgs_sharp

        # get poses
        poses_end = poses_start
        poses_start_se3 = SE3_to_se3_N(poses_start[:, :3, :4])
        poses_end_se3 = poses_start_se3
        poses_org = poses_start.repeat(args.deblur_images, 1, 1)
        poses = poses_org[:, :, :4]

        print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = torch.min(bds_start) * .9
            far = torch.max(bds_start) * 1.

        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if K is None:
        K = torch.Tensor([
            [focal, 0, 0.5 * W],
            [0, focal, 0.5 * H],
            [0, 0, 1]
        ])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt')
    test_metric_file_novel = os.path.join(basedir, expname, 'test_metrics_novel.txt')
    # print_file = os.path.join(basedir, expname, 'print.txt')
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    if args.linear:
        print('Linear Spline Model Loading!')
        model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)
    else:
        print('Cubic Spline Model Loading!')
        model = optimize_pose_cubic.Model(poses_start_se3, poses_start_se3, poses_start_se3, poses_start_se3)
    graph = model.build_network(args)
    optimizer, optimizer_se3 = model.setup_optimizer(args)
    path = os.path.join(basedir, expname, '{:06d}.tar'.format(args.weight_iter))
    graph_ckpt = torch.load(path)
    graph.load_state_dict(graph_ckpt['graph'])
    optimizer.load_state_dict(graph_ckpt['optimizer'])
    optimizer_se3.load_state_dict(graph_ckpt['optimizer_se3'])
    global_step = graph_ckpt['global_step']

    if args.deblur_images % 2 == 0:
        all_poses = graph.get_pose_even(0, torch.arange(graph.se3.weight.shape[0]), args.deblur_images)
    else:
        all_poses = graph.get_pose(0, torch.arange(graph.se3.weight.shape[0]), args)
    # Turn on testing mode
    with torch.no_grad():
        if args.deblur_images % 2 == 0:
            i_render = torch.arange(i_train.shape[0]) * (args.deblur_images + 1) + args.deblur_images // 2
        else:
            i_render = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2
        imgs_render = render_image_test(0, graph, all_poses[i_render], H, W, K, args)
    mse_render = compute_img_metric(imgs_sharp, imgs_render, 'mse')
    psnr_render = compute_img_metric(imgs_sharp, imgs_render, 'psnr')
    ssim_render = compute_img_metric(imgs_sharp, imgs_render, 'ssim')
    lpips_render = compute_img_metric(imgs_sharp, imgs_render, 'lpips')
    with open(test_metric_file, 'a') as outfile:
        outfile.write(f"test: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}"

              f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n")

    # Turn on novel view testing mode
    if args.novel_view:
        i_ = torch.arange(0, images.shape[0], args.llffhold - 1)
        poses_test_se3_ = graph.se3.weight[i_, :6]
        model_test = novel_view_test.Model(poses_test_se3_, graph)
        graph_test = model_test.build_network(args)
        optimizer_test = model_test.setup_optimizer(args)
        for j in range(args.N_novel_view):
            ret_sharp, ray_idx_sharp, poses_sharp = graph_test.forward(0, 0, 0, H, W, K, args,
                                                                       novel_view=True)
            target_s_novel = images_novel.reshape(-1, H * W, 3)[:, ray_idx_sharp]
            target_s_novel = target_s_novel.reshape(-1, 3)
            loss_sharp = img2mse(ret_sharp['rgb_map'], target_s_novel)
            psnr_sharp = mse2psnr(loss_sharp)
            if 'rgb0' in ret_sharp:
                img_loss0 = img2mse(ret_sharp['rgb0'], target_s_novel)
                loss_sharp = loss_sharp + img_loss0
            if j % 100 == 0:
                print(psnr_sharp.item(), loss_sharp.item())
            optimizer_test.zero_grad()
            loss_sharp.backward()
            optimizer_test.step()
            decay_rate_sharp = 0.01
            decay_steps_sharp = args.lrate_decay * 100
            new_lrate_novel = args.pose_lrate * (decay_rate_sharp ** (j / decay_steps_sharp))
            for param_group in optimizer_test.param_groups:
                if (j / decay_steps_sharp) <= 1.:
                    param_group['lr'] = new_lrate_novel * args.factor_pose_novel
        with torch.no_grad():
            imgs_render_novel = render_image_test(0, graph, poses_sharp, H, W, K, args, novel_view=True)

            mse_render = compute_img_metric(images_novel, imgs_render_novel, 'mse')
            psnr_render = compute_img_metric(images_novel, imgs_render_novel, 'psnr')
            ssim_render = compute_img_metric(images_novel, imgs_render_novel, 'ssim')
            lpips_render = compute_img_metric(images_novel, imgs_render_novel, 'lpips')
            with open(test_metric_file_novel, 'a') as outfile:
                outfile.write(f"novel view test: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}"
                              f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n")

    return 0


if __name__=='__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    test()


================================================
FILE: train.py
================================================
import torch

from nerf import *
import optimize_pose_linear, optimize_pose_cubic
import torchvision.transforms.functional as torchvision_F

import matplotlib.pyplot as plt

from metrics import compute_img_metric
import novel_view_test


def train():
    parser = config_parser()
    args = parser.parse_args()
    print('spline numbers: ', args.deblur_images)

    imgs_sharp_dir = os.path.join(args.datadir, 'images_test')
    imgs_sharp = load_imgs(imgs_sharp_dir)

    # Load data images and groundtruth
    K = None
    if args.dataset_type == 'llff':
        images_all, poses_start, bds_start, render_poses = load_llff_data(args.datadir, pose_state=None,
                                                                      factor=args.factor, recenter=True,
                                                                      bd_factor=.75, spherify=args.spherify)
        hwf = poses_start[0, :3, -1]

        # split train/val/test
        if args.novel_view:
            i_test = torch.arange(0, images_all.shape[0], args.llffhold)
        else:
            i_test = torch.tensor([100]).long()
        i_val = i_test
        i_train = torch.Tensor([i for i in torch.arange(int(images_all.shape[0])) if
                                (i not in i_test and i not in i_val)]).long()

        # train data
        images = images_all[i_train]
        # novel view data
        if args.novel_view:
            images_novel = images_all[i_test]
        # gt data
        imgs_sharp = imgs_sharp

        # get poses
        poses_end = poses_start
        poses_start_se3 = SE3_to_se3_N(poses_start[:, :3, :4])
        poses_end_se3 = poses_start_se3
        poses_org = poses_start.repeat(args.deblur_images, 1, 1)
        poses = poses_org[:, :, :4]

        print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)

        print('DEFINING BOUNDS')
        if args.no_ndc:
            near = torch.min(bds_start) * .9
            far = torch.max(bds_start) * 1.

        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if K is None:
        K = torch.Tensor([
            [focal, 0, 0.5 * W],
            [0, focal, 0.5 * H],
            [0, 0, 1]
        ])

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    test_metric_file = os.path.join(basedir, expname, 'test_metrics.txt')
    test_metric_file_novel = os.path.join(basedir, expname, 'test_metrics_novel.txt')
    print_file = os.path.join(basedir, expname, 'print.txt')
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg in sorted(vars(args)):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    if args.config is not None:
        f = os.path.join(basedir, expname, 'config.txt')
        with open(f, 'w') as file:
            file.write(open(args.config, 'r').read())

    if args.load_weights:
        if args.linear:
            print('Linear Spline Model Loading!')
            model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)
        else:
            print('Cubic Spline Model Loading!')
            model = optimize_pose_cubic.Model(poses_start_se3, poses_start_se3, poses_start_se3, poses_start_se3)
        graph = model.build_network(args)
        optimizer, optimizer_se3 = model.setup_optimizer(args)
        path = os.path.join(basedir, expname, '{:06d}.tar'.format(args.weight_iter))  # here
        graph_ckpt = torch.load(path)
        graph.load_state_dict(graph_ckpt['graph'])
        optimizer.load_state_dict(graph_ckpt['optimizer'])
        optimizer_se3.load_state_dict(graph_ckpt['optimizer_se3'])
        global_step = graph_ckpt['global_step']

    else:
        if args.linear:
            low, high = 0.0001, 0.005
            rand = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low
            poses_start_se3 = poses_start_se3 + rand

            model = optimize_pose_linear.Model(poses_start_se3, poses_end_se3)
        else:
            low, high = 0.0001, 0.01
            rand1 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low
            rand2 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low
            rand3 = (high - low) * torch.rand(poses_start_se3.shape[0], 6) + low
            poses_se3_1 = poses_start_se3 + rand1
            poses_se3_2 = poses_start_se3 + rand2
            poses_se3_3 = poses_start_se3 + rand3

            model = optimize_pose_cubic.Model(poses_start_se3, poses_se3_1, poses_se3_2, poses_se3_3)

        graph = model.build_network(args)  # nerf, nerf_fine, forward
        optimizer, optimizer_se3 = model.setup_optimizer(args)

    N_iters = args.N_iters + 1
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    start = 0
    if not args.load_weights:
        global_step = start
    global_step_ = global_step
    threshold = N_iters + 1

    poses_num = poses.shape[0]

    for i in trange(start, threshold):
    ### core optimization loop ###
        i = i+global_step_
        if i == 0:
            init_nerf(graph.nerf)
            init_nerf(graph.nerf_fine)

        img_idx = torch.randperm(images.shape[0])

        if (i % args.i_img == 0 or i % args.i_novel_view == 0) and i > 0:
            ret, ray_idx, spline_poses, all_poses = graph.forward(i, img_idx, poses_num, H, W, K, args)
        else:
            ret, ray_idx, spline_poses = graph.forward(i, img_idx, poses_num, H, W, K, args)

        # get image ground truth
        target_s = images[img_idx].reshape(-1, H * W, 3)
        target_s = target_s[:, ray_idx]
        target_s = target_s.reshape(-1, 3)

        # average
        shape0 = img_idx.shape[0]
        interval = target_s.shape[0] // shape0
        rgb_list = []
        extras_list = []
        rgb_ = 0
        extras_ = 0

        for j in range(0, shape0 * args.deblur_images):
            rgb_ += ret['rgb_map'][j * interval:(j + 1) * interval]
            if 'rgb0' in ret:
                extras_ += ret['rgb0'][j * interval:(j + 1) * interval]
            if (j + 1) % args.deblur_images == 0:
                rgb_ = rgb_ / args.deblur_images
                rgb_list.append(rgb_)
                rgb_ = 0
                if 'rgb0' in ret:
                    extras_ = extras_ / args.deblur_images
                    extras_list.append(extras_)
                    extras_ = 0

        rgb_blur = torch.stack(rgb_list, 0)
        rgb_blur = rgb_blur.reshape(-1, 3)

        if 'rgb0' in ret:
            extras_blur = torch.stack(extras_list, 0)
            extras_blur = extras_blur.reshape(-1, 3)

        # backward
        optimizer_se3.zero_grad()
        optimizer.zero_grad()
        img_loss = img2mse(rgb_blur, target_s)
        loss = img_loss
        psnr = mse2psnr(img_loss)

        if 'rgb0' in ret:
            img_loss0 = img2mse(extras_blur, target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()

        optimizer.step()
        optimizer_se3.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate

        decay_rate_pose = 0.01
        new_lrate_pose = args.pose_lrate * (decay_rate_pose ** (global_step / decay_steps))
        for param_group in optimizer_se3.param_groups:
            param_group['lr'] = new_lrate_pose
        ###############################

        if i%args.i_print==0:
            tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  coarse_loss:, {img_loss0.item()}, PSNR: {psnr.item()}")
            with open(print_file, 'a') as outfile:
                outfile.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  coarse_loss:, {img_loss0.item()}, PSNR: {psnr.item()}\n")

        if i < 10:
            print('coarse_loss:', img_loss0.item())
            with open(print_file, 'a') as outfile:
                outfile.write(f"coarse loss: {img_loss0.item()}\n")

        if i % args.i_weights == 0 and i > 0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save({
                'global_step': global_step,
                'graph': graph.state_dict(),
                'optimizer': optimizer.state_dict(),
                'optimizer_se3': optimizer_se3.state_dict(),
            }, path)
            print('Saved checkpoints at', path)

        if i % args.i_img == 0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                if args.deblur_images % 2 == 0:
                    i_render = torch.arange(i_train.shape[0]) * (args.deblur_images+1) + args.deblur_images // 2
                else:
                    i_render = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2
                imgs_render = render_image_test(i, graph, all_poses[i_render], H, W, K, args)
            mse_render = compute_img_metric(imgs_sharp, imgs_render, 'mse')
            psnr_render = compute_img_metric(imgs_sharp, imgs_render, 'psnr')
            ssim_render = compute_img_metric(imgs_sharp, imgs_render, 'ssim')
            lpips_render = compute_img_metric(imgs_sharp, imgs_render, 'lpips')
            with open(test_metric_file, 'a') as outfile:
                outfile.write(f"iter{i}: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}"
                              f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n")

        if i % args.i_video == 0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                rgbs, disps = render_video_test(i, graph, render_poses, H, W, K, args)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

        if args.novel_view and i % args.i_novel_view == 0 and i > 0:
            # Turn on novel view testing mode
            i_ = torch.arange(0, images.shape[0], args.llffhold-1)
            poses_test_se3_ = graph.se3.weight[i_,:6]
            model_test = novel_view_test.Model(poses_test_se3_, graph)
            graph_test = model_test.build_network(args)
            optimizer_test = model_test.setup_optimizer(args)
            for j in range(args.N_novel_view):
                ret_sharp, ray_idx_sharp, poses_sharp = graph_test.forward(i, img_idx, poses_num, H, W, K, args, novel_view=True)
                target_s_novel = images_novel.reshape(-1, H*W, 3)[:, ray_idx_sharp]
                target_s_novel = target_s_novel.reshape(-1, 3)
                loss_sharp = img2mse(ret_sharp['rgb_map'], target_s_novel)
                psnr_sharp = mse2psnr(loss_sharp)
                if 'rgb0' in ret_sharp:
                    img_loss0 = img2mse(ret_sharp['rgb0'], target_s_novel)
                    loss_sharp = loss_sharp + img_loss0
                if j%100==0:
                    print(psnr_sharp.item(), loss_sharp.item())
                optimizer_test.zero_grad()
                loss_sharp.backward()
                optimizer_test.step()
                decay_rate_sharp = 0.01
                decay_steps_sharp = args.lrate_decay * 100
                new_lrate_novel = args.pose_lrate * (decay_rate_sharp ** (j / decay_steps_sharp))
                for param_group in optimizer_test.param_groups:
                    if (j / decay_steps_sharp) <= 1.:
                        param_group['lr'] = new_lrate_novel * args.factor_pose_novel
            with torch.no_grad():
                imgs_render_novel = render_image_test(i, graph, poses_sharp, H, W, K, args, novel_view=True)

                mse_render = compute_img_metric(images_novel, imgs_render_novel, 'mse')
                psnr_render = compute_img_metric(images_novel, imgs_render_novel, 'psnr')
                ssim_render = compute_img_metric(images_novel, imgs_render_novel, 'ssim')
                lpips_render = compute_img_metric(images_novel, imgs_render_novel, 'lpips')
                with open(test_metric_file_novel, 'a') as outfile:
                    outfile.write(f"iter{i}: MSE:{mse_render.item():.8f} PSNR:{psnr_render.item():.8f}"
                                  f" SSIM:{ssim_render.item():.8f} LPIPS:{lpips_render.item():.8f}\n")

        if i % args.N_iters == 0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                path_pose = os.path.join(basedir, expname)
                i_render_pose = torch.arange(i_train.shape[0]) * args.deblur_images + args.deblur_images // 2
                render_poses_final = all_poses[i_render_pose]
                save_render_pose(render_poses_final, path_pose)

        global_step += 1


if __name__=='__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

    train()
Download .txt
gitextract_1dub1nd4/

├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── Spline.py
├── configs/
│   ├── cozy2room-cubic.txt
│   ├── cozy2room.txt
│   ├── factory-cubic.txt
│   ├── factory.txt
│   ├── pool-cubic.txt
│   ├── pool.txt
│   ├── roomdark-cubic.txt
│   ├── roomdark.txt
│   ├── roomhigh-cubic.txt
│   ├── roomhigh.txt
│   ├── roomlow-cubic.txt
│   ├── roomlow.txt
│   ├── tanabata-cubic.txt
│   ├── tanabata.txt
│   ├── wine-cubic.txt
│   └── wine.txt
├── load_llff.py
├── lpips/
│   ├── __init__.py
│   ├── lpips.py
│   └── pretrained_networks.py
├── metrics.py
├── nerf.py
├── novel_view_test.py
├── optimize_pose_cubic.py
├── optimize_pose_linear.py
├── requirements.txt
├── run_nerf.py
├── run_nerf_helpers.py
├── test.py
└── train.py
Download .txt
SYMBOL INDEX (139 symbols across 13 files)

FILE: Spline.py
  function skew_symmetric (line 7) | def skew_symmetric(w):
  function taylor_A (line 21) | def taylor_A(x, nth=10):
  function taylor_B (line 32) | def taylor_B(x, nth=10):
  function taylor_C (line 42) | def taylor_C(x, nth=10):
  function exp_r2q_parallel (line 52) | def exp_r2q_parallel(r, eps=1e-9):
  function exp_r2q (line 61) | def exp_r2q(x, y, z, theta):
  function exp_r2q_taylor (line 70) | def exp_r2q_taylor(x, y, z, theta):
  function q_to_R_parallel (line 78) | def q_to_R_parallel(q):
  function q_to_Q_parallel (line 112) | def q_to_Q_parallel(q):
  function q_to_q_conj_parallel (line 122) | def q_to_q_conj_parallel(q):
  function log_q2r_parallel (line 128) | def log_q2r_parallel(q, eps_theta=1e-20, eps_w=1e-10):
  function log_q2r (line 156) | def log_q2r(w, theta):
  function log_q2r_taylor_theta_0 (line 160) | def log_q2r_taylor_theta_0(w, theta):
  function log_q2r_lim_w_0_left (line 164) | def log_q2r_lim_w_0_left(theta):
  function log_q2r_lim_w_0_right (line 168) | def log_q2r_lim_w_0_right(theta):
  function SE3_to_se3 (line 172) | def SE3_to_se3(Rt, eps=1e-8):  # [...,3,4]
  function SO3_to_so3 (line 186) | def SO3_to_so3(R, eps=1e-7):  # [...,3,3]
  function se3_to_SE3 (line 199) | def se3_to_SE3(wu):  # [...,3]
  function SE3_to_se3_N (line 213) | def SE3_to_se3_N(poses_rt):
  function se3_to_SE3_N (line 222) | def se3_to_SE3_N(poses_wu):
  function se3_2_qt_parallel (line 231) | def se3_2_qt_parallel(wu):
  function SplineN_linear (line 246) | def SplineN_linear(start_pose, end_pose, poses_number, NUM, device=None):
  function SplineN_cubic (line 276) | def SplineN_cubic(pose0, pose1, pose2, pose3, poses_number, NUM):

FILE: load_llff.py
  function _minify (line 10) | def _minify(basedir, factors=[], resolutions=[]):  # basedir: ./data/ner...
  function _load_data (line 62) | def _load_data(basedir, pose_state, factor=None, width=None, height=None...
  function normalize (line 125) | def normalize(x):
  function viewmatrix (line 129) | def viewmatrix(z, up, pos):
  function ptstocam (line 138) | def ptstocam(pts, c2w):
  function poses_avg (line 143) | def poses_avg(poses):
  function render_path_spiral (line 154) | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
  function recenter_poses (line 166) | def recenter_poses(poses):
  function spherify_poses (line 183) | def spherify_poses(poses, bds):
  function load_llff_data (line 241) | def load_llff_data(basedir, pose_state=None, factor=1, recenter=True, bd...

FILE: lpips/lpips.py
  function normalize_tensor (line 12) | def normalize_tensor(in_feat, eps=1e-10):
  function l2 (line 17) | def l2(p0, p1, range=255.):
  function psnr (line 21) | def psnr(p0, p1, peak=255.):
  function dssim (line 25) | def dssim(p0, p1, range=255.):
  function rgb2lab (line 30) | def rgb2lab(in_img, mean_cent=False):
  function tensor2np (line 38) | def tensor2np(tensor_obj):
  function np2tensor (line 43) | def np2tensor(np_obj):
  function tensor2tensorlab (line 48) | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False):
  function tensorlab2tensor (line 63) | def tensorlab2tensor(lab_tensor, return_inbnd=False):
  function tensor2im (line 82) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.):
  function im2tensor (line 88) | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
  function tensor2vec (line 93) | def tensor2vec(vector_tensor):
  function voc_ap (line 97) | def voc_ap(rec, prec, use_07_metric=False):
  function spatial_average (line 130) | def spatial_average(in_tens, keepdim=True):
  function upsample (line 134) | def upsample(in_tens, out_HW=(64, 64)):  # assumes scale factor is same ...
  class LPIPS (line 140) | class LPIPS(nn.Module):
    method __init__ (line 141) | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=T...
    method forward (line 199) | def forward(self, in0, in1, retPerLayer=False, normalize=False):
  class ScalingLayer (line 245) | class ScalingLayer(nn.Module):
    method __init__ (line 246) | def __init__(self):
    method forward (line 251) | def forward(self, inp):
  class NetLinLayer (line 255) | class NetLinLayer(nn.Module):
    method __init__ (line 258) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
    method forward (line 265) | def forward(self, x):
  class Dist2LogitLayer (line 269) | class Dist2LogitLayer(nn.Module):
    method __init__ (line 272) | def __init__(self, chn_mid=32, use_sigmoid=True):
    method forward (line 284) | def forward(self, d0, d1, eps=0.1):
  class BCERankingLoss (line 288) | class BCERankingLoss(nn.Module):
    method __init__ (line 289) | def __init__(self, chn_mid=32):
    method forward (line 295) | def forward(self, d0, d1, judge):
  class FakeNet (line 302) | class FakeNet(nn.Module):
    method __init__ (line 303) | def __init__(self, use_gpu=True, colorspace='Lab'):
  class L2 (line 309) | class L2(FakeNet):
    method forward (line 310) | def forward(self, in0, in1, retPerLayer=None):
  class DSSIM (line 328) | class DSSIM(FakeNet):
    method forward (line 330) | def forward(self, in0, in1, retPerLayer=None):
  function print_network (line 346) | def print_network(net):

FILE: lpips/pretrained_networks.py
  class squeezenet (line 6) | class squeezenet(torch.nn.Module):
    method __init__ (line 7) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 36) | def forward(self, X):
  class alexnet (line 57) | class alexnet(torch.nn.Module):
    method __init__ (line 58) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 81) | def forward(self, X):
  class vgg16 (line 98) | class vgg16(torch.nn.Module):
    method __init__ (line 99) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 122) | def forward(self, X):
  class resnet (line 139) | class resnet(torch.nn.Module):
    method __init__ (line 140) | def __init__(self, requires_grad=False, pretrained=True, num=18):
    method forward (line 163) | def forward(self, X):

FILE: metrics.py
  function compute_img_metric (line 16) | def compute_img_metric(im1t: torch.Tensor, im2t: torch.Tensor,

FILE: nerf.py
  class Embedder (line 15) | class Embedder:
    method __init__ (line 16) | def __init__(self, **kwargs):
    method create_embedding_fn (line 20) | def create_embedding_fn(self):
    method embed (line 44) | def embed(self, inputs):
  function get_embedder (line 48) | def get_embedder(args, multires, i=0):
  class Model (line 66) | class Model():
    method __init__ (line 67) | def __init__(self):
    method build_network (line 70) | def build_network(self, args):
    method setup_optimizer (line 75) | def setup_optimizer(self, args):
  class NeRF (line 84) | class NeRF(nn.Module):
    method __init__ (line 85) | def __init__(self, D=8, W=256, input_ch=63, input_ch_views=27, output_...
    method forward (line 108) | def forward(self, barf_i, pts, viewdirs, args):
    method raw2output (line 148) | def raw2output(self, raw, z_vals, rays_d, raw_noise_std=0.0):
  class Graph (line 176) | class Graph(nn.Module):
    method __init__ (line 178) | def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, o...
    method forward (line 184) | def forward(self, i, img_idx, poses_num, H, W, K, args, novel_view=Fal...
    method get_pose (line 217) | def get_pose(self, i, img_idx, args):
    method get_gt_pose (line 221) | def get_gt_pose(self, poses, args):
    method render (line 225) | def render(self, barf_i, poses, ray_idx, H, W, K, args, near=0., far=1...
    method render_video (line 318) | def render_video(self, barf_i, poses, H, W, K, args):

FILE: novel_view_test.py
  class Model (line 5) | class Model(nerf.Model):
    method __init__ (line 6) | def __init__(self, se3_start, graph):
    method build_network (line 11) | def build_network(self, args):
    method setup_optimizer (line 17) | def setup_optimizer(self, args):

FILE: optimize_pose_cubic.py
  class Model (line 7) | class Model(nerf.Model):
    method __init__ (line 8) | def __init__(self, se3_0, se3_1, se3_2, se3_3):
    method build_network (line 15) | def build_network(self, args):
    method setup_optimizer (line 24) | def setup_optimizer(self, args):
  class Graph (line 36) | class Graph(nerf.Graph):
    method __init__ (line 37) | def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, o...
    method get_pose (line 43) | def get_pose(self, i, img_idx, args):
    method get_pose_even (line 60) | def get_pose_even(self, i, img_idx, num):
    method get_gt_pose (line 78) | def get_gt_pose(self, poses, args):

FILE: optimize_pose_linear.py
  class Model (line 7) | class Model(nerf.Model):
    method __init__ (line 8) | def __init__(self, se3_start, se3_end):
    method build_network (line 13) | def build_network(self, args):
    method setup_optimizer (line 22) | def setup_optimizer(self, args):
  class Graph (line 34) | class Graph(nerf.Graph):
    method __init__ (line 35) | def __init__(self, args, D=8, W=256, input_ch=63, input_ch_views=27, o...
    method get_pose (line 41) | def get_pose(self, i, img_idx, args):
    method get_pose_even (line 53) | def get_pose_even(self, i, img_idx, num):
    method get_gt_pose (line 66) | def get_gt_pose(self, poses, args):

FILE: run_nerf.py
  function config_parser (line 24) | def config_parser():

FILE: run_nerf_helpers.py
  function imread (line 19) | def imread(f):
  function load_imgs (line 25) | def load_imgs(path):
  function get_rays (line 38) | def get_rays(H, W, K, c2w):
  function ndc_rays (line 50) | def ndc_rays(H, W, focal, near, rays_o, rays_d):
  function sample_pdf (line 71) | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
  function render_video_test (line 117) | def render_video_test(i_, graph, render_poses, H, W, K, args):
  function render_image_test (line 139) | def render_image_test(i, graph, render_poses, H, W, K, args, novel_view=...
  function init_weights (line 165) | def init_weights(linear):
  function init_nerf (line 171) | def init_nerf(nerf):
  function get_specific_rays (line 185) | def get_specific_rays(i, j, K, c2w):
  function save_render_pose (line 199) | def save_render_pose(poses, path):

FILE: test.py
  function test (line 13) | def test():

FILE: train.py
  function train (line 13) | def train():
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (121K chars).
[
  {
    "path": ".gitignore",
    "chars": 67,
    "preview": ".idea/\n.vscode/\n\n__pycache__/\n\nlogs*/\n\ndata/\nconfigs_test/\nweights/"
  },
  {
    "path": "Dockerfile",
    "chars": 1266,
    "preview": "FROM nvcr.io/nvidia/pytorch:23.02-py3\n\nARG DEBIAN_FRONTEND=noninteractive\nENV TZ=Asia/Shanghai LANG=C.UTF-8 LC_ALL=C.UTF"
  },
  {
    "path": "LICENSE",
    "chars": 1104,
    "preview": "MIT License\n\nCopyright (c) 2023 Peng Wang, Lingzhe Zhao, Ruijie Ma, Peidong Liu\n\nPermission is hereby granted, free of c"
  },
  {
    "path": "README.md",
    "chars": 8223,
    "preview": "# 😈BAD-NeRF\n\n<a href=\"https://arxiv.org/abs/2211.12853\"><img src=\"https://img.shields.io/badge/arXiv-2211.12853-b31b1b.s"
  },
  {
    "path": "Spline.py",
    "chars": 9727,
    "preview": "import torch\nimport numpy as np\n\ndelt = 0\n\n\ndef skew_symmetric(w):\n    w0, w1, w2 = w.unbind(dim=-1)\n    O = torch.zeros"
  },
  {
    "path": "configs/cozy2room-cubic.txt",
    "chars": 422,
    "preview": "expname = cozy2room-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurcozy2room\ndataset_type = llff\n\nfactor = 1"
  },
  {
    "path": "configs/cozy2room.txt",
    "chars": 403,
    "preview": "expname = cozy2room-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurcozy2room\ndataset_type = llff\n\nfactor = "
  },
  {
    "path": "configs/factory-cubic.txt",
    "chars": 419,
    "preview": "expname = factory-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurfactory\ndataset_type = llff\n\nfactor = 1\n\nli"
  },
  {
    "path": "configs/factory.txt",
    "chars": 400,
    "preview": "expname = factory-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurfactory\ndataset_type = llff\n\nfactor = 1\n\nl"
  },
  {
    "path": "configs/pool-cubic.txt",
    "chars": 412,
    "preview": "expname = pool-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurpool\ndataset_type = llff\n\nfactor = 1\n\nlinear ="
  },
  {
    "path": "configs/pool.txt",
    "chars": 393,
    "preview": "expname = pool-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurpool\ndataset_type = llff\n\nfactor = 1\n\nlinear "
  },
  {
    "path": "configs/roomdark-cubic.txt",
    "chars": 412,
    "preview": "expname = roomdark-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/dark\ndataset_type = llff\n\nfactor = 1\n\nlinear ="
  },
  {
    "path": "configs/roomdark.txt",
    "chars": 394,
    "preview": "expname = roomdark-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/dark\ndataset_type = llff\n\nfactor = 1\n\nlinear "
  },
  {
    "path": "configs/roomhigh-cubic.txt",
    "chars": 421,
    "preview": "expname = roomhigh-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/roomblur_high\ndataset_type = llff\n\nfactor = 1\n"
  },
  {
    "path": "configs/roomhigh.txt",
    "chars": 403,
    "preview": "expname = roomhigh-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/roomblur_high\ndataset_type = llff\n\nfactor = 1"
  },
  {
    "path": "configs/roomlow-cubic.txt",
    "chars": 419,
    "preview": "expname = roomlow-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/roomblur_low\ndataset_type = llff\n\nfactor = 1\n\nl"
  },
  {
    "path": "configs/roomlow.txt",
    "chars": 401,
    "preview": "expname = roomlow-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/roomblur_low\ndataset_type = llff\n\nfactor = 1\n\n"
  },
  {
    "path": "configs/tanabata-cubic.txt",
    "chars": 420,
    "preview": "expname = tanabata-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurtanabata\ndataset_type = llff\n\nfactor = 1\n\n"
  },
  {
    "path": "configs/tanabata.txt",
    "chars": 401,
    "preview": "expname = tanabata-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurtanabata\ndataset_type = llff\n\nfactor = 1\n"
  },
  {
    "path": "configs/wine-cubic.txt",
    "chars": 412,
    "preview": "expname = wine-cubic\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurwine\ndataset_type = llff\n\nfactor = 1\n\nlinear ="
  },
  {
    "path": "configs/wine.txt",
    "chars": 393,
    "preview": "expname = wine-linear\nbasedir = ./logs\ndatadir = ./data/nerf_llff_data/blurwine\ndataset_type = llff\n\nfactor = 1\n\nlinear "
  },
  {
    "path": "load_llff.py",
    "chars": 9597,
    "preview": "import numpy as np\nimport os, imageio\nimport torch\n\n\n########## Slightly modified version of LLFF data loading code\n####"
  },
  {
    "path": "lpips/__init__.py",
    "chars": 202,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport num"
  },
  {
    "path": "lpips/lpips.py",
    "chars": 12852,
    "preview": "from __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autog"
  },
  {
    "path": "lpips/pretrained_networks.py",
    "chars": 6540,
    "preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\n\n\nclass squeezenet(torch.nn.Module)"
  },
  {
    "path": "metrics.py",
    "chars": 3089,
    "preview": "from skimage import metrics\nimport torch\nimport torch.hub\nfrom lpips.lpips import LPIPS\nimport os\nimport numpy as np\n\nph"
  },
  {
    "path": "nerf.py",
    "chars": 12437,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom run_nerf import *\n\nfrom Spline import se3_to_SE"
  },
  {
    "path": "novel_view_test.py",
    "chars": 655,
    "preview": "import nerf\nimport torch.nn\n\n\nclass Model(nerf.Model):\n    def __init__(self, se3_start, graph):\n        super().__init_"
  },
  {
    "path": "optimize_pose_cubic.py",
    "chars": 3028,
    "preview": "import torch.nn\n\nimport Spline\nimport nerf\n\n\nclass Model(nerf.Model):\n    def __init__(self, se3_0, se3_1, se3_2, se3_3)"
  },
  {
    "path": "optimize_pose_linear.py",
    "chars": 2643,
    "preview": "import torch.nn\n\nimport Spline\nimport nerf\n\n\nclass Model(nerf.Model):\n    def __init__(self, se3_start, se3_end):\n      "
  },
  {
    "path": "requirements.txt",
    "chars": 145,
    "preview": "configargparse\nimageio<2.28.0,>=2.26.0\nimageio-ffmpeg\nmatplotlib\nnumpy\nscikit-learn<1\nscikit-image<0.20,>=0.19\ntorch>=1."
  },
  {
    "path": "run_nerf.py",
    "chars": 9291,
    "preview": "import os, sys\nimport numpy as np\nimport imageio\nimport json\nimport random\nimport time\nimport torch\nfrom run_nerf_helper"
  },
  {
    "path": "run_nerf_helpers.py",
    "chars": 7726,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom tqdm import tqdm, trange\nimp"
  },
  {
    "path": "test.py",
    "chars": 7512,
    "preview": "import torch\n\nfrom nerf import *\nimport optimize_pose_linear, optimize_pose_cubic\nimport torchvision.transforms.function"
  },
  {
    "path": "train.py",
    "chars": 13510,
    "preview": "import torch\n\nfrom nerf import *\nimport optimize_pose_linear, optimize_pose_cubic\nimport torchvision.transforms.function"
  }
]

About this extraction

This page contains the full source code of the WU-CVGL/BAD-NeRF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (113.4 KB), approximately 34.4k tokens, and a symbol index with 139 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!