Full Code of gaochen315/DynamicNeRF for AI

main c417fb207ef3 cached
39 files
245.4 KB
69.2k tokens
265 symbols
1 requests
Download .txt
Showing preview only (258K chars total). Download the full file or copy to clipboard to get everything.
Repository: gaochen315/DynamicNeRF
Branch: main
Commit: c417fb207ef3
Files: 39
Total size: 245.4 KB

Directory structure:
gitextract_avcck_b3/

├── LICENSE
├── README.md
├── configs/
│   ├── config.txt
│   ├── config_Balloon1.txt
│   ├── config_Balloon2.txt
│   ├── config_Jumping.txt
│   ├── config_Playground.txt
│   ├── config_Skating.txt
│   ├── config_Truck.txt
│   └── config_Umbrella.txt
├── load_llff.py
├── render_utils.py
├── run_nerf.py
├── run_nerf_helpers.py
└── utils/
    ├── RAFT/
    │   ├── __init__.py
    │   ├── corr.py
    │   ├── datasets.py
    │   ├── demo.py
    │   ├── extractor.py
    │   ├── raft.py
    │   ├── update.py
    │   └── utils/
    │       ├── __init__.py
    │       ├── augmentor.py
    │       ├── flow_viz.py
    │       ├── frame_utils.py
    │       └── utils.py
    ├── colmap_utils.py
    ├── evaluation.py
    ├── flow_utils.py
    ├── generate_data.py
    ├── generate_depth.py
    ├── generate_flow.py
    ├── generate_motion_mask.py
    ├── generate_pose.py
    └── midas/
        ├── base_model.py
        ├── blocks.py
        ├── midas_net.py
        ├── transforms.py
        └── vit.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================

MIT License

Copyright (c) 2020 Virginia Tech Vision and Learning Lab

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

--------------------------- LICENSE FOR EdgeConnect --------------------------------

Attribution-NonCommercial 4.0 International


================================================
FILE: README.md
================================================
# Dynamic View Synthesis from Dynamic Monocular Video

[![arXiv](https://img.shields.io/badge/arXiv-2108.00946-b31b1b.svg)](https://arxiv.org/abs/2105.06468)

[Project Website](https://free-view-video.github.io/) | [Video](https://youtu.be/j8CUzIR0f8M) | [Paper](https://arxiv.org/abs/2105.06468)

> **Dynamic View Synthesis from Dynamic Monocular Video**<br>
> [Chen Gao](http://chengao.vision), [Ayush Saraf](#), [Johannes Kopf](https://johanneskopf.de/), [Jia-Bin Huang](https://filebox.ece.vt.edu/~jbhuang/) <br>
in ICCV 2021 <br>

## Setup
The code is test with
* Linux (tested on CentOS Linux release 7.4.1708)
* Anaconda 3
* Python 3.7.11
* CUDA 10.1
* 1 V100 GPU


To get started, please create the conda environment `dnerf` by running
```
conda create --name dnerf python=3.7
conda activate dnerf
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
pip install imageio scikit-image configargparse timm lpips
```
and install [COLMAP](https://colmap.github.io/install.html) manually. Then download MiDaS and RAFT weights
```
ROOT_PATH=/path/to/the/DynamicNeRF/folder
cd $ROOT_PATH
wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/weights.zip
unzip weights.zip
rm weights.zip
```

## Dynamic Scene Dataset
The [Dynamic Scene Dataset](https://www-users.cse.umn.edu/~jsyoon/dynamic_synth/) is used to
quantitatively evaluate our method. Please download the pre-processed data by running:
```
cd $ROOT_PATH
wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/data.zip
unzip data.zip
rm data.zip
```

### Training
You can train a model from scratch by running:
```
cd $ROOT_PATH/
python run_nerf.py --config configs/config_Balloon2.txt
```

Every 100k iterations, you should get videos like the following examples

The novel view-time synthesis results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/novelviewtime`.
![novelviewtime](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/novelviewtime_Balloon2.gif)
<!-- <img src="https://filebox.ece.vt.edu/~chengao/free-view-video/gif/novelviewtime.gif" height="270" /> -->

The reconstruction results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset`.
![testset](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_Balloon2.gif)

The fix-view-change-time results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_view000`.
![testset_view000](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_view000_Balloon2.gif)

The fix-time-change-view results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_time000`.
![testset_time000](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_time000_Balloon2.gif)


### Rendering from pre-trained models
We also provide pre-trained models. You can download them by running:
```
cd $ROOT_PATH/
wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/logs.zip
unzip logs.zip
rm logs.zip
```

Then you can render the results directly by running:
```
python run_nerf.py --config configs/config_Balloon2.txt --render_only --ft_path $ROOT_PATH/logs/Balloon2_H270_DyNeRF_pretrain/300000.tar
```

### Evaluating our method and others
Our goal is to make the evaluation as simple as possible for you. We have collected the fix-view-change-time results of the following methods:

`NeRF` \
`NeRF + t` \
`Yoon et al.` \
`Non-Rigid NeRF` \
`NSFF` \
`DynamicNeRF (ours)`

Please download the results by running:
```
cd $ROOT_PATH/
wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/results.zip
unzip results.zip
rm results.zip
```

Then you can calculate the PSNR/SSIM/LPIPS by running:
```
cd $ROOT_PATH/utils
python evaluation.py
```

| PSNR / LPIPS |    Jumping    |    Skating    |     Truck     |    Umbrella   |    Balloon1   |    Balloon2   |   Playground  |    Average    |
|:-------------|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|
| NeRF         | 20.99 / 0.305 | 23.67 / 0.311 | 22.73 / 0.229 | 21.29 / 0.440 | 19.82 / 0.205 | 24.37 / 0.098 | 21.07 / 0.165 | 21.99 / 0.250 |
| NeRF + t     | 18.04 / 0.455 | 20.32 / 0.512 | 18.33 / 0.382 | 17.69 / 0.728 | 18.54 / 0.275 | 20.69 / 0.216 | 14.68 / 0.421 | 18.33 / 0.427 |
| NR NeRF      | 20.09 / 0.287 | 23.95 / 0.227 | 19.33 / 0.446 | 19.63 / 0.421 | 17.39 / 0.348 | 22.41 / 0.213 | 15.06 / 0.317 | 19.69 / 0.323 |
| NSFF         | 24.65 / 0.151 | 29.29 / 0.129 | 25.96 / 0.167 | 22.97 / 0.295 | 21.96 / 0.215 | 24.27 / 0.222 | 21.22 / 0.212 | 24.33 / 0.199 |
| Ours         | 24.68 / 0.090 | 32.66 / 0.035 | 28.56 / 0.082 | 23.26 / 0.137 | 22.36 / 0.104 | 27.06 / 0.049 | 24.15 / 0.080 | 26.10 / 0.082 |


Please note:
1. The numbers reported in the paper are calculated using TF code. The numbers here are calculated using this improved Pytorch version.
2. In Yoon's results, the first frame and the last frame are missing. To compare with Yoon's results, we have to omit the first frame and the last frame. To do so, please uncomment line 72 and comment line 73 in `evaluation.py`.
3. We obtain the results of NSFF and NR NeRF using the official implementation with default parameters.


## Train a model on your sequence
0. Set some paths

```
ROOT_PATH=/path/to/the/DynamicNeRF/folder
DATASET_NAME=name_of_the_video_without_extension
DATASET_PATH=$ROOT_PATH/data/$DATASET_NAME
```

1. Prepare training images and background masks from a video.

```
cd $ROOT_PATH/utils
python generate_data.py --videopath /path/to/the/video
```

2. Use COLMAP to obtain camera poses.

```
colmap feature_extractor \
--database_path $DATASET_PATH/database.db \
--image_path $DATASET_PATH/images_colmap \
--ImageReader.mask_path $DATASET_PATH/background_mask \
--ImageReader.single_camera 1

colmap exhaustive_matcher \
--database_path $DATASET_PATH/database.db

mkdir $DATASET_PATH/sparse
colmap mapper \
    --database_path $DATASET_PATH/database.db \
    --image_path $DATASET_PATH/images_colmap \
    --output_path $DATASET_PATH/sparse \
    --Mapper.num_threads 16 \
    --Mapper.init_min_tri_angle 4 \
    --Mapper.multiple_models 0 \
    --Mapper.extract_colors 0
```

3. Save camera poses into the format that NeRF reads.

```
cd $ROOT_PATH/utils
python generate_pose.py --dataset_path $DATASET_PATH
```

4. Estimate monocular depth.

```
cd $ROOT_PATH/utils
python generate_depth.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/midas_v21-f6b98070.pt
```

5. Predict optical flows.

```
cd $ROOT_PATH/utils
python generate_flow.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/raft-things.pth
```

6. Obtain motion mask (code adapted from NSFF).

```
cd $ROOT_PATH/utils
python generate_motion_mask.py --dataset_path $DATASET_PATH
```

7. Train a model. Please change `expname` and `datadir` in `configs/config.txt`.

```
cd $ROOT_PATH/
python run_nerf.py --config configs/config.txt
```

Explanation of each parameter:

- `expname`: experiment name
- `basedir`: where to store ckpts and logs
- `datadir`: input data directory
- `factor`: downsample factor for the input images
- `N_rand`: number of random rays per gradient step
- `N_samples`: number of samples per ray
- `netwidth`: channels per layer
- `use_viewdirs`: whether enable view-dependency for StaticNeRF
- `use_viewdirsDyn`: whether enable view-dependency for DynamicNeRF
- `raw_noise_std`: std dev of noise added to regularize sigma_a output
- `no_ndc`: do not use normalized device coordinates
- `lindisp`: sampling linearly in disparity rather than depth
- `i_video`: frequency of novel view-time synthesis video saving
- `i_testset`: frequency of testset video saving
- `N_iters`: number of training iterations
- `i_img`: frequency of tensorboard image logging
- `DyNeRF_blending`: whether use DynamicNeRF to predict blending weight
- `pretrain`: whether pre-train StaticNeRF

## License
This work is licensed under MIT License. See [LICENSE](LICENSE) for details.

If you find this code useful for your research, please consider citing the following paper:

	@inproceedings{Gao-ICCV-DynNeRF,
	    author    = {Gao, Chen and Saraf, Ayush and Kopf, Johannes and Huang, Jia-Bin},
	    title     = {Dynamic View Synthesis from Dynamic Monocular Video},
	    booktitle = {Proceedings of the IEEE International Conference on Computer Vision},
	    year      = {2021}
	}

## Acknowledgments
Our training code is build upon
[NeRF](https://github.com/bmild/nerf),
[NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch), and
[NSFF](https://github.com/zl548/Neural-Scene-Flow-Fields).
Our flow prediction code is modified from [RAFT](https://github.com/princeton-vl/RAFT).
Our depth prediction code is modified from [MiDaS](https://github.com/isl-org/MiDaS).


================================================
FILE: configs/config.txt
================================================
expname = xxxxxx_DyNeRF_pretrain_test
basedir = ./logs
datadir = ./data/xxxxxx/

dataset_type = llff

factor = 4
N_rand = 1024
N_samples = 64
netwidth = 256

i_video = 100000
i_testset = 100000
N_iters = 500001
i_img = 500

use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False

dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.01
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True


================================================
FILE: configs/config_Balloon1.txt
================================================
expname = Balloon1_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Balloon1/

dataset_type = llff

factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256

i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500

use_viewdirs = True
use_viewdirsDyn = False
raw_noise_std = 1e0
no_ndc = False
lindisp = False

dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True


================================================
FILE: configs/config_Balloon2.txt
================================================
expname = Balloon2_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Balloon2/

dataset_type = llff

factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256

i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500

use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False

dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True


================================================
FILE: configs/config_Jumping.txt
================================================
expname = Jumping_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Jumping/

dataset_type = llff

factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256

i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500

use_viewdirs = True
use_viewdirsDyn = False
raw_noise_std = 1e0
no_ndc = False
lindisp = False

dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True


================================================
FILE: configs/config_Playground.txt
================================================
expname = Playground_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Playground/

dataset_type = llff

factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256

i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500

use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False

dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True


================================================
FILE: configs/config_Skating.txt
================================================
expname = Skating_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Skating/

dataset_type = llff

factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256

i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500

use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False

dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True


================================================
FILE: configs/config_Truck.txt
================================================
expname = Truck_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Truck/

dataset_type = llff

factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256

i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500

use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False

dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True


================================================
FILE: configs/config_Umbrella.txt
================================================
expname = Umbrella_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Umbrella/

dataset_type = llff

factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256

i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500

use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False

dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True


================================================
FILE: load_llff.py
================================================
import os
import cv2
import imageio
import numpy as np

from utils.flow_utils import resize_flow
from run_nerf_helpers import get_grid


def _minify(basedir, factors=[], resolutions=[]):
    needtoload = False
    for r in factors:
        imgdir = os.path.join(basedir, 'images_{}'.format(r))
        if not os.path.exists(imgdir):
            needtoload = True
    for r in resolutions:
        imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
        if not os.path.exists(imgdir):
            needtoload = True
    if not needtoload:
        return

    from shutil import copy
    from subprocess import check_output

    imgdir = os.path.join(basedir, 'images')
    imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
    imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
    imgdir_orig = imgdir

    wd = os.getcwd()

    for r in factors + resolutions:
        if isinstance(r, int):
            name = 'images_{}'.format(r)
            resizearg = '{}%'.format(100./r)
        else:
            name = 'images_{}x{}'.format(r[1], r[0])
            resizearg = '{}x{}'.format(r[1], r[0])
        imgdir = os.path.join(basedir, name)
        if os.path.exists(imgdir):
            continue

        print('Minifying', r, basedir)

        os.makedirs(imgdir)
        check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)

        ext = imgs[0].split('.')[-1]
        args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
        print(args)
        os.chdir(imgdir)
        check_output(args, shell=True)
        os.chdir(wd)

        if ext != 'png':
            check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
            print('Removed duplicates')
        print('Done')


def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
    print('factor ', factor)
    poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
    poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
    bds = poses_arr[:, -2:].transpose([1,0])

    img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
            if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
    sh = imageio.imread(img0).shape

    sfx = ''

    if factor is not None:
        sfx = '_{}'.format(factor)
        _minify(basedir, factors=[factor])
        factor = factor
    elif height is not None:
        factor = sh[0] / float(height)
        width = int(sh[1] / factor)
        if width % 2 == 1:
            width -= 1
        _minify(basedir, resolutions=[[height, width]])
        sfx = '_{}x{}'.format(width, height)
    elif width is not None:
        factor = sh[1] / float(width)
        height = int(sh[0] / factor)
        if height % 2 == 1:
            height -= 1
        _minify(basedir, resolutions=[[height, width]])
        sfx = '_{}x{}'.format(width, height)
    else:
        factor = 1

    imgdir = os.path.join(basedir, 'images' + sfx)
    if not os.path.exists(imgdir):
        print( imgdir, 'does not exist, returning' )
        return

    imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) \
                if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
    if poses.shape[-1] != len(imgfiles):
        print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
        return

    sh = imageio.imread(imgfiles[0]).shape
    num_img = len(imgfiles)
    poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
    poses[2, 4, :] = poses[2, 4, :] * 1./factor

    if not load_imgs:
        return poses, bds

    def imread(f):
        if f.endswith('png'):
            return imageio.imread(f, ignoregamma=True)
        else:
            return imageio.imread(f)

    imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
    imgs = np.stack(imgs, -1)

    assert imgs.shape[0] == sh[0]
    assert imgs.shape[1] == sh[1]

    disp_dir = os.path.join(basedir, 'disp')

    dispfiles = [os.path.join(disp_dir, f) \
                for f in sorted(os.listdir(disp_dir)) if f.endswith('npy')]

    disp = [cv2.resize(np.load(f),
                    (sh[1], sh[0]),
                    interpolation=cv2.INTER_NEAREST) for f in dispfiles]
    disp = np.stack(disp, -1)

    mask_dir = os.path.join(basedir, 'motion_masks')
    maskfiles = [os.path.join(mask_dir, f) \
                for f in sorted(os.listdir(mask_dir)) if f.endswith('png')]

    masks = [cv2.resize(imread(f)/255., (sh[1], sh[0]),
                        interpolation=cv2.INTER_NEAREST) for f in maskfiles]
    masks = np.stack(masks, -1)
    masks = np.float32(masks > 1e-3)

    flow_dir = os.path.join(basedir, 'flow')
    flows_f = []
    flow_masks_f = []
    flows_b = []
    flow_masks_b = []
    for i in range(num_img):
        if i == num_img - 1:
            fwd_flow, fwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1]))
        else:
            fwd_flow_path = os.path.join(flow_dir, '%03d_fwd.npz'%i)
            fwd_data = np.load(fwd_flow_path)
            fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']
            fwd_flow = resize_flow(fwd_flow, sh[0], sh[1])
            fwd_mask = np.float32(fwd_mask)
            fwd_mask = cv2.resize(fwd_mask, (sh[1], sh[0]),
                                interpolation=cv2.INTER_NEAREST)
        flows_f.append(fwd_flow)
        flow_masks_f.append(fwd_mask)

        if i == 0:
            bwd_flow, bwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1]))
        else:
            bwd_flow_path = os.path.join(flow_dir, '%03d_bwd.npz'%i)
            bwd_data = np.load(bwd_flow_path)
            bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']
            bwd_flow = resize_flow(bwd_flow, sh[0], sh[1])
            bwd_mask = np.float32(bwd_mask)
            bwd_mask = cv2.resize(bwd_mask, (sh[1], sh[0]),
                                interpolation=cv2.INTER_NEAREST)
        flows_b.append(bwd_flow)
        flow_masks_b.append(bwd_mask)

    flows_f = np.stack(flows_f, -1)
    flow_masks_f = np.stack(flow_masks_f, -1)
    flows_b = np.stack(flows_b, -1)
    flow_masks_b = np.stack(flow_masks_b, -1)

    print(imgs.shape)
    print(disp.shape)
    print(masks.shape)
    print(flows_f.shape)
    print(flow_masks_f.shape)

    assert(imgs.shape[0] == disp.shape[0])
    assert(imgs.shape[0] == masks.shape[0])
    assert(imgs.shape[0] == flows_f.shape[0])
    assert(imgs.shape[0] == flow_masks_f.shape[0])

    assert(imgs.shape[1] == disp.shape[1])
    assert(imgs.shape[1] == masks.shape[1])

    return poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b


def normalize(x):
    return x / np.linalg.norm(x)

def viewmatrix(z, up, pos):
    vec2 = normalize(z)
    vec1_avg = up
    vec0 = normalize(np.cross(vec1_avg, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.stack([vec0, vec1, vec2, pos], 1)
    return m


def poses_avg(poses):

    hwf = poses[0, :3, -1:]

    center = poses[:, :3, 3].mean(0)
    vec2 = normalize(poses[:, :3, 2].sum(0))
    up = poses[:, :3, 1].sum(0)
    c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)

    return c2w



def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
    render_poses = []
    rads = np.array(list(rads) + [1.])
    hwf = c2w[:,4:5]

    for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
        c = np.dot(c2w[:3, :4],
                    np.array([np.cos(theta),
                             -np.sin(theta),
                             -np.sin(theta*zrate),
                              1.]) * rads)
        z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
        render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
    return render_poses



def recenter_poses(poses):

    poses_ = poses+0
    bottom = np.reshape([0,0,0,1.], [1,4])
    c2w = poses_avg(poses)
    c2w = np.concatenate([c2w[:3,:4], bottom], -2)
    bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])
    poses = np.concatenate([poses[:,:3,:4], bottom], -2)

    poses = np.linalg.inv(c2w) @ poses
    poses_[:,:3,:4] = poses[:,:3,:4]
    poses = poses_
    return poses


def spherify_poses(poses, bds):

    p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)

    rays_d = poses[:,:3,2:3]
    rays_o = poses[:,:3,3:4]

    def min_line_dist(rays_o, rays_d):
        A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
        b_i = -A_i @ rays_o
        pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
        return pt_mindist

    pt_mindist = min_line_dist(rays_o, rays_d)

    center = pt_mindist
    up = (poses[:,:3,3] - center).mean(0)

    vec0 = normalize(up)
    vec1 = normalize(np.cross([.1,.2,.3], vec0))
    vec2 = normalize(np.cross(vec0, vec1))
    pos = center
    c2w = np.stack([vec1, vec2, vec0, pos], 1)

    poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])

    rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))

    sc = 1./rad
    poses_reset[:,:3,3] *= sc
    bds *= sc
    rad *= sc

    centroid = np.mean(poses_reset[:,:3,3], 0)
    zh = centroid[2]
    radcircle = np.sqrt(rad**2-zh**2)
    new_poses = []

    for th in np.linspace(0.,2.*np.pi, 120):

        camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
        up = np.array([0,0,-1.])

        vec2 = normalize(camorigin)
        vec0 = normalize(np.cross(vec2, up))
        vec1 = normalize(np.cross(vec2, vec0))
        pos = camorigin
        p = np.stack([vec0, vec1, vec2, pos], 1)

        new_poses.append(p)

    new_poses = np.stack(new_poses, 0)

    new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
    poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)

    return poses_reset, new_poses, bds


def load_llff_data(args, basedir,
                   factor=2,
                   recenter=True, bd_factor=.75,
                   spherify=False, path_zflat=False,
                   frame2dolly=10):

    poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b = \
        _load_data(basedir, factor=factor) # factor=2 downsamples original imgs by 2x

    print('Loaded', basedir, bds.min(), bds.max())

    # Correct rotation matrix ordering and move variable dim to axis 0
    poses = np.concatenate([poses[:, 1:2, :],
                           -poses[:, 0:1, :],
                            poses[:, 2:, :]], 1)
    poses = np.moveaxis(poses, -1, 0).astype(np.float32)
    images = np.moveaxis(imgs, -1, 0).astype(np.float32)
    bds = np.moveaxis(bds, -1, 0).astype(np.float32)
    disp = np.moveaxis(disp, -1, 0).astype(np.float32)
    masks = np.moveaxis(masks, -1, 0).astype(np.float32)
    flows_f = np.moveaxis(flows_f, -1, 0).astype(np.float32)
    flow_masks_f = np.moveaxis(flow_masks_f, -1, 0).astype(np.float32)
    flows_b = np.moveaxis(flows_b, -1, 0).astype(np.float32)
    flow_masks_b = np.moveaxis(flow_masks_b, -1, 0).astype(np.float32)

    # Rescale if bd_factor is provided
    sc = 1. if bd_factor is None else 1./(np.percentile(bds[:, 0], 5) * bd_factor)

    poses[:, :3, 3] *= sc
    bds *= sc

    if recenter:
        poses = recenter_poses(poses)

    # Only for rendering
    if frame2dolly == -1:
        c2w = poses_avg(poses)
    else:
        c2w = poses[frame2dolly, :, :]

    H, W, _ = c2w[:, -1]

    # Generate poses for novel views
    render_poses, render_focals = generate_path(c2w, args)
    render_poses = np.array(render_poses).astype(np.float32)

    grids = get_grid(int(H), int(W), len(poses), flows_f, flow_masks_f, flows_b, flow_masks_b) # [N, H, W, 8]

    return images, disp, masks, poses, bds,\
        render_poses, render_focals, grids


