Repository: gaochen315/DynamicNeRF
Branch: main
Commit: c417fb207ef3
Files: 39
Total size: 245.4 KB
Directory structure:
gitextract_avcck_b3/
├── LICENSE
├── README.md
├── configs/
│ ├── config.txt
│ ├── config_Balloon1.txt
│ ├── config_Balloon2.txt
│ ├── config_Jumping.txt
│ ├── config_Playground.txt
│ ├── config_Skating.txt
│ ├── config_Truck.txt
│ └── config_Umbrella.txt
├── load_llff.py
├── render_utils.py
├── run_nerf.py
├── run_nerf_helpers.py
└── utils/
├── RAFT/
│ ├── __init__.py
│ ├── corr.py
│ ├── datasets.py
│ ├── demo.py
│ ├── extractor.py
│ ├── raft.py
│ ├── update.py
│ └── utils/
│ ├── __init__.py
│ ├── augmentor.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ └── utils.py
├── colmap_utils.py
├── evaluation.py
├── flow_utils.py
├── generate_data.py
├── generate_depth.py
├── generate_flow.py
├── generate_motion_mask.py
├── generate_pose.py
└── midas/
├── base_model.py
├── blocks.py
├── midas_net.py
├── transforms.py
└── vit.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2020 Virginia Tech Vision and Learning Lab
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
--------------------------- LICENSE FOR EdgeConnect --------------------------------
Attribution-NonCommercial 4.0 International
================================================
FILE: README.md
================================================
# Dynamic View Synthesis from Dynamic Monocular Video
[](https://arxiv.org/abs/2105.06468)
[Project Website](https://free-view-video.github.io/) | [Video](https://youtu.be/j8CUzIR0f8M) | [Paper](https://arxiv.org/abs/2105.06468)
> **Dynamic View Synthesis from Dynamic Monocular Video**
> [Chen Gao](http://chengao.vision), [Ayush Saraf](#), [Johannes Kopf](https://johanneskopf.de/), [Jia-Bin Huang](https://filebox.ece.vt.edu/~jbhuang/)
in ICCV 2021
## Setup
The code is test with
* Linux (tested on CentOS Linux release 7.4.1708)
* Anaconda 3
* Python 3.7.11
* CUDA 10.1
* 1 V100 GPU
To get started, please create the conda environment `dnerf` by running
```
conda create --name dnerf python=3.7
conda activate dnerf
conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch
pip install imageio scikit-image configargparse timm lpips
```
and install [COLMAP](https://colmap.github.io/install.html) manually. Then download MiDaS and RAFT weights
```
ROOT_PATH=/path/to/the/DynamicNeRF/folder
cd $ROOT_PATH
wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/weights.zip
unzip weights.zip
rm weights.zip
```
## Dynamic Scene Dataset
The [Dynamic Scene Dataset](https://www-users.cse.umn.edu/~jsyoon/dynamic_synth/) is used to
quantitatively evaluate our method. Please download the pre-processed data by running:
```
cd $ROOT_PATH
wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/data.zip
unzip data.zip
rm data.zip
```
### Training
You can train a model from scratch by running:
```
cd $ROOT_PATH/
python run_nerf.py --config configs/config_Balloon2.txt
```
Every 100k iterations, you should get videos like the following examples
The novel view-time synthesis results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/novelviewtime`.

The reconstruction results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset`.

The fix-view-change-time results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_view000`.

The fix-time-change-view results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_time000`.

### Rendering from pre-trained models
We also provide pre-trained models. You can download them by running:
```
cd $ROOT_PATH/
wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/logs.zip
unzip logs.zip
rm logs.zip
```
Then you can render the results directly by running:
```
python run_nerf.py --config configs/config_Balloon2.txt --render_only --ft_path $ROOT_PATH/logs/Balloon2_H270_DyNeRF_pretrain/300000.tar
```
### Evaluating our method and others
Our goal is to make the evaluation as simple as possible for you. We have collected the fix-view-change-time results of the following methods:
`NeRF` \
`NeRF + t` \
`Yoon et al.` \
`Non-Rigid NeRF` \
`NSFF` \
`DynamicNeRF (ours)`
Please download the results by running:
```
cd $ROOT_PATH/
wget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/results.zip
unzip results.zip
rm results.zip
```
Then you can calculate the PSNR/SSIM/LPIPS by running:
```
cd $ROOT_PATH/utils
python evaluation.py
```
| PSNR / LPIPS | Jumping | Skating | Truck | Umbrella | Balloon1 | Balloon2 | Playground | Average |
|:-------------|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|
| NeRF | 20.99 / 0.305 | 23.67 / 0.311 | 22.73 / 0.229 | 21.29 / 0.440 | 19.82 / 0.205 | 24.37 / 0.098 | 21.07 / 0.165 | 21.99 / 0.250 |
| NeRF + t | 18.04 / 0.455 | 20.32 / 0.512 | 18.33 / 0.382 | 17.69 / 0.728 | 18.54 / 0.275 | 20.69 / 0.216 | 14.68 / 0.421 | 18.33 / 0.427 |
| NR NeRF | 20.09 / 0.287 | 23.95 / 0.227 | 19.33 / 0.446 | 19.63 / 0.421 | 17.39 / 0.348 | 22.41 / 0.213 | 15.06 / 0.317 | 19.69 / 0.323 |
| NSFF | 24.65 / 0.151 | 29.29 / 0.129 | 25.96 / 0.167 | 22.97 / 0.295 | 21.96 / 0.215 | 24.27 / 0.222 | 21.22 / 0.212 | 24.33 / 0.199 |
| Ours | 24.68 / 0.090 | 32.66 / 0.035 | 28.56 / 0.082 | 23.26 / 0.137 | 22.36 / 0.104 | 27.06 / 0.049 | 24.15 / 0.080 | 26.10 / 0.082 |
Please note:
1. The numbers reported in the paper are calculated using TF code. The numbers here are calculated using this improved Pytorch version.
2. In Yoon's results, the first frame and the last frame are missing. To compare with Yoon's results, we have to omit the first frame and the last frame. To do so, please uncomment line 72 and comment line 73 in `evaluation.py`.
3. We obtain the results of NSFF and NR NeRF using the official implementation with default parameters.
## Train a model on your sequence
0. Set some paths
```
ROOT_PATH=/path/to/the/DynamicNeRF/folder
DATASET_NAME=name_of_the_video_without_extension
DATASET_PATH=$ROOT_PATH/data/$DATASET_NAME
```
1. Prepare training images and background masks from a video.
```
cd $ROOT_PATH/utils
python generate_data.py --videopath /path/to/the/video
```
2. Use COLMAP to obtain camera poses.
```
colmap feature_extractor \
--database_path $DATASET_PATH/database.db \
--image_path $DATASET_PATH/images_colmap \
--ImageReader.mask_path $DATASET_PATH/background_mask \
--ImageReader.single_camera 1
colmap exhaustive_matcher \
--database_path $DATASET_PATH/database.db
mkdir $DATASET_PATH/sparse
colmap mapper \
--database_path $DATASET_PATH/database.db \
--image_path $DATASET_PATH/images_colmap \
--output_path $DATASET_PATH/sparse \
--Mapper.num_threads 16 \
--Mapper.init_min_tri_angle 4 \
--Mapper.multiple_models 0 \
--Mapper.extract_colors 0
```
3. Save camera poses into the format that NeRF reads.
```
cd $ROOT_PATH/utils
python generate_pose.py --dataset_path $DATASET_PATH
```
4. Estimate monocular depth.
```
cd $ROOT_PATH/utils
python generate_depth.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/midas_v21-f6b98070.pt
```
5. Predict optical flows.
```
cd $ROOT_PATH/utils
python generate_flow.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/raft-things.pth
```
6. Obtain motion mask (code adapted from NSFF).
```
cd $ROOT_PATH/utils
python generate_motion_mask.py --dataset_path $DATASET_PATH
```
7. Train a model. Please change `expname` and `datadir` in `configs/config.txt`.
```
cd $ROOT_PATH/
python run_nerf.py --config configs/config.txt
```
Explanation of each parameter:
- `expname`: experiment name
- `basedir`: where to store ckpts and logs
- `datadir`: input data directory
- `factor`: downsample factor for the input images
- `N_rand`: number of random rays per gradient step
- `N_samples`: number of samples per ray
- `netwidth`: channels per layer
- `use_viewdirs`: whether enable view-dependency for StaticNeRF
- `use_viewdirsDyn`: whether enable view-dependency for DynamicNeRF
- `raw_noise_std`: std dev of noise added to regularize sigma_a output
- `no_ndc`: do not use normalized device coordinates
- `lindisp`: sampling linearly in disparity rather than depth
- `i_video`: frequency of novel view-time synthesis video saving
- `i_testset`: frequency of testset video saving
- `N_iters`: number of training iterations
- `i_img`: frequency of tensorboard image logging
- `DyNeRF_blending`: whether use DynamicNeRF to predict blending weight
- `pretrain`: whether pre-train StaticNeRF
## License
This work is licensed under MIT License. See [LICENSE](LICENSE) for details.
If you find this code useful for your research, please consider citing the following paper:
@inproceedings{Gao-ICCV-DynNeRF,
author = {Gao, Chen and Saraf, Ayush and Kopf, Johannes and Huang, Jia-Bin},
title = {Dynamic View Synthesis from Dynamic Monocular Video},
booktitle = {Proceedings of the IEEE International Conference on Computer Vision},
year = {2021}
}
## Acknowledgments
Our training code is build upon
[NeRF](https://github.com/bmild/nerf),
[NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch), and
[NSFF](https://github.com/zl548/Neural-Scene-Flow-Fields).
Our flow prediction code is modified from [RAFT](https://github.com/princeton-vl/RAFT).
Our depth prediction code is modified from [MiDaS](https://github.com/isl-org/MiDaS).
================================================
FILE: configs/config.txt
================================================
expname = xxxxxx_DyNeRF_pretrain_test
basedir = ./logs
datadir = ./data/xxxxxx/
dataset_type = llff
factor = 4
N_rand = 1024
N_samples = 64
netwidth = 256
i_video = 100000
i_testset = 100000
N_iters = 500001
i_img = 500
use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False
dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.01
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True
================================================
FILE: configs/config_Balloon1.txt
================================================
expname = Balloon1_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Balloon1/
dataset_type = llff
factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256
i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500
use_viewdirs = True
use_viewdirsDyn = False
raw_noise_std = 1e0
no_ndc = False
lindisp = False
dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True
================================================
FILE: configs/config_Balloon2.txt
================================================
expname = Balloon2_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Balloon2/
dataset_type = llff
factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256
i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500
use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False
dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True
================================================
FILE: configs/config_Jumping.txt
================================================
expname = Jumping_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Jumping/
dataset_type = llff
factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256
i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500
use_viewdirs = True
use_viewdirsDyn = False
raw_noise_std = 1e0
no_ndc = False
lindisp = False
dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True
================================================
FILE: configs/config_Playground.txt
================================================
expname = Playground_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Playground/
dataset_type = llff
factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256
i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500
use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False
dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True
================================================
FILE: configs/config_Skating.txt
================================================
expname = Skating_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Skating/
dataset_type = llff
factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256
i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500
use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False
dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True
================================================
FILE: configs/config_Truck.txt
================================================
expname = Truck_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Truck/
dataset_type = llff
factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256
i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500
use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False
dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True
================================================
FILE: configs/config_Umbrella.txt
================================================
expname = Umbrella_H270_DyNeRF_pretrain
basedir = ./logs
datadir = ./data/Umbrella/
dataset_type = llff
factor = 2
N_rand = 1024
N_samples = 64
N_importance = 0
netwidth = 256
i_video = 100000
i_testset = 100000
N_iters = 300001
i_img = 500
use_viewdirs = True
use_viewdirsDyn = True
raw_noise_std = 1e0
no_ndc = False
lindisp = False
dynamic_loss_lambda = 1.0
static_loss_lambda = 1.0
full_loss_lambda = 3.0
depth_loss_lambda = 0.04
order_loss_lambda = 0.1
flow_loss_lambda = 0.02
slow_loss_lambda = 0.01
smooth_loss_lambda = 0.1
consistency_loss_lambda = 1.0
mask_loss_lambda = 0.1
sparse_loss_lambda = 0.001
DyNeRF_blending = True
pretrain = True
================================================
FILE: load_llff.py
================================================
import os
import cv2
import imageio
import numpy as np
from utils.flow_utils import resize_flow
from run_nerf_helpers import get_grid
def _minify(basedir, factors=[], resolutions=[]):
needtoload = False
for r in factors:
imgdir = os.path.join(basedir, 'images_{}'.format(r))
if not os.path.exists(imgdir):
needtoload = True
for r in resolutions:
imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))
if not os.path.exists(imgdir):
needtoload = True
if not needtoload:
return
from shutil import copy
from subprocess import check_output
imgdir = os.path.join(basedir, 'images')
imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]
imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]
imgdir_orig = imgdir
wd = os.getcwd()
for r in factors + resolutions:
if isinstance(r, int):
name = 'images_{}'.format(r)
resizearg = '{}%'.format(100./r)
else:
name = 'images_{}x{}'.format(r[1], r[0])
resizearg = '{}x{}'.format(r[1], r[0])
imgdir = os.path.join(basedir, name)
if os.path.exists(imgdir):
continue
print('Minifying', r, basedir)
os.makedirs(imgdir)
check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)
ext = imgs[0].split('.')[-1]
args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])
print(args)
os.chdir(imgdir)
check_output(args, shell=True)
os.chdir(wd)
if ext != 'png':
check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)
print('Removed duplicates')
print('Done')
def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):
print('factor ', factor)
poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
bds = poses_arr[:, -2:].transpose([1,0])
img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \
if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]
sh = imageio.imread(img0).shape
sfx = ''
if factor is not None:
sfx = '_{}'.format(factor)
_minify(basedir, factors=[factor])
factor = factor
elif height is not None:
factor = sh[0] / float(height)
width = int(sh[1] / factor)
if width % 2 == 1:
width -= 1
_minify(basedir, resolutions=[[height, width]])
sfx = '_{}x{}'.format(width, height)
elif width is not None:
factor = sh[1] / float(width)
height = int(sh[0] / factor)
if height % 2 == 1:
height -= 1
_minify(basedir, resolutions=[[height, width]])
sfx = '_{}x{}'.format(width, height)
else:
factor = 1
imgdir = os.path.join(basedir, 'images' + sfx)
if not os.path.exists(imgdir):
print( imgdir, 'does not exist, returning' )
return
imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) \
if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]
if poses.shape[-1] != len(imgfiles):
print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )
return
sh = imageio.imread(imgfiles[0]).shape
num_img = len(imgfiles)
poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
poses[2, 4, :] = poses[2, 4, :] * 1./factor
if not load_imgs:
return poses, bds
def imread(f):
if f.endswith('png'):
return imageio.imread(f, ignoregamma=True)
else:
return imageio.imread(f)
imgs = [imread(f)[..., :3] / 255. for f in imgfiles]
imgs = np.stack(imgs, -1)
assert imgs.shape[0] == sh[0]
assert imgs.shape[1] == sh[1]
disp_dir = os.path.join(basedir, 'disp')
dispfiles = [os.path.join(disp_dir, f) \
for f in sorted(os.listdir(disp_dir)) if f.endswith('npy')]
disp = [cv2.resize(np.load(f),
(sh[1], sh[0]),
interpolation=cv2.INTER_NEAREST) for f in dispfiles]
disp = np.stack(disp, -1)
mask_dir = os.path.join(basedir, 'motion_masks')
maskfiles = [os.path.join(mask_dir, f) \
for f in sorted(os.listdir(mask_dir)) if f.endswith('png')]
masks = [cv2.resize(imread(f)/255., (sh[1], sh[0]),
interpolation=cv2.INTER_NEAREST) for f in maskfiles]
masks = np.stack(masks, -1)
masks = np.float32(masks > 1e-3)
flow_dir = os.path.join(basedir, 'flow')
flows_f = []
flow_masks_f = []
flows_b = []
flow_masks_b = []
for i in range(num_img):
if i == num_img - 1:
fwd_flow, fwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1]))
else:
fwd_flow_path = os.path.join(flow_dir, '%03d_fwd.npz'%i)
fwd_data = np.load(fwd_flow_path)
fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']
fwd_flow = resize_flow(fwd_flow, sh[0], sh[1])
fwd_mask = np.float32(fwd_mask)
fwd_mask = cv2.resize(fwd_mask, (sh[1], sh[0]),
interpolation=cv2.INTER_NEAREST)
flows_f.append(fwd_flow)
flow_masks_f.append(fwd_mask)
if i == 0:
bwd_flow, bwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1]))
else:
bwd_flow_path = os.path.join(flow_dir, '%03d_bwd.npz'%i)
bwd_data = np.load(bwd_flow_path)
bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']
bwd_flow = resize_flow(bwd_flow, sh[0], sh[1])
bwd_mask = np.float32(bwd_mask)
bwd_mask = cv2.resize(bwd_mask, (sh[1], sh[0]),
interpolation=cv2.INTER_NEAREST)
flows_b.append(bwd_flow)
flow_masks_b.append(bwd_mask)
flows_f = np.stack(flows_f, -1)
flow_masks_f = np.stack(flow_masks_f, -1)
flows_b = np.stack(flows_b, -1)
flow_masks_b = np.stack(flow_masks_b, -1)
print(imgs.shape)
print(disp.shape)
print(masks.shape)
print(flows_f.shape)
print(flow_masks_f.shape)
assert(imgs.shape[0] == disp.shape[0])
assert(imgs.shape[0] == masks.shape[0])
assert(imgs.shape[0] == flows_f.shape[0])
assert(imgs.shape[0] == flow_masks_f.shape[0])
assert(imgs.shape[1] == disp.shape[1])
assert(imgs.shape[1] == masks.shape[1])
return poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b
def normalize(x):
return x / np.linalg.norm(x)
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def poses_avg(poses):
hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0)
vec2 = normalize(poses[:, :3, 2].sum(0))
up = poses[:, :3, 1].sum(0)
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
return c2w
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
render_poses = []
rads = np.array(list(rads) + [1.])
hwf = c2w[:,4:5]
for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:
c = np.dot(c2w[:3, :4],
np.array([np.cos(theta),
-np.sin(theta),
-np.sin(theta*zrate),
1.]) * rads)
z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
return render_poses
def recenter_poses(poses):
poses_ = poses+0
bottom = np.reshape([0,0,0,1.], [1,4])
c2w = poses_avg(poses)
c2w = np.concatenate([c2w[:3,:4], bottom], -2)
bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])
poses = np.concatenate([poses[:,:3,:4], bottom], -2)
poses = np.linalg.inv(c2w) @ poses
poses_[:,:3,:4] = poses[:,:3,:4]
poses = poses_
return poses
def spherify_poses(poses, bds):
p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)
rays_d = poses[:,:3,2:3]
rays_o = poses[:,:3,3:4]
def min_line_dist(rays_o, rays_d):
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])
b_i = -A_i @ rays_o
pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))
return pt_mindist
pt_mindist = min_line_dist(rays_o, rays_d)
center = pt_mindist
up = (poses[:,:3,3] - center).mean(0)
vec0 = normalize(up)
vec1 = normalize(np.cross([.1,.2,.3], vec0))
vec2 = normalize(np.cross(vec0, vec1))
pos = center
c2w = np.stack([vec1, vec2, vec0, pos], 1)
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))
sc = 1./rad
poses_reset[:,:3,3] *= sc
bds *= sc
rad *= sc
centroid = np.mean(poses_reset[:,:3,3], 0)
zh = centroid[2]
radcircle = np.sqrt(rad**2-zh**2)
new_poses = []
for th in np.linspace(0.,2.*np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0,0,-1.])
vec2 = normalize(camorigin)
vec0 = normalize(np.cross(vec2, up))
vec1 = normalize(np.cross(vec2, vec0))
pos = camorigin
p = np.stack([vec0, vec1, vec2, pos], 1)
new_poses.append(p)
new_poses = np.stack(new_poses, 0)
new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)
poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)
return poses_reset, new_poses, bds
def load_llff_data(args, basedir,
factor=2,
recenter=True, bd_factor=.75,
spherify=False, path_zflat=False,
frame2dolly=10):
poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b = \
_load_data(basedir, factor=factor) # factor=2 downsamples original imgs by 2x
print('Loaded', basedir, bds.min(), bds.max())
# Correct rotation matrix ordering and move variable dim to axis 0
poses = np.concatenate([poses[:, 1:2, :],
-poses[:, 0:1, :],
poses[:, 2:, :]], 1)
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
images = np.moveaxis(imgs, -1, 0).astype(np.float32)
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
disp = np.moveaxis(disp, -1, 0).astype(np.float32)
masks = np.moveaxis(masks, -1, 0).astype(np.float32)
flows_f = np.moveaxis(flows_f, -1, 0).astype(np.float32)
flow_masks_f = np.moveaxis(flow_masks_f, -1, 0).astype(np.float32)
flows_b = np.moveaxis(flows_b, -1, 0).astype(np.float32)
flow_masks_b = np.moveaxis(flow_masks_b, -1, 0).astype(np.float32)
# Rescale if bd_factor is provided
sc = 1. if bd_factor is None else 1./(np.percentile(bds[:, 0], 5) * bd_factor)
poses[:, :3, 3] *= sc
bds *= sc
if recenter:
poses = recenter_poses(poses)
# Only for rendering
if frame2dolly == -1:
c2w = poses_avg(poses)
else:
c2w = poses[frame2dolly, :, :]
H, W, _ = c2w[:, -1]
# Generate poses for novel views
render_poses, render_focals = generate_path(c2w, args)
render_poses = np.array(render_poses).astype(np.float32)
grids = get_grid(int(H), int(W), len(poses), flows_f, flow_masks_f, flows_b, flow_masks_b) # [N, H, W, 8]
return images, disp, masks, poses, bds,\
render_poses, render_focals, grids
def generate_path(c2w, args):
hwf = c2w[:, 4:5]
num_novelviews = args.num_novelviews
max_disp = 48.0
H, W, focal = hwf[:, 0]
max_trans = max_disp / focal
output_poses = []
output_focals = []
# Rendering teaser. Add translation.
for i in range(num_novelviews):
x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.x_trans_multiplier
y_trans = max_trans * (np.cos(2.0 * np.pi * float(i) / float(num_novelviews)) - 1.) * args.y_trans_multiplier
z_trans = 0.
i_pose = np.concatenate([
np.concatenate(
[np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
],axis=0)
i_pose = np.linalg.inv(i_pose)
ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)
render_pose = np.dot(ref_pose, i_pose)
output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
output_focals.append(focal)
# Rendering teaser. Add zooming.
if args.frame2dolly != -1:
for i in range(num_novelviews // 2 + 1):
x_trans = 0.
y_trans = 0.
# z_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.z_trans_multiplier
z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2)
i_pose = np.concatenate([
np.concatenate(
[np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
],axis=0)
i_pose = np.linalg.inv(i_pose) #torch.tensor(np.linalg.inv(i_pose)).float()
ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)
render_pose = np.dot(ref_pose, i_pose)
output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
output_focals.append(focal)
print(z_trans / max_trans / args.z_trans_multiplier)
# Rendering teaser. Add dolly zoom.
if args.frame2dolly != -1:
for i in range(num_novelviews // 2 + 1):
x_trans = 0.
y_trans = 0.
z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2)
i_pose = np.concatenate([
np.concatenate(
[np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),
np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]
],axis=0)
i_pose = np.linalg.inv(i_pose)
ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)
render_pose = np.dot(ref_pose, i_pose)
output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))
new_focal = focal - args.focal_decrease * z_trans / max_trans / args.z_trans_multiplier
output_focals.append(new_focal)
print(z_trans / max_trans / args.z_trans_multiplier, new_focal)
return output_poses, output_focals
================================================
FILE: render_utils.py
================================================
import os
import time
import torch
import imageio
import numpy as np
import torch.nn.functional as F
from run_nerf_helpers import *
from utils.flow_utils import flow_to_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def batchify_rays(t, chain_5frames,
rays_flat, chunk=1024*16, **kwargs):
"""Render rays in smaller minibatches to avoid OOM.
"""
all_ret = {}
for i in range(0, rays_flat.shape[0], chunk):
ret = render_rays(t, chain_5frames, rays_flat[i:i+chunk], **kwargs)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
def render(t, chain_5frames,
H, W, focal, focal_render=None,
chunk=1024*16, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
"""Render rays
Args:
H: int. Height of image in pixels.
W: int. Width of image in pixels.
focal: float. Focal length of pinhole camera.
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results.
rays: array of shape [2, batch_size, 3]. Ray origin and direction for
each example in batch.
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
ndc: bool. If True, represent ray origin, direction in NDC coordinates.
near: float or array of shape [batch_size]. Nearest distance for a ray.
far: float or array of shape [batch_size]. Farthest distance for a ray.
use_viewdirs: bool. If True, use viewing direction of a point in space in model.
c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
camera while using other c2w argument for viewing directions.
Returns:
rgb_map: [batch_size, 3]. Predicted RGB values for rays.
disp_map: [batch_size]. Disparity map. Inverse of depth.
acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
extras: dict with everything returned by render_rays().
"""
if c2w is not None:
# special case to render full image
if focal_render is not None:
# Render full image using different focal length for dolly zoom. Inference only.
rays_o, rays_d = get_rays(H, W, focal_render, c2w)
else:
rays_o, rays_d = get_rays(H, W, focal, c2w)
else:
# use provided ray batch
rays_o, rays_d = rays
if use_viewdirs:
# provide ray directions as input
viewdirs = rays_d
if c2w_staticcam is not None:
raise NotImplementedError
# Make all directions unit magnitude.
# shape: [batch_size, 3]
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
viewdirs = torch.reshape(viewdirs, [-1, 3]).float()
sh = rays_d.shape # [..., 3]
if ndc:
# for forward facing scenes
rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
# Create ray batch
rays_o = torch.reshape(rays_o, [-1, 3]).float()
rays_d = torch.reshape(rays_d, [-1, 3]).float()
near, far = near * \
torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])
# (ray origin, ray direction, min dist, max dist) for each ray
rays = torch.cat([rays_o, rays_d, near, far], -1)
if use_viewdirs:
rays = torch.cat([rays, viewdirs], -1)
# Render and reshape
all_ret = batchify_rays(t, chain_5frames,
rays, chunk, **kwargs)
for k in all_ret:
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)
return all_ret
def render_path_batch(render_poses, time2render,
hwf, chunk, render_kwargs, savedir=None, focal2render=None):
"""Render frames using batch.
Args:
render_poses: array of shape [num_frame, 3, 4]. Camera-to-world transformation matrix of each frame.
time2render: array of shape [num_frame]. Time of each frame.
hwf: list. [Height of image in pixels, Width of image in pixels, Focal length of pinhole camera]
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results.
render_kwargs: dictionary. args for the render function.
savedir: string. Directory to save results.
focal2render: list. Only used to perform dolly-zoom.
Returns:
ret_dict: dictionary. Final and intermediate results.
"""
H, W, focal = hwf
ret_dict = {}
rgbs = []
rgbs_d = []
rgbs_s = []
dynamicnesses = []
time_curr = time.time()
for i, c2w in enumerate(render_poses):
print(i, time.time() - time_curr)
time_curr = time.time()
t = time2render[i]
if focal2render is not None:
# Render full image using different focal length
rays_o, rays_d = get_rays(H, W, focal2render[i], c2w)
else:
rays_o, rays_d = get_rays(H, W, focal, c2w)
rays_o = torch.reshape(rays_o, (-1, 3))
rays_d = torch.reshape(rays_d, (-1, 3))
batch_rays = torch.stack([rays_o, rays_d], 0)
rgb = []
rgb_d = []
rgb_s = []
dynamicness = []
for j in range(0, batch_rays.shape[1], chunk):
# print(j, '/', batch_rays.shape[1])
ret = render(t, False,
H, W, focal,
chunk=chunk, rays=batch_rays[:, j:j+chunk, :],
**render_kwargs)
rgb.append(ret['rgb_map_full'].cpu())
rgb_d.append(ret['rgb_map_d'].cpu())
rgb_s.append(ret['rgb_map_s'].cpu())
dynamicness.append(ret['dynamicness_map'].cpu())
rgb = torch.reshape(torch.cat(rgb, 0), (H, W, 3)).numpy()
rgb_d = torch.reshape(torch.cat(rgb_d, 0), (H, W, 3)).numpy()
rgb_s = torch.reshape(torch.cat(rgb_s, 0), (H, W, 3)).numpy()
dynamicness = torch.reshape(torch.cat(dynamicness, 0), (H, W)).numpy()
# Not a good solution. Should take care of this when preparing the data.
if W%2 == 1:
# rgb = cv2.resize(rgb, (W - 1, H))
rgb = rgb[:, :-1, :]
rgb_d = rgb_d[:, :-1, :]
rgb_s = rgb_s[:, :-1, :]
dynamicness = dynamicness[:, :-1]
rgbs.append(rgb)
rgbs_d.append(rgb_d)
rgbs_s.append(rgb_s)
dynamicnesses.append(dynamicness)
if savedir is not None:
rgb8 = to8b(rgbs[-1])
filename = os.path.join(savedir, '{:03d}.png'.format(i))
imageio.imwrite(filename, rgb8)
ret_dict['rgbs'] = np.stack(rgbs, 0)
ret_dict['rgbs_d'] = np.stack(rgbs_d, 0)
ret_dict['rgbs_s'] = np.stack(rgbs_s, 0)
ret_dict['dynamicnesses'] = np.stack(dynamicnesses, 0)
return ret_dict
def render_path(render_poses,
time2render,
hwf,
chunk,
render_kwargs,
savedir=None,
flows_gt_f=None,
flows_gt_b=None,
focal2render=None):
"""Render frames.
Args:
render_poses: array of shape [num_frame, 3, 4]. Camera-to-world transformation matrix of each frame.
time2render: array of shape [num_frame]. Time of each frame.
hwf: list. [Height of image in pixels, Width of image in pixels, Focal length of pinhole camera]
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results.
render_kwargs: dictionary. args for the render function.
savedir: string. Directory to save results.
focal2render: list. Only used to perform dolly-zoom.
Returns:
ret_dict: dictionary. Final and intermediate results.
"""
H, W, focal = hwf
ret_dict = {}
rgbs = []
rgbs_d = []
rgbs_s = []
depths = []
depths_d = []
depths_s = []
flows_f = []
flows_b = []
dynamicness = []
blending = []
grid = np.stack(np.meshgrid(np.arange(W, dtype=np.float32),
np.arange(H, dtype=np.float32), indexing='xy'), -1)
grid = torch.Tensor(grid)
time_curr = time.time()
for i, c2w in enumerate(render_poses):
t = time2render[i]
pose = c2w[:3, :4]
print(i, time.time() - time_curr)
time_curr = time.time()
if focal2render is None:
# Normal rendering.
ret = render(t, False,
H, W, focal,
chunk=1024*32, c2w=pose,
**render_kwargs)
else:
# Render image using different focal length.
ret = render(t, False,
H, W, focal, focal_render=focal2render[i],
chunk=1024*32, c2w=pose,
**render_kwargs)
rgbs.append(ret['rgb_map_full'].cpu().numpy())
rgbs_d.append(ret['rgb_map_d'].cpu().numpy())
rgbs_s.append(ret['rgb_map_s'].cpu().numpy())
depths.append(ret['depth_map_full'].cpu().numpy())
depths_d.append(ret['depth_map_d'].cpu().numpy())
depths_s.append(ret['depth_map_s'].cpu().numpy())
dynamicness.append(ret['dynamicness_map'].cpu().numpy())
if flows_gt_f is not None:
# Reconstruction. Flow is caused by both changing camera and changing time.
pose_f = render_poses[min(i + 1, int(len(render_poses)) - 1), :3, :4]
pose_b = render_poses[max(i - 1, 0), :3, :4]
else:
# Non training view-time. Flow is caused by changing time (just for visualization).
pose_f = render_poses[i, :3, :4]
pose_b = render_poses[i, :3, :4]
# Sceneflow induced optical flow
induced_flow_f_ = induce_flow(H, W, focal, pose_f, ret['weights_d'], ret['raw_pts_f'], grid[..., :2])
induced_flow_b_ = induce_flow(H, W, focal, pose_b, ret['weights_d'], ret['raw_pts_b'], grid[..., :2])
if (i + 1) >= len(render_poses):
induced_flow_f = np.zeros((H, W, 2))
else:
induced_flow_f = induced_flow_f_.cpu().numpy()
if flows_gt_f is not None:
flow_gt_f = flows_gt_f[i].cpu().numpy()
induced_flow_f = np.concatenate((induced_flow_f, flow_gt_f), 0)
induced_flow_f_img = flow_to_image(induced_flow_f)
flows_f.append(induced_flow_f_img)
if (i - 1) < 0:
induced_flow_b = np.zeros((H, W, 2))
else:
induced_flow_b = induced_flow_b_.cpu().numpy()
if flows_gt_b is not None:
flow_gt_b = flows_gt_b[i].cpu().numpy()
induced_flow_b = np.concatenate((induced_flow_b, flow_gt_b), 0)
induced_flow_b_img = flow_to_image(induced_flow_b)
flows_b.append(induced_flow_b_img)
if i == 0:
ret_dict['sceneflow_f_NDC'] = ret['sceneflow_f'].cpu().numpy()
ret_dict['sceneflow_b_NDC'] = ret['sceneflow_b'].cpu().numpy()
ret_dict['blending'] = ret['blending'].cpu().numpy()
weights = np.concatenate((ret['weights_d'][..., None].cpu().numpy(),
ret['weights_s'][..., None].cpu().numpy(),
ret['blending'][..., None].cpu().numpy(),
ret['weights_full'][..., None].cpu().numpy()))
ret_dict['weights'] = np.moveaxis(weights, [0, 1, 2, 3], [1, 2, 0, 3])
if savedir is not None:
rgb8 = to8b(rgbs[-1])
filename = os.path.join(savedir, '{:03d}.png'.format(i))
imageio.imwrite(filename, rgb8)
ret_dict['rgbs'] = np.stack(rgbs, 0)
ret_dict['rgbs_d'] = np.stack(rgbs_d, 0)
ret_dict['rgbs_s'] = np.stack(rgbs_s, 0)
ret_dict['depths'] = np.stack(depths, 0)
ret_dict['depths_d'] = np.stack(depths_d, 0)
ret_dict['depths_s'] = np.stack(depths_s, 0)
ret_dict['dynamicness'] = np.stack(dynamicness, 0)
ret_dict['flows_f'] = np.stack(flows_f, 0)
ret_dict['flows_b'] = np.stack(flows_b, 0)
return ret_dict
def raw2outputs(raw_s,
raw_d,
blending,
z_vals,
rays_d,
raw_noise_std):
"""Transforms model's predictions to semantically meaningful values.
Args:
raw_d: [num_rays, num_samples along ray, 4]. Prediction from Dynamic model.
raw_s: [num_rays, num_samples along ray, 4]. Prediction from Static model.
z_vals: [num_rays, num_samples along ray]. Integration time.
rays_d: [num_rays, 3]. Direction of each ray.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
disp_map: [num_rays]. Disparity map. Inverse of depth map.
acc_map: [num_rays]. Sum of weights along each ray.
weights: [num_rays, num_samples]. Weights assigned to each sampled color.
depth_map: [num_rays]. Estimated distance to object.
"""
# Function for computing density from model prediction. This value is
# strictly between [0, 1].
def raw2alpha(raw, dists, act_fn=F.relu): return 1.0 - \
torch.exp(-act_fn(raw) * dists)
# Compute 'distance' (in time) between each integration time along a ray.
dists = z_vals[..., 1:] - z_vals[..., :-1]
# The 'distance' from the last integration time is infinity.
dists = torch.cat(
[dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)],
-1) # [N_rays, N_samples]
# Multiply each distance by the norm of its corresponding direction ray
# to convert to real world distance (accounts for non-unit directions).
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
# Extract RGB of each sample position along each ray.
rgb_d = torch.sigmoid(raw_d[..., :3]) # [N_rays, N_samples, 3]
rgb_s = torch.sigmoid(raw_s[..., :3]) # [N_rays, N_samples, 3]
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
noise = 0.
if raw_noise_std > 0.:
noise = torch.randn(raw_d[..., 3].shape) * raw_noise_std
# Predict density of each sample along each ray. Higher values imply
# higher likelihood of being absorbed at this point.
alpha_d = raw2alpha(raw_d[..., 3] + noise, dists) # [N_rays, N_samples]
alpha_s = raw2alpha(raw_s[..., 3] + noise, dists) # [N_rays, N_samples]
alphas = 1. - (1. - alpha_s) * (1. - alpha_d) # [N_rays, N_samples]
T_d = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), 1. - alpha_d + 1e-10], -1), -1)[:, :-1]
T_s = torch.cumprod(torch.cat([torch.ones((alpha_s.shape[0], 1)), 1. - alpha_s + 1e-10], -1), -1)[:, :-1]
T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), (1. - alpha_d * blending) * (1. - alpha_s * (1. - blending)) + 1e-10], -1), -1)[:, :-1]
# T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), torch.pow(1. - alpha_d + 1e-10, blending) * torch.pow(1. - alpha_s + 1e-10, 1. - blending)], -1), -1)[:, :-1]
# T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), (1. - alpha_d) * (1. - alpha_s) + 1e-10], -1), -1)[:, :-1]
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
weights_d = alpha_d * T_d
weights_s = alpha_s * T_s
weights_full = (alpha_d * blending + alpha_s * (1. - blending)) * T_full
# weights_full = alphas * T_full
# Computed weighted color of each sample along each ray.
rgb_map_d = torch.sum(weights_d[..., None] * rgb_d, -2)
rgb_map_s = torch.sum(weights_s[..., None] * rgb_s, -2)
rgb_map_full = torch.sum(
(T_full * alpha_d * blending)[..., None] * rgb_d + \
(T_full * alpha_s * (1. - blending))[..., None] * rgb_s, -2)
# Estimated depth map is expected distance.
depth_map_d = torch.sum(weights_d * z_vals, -1)
depth_map_s = torch.sum(weights_s * z_vals, -1)
depth_map_full = torch.sum(weights_full * z_vals, -1)
# Sum of weights along each ray. This value is in [0, 1] up to numerical error.
acc_map_d = torch.sum(weights_d, -1)
acc_map_s = torch.sum(weights_s, -1)
acc_map_full = torch.sum(weights_full, -1)
# Computed dynamicness
dynamicness_map = torch.sum(weights_full * blending, -1)
# dynamicness_map = 1 - T_d[..., -1]
return rgb_map_full, depth_map_full, acc_map_full, weights_full, \
rgb_map_s, depth_map_s, acc_map_s, weights_s, \
rgb_map_d, depth_map_d, acc_map_d, weights_d, dynamicness_map
def raw2outputs_d(raw_d,
z_vals,
rays_d,
raw_noise_std):
# Function for computing density from model prediction. This value is
# strictly between [0, 1].
def raw2alpha(raw, dists, act_fn=F.relu): return 1.0 - \
torch.exp(-act_fn(raw) * dists)
# Compute 'distance' (in time) between each integration time along a ray.
dists = z_vals[..., 1:] - z_vals[..., :-1]
# The 'distance' from the last integration time is infinity.
dists = torch.cat(
[dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)],
-1) # [N_rays, N_samples]
# Multiply each distance by the norm of its corresponding direction ray
# to convert to real world distance (accounts for non-unit directions).
dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
# Extract RGB of each sample position along each ray.
rgb_d = torch.sigmoid(raw_d[..., :3]) # [N_rays, N_samples, 3]
# Add noise to model's predictions for density. Can be used to
# regularize network during training (prevents floater artifacts).
noise = 0.
if raw_noise_std > 0.:
noise = torch.randn(raw_d[..., 3].shape) * raw_noise_std
# Predict density of each sample along each ray. Higher values imply
# higher likelihood of being absorbed at this point.
alpha_d = raw2alpha(raw_d[..., 3] + noise, dists) # [N_rays, N_samples]
T_d = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), 1. - alpha_d + 1e-10], -1), -1)[:, :-1]
# Compute weight for RGB of each sample along each ray. A cumprod() is
# used to express the idea of the ray not having reflected up to this
# sample yet.
weights_d = alpha_d * T_d
# Computed weighted color of each sample along each ray.
rgb_map_d = torch.sum(weights_d[..., None] * rgb_d, -2)
return rgb_map_d, weights_d
def render_rays(t,
chain_5frames,
ray_batch,
network_fn_d,
network_fn_s,
network_query_fn_d,
network_query_fn_s,
N_samples,
num_img,
DyNeRF_blending,
pretrain=False,
lindisp=False,
perturb=0.,
N_importance=0,
raw_noise_std=0.,
inference=False):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn_d: function. Model for predicting RGB and density at each point
in space.
network_query_fn_d: function used for passing queries to network_fn_d.
N_samples: int. Number of different times to sample along each ray.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
raw_noise_std: ...
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
# batch size
N_rays = ray_batch.shape[0]
# ray_batch: [N_rays, 11]
# rays_o: [N_rays, 0:3]
# rays_d: [N_rays, 3:6]
# near: [N_rays, 6:7]
# far: [N_rays, 7:8]
# viewdirs: [N_rays, 8:11]
# Extract ray origin, direction.
rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each
# Extract unit-normalized viewing direction.
viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
# Extract lower, upper bound for ray distance.
bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])
near, far = bounds[..., 0], bounds[..., 1]
# Decide where to sample along each ray. Under the logic, all rays will be sampled at
# the same times.
t_vals = torch.linspace(0., 1., steps=N_samples)
if not lindisp:
# Space integration times linearly between 'near' and 'far'. Same
# integration points will be used for all rays.
z_vals = near * (1.-t_vals) + far * (t_vals)
else:
# Sample linearly in inverse depth (disparity).
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
z_vals = z_vals.expand([N_rays, N_samples])
# Perturb sampling time along each ray.
if perturb > 0.:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape)
z_vals = lower + (upper - lower) * t_rand
# Points in space to evaluate model at.
pts = rays_o[..., None, :] + rays_d[..., None, :] * \
z_vals[..., :, None] # [N_rays, N_samples, 3]
# Add the time dimension to xyz.
pts_ref = torch.cat([pts, torch.ones_like(pts[..., 0:1]) * t], -1)
# First pass: we have the staticNeRF results
raw_s = network_query_fn_s(pts_ref[..., :3], viewdirs, network_fn_s)
# raw_s: [N_rays, N_samples, 5]
# raw_s_rgb: [N_rays, N_samples, 0:3]
# raw_s_a: [N_rays, N_samples, 3:4]
# raw_s_blending: [N_rays, N_samples, 4:5]
# Second pass: we have the DyanmicNeRF results and the blending weight
raw_d = network_query_fn_d(pts_ref, viewdirs, network_fn_d)
# raw_d: [N_rays, N_samples, 11]
# raw_d_rgb: [N_rays, N_samples, 0:3]
# raw_d_a: [N_rays, N_samples, 3:4]
# sceneflow_b: [N_rays, N_samples, 4:7]
# sceneflow_f: [N_rays, N_samples, 7:10]
# raw_d_blending: [N_rays, N_samples, 10:11]
if pretrain:
rgb_map_s, _ = raw2outputs_d(raw_s[..., :4],
z_vals,
rays_d,
raw_noise_std)
ret = {'rgb_map_s': rgb_map_s}
return ret
raw_s_rgba = raw_s[..., :4]
raw_d_rgba = raw_d[..., :4]
# We need the sceneflow from the dynamicNeRF.
sceneflow_b = raw_d[..., 4:7]
sceneflow_f = raw_d[..., 7:10]
if DyNeRF_blending:
blending = raw_d[..., 10]
else:
blending = raw_s[..., 4]
# if sfmask:
# sceneflow_f = sceneflow_f * blending.detach()[..., None]
# sceneflow_b = sceneflow_b * blending.detach()[..., None]
# Rerndering.
rgb_map_full, depth_map_full, acc_map_full, weights_full, \
rgb_map_s, depth_map_s, acc_map_s, weights_s, \
rgb_map_d, depth_map_d, acc_map_d, weights_d, \
dynamicness_map = raw2outputs(raw_s_rgba,
raw_d_rgba,
blending,
z_vals,
rays_d,
raw_noise_std)
ret = {'rgb_map_full': rgb_map_full, 'depth_map_full': depth_map_full, 'acc_map_full': acc_map_full, 'weights_full': weights_full,
'rgb_map_s': rgb_map_s, 'depth_map_s': depth_map_s, 'acc_map_s': acc_map_s, 'weights_s': weights_s,
'rgb_map_d': rgb_map_d, 'depth_map_d': depth_map_d, 'acc_map_d': acc_map_d, 'weights_d': weights_d,
'dynamicness_map': dynamicness_map}
t_interval = 1. / num_img * 2.
pts_f = torch.cat([pts + sceneflow_f, torch.ones_like(pts[..., 0:1]) * (t + t_interval)], -1)
pts_b = torch.cat([pts + sceneflow_b, torch.ones_like(pts[..., 0:1]) * (t - t_interval)], -1)
ret['sceneflow_b'] = sceneflow_b
ret['sceneflow_f'] = sceneflow_f
ret['raw_pts'] = pts_ref[..., :3]
ret['raw_pts_f'] = pts_f[..., :3]
ret['raw_pts_b'] = pts_b[..., :3]
ret['blending'] = blending
# Third pass: we have the DyanmicNeRF results at time t - 1
raw_d_b = network_query_fn_d(pts_b, viewdirs, network_fn_d)
raw_d_b_rgba = raw_d_b[..., :4]
sceneflow_b_b = raw_d_b[..., 4:7]
sceneflow_b_f = raw_d_b[..., 7:10]
# Rerndering t - 1
rgb_map_d_b, weights_d_b = raw2outputs_d(raw_d_b_rgba,
z_vals,
rays_d,
raw_noise_std)
ret['sceneflow_b_f'] = sceneflow_b_f
ret['rgb_map_d_b'] = rgb_map_d_b
ret['acc_map_d_b'] = torch.abs(torch.sum(weights_d_b - weights_d, -1))
# Fourth pass: we have the DyanmicNeRF results at time t + 1
raw_d_f = network_query_fn_d(pts_f, viewdirs, network_fn_d)
raw_d_f_rgba = raw_d_f[..., :4]
sceneflow_f_b = raw_d_f[..., 4:7]
sceneflow_f_f = raw_d_f[..., 7:10]
rgb_map_d_f, weights_d_f = raw2outputs_d(raw_d_f_rgba,
z_vals,
rays_d,
raw_noise_std)
ret['sceneflow_f_b'] = sceneflow_f_b
ret['rgb_map_d_f'] = rgb_map_d_f
ret['acc_map_d_f'] = torch.abs(torch.sum(weights_d_f - weights_d, -1))
if inference:
return ret
# Also consider time t - 2 and t + 2 (Learn from NSFF)
# Fifth pass: we have the DyanmicNeRF results at time t - 2
pts_b_b = torch.cat([pts_b[..., :3] + sceneflow_b_b, torch.ones_like(pts[..., 0:1]) * (t - t_interval * 2)], -1)
ret['raw_pts_b_b'] = pts_b_b[..., :3]
if chain_5frames:
raw_d_b_b = network_query_fn_d(pts_b_b, viewdirs, network_fn_d)
raw_d_b_b_rgba = raw_d_b_b[..., :4]
rgb_map_d_b_b, _ = raw2outputs_d(raw_d_b_b_rgba,
z_vals,
rays_d,
raw_noise_std)
ret['rgb_map_d_b_b'] = rgb_map_d_b_b
# Sixth pass: we have the DyanmicNeRF results at time t + 2
pts_f_f = torch.cat([pts_f[..., :3] + sceneflow_f_f, torch.ones_like(pts[..., 0:1]) * (t + t_interval * 2)], -1)
ret['raw_pts_f_f'] = pts_f_f[..., :3]
if chain_5frames:
raw_d_f_f = network_query_fn_d(pts_f_f, viewdirs, network_fn_d)
raw_d_f_f_rgba = raw_d_f_f[..., :4]
rgb_map_d_f_f, _ = raw2outputs_d(raw_d_f_f_rgba,
z_vals,
rays_d,
raw_noise_std)
ret['rgb_map_d_f_f'] = rgb_map_d_f_f
for k in ret:
if torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any():
print(f"! [Numerical Error] {k} contains nan or inf.")
import ipdb; ipdb.set_trace()
return ret
================================================
FILE: run_nerf.py
================================================
import os
import time
import torch
import imageio
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from render_utils import *
from run_nerf_helpers import *
from load_llff import *
from utils.flow_utils import flow_to_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def config_parser():
import configargparse
parser = configargparse.ArgumentParser()
parser.add_argument('--config', is_config_file=True,
help='config file path')
parser.add_argument("--expname", type=str,
help='experiment name')
parser.add_argument("--basedir", type=str, default='./logs/',
help='where to store ckpts and logs')
parser.add_argument("--datadir", type=str, default='./data/llff/fern',
help='input data directory')
# training options
parser.add_argument("--netdepth", type=int, default=8,
help='layers in network')
parser.add_argument("--netwidth", type=int, default=256,
help='channels per layer')
parser.add_argument("--netdepth_fine", type=int, default=8,
help='layers in fine network')
parser.add_argument("--netwidth_fine", type=int, default=256,
help='channels per layer in fine network')
parser.add_argument("--N_rand", type=int, default=32*32*4,
help='batch size (number of random rays per gradient step)')
parser.add_argument("--lrate", type=float, default=5e-4,
help='learning rate')
parser.add_argument("--lrate_decay", type=int, default=300000,
help='exponential learning rate decay')
parser.add_argument("--chunk", type=int, default=1024*128,
help='number of rays processed in parallel, decrease if running out of memory')
parser.add_argument("--netchunk", type=int, default=1024*128,
help='number of pts sent through network in parallel, decrease if running out of memory')
parser.add_argument("--no_reload", action='store_true',
help='do not reload weights from saved ckpt')
parser.add_argument("--ft_path", type=str, default=None,
help='specific weights npy file to reload for coarse network')
parser.add_argument("--random_seed", type=int, default=1,
help='fix random seed for repeatability')
# rendering options
parser.add_argument("--N_samples", type=int, default=64,
help='number of coarse samples per ray')
parser.add_argument("--N_importance", type=int, default=0,
help='number of additional fine samples per ray')
parser.add_argument("--perturb", type=float, default=1.,
help='set to 0. for no jitter, 1. for jitter')
parser.add_argument("--use_viewdirs", action='store_true',
help='use full 5D input instead of 3D')
parser.add_argument("--use_viewdirsDyn", action='store_true',
help='use full 5D input instead of 3D for D-NeRF')
parser.add_argument("--i_embed", type=int, default=0,
help='set 0 for default positional encoding, -1 for none')
parser.add_argument("--multires", type=int, default=10,
help='log2 of max freq for positional encoding (3D location)')
parser.add_argument("--multires_views", type=int, default=4,
help='log2 of max freq for positional encoding (2D direction)')
parser.add_argument("--raw_noise_std", type=float, default=0.,
help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
parser.add_argument("--render_only", action='store_true',
help='do not optimize, reload weights and render out render_poses path')
# dataset options
parser.add_argument("--dataset_type", type=str, default='llff',
help='options: llff')
# llff flags
parser.add_argument("--factor", type=int, default=8,
help='downsample factor for LLFF images')
parser.add_argument("--no_ndc", action='store_true',
help='do not use normalized device coordinates (set for non-forward facing scenes)')
parser.add_argument("--lindisp", action='store_true',
help='sampling linearly in disparity rather than depth')
parser.add_argument("--spherify", action='store_true',
help='set for spherical 360 scenes')
# logging/saving options
parser.add_argument("--i_print", type=int, default=500,
help='frequency of console printout and metric logging')
parser.add_argument("--i_img", type=int, default=500,
help='frequency of tensorboard image logging')
parser.add_argument("--i_weights", type=int, default=10000,
help='frequency of weight ckpt saving')
parser.add_argument("--i_testset", type=int, default=50000,
help='frequency of testset saving')
parser.add_argument("--i_video", type=int, default=50000,
help='frequency of render_poses video saving')
parser.add_argument("--N_iters", type=int, default=1000000,
help='number of training iterations')
# Dynamic NeRF lambdas
parser.add_argument("--dynamic_loss_lambda", type=float, default=1.,
help='lambda of dynamic loss')
parser.add_argument("--static_loss_lambda", type=float, default=1.,
help='lambda of static loss')
parser.add_argument("--full_loss_lambda", type=float, default=3.,
help='lambda of full loss')
parser.add_argument("--depth_loss_lambda", type=float, default=0.04,
help='lambda of depth loss')
parser.add_argument("--order_loss_lambda", type=float, default=0.1,
help='lambda of order loss')
parser.add_argument("--flow_loss_lambda", type=float, default=0.02,
help='lambda of optical flow loss')
parser.add_argument("--slow_loss_lambda", type=float, default=0.1,
help='lambda of sf slow regularization')
parser.add_argument("--smooth_loss_lambda", type=float, default=0.1,
help='lambda of sf smooth regularization')
parser.add_argument("--consistency_loss_lambda", type=float, default=0.1,
help='lambda of sf cycle consistency regularization')
parser.add_argument("--mask_loss_lambda", type=float, default=0.1,
help='lambda of the mask loss')
parser.add_argument("--sparse_loss_lambda", type=float, default=0.1,
help='lambda of sparse loss')
parser.add_argument("--DyNeRF_blending", action='store_true',
help='use Dynamic NeRF to predict blending weight')
parser.add_argument("--pretrain", action='store_true',
help='Pretrain the StaticneRF')
parser.add_argument("--ft_path_S", type=str, default=None,
help='specific weights npy file to reload for StaticNeRF')
# For rendering teasers
parser.add_argument("--frame2dolly", type=int, default=-1,
help='choose frame to perform dolly zoom')
parser.add_argument("--x_trans_multiplier", type=float, default=1.,
help='x_trans_multiplier')
parser.add_argument("--y_trans_multiplier", type=float, default=0.33,
help='y_trans_multiplier')
parser.add_argument("--z_trans_multiplier", type=float, default=5.,
help='z_trans_multiplier')
parser.add_argument("--num_novelviews", type=int, default=60,
help='num_novelviews')
parser.add_argument("--focal_decrease", type=float, default=200,
help='focal_decrease')
return parser
def train():
parser = config_parser()
args = parser.parse_args()
if args.random_seed is not None:
print('Fixing random seed', args.random_seed)
np.random.seed(args.random_seed)
# Load data
if args.dataset_type == 'llff':
frame2dolly = args.frame2dolly
images, invdepths, masks, poses, bds, \
render_poses, render_focals, grids = load_llff_data(args, args.datadir,
args.factor,
frame2dolly=frame2dolly,
recenter=True, bd_factor=.9,
spherify=args.spherify)
hwf = poses[0, :3, -1]
poses = poses[:, :3, :4]
num_img = float(poses.shape[0])
assert len(poses) == len(images)
print('Loaded llff', images.shape,
render_poses.shape, hwf, args.datadir)
# Use all views to train
i_train = np.array([i for i in np.arange(int(images.shape[0]))])
print('DEFINING BOUNDS')
if args.no_ndc:
raise NotImplementedError
near = np.ndarray.min(bds) * .9
far = np.ndarray.max(bds) * 1.
else:
near = 0.
far = 1.
print('NEAR FAR', near, far)
else:
print('Unknown dataset type', args.dataset_type, 'exiting')
return
# Cast intrinsics to right types
H, W, focal = hwf
H, W = int(H), int(W)
hwf = [H, W, focal]
# Create log dir and copy the config file
basedir = args.basedir
expname = args.expname
os.makedirs(os.path.join(basedir, expname), exist_ok=True)
if not args.render_only:
f = os.path.join(basedir, expname, 'args.txt')
with open(f, 'w') as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write('{} = {}\n'.format(arg, attr))
if args.config is not None:
f = os.path.join(basedir, expname, 'config.txt')
with open(f, 'w') as file:
file.write(open(args.config, 'r').read())
# Create nerf model
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
global_step = start
bds_dict = {
'near': near,
'far': far,
'num_img': num_img,
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)
# Short circuit if only rendering out from trained model
if args.render_only:
print('RENDER ONLY')
i = start - 1
# Change time and change view at the same time.
time2render = np.concatenate((np.repeat((i_train / float(num_img) * 2. - 1.0), 4),
np.repeat((i_train / float(num_img) * 2. - 1.0)[::-1][1:-1], 4)))
if len(time2render) > len(render_poses):
pose2render = np.tile(render_poses, (int(np.ceil(len(time2render) / len(render_poses))), 1, 1))
pose2render = pose2render[:len(time2render)]
pose2render = torch.Tensor(pose2render)
else:
time2render = np.tile(time2render, int(np.ceil(len(render_poses) / len(time2render))))
time2render = time2render[:len(render_poses)]
pose2render = torch.Tensor(render_poses)
result_type = 'novelviewtime'
testsavedir = os.path.join(
basedir, expname, result_type + '_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
with torch.no_grad():
ret = render_path(pose2render, time2render,
hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
moviebase = os.path.join(
testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
save_res(moviebase, ret)
# Fix view (first view) and change time.
pose2render = torch.Tensor(poses[0:1, ...]).expand([int(num_img), 3, 4])
time2render = i_train / float(num_img) * 2. - 1.0
result_type = 'testset_view000'
testsavedir = os.path.join(
basedir, expname, result_type + '_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
with torch.no_grad():
ret = render_path(pose2render, time2render,
hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
moviebase = os.path.join(
testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
save_res(moviebase, ret)
return
N_rand = args.N_rand
# Move training data to GPU
images = torch.Tensor(images)
invdepths = torch.Tensor(invdepths)
masks = 1.0 - torch.Tensor(masks)
poses = torch.Tensor(poses)
grids = torch.Tensor(grids)
print('Begin')
print('TRAIN views are', i_train)
# Summary writers
writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))
decay_iteration = max(25, num_img)
# Pre-train StaticNeRF
if args.pretrain:
render_kwargs_train.update({'pretrain': True})
# Pre-train StaticNeRF first and use DynamicNeRF to blend
assert args.DyNeRF_blending == True
if args.ft_path_S is not None and args.ft_path_S != 'None':
# Load Pre-trained StaticNeRF
ckpt_path = args.ft_path_S
print('Reloading StaticNeRF from', ckpt_path)
ckpt = torch.load(ckpt_path)
render_kwargs_train['network_fn_s'].load_state_dict(ckpt['network_fn_s_state_dict'])
else:
# Train StaticNeRF from scratch
for i in range(args.N_iters):
time0 = time.time()
# No raybatching as we need to take random rays from one image at a time
img_i = np.random.choice(i_train)
t = img_i / num_img * 2. - 1.0 # time of the current frame
target = images[img_i]
pose = poses[img_i, :3, :4]
mask = masks[img_i] # Static region mask
rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
coords_s = torch.stack((torch.where(mask >= 0.5)), -1)
select_inds_s = np.random.choice(coords_s.shape[0], size=[N_rand], replace=False)
select_coords = coords_s[select_inds_s]
def select_batch(value, select_coords=select_coords):
return value[select_coords[:, 0], select_coords[:, 1]]
rays_o = select_batch(rays_o) # (N_rand, 3)
rays_d = select_batch(rays_d) # (N_rand, 3)
target_rgb = select_batch(target)
batch_mask = select_batch(mask[..., None])
batch_rays = torch.stack([rays_o, rays_d], 0)
##### Core optimization loop #####
ret = render(t,
False,
H, W, focal,
chunk=args.chunk,
rays=batch_rays,
**render_kwargs_train)
optimizer.zero_grad()
# Compute MSE loss between rgb_s and true RGB.
img_s_loss = img2mse(ret['rgb_map_s'], target_rgb)
psnr_s = mse2psnr(img_s_loss)
loss = args.static_loss_lambda * img_s_loss
loss.backward()
optimizer.step()
# Learning rate decay.
decay_rate = 0.1
decay_steps = args.lrate_decay
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
dt = time.time() - time0
print(f"Pretraining step: {global_step}, Loss: {loss}, Time: {dt}, expname: {expname}")
if i % args.i_print == 0:
writer.add_scalar("loss", loss.item(), i)
writer.add_scalar("lr", new_lrate, i)
writer.add_scalar("psnr_s", psnr_s.item(), i)
if i % args.i_img == 0:
target = images[img_i]
pose = poses[img_i, :3, :4]
mask = masks[img_i]
with torch.no_grad():
ret = render(t,
False,
H, W, focal,
chunk=1024*16,
c2w=pose,
**render_kwargs_test)
# Save out the validation image for Tensorboard-free monitoring
writer.add_image("rgb_holdout", target, global_step=i, dataformats='HWC')
writer.add_image("mask", mask, global_step=i, dataformats='HW')
writer.add_image("rgb_s", torch.clamp(ret['rgb_map_s'], 0., 1.), global_step=i, dataformats='HWC')
writer.add_image("depth_s", normalize_depth(ret['depth_map_s']), global_step=i, dataformats='HW')
writer.add_image("acc_s", ret['acc_map_s'], global_step=i, dataformats='HW')
global_step += 1
# Save the pretrained weight
torch.save({
'global_step': global_step,
'network_fn_s_state_dict': render_kwargs_train['network_fn_s'].state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, os.path.join(basedir, expname, 'Pretrained_S.tar'))
# Reset
render_kwargs_train.update({'pretrain': False})
global_step = start
# Fix the StaticNeRF and only train the DynamicNeRF
grad_vars_d = list(render_kwargs_train['network_fn_d'].parameters())
optimizer = torch.optim.Adam(params=grad_vars_d, lr=args.lrate, betas=(0.9, 0.999))
for i in range(start, args.N_iters):
time0 = time.time()
# Use frames at t-2, t-1, t, t+1, t+2 (adapted from NSFF)
if i < decay_iteration * 2000:
chain_5frames = False
else:
chain_5frames = True
# Lambda decay.
Temp = 1. / (10 ** (i // (decay_iteration * 1000)))
if i % (decay_iteration * 1000) == 0:
torch.cuda.empty_cache()
# No raybatching as we need to take random rays from one image at a time
img_i = np.random.choice(i_train)
t = img_i / num_img * 2. - 1.0 # time of the current frame
target = images[img_i]
pose = poses[img_i, :3, :4]
mask = masks[img_i] # Static region mask
invdepth = invdepths[img_i]
grid = grids[img_i]
rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)
coords_d = torch.stack((torch.where(mask < 0.5)), -1)
coords_s = torch.stack((torch.where(mask >= 0.5)), -1)
coords = torch.stack((torch.where(mask > -1)), -1)
# Evenly sample dynamic region and static region
select_inds_d = np.random.choice(coords_d.shape[0], size=[min(len(coords_d), N_rand//2)], replace=False)
select_inds_s = np.random.choice(coords_s.shape[0], size=[N_rand//2], replace=False)
select_coords = torch.cat([coords_s[select_inds_s],
coords_d[select_inds_d]], 0)
def select_batch(value, select_coords=select_coords):
return value[select_coords[:, 0], select_coords[:, 1]]
rays_o = select_batch(rays_o) # (N_rand, 3)
rays_d = select_batch(rays_d) # (N_rand, 3)
target_rgb = select_batch(target)
batch_grid = select_batch(grid) # (N_rand, 8)
batch_mask = select_batch(mask[..., None])
batch_invdepth = select_batch(invdepth)
batch_rays = torch.stack([rays_o, rays_d], 0)
##### Core optimization loop #####
ret = render(t,
chain_5frames,
H, W, focal,
chunk=args.chunk,
rays=batch_rays,
**render_kwargs_train)
optimizer.zero_grad()
loss = 0
loss_dict = {}
# Compute MSE loss between rgb_full and true RGB.
img_loss = img2mse(ret['rgb_map_full'], target_rgb)
psnr = mse2psnr(img_loss)
loss_dict['psnr'] = psnr
loss_dict['img_loss'] = img_loss
loss += args.full_loss_lambda * loss_dict['img_loss']
# Compute MSE loss between rgb_s and true RGB.
img_s_loss = img2mse(ret['rgb_map_s'], target_rgb, batch_mask)
psnr_s = mse2psnr(img_s_loss)
loss_dict['psnr_s'] = psnr_s
loss_dict['img_s_loss'] = img_s_loss
loss += args.static_loss_lambda * loss_dict['img_s_loss']
# Compute MSE loss between rgb_d and true RGB.
img_d_loss = img2mse(ret['rgb_map_d'], target_rgb)
psnr_d = mse2psnr(img_d_loss)
loss_dict['psnr_d'] = psnr_d
loss_dict['img_d_loss'] = img_d_loss
loss += args.dynamic_loss_lambda * loss_dict['img_d_loss']
# Compute MSE loss between rgb_d_f and true RGB.
img_d_f_loss = img2mse(ret['rgb_map_d_f'], target_rgb)
psnr_d_f = mse2psnr(img_d_f_loss)
loss_dict['psnr_d_f'] = psnr_d_f
loss_dict['img_d_f_loss'] = img_d_f_loss
loss += args.dynamic_loss_lambda * loss_dict['img_d_f_loss']
# Compute MSE loss between rgb_d_b and true RGB.
img_d_b_loss = img2mse(ret['rgb_map_d_b'], target_rgb)
psnr_d_b = mse2psnr(img_d_b_loss)
loss_dict['psnr_d_b'] = psnr_d_b
loss_dict['img_d_b_loss'] = img_d_b_loss
loss += args.dynamic_loss_lambda * loss_dict['img_d_b_loss']
# Motion loss.
# Compuate EPE between induced flow and true flow (forward flow).
# The last frame does not have forward flow.
if img_i < num_img - 1:
pts_f = ret['raw_pts_f']
weight = ret['weights_d']
pose_f = poses[img_i + 1, :3, :4]
induced_flow_f = induce_flow(H, W, focal, pose_f, weight, pts_f, batch_grid[..., :2])
flow_f_loss = img2mae(induced_flow_f, batch_grid[:, 2:4], batch_grid[:, 4:5])
loss_dict['flow_f_loss'] = flow_f_loss
loss += args.flow_loss_lambda * Temp * loss_dict['flow_f_loss']
# Compuate EPE between induced flow and true flow (backward flow).
# The first frame does not have backward flow.
if img_i > 0:
pts_b = ret['raw_pts_b']
weight = ret['weights_d']
pose_b = poses[img_i - 1, :3, :4]
induced_flow_b = induce_flow(H, W, focal, pose_b, weight, pts_b, batch_grid[..., :2])
flow_b_loss = img2mae(induced_flow_b, batch_grid[:, 5:7], batch_grid[:, 7:8])
loss_dict['flow_b_loss'] = flow_b_loss
loss += args.flow_loss_lambda * Temp * loss_dict['flow_b_loss']
# Slow scene flow. The forward and backward sceneflow should be small.
slow_loss = L1(ret['sceneflow_b']) + L1(ret['sceneflow_f'])
loss_dict['slow_loss'] = slow_loss
loss += args.slow_loss_lambda * loss_dict['slow_loss']
# Smooth scene flow. The summation of the forward and backward sceneflow should be small.
smooth_loss = compute_sf_smooth_loss(ret['raw_pts'],
ret['raw_pts_f'],
ret['raw_pts_b'],
H, W, focal)
loss_dict['smooth_loss'] = smooth_loss
loss += args.smooth_loss_lambda * loss_dict['smooth_loss']
# Spatial smooth scene flow. (loss adapted from NSFF)
sp_smooth_loss = compute_sf_smooth_s_loss(ret['raw_pts'], ret['raw_pts_f'], H, W, focal) \
+ compute_sf_smooth_s_loss(ret['raw_pts'], ret['raw_pts_b'], H, W, focal)
loss_dict['sp_smooth_loss'] = sp_smooth_loss
loss += args.smooth_loss_lambda * loss_dict['sp_smooth_loss']
# Consistency loss.
consistency_loss = L1(ret['sceneflow_f'] + ret['sceneflow_f_b']) + \
L1(ret['sceneflow_b'] + ret['sceneflow_b_f'])
loss_dict['consistency_loss'] = consistency_loss
loss += args.consistency_loss_lambda * loss_dict['consistency_loss']
# Mask loss.
mask_loss = L1(ret['blending'][batch_mask[:, 0].type(torch.bool)]) + \
img2mae(ret['dynamicness_map'][..., None], 1 - batch_mask)
loss_dict['mask_loss'] = mask_loss
if i < decay_iteration * 1000:
loss += args.mask_loss_lambda * loss_dict['mask_loss']
# Sparsity loss.
sparse_loss = entropy(ret['weights_d']) + entropy(ret['blending'])
loss_dict['sparse_loss'] = sparse_loss
loss += args.sparse_loss_lambda * loss_dict['sparse_loss']
# Depth constraint
# Depth in NDC space equals to negative disparity in Euclidean space.
depth_loss = compute_depth_loss(ret['depth_map_d'], -batch_invdepth)
loss_dict['depth_loss'] = depth_loss
loss += args.depth_loss_lambda * Temp * loss_dict['depth_loss']
# Order loss
order_loss = torch.mean(torch.square(ret['depth_map_d'][batch_mask[:, 0].type(torch.bool)] - \
ret['depth_map_s'].detach()[batch_mask[:, 0].type(torch.bool)]))
loss_dict['order_loss'] = order_loss
loss += args.order_loss_lambda * loss_dict['order_loss']
sf_smooth_loss = compute_sf_smooth_loss(ret['raw_pts_b'],
ret['raw_pts'],
ret['raw_pts_b_b'],
H, W, focal) + \
compute_sf_smooth_loss(ret['raw_pts_f'],
ret['raw_pts_f_f'],
ret['raw_pts'],
H, W, focal)
loss_dict['sf_smooth_loss'] = sf_smooth_loss
loss += args.smooth_loss_lambda * loss_dict['sf_smooth_loss']
if chain_5frames:
img_d_b_b_loss = img2mse(ret['rgb_map_d_b_b'], target_rgb)
loss_dict['img_d_b_b_loss'] = img_d_b_b_loss
loss += args.dynamic_loss_lambda * loss_dict['img_d_b_b_loss']
img_d_f_f_loss = img2mse(ret['rgb_map_d_f_f'], target_rgb)
loss_dict['img_d_f_f_loss'] = img_d_f_f_loss
loss += args.dynamic_loss_lambda * loss_dict['img_d_f_f_loss']
loss.backward()
optimizer.step()
# Learning rate decay.
decay_rate = 0.1
decay_steps = args.lrate_decay
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
dt = time.time() - time0
print(f"Step: {global_step}, Loss: {loss}, Time: {dt}, chain_5frames: {chain_5frames}, expname: {expname}")
# Rest is logging
if i % args.i_weights==0:
path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
if args.N_importance > 0:
raise NotImplementedError
else:
torch.save({
'global_step': global_step,
'network_fn_d_state_dict': render_kwargs_train['network_fn_d'].state_dict(),
'network_fn_s_state_dict': render_kwargs_train['network_fn_s'].state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
print('Saved weights at', path)
if i % args.i_video == 0 and i > 0:
# Change time and change view at the same time.
time2render = np.concatenate((np.repeat((i_train / float(num_img) * 2. - 1.0), 4),
np.repeat((i_train / float(num_img) * 2. - 1.0)[::-1][1:-1], 4)))
if len(time2render) > len(render_poses):
pose2render = np.tile(render_poses, (int(np.ceil(len(time2render) / len(render_poses))), 1, 1))
pose2render = pose2render[:len(time2render)]
pose2render = torch.Tensor(pose2render)
else:
time2render = np.tile(time2render, int(np.ceil(len(render_poses) / len(time2render))))
time2render = time2render[:len(render_poses)]
pose2render = torch.Tensor(render_poses)
result_type = 'novelviewtime'
testsavedir = os.path.join(
basedir, expname, result_type + '_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
with torch.no_grad():
ret = render_path(pose2render, time2render,
hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
moviebase = os.path.join(
testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
save_res(moviebase, ret)
if i % args.i_testset == 0 and i > 0:
# Change view and time.
pose2render = torch.Tensor(poses)
time2render = i_train / float(num_img) * 2. - 1.0
result_type = 'testset'
testsavedir = os.path.join(
basedir, expname, result_type + '_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
with torch.no_grad():
ret = render_path(pose2render, time2render,
hwf, args.chunk, render_kwargs_test, savedir=testsavedir,
flows_gt_f=grids[:, :, :, 2:4], flows_gt_b=grids[:, :, :, 5:7])
moviebase = os.path.join(
testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
save_res(moviebase, ret)
# Fix view (first view) and change time.
pose2render = torch.Tensor(poses[0:1, ...].expand([int(num_img), 3, 4]))
time2render = i_train / float(num_img) * 2. - 1.0
result_type = 'testset_view000'
testsavedir = os.path.join(
basedir, expname, result_type + '_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
with torch.no_grad():
ret = render_path(pose2render, time2render,
hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
moviebase = os.path.join(
testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
save_res(moviebase, ret)
# Fix time (the first timestamp) and change view.
pose2render = torch.Tensor(poses)
time2render = np.tile(i_train[0], [int(num_img)]) / float(num_img) * 2. - 1.0
result_type = 'testset_time000'
testsavedir = os.path.join(
basedir, expname, result_type + '_{:06d}'.format(i))
os.makedirs(testsavedir, exist_ok=True)
with torch.no_grad():
ret = render_path(pose2render, time2render,
hwf, args.chunk, render_kwargs_test, savedir=testsavedir)
moviebase = os.path.join(
testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))
save_res(moviebase, ret)
if i % args.i_print == 0:
writer.add_scalar("loss", loss.item(), i)
writer.add_scalar("lr", new_lrate, i)
writer.add_scalar("Temp", Temp, i)
for loss_key in loss_dict:
writer.add_scalar(loss_key, loss_dict[loss_key].item(), i)
if i % args.i_img == 0:
# Log a rendered training view to Tensorboard.
# img_i = np.random.choice(i_train[1:-1])
target = images[img_i]
pose = poses[img_i, :3, :4]
mask = masks[img_i]
grid = grids[img_i]
invdepth = invdepths[img_i]
flow_f_img = flow_to_image(grid[..., 2:4].cpu().numpy())
flow_b_img = flow_to_image(grid[..., 5:7].cpu().numpy())
with torch.no_grad():
ret = render(t,
False,
H, W, focal,
chunk=1024*16,
c2w=pose,
**render_kwargs_test)
# The last frame does not have forward flow.
pose_f = poses[min(img_i + 1, int(num_img) - 1), :3, :4]
induced_flow_f = induce_flow(H, W, focal, pose_f, ret['weights_d'], ret['raw_pts_f'], grid[..., :2])
# The first frame does not have backward flow.
pose_b = poses[max(img_i - 1, 0), :3, :4]
induced_flow_b = induce_flow(H, W, focal, pose_b, ret['weights_d'], ret['raw_pts_b'], grid[..., :2])
induced_flow_f_img = flow_to_image(induced_flow_f.cpu().numpy())
induced_flow_b_img = flow_to_image(induced_flow_b.cpu().numpy())
psnr = mse2psnr(img2mse(ret['rgb_map_full'], target))
# Save out the validation image for Tensorboard-free monitoring
testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')
if i == 0:
os.makedirs(testimgdir, exist_ok=True)
imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(ret['rgb_map_full'].cpu().numpy()))
writer.add_scalar("psnr_holdout", psnr.item(), i)
writer.add_image("rgb_holdout", target, global_step=i, dataformats='HWC')
writer.add_image("mask", mask, global_step=i, dataformats='HW')
writer.add_image("disp", torch.clamp(invdepth / percentile(invdepth, 97), 0., 1.), global_step=i, dataformats='HW')
writer.add_image("rgb", torch.clamp(ret['rgb_map_full'], 0., 1.), global_step=i, dataformats='HWC')
writer.add_image("depth", normalize_depth(ret['depth_map_full']), global_step=i, dataformats='HW')
writer.add_image("acc", ret['acc_map_full'], global_step=i, dataformats='HW')
writer.add_image("rgb_s", torch.clamp(ret['rgb_map_s'], 0., 1.), global_step=i, dataformats='HWC')
writer.add_image("depth_s", normalize_depth(ret['depth_map_s']), global_step=i, dataformats='HW')
writer.add_image("acc_s", ret['acc_map_s'], global_step=i, dataformats='HW')
writer.add_image("rgb_d", torch.clamp(ret['rgb_map_d'], 0., 1.), global_step=i, dataformats='HWC')
writer.add_image("depth_d", normalize_depth(ret['depth_map_d']), global_step=i, dataformats='HW')
writer.add_image("acc_d", ret['acc_map_d'], global_step=i, dataformats='HW')
writer.add_image("induced_flow_f", induced_flow_f_img, global_step=i, dataformats='HWC')
writer.add_image("induced_flow_b", induced_flow_b_img, global_step=i, dataformats='HWC')
writer.add_image("flow_f_gt", flow_f_img, global_step=i, dataformats='HWC')
writer.add_image("flow_b_gt", flow_b_img, global_step=i, dataformats='HWC')
writer.add_image("dynamicness", ret['dynamicness_map'], global_step=i, dataformats='HW')
global_step += 1
if __name__ == '__main__':
torch.set_default_tensor_type('torch.cuda.FloatTensor')
train()
================================================
FILE: run_nerf_helpers.py
================================================
import os
import torch
import imageio
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Misc utils
def img2mse(x, y, M=None):
if M == None:
return torch.mean((x - y) ** 2)
else:
return torch.sum((x - y) ** 2 * M) / (torch.sum(M) + 1e-8) / x.shape[-1]
def img2mae(x, y, M=None):
if M == None:
return torch.mean(torch.abs(x - y))
else:
return torch.sum(torch.abs(x - y) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]
def L1(x, M=None):
if M == None:
return torch.mean(torch.abs(x))
else:
return torch.sum(torch.abs(x) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]
def L2(x, M=None):
if M == None:
return torch.mean(x ** 2)
else:
return torch.sum((x ** 2) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]
def entropy(x):
return -torch.sum(x * torch.log(x + 1e-19)) / x.shape[0]
def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
def to8b(x): return (255 * np.clip(x, 0, 1)).astype(np.uint8)
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x: x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn,
freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0, input_dims=3):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input': True,
'input_dims': input_dims,
'max_freq_log2': multires-1,
'num_freqs': multires,
'log_sampling': True,
'periodic_fns': [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
def embed(x, eo=embedder_obj): return eo.embed(x)
return embed, embedder_obj.out_dim
# Dynamic NeRF model architecture
class NeRF_d(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirsDyn=True):
"""
"""
super(NeRF_d, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirsDyn = use_viewdirsDyn
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
if self.use_viewdirsDyn:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
self.sf_linear = nn.Linear(W, 6)
self.weight_linear = nn.Linear(W, 1)
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
# Scene flow should be unbounded. However, in NDC space the coordinate is
# bounded in [-1, 1].
sf = torch.tanh(self.sf_linear(h))
blending = torch.sigmoid(self.weight_linear(h))
if self.use_viewdirsDyn:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
else:
outputs = self.output_linear(h)
return torch.cat([outputs, sf, blending], dim=-1)
# Static NeRF model architecture
class NeRF_s(nn.Module):
def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=True):
"""
"""
super(NeRF_s, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_views = input_ch_views
self.skips = skips
self.use_viewdirs = use_viewdirs
self.pts_linears = nn.ModuleList(
[nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])
if self.use_viewdirs:
self.feature_linear = nn.Linear(W, W)
self.alpha_linear = nn.Linear(W, 1)
self.rgb_linear = nn.Linear(W//2, 3)
else:
self.output_linear = nn.Linear(W, output_ch)
self.weight_linear = nn.Linear(W, 1)
def forward(self, x):
input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
h = input_pts
for i, l in enumerate(self.pts_linears):
h = self.pts_linears[i](h)
h = F.relu(h)
if i in self.skips:
h = torch.cat([input_pts, h], -1)
blending = torch.sigmoid(self.weight_linear(h))
if self.use_viewdirs:
alpha = self.alpha_linear(h)
feature = self.feature_linear(h)
h = torch.cat([feature, input_views], -1)
for i, l in enumerate(self.views_linears):
h = self.views_linears[i](h)
h = F.relu(h)
rgb = self.rgb_linear(h)
outputs = torch.cat([rgb, alpha], -1)
else:
outputs = self.output_linear(h)
return torch.cat([outputs, blending], -1)
def batchify(fn, chunk):
"""Constructs a version of 'fn' that applies to smaller batches.
"""
if chunk is None:
return fn
def ret(inputs):
return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
return ret
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
"""Prepares inputs and applies network 'fn'.
"""
inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
embedded = embed_fn(inputs_flat)
if viewdirs is not None:
input_dirs = viewdirs[:, None].expand(inputs[:, :, :3].shape)
input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
embedded_dirs = embeddirs_fn(input_dirs_flat)
embedded = torch.cat([embedded, embedded_dirs], -1)
outputs_flat = batchify(fn, netchunk)(embedded)
outputs = torch.reshape(outputs_flat, list(
inputs.shape[:-1]) + [outputs_flat.shape[-1]])
return outputs
def create_nerf(args):
"""Instantiate NeRF's MLP model.
"""
embed_fn_d, input_ch_d = get_embedder(args.multires, args.i_embed, 4)
# 10 * 2 * 4 + 4 = 84
# L * (sin, cos) * (x, y, z, t) + (x, y, z, t)
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(
args.multires_views, args.i_embed, 3)
# 4 * 2 * 3 + 3 = 27
# L * (sin, cos) * (3 Cartesian viewing direction unit vector from [theta, phi]) + (3 Cartesian viewing direction unit vector from [theta, phi])
output_ch = 5 if args.N_importance > 0 else 4
skips = [4]
model_d = NeRF_d(D=args.netdepth, W=args.netwidth,
input_ch=input_ch_d, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views,
use_viewdirsDyn=args.use_viewdirsDyn).to(device)
device_ids = list(range(torch.cuda.device_count()))
model_d = torch.nn.DataParallel(model_d, device_ids=device_ids)
grad_vars = list(model_d.parameters())
embed_fn_s, input_ch_s = get_embedder(args.multires, args.i_embed, 3)
# 10 * 2 * 3 + 3 = 63
# L * (sin, cos) * (x, y, z) + (x, y, z)
model_s = NeRF_s(D=args.netdepth, W=args.netwidth,
input_ch=input_ch_s, output_ch=output_ch, skips=skips,
input_ch_views=input_ch_views,
use_viewdirs=args.use_viewdirs).to(device)
model_s = torch.nn.DataParallel(model_s, device_ids=device_ids)
grad_vars += list(model_s.parameters())
model_fine = None
if args.N_importance > 0:
raise NotImplementedError
def network_query_fn_d(inputs, viewdirs, network_fn): return run_network(
inputs, viewdirs, network_fn,
embed_fn=embed_fn_d,
embeddirs_fn=embeddirs_fn,
netchunk=args.netchunk)
def network_query_fn_s(inputs, viewdirs, network_fn): return run_network(
inputs, viewdirs, network_fn,
embed_fn=embed_fn_s,
embeddirs_fn=embeddirs_fn,
netchunk=args.netchunk)
render_kwargs_train = {
'network_query_fn_d': network_query_fn_d,
'network_query_fn_s': network_query_fn_s,
'network_fn_d': model_d,
'network_fn_s': model_s,
'perturb': args.perturb,
'N_importance': args.N_importance,
'N_samples': args.N_samples,
'use_viewdirs': args.use_viewdirs,
'raw_noise_std': args.raw_noise_std,
'inference': False,
'DyNeRF_blending': args.DyNeRF_blending,
}
# NDC only good for LLFF-style forward facing data
if args.dataset_type != 'llff' or args.no_ndc:
print('Not ndc!')
render_kwargs_train['ndc'] = False
render_kwargs_train['lindisp'] = args.lindisp
else:
render_kwargs_train['ndc'] = True
render_kwargs_test = {
k: render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.
render_kwargs_test['inference'] = True
# Create optimizer
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
start = 0
basedir = args.basedir
expname = args.expname
if args.ft_path is not None and args.ft_path != 'None':
ckpts = [args.ft_path]
else:
ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
print('Found ckpts', ckpts)
if len(ckpts) > 0 and not args.no_reload:
ckpt_path = ckpts[-1]
print('Reloading from', ckpt_path)
ckpt = torch.load(ckpt_path)
start = ckpt['global_step'] + 1
# optimizer.load_state_dict(ckpt['optimizer_state_dict'])
model_d.load_state_dict(ckpt['network_fn_d_state_dict'])
model_s.load_state_dict(ckpt['network_fn_s_state_dict'])
print('Resetting step to', start)
if model_fine is not None:
raise NotImplementedError
return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
# Ray helpers
def get_rays(H, W, focal, c2w):
"""Get ray origins, directions from a pinhole camera."""
i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'
i = i.t()
j = j.t()
dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
rays_o = c2w[:3, -1].expand(rays_d.shape)
return rays_o, rays_d
def ndc_rays(H, W, focal, near, rays_o, rays_d):
"""Normalized device coordinate rays.
Space such that the canvas is a cube with sides [-1, 1] in each axis.
Args:
H: int. Height in pixels.
W: int. Width in pixels.
focal: float. Focal length of pinhole camera.
near: float or array of shape[batch_size]. Near depth bound for the scene.
rays_o: array of shape [batch_size, 3]. Camera origin.
rays_d: array of shape [batch_size, 3]. Ray direction.
Returns:
rays_o: array of shape [batch_size, 3]. Camera origin in NDC.
rays_d: array of shape [batch_size, 3]. Ray direction in NDC.
"""
# Shift ray origins to near plane
t = -(near + rays_o[..., 2]) / rays_d[..., 2]
rays_o = rays_o + t[..., None] * rays_d
# Projection
o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2]
o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2]
o2 = 1. + 2. * near / rays_o[..., 2]
d0 = -1./(W/(2.*focal)) * \
(rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2])
d1 = -1./(H/(2.*focal)) * \
(rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2])
d2 = -2. * near / rays_o[..., 2]
rays_o = torch.stack([o0, o1, o2], -1)
rays_d = torch.stack([d0, d1, d2], -1)
return rays_o, rays_d
def get_grid(H, W, num_img, flows_f, flow_masks_f, flows_b, flow_masks_b):
# |--------------------| |--------------------|
# | j | | v |
# | i * | | u * |
# | | | |
# |--------------------| |--------------------|
i, j = np.meshgrid(np.arange(W, dtype=np.float32),
np.arange(H, dtype=np.float32), indexing='xy')
grid = np.empty((0, H, W, 8), np.float32)
for idx in range(num_img):
grid = np.concatenate((grid, np.stack([i,
j,
flows_f[idx, :, :, 0],
flows_f[idx, :, :, 1],
flow_masks_f[idx, :, :],
flows_b[idx, :, :, 0],
flows_b[idx, :, :, 1],
flow_masks_b[idx, :, :]], -1)[None, ...]))
return grid
def NDC2world(pts, H, W, f):
# NDC coordinate to world coordinate
pts_z = 2 / (torch.clamp(pts[..., 2:], min=-1., max=1-1e-3) - 1)
pts_x = - pts[..., 0:1] * pts_z * W / 2 / f
pts_y = - pts[..., 1:2] * pts_z * H / 2 / f
pts_world = torch.cat([pts_x, pts_y, pts_z], -1)
return pts_world
def render_3d_point(H, W, f, pose, weights, pts):
"""Render 3D position along each ray and project it to the image plane.
"""
c2w = pose
w2c = c2w[:3, :3].transpose(0, 1) # same as np.linalg.inv(c2w[:3, :3])
# Rendered 3D position in NDC coordinate
pts_map_NDC = torch.sum(weights[..., None] * pts, -2)
# NDC coordinate to world coordinate
pts_map_world = NDC2world(pts_map_NDC, H, W, f)
# World coordinate to camera coordinate
# Translate
pts_map_world = pts_map_world - c2w[:, 3]
# Rotate
pts_map_cam = torch.sum(pts_map_world[..., None, :] * w2c[:3, :3], -1)
# Camera coordinate to 2D image coordinate
pts_plane = torch.cat([pts_map_cam[..., 0:1] / (- pts_map_cam[..., 2:]) * f + W * .5,
- pts_map_cam[..., 1:2] / (- pts_map_cam[..., 2:]) * f + H * .5],
-1)
return pts_plane
def induce_flow(H, W, focal, pose_neighbor, weights, pts_3d_neighbor, pts_2d):
# Render 3D position along each ray and project it to the neighbor frame's image plane.
pts_2d_neighbor = render_3d_point(H, W, focal,
pose_neighbor,
weights,
pts_3d_neighbor)
induced_flow = pts_2d_neighbor - pts_2d
return induced_flow
def compute_depth_loss(dyn_depth, gt_depth):
t_d = torch.median(dyn_depth)
s_d = torch.mean(torch.abs(dyn_depth - t_d))
dyn_depth_norm = (dyn_depth - t_d) / s_d
t_gt = torch.median(gt_depth)
s_gt = torch.mean(torch.abs(gt_depth - t_gt))
gt_depth_norm = (gt_depth - t_gt) / s_gt
return torch.mean((dyn_depth_norm - gt_depth_norm) ** 2)
def normalize_depth(depth):
return torch.clamp(depth / percentile(depth, 97), 0., 1.)
def percentile(t, q):
"""
Return the ``q``-th percentile of the flattened input tensor's data.
CAUTION:
* Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
* Values are not interpolated, which corresponds to
``numpy.percentile(..., interpolation="nearest")``.
:param t: Input tensor.
:param q: Percentile to compute, which must be between 0 and 100 inclusive.
:return: Resulting value (scalar).
"""
k = 1 + round(.01 * float(q) * (t.numel() - 1))
result = t.view(-1).kthvalue(k).values.item()
return result
def save_res(moviebase, ret, fps=None):
if fps == None:
if len(ret['rgbs']) < 25:
fps = 4
else:
fps = 24
for k in ret:
if 'rgbs' in k:
imageio.mimwrite(moviebase + k + '.mp4',
to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
# imageio.mimsave(moviebase + k + '.gif',
# to8b(ret[k]), format='gif', fps=fps)
elif 'depths' in k:
imageio.mimwrite(moviebase + k + '.mp4',
to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
# imageio.mimsave(moviebase + k + '.gif',
# to8b(ret[k]), format='gif', fps=fps)
elif 'disps' in k:
imageio.mimwrite(moviebase + k + '.mp4',
to8b(ret[k] / np.max(ret[k])), fps=fps, quality=8, macro_block_size=1)
# imageio.mimsave(moviebase + k + '.gif',
# to8b(ret[k] / np.max(ret[k])), format='gif', fps=fps)
elif 'sceneflow_' in k:
imageio.mimwrite(moviebase + k + '.mp4',
to8b(norm_sf(ret[k])), fps=fps, quality=8, macro_block_size=1)
# imageio.mimsave(moviebase + k + '.gif',
# to8b(norm_sf(ret[k])), format='gif', fps=fps)
elif 'flows' in k:
imageio.mimwrite(moviebase + k + '.mp4',
ret[k], fps=fps, quality=8, macro_block_size=1)
# imageio.mimsave(moviebase + k + '.gif',
# ret[k], format='gif', fps=fps)
elif 'dynamicness' in k:
imageio.mimwrite(moviebase + k + '.mp4',
to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
# imageio.mimsave(moviebase + k + '.gif',
# to8b(ret[k]), format='gif', fps=fps)
elif 'disocclusions' in k:
imageio.mimwrite(moviebase + k + '.mp4',
to8b(ret[k][..., 0]), fps=fps, quality=8, macro_block_size=1)
# imageio.mimsave(moviebase + k + '.gif',
# to8b(ret[k][..., 0]), format='gif', fps=fps)
elif 'blending' in k:
blending = ret[k][..., None]
blending = np.moveaxis(blending, [0, 1, 2, 3], [1, 2, 0, 3])
imageio.mimwrite(moviebase + k + '.mp4',
to8b(blending), fps=fps, quality=8, macro_block_size=1)
# imageio.mimsave(moviebase + k + '.gif',
# to8b(blending), format='gif', fps=fps)
elif 'weights' in k:
imageio.mimwrite(moviebase + k + '.mp4',
to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)
else:
raise NotImplementedError
def norm_sf_channel(sf_ch):
# Make sure zero scene flow is not shifted
sf_ch[sf_ch >= 0] = sf_ch[sf_ch >= 0] / sf_ch.max() / 2
sf_ch[sf_ch < 0] = sf_ch[sf_ch < 0] / np.abs(sf_ch.min()) / 2
sf_ch = sf_ch + 0.5
return sf_ch
def norm_sf(sf):
sf = np.concatenate((norm_sf_channel(sf[..., 0:1]),
norm_sf_channel(sf[..., 1:2]),
norm_sf_channel(sf[..., 2:3])), -1)
sf = np.moveaxis(sf, [0, 1, 2, 3], [1, 2, 0, 3])
return sf
# Spatial smoothness (adapted from NSFF)
def compute_sf_smooth_s_loss(pts1, pts2, H, W, f):
N_samples = pts1.shape[1]
# NDC coordinate to world coordinate
pts1_world = NDC2world(pts1[..., :int(N_samples * 0.95), :], H, W, f)
pts2_world = NDC2world(pts2[..., :int(N_samples * 0.95), :], H, W, f)
# scene flow in world coordinate
scene_flow_world = pts1_world - pts2_world
return L1(scene_flow_world[..., :-1, :] - scene_flow_world[..., 1:, :])
# Temporal smoothness
def compute_sf_smooth_loss(pts, pts_f, pts_b, H, W, f):
N_samples = pts.shape[1]
pts_world = NDC2world(pts[..., :int(N_samples * 0.9), :], H, W, f)
pts_f_world = NDC2world(pts_f[..., :int(N_samples * 0.9), :], H, W, f)
pts_b_world = NDC2world(pts_b[..., :int(N_samples * 0.9), :], H, W, f)
# scene flow in world coordinate
sceneflow_f = pts_f_world - pts_world
sceneflow_b = pts_b_world - pts_world
# For a 3D point, its forward and backward sceneflow should be opposite.
return L2(sceneflow_f + sceneflow_b)
================================================
FILE: utils/RAFT/__init__.py
================================================
# from .demo import RAFT_infer
from .raft import RAFT
================================================
FILE: utils/RAFT/corr.py
================================================
import torch
import torch.nn.functional as F
from .utils.utils import bilinear_sampler, coords_grid
try:
import alt_cuda_corr
except:
# alt_cuda_corr is not compiled
pass
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1)
dy = torch.linspace(-r, r, 2*r+1)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class CorrLayer(torch.autograd.Function):
@staticmethod
def forward(ctx, fmap1, fmap2, coords, r):
fmap1 = fmap1.contiguous()
fmap2 = fmap2.contiguous()
coords = coords.contiguous()
ctx.save_for_backward(fmap1, fmap2, coords)
ctx.r = r
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
return corr
@staticmethod
def backward(ctx, grad_corr):
fmap1, fmap2, coords = ctx.saved_tensors
grad_corr = grad_corr.contiguous()
fmap1_grad, fmap2_grad, coords_grad = \
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
return fmap1_grad, fmap2_grad, coords_grad, None
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / 16.0
================================================
FILE: utils/RAFT/datasets.py
================================================
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
import os
import math
import random
from glob import glob
import os.path as osp
from utils import frame_utils
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
class FlowDataset(data.Dataset):
def __init__(self, aug_params=None, sparse=False):
self.augmentor = None
self.sparse = sparse
if aug_params is not None:
if sparse:
self.augmentor = SparseFlowAugmentor(**aug_params)
else:
self.augmentor = FlowAugmentor(**aug_params)
self.is_test = False
self.init_seed = False
self.flow_list = []
self.image_list = []
self.extra_info = []
def __getitem__(self, index):
if self.is_test:
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
img1 = np.array(img1).astype(np.uint8)[..., :3]
img2 = np.array(img2).astype(np.uint8)[..., :3]
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
return img1, img2, self.extra_info[index]
if not self.init_seed:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
torch.manual_seed(worker_info.id)
np.random.seed(worker_info.id)
random.seed(worker_info.id)
self.init_seed = True
index = index % len(self.image_list)
valid = None
if self.sparse:
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
else:
flow = frame_utils.read_gen(self.flow_list[index])
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = np.array(flow).astype(np.float32)
img1 = np.array(img1).astype(np.uint8)
img2 = np.array(img2).astype(np.uint8)
# grayscale images
if len(img1.shape) == 2:
img1 = np.tile(img1[...,None], (1, 1, 3))
img2 = np.tile(img2[...,None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
if self.augmentor is not None:
if self.sparse:
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
else:
img1, img2, flow = self.augmentor(img1, img2, flow)
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
if valid is not None:
valid = torch.from_numpy(valid)
else:
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
return img1, img2, flow, valid.float()
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
def __len__(self):
return len(self.image_list)
class MpiSintel(FlowDataset):
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
super(MpiSintel, self).__init__(aug_params)
flow_root = osp.join(root, split, 'flow')
image_root = osp.join(root, split, dstype)
if split == 'test':
self.is_test = True
for scene in os.listdir(image_root):
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
for i in range(len(image_list)-1):
self.image_list += [ [image_list[i], image_list[i+1]] ]
self.extra_info += [ (scene, i) ] # scene and frame_id
if split != 'test':
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
class FlyingChairs(FlowDataset):
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
super(FlyingChairs, self).__init__(aug_params)
images = sorted(glob(osp.join(root, '*.ppm')))
flows = sorted(glob(osp.join(root, '*.flo')))
assert (len(images)//2 == len(flows))
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
if (split=='training' and xid==1) or (split=='validation' and xid==2):
self.flow_list += [ flows[i] ]
self.image_list += [ [images[2*i], images[2*i+1]] ]
class FlyingThings3D(FlowDataset):
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
super(FlyingThings3D, self).__init__(aug_params)
for cam in ['left']:
for direction in ['into_future', 'into_past']:
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
for idir, fdir in zip(image_dirs, flow_dirs):
images = sorted(glob(osp.join(idir, '*.png')) )
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
for i in range(len(flows)-1):
if direction == 'into_future':
self.image_list += [ [images[i], images[i+1]] ]
self.flow_list += [ flows[i] ]
elif direction == 'into_past':
self.image_list += [ [images[i+1], images[i]] ]
self.flow_list += [ flows[i+1] ]
class KITTI(FlowDataset):
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
super(KITTI, self).__init__(aug_params, sparse=True)
if split == 'testing':
self.is_test = True
root = osp.join(root, split)
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
for img1, img2 in zip(images1, images2):
frame_id = img1.split('/')[-1]
self.extra_info += [ [frame_id] ]
self.image_list += [ [img1, img2] ]
if split == 'training':
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
class HD1K(FlowDataset):
def __init__(self, aug_params=None, root='datasets/HD1k'):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
if len(flows) == 0:
break
for i in range(len(flows)-1):
self.flow_list += [flows[i]]
self.image_list += [ [images[i], images[i+1]] ]
seq_ix += 1
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
""" Create the data loader for the corresponding trainign set """
if args.stage == 'chairs':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
train_dataset = FlyingChairs(aug_params, split='training')
elif args.stage == 'things':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
train_dataset = clean_dataset + final_dataset
elif args.stage == 'sintel':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
if TRAIN_DS == 'C+T+K+S+H':
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
elif TRAIN_DS == 'C+T+K/S':
train_dataset = 100*sintel_clean + 100*sintel_final + things
elif args.stage == 'kitti':
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
train_dataset = KITTI(aug_params, split='training')
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
print('Training with %d image pairs' % len(train_dataset))
return train_loader
================================================
FILE: utils/RAFT/demo.py
================================================
import sys
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
from .raft import RAFT
from .utils import flow_viz
from .utils.utils import InputPadder
DEVICE = 'cuda'
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img
def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))
images = torch.stack(images, dim=0)
images = images.to(DEVICE)
padder = InputPadder(images.shape)
return padder.pad(images)[0]
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
# img_flo = np.concatenate([img, flo], axis=0)
img_flo = flo
cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
# cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
# cv2.waitKey()
def demo(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
images = load_image_list(images)
for i in range(images.shape[0]-1):
image1 = images[i,None]
image2 = images[i+1,None]
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
def RAFT_infer(args):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
return model
================================================
FILE: utils/RAFT/extractor.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes//4)
self.norm2 = nn.BatchNorm2d(planes//4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes//4)
self.norm2 = nn.InstanceNorm2d(planes//4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
================================================
FILE: utils/RAFT/raft.py
================================================
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .update import BasicUpdateBlock, SmallUpdateBlock
from .extractor import BasicEncoder, SmallEncoder
from .corr import CorrBlock, AlternateCorrBlock
from .utils.utils import bilinear_sampler, coords_grid, upflow8
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
if 'dropout' not in args._get_kwargs():
args.dropout = 0
if 'alternate_corr' not in args._get_kwargs():
args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H//8, W//8).to(img.device)
coords1 = coords_grid(N, H//8, W//8).to(img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8*H, 8*W)
def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions
================================================
FILE: utils/RAFT/update.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 64*9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow
================================================
FILE: utils/RAFT/utils/__init__.py
================================================
from .flow_viz import flow_to_image
from .frame_utils import writeFlow
================================================
FILE: utils/RAFT/utils/augmentor.py
================================================
import numpy as np
import random
import math
from PIL import Image
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
import torch
from torchvision.transforms import ColorJitter
import torch.nn.functional as F
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
""" Photometric augmentation """
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
""" Occlusion augmentation """
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 8) / float(ht),
(self.crop_size[1] + 8) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
scale_y = scale
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow = flow * [scale_x, scale_y]
if self.do_flip:
if np.random.rand() < self.h_flip_prob: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow
def __call__(self, img1, img2, flow):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow = self.spatial_transform(img1, img2, flow)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
return img1, img2, flow
class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
self.max_scale = max_scale
self.spatial_aug_prob = 0.8
self.stretch_prob = 0.8
self.max_stretch = 0.2
# flip augmentation params
self.do_flip = do_flip
self.h_flip_prob = 0.5
self.v_flip_prob = 0.1
# photometric augmentation params
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2):
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
for _ in range(np.random.randint(1, 3)):
x0 = np.random.randint(0, wd)
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)
dy = np.random.randint(50, 100)
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
return img1, img2
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
ht, wd = flow.shape[:2]
coords = np.meshgrid(np.arange(wd), np.arange(ht))
coords = np.stack(coords, axis=-1)
coords = coords.reshape(-1, 2).astype(np.float32)
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
coords0 = coords[valid>=1]
flow0 = flow[valid>=1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
xx = np.round(coords1[:,0]).astype(np.int32)
yy = np.round(coords1[:,1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
yy = yy[v]
flow1 = flow1[v]
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
flow_img[yy, xx] = flow1
valid_img[yy, xx] = 1
return flow_img, valid_img
def spatial_transform(self, img1, img2, flow, valid):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
(self.crop_size[0] + 1) / float(ht),
(self.crop_size[1] + 1) / float(wd))
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = np.clip(scale, min_scale, None)
scale_y = np.clip(scale, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
if self.do_flip:
if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
valid = valid[:, ::-1]
margin_y = 20
margin_x = 50
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
return img1, img2, flow, valid
def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
img1 = np.ascontiguousarray(img1)
img2 = np.ascontiguousarray(img2)
flow = np.ascontiguousarray(flow)
valid = np.ascontiguousarray(valid)
return img1, img2, flow, valid
================================================
FILE: utils/RAFT/utils/flow_viz.py
================================================
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-08-03
import numpy as np
def make_colorwheel():
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
Code follows the original C++ source code of Daniel Scharstein.
Code follows the the Matlab source code of Deqing Sun.
Returns:
np.ndarray: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
col = col+RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
colorwheel[col:col+YG, 1] = 255
col = col+YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
col = col+GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
colorwheel[col:col+CB, 2] = 255
col = col+CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
col = col+BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
colorwheel[col:col+MR, 0] = 255
return colorwheel
def flow_uv_to_colors(u, v, convert_to_bgr=False):
"""
Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
Args:
u (np.ndarray): Input horizontal flow of shape [H,W]
v (np.ndarray): Input vertical flow of shape [H,W]
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u)/np.pi
fk = (a+1) / 2*(ncols-1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:,i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1-f)*col0 + f*col1
idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1-col[idx])
col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB
ch_idx = 2-i if convert_to_bgr else i
flow_image[:,:,ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
"""
Expects a two dimensional flow image of shape.
Args:
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:,:,0]
v = flow_uv[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_uv_to_colors(u, v, convert_to_bgr)
================================================
FILE: utils/RAFT/utils/frame_utils.py
================================================
import numpy as np
from PIL import Image
from os.path import *
import re
import cv2
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
""" Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def readPFM(file):
file = open(file, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().rstrip()
if header == b'PF':
color = True
elif header == b'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
def writeFlow(filename,uv,v=None):
""" Write optical flow to file.
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
"""
nBands = 2
if v is None:
assert(uv.ndim == 3)
assert(uv.shape[2] == 2)
u = uv[:,:,0]
v = uv[:,:,1]
else:
u = uv
assert(u.shape == v.shape)
height,width = u.shape
f = open(filename,'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width*nBands))
tmp[:,np.arange(width)*2] = u
tmp[:,np.arange(width)*2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def readFlowKITTI(filename):
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
flow = flow[:,:,::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
flow = np.stack([-disp, np.zeros_like(disp)], -1)
return flow, valid
def writeFlowKITTI(filename, uv):
uv = 64.0 * uv + 2**15
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1]
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
return Image.open(file_name)
elif ext == '.bin' or ext == '.raw':
return np.load(file_name)
elif ext == '.flo':
return readFlow(file_name).astype(np.float32)
elif ext == '.pfm':
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []
================================================
FILE: utils/RAFT/utils/utils.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
from scipy import interpolate
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == 'sintel':
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
else:
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self,x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata(
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
flow_y = interpolate.griddata(
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd):
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
================================================
FILE: utils/colmap_utils.py
================================================
# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
# its contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)
import os
import sys
import collections
import numpy as np
import struct
CameraModel = collections.namedtuple(
"CameraModel", ["model_id", "model_name", "num_params"])
Camera = collections.namedtuple(
"Camera", ["id", "model", "width", "height", "params"])
BaseImage = collections.namedtuple(
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
Point3D = collections.namedtuple(
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
class Image(BaseImage):
def qvec2rotmat(self):
return qvec2rotmat(self.qvec)
CAMERA_MODELS = {
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
CameraModel(model_id=7, model_name="FOV", num_params=5),
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
}
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \
for camera_model in CAMERA_MODELS])
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
"""Read and unpack the next bytes from a binary file.
:param fid:
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
:param endian_character: Any of {@, =, <, >, !}
:return: Tuple of read and unpacked values.
"""
data = fid.read(num_bytes)
return struct.unpack(endian_character + format_char_sequence, data)
def read_cameras_text(path):
"""
see: src/base/reconstruction.cc
void Reconstruction::WriteCamerasText(const std::string& path)
void Reconstruction::ReadCamerasText(const std::string& path)
"""
cameras = {}
with open(path, "r") as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
camera_id = int(elems[0])
model = elems[1]
width = int(elems[2])
height = int(elems[3])
params = np.array(tuple(map(float, elems[4:])))
cameras[camera_id] = Camera(id=camera_id, model=model,
width=width, height=height,
params=params)
return cameras
def read_cameras_binary(path_to_model_file):
"""
see: src/base/reconstruction.cc
void Reconstruction::WriteCamerasBinary(const std::string& path)
void Reconstruction::ReadCamerasBinary(const std::string& path)
"""
cameras = {}
with open(path_to_model_file, "rb") as fid:
num_cameras = read_next_bytes(fid, 8, "Q")[0]
for camera_line_index in range(num_cameras):
camera_properties = read_next_bytes(
fid, num_bytes=24, format_char_sequence="iiQQ")
camera_id = camera_properties[0]
model_id = camera_properties[1]
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
width = camera_properties[2]
height = camera_properties[3]
num_params = CAMERA_MODEL_IDS[model_id].num_params
params = read_next_bytes(fid, num_bytes=8*num_params,
format_char_sequence="d"*num_params)
cameras[camera_id] = Camera(id=camera_id,
model=model_name,
width=width,
height=height,
params=np.array(params))
assert len(cameras) == num_cameras
return cameras
def read_images_text(path):
"""
see: src/base/reconstruction.cc
void Reconstruction::ReadImagesText(const std::string& path)
void Reconstruction::WriteImagesText(const std::string& path)
"""
images = {}
with open(path, "r") as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
image_id = int(elems[0])
qvec = np.array(tuple(map(float, elems[1:5])))
tvec = np.array(tuple(map(float, elems[5:8])))
camera_id = int(elems[8])
image_name = elems[9]
elems = fid.readline().split()
xys = np.column_stack([tuple(map(float, elems[0::3])),
tuple(map(float, elems[1::3]))])
point3D_ids = np.array(tuple(map(int, elems[2::3])))
images[image_id] = Image(
id=image_id, qvec=qvec, tvec=tvec,
camera_id=camera_id, name=image_name,
xys=xys, point3D_ids=point3D_ids)
return images
def read_images_binary(path_to_model_file):
"""
see: src/base/reconstruction.cc
void Reconstruction::ReadImagesBinary(const std::string& path)
void Reconstruction::WriteImagesBinary(const std::string& path)
"""
images = {}
with open(path_to_model_file, "rb") as fid:
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
for image_index in range(num_reg_images):
binary_image_properties = read_next_bytes(
fid, num_bytes=64, format_char_sequence="idddddddi")
image_id = binary_image_properties[0]
qvec = np.array(binary_image_properties[1:5])
tvec = np.array(binary_image_properties[5:8])
camera_id = binary_image_properties[8]
image_name = ""
current_char = read_next_bytes(fid, 1, "c")[0]
while current_char != b"\x00": # look for the ASCII 0 entry
image_name += current_char.decode("utf-8")
current_char = read_next_bytes(fid, 1, "c")[0]
num_points2D = read_next_bytes(fid, num_bytes=8,
format_char_sequence="Q")[0]
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
format_char_sequence="ddq"*num_points2D)
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
tuple(map(float, x_y_id_s[1::3]))])
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
images[image_id] = Image(
id=image_id, qvec=qvec, tvec=tvec,
camera_id=camera_id, name=image_name,
xys=xys, point3D_ids=point3D_ids)
return images
def read_points3D_text(path):
"""
see: src/base/reconstruction.cc
void Reconstruction::ReadPoints3DText(const std::string& path)
void Reconstruction::WritePoints3DText(const std::string& path)
"""
points3D = {}
with open(path, "r") as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
point3D_id = int(elems[0])
xyz = np.array(tuple(map(float, elems[1:4])))
rgb = np.array(tuple(map(int, elems[4:7])))
error = float(elems[7])
image_ids = np.array(tuple(map(int, elems[8::2])))
point2D_idxs = np.array(tuple(map(int, elems[9::2])))
points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
error=error, image_ids=image_ids,
point2D_idxs=point2D_idxs)
return points3D
def read_points3d_binary(path_to_model_file):
"""
see: src/base/reconstruction.cc
void Reconstruction::ReadPoints3DBinary(const std::string& path)
void Reconstruction::WritePoints3DBinary(const std::string& path)
"""
points3D = {}
with open(path_to_model_file, "rb") as fid:
num_points = read_next_bytes(fid, 8, "Q")[0]
for point_line_index in range(num_points):
binary_point_line_properties = read_next_bytes(
fid, num_bytes=43, format_char_sequence="QdddBBBd")
point3D_id = binary_point_line_properties[0]
xyz = np.array(binary_point_line_properties[1:4])
rgb = np.array(binary_point_line_properties[4:7])
error = np.array(binary_point_line_properties[7])
track_length = read_next_bytes(
fid, num_bytes=8, format_char_sequence="Q")[0]
track_elems = read_next_bytes(
fid, num_bytes=8*track_length,
format_char_sequence="ii"*track_length)
image_ids = np.array(tuple(map(int, track_elems[0::2])))
point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
points3D[point3D_id] = Point3D(
id=point3D_id, xyz=xyz, rgb=rgb,
error=error, image_ids=image_ids,
point2D_idxs=point2D_idxs)
return points3D
def read_model(path, ext):
if ext == ".txt":
cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
images = read_images_text(os.path.join(path, "images" + ext))
points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
else:
cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
images = read_images_binary(os.path.join(path, "images" + ext))
points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
return cameras, images, points3D
def qvec2rotmat(qvec):
return np.array([
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
def rotmat2qvec(R):
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
K = np.array([
[Rxx - Ryy - Rzz, 0, 0, 0],
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
eigvals, eigvecs = np.linalg.eigh(K)
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
if qvec[0] < 0:
qvec *= -1
return qvec
def main():
if len(sys.argv) != 3:
print("Usage: python read_model.py path/to/model/folder [.txt,.bin]")
return
cameras, images, points3D = read_model(path=sys.argv[1], ext=sys.argv[2])
print("num_cameras:", len(cameras))
print("num_images:", len(images))
print("num_points3D:", len(points3D))
if __name__ == "__main__":
main()
================================================
FILE: utils/evaluation.py
================================================
import os
import cv2
import lpips
import torch
import numpy as np
from skimage.metrics import structural_similarity
def im2tensor(img):
return torch.Tensor(img.transpose(2, 0, 1) / 127.5 - 1.0)[None, ...]
def create_dir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def readimage(data_dir, sequence, time, method):
img = cv2.imread(os.path.join(data_dir, method, sequence, 'v000_t' + str(time).zfill(3) + '.png'))
return img
def calculate_metrics(data_dir, sequence, methods, lpips_loss):
PSNRs = np.zeros((len(methods)))
SSIMs = np.zeros((len(methods)))
LPIPSs = np.zeros((len(methods)))
nFrame = 0
# Yoon's results do not include v000_t000 and v000_t011. Omit these two
# frames if evaluating Yoon's method.
if 'Yoon' in methods:
time_start = 1
time_end = 11
else:
time_start = 0
time_end = 12
for time in range(time_start, time_end): # Fix view v0, change time
nFrame += 1
img_true = readimage(data_dir, sequence, time, 'gt')
for method_idx, method in enumerate(methods):
if 'Yoon' in methods and sequence == 'Truck' and time == 10:
break
img = readimage(data_dir, sequence, time, method)
PSNR = cv2.PSNR(img_true, img)
SSIM = structural_similarity(img_true, img, multichannel=True)
LPIPS = lpips_loss.forward(im2tensor(img_true), im2tensor(img)).item()
PSNRs[method_idx] += PSNR
SSIMs[method_idx] += SSIM
LPIPSs[method_idx] += LPIPS
PSNRs = PSNRs / nFrame
SSIMs = SSIMs / nFrame
LPIPSs = LPIPSs / nFrame
return PSNRs, SSIMs, LPIPSs
if __name__ == '__main__':
lpips_loss = lpips.LPIPS(net='alex') # best forward scores
data_dir = '../results'
sequences = ['Balloon1', 'Balloon2', 'Jumping', 'Playground', 'Skating', 'Truck', 'Umbrella']
# methods = ['NeRF', 'NeRF_t', 'Yoon', 'NR', 'NSFF', 'Ours']
methods = ['NeRF', 'NeRF_t', 'NR', 'NSFF', 'Ours']
PSNRs_total = np.zeros((len(methods)))
SSIMs_total = np.zeros((len(methods)))
LPIPSs_total = np.zeros((len(methods)))
for sequence in sequences:
print(sequence)
PSNRs, SSIMs, LPIPSs = calculate_metrics(data_dir, sequence, methods, lpips_loss)
for method_idx, method in enumerate(methods):
print(method.ljust(7) + '%.2f'%(PSNRs[method_idx]) + ' / %.4f'%(SSIMs[method_idx]) + ' / %.3f'%(LPIPSs[method_idx]))
PSNRs_total += PSNRs
SSIMs_total += SSIMs
LPIPSs_total += LPIPSs
PSNRs_total = PSNRs_total / len(sequences)
SSIMs_total = SSIMs_total / len(sequences)
LPIPSs_total = LPIPSs_total / len(sequences)
print('Avg.')
for method_idx, method in enumerate(methods):
print(method.ljust(7) + '%.2f'%(PSNRs_total[method_idx]) + ' / %.4f'%(SSIMs_total[method_idx]) + ' / %.3f'%(LPIPSs_total[method_idx]))
================================================
FILE: utils/flow_utils.py
================================================
import os
import cv2
import numpy as np
from PIL import Image
from os.path import *
UNKNOWN_FLOW_THRESH = 1e7
def flow_to_image(flow, global_max=None):
"""
Convert flow into middlebury color code image
:param flow: optical flow map
:return: optical flow image in middlebury color
"""
u = flow[:, :, 0]
v = flow[:, :, 1]
maxu = -999.
maxv = -999.
minu = 999.
minv = 999.
idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)
u[idxUnknow] = 0
v[idxUnknow] = 0
maxu = max(maxu, np.max(u))
minu = min(minu, np.min(u))
maxv = max(maxv, np.max(v))
minv = min(minv, np.min(v))
rad = np.sqrt(u ** 2 + v ** 2)
if global_max == None:
maxrad = max(-1, np.max(rad))
else:
maxrad = global_max
u = u/(maxrad + np.finfo(float).eps)
v = v/(maxrad + np.finfo(float).eps)
img = compute_color(u, v)
idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)
img[idx] = 0
return np.uint8(img)
def compute_color(u, v):
"""
compute optical flow color map
:param u: optical flow horizontal map
:param v: optical flow vertical map
:return: optical flow in color code
"""
[h, w] = u.shape
img = np.zeros([h, w, 3])
nanIdx = np.isnan(u) | np.isnan(v)
u[nanIdx] = 0
v[nanIdx] = 0
colorwheel = make_color_wheel()
ncols = np.size(colorwheel, 0)
rad = np.sqrt(u**2+v**2)
a = np.arctan2(-v, -u) / np.pi
fk = (a+1) / 2 * (ncols - 1) + 1
k0 = np.floor(fk).astype(int)
k1 = k0 + 1
k1[k1 == ncols+1] = 1
f = fk - k0
for i in range(0, np.size(colorwheel,1)):
tmp = colorwheel[:, i]
col0 = tmp[k0-1] / 255
col1 = tmp[k1-1] / 255
col = (1-f) * col0 + f * col1
idx = rad <= 1
col[idx] = 1-rad[idx]*(1-col[idx])
notidx = np.logical_not(idx)
col[notidx] *= 0.75
img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))
return img
def make_color_wheel():
"""
Generate color wheel according Middlebury color code
:return: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros([ncols, 3])
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))
col += RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))
colorwheel[col:col+YG, 1] = 255
col += YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))
col += GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))
colorwheel[col:col+CB, 2] = 255
col += CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))
col += + BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))
colorwheel[col:col+MR, 0] = 255
return colorwheel
def resize_flow(flow, H_new, W_new):
H_old, W_old = flow.shape[0:2]
flow_resized = cv2.resize(flow, (W_new, H_new), interpolation=cv2.INTER_LINEAR)
flow_resized[:, :, 0] *= H_new / H_old
flow_resized[:, :, 1] *= W_new / W_old
return flow_resized
def warp_flow(img, flow):
h, w = flow.shape[:2]
flow_new = flow.copy()
flow_new[:,:,0] += np.arange(w)
flow_new[:,:,1] += np.arange(h)[:,np.newaxis]
res = cv2.remap(img, flow_new, None,
cv2.INTER_CUBIC,
borderMode=cv2.BORDER_CONSTANT)
return res
def consistCheck(flowB, flowF):
# |--------------------| |--------------------|
# | y | | v |
# | x * | | u * |
# | | | |
# |--------------------| |--------------------|
# sub: numPix * [y x t]
imgH, imgW, _ = flowF.shape
(fy, fx) = np.mgrid[0 : imgH, 0 : imgW].astype(np.float32)
fxx = fx + flowB[:, :, 0] # horizontal
fyy = fy + flowB[:, :, 1] # vertical
u = (fxx + cv2.remap(flowF[:, :, 0], fxx, fyy, cv2.INTER_LINEAR) - fx)
v = (fyy + cv2.remap(flowF[:, :, 1], fxx, fyy, cv2.INTER_LINEAR) - fy)
BFdiff = (u ** 2 + v ** 2) ** 0.5
return BFdiff, np.stack((u, v), axis=2)
def read_optical_flow(basedir, img_i_name, read_fwd):
flow_dir = os.path.join(basedir, 'flow')
fwd_flow_path = os.path.join(flow_dir, '%s_fwd.npz'%img_i_name[:-4])
bwd_flow_path = os.path.join(flow_dir, '%s_bwd.npz'%img_i_name[:-4])
if read_fwd:
fwd_data = np.load(fwd_flow_path)
fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']
return fwd_flow, fwd_mask
else:
bwd_data = np.load(bwd_flow_path)
bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']
return bwd_flow, bwd_mask
def compute_epipolar_distance(T_21, K, p_1, p_2):
R_21 = T_21[:3, :3]
t_21 = T_21[:3, 3]
E_mat = np.dot(skew(t_21), R_21)
# compute bearing vector
inv_K = np.linalg.inv(K)
F_mat = np.dot(np.dot(inv_K.T, E_mat), inv_K)
l_2 = np.dot(F_mat, p_1)
algebric_e_distance = np.sum(p_2 * l_2, axis=0)
n_term = np.sqrt(l_2[0, :]**2 + l_2[1, :]**2) + 1e-8
geometric_e_distance = algebric_e_distance/n_term
geometric_e_distance = np.abs(geometric_e_distance)
return geometric_e_distance
def skew(x):
return np.array([[0, -x[2], x[1]],
[x[2], 0, -x[0]],
[-x[1], x[0], 0]])
================================================
FILE: utils/generate_data.py
================================================
import os
import numpy as np
import imageio
import glob
import torch
import torchvision
import skimage.morphology
import argparse
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def create_dir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def multi_view_multi_time(args):
"""
Generating multi view multi time data
"""
Maskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval()
threshold = 0.5
videoname, ext = os.path.splitext(os.path.basename(args.videopath))
imgs = []
reader = imageio.get_reader(args.videopath)
for i, im in enumerate(reader):
imgs.append(im)
imgs = np.array(imgs)
num_frames, H, W, _ = imgs.shape
imgs = imgs[::int(np.ceil(num_frames / 100))]
create_dir(os.path.join(args.data_dir, videoname, 'images'))
create_dir(os.path.join(args.data_dir, videoname, 'images_colmap'))
create_dir(os.path.join(args.data_dir, videoname, 'background_mask'))
for idx, img in enumerate(imgs):
print(idx)
imageio.imwrite(os.path.join(args.data_dir, videoname, 'images', str(idx).zfill(3) + '.png'), img)
imageio.imwrite(os.path.join(args.data_dir, videoname, 'images_colmap', str(idx).zfill(3) + '.jpg'), img)
# Get coarse background mask
img = torchvision.transforms.functional.to_tensor(img).to(device)
background_mask = torch.FloatTensor(H, W).fill_(1.0).to(device)
objPredictions = Maskrcnn([img])[0]
for intMask in range(len(objPredictions['masks'])):
if objPredictions['scores'][intMask].item() > threshold:
if objPredictions['labels'][intMask].item() == 1: # person
background_mask[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
background_mask_np = ((background_mask.cpu().numpy() > 0.1) * 255).astype(np.uint8)
imageio.imwrite(os.path.join(args.data_dir, videoname, 'background_mask', str(idx).zfill(3) + '.jpg.png'), background_mask_np)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--videopath", type=str,
help='video path')
parser.add_argument("--data_dir", type=str, default='../data/',
help='where to store data')
args = parser.parse_args()
multi_view_multi_time(args)
================================================
FILE: utils/generate_depth.py
================================================
"""Compute depth maps for images in the input folder.
"""
import os
import cv2
import glob
import torch
import argparse
import numpy as np
from torchvision.transforms import Compose
from midas.midas_net import MidasNet
from midas.transforms import Resize, NormalizeImage, PrepareForNet
def create_dir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def read_image(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def run(input_path, output_path, output_img_path, model_path):
"""Run MonoDepthNN to compute depth maps.
Args:
input_path (str): path to input folder
output_path (str): path to output folder
model_path (str): path to saved model
"""
print("initialize")
# select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device: %s" % device)
# load network
model = MidasNet(model_path, non_negative=True)
sh = cv2.imread(sorted(glob.glob(os.path.join(input_path, "*")))[0]).shape
net_w, net_h = sh[1], sh[0]
resize_mode="upper_bound"
transform = Compose(
[
Resize(
net_w,
net_h,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method=resize_mode,
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
PrepareForNet(),
]
)
model.eval()
model.to(device)
# get input
img_names = sorted(glob.glob(os.path.join(input_path, "*")))
num_images = len(img_names)
# create output folder
os.makedirs(output_path, exist_ok=True)
print("start processing")
for ind, img_name in enumerate(img_names):
print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
# input
img = read_image(img_name)
img_input = transform({"image": img})["image"]
# compute
with torch.no_grad():
sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
prediction = model.forward(sample)
prediction = (
torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=[net_h, net_w],
mode="bicubic",
align_corners=False,
)
.squeeze()
.cpu()
.numpy()
)
# output
filename = os.path.join(
output_path, os.path.splitext(os.path.basename(img_name))[0]
)
print(filename + '.npy')
np.save(filename + '.npy', prediction.astype(np.float32))
depth_min = prediction.min()
depth_max = prediction.max()
max_val = (2**(8*2))-1
if depth_max - depth_min > np.finfo("float").eps:
out = max_val * (prediction - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(prediction.shape, dtype=prediction.type)
cv2.imwrite(os.path.join(output_img_path, os.path.splitext(os.path.basename(img_name))[0] + '.png'), out.astype("uint16"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, help='Dataset path')
parser.add_argument('--model', help="restore midas checkpoint")
args = parser.parse_args()
input_path = os.path.join(args.dataset_path, 'images')
output_path = os.path.join(args.dataset_path, 'disp')
output_img_path = os.path.join(args.dataset_path, 'disp_png')
create_dir(output_path)
create_dir(output_img_path)
# set torch options
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# compute depth maps
run(input_path, output_path, output_img_path, args.model)
================================================
FILE: utils/generate_flow.py
================================================
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
from RAFT.raft import RAFT
from RAFT.utils import flow_viz
from RAFT.utils.utils import InputPadder
from flow_utils import *
DEVICE = 'cuda'
def create_dir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
def warp_flow(img, flow):
h, w = flow.shape[:2]
flow_new = flow.copy()
flow_new[:,:,0] += np.arange(w)
flow_new[:,:,1] += np.arange(h)[:,np.newaxis]
res = cv2.remap(img, flow_new, None, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT)
return res
def compute_fwdbwd_mask(fwd_flow, bwd_flow):
alpha_1 = 0.5
alpha_2 = 0.5
bwd2fwd_flow = warp_flow(bwd_flow, fwd_flow)
fwd_lr_error = np.linalg.norm(fwd_flow + bwd2fwd_flow, axis=-1)
fwd_mask = fwd_lr_error < alpha_1 * (np.linalg.norm(fwd_flow, axis=-1) \
+ np.linalg.norm(bwd2fwd_flow, axis=-1)) + alpha_2
fwd2bwd_flow = warp_flow(fwd_flow, bwd_flow)
bwd_lr_error = np.linalg.norm(bwd_flow + fwd2bwd_flow, axis=-1)
bwd_mask = bwd_lr_error < alpha_1 * (np.linalg.norm(bwd_flow, axis=-1) \
+ np.linalg.norm(fwd2bwd_flow, axis=-1)) + alpha_2
return fwd_mask, bwd_mask
def run(args, input_path, output_path, output_img_path):
model = torch.nn.DataParallel(RAFT(args))
model.load_state_dict(torch.load(args.model))
model = model.module
model.to(DEVICE)
model.eval()
with torch.no_grad():
images = glob.glob(os.path.join(input_path, '*.png')) + \
glob.glob(os.path.join(input_path, '*.jpg'))
images = sorted(images)
for i in range(len(images) - 1):
print(i)
image1 = load_image(images[i])
image2 = load_image(images[i + 1])
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
_, flow_fwd = model(image1, image2, iters=20, test_mode=True)
_, flow_bwd = model(image2, image1, iters=20, test_mode=True)
flow_fwd = padder.unpad(flow_fwd[0]).cpu().numpy().transpose(1, 2, 0)
flow_bwd = padder.unpad(flow_bwd[0]).cpu().numpy().transpose(1, 2, 0)
mask_fwd, mask_bwd = compute_fwdbwd_mask(flow_fwd, flow_bwd)
# Save flow
np.savez(os.path.join(output_path, '%03d_fwd.npz'%i), flow=flow_fwd, mask=mask_fwd)
np.savez(os.path.join(output_path, '%03d_bwd.npz'%(i + 1)), flow=flow_bwd, mask=mask_bwd)
# Save flow_img
Image.fromarray(flow_viz.flow_to_image(flow_fwd)).save(os.path.join(output_img_path, '%03d_fwd.png'%i))
Image.fromarray(flow_viz.flow_to_image(flow_bwd)).save(os.path.join(output_img_path, '%03d_bwd.png'%(i + 1)))
Image.fromarray(mask_fwd).save(os.path.join(output_img_path, '%03d_fwd_mask.png'%i))
Image.fromarray(mask_bwd).save(os.path.join(output_img_path, '%03d_bwd_mask.png'%(i + 1)))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, help='Dataset path')
parser.add_argument('--model', help="restore RAFT checkpoint")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
args = parser.parse_args()
input_path = os.path.join(args.dataset_path, 'images')
output_path = os.path.join(args.dataset_path, 'flow')
output_img_path = os.path.join(args.dataset_path, 'flow_png')
create_dir(output_path)
create_dir(output_img_path)
run(args, input_path, output_path, output_img_path)
================================================
FILE: utils/generate_motion_mask.py
================================================
import os
import cv2
import PIL
import glob
import torch
import argparse
import numpy as np
from colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary
import skimage.morphology
import torchvision
from flow_utils import read_optical_flow, compute_epipolar_distance, skew
def create_dir(dir):
if not os.path.exists(dir):
os.makedirs(dir)
def extract_poses(im):
R = im.qvec2rotmat()
t = im.tvec.reshape([3,1])
bottom = np.array([0,0,0,1.]).reshape([1,4])
m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
return m
def load_colmap_data(realdir):
camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin')
camdata = read_cameras_binary(camerasfile)
list_of_keys = list(camdata.keys())
cam = camdata[list_of_keys[0]]
print( 'Cameras', len(cam))
h, w, f = cam.height, cam.width, cam.params[0]
# w, h, f = factor * w, factor * h, factor * f
hwf = np.array([h,w,f]).reshape([3,1])
imagesfile = os.path.join(realdir, 'sparse/0/images.bin')
imdata = read_images_binary(imagesfile)
w2c_mats = []
# bottom = np.array([0,0,0,1.]).reshape([1,4])
names = [imdata[k].name for k in imdata]
img_keys = [k for k in imdata]
print( 'Images #', len(names))
perm = np.argsort(names)
return imdata, perm, img_keys, hwf
def run_maskrcnn(model, img_path, intWidth=1024, intHeight=576):
# intHeight = 576
# intWidth = 1024
threshold = 0.5
o_image = PIL.Image.open(img_path)
image = o_image.resize((intWidth, intHeight), PIL.Image.ANTIALIAS)
image_tensor = torchvision.transforms.functional.to_tensor(image).cuda()
tenHumans = torch.FloatTensor(intHeight, intWidth).fill_(1.0).cuda()
objPredictions = model([image_tensor])[0]
for intMask in range(objPredictions['masks'].size(0)):
if objPredictions['scores'][intMask].item() > threshold:
if objPredictions['labels'][intMask].item() == 1: # person
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
if objPredictions['labels'][intMask].item() == 4: # motorcycle
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
if objPredictions['labels'][intMask].item() == 2: # bicycle
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
if objPredictions['labels'][intMask].item() == 8: # truck
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
if objPredictions['labels'][intMask].item() == 28: # umbrella
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
if objPredictions['labels'][intMask].item() == 17: # cat
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
if objPredictions['labels'][intMask].item() == 18: # dog
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
if objPredictions['labels'][intMask].item() == 36: # snowboard
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
if objPredictions['labels'][intMask].item() == 41: # skateboard
tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0
npyMask = skimage.morphology.erosion(tenHumans.cpu().numpy(),
skimage.morphology.disk(1))
npyMask = ((npyMask < 1e-3) * 255.0).clip(0.0, 255.0).astype(np.uint8)
return npyMask
def motion_segmentation(basedir, threshold,
input_semantic_w=1024,
input_semantic_h=576):
points3dfile = os.path.join(basedir, 'sparse/0/points3D.bin')
pts3d = read_points3d_binary(points3dfile)
img_dir = glob.glob(basedir + '/images_colmap')[0]
img0 = glob.glob(glob.glob(img_dir)[0] + '/*jpg')[0]
shape_0 = cv2.imread(img0).shape
resized_height, resized_width = shape_0[0], shape_0[1]
imdata, perm, img_keys, hwf = load_colmap_data(basedir)
scale_x, scale_y = resized_width / float(hwf[1]), resized_height / float(hwf[0])
K = np.eye(3)
K[0, 0] = hwf[2]
K[0, 2] = hwf[1] / 2.
K[1, 1] = hwf[2]
K[1, 2] = hwf[0] / 2.
xx = range(0, resized_width)
yy = range(0, resized_height)
xv, yv = np.meshgrid(xx, yy)
p_ref = np.float32(np.stack((xv, yv), axis=-1))
p_ref_h = np.reshape(p_ref, (-1, 2))
p_ref_h = np.concatenate((p_ref_h, np.ones((p_ref_h.shape[0], 1))), axis=-1).T
num_frames = len(perm)
if os.path.isdir(os.path.join(basedir, 'images_colmap')):
num_colmap_frames = len(glob.glob(os.path.join(basedir, 'images_colmap', '*.jpg')))
num_data_frames = len(glob.glob(os.path.join(basedir, 'images', '*.png')))
if num_colmap_frames != num_data_frames:
num_frames = num_data_frames
save_mask_dir = os.path.join(basedir, 'motion_segmentation')
create_dir(save_mask_dir)
for i in range(0, num_frames):
im_prev = imdata[img_keys[perm[max(0, i - 1)]]]
im_ref = imdata[img_keys[perm[i]]]
im_post = imdata[img_keys[perm[min(num_frames -1, i + 1)]]]
print(im_prev.name, im_ref.name, im_post.name)
T_prev_G = extract_poses(im_prev)
T_ref_G = extract_poses(im_ref)
T_post_G = extract_poses(im_post)
T_ref2prev = np.dot(T_prev_G, np.linalg.inv(T_ref_G))
T_ref2post = np.dot(T_post_G, np.linalg.inv(T_ref_G))
# load optical flow
if i == 0:
fwd_flow, _ = read_optical_flow(basedir,
im_ref.name,
read_fwd=True)
bwd_flow = np.zeros_like(fwd_flow)
elif i == num_frames - 1:
bwd_flow, _ = read_optical_flow(basedir,
im_ref.name,
read_fwd=False)
fwd_flow = np.zeros_like(bwd_flow)
else:
fwd_flow, _ = read_optical_flow(basedir,
im_ref.name,
read_fwd=True)
bwd_flow, _ = read_optical_flow(basedir,
im_ref.name,
read_fwd=False)
p_post = p_ref + fwd_flow
p_post_h = np.reshape(p_post, (-1, 2))
p_post_h = np.concatenate((p_post_h, np.ones((p_post_h.shape[0], 1))), axis=-1).T
fwd_e_dist = compute_epipolar_distance(T_ref2post, K,
p_ref_h, p_post_h)
fwd_e_dist = np.reshape(fwd_e_dist, (fwd_flow.shape[0], fwd_flow.shape[1]))
p_prev = p_ref + bwd_flow
p_prev_h = np.reshape(p_prev, (-1, 2))
p_prev_h = np.concatenate((p_prev_h, np.ones((p_prev_h.shape[0], 1))), axis=-1).T
bwd_e_dist = compute_epipolar_distance(T_ref2prev, K,
p_ref_h, p_prev_h)
bwd_e_dist = np.reshape(bwd_e_dist, (bwd_flow.shape[0], bwd_flow.shape[1]))
e_dist = np.maximum(bwd_e_dist, fwd_e_dist)
motion_mask = skimage.morphology.binary_opening(e_dist > threshold, skimage.morphology.disk(1))
cv2.imwrite(os.path.join(save_mask_dir, im_ref.name.replace('.jpg', '.png')), np.uint8(255 * (0. + motion_mask)))
# RUN SEMANTIC SEGMENTATION
img_dir = os.path.join(basedir, 'images')
img_path_list = sorted(glob.glob(os.path.join(img_dir, '*.jpg'))) \
+ sorted(glob.glob(os.path.join(img_dir, '*.png')))
semantic_mask_dir = os.path.join(basedir, 'semantic_mask')
netMaskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval()
create_dir(semantic_mask_dir)
for i in range(0, len(img_path_list)):
img_path = img_path_list[i]
img_name = img_path.split('/')[-1]
semantic_mask = run_maskrcnn(netMaskrcnn, img_path,
input_semantic_w,
input_semantic_h)
cv2.imwrite(os.path.join(semantic_mask_dir,
img_name.replace('.jpg', '.png')),
semantic_mask)
# combine them
save_mask_dir = os.path.join(basedir, 'motion_masks')
create_dir(save_mask_dir)
mask_dir = os.path.join(basedir, 'motion_segmentation')
mask_path_list = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
semantic_dir = os.path.join(basedir, 'semantic_mask')
for mask_path in mask_path_list:
print(mask_path)
motion_mask = cv2.imread(mask_path)
motion_mask = cv2.resize(motion_mask, (resized_width, resized_height),
interpolation=cv2.INTER_NEAREST)
motion_mask = motion_mask[:, :, 0] > 0.1
# combine from motion segmentation
semantic_mask = cv2.imread(os.path.join(semantic_dir, mask_path.split('/')[-1]))
semantic_mask = cv2.resize(semantic_mask, (resized_width, resized_height),
interpolation=cv2.INTER_NEAREST)
semantic_mask = semantic_mask[:, :, 0] > 0.1
motion_mask = semantic_mask | motion_mask
motion_mask = skimage.morphology.dilation(motion_mask, skimage.morphology.disk(2))
cv2.imwrite(os.path.join(save_mask_dir, '%s'%mask_path.split('/')[-1]),
np.uint8(np.clip((motion_mask), 0, 1) * 255) )
# delete old mask dir
os.system('rm -r %s'%mask_dir)
os.system('rm -r %s'%semantic_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, help='Dataset path')
parser.add_argument("--epi_threshold", type=float,
default=1.0,
help='epipolar distance threshold for physical motion segmentation')
parser.add_argument("--input_flow_w", type=int,
default=768,
help='input image width for optical flow, \
the height will be computed based on original aspect ratio ')
parser.add_argument("--input_semantic_w", type=int,
default=1024,
help='input image width for semantic segmentation')
parser.add_argument("--input_semantic_h", type=int,
default=576,
help='input image height for semantic segmentation')
args = parser.parse_args()
motion_segmentation(args.dataset_path, args.epi_threshold,
args.input_semantic_w,
args.input_semantic_h)
================================================
FILE: utils/generate_pose.py
================================================
import os
import glob
import argparse
import numpy as np
from colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary
def load_colmap_data(realdir):
camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin')
camdata = read_cameras_binary(camerasfile)
list_of_keys = list(camdata.keys())
cam = camdata[list_of_keys[0]]
print( 'Cameras', len(cam))
h, w, f = cam.height, cam.width, cam.params[0]
# w, h, f = factor * w, factor * h, factor * f
hwf = np.array([h,w,f]).reshape([3,1])
imagesfile = os.path.join(realdir, 'sparse/0/images.bin')
imdata = read_images_binary(imagesfile)
w2c_mats = []
bottom = np.array([0,0,0,1.]).reshape([1,4])
names = [imdata[k].name for k in imdata]
img_keys = [k for k in imdata]
print('Images #', len(names))
perm = np.argsort(names)
points3dfile = os.path.join(realdir, 'sparse/0/points3D.bin')
pts3d = read_points3d_binary(points3dfile)
bounds_mats = []
for i in perm[0:len(img_keys)]:
im = imdata[img_keys[i]]
print(im.name)
R = im.qvec2rotmat()
t = im.tvec.reshape([3,1])
m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
w2c_mats.append(m)
pts_3d_idx = im.point3D_ids
pts_3d_vis_idx = pts_3d_idx[pts_3d_idx >= 0]
#
depth_list = []
for k in range(len(pts_3d_vis_idx)):
point_info = pts3d[pts_3d_vis_idx[k]]
P_g = point_info.xyz
P_c = np.dot(R, P_g.reshape(3, 1)) + t.reshape(3, 1)
depth_list.append(P_c[2])
zs = np.array(depth_list)
close_depth, inf_depth = np.percentile(zs, 5), np.percentile(zs, 95)
bounds = np.array([close_depth, inf_depth])
bounds_mats.append(bounds)
w2c_mats = np.stack(w2c_mats, 0)
c2w_mats = np.linalg.inv(w2c_mats)
poses = c2w_mats[:, :3, :4].transpose([1,2,0])
poses = np.concatenate([poses, np.tile(hwf[..., np.newaxis],
[1,1,poses.shape[-1]])], 1)
# must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t]
poses = np.concatenate([poses[:, 1:2, :],
poses[:, 0:1, :],
-poses[:, 2:3, :],
poses[:, 3:4, :],
poses[:, 4:5, :]], 1)
save_arr = []
for i in range((poses.shape[2])):
save_arr.append(np.concatenate([poses[..., i].ravel(), bounds_mats[i]], 0))
save_arr = np.array(save_arr)
print(save_arr.shape)
# Use all frames to calculate COLMAP camera poses.
if os.path.isdir(os.path.join(realdir, 'images_colmap')):
num_colmap_frames = len(glob.glob(os.path.join(realdir, 'images_colmap', '*.jpg')))
num_data_frames = len(glob.glob(os.path.join(realdir, 'images', '*.png')))
assert num_colmap_frames == save_arr.shape[0]
np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr[:num_data_frames, :])
else:
np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str,
help='Dataset path')
args = parser.parse_args()
load_colmap_data(args.dataset_path)
================================================
FILE: utils/midas/base_model.py
================================================
import torch
class BaseModel(torch.nn.Module):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = torch.load(path, map_location=torch.device('cpu'))
if "optimizer" in parameters:
parameters = parameters["model"]
self.load_state_dict(parameters)
================================================
FILE: utils/midas/blocks.py
================================================
import torch
import torch.nn as nn
from .vit import (
_make_pretrained_vitb_rn50_384,
_make_pretrained_vitl16_384,
_make_pretrained_vitb16_384,
forward_vit,
)
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
if backbone == "vitl16_384":
pretrained = _make_pretrained_vitl16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[256, 512, 1024, 1024], features, groups=groups, expand=expand
) # ViT-L/16 - 85.0% Top1 (backbone)
elif backbone == "vitb_rn50_384":
pretrained = _make_pretrained_vitb_rn50_384(
use_pretrained,
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)
scratch = _make_scratch(
[256, 512, 768, 768], features, groups=groups, expand=expand
) # ViT-H/16 - 85.0% Top1 (backbone)
elif backbone == "vitb16_384":
pretrained = _make_pretrained_vitb16_384(
use_pretrained, hooks=hooks, use_readout=use_readout
)
scratch = _make_scratch(
[96, 192, 384, 768], features, groups=groups, expand=expand
) # ViT-B/16 - 84.6% Top1 (backbone)
elif backbone == "resnext101_wsl":
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
elif backbone == "efficientnet_lite3":
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand==True:
out_shape1 = out_shape
out_shape2 = out_shape*2
out_shape3 = out_shape*4
out_shape4 = out_shape*8
scratch.layer1_rn = nn.Conv2d(
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer2_rn = nn.Conv2d(
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer3_rn = nn.Conv2d(
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
scratch.layer4_rn = nn.Conv2d(
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
)
return scratch
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
efficientnet = torch.hub.load(
"rwightman/gen-efficientnet-pytorch",
"tf_efficientnet_lite3",
pretrained=use_pretrained,
exportable=exportable
)
return _make_efficientnet_backbone(efficientnet)
def _make_efficientnet_backbone(effnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
)
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
return pretrained
def _make_resnet_backbone(resnet):
pretrained = nn.Module()
pretrained.layer1 = nn.Sequential(
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
)
pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4
return pretrained
def _make_pretrained_resnext101_wsl(use_pretrained):
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
return _make_resnet_backbone(resnet)
class Interpolate(nn.Module):
"""Interpolation module.
"""
def __init__(self, scale_factor, mode, align_corners=False):
"""Init.
Args:
scale_factor (float): scaling
mode (str): interpolation mode
"""
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: interpolated data
"""
x = self.interp(
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
)
return x
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=True
)
return output
class ResidualConvUnit_custom(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
)
if self.bn==True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn==True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn==True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
# return out + x
class FeatureFusionBlock_custom(nn.Module):
"""Feature fusion block.
"""
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock_custom, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand==True:
out_features = features//2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
# output += res
output = self.resConfUnit2(output)
output = nn.functional.interpolate(
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
)
output = self.out_conv(output)
return output
================================================
FILE: utils/midas/midas_net.py
================================================
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
This file contains code that is adapted from
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
"""
import torch
import torch.nn as nn
from .base_model import BaseModel
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)
super(MidasNet, self).__init__()
use_pretrained = False if path is None else True
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)
self.scratch.output_conv = nn.Sequential(
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear"),
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True) if non_negative else nn.Identity(),
)
if path:
self.load(path)
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv(path_1)
return torch.squeeze(out, dim=1)
================================================
FILE: utils/midas/transforms.py
================================================
import numpy as np
import cv2
import math
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
)
sample["disparity"] = cv2.resize(
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented"
)
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, min_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, min_val=self.__width
)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, max_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, max_val=self.__width
)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(
sample["image"].shape[1], sample["image"].shape[0]
)
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "disparity" in sample:
disparity = sample["disparity"].astype(np.float32)
sample["disparity"] = np.ascontiguousarray(disparity)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
return sample
================================================
FILE: utils/midas/vit.py
================================================
import torch
import torch.nn as nn
import timm
import types
import math
import torch.nn.functional as F
class Slice(nn.Module):
def __init__(self, start_index=1):
super(Slice, self).__init__()
self.start_index = start_index
def forward(self, x):
return x[:, self.start_index :]
class AddReadout(nn.Module):
def __init__(self, start_index=1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x):
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index :] + readout.unsqueeze(1)
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
def forward_vit(pretrained, x):
b, c, h, w = x.shape
glob = pretrained.model.forward_flex(x)
layer_1 = pretrained.activations["1"]
layer_2 = pretrained.activations["2"]
layer_3 = pretrained.activations["3"]
layer_4 = pretrained.activations["4"]
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
unflatten = nn.Sequential(
nn.Unflatten(
2,
torch.Size(
[
h // pretrained.model.patch_size[1],
w // pretrained.model.patch_size[0],
]
),
)
)
if layer_1.ndim == 3:
layer_1 = unflatten(layer_1)
if layer_2.ndim == 3:
layer_2 = unflatten(layer_2)
if layer_3.ndim == 3:
layer_3 = unflatten(layer_3)
if layer_4.ndim == 3:
layer_4 = unflatten(layer_4)
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
return layer_1, layer_2, layer_3, layer_4
def _resize_pos_embed(self, posemb, gs_h, gs_w):
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x):
b, c, h, w = x.shape
pos_embed = self._resize_pos_embed(
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
)
B = x.shape[0]
if hasattr(self.patch_embed, "backbone"):
x = self.patch_embed.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
if getattr(self, "dist_token", None) is not None:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
else:
cls_tokens = self.cls_token.expand(
B, -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name):
def hook(model, input, output):
activations[name] = output
return hook
def get_readout_oper(vit_features, features, use_readout, start_index=1):
if use_readout == "ignore":
readout_oper = [Slice(start_index)] * len(features)
elif use_readout == "add":
readout_oper = [AddReadout(start_index)] * len(features)
elif use_readout == "project":
readout_oper = [
ProjectReadout(vit_features, start_index) for out_feat in features
]
else:
assert (
False
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
return readout_oper
def _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
size=[384, 384],
hooks=[2, 5, 8, 11],
vit_features=768,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
# 32, 48, 136, 384
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
hooks = [5, 11, 17, 23] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[256, 512, 1024, 1024],
hooks=hooks,
vit_features=1024,
use_readout=use_readout,
)
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
)
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
model = timm.create_model(
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
)
hooks = [2, 5, 8, 11] if hooks == None else hooks
return _make_vit_b16_backbone(
model,
features=[96, 192, 384, 768],
hooks=hooks,
use_readout=use_readout,
start_index=2,
)
def _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=[0, 1, 8, 11],
vit_features=768,
use_vit_only=False,
use_readout="ignore",
start_index=1,
):
pretrained = nn.Module()
pretrained.model = model
if use_vit_only == True:
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
else:
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
get_activation("1")
)
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
get_activation("2")
)
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
pretrained.activations = activations
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
if use_vit_only == True:
pretrained.act_postprocess1 = nn.Sequential(
readout_oper[0],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[0],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[0],
out_channels=features[0],
kernel_size=4,
stride=4,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
pretrained.act_postprocess2 = nn.Sequential(
readout_oper[1],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[1],
kernel_size=1,
stride=1,
padding=0,
),
nn.ConvTranspose2d(
in_channels=features[1],
out_channels=features[1],
kernel_size=2,
stride=2,
padding=0,
bias=True,
dilation=1,
groups=1,
),
)
else:
pretrained.act_postprocess1 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess2 = nn.Sequential(
nn.Identity(), nn.Identity(), nn.Identity()
)
pretrained.act_postprocess3 = nn.Sequential(
readout_oper[2],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[2],
kernel_size=1,
stride=1,
padding=0,
),
)
pretrained.act_postprocess4 = nn.Sequential(
readout_oper[3],
Transpose(1, 2),
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
nn.Conv2d(
in_channels=vit_features,
out_channels=features[3],
kernel_size=1,
stride=1,
padding=0,
),
nn.Conv2d(
in_channels=features[3],
out_channels=features[3],
kernel_size=3,
stride=2,
padding=1,
),
)
pretrained.model.start_index = start_index
pretrained.model.patch_size = [16, 16]
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def _make_pretrained_vitb_rn50_384(
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
):
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
hooks = [0, 1, 8, 11] if hooks == None else hooks
return _make_vit_b_rn50_backbone(
model,
features=[256, 512, 768, 768],
size=[384, 384],
hooks=hooks,
use_vit_only=use_vit_only,
use_readout=use_readout,
)