def generate_path(c2w, args):
    hwf = c2w[:, 4:5]
    num_novelviews = args.num_novelviews
    max_disp = 48.0
    H, W, focal = hwf[:, 0]

    max_trans = max_disp / focal
    output_poses = []
    output_focals = []

    # Rendering teaser. Add translation.
    for i in range(num_novelviews):
        x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.x_trans_multiplier
        y_trans = max_trans * (np.cos(2.0 * np.pi * float(i) / float(num_novelviews)) - 1.) * args.y_trans_multiplier
        z_trans = 0.

        i_pose = np.concatenate([
            np.concatenate(
                [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
            np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
        ],axis=0)

        i_pose = np.linalg.inv(i_pose)

        ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)

        render_pose = np.dot(ref_pose, i_pose)
        output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
        output_focals.append(focal)

    # Rendering teaser. Add zooming.
    if args.frame2dolly != -1:
        for i in range(num_novelviews // 2 + 1):
            x_trans = 0.
            y_trans = 0.
            # z_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.z_trans_multiplier
            z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2)
            i_pose = np.concatenate([
                np.concatenate(
                    [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
                np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
            ],axis=0)

            i_pose = np.linalg.inv(i_pose) #torch.tensor(np.linalg.inv(i_pose)).float()

            ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)

            render_pose = np.dot(ref_pose, i_pose)
            output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
            output_focals.append(focal)
            print(z_trans / max_trans / args.z_trans_multiplier)

    # Rendering teaser. Add dolly zoom.
    if args.frame2dolly != -1:
        for i in range(num_novelviews // 2 + 1):
            x_trans = 0.
            y_trans = 0.
            z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2)
            i_pose = np.concatenate([
                np.concatenate(
                    [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
                np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
            ],axis=0)

            i_pose = np.linalg.inv(i_pose)

            ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)

            render_pose = np.dot(ref_pose, i_pose)
            output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
            new_focal = focal - args.focal_decrease * z_trans / max_trans / args.z_trans_multiplier
            output_focals.append(new_focal)
            print(z_trans / max_trans / args.z_trans_multiplier, new_focal)

    return output_poses, output_focals


================================================
FILE: render_utils.py
================================================
import os
import time
import torch
import imageio
import numpy as np
import torch.nn.functional as F

from run_nerf_helpers import *
from utils.flow_utils import flow_to_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def batchify_rays(t, chain_5frames,
                rays_flat, chunk=1024*16, **kwargs):
    """Render rays in smaller minibatches to avoid OOM.
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(t, chain_5frames, rays_flat[i:i+chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret


def render(t, chain_5frames,
           H, W, focal, focal_render=None,
           chunk=1024*16, rays=None, c2w=None, ndc=True,
           near=0., far=1.,
           use_viewdirs=False, c2w_staticcam=None,
           **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """

    if c2w is not None:
        # special case to render full image
        if focal_render is not None:
            # Render full image using different focal length for dolly zoom. Inference only.
            rays_o, rays_d = get_rays(H, W, focal_render, c2w)
        else:
            rays_o, rays_d = get_rays(H, W, focal, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            raise NotImplementedError
        # Make all directions unit magnitude.
        # shape: [batch_size, 3]
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1, 3]).float()

    sh = rays_d.shape # [..., 3]
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1, 3]).float()
    rays_d = torch.reshape(rays_d, [-1, 3]).float()
    near, far = near * \
        torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])

    # (ray origin, ray direction, min dist, max dist) for each ray
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    # Render and reshape
    all_ret = batchify_rays(t, chain_5frames,
                        rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    return all_ret


def render_path_batch(render_poses, time2render,
                    hwf, chunk, render_kwargs, savedir=None, focal2render=None):
    """Render frames using batch.

    Args:
      render_poses: array of shape [num_frame, 3, 4]. Camera-to-world transformation matrix of each frame.
      time2render: array of shape [num_frame]. Time of each frame.
      hwf: list. [Height of image in pixels, Width of image in pixels, Focal length of pinhole camera]
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      render_kwargs: dictionary. args for the render function.
      savedir: string. Directory to save results.
      focal2render: list. Only used to perform dolly-zoom.
    Returns:
      ret_dict: dictionary. Final and intermediate results.
    """
    H, W, focal = hwf

    ret_dict = {}
    rgbs = []
    rgbs_d = []
    rgbs_s = []
    dynamicnesses = []

    time_curr = time.time()
    for i, c2w in enumerate(render_poses):

        print(i, time.time() - time_curr)
        time_curr = time.time()

        t = time2render[i]

        if focal2render is not None:
            # Render full image using different focal length
            rays_o, rays_d = get_rays(H, W, focal2render[i], c2w)
        else:
            rays_o, rays_d = get_rays(H, W, focal, c2w)
        rays_o = torch.reshape(rays_o, (-1, 3))
        rays_d = torch.reshape(rays_d, (-1, 3))
        batch_rays = torch.stack([rays_o, rays_d], 0)
        rgb = []
        rgb_d = []
        rgb_s = []
        dynamicness = []
        for j in range(0, batch_rays.shape[1], chunk):
            # print(j, '/', batch_rays.shape[1])
            ret = render(t, False,
                         H, W, focal,
                         chunk=chunk, rays=batch_rays[:, j:j+chunk, :],
                         **render_kwargs)
            rgb.append(ret['rgb_map_full'].cpu())
            rgb_d.append(ret['rgb_map_d'].cpu())
            rgb_s.append(ret['rgb_map_s'].cpu())
            dynamicness.append(ret['dynamicness_map'].cpu())
        rgb = torch.reshape(torch.cat(rgb, 0), (H, W, 3)).numpy()
        rgb_d = torch.reshape(torch.cat(rgb_d, 0), (H, W, 3)).numpy()
        rgb_s = torch.reshape(torch.cat(rgb_s, 0), (H, W, 3)).numpy()
        dynamicness = torch.reshape(torch.cat(dynamicness, 0), (H, W)).numpy()

        # Not a good solution. Should take care of this when preparing the data.
        if W%2 == 1:
            # rgb = cv2.resize(rgb, (W - 1, H))
            rgb = rgb[:, :-1, :]
            rgb_d = rgb_d[:, :-1, :]
            rgb_s = rgb_s[:, :-1, :]
            dynamicness = dynamicness[:, :-1]
        rgbs.append(rgb)
        rgbs_d.append(rgb_d)
        rgbs_s.append(rgb_s)
        dynamicnesses.append(dynamicness)

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)

    ret_dict['rgbs'] = np.stack(rgbs, 0)
    ret_dict['rgbs_d'] = np.stack(rgbs_d, 0)
    ret_dict['rgbs_s'] = np.stack(rgbs_s, 0)
    ret_dict['dynamicnesses'] = np.stack(dynamicnesses, 0)

    return ret_dict


def render_path(render_poses,
                time2render,
                hwf,
                chunk,
                render_kwargs,
                savedir=None,
                flows_gt_f=None,
                flows_gt_b=None,
                focal2render=None):
    """Render frames.

    Args:
      render_poses: array of shape [num_frame, 3, 4]. Camera-to-world transformation matrix of each frame.
      time2render: array of shape [num_frame]. Time of each frame.
      hwf: list. [Height of image in pixels, Width of image in pixels, Focal length of pinhole camera]
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      render_kwargs: dictionary. args for the render function.
      savedir: string. Directory to save results.
      focal2render: list. Only used to perform dolly-zoom.
    Returns:
      ret_dict: dictionary. Final and intermediate results.
    """
    H, W, focal = hwf

    ret_dict = {}
    rgbs = []
    rgbs_d = []
    rgbs_s = []
    depths = []
    depths_d = []
    depths_s = []
    flows_f = []
    flows_b = []
    dynamicness = []
    blending = []

    grid = np.stack(np.meshgrid(np.arange(W, dtype=np.float32),
                       np.arange(H, dtype=np.float32), indexing='xy'), -1)
    grid = torch.Tensor(grid)
    time_curr = time.time()
    for i, c2w in enumerate(render_poses):
        t = time2render[i]
        pose = c2w[:3, :4]
        print(i, time.time() - time_curr)
        time_curr = time.time()

        if focal2render is None:
            # Normal rendering.
            ret = render(t, False,
                         H, W, focal,
                         chunk=1024*32, c2w=pose,
                         **render_kwargs)
        else:
            # Render image using different focal length.
            ret = render(t, False,
                         H, W, focal, focal_render=focal2render[i],
                         chunk=1024*32, c2w=pose,
                         **render_kwargs)

        rgbs.append(ret['rgb_map_full'].cpu().numpy())
        rgbs_d.append(ret['rgb_map_d'].cpu().numpy())
        rgbs_s.append(ret['rgb_map_s'].cpu().numpy())

        depths.append(ret['depth_map_full'].cpu().numpy())
        depths_d.append(ret['depth_map_d'].cpu().numpy())
        depths_s.append(ret['depth_map_s'].cpu().numpy())

        dynamicness.append(ret['dynamicness_map'].cpu().numpy())

        if flows_gt_f is not None:
            # Reconstruction. Flow is caused by both changing camera and changing time.
            pose_f = render_poses[min(i + 1, int(len(render_poses)) - 1), :3, :4]
            pose_b = render_poses[max(i - 1, 0), :3, :4]
        else:
            # Non training view-time. Flow is caused by changing time (just for visualization).
            pose_f = render_poses[i, :3, :4]
            pose_b = render_poses[i, :3, :4]

        # Sceneflow induced optical flow
        induced_flow_f_ = induce_flow(H, W, focal, pose_f, ret['weights_d'], ret['raw_pts_f'], grid[..., :2])
        induced_flow_b_ = induce_flow(H, W, focal, pose_b, ret['weights_d'], ret['raw_pts_b'], grid[..., :2])

        if (i + 1) >= len(render_poses):
            induced_flow_f = np.zeros((H, W, 2))
        else:
            induced_flow_f = induced_flow_f_.cpu().numpy()
        if flows_gt_f is not None:
            flow_gt_f = flows_gt_f[i].cpu().numpy()
            induced_flow_f = np.concatenate((induced_flow_f, flow_gt_f), 0)
        induced_flow_f_img = flow_to_image(induced_flow_f)
        flows_f.append(induced_flow_f_img)

        if (i - 1) < 0:
            induced_flow_b = np.zeros((H, W, 2))
        else:
            induced_flow_b = induced_flow_b_.cpu().numpy()
        if flows_gt_b is not None:
            flow_gt_b = flows_gt_b[i].cpu().numpy()
            induced_flow_b = np.concatenate((induced_flow_b, flow_gt_b), 0)
        induced_flow_b_img = flow_to_image(induced_flow_b)
        flows_b.append(induced_flow_b_img)

        if i == 0:
            ret_dict['sceneflow_f_NDC'] = ret['sceneflow_f'].cpu().numpy()
            ret_dict['sceneflow_b_NDC'] = ret['sceneflow_b'].cpu().numpy()
            ret_dict['blending'] = ret['blending'].cpu().numpy()

            weights = np.concatenate((ret['weights_d'][..., None].cpu().numpy(),
                                      ret['weights_s'][..., None].cpu().numpy(),
                                      ret['blending'][..., None].cpu().numpy(),
                                      ret['weights_full'][..., None].cpu().numpy()))
            ret_dict['weights'] = np.moveaxis(weights, [0, 1, 2, 3], [1, 2, 0, 3])

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)

    ret_dict['rgbs'] = np.stack(rgbs, 0)
    ret_dict['rgbs_d'] = np.stack(rgbs_d, 0)
    ret_dict['rgbs_s'] = np.stack(rgbs_s, 0)
    ret_dict['depths'] = np.stack(depths, 0)
    ret_dict['depths_d'] = np.stack(depths_d, 0)
    ret_dict['depths_s'] = np.stack(depths_s, 0)
    ret_dict['dynamicness'] = np.stack(dynamicness, 0)
    ret_dict['flows_f'] = np.stack(flows_f, 0)
    ret_dict['flows_b'] = np.stack(flows_b, 0)

    return ret_dict


def raw2outputs(raw_s,
                raw_d,
                blending,
                z_vals,
                rays_d,
                raw_noise_std):
    """Transforms model's predictions to semantically meaningful values.

    Args:
      raw_d: [num_rays, num_samples along ray, 4]. Prediction from Dynamic model.
      raw_s: [num_rays, num_samples along ray, 4]. Prediction from Static model.
      z_vals: [num_rays, num_samples along ray]. Integration time.
      rays_d: [num_rays, 3]. Direction of each ray.

    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
      disp_map: [num_rays]. Disparity map. Inverse of depth map.
      acc_map: [num_rays]. Sum of weights along each ray.
      weights: [num_rays, num_samples]. Weights assigned to each sampled color.
      depth_map: [num_rays]. Estimated distance to object.
    """
    # Function for computing density from model prediction. This value is
    # strictly between [0, 1].
    def raw2alpha(raw, dists, act_fn=F.relu): return 1.0 - \
        torch.exp(-act_fn(raw) * dists)

    # Compute 'distance' (in time) between each integration time along a ray.
    dists = z_vals[..., 1:] - z_vals[..., :-1]

    # The 'distance' from the last integration time is infinity.
    dists = torch.cat(
        [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)],
         -1) # [N_rays, N_samples]

    # Multiply each distance by the norm of its corresponding direction ray
    # to convert to real world distance (accounts for non-unit directions).
    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

    # Extract RGB of each sample position along each ray.
    rgb_d = torch.sigmoid(raw_d[..., :3])  # [N_rays, N_samples, 3]
    rgb_s = torch.sigmoid(raw_s[..., :3])  # [N_rays, N_samples, 3]

    # Add noise to model's predictions for density. Can be used to
    # regularize network during training (prevents floater artifacts).
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw_d[..., 3].shape) * raw_noise_std

    # Predict density of each sample along each ray. Higher values imply
    # higher likelihood of being absorbed at this point.
    alpha_d = raw2alpha(raw_d[..., 3] + noise, dists) # [N_rays, N_samples]
    alpha_s = raw2alpha(raw_s[..., 3] + noise, dists) # [N_rays, N_samples]
    alphas  = 1. - (1. - alpha_s) * (1. - alpha_d) # [N_rays, N_samples]

    T_d    = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), 1. - alpha_d + 1e-10], -1), -1)[:, :-1]
    T_s    = torch.cumprod(torch.cat([torch.ones((alpha_s.shape[0], 1)), 1. - alpha_s + 1e-10], -1), -1)[:, :-1]
    T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), (1. - alpha_d * blending) * (1. - alpha_s * (1. - blending)) + 1e-10], -1), -1)[:, :-1]
    # T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), torch.pow(1. - alpha_d + 1e-10, blending) * torch.pow(1. - alpha_s + 1e-10, 1. - blending)], -1), -1)[:, :-1]
    # T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), (1. - alpha_d) * (1. - alpha_s) + 1e-10], -1), -1)[:, :-1]

    # Compute weight for RGB of each sample along each ray.  A cumprod() is
    # used to express the idea of the ray not having reflected up to this
    # sample yet.
    weights_d = alpha_d * T_d
    weights_s = alpha_s * T_s
    weights_full = (alpha_d * blending + alpha_s * (1. - blending)) * T_full
    # weights_full = alphas * T_full

    # Computed weighted color of each sample along each ray.
    rgb_map_d = torch.sum(weights_d[..., None] * rgb_d, -2)
    rgb_map_s = torch.sum(weights_s[..., None] * rgb_s, -2)
    rgb_map_full = torch.sum(
        (T_full * alpha_d * blending)[..., None] * rgb_d + \
        (T_full * alpha_s * (1. - blending))[..., None] * rgb_s, -2)

    # Estimated depth map is expected distance.
    depth_map_d = torch.sum(weights_d * z_vals, -1)
    depth_map_s = torch.sum(weights_s * z_vals, -1)
    depth_map_full = torch.sum(weights_full * z_vals, -1)

    # Sum of weights along each ray. This value is in [0, 1] up to numerical error.
    acc_map_d = torch.sum(weights_d, -1)
    acc_map_s = torch.sum(weights_s, -1)
    acc_map_full = torch.sum(weights_full, -1)

    # Computed dynamicness
    dynamicness_map = torch.sum(weights_full * blending, -1)
    # dynamicness_map = 1 - T_d[..., -1]

    return rgb_map_full, depth_map_full, acc_map_full, weights_full, \
           rgb_map_s, depth_map_s, acc_map_s, weights_s, \
           rgb_map_d, depth_map_d, acc_map_d, weights_d, dynamicness_map


def raw2outputs_d(raw_d,
                  z_vals,
                  rays_d,
                  raw_noise_std):

    # Function for computing density from model prediction. This value is
    # strictly between [0, 1].
    def raw2alpha(raw, dists, act_fn=F.relu): return 1.0 - \
        torch.exp(-act_fn(raw) * dists)

    # Compute 'distance' (in time) between each integration time along a ray.
    dists = z_vals[..., 1:] - z_vals[..., :-1]

    # The 'distance' from the last integration time is infinity.
    dists = torch.cat(
        [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)],
        -1)  # [N_rays, N_samples]

    # Multiply each distance by the norm of its corresponding direction ray
    # to convert to real world distance (accounts for non-unit directions).
    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

    # Extract RGB of each sample position along each ray.
    rgb_d = torch.sigmoid(raw_d[..., :3])  # [N_rays, N_samples, 3]

    # Add noise to model's predictions for density. Can be used to
    # regularize network during training (prevents floater artifacts).
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw_d[..., 3].shape) * raw_noise_std

    # Predict density of each sample along each ray. Higher values imply
    # higher likelihood of being absorbed at this point.
    alpha_d = raw2alpha(raw_d[..., 3] + noise, dists)  # [N_rays, N_samples]

    T_d = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), 1. - alpha_d + 1e-10], -1), -1)[:, :-1]
    # Compute weight for RGB of each sample along each ray.  A cumprod() is
    # used to express the idea of the ray not having reflected up to this
    # sample yet.
    weights_d = alpha_d * T_d

    # Computed weighted color of each sample along each ray.
    rgb_map_d = torch.sum(weights_d[..., None] * rgb_d, -2)

    return rgb_map_d, weights_d


def render_rays(t,
                chain_5frames,
                ray_batch,
                network_fn_d,
                network_fn_s,
                network_query_fn_d,
                network_query_fn_s,
                N_samples,
                num_img,
                DyNeRF_blending,
                pretrain=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                raw_noise_std=0.,
                inference=False):

    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn_d: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn_d: function used for passing queries to network_fn_d.
      N_samples: int. Number of different times to sample along each ray.
      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      raw_noise_std: ...
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disp_map: [num_rays]. Disparity map. 1 / depth.
      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disp_map. Output for coarse model.
      acc0: See acc_map. Output for coarse model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """

    # batch size
    N_rays = ray_batch.shape[0]

    # ray_batch: [N_rays, 11]
    # rays_o:    [N_rays, 0:3]
    # rays_d:    [N_rays, 3:6]
    # near:      [N_rays, 6:7]
    # far:       [N_rays, 7:8]
    # viewdirs:  [N_rays, 8:11]

    # Extract ray origin, direction.
    rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each

    # Extract unit-normalized viewing direction.
    viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None

    # Extract lower, upper bound for ray distance.
    bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])
    near, far = bounds[..., 0], bounds[..., 1]

    # Decide where to sample along each ray. Under the logic, all rays will be sampled at
    # the same times.
    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
        # Space integration times linearly between 'near' and 'far'. Same
        # integration points will be used for all rays.
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        # Sample linearly in inverse depth (disparity).
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
    z_vals = z_vals.expand([N_rays, N_samples])

    # Perturb sampling time along each ray.
    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
        upper = torch.cat([mids, z_vals[..., -1:]], -1)
        lower = torch.cat([z_vals[..., :1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)
        z_vals = lower + (upper - lower) * t_rand

    # Points in space to evaluate model at.
    pts = rays_o[..., None, :] + rays_d[..., None, :] * \
        z_vals[..., :, None] # [N_rays, N_samples, 3]

    # Add the time dimension to xyz.
    pts_ref = torch.cat([pts, torch.ones_like(pts[..., 0:1]) * t], -1)

    # First pass: we have the staticNeRF results
    raw_s = network_query_fn_s(pts_ref[..., :3], viewdirs, network_fn_s)
    # raw_s:          [N_rays, N_samples, 5]
    # raw_s_rgb:      [N_rays, N_samples, 0:3]
    # raw_s_a:        [N_rays, N_samples, 3:4]
    # raw_s_blending: [N_rays, N_samples, 4:5]

    # Second pass: we have the DyanmicNeRF results and the blending weight
    raw_d = network_query_fn_d(pts_ref, viewdirs, network_fn_d)
    # raw_d:          [N_rays, N_samples, 11]
    # raw_d_rgb:      [N_rays, N_samples, 0:3]
    # raw_d_a:        [N_rays, N_samples, 3:4]
    # sceneflow_b:    [N_rays, N_samples, 4:7]
    # sceneflow_f:    [N_rays, N_samples, 7:10]
    # raw_d_blending: [N_rays, N_samples, 10:11]

    if pretrain:
        rgb_map_s, _ = raw2outputs_d(raw_s[..., :4],
                                     z_vals,
                                     rays_d,
                                     raw_noise_std)
        ret = {'rgb_map_s': rgb_map_s}
        return ret

    raw_s_rgba = raw_s[..., :4]
    raw_d_rgba = raw_d[..., :4]

    # We need the sceneflow from the dynamicNeRF.
    sceneflow_b = raw_d[..., 4:7]
    sceneflow_f = raw_d[..., 7:10]

    if DyNeRF_blending:
        blending = raw_d[..., 10]
    else:
        blending = raw_s[..., 4]

    # if sfmask:
    #     sceneflow_f = sceneflow_f * blending.detach()[..., None]
    #     sceneflow_b = sceneflow_b * blending.detach()[..., None]

    # Rerndering.
    rgb_map_full, depth_map_full, acc_map_full, weights_full, \
    rgb_map_s, depth_map_s, acc_map_s, weights_s, \
    rgb_map_d, depth_map_d, acc_map_d, weights_d, \
    dynamicness_map = raw2outputs(raw_s_rgba,
                                  raw_d_rgba,
                                  blending,
                                  z_vals,
                                  rays_d,
                                  raw_noise_std)

    ret = {'rgb_map_full': rgb_map_full, 'depth_map_full': depth_map_full, 'acc_map_full': acc_map_full, 'weights_full': weights_full,
           'rgb_map_s': rgb_map_s, 'depth_map_s': depth_map_s, 'acc_map_s': acc_map_s, 'weights_s': weights_s,
           'rgb_map_d': rgb_map_d, 'depth_map_d': depth_map_d, 'acc_map_d': acc_map_d, 'weights_d': weights_d,
           'dynamicness_map': dynamicness_map}

    t_interval = 1. / num_img * 2.
    pts_f = torch.cat([pts + sceneflow_f, torch.ones_like(pts[..., 0:1]) * (t + t_interval)], -1)
    pts_b = torch.cat([pts + sceneflow_b, torch.ones_like(pts[..., 0:1]) * (t - t_interval)], -1)

    ret['sceneflow_b'] = sceneflow_b
    ret['sceneflow_f'] = sceneflow_f
    ret['raw_pts'] = pts_ref[..., :3]
    ret['raw_pts_f'] = pts_f[..., :3]
    ret['raw_pts_b'] = pts_b[..., :3]
    ret['blending'] = blending

    # Third pass: we have the DyanmicNeRF results at time t - 1
    raw_d_b = network_query_fn_d(pts_b, viewdirs, network_fn_d)
    raw_d_b_rgba = raw_d_b[..., :4]
    sceneflow_b_b = raw_d_b[..., 4:7]
    sceneflow_b_f = raw_d_b[..., 7:10]

    # Rerndering t - 1
    rgb_map_d_b, weights_d_b = raw2outputs_d(raw_d_b_rgba,
                                             z_vals,
                                             rays_d,
                                             raw_noise_std)

    ret['sceneflow_b_f'] = sceneflow_b_f
    ret['rgb_map_d_b'] = rgb_map_d_b
    ret['acc_map_d_b'] = torch.abs(torch.sum(weights_d_b - weights_d, -1))

    # Fourth pass: we have the DyanmicNeRF results at time t + 1
    raw_d_f = network_query_fn_d(pts_f, viewdirs, network_fn_d)
    raw_d_f_rgba = raw_d_f[..., :4]
    sceneflow_f_b = raw_d_f[..., 4:7]
    sceneflow_f_f = raw_d_f[..., 7:10]

    rgb_map_d_f, weights_d_f = raw2outputs_d(raw_d_f_rgba,
                                             z_vals,
                                             rays_d,
                                             raw_noise_std)

    ret['sceneflow_f_b'] = sceneflow_f_b
    ret['rgb_map_d_f'] = rgb_map_d_f
    ret['acc_map_d_f'] = torch.abs(torch.sum(weights_d_f - weights_d, -1))

    if inference:
        return ret

    # Also consider time t - 2 and t + 2 (Learn from NSFF)

    # Fifth pass: we have the DyanmicNeRF results at time t - 2
    pts_b_b = torch.cat([pts_b[..., :3] + sceneflow_b_b, torch.ones_like(pts[..., 0:1]) * (t - t_interval * 2)], -1)
    ret['raw_pts_b_b'] = pts_b_b[..., :3]

    if chain_5frames:
        raw_d_b_b = network_query_fn_d(pts_b_b, viewdirs, network_fn_d)
        raw_d_b_b_rgba = raw_d_b_b[..., :4]
        rgb_map_d_b_b, _ = raw2outputs_d(raw_d_b_b_rgba,
                                      z_vals,
                                      rays_d,
                                      raw_noise_std)

        ret['rgb_map_d_b_b'] = rgb_map_d_b_b

    # Sixth pass: we have the DyanmicNeRF results at time t + 2
    pts_f_f = torch.cat([pts_f[..., :3] + sceneflow_f_f, torch.ones_like(pts[..., 0:1]) * (t + t_interval * 2)], -1)
    ret['raw_pts_f_f'] = pts_f_f[..., :3]

    if chain_5frames:
        raw_d_f_f = network_query_fn_d(pts_f_f, viewdirs, network_fn_d)
        raw_d_f_f_rgba = raw_d_f_f[..., :4]
        rgb_map_d_f_f, _ = raw2outputs_d(raw_d_f_f_rgba,
                                      z_vals,
                                      rays_d,
                                      raw_noise_std)

        ret['rgb_map_d_f_f'] = rgb_map_d_f_f

    for k in ret:
        if torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any():
            print(f"! [Numerical Error] {k} contains nan or inf.")
            import ipdb; ipdb.set_trace()

    return ret


================================================
FILE: run_nerf.py
================================================
import os
import time
import torch
import imageio
import numpy as np
from torch.utils.tensorboard import SummaryWriter

from render_utils import *
from run_nerf_helpers import *
from load_llff import *
from utils.flow_utils import flow_to_image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def config_parser():

    import configargparse
    parser = configargparse.ArgumentParser()
    parser.add_argument('--config', is_config_file=True,
                        help='config file path')
    parser.add_argument("--expname", type=str,
                        help='experiment name')
    parser.add_argument("--basedir", type=str, default='./logs/',
                        help='where to store ckpts and logs')
    parser.add_argument("--datadir", type=str, default='./data/llff/fern',
                        help='input data directory')

    # training options
    parser.add_argument("--netdepth", type=int, default=8,
                        help='layers in network')
    parser.add_argument("--netwidth", type=int, default=256,
                        help='channels per layer')
    parser.add_argument("--netdepth_fine", type=int, default=8,
                        help='layers in fine network')
    parser.add_argument("--netwidth_fine", type=int, default=256,
                        help='channels per layer in fine network')
    parser.add_argument("--N_rand", type=int, default=32*32*4,
                        help='batch size (number of random rays per gradient step)')
    parser.add_argument("--lrate", type=float, default=5e-4,
                        help='learning rate')
    parser.add_argument("--lrate_decay", type=int, default=300000,
                        help='exponential learning rate decay')
    parser.add_argument("--chunk", type=int, default=1024*128,
                        help='number of rays processed in parallel, decrease if running out of memory')
    parser.add_argument("--netchunk", type=int, default=1024*128,
                        help='number of pts sent through network in parallel, decrease if running out of memory')
    parser.add_argument("--no_reload", action='store_true',
                        help='do not reload weights from saved ckpt')
    parser.add_argument("--ft_path", type=str, default=None,
                        help='specific weights npy file to reload for coarse network')
    parser.add_argument("--random_seed", type=int, default=1,
                        help='fix random seed for repeatability')

    # rendering options
    parser.add_argument("--N_samples", type=int, default=64,
                        help='number of coarse samples per ray')
    parser.add_argument("--N_importance", type=int, default=0,
                        help='number of additional fine samples per ray')
    parser.add_argument("--perturb", type=float, default=1.,
                        help='set to 0. for no jitter, 1. for jitter')
    parser.add_argument("--use_viewdirs", action='store_true',
                        help='use full 5D input instead of 3D')
    parser.add_argument("--use_viewdirsDyn", action='store_true',
                        help='use full 5D input instead of 3D for D-NeRF')
    parser.add_argument("--i_embed", type=int, default=0,
                        help='set 0 for default positional encoding, -1 for none')
    parser.add_argument("--multires", type=int, default=10,
                        help='log2 of max freq for positional encoding (3D location)')
    parser.add_argument("--multires_views", type=int, default=4,
                        help='log2 of max freq for positional encoding (2D direction)')
    parser.add_argument("--raw_noise_std", type=float, default=0.,
                        help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
    parser.add_argument("--render_only", action='store_true',
                        help='do not optimize, reload weights and render out render_poses path')

    # dataset options
    parser.add_argument("--dataset_type", type=str, default='llff',
                        help='options: llff')

    # llff flags
    parser.add_argument("--factor", type=int, default=8,
                        help='downsample factor for LLFF images')
    parser.add_argument("--no_ndc", action='store_true',
                        help='do not use normalized device coordinates (set for non-forward facing scenes)')
    parser.add_argument("--lindisp", action='store_true',
                        help='sampling linearly in disparity rather than depth')
    parser.add_argument("--spherify", action='store_true',
                        help='set for spherical 360 scenes')

    # logging/saving options
    parser.add_argument("--i_print",   type=int, default=500,
                        help='frequency of console printout and metric logging')
    parser.add_argument("--i_img",     type=int, default=500,
                        help='frequency of tensorboard image logging')
    parser.add_argument("--i_weights", type=int, default=10000,
                        help='frequency of weight ckpt saving')
    parser.add_argument("--i_testset", type=int, default=50000,
                        help='frequency of testset saving')
    parser.add_argument("--i_video",   type=int, default=50000,
                        help='frequency of render_poses video saving')
    parser.add_argument("--N_iters", type=int, default=1000000,
                        help='number of training iterations')

    # Dynamic NeRF lambdas
    parser.add_argument("--dynamic_loss_lambda", type=float, default=1.,
                        help='lambda of dynamic loss')
    parser.add_argument("--static_loss_lambda", type=float, default=1.,
                        help='lambda of static loss')
    parser.add_argument("--full_loss_lambda", type=float, default=3.,
                        help='lambda of full loss')
    parser.add_argument("--depth_loss_lambda", type=float, default=0.04,
                        help='lambda of depth loss')
    parser.add_argument("--order_loss_lambda", type=float, default=0.1,
                        help='lambda of order loss')
    parser.add_argument("--flow_loss_lambda", type=float, default=0.02,
                        help='lambda of optical flow loss')
    parser.add_argument("--slow_loss_lambda", type=float, default=0.1,
                        help='lambda of sf slow regularization')
    parser.add_argument("--smooth_loss_lambda", type=float, default=0.1,
                        help='lambda of sf smooth regularization')
    parser.add_argument("--consistency_loss_lambda", type=float, default=0.1,
                        help='lambda of sf cycle consistency regularization')
    parser.add_argument("--mask_loss_lambda", type=float, default=0.1,
                        help='lambda of the mask loss')
    parser.add_argument("--sparse_loss_lambda", type=float, default=0.1,
                        help='lambda of sparse loss')
    parser.add_argument("--DyNeRF_blending", action='store_true',
                        help='use Dynamic NeRF to predict blending weight')
    parser.add_argument("--pretrain", action='store_true',
                        help='Pretrain the StaticneRF')
    parser.add_argument("--ft_path_S", type=str, default=None,
                        help='specific weights npy file to reload for StaticNeRF')

    # For rendering teasers
    parser.add_argument("--frame2dolly", type=int, default=-1,
                        help='choose frame to perform dolly zoom')
    parser.add_argument("--x_trans_multiplier", type=float, default=1.,
                        help='x_trans_multiplier')
    parser.add_argument("--y_trans_multiplier", type=float, default=0.33,
                        help='y_trans_multiplier')
    parser.add_argument("--z_trans_multiplier", type=float, default=5.,
                        help='z_trans_multiplier')
    parser.add_argument("--num_novelviews", type=int, default=60,
                        help='num_novelviews')
    parser.add_argument("--focal_decrease", type=float, default=200,
                        help='focal_decrease')
    return parser


def train():

    parser = config_parser()
    args = parser.parse_args()

    if args.random_seed is not None:
        print('Fixing random seed', args.random_seed)
        np.random.seed(args.random_seed)

    # Load data
    if args.dataset_type == 'llff':
        frame2dolly = args.frame2dolly
        images, invdepths, masks, poses, bds, \
        render_poses, render_focals, grids = load_llff_data(args, args.datadir,
                                                            args.factor,
                                                            frame2dolly=frame2dolly,
                                                            recenter=True, bd_factor=.9,
                                                            spherify=args.spherify)

        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]
        num_img = float(poses.shape[0])
        assert len(poses) == len(images)
        print('Loaded llff', images.shape,
            render_poses.shape, hwf, args.datadir)

        # Use all views to train
        i_train = np.array([i for i in np.arange(int(images.shape[0]))])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            raise NotImplementedError
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.
        else:
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)
    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # Cast intrinsics to right types
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    # Create log dir and copy the config file
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)

    if not args.render_only:
        f = os.path.join(basedir, expname, 'args.txt')
        with open(f, 'w') as file:
            for arg in sorted(vars(args)):
                attr = getattr(args, arg)
                file.write('{} = {}\n'.format(arg, attr))
        if args.config is not None:
            f = os.path.join(basedir, expname, 'config.txt')
            with open(f, 'w') as file:
                file.write(open(args.config, 'r').read())

    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
    global_step = start

    bds_dict = {
        'near': near,
        'far': far,
        'num_img': num_img,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Short circuit if only rendering out from trained model
    if args.render_only:
        print('RENDER ONLY')
        i = start - 1

        # Change time and change view at the same time.
        time2render = np.concatenate((np.repeat((i_train / float(num_img) * 2. - 1.0), 4),
                                      np.repeat((i_train / float(num_img) * 2. - 1.0)[::-1][1:-1], 4)))
        if len(time2render) > len(render_poses):
            pose2render = np.tile(render_poses, (int(np.ceil(len(time2render) / len(render_poses))), 1, 1))
            pose2render = pose2render[:len(time2render)]
            pose2render = torch.Tensor(pose2render)
        else:
            time2render = np.tile(time2render, int(np.ceil(len(render_poses) / len(time2render))))
            time2render = time2render[:len(render_poses)]
            pose2render = torch.Tensor(render_poses)
        result_type = 'novelviewtime'

        testsavedir = os.path.join(
            basedir, expname, result_type + '_{:06d}'.format(i))
        os.makedirs(testsavedir, exist_ok=True)
        with torch.no_grad():
            ret = render_path(pose2render, time2render,
                              hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
        moviebase = os.path.join(
            testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
        save_res(moviebase, ret)

        # Fix view (first view) and change time.
        pose2render = torch.Tensor(poses[0:1, ...]).expand([int(num_img), 3, 4])
        time2render = i_train / float(num_img) * 2. - 1.0
        result_type = 'testset_view000'

        testsavedir = os.path.join(
            basedir, expname, result_type + '_{:06d}'.format(i))
        os.makedirs(testsavedir, exist_ok=True)
        with torch.no_grad():
            ret = render_path(pose2render, time2render,
                              hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
        moviebase = os.path.join(
            testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
        save_res(moviebase, ret)

        return

    N_rand = args.N_rand

    # Move training data to GPU
    images = torch.Tensor(images)
    invdepths = torch.Tensor(invdepths)
    masks = 1.0 - torch.Tensor(masks)
    poses = torch.Tensor(poses)
    grids = torch.Tensor(grids)

    print('Begin')
    print('TRAIN views are', i_train)

    # Summary writers
    writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))

    decay_iteration = max(25, num_img)

    # Pre-train StaticNeRF
    if args.pretrain:
        render_kwargs_train.update({'pretrain': True})

        # Pre-train StaticNeRF first and use DynamicNeRF to blend
        assert args.DyNeRF_blending == True

        if args.ft_path_S is not None and args.ft_path_S != 'None':
            # Load Pre-trained StaticNeRF
            ckpt_path = args.ft_path_S
            print('Reloading StaticNeRF from', ckpt_path)
            ckpt = torch.load(ckpt_path)
            render_kwargs_train['network_fn_s'].load_state_dict(ckpt['network_fn_s_state_dict'])
        else:
            # Train StaticNeRF from scratch
            for i in range(args.N_iters):
                time0 = time.time()

                # No raybatching as we need to take random rays from one image at a time
                img_i = np.random.choice(i_train)
                t = img_i / num_img * 2. - 1.0 # time of the current frame
                target = images[img_i]
                pose = poses[img_i, :3, :4]
                mask = masks[img_i] # Static region mask

                rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
                coords_s = torch.stack((torch.where(mask >= 0.5)), -1)
                select_inds_s = np.random.choice(coords_s.shape[0], size=[N_rand], replace=False)
                select_coords = coords_s[select_inds_s]

                def select_batch(value, select_coords=select_coords):
                    return value[select_coords[:, 0], select_coords[:, 1]]

                rays_o = select_batch(rays_o) # (N_rand, 3)
                rays_d = select_batch(rays_d) # (N_rand, 3)
                target_rgb = select_batch(target)
                batch_mask = select_batch(mask[..., None])
                batch_rays = torch.stack([rays_o, rays_d], 0)

                #####  Core optimization loop  #####
                ret = render(t,
                             False,
                             H, W, focal,
                             chunk=args.chunk,
                             rays=batch_rays,
                             **render_kwargs_train)

                optimizer.zero_grad()

                # Compute MSE loss between rgb_s and true RGB.
                img_s_loss = img2mse(ret['rgb_map_s'], target_rgb)
                psnr_s = mse2psnr(img_s_loss)
                loss = args.static_loss_lambda * img_s_loss

                loss.backward()
                optimizer.step()

                # Learning rate decay.
                decay_rate = 0.1
                decay_steps = args.lrate_decay
                new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
                for param_group in optimizer.param_groups:
                    param_group['lr'] = new_lrate

                dt = time.time() - time0

                print(f"Pretraining step: {global_step}, Loss: {loss}, Time: {dt}, expname: {expname}")

                if i % args.i_print == 0:
                    writer.add_scalar("loss", loss.item(), i)
                    writer.add_scalar("lr", new_lrate, i)
                    writer.add_scalar("psnr_s", psnr_s.item(), i)

                if i % args.i_img == 0:
                    target = images[img_i]
                    pose = poses[img_i, :3, :4]
                    mask = masks[img_i]

                    with torch.no_grad():
                        ret = render(t,
                                     False,
                                     H, W, focal,
                                     chunk=1024*16,
                                     c2w=pose,
                                     **render_kwargs_test)

                        # Save out the validation image for Tensorboard-free monitoring
                        writer.add_image("rgb_holdout", target, global_step=i, dataformats='HWC')
                        writer.add_image("mask", mask, global_step=i, dataformats='HW')
                        writer.add_image("rgb_s", torch.clamp(ret['rgb_map_s'], 0., 1.), global_step=i, dataformats='HWC')
                        writer.add_image("depth_s", normalize_depth(ret['depth_map_s']), global_step=i, dataformats='HW')
                        writer.add_image("acc_s", ret['acc_map_s'], global_step=i, dataformats='HW')

                global_step += 1

        # Save the pretrained weight
        torch.save({
            'global_step': global_step,
            'network_fn_s_state_dict': render_kwargs_train['network_fn_s'].state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, os.path.join(basedir, expname, 'Pretrained_S.tar'))

        # Reset
        render_kwargs_train.update({'pretrain': False})
        global_step = start

        # Fix the StaticNeRF and only train the DynamicNeRF
        grad_vars_d = list(render_kwargs_train['network_fn_d'].parameters())
        optimizer = torch.optim.Adam(params=grad_vars_d, lr=args.lrate, betas=(0.9, 0.999))

    for i in range(start, args.N_iters):
        time0 = time.time()

        # Use frames at t-2, t-1, t, t+1, t+2 (adapted from NSFF)
        if i < decay_iteration * 2000:
            chain_5frames = False
        else:
            chain_5frames = True

        # Lambda decay.
        Temp = 1. / (10 ** (i // (decay_iteration * 1000)))

        if i % (decay_iteration * 1000) == 0:
            torch.cuda.empty_cache()

        # No raybatching as we need to take random rays from one image at a time
        img_i = np.random.choice(i_train)
        t = img_i / num_img * 2. - 1.0 # time of the current frame
        target = images[img_i]
        pose = poses[img_i, :3, :4]
        mask = masks[img_i] # Static region mask
        invdepth = invdepths[img_i]
        grid = grids[img_i]

        rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
        coords_d = torch.stack((torch.where(mask < 0.5)), -1)
        coords_s = torch.stack((torch.where(mask >= 0.5)), -1)
        coords = torch.stack((torch.where(mask > -1)), -1)

        # Evenly sample dynamic region and static region
        select_inds_d = np.random.choice(coords_d.shape[0], size=[min(len(coords_d), N_rand//2)], replace=False)
        select_inds_s = np.random.choice(coords_s.shape[0], size=[N_rand//2], replace=False)
        select_coords = torch.cat([coords_s[select_inds_s],
                                   coords_d[select_inds_d]], 0)

        def select_batch(value, select_coords=select_coords):
            return value[select_coords[:, 0], select_coords[:, 1]]

        rays_o = select_batch(rays_o) # (N_rand, 3)
        rays_d = select_batch(rays_d) # (N_rand, 3)
        target_rgb = select_batch(target)
        batch_grid = select_batch(grid) # (N_rand, 8)
        batch_mask = select_batch(mask[..., None])
        batch_invdepth = select_batch(invdepth)
        batch_rays = torch.stack([rays_o, rays_d], 0)

        #####  Core optimization loop  #####
        ret = render(t,
                     chain_5frames,
                     H, W, focal,
                     chunk=args.chunk,
                     rays=batch_rays,
                     **render_kwargs_train)

        optimizer.zero_grad()
        loss = 0
        loss_dict = {}

        # Compute MSE loss between rgb_full and true RGB.
        img_loss = img2mse(ret['rgb_map_full'], target_rgb)
        psnr = mse2psnr(img_loss)
        loss_dict['psnr'] = psnr
        loss_dict['img_loss'] = img_loss
        loss += args.full_loss_lambda * loss_dict['img_loss']

        # Compute MSE loss between rgb_s and true RGB.
        img_s_loss = img2mse(ret['rgb_map_s'], target_rgb, batch_mask)
        psnr_s = mse2psnr(img_s_loss)
        loss_dict['psnr_s'] = psnr_s
        loss_dict['img_s_loss'] = img_s_loss
        loss += args.static_loss_lambda * loss_dict['img_s_loss']

        # Compute MSE loss between rgb_d and true RGB.
        img_d_loss = img2mse(ret['rgb_map_d'], target_rgb)
        psnr_d = mse2psnr(img_d_loss)
        loss_dict['psnr_d'] = psnr_d
        loss_dict['img_d_loss'] = img_d_loss
        loss += args.dynamic_loss_lambda * loss_dict['img_d_loss']

        # Compute MSE loss between rgb_d_f and true RGB.
        img_d_f_loss = img2mse(ret['rgb_map_d_f'], target_rgb)
        psnr_d_f = mse2psnr(img_d_f_loss)
        loss_dict['psnr_d_f'] = psnr_d_f
        loss_dict['img_d_f_loss'] = img_d_f_loss
        loss += args.dynamic_loss_lambda * loss_dict['img_d_f_loss']

        # Compute MSE loss between rgb_d_b and true RGB.
        img_d_b_loss = img2mse(ret['rgb_map_d_b'], target_rgb)
        psnr_d_b = mse2psnr(img_d_b_loss)
        loss_dict['psnr_d_b'] = psnr_d_b
        loss_dict['img_d_b_loss'] = img_d_b_loss
        loss += args.dynamic_loss_lambda * loss_dict['img_d_b_loss']

        # Motion loss.
        # Compuate EPE between induced flow and true flow (forward flow).
        # The last frame does not have forward flow.
        if img_i < num_img - 1:
            pts_f = ret['raw_pts_f']
            weight = ret['weights_d']
            pose_f = poses[img_i + 1, :3, :4]
            induced_flow_f = induce_flow(H, W, focal, pose_f, weight, pts_f, batch_grid[..., :2])
            flow_f_loss = img2mae(induced_flow_f, batch_grid[:, 2:4], batch_grid[:, 4:5])
            loss_dict['flow_f_loss'] = flow_f_loss
            loss += args.flow_loss_lambda * Temp * loss_dict['flow_f_loss']

        # Compuate EPE between induced flow and true flow (backward flow).
        # The first frame does not have backward flow.
        if img_i > 0:
            pts_b = ret['raw_pts_b']
            weight = ret['weights_d']
            pose_b = poses[img_i - 1, :3, :4]
            induced_flow_b = induce_flow(H, W, focal, pose_b, weight, pts_b, batch_grid[..., :2])
            flow_b_loss = img2mae(induced_flow_b, batch_grid[:, 5:7], batch_grid[:, 7:8])
            loss_dict['flow_b_loss'] = flow_b_loss
            loss += args.flow_loss_lambda * Temp * loss_dict['flow_b_loss']

        # Slow scene flow. The forward and backward sceneflow should be small.
        slow_loss = L1(ret['sceneflow_b']) + L1(ret['sceneflow_f'])
        loss_dict['slow_loss'] = slow_loss
        loss += args.slow_loss_lambda * loss_dict['slow_loss']

        # Smooth scene flow. The summation of the forward and backward sceneflow should be small.
        smooth_loss = compute_sf_smooth_loss(ret['raw_pts'],
                                             ret['raw_pts_f'],
                                             ret['raw_pts_b'],
                                             H, W, focal)
        loss_dict['smooth_loss'] = smooth_loss
        loss += args.smooth_loss_lambda * loss_dict['smooth_loss']

        # Spatial smooth scene flow. (loss adapted from NSFF)
        sp_smooth_loss = compute_sf_smooth_s_loss(ret['raw_pts'], ret['raw_pts_f'], H, W, focal) \
                       + compute_sf_smooth_s_loss(ret['raw_pts'], ret['raw_pts_b'], H, W, focal)
        loss_dict['sp_smooth_loss'] = sp_smooth_loss
        loss += args.smooth_loss_lambda * loss_dict['sp_smooth_loss']

        # Consistency loss.
        consistency_loss = L1(ret['sceneflow_f'] + ret['sceneflow_f_b']) + \
                           L1(ret['sceneflow_b'] + ret['sceneflow_b_f'])
        loss_dict['consistency_loss'] = consistency_loss
        loss += args.consistency_loss_lambda * loss_dict['consistency_loss']

        # Mask loss.
        mask_loss = L1(ret['blending'][batch_mask[:, 0].type(torch.bool)]) + \
                    img2mae(ret['dynamicness_map'][..., None], 1 - batch_mask)
        loss_dict['mask_loss'] = mask_loss
        if i < decay_iteration * 1000:
            loss += args.mask_loss_lambda * loss_dict['mask_loss']

        # Sparsity loss.
        sparse_loss = entropy(ret['weights_d']) + entropy(ret['blending'])
        loss_dict['sparse_loss'] = sparse_loss
        loss += args.sparse_loss_lambda * loss_dict['sparse_loss']

        # Depth constraint
        # Depth in NDC space equals to negative disparity in Euclidean space.
        depth_loss = compute_depth_loss(ret['depth_map_d'], -batch_invdepth)
        loss_dict['depth_loss'] = depth_loss
        loss += args.depth_loss_lambda * Temp * loss_dict['depth_loss']

        # Order loss
        order_loss = torch.mean(torch.square(ret['depth_map_d'][batch_mask[:, 0].type(torch.bool)] - \
                                             ret['depth_map_s'].detach()[batch_mask[:, 0].type(torch.bool)]))
        loss_dict['order_loss'] = order_loss
        loss += args.order_loss_lambda * loss_dict['order_loss']

        sf_smooth_loss = compute_sf_smooth_loss(ret['raw_pts_b'],
                                                ret['raw_pts'],
                                                ret['raw_pts_b_b'],
                                                H, W, focal) + \
                         compute_sf_smooth_loss(ret['raw_pts_f'],
                                                ret['raw_pts_f_f'],
                                                ret['raw_pts'],
                                                H, W, focal)
        loss_dict['sf_smooth_loss'] = sf_smooth_loss
        loss += args.smooth_loss_lambda * loss_dict['sf_smooth_loss']

        if chain_5frames:
            img_d_b_b_loss = img2mse(ret['rgb_map_d_b_b'], target_rgb)
            loss_dict['img_d_b_b_loss'] = img_d_b_b_loss
            loss += args.dynamic_loss_lambda * loss_dict['img_d_b_b_loss']

            img_d_f_f_loss = img2mse(ret['rgb_map_d_f_f'], target_rgb)
            loss_dict['img_d_f_f_loss'] = img_d_f_f_loss
            loss += args.dynamic_loss_lambda * loss_dict['img_d_f_f_loss']

        loss.backward()
        optimizer.step()

        # Learning rate decay.
        decay_rate = 0.1
        decay_steps = args.lrate_decay
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate

        dt = time.time() - time0

        print(f"Step: {global_step}, Loss: {loss}, Time: {dt}, chain_5frames: {chain_5frames}, expname: {expname}")

        # Rest is logging
        if i % args.i_weights==0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))

            if args.N_importance > 0:
                raise NotImplementedError
            else:
                torch.save({
                    'global_step': global_step,
                    'network_fn_d_state_dict': render_kwargs_train['network_fn_d'].state_dict(),
                    'network_fn_s_state_dict': render_kwargs_train['network_fn_s'].state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, path)

            print('Saved weights at', path)

        if i % args.i_video == 0 and i > 0:

            # Change time and change view at the same time.
            time2render = np.concatenate((np.repeat((i_train / float(num_img) * 2. - 1.0), 4),
                                          np.repeat((i_train / float(num_img) * 2. - 1.0)[::-1][1:-1], 4)))
            if len(time2render) > len(render_poses):
                pose2render = np.tile(render_poses, (int(np.ceil(len(time2render) / len(render_poses))), 1, 1))
                pose2render = pose2render[:len(time2render)]
                pose2render = torch.Tensor(pose2render)
            else:
                time2render = np.tile(time2render, int(np.ceil(len(render_poses) / len(time2render))))
                time2render = time2render[:len(render_poses)]
                pose2render = torch.Tensor(render_poses)
            result_type = 'novelviewtime'

            testsavedir = os.path.join(
                basedir, expname, result_type + '_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            with torch.no_grad():
                ret = render_path(pose2render, time2render,
                                  hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
            moviebase = os.path.join(
                testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
            save_res(moviebase, ret)

        if i % args.i_testset == 0 and i > 0:

            # Change view and time.
            pose2render = torch.Tensor(poses)
            time2render = i_train / float(num_img) * 2. - 1.0
            result_type = 'testset'

            testsavedir = os.path.join(
                basedir, expname, result_type + '_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            with torch.no_grad():
                ret = render_path(pose2render, time2render,
                                  hwf, args.chunk, render_kwargs_test, savedir=testsavedir,
                                  flows_gt_f=grids[:, :, :, 2:4], flows_gt_b=grids[:, :, :, 5:7])
            moviebase = os.path.join(
                testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
            save_res(moviebase, ret)

            # Fix view (first view) and change time.
            pose2render = torch.Tensor(poses[0:1, ...].expand([int(num_img), 3, 4]))
            time2render = i_train / float(num_img) * 2. - 1.0
            result_type = 'testset_view000'

            testsavedir = os.path.join(
                basedir, expname, result_type + '_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            with torch.no_grad():
                ret = render_path(pose2render, time2render,
                                  hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
            moviebase = os.path.join(
                testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
            save_res(moviebase, ret)

            # Fix time (the first timestamp) and change view.
            pose2render = torch.Tensor(poses)
            time2render = np.tile(i_train[0], [int(num_img)]) / float(num_img) * 2. - 1.0
            result_type = 'testset_time000'

            testsavedir = os.path.join(
                basedir, expname, result_type + '_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            with torch.no_grad():
                ret = render_path(pose2render, time2render,
                                  hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
            moviebase = os.path.join(
                testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
            save_res(moviebase, ret)

        if i % args.i_print == 0:
            writer.add_scalar("loss", loss.item(), i)
            writer.add_scalar("lr", new_lrate, i)
            writer.add_scalar("Temp", Temp, i)
            for loss_key in loss_dict:
                writer.add_scalar(loss_key, loss_dict[loss_key].item(), i)

        if i % args.i_img == 0:
            # Log a rendered training view to Tensorboard.
            # img_i = np.random.choice(i_train[1:-1])
            target = images[img_i]
            pose = poses[img_i, :3, :4]
            mask = masks[img_i]
            grid = grids[img_i]
            invdepth = invdepths[img_i]

            flow_f_img = flow_to_image(grid[..., 2:4].cpu().numpy())
            flow_b_img = flow_to_image(grid[..., 5:7].cpu().numpy())

            with torch.no_grad():
                ret = render(t,
                             False,
                             H, W, focal,
                             chunk=1024*16,
                             c2w=pose,
                             **render_kwargs_test)

                # The last frame does not have forward flow.
                pose_f = poses[min(img_i + 1, int(num_img) - 1), :3, :4]
                induced_flow_f = induce_flow(H, W, focal, pose_f, ret['weights_d'], ret['raw_pts_f'], grid[..., :2])

                # The first frame does not have backward flow.
                pose_b = poses[max(img_i - 1, 0), :3, :4]
                induced_flow_b = induce_flow(H, W, focal, pose_b, ret['weights_d'], ret['raw_pts_b'], grid[..., :2])

                induced_flow_f_img = flow_to_image(induced_flow_f.cpu().numpy())
                induced_flow_b_img = flow_to_image(induced_flow_b.cpu().numpy())

                psnr = mse2psnr(img2mse(ret['rgb_map_full'], target))

                # Save out the validation image for Tensorboard-free monitoring
                testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')
                if i == 0:
                    os.makedirs(testimgdir, exist_ok=True)
                imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(ret['rgb_map_full'].cpu().numpy()))

                writer.add_scalar("psnr_holdout", psnr.item(), i)
                writer.add_image("rgb_holdout", target, global_step=i, dataformats='HWC')
                writer.add_image("mask", mask, global_step=i, dataformats='HW')
                writer.add_image("disp", torch.clamp(invdepth / percentile(invdepth, 97), 0., 1.), global_step=i, dataformats='HW')

                writer.add_image("rgb", torch.clamp(ret['rgb_map_full'], 0., 1.), global_step=i, dataformats='HWC')
                writer.add_image("depth", normalize_depth(ret['depth_map_full']), global_step=i, dataformats='HW')
                writer.add_image("acc", ret['acc_map_full'], global_step=i, dataformats='HW')

                writer.add_image("rgb_s", torch.clamp(ret['rgb_map_s'], 0., 1.), global_step=i, dataformats='HWC')
                writer.add_image("depth_s", normalize_depth(ret['depth_map_s']), global_step=i, dataformats='HW')
                writer.add_image("acc_s", ret['acc_map_s'], global_step=i, dataformats='HW')

                writer.add_image("rgb_d", torch.clamp(ret['rgb_map_d'], 0., 1.), global_step=i, dataformats='HWC')
                writer.add_image("depth_d", normalize_depth(ret['depth_map_d']), global_step=i, dataformats='HW')
                writer.add_image("acc_d", ret['acc_map_d'], global_step=i, dataformats='HW')

                writer.add_image("induced_flow_f", induced_flow_f_img, global_step=i, dataformats='HWC')
                writer.add_image("induced_flow_b", induced_flow_b_img, global_step=i, dataformats='HWC')
                writer.add_image("flow_f_gt", flow_f_img, global_step=i, dataformats='HWC')
                writer.add_image("flow_b_gt", flow_b_img, global_step=i, dataformats='HWC')

                writer.add_image("dynamicness", ret['dynamicness_map'], global_step=i, dataformats='HW')

        global_step += 1


if __name__ == '__main__':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    train()


================================================
FILE: run_nerf_helpers.py
================================================
import os
import torch
import imageio
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Misc utils
def img2mse(x, y, M=None):
    if M == None:
        return torch.mean((x - y) ** 2)
    else:
        return torch.sum((x - y) ** 2 * M) / (torch.sum(M) + 1e-8) / x.shape[-1]


def img2mae(x, y, M=None):
    if M == None:
        return torch.mean(torch.abs(x - y))
    else:
        return torch.sum(torch.abs(x - y) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]


def L1(x, M=None):
    if M == None:
        return torch.mean(torch.abs(x))
    else:
        return torch.sum(torch.abs(x) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]


def L2(x, M=None):
    if M == None:
        return torch.mean(x ** 2)
    else:
        return torch.sum((x ** 2) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]


def entropy(x):
    return -torch.sum(x * torch.log(x + 1e-19)) / x.shape[0]


def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.]))


def to8b(x): return (255 * np.clip(x, 0, 1)).astype(np.uint8)


class Embedder:

    def __init__(self, **kwargs):

        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):

        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn,
                                 freq=freq : p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


def get_embedder(multires, i=0, input_dims=3):

    if i == -1:
        return nn.Identity(), 3

    embed_kwargs = {
        'include_input': True,
        'input_dims': input_dims,
        'max_freq_log2': multires-1,
        'num_freqs': multires,
        'log_sampling': True,
        'periodic_fns': [torch.sin, torch.cos],
    }

    embedder_obj = Embedder(**embed_kwargs)
    def embed(x, eo=embedder_obj): return eo.embed(x)
    return embed, embedder_obj.out_dim


# Dynamic NeRF model architecture
class NeRF_d(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirsDyn=True):
        """
        """
        super(NeRF_d, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_viewdirsDyn = use_viewdirsDyn

        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])

        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])

        if self.use_viewdirsDyn:
            self.feature_linear = nn.Linear(W, W)
            self.alpha_linear = nn.Linear(W, 1)
            self.rgb_linear = nn.Linear(W//2, 3)
        else:
            self.output_linear = nn.Linear(W, output_ch)

        self.sf_linear = nn.Linear(W, 6)
        self.weight_linear = nn.Linear(W, 1)

    def forward(self, x):
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
        h = input_pts
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = F.relu(h)
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)

        # Scene flow should be unbounded. However, in NDC space the coordinate is
        # bounded in [-1, 1].
        sf = torch.tanh(self.sf_linear(h))
        blending = torch.sigmoid(self.weight_linear(h))

        if self.use_viewdirsDyn:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            h = torch.cat([feature, input_views], -1)

            for i, l in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = F.relu(h)

            rgb = self.rgb_linear(h)
            outputs = torch.cat([rgb, alpha], -1)
        else:
            outputs = self.output_linear(h)

        return torch.cat([outputs, sf, blending], dim=-1)


# Static NeRF model architecture
class NeRF_s(nn.Module):
    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=True):
        """
        """
        super(NeRF_s, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_views = input_ch_views
        self.skips = skips
        self.use_viewdirs = use_viewdirs

        self.pts_linears = nn.ModuleList(
            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])

        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])

        if self.use_viewdirs:
            self.feature_linear = nn.Linear(W, W)
            self.alpha_linear = nn.Linear(W, 1)
            self.rgb_linear = nn.Linear(W//2, 3)
        else:
            self.output_linear = nn.Linear(W, output_ch)

        self.weight_linear = nn.Linear(W, 1)

    def forward(self, x):
        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
        h = input_pts
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h)
            h = F.relu(h)
            if i in self.skips:
                h = torch.cat([input_pts, h], -1)

        blending = torch.sigmoid(self.weight_linear(h))
        if self.use_viewdirs:
            alpha = self.alpha_linear(h)
            feature = self.feature_linear(h)
            h = torch.cat([feature, input_views], -1)

            for i, l in enumerate(self.views_linears):
                h = self.views_linears[i](h)
                h = F.relu(h)

            rgb = self.rgb_linear(h)
            outputs = torch.cat([rgb, alpha], -1)
        else:
            outputs = self.output_linear(h)

        return torch.cat([outputs, blending], -1)


def batchify(fn, chunk):
    """Constructs a version of 'fn' that applies to smaller batches.
    """
    if chunk is None:
        return fn

    def ret(inputs):
        return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret


def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """Prepares inputs and applies network 'fn'.
    """

    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])

    embedded = embed_fn(inputs_flat)
    if viewdirs is not None:
        input_dirs = viewdirs[:, None].expand(inputs[:, :, :3].shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = torch.cat([embedded, embedded_dirs], -1)

    outputs_flat = batchify(fn, netchunk)(embedded)
    outputs = torch.reshape(outputs_flat, list(
        inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs


def create_nerf(args):
    """Instantiate NeRF's MLP model.
    """

    embed_fn_d, input_ch_d = get_embedder(args.multires, args.i_embed, 4)
    # 10 * 2 * 4 + 4 = 84
    # L * (sin, cos) * (x, y, z, t) + (x, y, z, t)

    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(
            args.multires_views, args.i_embed, 3)
        # 4 * 2 * 3 + 3 = 27
        # L * (sin, cos) * (3 Cartesian viewing direction unit vector from [theta, phi]) + (3 Cartesian viewing direction unit vector from [theta, phi])
    output_ch = 5 if args.N_importance > 0 else 4
    skips = [4]
    model_d = NeRF_d(D=args.netdepth, W=args.netwidth,
                     input_ch=input_ch_d, output_ch=output_ch, skips=skips,
                     input_ch_views=input_ch_views,
                     use_viewdirsDyn=args.use_viewdirsDyn).to(device)

    device_ids = list(range(torch.cuda.device_count()))
    model_d = torch.nn.DataParallel(model_d, device_ids=device_ids)
    grad_vars = list(model_d.parameters())

    embed_fn_s, input_ch_s = get_embedder(args.multires, args.i_embed, 3)
    # 10 * 2 * 3 + 3 = 63
    # L * (sin, cos) * (x, y, z) + (x, y, z)

    model_s = NeRF_s(D=args.netdepth, W=args.netwidth,
                     input_ch=input_ch_s, output_ch=output_ch, skips=skips,
                     input_ch_views=input_ch_views,
                     use_viewdirs=args.use_viewdirs).to(device)

    model_s = torch.nn.DataParallel(model_s, device_ids=device_ids)
    grad_vars += list(model_s.parameters())

    model_fine = None
    if args.N_importance > 0:
        raise NotImplementedError

    def network_query_fn_d(inputs, viewdirs, network_fn): return run_network(
        inputs, viewdirs, network_fn,
        embed_fn=embed_fn_d,
        embeddirs_fn=embeddirs_fn,
        netchunk=args.netchunk)

    def network_query_fn_s(inputs, viewdirs, network_fn): return run_network(
        inputs, viewdirs, network_fn,
        embed_fn=embed_fn_s,
        embeddirs_fn=embeddirs_fn,
        netchunk=args.netchunk)

    render_kwargs_train = {
        'network_query_fn_d': network_query_fn_d,
        'network_query_fn_s': network_query_fn_s,
        'network_fn_d': model_d,
        'network_fn_s': model_s,
        'perturb': args.perturb,
        'N_importance': args.N_importance,
        'N_samples': args.N_samples,
        'use_viewdirs': args.use_viewdirs,
        'raw_noise_std': args.raw_noise_std,
        'inference': False,
        'DyNeRF_blending': args.DyNeRF_blending,
    }

    # NDC only good for LLFF-style forward facing data
    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp
    else:
        render_kwargs_train['ndc'] = True

    render_kwargs_test = {
        k: render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.
    render_kwargs_test['inference'] = True

    # Create optimizer
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

    start = 0
    basedir = args.basedir
    expname = args.expname

    if args.ft_path is not None and args.ft_path != 'None':
        ckpts = [args.ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not args.no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step'] + 1
        # optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        model_d.load_state_dict(ckpt['network_fn_d_state_dict'])
        model_s.load_state_dict(ckpt['network_fn_s_state_dict'])
        print('Resetting step to', start)

        if model_fine is not None:
            raise NotImplementedError

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer


# Ray helpers
def get_rays(H, W, focal, c2w):
    """Get ray origins, directions from a pinhole camera."""
    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
    i = i.t()
    j = j.t()
    dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = c2w[:3, -1].expand(rays_d.shape)
    return rays_o, rays_d


def ndc_rays(H, W, focal, near, rays_o, rays_d):
    """Normalized device coordinate rays.
    Space such that the canvas is a cube with sides [-1, 1] in each axis.
    Args:
      H: int. Height in pixels.
      W: int. Width in pixels.
      focal: float. Focal length of pinhole camera.
      near: float or array of shape[batch_size]. Near depth bound for the scene.
      rays_o: array of shape [batch_size, 3]. Camera origin.
      rays_d: array of shape [batch_size, 3]. Ray direction.
    Returns:
      rays_o: array of shape [batch_size, 3]. Camera origin in NDC.
      rays_d: array of shape [batch_size, 3]. Ray direction in NDC.
    """
    # Shift ray origins to near plane
    t = -(near + rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    # Projection
    o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2]
    o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2]
    o2 = 1. + 2. * near / rays_o[..., 2]

    d0 = -1./(W/(2.*focal)) * \
        (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2])
    d1 = -1./(H/(2.*focal)) * \
    (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2])
    d2 = -2. * near / rays_o[..., 2]

    rays_o = torch.stack([o0, o1, o2], -1)
    rays_d = torch.stack([d0, d1, d2], -1)

    return rays_o, rays_d


def get_grid(H, W, num_img, flows_f, flow_masks_f, flows_b, flow_masks_b):

    # |--------------------|  |--------------------|
    # |       j            |  |       v            |
    # |   i   *            |  |   u   *            |
    # |                    |  |                    |
    # |--------------------|  |--------------------|

    i, j = np.meshgrid(np.arange(W, dtype=np.float32),
                       np.arange(H, dtype=np.float32), indexing='xy')

    grid = np.empty((0, H, W, 8), np.float32)
    for idx in range(num_img):
        grid = np.concatenate((grid, np.stack([i,
                                               j,
                                               flows_f[idx, :, :, 0],
                                               flows_f[idx, :, :, 1],
                                               flow_masks_f[idx, :, :],
                                               flows_b[idx, :, :, 0],
                                               flows_b[idx, :, :, 1],
                                               flow_masks_b[idx, :, :]], -1)[None, ...]))
    return grid


def NDC2world(pts, H, W, f):

    # NDC coordinate to world coordinate
    pts_z = 2 / (torch.clamp(pts[..., 2:], min=-1., max=1-1e-3) - 1)
    pts_x = - pts[..., 0:1] * pts_z * W / 2 / f
    pts_y = - pts[..., 1:2] * pts_z * H / 2 / f
    pts_world = torch.cat([pts_x, pts_y, pts_z], -1)

    return pts_world


def render_3d_point(H, W, f, pose, weights, pts):
    """Render 3D position along each ray and project it to the image plane.
    """

    c2w = pose
    w2c = c2w[:3, :3].transpose(0, 1) # same as np.linalg.inv(c2w[:3, :3])

    # Rendered 3D position in NDC coordinate
    pts_map_NDC = torch.sum(weights[..., None] * pts, -2)

    # NDC coordinate to world coordinate
    pts_map_world = NDC2world(pts_map_NDC, H, W, f)

    # World coordinate to camera coordinate
    # Translate
    pts_map_world = pts_map_world - c2w[:, 3]
    # Rotate
    pts_map_cam = torch.sum(pts_map_world[..., None, :] * w2c[:3, :3], -1)

    # Camera coordinate to 2D image coordinate
    pts_plane = torch.cat([pts_map_cam[..., 0:1] / (- pts_map_cam[..., 2:]) * f + W * .5,
                         - pts_map_cam[..., 1:2] / (- pts_map_cam[..., 2:]) * f + H * .5],
                         -1)

    return pts_plane


def induce_flow(H, W, focal, pose_neighbor, weights, pts_3d_neighbor, pts_2d):

    # Render 3D position along each ray and project it to the neighbor frame's image plane.
    pts_2d_neighbor = render_3d_point(H, W, focal,
                                      pose_neighbor,
                                      weights,
                                      pts_3d_neighbor)
    induced_flow = pts_2d_neighbor - pts_2d

    return induced_flow


def compute_depth_loss(dyn_depth, gt_depth):

    t_d = torch.median(dyn_depth)
    s_d = torch.mean(torch.abs(dyn_depth - t_d))
    dyn_depth_norm = (dyn_depth - t_d) / s_d

    t_gt = torch.median(gt_depth)
    s_gt = torch.mean(torch.abs(gt_depth - t_gt))
    gt_depth_norm = (gt_depth - t_gt) / s_gt

    return torch.mean((dyn_depth_norm - gt_depth_norm) ** 2)


def normalize_depth(depth):
    return torch.clamp(depth / percentile(depth, 97), 0., 1.)


def percentile(t, q):
    """
    Return the ``q``-th percentile of the flattened input tensor's data.

    CAUTION:
     * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
     * Values are not interpolated, which corresponds to
       ``numpy.percentile(..., interpolation="nearest")``.

    :param t: Input tensor.
    :param q: Percentile to compute, which must be between 0 and 100 inclusive.
    :return: Resulting value (scalar).
    """

    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    result = t.view(-1).kthvalue(k).values.item()
    return result


def save_res(moviebase, ret, fps=None):

    if fps == None:
        if len(ret['rgbs']) < 25:
            fps = 4
        else:
            fps = 24

    for k in ret:
        if 'rgbs' in k:
            imageio.mimwrite(moviebase + k + '.mp4',
                             to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
            # imageio.mimsave(moviebase + k + '.gif',
            #                  to8b(ret[k]), format='gif', fps=fps)
        elif 'depths' in k:
            imageio.mimwrite(moviebase + k + '.mp4',
                             to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
            # imageio.mimsave(moviebase + k + '.gif',
            #                  to8b(ret[k]), format='gif', fps=fps)
        elif 'disps' in k:
            imageio.mimwrite(moviebase + k + '.mp4',
                             to8b(ret[k] / np.max(ret[k])), fps=fps, quality=8, macro_block_size=1)
            # imageio.mimsave(moviebase + k + '.gif',
            #                  to8b(ret[k] / np.max(ret[k])), format='gif', fps=fps)
        elif 'sceneflow_' in k:
            imageio.mimwrite(moviebase + k + '.mp4',
                             to8b(norm_sf(ret[k])), fps=fps, quality=8, macro_block_size=1)
            # imageio.mimsave(moviebase + k + '.gif',
            #                  to8b(norm_sf(ret[k])), format='gif', fps=fps)
        elif 'flows' in k:
            imageio.mimwrite(moviebase + k + '.mp4',
                             ret[k], fps=fps, quality=8, macro_block_size=1)
            # imageio.mimsave(moviebase + k + '.gif',
            #                  ret[k], format='gif', fps=fps)
        elif 'dynamicness' in k:
            imageio.mimwrite(moviebase + k + '.mp4',
                             to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
            # imageio.mimsave(moviebase + k + '.gif',
            #                  to8b(ret[k]), format='gif', fps=fps)
        elif 'disocclusions' in k:
            imageio.mimwrite(moviebase + k + '.mp4',
                             to8b(ret[k][..., 0]), fps=fps, quality=8, macro_block_size=1)
            # imageio.mimsave(moviebase + k + '.gif',
            #                  to8b(ret[k][..., 0]), format='gif', fps=fps)
        elif 'blending' in k:
            blending = ret[k][..., None]
            blending = np.moveaxis(blending, [0, 1, 2, 3], [1, 2, 0, 3])
            imageio.mimwrite(moviebase + k + '.mp4',
                             to8b(blending), fps=fps, quality=8, macro_block_size=1)
            # imageio.mimsave(moviebase + k + '.gif',
            #                  to8b(blending), format='gif', fps=fps)
        elif 'weights' in k:
            imageio.mimwrite(moviebase + k + '.mp4',
                             to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
        else:
            raise NotImplementedError


def norm_sf_channel(sf_ch):

    # Make sure zero scene flow is not shifted
    sf_ch[sf_ch >= 0] = sf_ch[sf_ch >= 0] / sf_ch.max() / 2
    sf_ch[sf_ch < 0] = sf_ch[sf_ch < 0] / np.abs(sf_ch.min()) / 2
    sf_ch = sf_ch + 0.5
    return sf_ch


def norm_sf(sf):

    sf = np.concatenate((norm_sf_channel(sf[..., 0:1]),
                         norm_sf_channel(sf[..., 1:2]),
                         norm_sf_channel(sf[..., 2:3])), -1)
    sf = np.moveaxis(sf, [0, 1, 2, 3], [1, 2, 0, 3])
    return sf


# Spatial smoothness (adapted from NSFF)
def compute_sf_smooth_s_loss(pts1, pts2, H, W, f):

    N_samples = pts1.shape[1]

    # NDC coordinate to world coordinate
    pts1_world = NDC2world(pts1[..., :int(N_samples * 0.95), :], H, W, f)
    pts2_world = NDC2world(pts2[..., :int(N_samples * 0.95), :], H, W, f)

    # scene flow in world coordinate
    scene_flow_world = pts1_world - pts2_world

    return L1(scene_flow_world[..., :-1, :] - scene_flow_world[..., 1:, :])


# Temporal smoothness
def compute_sf_smooth_loss(pts, pts_f, pts_b, H, W, f):

    N_samples = pts.shape[1]

    pts_world   = NDC2world(pts[..., :int(N_samples * 0.9), :],   H, W, f)
    pts_f_world = NDC2world(pts_f[..., :int(N_samples * 0.9), :], H, W, f)
    pts_b_world = NDC2world(pts_b[..., :int(N_samples * 0.9), :], H, W, f)

    # scene flow in world coordinate
    sceneflow_f = pts_f_world - pts_world
    sceneflow_b = pts_b_world - pts_world

    # For a 3D point, its forward and backward sceneflow should be opposite.
    return L2(sceneflow_f + sceneflow_b)


================================================
FILE: utils/RAFT/__init__.py
================================================
# from .demo import RAFT_infer
from .raft import RAFT


================================================
FILE: utils/RAFT/corr.py
================================================
import torch
import torch.nn.functional as F
from .utils.utils import bilinear_sampler, coords_grid

try:
    import alt_cuda_corr
except:
    # alt_cuda_corr is not compiled
    pass


class CorrBlock:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        self.num_levels = num_levels
        self.radius = radius
        self.corr_pyramid = []

        # all pairs correlation
        corr = CorrBlock.corr(fmap1, fmap2)

        batch, h1, w1, dim, h2, w2 = corr.shape
        corr = corr.reshape(batch*h1*w1, dim, h2, w2)

        self.corr_pyramid.append(corr)
        for i in range(self.num_levels-1):
            corr = F.avg_pool2d(corr, 2, stride=2)
            self.corr_pyramid.append(corr)

    def __call__(self, coords):
        r = self.radius
        coords = coords.permute(0, 2, 3, 1)
        batch, h1, w1, _ = coords.shape

        out_pyramid = []
        for i in range(self.num_levels):
            corr = self.corr_pyramid[i]
            dx = torch.linspace(-r, r, 2*r+1)
            dy = torch.linspace(-r, r, 2*r+1)
            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)

            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
            coords_lvl = centroid_lvl + delta_lvl

            corr = bilinear_sampler(corr, coords_lvl)
            corr = corr.view(batch, h1, w1, -1)
            out_pyramid.append(corr)

        out = torch.cat(out_pyramid, dim=-1)
        return out.permute(0, 3, 1, 2).contiguous().float()

    @staticmethod
    def corr(fmap1, fmap2):
        batch, dim, ht, wd = fmap1.shape
        fmap1 = fmap1.view(batch, dim, ht*wd)
        fmap2 = fmap2.view(batch, dim, ht*wd)

        corr = torch.matmul(fmap1.transpose(1,2), fmap2)
        corr = corr.view(batch, ht, wd, 1, ht, wd)
        return corr  / torch.sqrt(torch.tensor(dim).float())


class CorrLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, fmap1, fmap2, coords, r):
        fmap1 = fmap1.contiguous()
        fmap2 = fmap2.contiguous()
        coords = coords.contiguous()
        ctx.save_for_backward(fmap1, fmap2, coords)
        ctx.r = r
        corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
        return corr

    @staticmethod
    def backward(ctx, grad_corr):
        fmap1, fmap2, coords = ctx.saved_tensors
        grad_corr = grad_corr.contiguous()
        fmap1_grad, fmap2_grad, coords_grad = \
            correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
        return fmap1_grad, fmap2_grad, coords_grad, None


class AlternateCorrBlock:
    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
        self.num_levels = num_levels
        self.radius = radius

        self.pyramid = [(fmap1, fmap2)]
        for i in range(self.num_levels):
            fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
            self.pyramid.append((fmap1, fmap2))

    def __call__(self, coords):

        coords = coords.permute(0, 2, 3, 1)
        B, H, W, _ = coords.shape

        corr_list = []
        for i in range(self.num_levels):
            r = self.radius
            fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
            fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)

            coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
            corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
            corr_list.append(corr.squeeze(1))

        corr = torch.stack(corr_list, dim=1)
        corr = corr.reshape(B, -1, H, W)
        return corr / 16.0


================================================
FILE: utils/RAFT/datasets.py
================================================
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch

import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F

import os
import math
import random
from glob import glob
import os.path as osp

from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor


class FlowDataset(data.Dataset):
    def __init__(self, aug_params=None, sparse=False):
        self.augmentor = None
        self.sparse = sparse
        if aug_params is not None:
            if sparse:
                self.augmentor = SparseFlowAugmentor(**aug_params)
            else:
                self.augmentor = FlowAugmentor(**aug_params)

        self.is_test = False
        self.init_seed = False
        self.flow_list = []
        self.image_list = []
        self.extra_info = []

    def __getitem__(self, index):

        if self.is_test:
            img1 = frame_utils.read_gen(self.image_list[index][0])
            img2 = frame_utils.read_gen(self.image_list[index][1])
            img1 = np.array(img1).astype(np.uint8)[..., :3]
            img2 = np.array(img2).astype(np.uint8)[..., :3]
            img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
            img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
            return img1, img2, self.extra_info[index]

        if not self.init_seed:
            worker_info = torch.utils.data.get_worker_info()
            if worker_info is not None:
                torch.manual_seed(worker_info.id)
                np.random.seed(worker_info.id)
                random.seed(worker_info.id)
                self.init_seed = True

        index = index % len(self.image_list)
        valid = None
        if self.sparse:
            flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
        else:
            flow = frame_utils.read_gen(self.flow_list[index])

        img1 = frame_utils.read_gen(self.image_list[index][0])
        img2 = frame_utils.read_gen(self.image_list[index][1])

        flow = np.array(flow).astype(np.float32)
        img1 = np.array(img1).astype(np.uint8)
        img2 = np.array(img2).astype(np.uint8)

        # grayscale images
        if len(img1.shape) == 2:
            img1 = np.tile(img1[...,None], (1, 1, 3))
            img2 = np.tile(img2[...,None], (1, 1, 3))
        else:
            img1 = img1[..., :3]
            img2 = img2[..., :3]

        if self.augmentor is not None:
            if self.sparse:
                img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
            else:
                img1, img2, flow = self.augmentor(img1, img2, flow)

        img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
        img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
        flow = torch.from_numpy(flow).permute(2, 0, 1).float()

        if valid is not None:
            valid = torch.from_numpy(valid)
        else:
            valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)

        return img1, img2, flow, valid.float()


    def __rmul__(self, v):
        self.flow_list = v * self.flow_list
        self.image_list = v * self.image_list
        return self
        
    def __len__(self):
        return len(self.image_list)
        

class MpiSintel(FlowDataset):
    def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
        super(MpiSintel, self).__init__(aug_params)
        flow_root = osp.join(root, split, 'flow')
        image_root = osp.join(root, split, dstype)

        if split == 'test':
            self.is_test = True

        for scene in os.listdir(image_root):
            image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
            for i in range(len(image_list)-1):
                self.image_list += [ [image_list[i], image_list[i+1]] ]
                self.extra_info += [ (scene, i) ] # scene and frame_id

            if split != 'test':
                self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))


class FlyingChairs(FlowDataset):
    def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
        super(FlyingChairs, self).__init__(aug_params)

        images = sorted(glob(osp.join(root, '*.ppm')))
        flows = sorted(glob(osp.join(root, '*.flo')))
        assert (len(images)//2 == len(flows))

        split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
        for i in range(len(flows)):
            xid = split_list[i]
            if (split=='training' and xid==1) or (split=='validation' and xid==2):
                self.flow_list += [ flows[i] ]
                self.image_list += [ [images[2*i], images[2*i+1]] ]


class FlyingThings3D(FlowDataset):
    def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
        super(FlyingThings3D, self).__init__(aug_params)

        for cam in ['left']:
            for direction in ['into_future', 'into_past']:
                image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
                image_dirs = sorted([osp.join(f, cam) for f in image_dirs])

                flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
                flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])

                for idir, fdir in zip(image_dirs, flow_dirs):
                    images = sorted(glob(osp.join(idir, '*.png')) )
                    flows = sorted(glob(osp.join(fdir, '*.pfm')) )
                    for i in range(len(flows)-1):
                        if direction == 'into_future':
                            self.image_list += [ [images[i], images[i+1]] ]
                            self.flow_list += [ flows[i] ]
                        elif direction == 'into_past':
                            self.image_list += [ [images[i+1], images[i]] ]
                            self.flow_list += [ flows[i+1] ]
      

class KITTI(FlowDataset):
    def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
        super(KITTI, self).__init__(aug_params, sparse=True)
        if split == 'testing':
            self.is_test = True

        root = osp.join(root, split)
        images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
        images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))

        for img1, img2 in zip(images1, images2):
            frame_id = img1.split('/')[-1]
            self.extra_info += [ [frame_id] ]
            self.image_list += [ [img1, img2] ]

        if split == 'training':
            self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))


class HD1K(FlowDataset):
    def __init__(self, aug_params=None, root='datasets/HD1k'):
        super(HD1K, self).__init__(aug_params, sparse=True)

        seq_ix = 0
        while 1:
            flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
            images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))

            if len(flows) == 0:
                break

            for i in range(len(flows)-1):
                self.flow_list += [flows[i]]
                self.image_list += [ [images[i], images[i+1]] ]

            seq_ix += 1


def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
    """ Create the data loader for the corresponding trainign set """

    if args.stage == 'chairs':
        aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
        train_dataset = FlyingChairs(aug_params, split='training')
    
    elif args.stage == 'things':
        aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
        clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
        final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
        train_dataset = clean_dataset + final_dataset

    elif args.stage == 'sintel':
        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
        things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
        sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
        sintel_final = MpiSintel(aug_params, split='training', dstype='final')        

        if TRAIN_DS == 'C+T+K+S+H':
            kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
            hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
            train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things

        elif TRAIN_DS == 'C+T+K/S':
            train_dataset = 100*sintel_clean + 100*sintel_final + things

    elif args.stage == 'kitti':
        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
        train_dataset = KITTI(aug_params, split='training')

    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 
        pin_memory=False, shuffle=True, num_workers=4, drop_last=True)

    print('Training with %d image pairs' % len(train_dataset))
    return train_loader



================================================
FILE: utils/RAFT/demo.py
================================================
import sys
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image

from .raft import RAFT
from .utils import flow_viz
from .utils.utils import InputPadder



DEVICE = 'cuda'

def load_image(imfile):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img


def load_image_list(image_files):
    images = []
    for imfile in sorted(image_files):
        images.append(load_image(imfile))

    images = torch.stack(images, dim=0)
    images = images.to(DEVICE)

    padder = InputPadder(images.shape)
    return padder.pad(images)[0]


def viz(img, flo):
    img = img[0].permute(1,2,0).cpu().numpy()
    flo = flo[0].permute(1,2,0).cpu().numpy()

    # map flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    # img_flo = np.concatenate([img, flo], axis=0)
    img_flo = flo

    cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
    # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
    # cv2.waitKey()


def demo(args):
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()

    with torch.no_grad():
        images = glob.glob(os.path.join(args.path, '*.png')) + \
                 glob.glob(os.path.join(args.path, '*.jpg'))

        images = load_image_list(images)
        for i in range(images.shape[0]-1):
            image1 = images[i,None]
            image2 = images[i+1,None]

            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
            viz(image1, flow_up)


def RAFT_infer(args):
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()

    return model


================================================
FILE: utils/RAFT/extractor.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
        super(ResidualBlock, self).__init__()
  
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)

        num_groups = planes // 8

        if norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            if not stride == 1:
                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
        
        elif norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(planes)
            self.norm2 = nn.BatchNorm2d(planes)
            if not stride == 1:
                self.norm3 = nn.BatchNorm2d(planes)
        
        elif norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(planes)
            self.norm2 = nn.InstanceNorm2d(planes)
            if not stride == 1:
                self.norm3 = nn.InstanceNorm2d(planes)

        elif norm_fn == 'none':
            self.norm1 = nn.Sequential()
            self.norm2 = nn.Sequential()
            if not stride == 1:
                self.norm3 = nn.Sequential()

        if stride == 1:
            self.downsample = None
        
        else:    
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)


    def forward(self, x):
        y = x
        y = self.relu(self.norm1(self.conv1(y)))
        y = self.relu(self.norm2(self.conv2(y)))

        if self.downsample is not None:
            x = self.downsample(x)

        return self.relu(x+y)



class BottleneckBlock(nn.Module):
    def __init__(self, in_planes, planes, norm_fn='group', stride=1):
        super(BottleneckBlock, self).__init__()
  
        self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
        self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
        self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
        self.relu = nn.ReLU(inplace=True)

        num_groups = planes // 8

        if norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
            self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
            if not stride == 1:
                self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
        
        elif norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(planes//4)
            self.norm2 = nn.BatchNorm2d(planes//4)
            self.norm3 = nn.BatchNorm2d(planes)
            if not stride == 1:
                self.norm4 = nn.BatchNorm2d(planes)
        
        elif norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(planes//4)
            self.norm2 = nn.InstanceNorm2d(planes//4)
            self.norm3 = nn.InstanceNorm2d(planes)
            if not stride == 1:
                self.norm4 = nn.InstanceNorm2d(planes)

        elif norm_fn == 'none':
            self.norm1 = nn.Sequential()
            self.norm2 = nn.Sequential()
            self.norm3 = nn.Sequential()
            if not stride == 1:
                self.norm4 = nn.Sequential()

        if stride == 1:
            self.downsample = None
        
        else:    
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)


    def forward(self, x):
        y = x
        y = self.relu(self.norm1(self.conv1(y)))
        y = self.relu(self.norm2(self.conv2(y)))
        y = self.relu(self.norm3(self.conv3(y)))

        if self.downsample is not None:
            x = self.downsample(x)

        return self.relu(x+y)

class BasicEncoder(nn.Module):
    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
        super(BasicEncoder, self).__init__()
        self.norm_fn = norm_fn

        if self.norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
            
        elif self.norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(64)

        elif self.norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(64)

        elif self.norm_fn == 'none':
            self.norm1 = nn.Sequential()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.relu1 = nn.ReLU(inplace=True)

        self.in_planes = 64
        self.layer1 = self._make_layer(64,  stride=1)
        self.layer2 = self._make_layer(96, stride=2)
        self.layer3 = self._make_layer(128, stride=2)

        # output convolution
        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)

        self.dropout = None
        if dropout > 0:
            self.dropout = nn.Dropout2d(p=dropout)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _make_layer(self, dim, stride=1):
        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
        layers = (layer1, layer2)
        
        self.in_planes = dim
        return nn.Sequential(*layers)


    def forward(self, x):

        # if input is list, combine batch dimension
        is_list = isinstance(x, tuple) or isinstance(x, list)
        if is_list:
            batch_dim = x[0].shape[0]
            x = torch.cat(x, dim=0)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.conv2(x)

        if self.training and self.dropout is not None:
            x = self.dropout(x)

        if is_list:
            x = torch.split(x, [batch_dim, batch_dim], dim=0)

        return x


class SmallEncoder(nn.Module):
    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
        super(SmallEncoder, self).__init__()
        self.norm_fn = norm_fn

        if self.norm_fn == 'group':
            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
            
        elif self.norm_fn == 'batch':
            self.norm1 = nn.BatchNorm2d(32)

        elif self.norm_fn == 'instance':
            self.norm1 = nn.InstanceNorm2d(32)

        elif self.norm_fn == 'none':
            self.norm1 = nn.Sequential()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
        self.relu1 = nn.ReLU(inplace=True)

        self.in_planes = 32
        self.layer1 = self._make_layer(32,  stride=1)
        self.layer2 = self._make_layer(64, stride=2)
        self.layer3 = self._make_layer(96, stride=2)

        self.dropout = None
        if dropout > 0:
            self.dropout = nn.Dropout2d(p=dropout)
        
        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
                if m.weight is not None:
                    nn.init.constant_(m.weight, 1)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _make_layer(self, dim, stride=1):
        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
        layers = (layer1, layer2)
    
        self.in_planes = dim
        return nn.Sequential(*layers)


    def forward(self, x):

        # if input is list, combine batch dimension
        is_list = isinstance(x, tuple) or isinstance(x, list)
        if is_list:
            batch_dim = x[0].shape[0]
            x = torch.cat(x, dim=0)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.conv2(x)

        if self.training and self.dropout is not None:
            x = self.dropout(x)

        if is_list:
            x = torch.split(x, [batch_dim, batch_dim], dim=0)

        return x


================================================
FILE: utils/RAFT/raft.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from .update import BasicUpdateBlock, SmallUpdateBlock
from .extractor import BasicEncoder, SmallEncoder
from .corr import CorrBlock, AlternateCorrBlock
from .utils.utils import bilinear_sampler, coords_grid, upflow8

try:
    autocast = torch.cuda.amp.autocast
except:
    # dummy autocast for PyTorch < 1.6
    class autocast:
        def __init__(self, enabled):
            pass
        def __enter__(self):
            pass
        def __exit__(self, *args):
            pass


class RAFT(nn.Module):
    def __init__(self, args):
        super(RAFT, self).__init__()
        self.args = args

        if args.small:
            self.hidden_dim = hdim = 96
            self.context_dim = cdim = 64
            args.corr_levels = 4
            args.corr_radius = 3

        else:
            self.hidden_dim = hdim = 128
            self.context_dim = cdim = 128
            args.corr_levels = 4
            args.corr_radius = 4

        if 'dropout' not in args._get_kwargs():
            args.dropout = 0

        if 'alternate_corr' not in args._get_kwargs():
            args.alternate_corr = False

        # feature network, context network, and update block
        if args.small:
            self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
            self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
            self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)

        else:
            self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
            self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
            self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)


    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    def initialize_flow(self, img):
        """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
        N, C, H, W = img.shape
        coords0 = coords_grid(N, H//8, W//8).to(img.device)
        coords1 = coords_grid(N, H//8, W//8).to(img.device)

        # optical flow computed as difference: flow = coords1 - coords0
        return coords0, coords1

    def upsample_flow(self, flow, mask):
        """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
        N, _, H, W = flow.shape
        mask = mask.view(N, 1, 9, 8, 8, H, W)
        mask = torch.softmax(mask, dim=2)

        up_flow = F.unfold(8 * flow, [3,3], padding=1)
        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)

        up_flow = torch.sum(mask * up_flow, dim=2)
        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
        return up_flow.reshape(N, 2, 8*H, 8*W)


    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
        """ Estimate optical flow between pair of frames """

        image1 = 2 * (image1 / 255.0) - 1.0
        image2 = 2 * (image2 / 255.0) - 1.0

        image1 = image1.contiguous()
        image2 = image2.contiguous()

        hdim = self.hidden_dim
        cdim = self.context_dim

        # run the feature network
        with autocast(enabled=self.args.mixed_precision):
            fmap1, fmap2 = self.fnet([image1, image2])

        fmap1 = fmap1.float()
        fmap2 = fmap2.float()
        if self.args.alternate_corr:
            corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
        else:
            corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)

        # run the context network
        with autocast(enabled=self.args.mixed_precision):
            cnet = self.cnet(image1)
            net, inp = torch.split(cnet, [hdim, cdim], dim=1)
            net = torch.tanh(net)
            inp = torch.relu(inp)

        coords0, coords1 = self.initialize_flow(image1)

        if flow_init is not None:
            coords1 = coords1 + flow_init

        flow_predictions = []
        for itr in range(iters):
            coords1 = coords1.detach()
            corr = corr_fn(coords1) # index correlation volume

            flow = coords1 - coords0
            with autocast(enabled=self.args.mixed_precision):
                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)

            # F(t+1) = F(t) + \Delta(t)
            coords1 = coords1 + delta_flow

            # upsample predictions
            if up_mask is None:
                flow_up = upflow8(coords1 - coords0)
            else:
                flow_up = self.upsample_flow(coords1 - coords0, up_mask)

            flow_predictions.append(flow_up)

        if test_mode:
            return coords1 - coords0, flow_up

        return flow_predictions


================================================
FILE: utils/RAFT/update.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F


class FlowHead(nn.Module):
    def __init__(self, input_dim=128, hidden_dim=256):
        super(FlowHead, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))

class ConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192+128):
        super(ConvGRU, self).__init__()
        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)

    def forward(self, h, x):
        hx = torch.cat([h, x], dim=1)

        z = torch.sigmoid(self.convz(hx))
        r = torch.sigmoid(self.convr(hx))
        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))

        h = (1-z) * h + z * q
        return h

class SepConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192+128):
        super(SepConvGRU, self).__init__()
        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))

        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))


    def forward(self, h, x):
        # horizontal
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz1(hx))
        r = torch.sigmoid(self.convr1(hx))
        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))        
        h = (1-z) * h + z * q

        # vertical
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz2(hx))
        r = torch.sigmoid(self.convr2(hx))
        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       
        h = (1-z) * h + z * q

        return h

class SmallMotionEncoder(nn.Module):
    def __init__(self, args):
        super(SmallMotionEncoder, self).__init__()
        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
        self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
        self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
        self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
        self.conv = nn.Conv2d(128, 80, 3, padding=1)

    def forward(self, flow, corr):
        cor = F.relu(self.convc1(corr))
        flo = F.relu(self.convf1(flow))
        flo = F.relu(self.convf2(flo))
        cor_flo = torch.cat([cor, flo], dim=1)
        out = F.relu(self.conv(cor_flo))
        return torch.cat([out, flow], dim=1)

class BasicMotionEncoder(nn.Module):
    def __init__(self, args):
        super(BasicMotionEncoder, self).__init__()
        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)

    def forward(self, flow, corr):
        cor = F.relu(self.convc1(corr))
        cor = F.relu(self.convc2(cor))
        flo = F.relu(self.convf1(flow))
        flo = F.relu(self.convf2(flo))

        cor_flo = torch.cat([cor, flo], dim=1)
        out = F.relu(self.conv(cor_flo))
        return torch.cat([out, flow], dim=1)

class SmallUpdateBlock(nn.Module):
    def __init__(self, args, hidden_dim=96):
        super(SmallUpdateBlock, self).__init__()
        self.encoder = SmallMotionEncoder(args)
        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
        self.flow_head = FlowHead(hidden_dim, hidden_dim=128)

    def forward(self, net, inp, corr, flow):
        motion_features = self.encoder(flow, corr)
        inp = torch.cat([inp, motion_features], dim=1)
        net = self.gru(net, inp)
        delta_flow = self.flow_head(net)

        return net, None, delta_flow

class BasicUpdateBlock(nn.Module):
    def __init__(self, args, hidden_dim=128, input_dim=128):
        super(BasicUpdateBlock, self).__init__()
        self.args = args
        self.encoder = BasicMotionEncoder(args)
        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)

        self.mask = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 64*9, 1, padding=0))

    def forward(self, net, inp, corr, flow, upsample=True):
        motion_features = self.encoder(flow, corr)
        inp = torch.cat([inp, motion_features], dim=1)

        net = self.gru(net, inp)
        delta_flow = self.flow_head(net)

        # scale mask to balence gradients
        mask = .25 * self.mask(net)
        return net, mask, delta_flow





================================================
FILE: utils/RAFT/utils/__init__.py
================================================
from .flow_viz import flow_to_image
from .frame_utils import writeFlow


================================================
FILE: utils/RAFT/utils/augmentor.py
================================================
import numpy as np
import random
import math
from PIL import Image

import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

import torch
from torchvision.transforms import ColorJitter
import torch.nn.functional as F


class FlowAugmentor:
    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
        
        # spatial augmentation params
        self.crop_size = crop_size
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.spatial_aug_prob = 0.8
        self.stretch_prob = 0.8
        self.max_stretch = 0.2

        # flip augmentation params
        self.do_flip = do_flip
        self.h_flip_prob = 0.5
        self.v_flip_prob = 0.1

        # photometric augmentation params
        self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
        self.asymmetric_color_aug_prob = 0.2
        self.eraser_aug_prob = 0.5

    def color_transform(self, img1, img2):
        """ Photometric augmentation """

        # asymmetric
        if np.random.rand() < self.asymmetric_color_aug_prob:
            img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
            img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)

        # symmetric
        else:
            image_stack = np.concatenate([img1, img2], axis=0)
            image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
            img1, img2 = np.split(image_stack, 2, axis=0)

        return img1, img2

    def eraser_transform(self, img1, img2, bounds=[50, 100]):
        """ Occlusion augmentation """

        ht, wd = img1.shape[:2]
        if np.random.rand() < self.eraser_aug_prob:
            mean_color = np.mean(img2.reshape(-1, 3), axis=0)
            for _ in range(np.random.randint(1, 3)):
                x0 = np.random.randint(0, wd)
                y0 = np.random.randint(0, ht)
                dx = np.random.randint(bounds[0], bounds[1])
                dy = np.random.randint(bounds[0], bounds[1])
                img2[y0:y0+dy, x0:x0+dx, :] = mean_color

        return img1, img2

    def spatial_transform(self, img1, img2, flow):
        # randomly sample scale
        ht, wd = img1.shape[:2]
        min_scale = np.maximum(
            (self.crop_size[0] + 8) / float(ht), 
            (self.crop_size[1] + 8) / float(wd))

        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
        scale_x = scale
        scale_y = scale
        if np.random.rand() < self.stretch_prob:
            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
        
        scale_x = np.clip(scale_x, min_scale, None)
        scale_y = np.clip(scale_y, min_scale, None)

        if np.random.rand() < self.spatial_aug_prob:
            # rescale the images
            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
            flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
            flow = flow * [scale_x, scale_y]

        if self.do_flip:
            if np.random.rand() < self.h_flip_prob: # h-flip
                img1 = img1[:, ::-1]
                img2 = img2[:, ::-1]
                flow = flow[:, ::-1] * [-1.0, 1.0]

            if np.random.rand() < self.v_flip_prob: # v-flip
                img1 = img1[::-1, :]
                img2 = img2[::-1, :]
                flow = flow[::-1, :] * [1.0, -1.0]

        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
        x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
        
        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]

        return img1, img2, flow

    def __call__(self, img1, img2, flow):
        img1, img2 = self.color_transform(img1, img2)
        img1, img2 = self.eraser_transform(img1, img2)
        img1, img2, flow = self.spatial_transform(img1, img2, flow)

        img1 = np.ascontiguousarray(img1)
        img2 = np.ascontiguousarray(img2)
        flow = np.ascontiguousarray(flow)

        return img1, img2, flow

class SparseFlowAugmentor:
    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
        # spatial augmentation params
        self.crop_size = crop_size
        self.min_scale = min_scale
        self.max_scale = max_scale
        self.spatial_aug_prob = 0.8
        self.stretch_prob = 0.8
        self.max_stretch = 0.2

        # flip augmentation params
        self.do_flip = do_flip
        self.h_flip_prob = 0.5
        self.v_flip_prob = 0.1

        # photometric augmentation params
        self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
        self.asymmetric_color_aug_prob = 0.2
        self.eraser_aug_prob = 0.5
        
    def color_transform(self, img1, img2):
        image_stack = np.concatenate([img1, img2], axis=0)
        image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
        img1, img2 = np.split(image_stack, 2, axis=0)
        return img1, img2

    def eraser_transform(self, img1, img2):
        ht, wd = img1.shape[:2]
        if np.random.rand() < self.eraser_aug_prob:
            mean_color = np.mean(img2.reshape(-1, 3), axis=0)
            for _ in range(np.random.randint(1, 3)):
                x0 = np.random.randint(0, wd)
                y0 = np.random.randint(0, ht)
                dx = np.random.randint(50, 100)
                dy = np.random.randint(50, 100)
                img2[y0:y0+dy, x0:x0+dx, :] = mean_color

        return img1, img2

    def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
        ht, wd = flow.shape[:2]
        coords = np.meshgrid(np.arange(wd), np.arange(ht))
        coords = np.stack(coords, axis=-1)

        coords = coords.reshape(-1, 2).astype(np.float32)
        flow = flow.reshape(-1, 2).astype(np.float32)
        valid = valid.reshape(-1).astype(np.float32)

        coords0 = coords[valid>=1]
        flow0 = flow[valid>=1]

        ht1 = int(round(ht * fy))
        wd1 = int(round(wd * fx))

        coords1 = coords0 * [fx, fy]
        flow1 = flow0 * [fx, fy]

        xx = np.round(coords1[:,0]).astype(np.int32)
        yy = np.round(coords1[:,1]).astype(np.int32)

        v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
        xx = xx[v]
        yy = yy[v]
        flow1 = flow1[v]

        flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
        valid_img = np.zeros([ht1, wd1], dtype=np.int32)

        flow_img[yy, xx] = flow1
        valid_img[yy, xx] = 1

        return flow_img, valid_img

    def spatial_transform(self, img1, img2, flow, valid):
        # randomly sample scale

        ht, wd = img1.shape[:2]
        min_scale = np.maximum(
            (self.crop_size[0] + 1) / float(ht), 
            (self.crop_size[1] + 1) / float(wd))

        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
        scale_x = np.clip(scale, min_scale, None)
        scale_y = np.clip(scale, min_scale, None)

        if np.random.rand() < self.spatial_aug_prob:
            # rescale the images
            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
            flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)

        if self.do_flip:
            if np.random.rand() < 0.5: # h-flip
                img1 = img1[:, ::-1]
                img2 = img2[:, ::-1]
                flow = flow[:, ::-1] * [-1.0, 1.0]
                valid = valid[:, ::-1]

        margin_y = 20
        margin_x = 50

        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
        x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)

        y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
        x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])

        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
        valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
        return img1, img2, flow, valid


    def __call__(self, img1, img2, flow, valid):
        img1, img2 = self.color_transform(img1, img2)
        img1, img2 = self.eraser_transform(img1, img2)
        img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)

        img1 = np.ascontiguousarray(img1)
        img2 = np.ascontiguousarray(img2)
        flow = np.ascontiguousarray(flow)
        valid = np.ascontiguousarray(valid)

        return img1, img2, flow, valid


================================================
FILE: utils/RAFT/utils/flow_viz.py
================================================
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization


# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-08-03

import numpy as np

def make_colorwheel():
    """
    Generates a color wheel for optical flow visualization as presented in:
        Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf

    Code follows the original C++ source code of Daniel Scharstein.
    Code follows the the Matlab source code of Deqing Sun.

    Returns:
        np.ndarray: Color wheel
    """

    RY = 15
    YG = 6
    GC = 4
    CB = 11
    BM = 13
    MR = 6

    ncols = RY + YG + GC + CB + BM + MR
    colorwheel = np.zeros((ncols, 3))
    col = 0

    # RY
    colorwheel[0:RY, 0] = 255
    colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
    col = col+RY
    # YG
    colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
    colorwheel[col:col+YG, 1] = 255
    col = col+YG
    # GC
    colorwheel[col:col+GC, 1] = 255
    colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
    col = col+GC
    # CB
    colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
    colorwheel[col:col+CB, 2] = 255
    col = col+CB
    # BM
    colorwheel[col:col+BM, 2] = 255
    colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
    col = col+BM
    # MR
    colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
    colorwheel[col:col+MR, 0] = 255
    return colorwheel


def flow_uv_to_colors(u, v, convert_to_bgr=False):
    """
    Applies the flow color wheel to (possibly clipped) flow components u and v.

    According to the C++ source code of Daniel Scharstein
    According to the Matlab source code of Deqing Sun

    Args:
        u (np.ndarray): Input horizontal flow of shape [H,W]
        v (np.ndarray): Input vertical flow of shape [H,W]
        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.

    Returns:
        np.ndarray: Flow visualization image of shape [H,W,3]
    """
    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
    colorwheel = make_colorwheel()  # shape [55x3]
    ncols = colorwheel.shape[0]
    rad = np.sqrt(np.square(u) + np.square(v))
    a = np.arctan2(-v, -u)/np.pi
    fk = (a+1) / 2*(ncols-1)
    k0 = np.floor(fk).astype(np.int32)
    k1 = k0 + 1
    k1[k1 == ncols] = 0
    f = fk - k0
    for i in range(colorwheel.shape[1]):
        tmp = colorwheel[:,i]
        col0 = tmp[k0] / 255.0
        col1 = tmp[k1] / 255.0
        col = (1-f)*col0 + f*col1
        idx = (rad <= 1)
        col[idx]  = 1 - rad[idx] * (1-col[idx])
        col[~idx] = col[~idx] * 0.75   # out of range
        # Note the 2-i => BGR instead of RGB
        ch_idx = 2-i if convert_to_bgr else i
        flow_image[:,:,ch_idx] = np.floor(255 * col)
    return flow_image


def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
    """
    Expects a two dimensional flow image of shape.

    Args:
        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.

    Returns:
        np.ndarray: Flow visualization image of shape [H,W,3]
    """
    assert flow_uv.ndim == 3, 'input flow must have three dimensions'
    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
    if clip_flow is not None:
        flow_uv = np.clip(flow_uv, 0, clip_flow)
    u = flow_uv[:,:,0]
    v = flow_uv[:,:,1]
    rad = np.sqrt(np.square(u) + np.square(v))
    rad_max = np.max(rad)
    epsilon = 1e-5
    u = u / (rad_max + epsilon)
    v = v / (rad_max + epsilon)
    return flow_uv_to_colors(u, v, convert_to_bgr)

================================================
FILE: utils/RAFT/utils/frame_utils.py
================================================
import numpy as np
from PIL import Image
from os.path import *
import re

import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

TAG_CHAR = np.array([202021.25], np.float32)

def readFlow(fn):
    """ Read .flo file in Middlebury format"""
    # Code adapted from:
    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy

    # WARNING: this will work on little-endian architectures (eg Intel x86) only!
    # print 'fn = %s'%(fn)
    with open(fn, 'rb') as f:
        magic = np.fromfile(f, np.float32, count=1)
        if 202021.25 != magic:
            print('Magic number incorrect. Invalid .flo file')
            return None
        else:
            w = np.fromfile(f, np.int32, count=1)
            h = np.fromfile(f, np.int32, count=1)
            # print 'Reading %d x %d flo file\n' % (w, h)
            data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
            # Reshape data into 3D array (columns, rows, bands)
            # The reshape here is for visualization, the original code is (w,h,2)
            return np.resize(data, (int(h), int(w), 2))

def readPFM(file):
    file = open(file, 'rb')

    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().rstrip()
    if header == b'PF':
        color = True
    elif header == b'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip())
    if scale < 0: # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>' # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    return data

def writeFlow(filename,uv,v=None):
    """ Write optical flow to file.
    
    If v is None, uv is assumed to contain both u and v channels,
    stacked in depth.
    Original code by Deqing Sun, adapted from Daniel Scharstein.
    """
    nBands = 2

    if v is None:
        assert(uv.ndim == 3)
        assert(uv.shape[2] == 2)
        u = uv[:,:,0]
        v = uv[:,:,1]
    else:
        u = uv

    assert(u.shape == v.shape)
    height,width = u.shape
    f = open(filename,'wb')
    # write the header
    f.write(TAG_CHAR)
    np.array(width).astype(np.int32).tofile(f)
    np.array(height).astype(np.int32).tofile(f)
    # arrange into matrix form
    tmp = np.zeros((height, width*nBands))
    tmp[:,np.arange(width)*2] = u
    tmp[:,np.arange(width)*2 + 1] = v
    tmp.astype(np.float32).tofile(f)
    f.close()


def readFlowKITTI(filename):
    flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
    flow = flow[:,:,::-1].astype(np.float32)
    flow, valid = flow[:, :, :2], flow[:, :, 2]
    flow = (flow - 2**15) / 64.0
    return flow, valid

def readDispKITTI(filename):
    disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
    valid = disp > 0.0
    flow = np.stack([-disp, np.zeros_like(disp)], -1)
    return flow, valid


def writeFlowKITTI(filename, uv):
    uv = 64.0 * uv + 2**15
    valid = np.ones([uv.shape[0], uv.shape[1], 1])
    uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
    cv2.imwrite(filename, uv[..., ::-1])
    

def read_gen(file_name, pil=False):
    ext = splitext(file_name)[-1]
    if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
        return Image.open(file_name)
    elif ext == '.bin' or ext == '.raw':
        return np.load(file_name)
    elif ext == '.flo':
        return readFlow(file_name).astype(np.float32)
    elif ext == '.pfm':
        flow = readPFM(file_name).astype(np.float32)
        if len(flow.shape) == 2:
            return flow
        else:
            return flow[:, :, :-1]
    return []

================================================
FILE: utils/RAFT/utils/utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate


class InputPadder:
    """ Pads images such that dimensions are divisible by 8 """
    def __init__(self, dims, mode='sintel'):
        self.ht, self.wd = dims[-2:]
        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
        if mode == 'sintel':
            self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
        else:
            self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]

    def pad(self, *inputs):
        return [F.pad(x, self._pad, mode='replicate') for x in inputs]

    def unpad(self,x):
        ht, wd = x.shape[-2:]
        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
        return x[..., c[0]:c[1], c[2]:c[3]]

def forward_interpolate(flow):
    flow = flow.detach().cpu().numpy()
    dx, dy = flow[0], flow[1]

    ht, wd = dx.shape
    x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))

    x1 = x0 + dx
    y1 = y0 + dy
    
    x1 = x1.reshape(-1)
    y1 = y1.reshape(-1)
    dx = dx.reshape(-1)
    dy = dy.reshape(-1)

    valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
    x1 = x1[valid]
    y1 = y1[valid]
    dx = dx[valid]
    dy = dy[valid]

    flow_x = interpolate.griddata(
        (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)

    flow_y = interpolate.griddata(
        (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)

    flow = np.stack([flow_x, flow_y], axis=0)
    return torch.from_numpy(flow).float()


def bilinear_sampler(img, coords, mode='bilinear', mask=False):
    """ Wrapper for grid_sample, uses pixel coordinates """
    H, W = img.shape[-2:]
    xgrid, ygrid = coords.split([1,1], dim=-1)
    xgrid = 2*xgrid/(W-1) - 1
    ygrid = 2*ygrid/(H-1) - 1

    grid = torch.cat([xgrid, ygrid], dim=-1)
    img = F.grid_sample(img, grid, align_corners=True)

    if mask:
        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
        return img, mask.float()

    return img


def coords_grid(batch, ht, wd):
    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
    coords = torch.stack(coords[::-1], dim=0).float()
    return coords[None].repeat(batch, 1, 1, 1)


def upflow8(flow, mode='bilinear'):
    new_size = (8 * flow.shape[2], 8 * flow.shape[3])
    return  8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)


================================================
FILE: utils/colmap_utils.py
================================================
# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     * Redistributions of source code must retain the above copyright
#       notice, this list of conditions and the following disclaimer.
#
#     * Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
#     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
#       its contributors may be used to endorse or promote products derived
#       from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)

import os
import sys
import collections
import numpy as np
import struct


CameraModel = collections.namedtuple(
    "CameraModel", ["model_id", "model_name", "num_params"])
Camera = collections.namedtuple(
    "Camera", ["id", "model", "width", "height", "params"])
BaseImage = collections.namedtuple(
    "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
Point3D = collections.namedtuple(
    "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])

class Image(BaseImage):
    def qvec2rotmat(self):
        return qvec2rotmat(self.qvec)


CAMERA_MODELS = {
    CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
    CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
    CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
    CameraModel(model_id=3, model_name="RADIAL", num_params=5),
    CameraModel(model_id=4, model_name="OPENCV", num_params=8),
    CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
    CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
    CameraModel(model_id=7, model_name="FOV", num_params=5),
    CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
    CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
    CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
}
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \
                         for camera_model in CAMERA_MODELS])


def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
    """Read and unpack the next bytes from a binary file.
    :param fid:
    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
    :param endian_character: Any of {@, =, <, >, !}
    :return: Tuple of read and unpacked values.
    """
    data = fid.read(num_bytes)
    return struct.unpack(endian_character + format_char_sequence, data)


def read_cameras_text(path):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::WriteCamerasText(const std::string& path)
        void Reconstruction::ReadCamerasText(const std::string& path)
    """
    cameras = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                camera_id = int(elems[0])
                model = elems[1]
                width = int(elems[2])
                height = int(elems[3])
                params = np.array(tuple(map(float, elems[4:])))
                cameras[camera_id] = Camera(id=camera_id, model=model,
                                            width=width, height=height,
                                            params=params)
    return cameras


def read_cameras_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::WriteCamerasBinary(const std::string& path)
        void Reconstruction::ReadCamerasBinary(const std::string& path)
    """
    cameras = {}
    with open(path_to_model_file, "rb") as fid:
        num_cameras = read_next_bytes(fid, 8, "Q")[0]
        for camera_line_index in range(num_cameras):
            camera_properties = read_next_bytes(
                fid, num_bytes=24, format_char_sequence="iiQQ")
            camera_id = camera_properties[0]
            model_id = camera_properties[1]
            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
            width = camera_properties[2]
            height = camera_properties[3]
            num_params = CAMERA_MODEL_IDS[model_id].num_params
            params = read_next_bytes(fid, num_bytes=8*num_params,
                                     format_char_sequence="d"*num_params)
            cameras[camera_id] = Camera(id=camera_id,
                                        model=model_name,
                                        width=width,
                                        height=height,
                                        params=np.array(params))
        assert len(cameras) == num_cameras
    return cameras


def read_images_text(path):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadImagesText(const std::string& path)
        void Reconstruction::WriteImagesText(const std::string& path)
    """
    images = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                image_id = int(elems[0])
                qvec = np.array(tuple(map(float, elems[1:5])))
                tvec = np.array(tuple(map(float, elems[5:8])))
                camera_id = int(elems[8])
                image_name = elems[9]
                elems = fid.readline().split()
                xys = np.column_stack([tuple(map(float, elems[0::3])),
                                       tuple(map(float, elems[1::3]))])
                point3D_ids = np.array(tuple(map(int, elems[2::3])))
                images[image_id] = Image(
                    id=image_id, qvec=qvec, tvec=tvec,
                    camera_id=camera_id, name=image_name,
                    xys=xys, point3D_ids=point3D_ids)
    return images


def read_images_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadImagesBinary(const std::string& path)
        void Reconstruction::WriteImagesBinary(const std::string& path)
    """
    images = {}
    with open(path_to_model_file, "rb") as fid:
        num_reg_images = read_next_bytes(fid, 8, "Q")[0]
        for image_index in range(num_reg_images):
            binary_image_properties = read_next_bytes(
                fid, num_bytes=64, format_char_sequence="idddddddi")
            image_id = binary_image_properties[0]
            qvec = np.array(binary_image_properties[1:5])
            tvec = np.array(binary_image_properties[5:8])
            camera_id = binary_image_properties[8]
            image_name = ""
            current_char = read_next_bytes(fid, 1, "c")[0]
            while current_char != b"\x00":   # look for the ASCII 0 entry
                image_name += current_char.decode("utf-8")
                current_char = read_next_bytes(fid, 1, "c")[0]
            num_points2D = read_next_bytes(fid, num_bytes=8,
                                           format_char_sequence="Q")[0]
            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
                                       format_char_sequence="ddq"*num_points2D)
            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
                                   tuple(map(float, x_y_id_s[1::3]))])
            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
            images[image_id] = Image(
                id=image_id, qvec=qvec, tvec=tvec,
                camera_id=camera_id, name=image_name,
                xys=xys, point3D_ids=point3D_ids)
    return images


def read_points3D_text(path):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadPoints3DText(const std::string& path)
        void Reconstruction::WritePoints3DText(const std::string& path)
    """
    points3D = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                point3D_id = int(elems[0])
                xyz = np.array(tuple(map(float, elems[1:4])))
                rgb = np.array(tuple(map(int, elems[4:7])))
                error = float(elems[7])
                image_ids = np.array(tuple(map(int, elems[8::2])))
                point2D_idxs = np.array(tuple(map(int, elems[9::2])))
                points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
                                               error=error, image_ids=image_ids,
                                               point2D_idxs=point2D_idxs)
    return points3D


def read_points3d_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadPoints3DBinary(const std::string& path)
        void Reconstruction::WritePoints3DBinary(const std::string& path)
    """
    points3D = {}
    with open(path_to_model_file, "rb") as fid:
        num_points = read_next_bytes(fid, 8, "Q")[0]
        for point_line_index in range(num_points):
            binary_point_line_properties = read_next_bytes(
                fid, num_bytes=43, format_char_sequence="QdddBBBd")
            point3D_id = binary_point_line_properties[0]
            xyz = np.array(binary_point_line_properties[1:4])
            rgb = np.array(binary_point_line_properties[4:7])
            error = np.array(binary_point_line_properties[7])
            track_length = read_next_bytes(
                fid, num_bytes=8, format_char_sequence="Q")[0]
            track_elems = read_next_bytes(
                fid, num_bytes=8*track_length,
                format_char_sequence="ii"*track_length)
            image_ids = np.array(tuple(map(int, track_elems[0::2])))
            point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
            points3D[point3D_id] = Point3D(
                id=point3D_id, xyz=xyz, rgb=rgb,
                error=error, image_ids=image_ids,
                point2D_idxs=point2D_idxs)
    return points3D


def read_model(path, ext):
    if ext == ".txt":
        cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
        images = read_images_text(os.path.join(path, "images" + ext))
        points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
    else:
        cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
        images = read_images_binary(os.path.join(path, "images" + ext))
        points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
    return cameras, images, points3D


def qvec2rotmat(qvec):
    return np.array([
        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])


def rotmat2qvec(R):
    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
    K = np.array([
        [Rxx - Ryy - Rzz, 0, 0, 0],
        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
    eigvals, eigvecs = np.linalg.eigh(K)
    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
    if qvec[0] < 0:
        qvec *= -1
    return qvec


def main():
    if len(sys.argv) != 3:
        print("Usage: python read_model.py path/to/model/folder [.txt,.bin]")
        return

    cameras, images, points3D = read_model(path=sys.argv[1], ext=sys.argv[2])

    print("num_cameras:", len(cameras))
    print("num_images:", len(images))
    print("num_points3D:", len(points3D))


if __name__ == "__main__":
    main()


================================================
FILE: utils/evaluation.py
================================================
import os
import cv2
import lpips
import torch
import numpy as np
from skimage.metrics import structural_similarity


def im2tensor(img):
    return torch.Tensor(img.transpose(2, 0, 1) / 127.5 - 1.0)[None, ...]


def create_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)


def readimage(data_dir, sequence, time, method):
    img = cv2.imread(os.path.join(data_dir, method, sequence, 'v000_t' + str(time).zfill(3) + '.png'))
    return img


def calculate_metrics(data_dir, sequence, methods, lpips_loss):

    PSNRs = np.zeros((len(methods)))
    SSIMs = np.zeros((len(methods)))
    LPIPSs = np.zeros((len(methods)))

    nFrame = 0

    # Yoon's results do not include v000_t000 and v000_t011. Omit these two
    # frames if evaluating Yoon's method.
    if 'Yoon' in methods:
        time_start = 1
        time_end = 11
    else:
        time_start = 0
        time_end = 12

    for time in range(time_start, time_end): # Fix view v0, change time

        nFrame += 1

        img_true = readimage(data_dir, sequence, time, 'gt')

        for method_idx, method in enumerate(methods):

            if 'Yoon' in methods and sequence == 'Truck' and time == 10:
                break

            img = readimage(data_dir, sequence, time, method)
            PSNR = cv2.PSNR(img_true, img)
            SSIM = structural_similarity(img_true, img, multichannel=True)
            LPIPS = lpips_loss.forward(im2tensor(img_true), im2tensor(img)).item()

            PSNRs[method_idx] += PSNR
            SSIMs[method_idx] += SSIM
            LPIPSs[method_idx] += LPIPS

    PSNRs = PSNRs / nFrame
    SSIMs = SSIMs / nFrame
    LPIPSs = LPIPSs / nFrame

    return PSNRs, SSIMs, LPIPSs


if __name__ == '__main__':

    lpips_loss = lpips.LPIPS(net='alex') # best forward scores
    data_dir = '../results'
    sequences = ['Balloon1', 'Balloon2', 'Jumping', 'Playground', 'Skating', 'Truck', 'Umbrella']
    # methods = ['NeRF', 'NeRF_t', 'Yoon', 'NR', 'NSFF', 'Ours']
    methods = ['NeRF', 'NeRF_t', 'NR', 'NSFF', 'Ours']

    PSNRs_total = np.zeros((len(methods)))
    SSIMs_total = np.zeros((len(methods)))
    LPIPSs_total = np.zeros((len(methods)))
    for sequence in sequences:
        print(sequence)
        PSNRs, SSIMs, LPIPSs = calculate_metrics(data_dir, sequence, methods, lpips_loss)
        for method_idx, method in enumerate(methods):
            print(method.ljust(7) + '%.2f'%(PSNRs[method_idx]) + ' / %.4f'%(SSIMs[method_idx]) + ' / %.3f'%(LPIPSs[method_idx]))

        PSNRs_total += PSNRs
        SSIMs_total += SSIMs
        LPIPSs_total += LPIPSs

    PSNRs_total = PSNRs_total / len(sequences)
    SSIMs_total = SSIMs_total / len(sequences)
    LPIPSs_total = LPIPSs_total / len(sequences)
    print('Avg.')
    for method_idx, method in enumerate(methods):
        print(method.ljust(7) + '%.2f'%(PSNRs_total[method_idx]) + ' / %.4f'%(SSIMs_total[method_idx]) + ' / %.3f'%(LPIPSs_total[method_idx]))


================================================
FILE: utils/flow_utils.py
================================================
import os
import cv2
import numpy as np
from PIL import Image
from os.path import *
UNKNOWN_FLOW_THRESH = 1e7

def flow_to_image(flow, global_max=None):
    """
    Convert flow into middlebury color code image
    :param flow: optical flow map
    :return: optical flow image in middlebury color
    """
    u = flow[:, :, 0]
    v = flow[:, :, 1]

    maxu = -999.
    maxv = -999.
    minu = 999.
    minv = 999.

    idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
    u[idxUnknow] = 0
    v[idxUnknow] = 0

    maxu = max(maxu, np.max(u))
    minu = min(minu, np.min(u))

    maxv = max(maxv, np.max(v))
    minv = min(minv, np.min(v))

    rad = np.sqrt(u ** 2 + v ** 2)

    if global_max == None:
        maxrad = max(-1, np.max(rad))
    else:
        maxrad = global_max

    u = u/(maxrad + np.finfo(float).eps)
    v = v/(maxrad + np.finfo(float).eps)

    img = compute_color(u, v)

    idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
    img[idx] = 0

    return np.uint8(img)


def compute_color(u, v):
    """
    compute optical flow color map
    :param u: optical flow horizontal map
    :param v: optical flow vertical map
    :return: optical flow in color code
    """
    [h, w] = u.shape
    img = np.zeros([h, w, 3])
    nanIdx = np.isnan(u) | np.isnan(v)
    u[nanIdx] = 0
    v[nanIdx] = 0

    colorwheel = make_color_wheel()
    ncols = np.size(colorwheel, 0)

    rad = np.sqrt(u**2+v**2)

    a = np.arctan2(-v, -u) / np.pi

    fk = (a+1) / 2 * (ncols - 1) + 1

    k0 = np.floor(fk).astype(int)

    k1 = k0 + 1
    k1[k1 == ncols+1] = 1
    f = fk - k0

    for i in range(0, np.size(colorwheel,1)):
        tmp = colorwheel[:, i]
        col0 = tmp[k0-1] / 255
        col1 = tmp[k1-1] / 255
        col = (1-f) * col0 + f * col1

        idx = rad <= 1
        col[idx] = 1-rad[idx]*(1-col[idx])
        notidx = np.logical_not(idx)

        col[notidx] *= 0.75
        img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))

    return img


def make_color_wheel():
    """
    Generate color wheel according Middlebury color code
    :return: Color wheel
    """
    RY = 15
    YG = 6
    GC = 4
    CB = 11
    BM = 13
    MR = 6

    ncols = RY + YG + GC + CB + BM + MR

    colorwheel = np.zeros([ncols, 3])

    col = 0

    # RY
    colorwheel[0:RY, 0] = 255
    colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
    col += RY

    # YG
    colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
    colorwheel[col:col+YG, 1] = 255
    col += YG

    # GC
    colorwheel[col:col+GC, 1] = 255
    colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
    col += GC

    # CB
    colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
    colorwheel[col:col+CB, 2] = 255
    col += CB

    # BM
    colorwheel[col:col+BM, 2] = 255
    colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
    col += + BM

    # MR
    colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
    colorwheel[col:col+MR, 0] = 255

    return colorwheel


def resize_flow(flow, H_new, W_new):
    H_old, W_old = flow.shape[0:2]
    flow_resized = cv2.resize(flow, (W_new, H_new), interpolation=cv2.INTER_LINEAR)
    flow_resized[:, :, 0] *= H_new / H_old
    flow_resized[:, :, 1] *= W_new / W_old
    return flow_resized



def warp_flow(img, flow):
    h, w = flow.shape[:2]
    flow_new = flow.copy()
    flow_new[:,:,0] += np.arange(w)
    flow_new[:,:,1] += np.arange(h)[:,np.newaxis]

    res = cv2.remap(img, flow_new, None,
                    cv2.INTER_CUBIC,
                    borderMode=cv2.BORDER_CONSTANT)
    return res


def consistCheck(flowB, flowF):

    # |--------------------|  |--------------------|
    # |       y            |  |       v            |
    # |   x   *            |  |   u   *            |
    # |                    |  |                    |
    # |--------------------|  |--------------------|

    # sub: numPix * [y x t]

    imgH, imgW, _ = flowF.shape

    (fy, fx) = np.mgrid[0 : imgH, 0 : imgW].astype(np.float32)
    fxx = fx + flowB[:, :, 0]  # horizontal
    fyy = fy + flowB[:, :, 1]  # vertical

    u = (fxx + cv2.remap(flowF[:, :, 0], fxx, fyy, cv2.INTER_LINEAR) - fx)
    v = (fyy + cv2.remap(flowF[:, :, 1], fxx, fyy, cv2.INTER_LINEAR) - fy)
    BFdiff = (u ** 2 + v ** 2) ** 0.5

    return BFdiff, np.stack((u, v), axis=2)


def read_optical_flow(basedir, img_i_name, read_fwd):
    flow_dir = os.path.join(basedir, 'flow')

    fwd_flow_path = os.path.join(flow_dir, '%s_fwd.npz'%img_i_name[:-4])
    bwd_flow_path = os.path.join(flow_dir, '%s_bwd.npz'%img_i_name[:-4])

    if read_fwd:
      fwd_data = np.load(fwd_flow_path)
      fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']
      return fwd_flow, fwd_mask
    else:
      bwd_data = np.load(bwd_flow_path)
      bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']
      return bwd_flow, bwd_mask


def compute_epipolar_distance(T_21, K, p_1, p_2):
    R_21 = T_21[:3, :3]
    t_21 = T_21[:3, 3]

    E_mat = np.dot(skew(t_21), R_21)
    # compute bearing vector
    inv_K = np.linalg.inv(K)

    F_mat = np.dot(np.dot(inv_K.T, E_mat), inv_K)

    l_2 = np.dot(F_mat, p_1)
    algebric_e_distance = np.sum(p_2 * l_2, axis=0)
    n_term = np.sqrt(l_2[0, :]**2 + l_2[1, :]**2) + 1e-8
    geometric_e_distance = algebric_e_distance/n_term
    geometric_e_distance = np.abs(geometric_e_distance)

    return geometric_e_distance


def skew(x):
    return np.array([[0, -x[2], x[1]],
                     [x[2], 0, -x[0]],
                     [-x[1], x[0], 0]])


================================================
FILE: utils/generate_data.py
================================================
import os
import numpy as np
import imageio
import glob
import torch
import torchvision
import skimage.morphology
import argparse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def create_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)


def multi_view_multi_time(args):
    """
    Generating multi view multi time data
    """

    Maskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval()
    threshold = 0.5

    videoname, ext = os.path.splitext(os.path.basename(args.videopath))

    imgs = []
    reader = imageio.get_reader(args.videopath)
    for i, im in enumerate(reader):
        imgs.append(im)

    imgs = np.array(imgs)
    num_frames, H, W, _ = imgs.shape
    imgs = imgs[::int(np.ceil(num_frames / 100))]

    create_dir(os.path.join(args.data_dir, videoname, 'images'))
    create_dir(os.path.join(args.data_dir, videoname, 'images_colmap'))
    create_dir(os.path.join(args.data_dir, videoname, 'background_mask'))

    for idx, img in enumerate(imgs):
        print(idx)
        imageio.imwrite(os.path.join(args.data_dir, videoname, 'images', str(idx).zfill(3) + '.png'), img)
        imageio.imwrite(os.path.join(args.data_dir, videoname, 'images_colmap', str(idx).zfill(3) + '.jpg'), img)

        # Get coarse background mask
        img = torchvision.transforms.functional.to_tensor(img).to(device)
        background_mask = torch.FloatTensor(H, W).fill_(1.0).to(device)
        objPredictions = Maskrcnn([img])[0]

        for intMask in range(len(objPredictions['masks'])):
            if objPredictions['scores'][intMask].item() > threshold:
                if objPredictions['labels'][intMask].item() == 1: # person
                    background_mask[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0

        background_mask_np = ((background_mask.cpu().numpy() > 0.1) * 255).astype(np.uint8)
        imageio.imwrite(os.path.join(args.data_dir, videoname, 'background_mask', str(idx).zfill(3) + '.jpg.png'), background_mask_np)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--videopath", type=str,
                        help='video path')
    parser.add_argument("--data_dir", type=str, default='../data/',
                        help='where to store data')

    args = parser.parse_args()

    multi_view_multi_time(args)


================================================
FILE: utils/generate_depth.py
================================================
"""Compute depth maps for images in the input folder.
"""
import os
import cv2
import glob
import torch
import argparse
import numpy as np

from torchvision.transforms import Compose
from midas.midas_net import MidasNet
from midas.transforms import Resize, NormalizeImage, PrepareForNet


def create_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)


def read_image(path):
    """Read image and output RGB image (0-1).

    Args:
        path (str): path to file

    Returns:
        array: RGB image (0-1)
    """
    img = cv2.imread(path)

    if img.ndim == 2:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0

    return img


def run(input_path, output_path, output_img_path, model_path):
    """Run MonoDepthNN to compute depth maps.
    Args:
        input_path (str): path to input folder
        output_path (str): path to output folder
        model_path (str): path to saved model
    """
    print("initialize")

    # select device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device: %s" % device)

    # load network
    model = MidasNet(model_path, non_negative=True)
    sh = cv2.imread(sorted(glob.glob(os.path.join(input_path, "*")))[0]).shape
    net_w, net_h = sh[1], sh[0]

    resize_mode="upper_bound"

    transform = Compose(
        [
            Resize(
                net_w,
                net_h,
                resize_target=None,
                keep_aspect_ratio=True,
                ensure_multiple_of=32,
                resize_method=resize_mode,
                image_inter
Download .txt
gitextract_avcck_b3/

├── LICENSE
├── README.md
├── configs/
│   ├── config.txt
│   ├── config_Balloon1.txt
│   ├── config_Balloon2.txt
│   ├── config_Jumping.txt
│   ├── config_Playground.txt
│   ├── config_Skating.txt
│   ├── config_Truck.txt
│   └── config_Umbrella.txt
├── load_llff.py
├── render_utils.py
├── run_nerf.py
├── run_nerf_helpers.py
└── utils/
    ├── RAFT/
    │   ├── __init__.py
    │   ├── corr.py
    │   ├── datasets.py
    │   ├── demo.py
    │   ├── extractor.py
    │   ├── raft.py
    │   ├── update.py
    │   └── utils/
    │       ├── __init__.py
    │       ├── augmentor.py
    │       ├── flow_viz.py
    │       ├── frame_utils.py
    │       └── utils.py
    ├── colmap_utils.py
    ├── evaluation.py
    ├── flow_utils.py
    ├── generate_data.py
    ├── generate_depth.py
    ├── generate_flow.py
    ├── generate_motion_mask.py
    ├── generate_pose.py
    └── midas/
        ├── base_model.py
        ├── blocks.py
        ├── midas_net.py
        ├── transforms.py
        └── vit.py
Download .txt
SYMBOL INDEX (265 symbols across 27 files)

FILE: load_llff.py
  function _minify (line 10) | def _minify(basedir, factors=[], resolutions=[]):
  function _load_data (line 62) | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=...
  function normalize (line 199) | def normalize(x):
  function viewmatrix (line 202) | def viewmatrix(z, up, pos):
  function poses_avg (line 211) | def poses_avg(poses):
  function render_path_spiral (line 224) | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
  function recenter_poses (line 241) | def recenter_poses(poses):
  function spherify_poses (line 256) | def spherify_poses(poses, bds):
  function load_llff_data (line 315) | def load_llff_data(args, basedir,
  function generate_path (line 367) | def generate_path(c2w, args):

FILE: render_utils.py
  function batchify_rays (line 14) | def batchify_rays(t, chain_5frames,
  function render (line 30) | def render(t, chain_5frames,
  function render_path_batch (line 106) | def render_path_batch(render_poses, time2render,
  function render_path (line 190) | def render_path(render_poses,
  function raw2outputs (line 322) | def raw2outputs(raw_s,
  function raw2outputs_d (line 416) | def raw2outputs_d(raw_d,
  function render_rays (line 463) | def render_rays(t,

FILE: run_nerf.py
  function config_parser (line 16) | def config_parser():
  function train (line 151) | def train():

FILE: run_nerf_helpers.py
  function img2mse (line 12) | def img2mse(x, y, M=None):
  function img2mae (line 19) | def img2mae(x, y, M=None):
  function L1 (line 26) | def L1(x, M=None):
  function L2 (line 33) | def L2(x, M=None):
  function entropy (line 40) | def entropy(x):
  function mse2psnr (line 44) | def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10...
  function to8b (line 47) | def to8b(x): return (255 * np.clip(x, 0, 1)).astype(np.uint8)
  class Embedder (line 50) | class Embedder:
    method __init__ (line 52) | def __init__(self, **kwargs):
    method create_embedding_fn (line 57) | def create_embedding_fn(self):
    method embed (line 83) | def embed(self, inputs):
  function get_embedder (line 87) | def get_embedder(multires, i=0, input_dims=3):
  class NeRF_d (line 107) | class NeRF_d(nn.Module):
    method __init__ (line 108) | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch...
    method forward (line 134) | def forward(self, x):
  class NeRF_s (line 166) | class NeRF_s(nn.Module):
    method __init__ (line 167) | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch...
    method forward (line 192) | def forward(self, x):
  function batchify (line 219) | def batchify(fn, chunk):
  function run_network (line 230) | def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1...
  function create_nerf (line 249) | def create_nerf(args):
  function get_rays (line 361) | def get_rays(H, W, focal, c2w):
  function ndc_rays (line 374) | def ndc_rays(H, W, focal, near, rays_o, rays_d):
  function get_grid (line 409) | def get_grid(H, W, num_img, flows_f, flow_masks_f, flows_b, flow_masks_b):
  function NDC2world (line 433) | def NDC2world(pts, H, W, f):
  function render_3d_point (line 444) | def render_3d_point(H, W, f, pose, weights, pts):
  function induce_flow (line 471) | def induce_flow(H, W, focal, pose_neighbor, weights, pts_3d_neighbor, pt...
  function compute_depth_loss (line 483) | def compute_depth_loss(dyn_depth, gt_depth):
  function normalize_depth (line 496) | def normalize_depth(depth):
  function percentile (line 500) | def percentile(t, q):
  function save_res (line 519) | def save_res(moviebase, ret, fps=None):
  function norm_sf_channel (line 577) | def norm_sf_channel(sf_ch):
  function norm_sf (line 586) | def norm_sf(sf):
  function compute_sf_smooth_s_loss (line 596) | def compute_sf_smooth_s_loss(pts1, pts2, H, W, f):
  function compute_sf_smooth_loss (line 611) | def compute_sf_smooth_loss(pts, pts_f, pts_b, H, W, f):

FILE: utils/RAFT/corr.py
  class CorrBlock (line 12) | class CorrBlock:
    method __init__ (line 13) | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
    method __call__ (line 29) | def __call__(self, coords):
    method corr (line 53) | def corr(fmap1, fmap2):
  class CorrLayer (line 63) | class CorrLayer(torch.autograd.Function):
    method forward (line 65) | def forward(ctx, fmap1, fmap2, coords, r):
    method backward (line 75) | def backward(ctx, grad_corr):
  class AlternateCorrBlock (line 83) | class AlternateCorrBlock:
    method __init__ (line 84) | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
    method __call__ (line 94) | def __call__(self, coords):

FILE: utils/RAFT/datasets.py
  class FlowDataset (line 18) | class FlowDataset(data.Dataset):
    method __init__ (line 19) | def __init__(self, aug_params=None, sparse=False):
    method __getitem__ (line 34) | def __getitem__(self, index):
    method __rmul__ (line 93) | def __rmul__(self, v):
    method __len__ (line 98) | def __len__(self):
  class MpiSintel (line 102) | class MpiSintel(FlowDataset):
    method __init__ (line 103) | def __init__(self, aug_params=None, split='training', root='datasets/S...
  class FlyingChairs (line 121) | class FlyingChairs(FlowDataset):
    method __init__ (line 122) | def __init__(self, aug_params=None, split='train', root='datasets/Flyi...
  class FlyingThings3D (line 137) | class FlyingThings3D(FlowDataset):
    method __init__ (line 138) | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', ds...
  class KITTI (line 161) | class KITTI(FlowDataset):
    method __init__ (line 162) | def __init__(self, aug_params=None, split='training', root='datasets/K...
  class HD1K (line 180) | class HD1K(FlowDataset):
    method __init__ (line 181) | def __init__(self, aug_params=None, root='datasets/HD1k'):
  function fetch_dataloader (line 199) | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):

FILE: utils/RAFT/demo.py
  function load_image (line 18) | def load_image(imfile):
  function load_image_list (line 24) | def load_image_list(image_files):
  function viz (line 36) | def viz(img, flo):
  function demo (line 50) | def demo(args):
  function RAFT_infer (line 71) | def RAFT_infer(args):

FILE: utils/RAFT/extractor.py
  class ResidualBlock (line 6) | class ResidualBlock(nn.Module):
    method __init__ (line 7) | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    method forward (line 48) | def forward(self, x):
  class BottleneckBlock (line 60) | class BottleneckBlock(nn.Module):
    method __init__ (line 61) | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
    method forward (line 107) | def forward(self, x):
  class BasicEncoder (line 118) | class BasicEncoder(nn.Module):
    method __init__ (line 119) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 159) | def _make_layer(self, dim, stride=1):
    method forward (line 168) | def forward(self, x):
  class SmallEncoder (line 195) | class SmallEncoder(nn.Module):
    method __init__ (line 196) | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
    method _make_layer (line 235) | def _make_layer(self, dim, stride=1):
    method forward (line 244) | def forward(self, x):

FILE: utils/RAFT/raft.py
  class autocast (line 15) | class autocast:
    method __init__ (line 16) | def __init__(self, enabled):
    method __enter__ (line 18) | def __enter__(self):
    method __exit__ (line 20) | def __exit__(self, *args):
  class RAFT (line 24) | class RAFT(nn.Module):
    method __init__ (line 25) | def __init__(self, args):
    method freeze_bn (line 59) | def freeze_bn(self):
    method initialize_flow (line 64) | def initialize_flow(self, img):
    method upsample_flow (line 73) | def upsample_flow(self, flow, mask):
    method forward (line 87) | def forward(self, image1, image2, iters=12, flow_init=None, upsample=T...

FILE: utils/RAFT/update.py
  class FlowHead (line 6) | class FlowHead(nn.Module):
    method __init__ (line 7) | def __init__(self, input_dim=128, hidden_dim=256):
    method forward (line 13) | def forward(self, x):
  class ConvGRU (line 16) | class ConvGRU(nn.Module):
    method __init__ (line 17) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 23) | def forward(self, h, x):
  class SepConvGRU (line 33) | class SepConvGRU(nn.Module):
    method __init__ (line 34) | def __init__(self, hidden_dim=128, input_dim=192+128):
    method forward (line 45) | def forward(self, h, x):
  class SmallMotionEncoder (line 62) | class SmallMotionEncoder(nn.Module):
    method __init__ (line 63) | def __init__(self, args):
    method forward (line 71) | def forward(self, flow, corr):
  class BasicMotionEncoder (line 79) | class BasicMotionEncoder(nn.Module):
    method __init__ (line 80) | def __init__(self, args):
    method forward (line 89) | def forward(self, flow, corr):
  class SmallUpdateBlock (line 99) | class SmallUpdateBlock(nn.Module):
    method __init__ (line 100) | def __init__(self, args, hidden_dim=96):
    method forward (line 106) | def forward(self, net, inp, corr, flow):
  class BasicUpdateBlock (line 114) | class BasicUpdateBlock(nn.Module):
    method __init__ (line 115) | def __init__(self, args, hidden_dim=128, input_dim=128):
    method forward (line 127) | def forward(self, net, inp, corr, flow, upsample=True):

FILE: utils/RAFT/utils/augmentor.py
  class FlowAugmentor (line 15) | class FlowAugmentor:
    method __init__ (line 16) | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=T...
    method color_transform (line 36) | def color_transform(self, img1, img2):
    method eraser_transform (line 52) | def eraser_transform(self, img1, img2, bounds=[50, 100]):
    method spatial_transform (line 67) | def spatial_transform(self, img1, img2, flow):
    method __call__ (line 111) | def __call__(self, img1, img2, flow):
  class SparseFlowAugmentor (line 122) | class SparseFlowAugmentor:
    method __init__ (line 123) | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=F...
    method color_transform (line 142) | def color_transform(self, img1, img2):
    method eraser_transform (line 148) | def eraser_transform(self, img1, img2):
    method resize_sparse_flow_map (line 161) | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
    method spatial_transform (line 195) | def spatial_transform(self, img1, img2, flow, valid):
    method __call__ (line 236) | def __call__(self, img1, img2, flow, valid):

FILE: utils/RAFT/utils/flow_viz.py
  function make_colorwheel (line 20) | def make_colorwheel():
  function flow_uv_to_colors (line 70) | def flow_uv_to_colors(u, v, convert_to_bgr=False):
  function flow_to_image (line 109) | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):

FILE: utils/RAFT/utils/frame_utils.py
  function readFlow (line 12) | def readFlow(fn):
  function readPFM (line 33) | def readPFM(file):
  function writeFlow (line 70) | def writeFlow(filename,uv,v=None):
  function readFlowKITTI (line 102) | def readFlowKITTI(filename):
  function readDispKITTI (line 109) | def readDispKITTI(filename):
  function writeFlowKITTI (line 116) | def writeFlowKITTI(filename, uv):
  function read_gen (line 123) | def read_gen(file_name, pil=False):

FILE: utils/RAFT/utils/utils.py
  class InputPadder (line 7) | class InputPadder:
    method __init__ (line 9) | def __init__(self, dims, mode='sintel'):
    method pad (line 18) | def pad(self, *inputs):
    method unpad (line 21) | def unpad(self,x):
  function forward_interpolate (line 26) | def forward_interpolate(flow):
  function bilinear_sampler (line 57) | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
  function coords_grid (line 74) | def coords_grid(batch, ht, wd):
  function upflow8 (line 80) | def upflow8(flow, mode='bilinear'):

FILE: utils/colmap_utils.py
  class Image (line 48) | class Image(BaseImage):
    method qvec2rotmat (line 49) | def qvec2rotmat(self):
  function read_next_bytes (line 70) | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_charact...
  function read_cameras_text (line 82) | def read_cameras_text(path):
  function read_cameras_binary (line 108) | def read_cameras_binary(path_to_model_file):
  function read_images_text (line 137) | def read_images_text(path):
  function read_images_binary (line 168) | def read_images_binary(path_to_model_file):
  function read_points3D_text (line 203) | def read_points3D_text(path):
  function read_points3d_binary (line 230) | def read_points3d_binary(path_to_model_file):
  function read_model (line 260) | def read_model(path, ext):
  function qvec2rotmat (line 272) | def qvec2rotmat(qvec):
  function rotmat2qvec (line 285) | def rotmat2qvec(R):
  function main (line 299) | def main():

FILE: utils/evaluation.py
  function im2tensor (line 9) | def im2tensor(img):
  function create_dir (line 13) | def create_dir(dir):
  function readimage (line 18) | def readimage(data_dir, sequence, time, method):
  function calculate_metrics (line 23) | def calculate_metrics(data_dir, sequence, methods, lpips_loss):

FILE: utils/flow_utils.py
  function flow_to_image (line 8) | def flow_to_image(flow, global_max=None):
  function compute_color (line 50) | def compute_color(u, v):
  function make_color_wheel (line 94) | def make_color_wheel():
  function resize_flow (line 144) | def resize_flow(flow, H_new, W_new):
  function warp_flow (line 153) | def warp_flow(img, flow):
  function consistCheck (line 165) | def consistCheck(flowB, flowF):
  function read_optical_flow (line 188) | def read_optical_flow(basedir, img_i_name, read_fwd):
  function compute_epipolar_distance (line 204) | def compute_epipolar_distance(T_21, K, p_1, p_2):
  function skew (line 223) | def skew(x):

FILE: utils/generate_data.py
  function create_dir (line 13) | def create_dir(dir):
  function multi_view_multi_time (line 18) | def multi_view_multi_time(args):

FILE: utils/generate_depth.py
  function create_dir (line 15) | def create_dir(dir):
  function read_image (line 20) | def read_image(path):
  function run (line 39) | def run(input_path, output_path, output_img_path, model_path):

FILE: utils/generate_flow.py
  function create_dir (line 18) | def create_dir(dir):
  function load_image (line 23) | def load_image(imfile):
  function warp_flow (line 29) | def warp_flow(img, flow):
  function compute_fwdbwd_mask (line 39) | def compute_fwdbwd_mask(fwd_flow, bwd_flow):
  function run (line 56) | def run(args, input_path, output_path, output_img_path):

FILE: utils/generate_motion_mask.py
  function create_dir (line 17) | def create_dir(dir):
  function extract_poses (line 22) | def extract_poses(im):
  function load_colmap_data (line 32) | def load_colmap_data(realdir):
  function run_maskrcnn (line 60) | def run_maskrcnn(model, img_path, intWidth=1024, intHeight=576):
  function motion_segmentation (line 111) | def motion_segmentation(basedir, threshold,

FILE: utils/generate_pose.py
  function load_colmap_data (line 8) | def load_colmap_data(realdir):

FILE: utils/midas/base_model.py
  class BaseModel (line 4) | class BaseModel(torch.nn.Module):
    method load (line 5) | def load(self, path):

FILE: utils/midas/blocks.py
  function _make_encoder (line 11) | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=F...
  function _make_scratch (line 49) | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
  function _make_pretrained_efficientnet_lite3 (line 78) | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
  function _make_efficientnet_backbone (line 88) | def _make_efficientnet_backbone(effnet):
  function _make_resnet_backbone (line 101) | def _make_resnet_backbone(resnet):
  function _make_pretrained_resnext101_wsl (line 114) | def _make_pretrained_resnext101_wsl(use_pretrained):
  class Interpolate (line 120) | class Interpolate(nn.Module):
    method __init__ (line 124) | def __init__(self, scale_factor, mode, align_corners=False):
    method forward (line 138) | def forward(self, x):
  class ResidualConvUnit (line 155) | class ResidualConvUnit(nn.Module):
    method __init__ (line 159) | def __init__(self, features):
    method forward (line 177) | def forward(self, x):
  class FeatureFusionBlock (line 194) | class FeatureFusionBlock(nn.Module):
    method __init__ (line 198) | def __init__(self, features):
    method forward (line 209) | def forward(self, *xs):
  class ResidualConvUnit_custom (line 231) | class ResidualConvUnit_custom(nn.Module):
    method __init__ (line 235) | def __init__(self, features, activation, bn):
    method forward (line 263) | def forward(self, x):
  class FeatureFusionBlock_custom (line 291) | class FeatureFusionBlock_custom(nn.Module):
    method __init__ (line 295) | def __init__(self, features, activation, deconv=False, bn=False, expan...
    method forward (line 320) | def forward(self, *xs):

FILE: utils/midas/midas_net.py
  class MidasNet (line 12) | class MidasNet(BaseModel):
    method __init__ (line 16) | def __init__(self, path=None, features=256, non_negative=True):
    method forward (line 49) | def forward(self, x):

FILE: utils/midas/transforms.py
  function apply_min_size (line 6) | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AR...
  class Resize (line 48) | class Resize(object):
    method __init__ (line 52) | def __init__(
    method constrain_to_multiple_of (line 94) | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
    method get_size (line 105) | def get_size(self, width, height):
    method __call__ (line 162) | def __call__(self, sample):
  class NormalizeImage (line 197) | class NormalizeImage(object):
    method __init__ (line 201) | def __init__(self, mean, std):
    method __call__ (line 205) | def __call__(self, sample):
  class PrepareForNet (line 211) | class PrepareForNet(object):
    method __init__ (line 215) | def __init__(self):
    method __call__ (line 218) | def __call__(self, sample):

FILE: utils/midas/vit.py
  class Slice (line 9) | class Slice(nn.Module):
    method __init__ (line 10) | def __init__(self, start_index=1):
    method forward (line 14) | def forward(self, x):
  class AddReadout (line 18) | class AddReadout(nn.Module):
    method __init__ (line 19) | def __init__(self, start_index=1):
    method forward (line 23) | def forward(self, x):
  class ProjectReadout (line 31) | class ProjectReadout(nn.Module):
    method __init__ (line 32) | def __init__(self, in_features, start_index=1):
    method forward (line 38) | def forward(self, x):
  class Transpose (line 45) | class Transpose(nn.Module):
    method __init__ (line 46) | def __init__(self, dim0, dim1):
    method forward (line 51) | def forward(self, x):
  function forward_vit (line 56) | def forward_vit(pretrained, x):
  function _resize_pos_embed (line 100) | def _resize_pos_embed(self, posemb, gs_h, gs_w):
  function forward_flex (line 117) | def forward_flex(self, x):
  function get_activation (line 159) | def get_activation(name):
  function get_readout_oper (line 166) | def get_readout_oper(vit_features, features, use_readout, start_index=1):
  function _make_vit_b16_backbone (line 183) | def _make_vit_b16_backbone(
  function _make_pretrained_vitl16_384 (line 297) | def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=...
  function _make_pretrained_vitb16_384 (line 310) | def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=...
  function _make_pretrained_deitb16_384 (line 319) | def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks...
  function _make_pretrained_deitb16_distil_384 (line 328) | def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore"...
  function _make_vit_b_rn50_backbone (line 343) | def _make_vit_b_rn50_backbone(
  function _make_pretrained_vitb_rn50_384 (line 478) | def _make_pretrained_vitb_rn50_384(
Condensed preview — 39 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (261K chars).
[
  {
    "path": "LICENSE",
    "chars": 1226,
    "preview": "\nMIT License\n\nCopyright (c) 2020 Virginia Tech Vision and Learning Lab\n\nPermission is hereby granted, free of charge, to"
  },
  {
    "path": "README.md",
    "chars": 8818,
    "preview": "# Dynamic View Synthesis from Dynamic Monocular Video\n\n[![arXiv](https://img.shields.io/badge/arXiv-2108.00946-b31b1b.sv"
  },
  {
    "path": "configs/config.txt",
    "chars": 635,
    "preview": "expname = xxxxxx_DyNeRF_pretrain_test\nbasedir = ./logs\ndatadir = ./data/xxxxxx/\n\ndataset_type = llff\n\nfactor = 4\nN_rand "
  },
  {
    "path": "configs/config_Balloon1.txt",
    "chars": 656,
    "preview": "expname = Balloon1_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Balloon1/\n\ndataset_type = llff\n\nfactor = 2\nN_r"
  },
  {
    "path": "configs/config_Balloon2.txt",
    "chars": 655,
    "preview": "expname = Balloon2_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Balloon2/\n\ndataset_type = llff\n\nfactor = 2\nN_r"
  },
  {
    "path": "configs/config_Jumping.txt",
    "chars": 654,
    "preview": "expname = Jumping_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Jumping/\n\ndataset_type = llff\n\nfactor = 2\nN_ran"
  },
  {
    "path": "configs/config_Playground.txt",
    "chars": 659,
    "preview": "expname = Playground_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Playground/\n\ndataset_type = llff\n\nfactor = 2"
  },
  {
    "path": "configs/config_Skating.txt",
    "chars": 653,
    "preview": "expname = Skating_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Skating/\n\ndataset_type = llff\n\nfactor = 2\nN_ran"
  },
  {
    "path": "configs/config_Truck.txt",
    "chars": 649,
    "preview": "expname = Truck_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Truck/\n\ndataset_type = llff\n\nfactor = 2\nN_rand = "
  },
  {
    "path": "configs/config_Umbrella.txt",
    "chars": 655,
    "preview": "expname = Umbrella_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Umbrella/\n\ndataset_type = llff\n\nfactor = 2\nN_r"
  },
  {
    "path": "load_llff.py",
    "chars": 15268,
    "preview": "import os\nimport cv2\nimport imageio\nimport numpy as np\n\nfrom utils.flow_utils import resize_flow\nfrom run_nerf_helpers i"
  },
  {
    "path": "render_utils.py",
    "chars": 28187,
    "preview": "import os\nimport time\nimport torch\nimport imageio\nimport numpy as np\nimport torch.nn.functional as F\n\nfrom run_nerf_help"
  },
  {
    "path": "run_nerf.py",
    "chars": 35756,
    "preview": "import os\nimport time\nimport torch\nimport imageio\nimport numpy as np\nfrom torch.utils.tensorboard import SummaryWriter\n\n"
  },
  {
    "path": "run_nerf_helpers.py",
    "chars": 21893,
    "preview": "import os\nimport torch\nimport imageio\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndevice ="
  },
  {
    "path": "utils/RAFT/__init__.py",
    "chars": 54,
    "preview": "# from .demo import RAFT_infer\nfrom .raft import RAFT\n"
  },
  {
    "path": "utils/RAFT/corr.py",
    "chars": 3640,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom .utils.utils import bilinear_sampler, coords_grid\n\ntry:\n    import alt"
  },
  {
    "path": "utils/RAFT/datasets.py",
    "chars": 9245,
    "preview": "# Data loading based on https://github.com/NVIDIA/flownet2-pytorch\n\nimport numpy as np\nimport torch\nimport torch.utils.d"
  },
  {
    "path": "utils/RAFT/demo.py",
    "chars": 1856,
    "preview": "import sys\nimport argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom "
  },
  {
    "path": "utils/RAFT/extractor.py",
    "chars": 8847,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(se"
  },
  {
    "path": "utils/RAFT/raft.py",
    "chars": 4865,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .update import BasicUpdateBl"
  },
  {
    "path": "utils/RAFT/update.py",
    "chars": 5227,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FlowHead(nn.Module):\n    def __init__(self, i"
  },
  {
    "path": "utils/RAFT/utils/__init__.py",
    "chars": 71,
    "preview": "from .flow_viz import flow_to_image\nfrom .frame_utils import writeFlow\n"
  },
  {
    "path": "utils/RAFT/utils/augmentor.py",
    "chars": 9108,
    "preview": "import numpy as np\nimport random\nimport math\nfrom PIL import Image\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL"
  },
  {
    "path": "utils/RAFT/utils/flow_viz.py",
    "chars": 4318,
    "preview": "# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization\n\n\n# MIT License\n#\n# Copyright "
  },
  {
    "path": "utils/RAFT/utils/frame_utils.py",
    "chars": 4024,
    "preview": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUse"
  },
  {
    "path": "utils/RAFT/utils/utils.py",
    "chars": 2451,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\nclass InputPadder:\n    \""
  },
  {
    "path": "utils/colmap_utils.py",
    "chars": 13273,
    "preview": "# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.\n# All rights reserved.\n#\n# Redistribution and use in source and bi"
  },
  {
    "path": "utils/evaluation.py",
    "chars": 2948,
    "preview": "import os\nimport cv2\nimport lpips\nimport torch\nimport numpy as np\nfrom skimage.metrics import structural_similarity\n\n\nde"
  },
  {
    "path": "utils/flow_utils.py",
    "chars": 5704,
    "preview": "import os\nimport cv2\nimport numpy as np\nfrom PIL import Image\nfrom os.path import *\nUNKNOWN_FLOW_THRESH = 1e7\n\ndef flow_"
  },
  {
    "path": "utils/generate_data.py",
    "chars": 2395,
    "preview": "import os\nimport numpy as np\nimport imageio\nimport glob\nimport torch\nimport torchvision\nimport skimage.morphology\nimport"
  },
  {
    "path": "utils/generate_depth.py",
    "chars": 4210,
    "preview": "\"\"\"Compute depth maps for images in the input folder.\n\"\"\"\nimport os\nimport cv2\nimport glob\nimport torch\nimport argparse\n"
  },
  {
    "path": "utils/generate_flow.py",
    "chars": 3867,
    "preview": "import argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom RAFT.raft i"
  },
  {
    "path": "utils/generate_motion_mask.py",
    "chars": 10713,
    "preview": "import os\nimport cv2\nimport PIL\nimport glob\nimport torch\nimport argparse\nimport numpy as np\n\nfrom colmap_utils import re"
  },
  {
    "path": "utils/generate_pose.py",
    "chars": 3319,
    "preview": "import os\nimport glob\nimport argparse\nimport numpy as np\nfrom colmap_utils import read_cameras_binary, read_images_binar"
  },
  {
    "path": "utils/midas/base_model.py",
    "chars": 366,
    "preview": "import torch\n\n\nclass BaseModel(torch.nn.Module):\n    def load(self, path):\n        \"\"\"Load model from file.\n        Args"
  },
  {
    "path": "utils/midas/blocks.py",
    "chars": 9183,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .vit import (\n    _make_pretrained_vitb_rn50_384,\n    _make_pretrained_vitl16_3"
  },
  {
    "path": "utils/midas/midas_net.py",
    "chars": 2709,
    "preview": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is"
  },
  {
    "path": "utils/midas/transforms.py",
    "chars": 7869,
    "preview": "import numpy as np\nimport cv2\nimport math\n\n\ndef apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):"
  },
  {
    "path": "utils/midas/vit.py",
    "chars": 14625,
    "preview": "import torch\nimport torch.nn as nn\nimport timm\nimport types\nimport math\nimport torch.nn.functional as F\n\n\nclass Slice(nn"
  }
]

About this extraction

This page contains the full source code of the gaochen315/DynamicNeRF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 39 files (245.4 KB), approximately 69.2k tokens, and a symbol index with 265 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!